How to Create a Publication-Quality Heatmap in Python

Author:Murphy  |  View: 25927  |  Time: 2025-03-23 12:54:15

Introduction

Heatmaps can be used as informative figures to convey quantitative data. They can be used to convey data in an easy-to-read format providing a concise data summary.

Python has a number of tools to facilitate the production of publication quality heatmaps. These include the Seaborn and Matplotlib libraries, in addition to the subplot2grid libraries which can provide a convenient way to organise data in a heatmap.

In this tutorial, I will detail the steps required to produce a heatmap which focuses on the presence/absence of key elements. To do this, I will use a CSV file containing fictitious data about a selection of bacterial isolates. These bacterial strains have a number of features including antibiotic resistance genes, virulence genes, and certain capsule types. A heatmap will allow the quick inspection and comparison between the various strains.

While the example used focuses on bacterial strains, the techniques applied can be used more broadly for other datasets to help you visualised your data using a heatmap. Throughout the following tutorial, all images are by the author.


Objective

To create a publication quality heatmap displaying the presence/absence of key genes from fictitious bacterial strains.

This tutorial will use the following csv file, ‘Bacterial_strain_heatmap_tutorial_data.csv' available from the Github repository.

Getting started

To begin, a few imports are necessary to read in the data and stylise the figure later. We will begin by including all of the import statements together.

Python">import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt 
from matplotlib.colors import ListedColormap
import seaborn as sns 
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle

Next, we read in the dataframe, set the index using the column ‘Strain' and view the first 5 rows.

df = pd.read_csv('Bacterial_strain_heatmap_tutorial_data.csv').set_index('Strain')
df.head()

From the first 5 rows, we can see the data is organised with gene names as columns, while the index refers to each specific strain. A number, from 0–5 has been used to indicate the presence/absence of particular genes. This is because we will create a key denoting where the genes are found in the strains.

To improve readability, the heatmap will be organised into 3 sections. These include a section for antibiotic resistance genes, another section for virulence genes, and a third section for the capsule type. In order to achieve this, we will first need to extract the key columns for each section as a separate list.

amr_columns_to_extract = ['OqxA', 'OqxB', 'emrAB', 'nfsA', 'nfsB', 'MexAB-OprM',
       'mdt Efflux pump genes','blaSHV-28', 'blaSHV-148', 'aac(3)-IId', 'aac(3)-IIe',
       "aac(6')-Ib-cr", "strA", "strB", 'blaCTX-M-14',
       'blaCTX-M-15', 'blaLAP-2', 'blaOXA-1', 'blaTEM-1B', 'dfrA1', 'dfrA14',
       'qnrB1', 'qnrS1', 'sul1', 'sul2', 'tet(A)']

virulence_columns_to_extract = ['clpK', 'traT',
       'fecIRABCDE', 'acrAB', 'acrREF', 'mrkD', 'iutA', 'entABCDEFHS', 'treC',
       'pgaABC', 'ureDABCEFG']

capsule_type = ['Capsule K2', 'Capsule K3','Capsule K123']

Next, we will create a unique colour for each category of gene. To do this, we will create a dictionary where the key matches the value encoded in the raw csv file, and the value for that key is a unique hexcode colour. We will then use the ListedColormap class from the matplotlib.colors import at the beginning to create a colour map using list comprehension, which we can then use as a argument for the cmap parameter in the call to sns.heatmap.

cmap_dict = {0: '#E9EEEC', 1: '#6B83A9',2: '#868B89'}
cmap = ListedColormap([cmap_dict[i] for i in range(0,3)])

cmap_dict_2 = {3: '#D43B34', 4: '#79D170',5: '#E9EEEC'}
cmap2 = ListedColormap([cmap_dict_2[i] for i in range(3,6)])

Figure grid pattern organisation

The next task involves deciding how the heatmap should be organised. We will divide the heatmap into 3 sections, which include an initial heatmap for the antibiotic resistance genes, followed by a heatmap for the virulence genes and finally a heatmap for the capsule type.

To divide the heatmap into these 3 sections, we will use subplot2grid from the matplotlib.pyplot library.

fig = plt.figure(figsize=(15, 4))

ax1 = plt.subplot2grid((1, 12), (0, 0), colspan=6)
ax2 = plt.subplot2grid((1, 12), (0, 6), colspan=3)
ax3 = plt.subplot2grid((1, 12), (0, 9), colspan=2)

First we decide the figure size, followed by the number of rows and columns (for the whole figure) in the first tuple in the subplot2grid function call. We then determine the starting placement of each section in the second tuple, before assigning the designated number of columns to the colspan argument. This produces the following layout shown below.

We can now begin work on the first section of the heatmap. We will begin working on ax1.

In the call to sns.heatmap, we first extract the columns, amr_columns_to_extract from the original dataframe, add in the custom colour map to the cmap parameter, remove the colour bar, add a linewidth and colour, and add an alpha value to give the colours a softer tone.

We then loop through the x-axis tick labels, and assign them a different font, before italicising them. By convention gene names are italicised in genetics. We also loop through the y-axis labels and assign both labels for the x, and y axis a size of 10.5.

The legend elements that follow will be the key for the figure. We create a list of Patch objects that are each assigned a facecolor, edge color and a corresponding label. The list of Patch objects is assigned the variable name ‘legends elements', which can then be passed as a parameter to the ax1.legend attribute. The 5 legend elements are organised as a row by passing the ncol parameter the argument of 5.

Finally, we used the axhline/axvline to add a black border around the first heatmap to improve clarity.

fig = plt.figure(figsize=(15, 4))

ax1 = plt.subplot2grid((1, 12), (0, 0), colspan=6)
ax2 = plt.subplot2grid((1, 12), (0, 6), colspan=3)
ax3 = plt.subplot2grid((1, 12), (0, 9), colspan=2)

ax1.set_title('Antimicrobial resistance genes',size=12, pad=30, fontname='Ubuntu')

ax1 = sns.heatmap(df[amr_columns_to_extract], cmap=cmap, cbar=False,linewidths=0.4, 
                  linecolor='black', alpha=0.8, ax=ax1)

for tick in ax1.get_xticklabels():
    tick.set_fontname('Ubuntu')
    tick.set_style('italic')
ax1.tick_params(axis='x', labelsize=10.5)

for tick in ax1.get_yticklabels():
    tick.set_fontname('Ubuntu')
ax1.tick_params(axis='y', labelsize=10.5)

ax1.set_ylabel("Strainn", fontname='Ubuntu', fontsize=14)

legend_elements = [ Patch(facecolor='#868B89', edgecolor='#323436',
                         label='Present on chromosome'),
                   Patch(facecolor='#6B83A9', edgecolor='#323436',
                         label='Present on plasmid'),
                  Patch(facecolor='#D43B34', edgecolor='#323436',
                         label='Capsule K1'),
                  Patch(facecolor='#79D170', edgecolor='#323436',
                         label='Capsule K2/K57'),
                  Patch(facecolor='#E9EEEC', edgecolor='#323436',
                         label='Absent')]

ax1.legend(handles=legend_elements, 
          bbox_to_anchor=[0.9, -0.7], 
          title='Strain key', 
          ncol=5,
          frameon=False,
          prop={'size': 12, 'family': 'Ubuntu'},
          title_fontsize=12,
          handleheight=3, 
          handlelength=3,
          handletextpad=0.5,
          labelspacing=1.2,
          loc='center')

ax1.axhline(y=0, color='k',linewidth=3)
ax1.axhline(y=df[amr_columns_to_extract].shape[1], color='k',linewidth=3)
ax1.axvline(x=0, color='k',linewidth=3)
ax1.axvline(x=df[amr_columns_to_extract].shape[1], color='k',linewidth=3); 

This produces the first section of the heatmap.

We can now add the code for the second section of the heatmap. Here, key virulence genes are extracted and represented as a second dataframe in the call to sns.heatmap, assigned the ax2 variable. Crucially, for the second heatmap, we remove the y-axis tick labels and their names, as these are already provided via the first heatmap. Horizontal and vertical lines are again added to provide definition to the figure.

ax2 = sns.heatmap(df[virulence_columns_to_extract], cmap=cmap, cbar=False,linewidths=0.4, 
                  linecolor='black', alpha=0.8, ax=ax2)

ax2.set_title('Virulence genes',size=12, pad=30, fontname='Ubuntu')

for tick in ax2.get_xticklabels():
    tick.set_fontname('Ubuntu')
    tick.set_style('italic')
ax2.tick_params(axis='x', labelsize=10.5)

for tick in ax2.get_yticklabels():
    tick.set_visible(False)

ax2.tick_params(left=False)
ax2.set_ylabel('')    

ax2.axhline(y=0, color='k',linewidth=3)

ax2.axvline(x=0, color='k',linewidth=3)
ax2.axvline(x=df[virulence_columns_to_extract].shape[1], color='k',linewidth=3); 

This produces the following heatmap.

The third section of the heatmap can now be added, using the same techniques discussed. This time, different columns from the original dataframe are extracted.

ax3 = sns.heatmap(df[capsule_type], cmap=cmap2, cbar=False,linewidths=0.4, 
                  linecolor='black', alpha=0.8, ax=ax3)

ax3.set_title('Capsule type',size=12, pad=30, fontname='Ubuntu')

for tick in ax3.get_xticklabels():
    tick.set_fontname('Ubuntu')
    tick.set_style('italic')
ax3.tick_params(axis='x', labelsize=10.5)

for tick in ax3.get_yticklabels():
    tick.set_visible(False)

ax3.tick_params(left=False)
ax3.set_ylabel('')

ax3.axvline(x=0, color='k',linewidth=3)
ax3.axvline(x=df[capsule_type].shape[1], color='k',linewidth=3); 

Combining the code altogether, produces the following figure. Through simple figure organisation, and colour coordination with a key, a publication quality heatmap is ready for your next piece of work!

Conclusion

Through the use of Python plotting libraries, and the subplot2grid module from matplotlib.pyplot, informative heatmaps can be generated to provide a comprehensive data summary. Combing the heatmaps with a key can also include the level of information included. While the example presented here is specific to genetics, the techniques are broadly applicable, and will work for other data sets providing the raw data is encoded appropriately. The code, csv file and image to support this tutorial can be found on my Github repository.

Tags: Data Science Data Visualization Heatmap Programming Python

Comment