Source code for votekit.matrices.heatmap

import numpy as np
from matplotlib.axes import Axes
from typing import Optional, Dict, List, Tuple, Union
from matplotlib.colors import Colormap, TwoSlopeNorm
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as mpatches


def _validate_heatmap_inputs(
    *,
    matrix: np.ndarray,
    row_labels: Optional[List[str]] = None,
    column_labels: Optional[List[str]] = None,
    row_legend: Optional[Dict[str, str]] = None,
    column_legend: Optional[Dict[str, str]] = None,
) -> None:
    """
    Used internally for the `matrix_heatmap` function to validate the inputs to the function.

    Args:
        matrix (np.ndarray): A 2D numpy array containing the data to be plotted.
        row_labels (Optional[List[str]]): A list of strings containing the labels for the rows
            of the heatmap. Defaults to None.
        column_labels (Optional[List[str]]): A list of strings containing the labels for the
            columns of the heatmap. Defaults to None.
        row_legend (Optional[Dict[str, str]]): A dictionary mapping row labels to legend
            descriptions. Defaults to None.
        column_legend (Optional[Dict[str, str]]): A dictionary mapping column labels to legend
            descriptions. Defaults to None.

    Raises:
        ValueError: If the matrix is not a 2D numpy array.
        ValueError: If the row labels are not of the correct length.
        ValueError: If the column labels are not of the correct length.
        ValueError: If the keys of the row legend are do not match the row labels.
        ValueError: If the keys of the column legend are do not match the column labels.
        ValueError: If the legend descriptions in the row legend and column legend are in conflict.
    """
    if matrix.ndim != 2:
        raise ValueError(
            f"Please provide a 2D matrix to plot. Found a {matrix.ndim}-D matrix."
        )

    n_rows = matrix.shape[0]
    n_cols = matrix.shape[1]

    if row_labels is not None:
        if n_rows != len(row_labels):
            raise ValueError(
                f"Please provide {n_rows} labels for the rows of the "
                f"matrix. Found {len(row_labels)} labels."
            )

    if column_labels is not None:
        if n_cols != len(column_labels):
            raise ValueError(
                f"Please provide {n_cols} labels for the columns of the "
                f"matrix. Found {len(column_labels)} labels."
            )

    if row_legend is not None and column_legend is not None:
        for label, description in column_legend.items():
            desc = row_legend.get(label, description)
            if desc != description:
                raise ValueError(
                    f"Conflicting legend descriptions for '{label}': "
                    f"got '{description}' and '{desc}'."
                )

    if row_legend is not None and list(row_legend.keys()) != row_labels:
        raise ValueError("Row labels do not match row legend keys.")

    if column_legend is not None and list(column_legend.keys()) != column_labels:
        raise ValueError("Column labels do not match column legend keys.")


def _add_text_to_heatmap(
    *,
    heatmap: Axes,
    n_decimals_to_display: int,
    matrix: np.ndarray,
    cell_font_size: Optional[int] = None,
) -> Axes:
    """
    Adds the text values to the heatmap cells. This function dynamically determines
    the font size based on the number of cells in the figure, the figure size, and
    the length of the numbers to ensure readability.

    Args:
        heatmap (matplotlib.axes.Axes): The matplotlib axis containing the heatmap.
        n_decimals_to_display (int): The number of decimal places to display for the values
            in the heatmap.
        matrix (np.ndarray): A 2D numpy array containing the data to be plotted.
        cell_font_size (Optional[int]): The base font size to use for the cell values.
            If None, the font size will be dynamically determined.

    Returns:
        matplotlib.axes.Axes: The matplotlib axis containing the heatmap with the cell values
            and text values added.
    """
    nrows, ncols = matrix.shape
    quadmesh = heatmap.collections[0]

    max_chars = max(
        len(f"{val:.{n_decimals_to_display}f}") + (3 if val < 0 else 2)
        for val in matrix.flatten()
    )

    if cell_font_size is not None:
        font_size = cell_font_size
    else:
        fig = heatmap.get_figure()
        if fig:
            fig.canvas.draw()

            # ignoring mypy error, mypy not up to date with matplotlib get_renderer
            renderer = fig.canvas.get_renderer()  # type: ignore[attr-defined]

            # bounding box in display coords
            bbox = heatmap.get_window_extent(renderer=renderer)
            width_pts = bbox.width

            cell_width_pts = width_pts / ncols

            font_size = int(cell_width_pts / max_chars)

    for i in range(nrows):
        for j in range(ncols):
            val = matrix[i, j]
            txt = "N/A" if np.isnan(val) else f"{val:.{n_decimals_to_display}f}"
            txt = "inf" if np.isinf(val) else f"{val:.{n_decimals_to_display}f}"

            # Normalize the cell value between 0 and 1, then get the RGBA color
            norm_val = quadmesh.norm(val)
            r, g, b, _ = quadmesh.cmap(norm_val)  # type: ignore[misc]

            # Simple brightness measure: average of R, G, B
            brightness = (r + g + b) / 3
            txt_color = "black" if brightness > 0.5 else "white"
            txt_color = "white" if np.isnan(val) else txt_color
            txt_color = "white" if np.isinf(val) else txt_color

            heatmap.text(
                j + 0.5,
                i + 0.5,
                txt,
                ha="center",
                va="center",
                color=txt_color,
                fontsize=font_size,
            )

    return heatmap


def _add_legend_to_heatmap(
    *,
    row_and_col_legends: Dict[str, str],
    ax: Axes,
    legend_font_size: float,
    legend_loc: str,
    legend_bbox_to_anchor: Tuple[float, float],
) -> Axes:
    """
    Adds a legend to a heatmap.

    Args:
        row_and_col_legends (Dict[str, str]): A dictionary mapping row and column labels to
            legend descriptions.
        ax (matplotlib.axes.Axes): The matplotlib axis to add the legend to.
        legend_font_size (float): The font size to use for the legend.
        legend_loc (str): The location to place the legend.
        legend_bbox_to_anchor (Tuple[float, float]): The bounding box to anchor the legend to.

    Returns:
        matplotlib.axes.Axes: The matplotlib axis containing with the updated legend.
    """
    proxy_artists = []
    proxy_labels = []

    if len(row_and_col_legends) != 0:
        for label, description in row_and_col_legends.items():
            patch = mpatches.Patch(color="white", label=f"{label}: {description}")
            proxy_artists.append(patch)
            proxy_labels.append(f"{label}: {description}")

    if proxy_artists:
        leg = ax.legend(
            handles=proxy_artists,
            labels=proxy_labels,
            loc=legend_loc,
            bbox_to_anchor=legend_bbox_to_anchor,
            fontsize=legend_font_size,
            ncol=len(proxy_labels) // 15 + 1,
            frameon=True,
            borderaxespad=0.0,
            handlelength=0,
            handletextpad=0,
            fancybox=True,
        )

        for item in leg.legend_handles:
            if item:
                item.set_visible(False)

    return ax


[docs] def matrix_heatmap( matrix: np.ndarray, *, ax: Optional[Axes] = None, show_cell_values: bool = True, n_decimals_to_display: int = 2, row_labels: Optional[List[str]] = None, row_label_rotation: Optional[float] = None, row_legend: Optional[dict[str, str]] = None, column_labels: Optional[List[str]] = None, column_label_rotation: Optional[float] = None, column_legend: Optional[dict[str, str]] = None, cell_color_map: Optional[Union[str, Colormap]] = None, cell_font_size: Optional[int] = None, cell_spacing: float = 0.5, cell_divider_color: str = "white", show_colorbar: bool = False, legend_font_size: float = 10.0, legend_location: str = "center left", legend_bbox_to_anchor: Tuple[float, float] = (1.03, 0.5), ) -> Axes: """ Basic function for plotting a matrix as a heatmap. Args: matrix (np.ndarray): A 2D numpy array containing the data to be plotted. ax (matplotlib.axes.Axes, optional): The matplotlib axis to plot on. Defaults to None, in which case an axis is created. show_cell_values (bool): Whether to show the values of the cells in the heatmap. These values are shown in the center of each cell and are dynamically formatted to be human-readable. Defaults to True. n_decimals_to_display (int): The number of decimal places to display for the values in the heatmap. Defaults to 2. row_labels (Optional(List[str])): A list of strings containing the labels for the rows of the heatmap. Defaults to None. row_label_rotation (Optional(float)): The rotation to apply to the row labels. Defaults to None. row_legend (Optional(Dict[str, str])): A dictionary mapping row labels to legend descriptions. Defaults to None. column_labels (Optional(List[str])): A list of strings containing the labels for the columns of the heatmap. Defaults to None. column_label_rotation (Optional(float)): The rotation to apply to the column labels. Defaults to None. column_legend (Optional(Dict[str, str])): A dictionary mapping column labels to legend descriptions. Defaults to None. cell_color_map (Optional(Union[str, matplotlib.colors.Colormap])): The color map to use for the heatmap. Defaults to `PRGn` if the matrix contains negative values and `Greens` otherwise. cell_font_size (Optional(int)): The font size to use for the cell values. Defaults to None, which will then use dynamic font size based on the number of cells and the figure size. cell_spacing (float): The spacing between the cells in the heatmap. Defaults to 0.5. cell_divider_color (str): The color to use for the cell dividers for spacing cells. Defaults to "white". show_colorbar (bool): Whether to show the colorbar for the heatmap. Defaults to False. legend_font_size (float): The font size to use for the legend. Defaults to 10.0. legend_location (str): The location to place the legend. Defaults to "center left". legend_bbox_to_anchor (Tuple[float, float]): The bounding box to anchor the legend to. Defaults to (1.03, 0.5). Returns: matplotlib.axes.Axes: The matplotlib axis containing the heatmap. """ if ax is None: _, ax = plt.subplots() _validate_heatmap_inputs( matrix=matrix, row_labels=row_labels, column_labels=column_labels, row_legend=row_legend, column_legend=column_legend, ) row_and_col_legends = dict() if row_legend is not None: row_and_col_legends.update(row_legend) if column_legend is not None: row_and_col_legends.update(column_legend) ax.xaxis.set_ticks_position("top") ax.yaxis.set_ticks_position("left") vmin = np.nanmin(matrix) vmax = np.nanmax(matrix) norm: Union[TwoSlopeNorm, str, None] = ( TwoSlopeNorm(vmin=vmin, vcenter=0.0, vmax=vmax) if (vmin < 0 < vmax) else None ) if cell_color_map is None: if np.nanmin(matrix) < 0: cell_color_map = sns.color_palette("PRGn", as_cmap=True) norm = TwoSlopeNorm( vmin=np.nanmin(matrix), vcenter=0.0, vmax=np.nanmax(matrix) ) else: cell_color_map = sns.color_palette("Greens", as_cmap=True) norm = "linear" heatmap = sns.heatmap( matrix, ax=ax, cmap=cell_color_map, norm=norm, fmt=f".{n_decimals_to_display}f", linewidths=cell_spacing, linecolor=cell_divider_color, cbar=show_colorbar, yticklabels=row_labels if row_labels is not None else False, xticklabels=column_labels if column_labels is not None else False, ) plt.gca().set_facecolor("black") # inf cells if show_cell_values: heatmap = _add_text_to_heatmap( heatmap=heatmap, n_decimals_to_display=n_decimals_to_display, matrix=matrix, cell_font_size=cell_font_size, ) if len(row_and_col_legends) > 0: heatmap = _add_legend_to_heatmap( row_and_col_legends=row_and_col_legends, ax=heatmap, legend_font_size=legend_font_size, legend_loc=legend_location, legend_bbox_to_anchor=legend_bbox_to_anchor, ) if column_label_rotation is not None: ax.set_xticklabels( ax.get_xticklabels(), rotation=column_label_rotation, ha="left", rotation_mode="anchor", ) if row_label_rotation is not None: ax.set_yticklabels( ax.get_yticklabels(), rotation=row_label_rotation, rotation_mode="anchor" ) return ax