Source code for aigct.plot_util

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


[docs] def barchart(df: pd.DataFrame, x: str, y: str, x_label: str = '', y_label: str = '', ci: pd.DataFrame = None, filename: str = None, hatch: str = None, edgecolor=None, palette=None, errorbar=None, estimator: str = 'mean', title: str = '', title_fontsize: int = 20, y_label_fontsize: int = 15, x_label_fontsize: int = 15, y_tick_label_size: int = 15, x_tick_label_size: int = 15, xtick_rotation: int = None, xtick_rotation_mode: str = None, file_dpi: int = 300, bbox_inches: str = 'tight'): fig, ax = plt.subplots(figsize=(6, 4)) ax = sns.barplot(df, x=x, y=y, palette=palette, hatch=hatch, edgecolor=edgecolor, errorbar=errorbar, estimator=estimator, hue=x) ax.tick_params('y', labelsize=y_tick_label_size) ax.tick_params('x', labelsize=x_tick_label_size) ax.set_ylabel(y_label, fontsize=y_label_fontsize) ax.set_xlabel(x_label, fontsize=x_label_fontsize) if xtick_rotation: ax.set_xticklabels(ax.get_xticklabels(), rotation=xtick_rotation, ha="right", rotation_mode=xtick_rotation_mode) if ci is not None and len(ci) != 0: plt.errorbar(x=df[x], y=df[y], yerr=ci["CI"], fmt="none", c="k") plt.title(title, fontsize=title_fontsize) if filename is None: plt.show() return plt.savefig(filename, dpi=file_dpi, bbox_inches=bbox_inches) plt.savefig(filename.replace('.png', '.svg'), bbox_inches=bbox_inches)
[docs] def create_feature_palette(values: list) -> dict: if len(values) > 255: colors = sns.color_palette('husl', len(values)) else: # Get a list of all available colors in Seaborn's color palette all_colors = sns.color_palette("husl", 256) # Randomly select colors from the list random_colors = np.random.choice(range(len(all_colors)), len(values), replace=False) colors = [all_colors[i] for i in random_colors] return dict(zip(values, colors))
[docs] def get_colors(num_colors: int) -> list: if num_colors > 255: colors = sns.color_palette('husl', num_colors) else: # Get a list of all available colors in Seaborn's color palette all_colors = sns.color_palette("husl", 256) # Randomly select colors from the list random_colors = np.random.choice(range(len(all_colors)), num_colors, replace=False) colors = [all_colors[i] for i in random_colors] return colors
[docs] def heatmap(df: pd.DataFrame, y_column: str = None, x_column: str = None, value_column: str = None, x_label: str = '', y_label: str = '', title: str = '', cmap: str = 'viridis', annot: bool = True, fmt: str = '.2f', linewidths: float = 0.5, title_fontsize: int = 20, label_fontsize: int = 15, tick_fontsize: int = 12, annot_fontsize: int = 10, figsize: tuple = (10, 8), filename: str = None, vmin: float = None, vmax: float = None, cbar: bool = True, cbar_kws: dict = None, square: bool = False, mask: pd.DataFrame = None, file_dpi: int = 300, bbox_inches: str = 'tight'): """ Plot a heatmap using seaborn. Parameters ---------- df : pd.DataFrame DataFrame to plot as heatmap y_column : str, optional Column to use for y-axis categories. If provided along with x_column and value_column, the function will pivot the data before plotting. x_column : str, optional Column to use for x-axis categories. If provided along with y_column and value_column, the function will pivot the data before plotting. value_column : str, optional Column containing values to plot. If provided along with x_column and y_column, the function will pivot the data before plotting. x_label : str, optional Label for x-axis y_label : str, optional Label for y-axis title : str, optional Title of the plot cmap : str, optional Colormap to use annot : bool, optional Whether to annotate cells with values fmt : str, optional Format string for annotations linewidths : float, optional Width of lines between cells title_fontsize : int, optional Font size for title label_fontsize : int, optional Font size for axis labels tick_fontsize : int, optional Font size for tick labels annot_fontsize : int, optional Font size for annotations figsize : tuple, optional Figure size (width, height) in inches filename : str, optional If provided, save figure to this file vmin : float, optional Minimum value for colormap scaling vmax : float, optional Maximum value for colormap scaling cbar : bool, optional Whether to draw a colorbar cbar_kws : dict, optional Additional arguments for colorbar square : bool, optional Whether to make cells square-shaped mask : pd.DataFrame, optional Boolean DataFrame of same shape as df, True values will not be plotted file_dpi : int, optional DPI for saved figure bbox_inches : str, optional Bounding box in inches for saved figure Returns ------- matplotlib.axes.Axes The matplotlib axes containing the plot """ if (x_column is not None and y_column is not None and value_column is not None): df = df.pivot(index=y_column, columns=x_column, values=value_column) fig, ax = plt.subplots(figsize=figsize) # Set default colorbar keywords if not provided if cbar_kws is None: cbar_kws = {'shrink': 0.8} # Create the heatmap hm = sns.heatmap(df, ax=ax, cmap=cmap, annot=annot, fmt=fmt, linewidths=linewidths, vmin=vmin, vmax=vmax, cbar=cbar, cbar_kws=cbar_kws, square=square, mask=mask, annot_kws={'size': annot_fontsize}) # Set labels and title ax.set_xlabel(x_label, fontsize=label_fontsize) ax.set_ylabel(y_label, fontsize=label_fontsize) ax.set_title(title, fontsize=title_fontsize) # Set tick parameters ax.tick_params(axis='both', labelsize=tick_fontsize) # Rotate x-axis labels if needed plt.xticks(rotation=45, ha='right') # Tight layout plt.tight_layout() # Save figure if filename is provided if filename: plt.savefig(filename, dpi=file_dpi, bbox_inches=bbox_inches) # Also save as SVG for vector graphics plt.savefig(filename.replace('.png', '.svg'), bbox_inches=bbox_inches) return ax