Skip to content

Instantly share code, notes, and snippets.

@ZenithClown
Last active April 13, 2023 07:08
Show Gist options
  • Save ZenithClown/253fa1592acd024a2e7e022c99b6a154 to your computer and use it in GitHub Desktop.
Save ZenithClown/253fa1592acd024a2e7e022c99b6a154 to your computer and use it in GitHub Desktop.
A set of Utility Function(s) to Visualize Correlation of a Data Frame

Correlation Visualization

better understand correlation with bar-charts and heatmaps

Colab Notebook

Understanding correlation of a data is one of the fundamental thing to do with a new dataset! It is as easy as to use the function pd.corr(), however often times it is difficult to identify features from all the abundant informations available. Typically, for business requirements, one uses the sns.heatmap(pd.corr()) to visually represent the correlation. Considering the Titanic dataset, we have the correlation heatmap as:

Correlation Heatmap

Well, for a smaller and mostly well-known dataset, it is okayish! Yet, we can upgrade the functionality to something more interesting like:

Better Heatmap

At a glance, we now have much more informations - (i) filtering and annotating only those values which have a certain threshold value, (ii) visualize a barplot of correlation values against a target feature. However, this is just the beginning! Let's showcase some more functionalities of heatmap.py module.

Getting Started

The code is publically available at GitHub gists which is a simple platform for sharing code snippets with the community. To use the code, simply clone the code like:

git clone https://gist.github.com/ZenithClown/253fa1592acd024a2e7e022c99b6a154.git corr_heatmap
export PYTHONPATH="${PYTHONPATH}:corr_heatmap"

Done, you can now easily import the function with python notebooks/code-files like:

from heatmap import corr_heatmap

Advanced Imports

The heatmap.py file can be directly imported during a code runtime, like when using Colab Notebooks using the requests module.

import requests

CODE_URI = "https://gist.githubusercontent.com/ZenithClown/253fa1592acd024a2e7e022c99b6a154/raw/c6232465acc24ca55d9ab5da0b02930db79daf2e/heatmap.py"
with open("heatmap.py", "w") as f:
    f.write(requests.get(CODE_URI).text)

Better Visualization

The function has a well-extended set of parameters that can control almost anything! In addition, it also displays the heatmap more appealing, like:

Displaying either Upper/Lower Triangle

The tri = all|upper|lower argument uses the np.tri method to create a mask for the heatmap plot. Thus, we can now only display either the upper or lower triangle for heatmap like:

fig = corr_heatmap(
    data,
    target_column = "survived", # display the correlation value of other features
    annot_thresh = 0.1, # based on which annotation is marked
    tri = "lower", # only display lower triangle
    round = 3, linewidths = 0.2, linecolor = "black", square = True, orient = "h", cmap = "RdYlBu"
)

tri = lower

fig = corr_heatmap(
    data,
    target_column = "survived", # display the correlation value of other features
    annot_thresh = 0.2, # based on which annotation is marked
    tri = "upper", # only display upper triangle
    round = 3, linewidths = 0.2, linecolor = "black", square = True, orient = "h", cmap = "RdYlBu"
)

tri = upper

This is just the beginning! Check the colab notebook for a complete example.

Known Issues

  • To change the figure display options, one have to edit the rcparams, or you can use a custom stylesheet (recommended stylesheet) to control parameters like figsize etc.
  • Jupyter notebooks' auto display function has an issue of discplaying same figures twice. This is known and can be controlled by:
    • either assign a variable like fig when plotting, or
    • pass _environment = jupyter on function call.
# -*- encoding: utf-8 -*-
"""
A set of Utility Function(s) to Visualize Correlation of a DataFrame
"Correlation is a statistical measure that expresses the extent to
which two variables are linearly related." Popular methods' like
`Pearson's R` quantifies the strength of relation between features.
Programatically, there are various in-built function like `pd.corr()`
which calculates correlation.
Often, it is easier to just visualize the correlation information
with the help of Heat Maps to understand the relationship for many
variables. The code uses `seaborn` and `matplotlib` libraries to
showcase correlation heat-map of a dataframe.
@author: Debmalya Pramanik
"""
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def corr_heatmap(
df : pd.DataFrame,
plot_bar : bool = True,
target_column : str = None,
annot : bool = True,
annot_thresh : float = 0.65,
orient : str = "v",
tri : str = "all",
**kwargs
) -> plt.Figure:
"""
Calculate Correlation of a DataFrame (`df`) and Plot Heat-Map
The function uses the in-built `pandas` function `pd.corr()` to
calculate correlation of all numeric columns and then returns a
visualization heatmap using `seaborn` for better understanding of
inter-relationship between all features.
:type df: object
:param df: Original dataframe on which correlation heatmap is to
be performed. The function uses the in-built pandas
method.
:type plot_bar: bool
:param plot_bar: Bar plot of correlation of `target_column`
against all other related numeric features.
! this parameter is not used currently
TODO set controls to plot bar based on user
:type target_column: str
:param target_column: Name of the target column when plotting
correlation values in bar plot against all
other features available in the dataframe.
If `plot_bar == True` then this parameter
is required, else ignored by the function.
:type annot: bool
:param annot: Annotate heatmap labels. Defaults to True. When set
to True, then `annot_thresh` keyword arguments can
be used to control the nature of the map.
:type annot_thresh: float
:param annot_thresh: Threshold value for annotation. The range of
value is [0, 1] i.e. for any given value (x)
the annotation is done only when correlation
is greater than or equal to `x` or less than
or equal to `-x`. Defaults to 0.65. This
parameter is passed to `corr_barplot()` as
`threshold` parameter for deciding important
features.
:type orient: str
:param orient: Orientation of the bar plot (h|v) of correlated
terms whose value is `abs(corr) >= annot_thresh`.
Defaults to `v` i.e. heatmap and barplot is
stacked vertically, else pass `h` for horizontal
stacking.
:type tri: str
:param tri: An array with ones at and below the given diagonal
and zeros elsewhere. Accepted value (all|upper|lower)
are used as `mask` parameter to `sns.heatmap`, and
defaults to `all` i.e. both upper and lower triangle
alongwith the diagonal is displayed in heatmap.
Keyword Arguments
-----------------
The function accepts alomst all keyword arguments accepted by
`pd.corr()` and `sns.heatmap()` method. Additionally, the
behaviour of the plot and correlation can be controlled with the
below arguments.
* *method* (`str`): Method of correlation. Accepts all values
as supported by `df.corr()` function.
* *min_periods* (`int`): Minimum number of observations
required per pair of columns to have a valid result. Check
documentation of `df.corr()` for more information.
* *round* (`int`): Round a number to a given precision in
decimal digits. Typicall used in plot annotations.
* *vmin* (floats): Ass accepted by `sns.heatmap` function.
* *vmax* (floats): Ass accepted by `sns.heatmap` function.
* *cmap* (floats): Ass accepted by `sns.heatmap` function.
* *cbar* (floats): Ass accepted by `sns.heatmap` function.
* *square* (floats): Ass accepted by `sns.heatmap` function.
* *linecolor* (floats): Ass accepted by `sns.heatmap` function.
* *linewidths* (floats): Ass accepted by `sns.heatmap` function.
"""
corr = df.corr(
method = kwargs.get("method", "pearson"),
min_periods = kwargs.get("min_periods", 1)
)
# * get additional keyword arguments for control
round_ = kwargs.get("round", 2)
_environment = kwargs.get("_environment", "terminal")
# if annotation is true, then define `labels`
if annot:
labels = corr.applymap(lambda x : str(round(x, round_)) if abs(x) >= annot_thresh else "")
else:
# annotation is not required in heatmap, so set default `None` for all labels
labels = None
# define masking for heatmap
# https://stackoverflow.com/q/57414771/6623589
if tri == "all":
mask = None # default, show both upper and lower triangle
elif tri == "upper":
# https://numpy.org/doc/stable/reference/generated/numpy.triu.html
mask = np.triu(corr) # show upper triangle in heatmap
elif tri == "lower":
# https://numpy.org/doc/stable/reference/generated/numpy.tril.html
mask = np.tril(corr) # show lower triangle in heatmap
else:
raise ValueError(f"tri ( == {tri}) is not understood. Accepted values: all|upper|lower.")
# plot the actual heatmap and set other attributes
if orient == "h":
# horizontally stack heatmap and barplot
fig, axs = plt.subplots(nrows = 1, ncols = 2)
elif orient == "v":
fig, axs = plt.subplots(nrows = 2, ncols = 1)
else:
raise ValueError(f"orient ( == {orient}) is not understood. Accepted values: h|v.")
# plot heatmap using seaborn library
_ = sns.heatmap(
corr,
fmt = "",
ax = axs[0],
mask = mask,
annot = labels,
vmin = kwargs.get("vmin", None),
vmax = kwargs.get("vmax", None),
cmap = kwargs.get("cmap", None),
cbar = kwargs.get("cbar", True),
square = kwargs.get("square", False),
linewidths = kwargs.get("linewidths", 0),
linecolor = kwargs.get("linecolor", "white"),
)
# plot bar using defined corr_barplot()
_ = corr_barplot(
corr,
target_column = target_column,
keep_all_feature = kwargs.get("keep_all_feature", False),
threshold = annot_thresh,
ax = axs[1],
round = round_,
y_annot_pos_adjust = kwargs.get("y_annot_pos_adjust", (5e-4, -25e-3))
)
if _environment == "jupyter":
# https://stackoverflow.com/q/35422988/6623589
plt.close()
return fig
def corr_barplot(
correlations : pd.DataFrame,
target_column : str,
keep_all_feature : bool = False,
threshold : float = 0.65,
**kwargs
) -> plt.Figure:
"""
Bar Plot the Correlation Values of all Features against Target
Given a target column name from the `correlations = df.corr()`
the function plots a bar plot, where the y-value length is the
correlation coefficient.
:type correlations: object
:param correlations: Correlation values, typically this is
obtained by using the `df.corr()` function,
and is controlled externally.
:type target_column: str
:param target_column: Name of the target column, must be present
in both column and index of the correlation
dataframe.
:type keep_all_feature: bool
:param keep_all_feature: Keep all the numeric features column in
bar plot, or just show only those column
(or feature) whose correlation is above
`threshold` or below `-threshold`.
Defaults to False, i.e. only essential
features are displayed.
:type threshold: float
:param threshold: Threshold value based on which important
feeatures is decided. The range of value is
[0, 1] i.e. for any given value (x) the
feature is important iff correlation value is
greater than or equal to `x` or less than or
equal to `-x`. Defaults to 0.65.
Keyword Arguments
-----------------
The function accepts alomst all keyword arguments accepted by
`df.sort_values()` and `sns.barplot()` method. Additionally, the
behaviour of the plot and correlation can be controlled with the
below arguments.
* *ascending* (`bool`): Sort the features correlation in
ascending order. Defaults to True. This parameter is passed
directly to `df.sort_values()` for sorting.
* *y_annot_pos_adjust* (`array-like`): A set of `(up, low)`
value to adjust the annotation in bar plot. The two value
is passed to `ticks` and the text is adjusted.
"""
# format the correlation dataframe
correlations = correlations[target_column].reset_index() \
.rename(columns = {"index" : "feature"}) \
.sort_values(target_column, ascending = kwargs.get("ascending", True))
correlations = correlations[correlations.feature != target_column]
if not keep_all_feature:
# remove features based on threshold value
correlations = correlations[~correlations[target_column].between(-threshold, threshold)]
# plot the bar, but first let decide axis parameters
ax = kwargs.get("ax", None)
_environment = kwargs.get("_environment", "terminal")
if ax:
# we have axis object, no need for seperate defination
# ! no return, as controlled by parent class `fig` object
ax = sns.barplot(
x = "feature", y = target_column, data = correlations,
palette = sns.color_palette("RdYlBu", correlations.shape[0]).as_hex(),
ax = ax,
)
else:
# this function is called in standalone mode
# thus, we define axis object, and figure is returned
fig, ax = plt.subplots(nrows = 1, ncols = 1)
ax = sns.barplot(
x = "feature", y = target_column, data = correlations,
palette = sns.color_palette("RdYlBu", correlations.shape[0]).as_hex()
)
if _environment == "jupyter":
# https://stackoverflow.com/q/35422988/6623589
plt.close()
# get y-annotation position adjustment value from kwargs
y_pos_u, y_pos_l = kwargs.get("y_annot_pos_adjust", (0, 0))
corr_ = correlations[target_column].values.round(kwargs.get("round", 2))
for tick in range(len(ax.get_xticklabels())):
y_pos = corr_[tick] + (y_pos_u if corr_[tick] >= 0 else y_pos_l)
ax.text(tick, y_pos, str(corr_[tick]), ha = "center", weight = "bold")
ax.set(xlabel = kwargs.get("bar_xlabel", "Feature Names"))
ax.set(ylabel = kwargs.get("bar_ylabel", f"Correlation Value with `{target_column}`"))
return None if ax else fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment