import matplotlib.pyplot as plt
# from IPython.display import HTML
from IPython.display import display
# import seaborn as sns
# import plotly.express as px
from matplotlib.ticker import MaxNLocator
from scipy.stats import linregress
import numpy as np
import pandas as pd
import os
from .model import (
VEAnalysisCalibrationResult,
VEAnalysisResult
)
from .util import Config
from .date_util import now_str_compact
from .file_util import (
unique_file_name,
create_folder
)
from .plot_util import (barchart, create_feature_palette,
get_colors, heatmap)
from .report_util import GeneMetricSorter
from adjustText import adjust_text
[docs]
class VEAnalysisPlotter:
"""Plot results of an analysis"""
def __init__(self, config: Config):
[docs]
self._roc_pr_config = config.roc_pr_line
[docs]
self._mwu_config = config.mwu_bar
[docs]
self._calibration_line_config = config.calibration_line
[docs]
def _plot_roc_curves(self, aucs: pd.DataFrame, user_vep_name: str,
curve_coords: pd.DataFrame,
batch_no: int, num_batches: int,
num_top_labelled_veps: int,
ves_color_palette: dict,
file_name: str = None):
title = self._roc_pr_config.roc_title + (
f" ({batch_no+1} of {num_batches})" if num_batches > 1 else "")
plt.figure(figsize=(10, 8))
plot_count = 0
for i, ve_auc in aucs.iterrows():
if pd.isna(ve_auc['ROC_AUC']):
continue
ve_curve_coords = curve_coords[curve_coords['SCORE_SOURCE'] ==
ve_auc['SCORE_SOURCE']].sort_values(
'THRESHOLD', ascending=False)
if num_top_labelled_veps is None or plot_count < \
num_top_labelled_veps:
label = ve_auc['SOURCE_NAME'] + \
' (AUC=' + str(round(ve_auc['ROC_AUC'], 4)) + ')'
else:
label = None
plt.plot(
ve_curve_coords['FALSE_POSITIVE_RATE'],
ve_curve_coords['TRUE_POSITIVE_RATE'],
label=label,
color=ves_color_palette[ve_auc['SOURCE_NAME']],
lw=self._config.line_width,
linestyle=self._config.line_style)
plot_count += 1
plt.xlabel('False Positive Rate',
fontsize=self._roc_pr_config.x_axis_font_size)
plt.ylabel('True Positive Rate',
fontsize=self._roc_pr_config.y_axis_font_size)
plt.tick_params(axis='both', labelsize=self._roc_pr_config.label_size)
plt.title(title, fontsize=self._roc_pr_config.title_font_size)
legend = plt.legend(loc="lower right",
fontsize=self._roc_pr_config.legend_font_size)
for line in legend.get_lines():
line.set_linewidth(self._roc_pr_config.legend_line_width)
if file_name is None:
plt.show()
return
plt.savefig(file_name, dpi=self._config.file_dpi,
format='png',
bbox_inches=self._config.bbox_inches)
plt.savefig(file_name.replace(".png", ".svg"),
format='svg',
bbox_inches=self._config.bbox_inches)
[docs]
def _plot_pr_curves(self, aucs: pd.DataFrame, user_vep_name: str,
curve_coords: pd.DataFrame,
batch_no: int, num_batches: int,
num_top_labelled_veps: int,
ves_color_palette: dict, file_name: str = None):
title = self._roc_pr_config.pr_title + \
(f" ({batch_no+1} of {num_batches})" if num_batches > 1 else "")
plt.figure(figsize=(10, 8))
plot_count = 0
for i, ve_auc in aucs.iterrows():
ve_curve_coords = curve_coords[curve_coords['SCORE_SOURCE'] ==
ve_auc['SCORE_SOURCE']].sort_values(
'THRESHOLD', ascending=False)
if num_top_labelled_veps is None or plot_count < \
num_top_labelled_veps:
label = ve_auc['SOURCE_NAME'] + \
' (AUC=' + str(round(ve_auc['PR_AUC'], 4)) + ')'
else:
label = None
plt.plot(
ve_curve_coords['RECALL'][1:],
ve_curve_coords['PRECISION'][1:],
label=label,
color=ves_color_palette[ve_auc['SOURCE_NAME']],
lw=self._config.line_width,
linestyle=self._config.line_style)
plot_count += 1
plt.xlabel('Recall', fontsize=self._roc_pr_config.x_axis_font_size)
plt.ylabel('Precision', fontsize=self._roc_pr_config.y_axis_font_size)
plt.tick_params(axis='both', labelsize=self._roc_pr_config.label_size)
plt.title(title, fontsize=self._roc_pr_config.title_font_size)
plt.ylim([0, 1.05])
legend = plt.legend(loc="lower center",
fontsize=self._roc_pr_config.legend_font_size)
for line in legend.get_lines():
line.set_linewidth(self._roc_pr_config.legend_line_width)
if file_name is None:
plt.show()
return
plt.savefig(file_name, dpi=self._config.file_dpi,
format='png',
bbox_inches=self._config.bbox_inches)
plt.savefig(file_name.replace(".png", ".svg"),
format='svg',
bbox_inches=self._config.bbox_inches)
[docs]
def _display_mwu_table(self, results: VEAnalysisResult,
file_name: str = None):
table_output = results.general_metrics.merge(
results.mwu_metrics, how="inner", suffixes=(None, "_y"),
on='SCORE_SOURCE')[[
'SOURCE_NAME', 'NUM_VARIANTS', 'NUM_POSITIVE_LABELS',
'NUM_NEGATIVE_LABELS', 'NEG_LOG10_MWU_PVAL']]
table_output = table_output.sort_values('NEG_LOG10_MWU_PVAL',
ascending=False)
style = table_output.style.set_properties(
subset=["NUM_VARIANTS",
"NUM_POSITIVE_LABELS", "NUM_NEGATIVE_LABELS",
"NEG_LOG10_MWU_PVAL"], **{"text-align": "right"})
style = style.hide().relabel_index(
['VEP', "Variant Total", "Positive Labels", "Negative Labels",
"MWU -log10(pval)"],
axis=1).set_caption("Mann-Whitney U -log10(p value)")
if file_name:
style.to_html(file_name)
else:
display(style)
[docs]
def _display_pr_table(self, results: VEAnalysisResult,
file_name: str = None):
table_output = results.general_metrics.merge(
results.pr_metrics, how="inner", suffixes=(None, "_y"),
on='SCORE_SOURCE')[[
'SOURCE_NAME', 'NUM_VARIANTS', 'NUM_POSITIVE_LABELS',
'NUM_NEGATIVE_LABELS', 'PR_AUC']]
table_output = table_output.sort_values('PR_AUC', ascending=False)
style = table_output.style.set_properties(
subset=["NUM_VARIANTS",
"NUM_POSITIVE_LABELS", "NUM_NEGATIVE_LABELS",
"PR_AUC"], **{"text-align": "right"})
style = style.hide().relabel_index(
['VEP', "Variant Total", "Positive Labels", "Negative Labels",
"PR AUC"],
axis=1).set_caption("Precision/Recall")
style = style.set_properties(
subset=["NUM_VARIANTS",
"NUM_POSITIVE_LABELS", "NUM_NEGATIVE_LABELS",
"PR_AUC"], **{"text-align": "right"})
if file_name:
style.to_html(file_name)
else:
display(style)
[docs]
def _display_roc_table(self, results: VEAnalysisResult,
file_name: str = None):
table_output = results.general_metrics.merge(
results.roc_metrics, how="inner", suffixes=(None, "_y"),
on='SCORE_SOURCE')[[
'SOURCE_NAME', 'NUM_VARIANTS', 'NUM_POSITIVE_LABELS',
'NUM_NEGATIVE_LABELS', 'ROC_AUC']]
table_output = table_output.sort_values('ROC_AUC', ascending=False)
style = table_output.style.set_properties(
subset=["NUM_VARIANTS",
"NUM_POSITIVE_LABELS", "NUM_NEGATIVE_LABELS",
"ROC_AUC"], **{"text-align": "right"})
style = style.hide().relabel_index(
['VEP', "Variant Total", "Positive Labels", "Negative Labels",
"ROC AUC"],
axis=1).set_caption("ROC")
if file_name:
style.to_html(file_name)
else:
display(style)
[docs]
def _plot_mwu_bar(self, mwus: pd.DataFrame,
batch_no: int, num_batches: int,
file_name: str = None, palette=None):
title = self._mwu_config.title + \
(f" ({batch_no+1} of {num_batches})" if num_batches > 1 else "")
config = self._mwu_config
barchart(mwus, 'SOURCE_NAME', 'NEG_LOG10_MWU_PVAL',
palette=palette,
y_label='Mann-Whitney U log10(p value)',
filename=file_name, title=title,
title_fontsize=config.title_font_size,
y_label_fontsize=config.y_label_font_size,
x_label_fontsize=config.x_label_font_size,
y_tick_label_size=config.y_tick_label_size,
x_tick_label_size=config.x_tick_label_size,
xtick_rotation=config.xtick_rotation,
xtick_rotation_mode=config.xtick_rotation_mode,
file_dpi=self._config.file_dpi,
bbox_inches=self._config.bbox_inches)
[docs]
def plot_pr_results(self, results: VEAnalysisResult,
num_top_labelled_veps: int,
ves_color_palette: dict,
dir: str = None):
num_curves_per_plot = self._roc_pr_config.num_curves_per_plot
plot_batches = []
pr_metrics = results.pr_metrics.sort_values('PR_AUC',
ascending=False)
for idx in range(0, len(pr_metrics), num_curves_per_plot):
batch = pr_metrics.iloc[idx:idx+num_curves_per_plot]
plot_batches.append(batch)
pr_curves_file_name = None if dir is None else os.path.join(
dir, "pr_curves_")
num_batches = len(plot_batches)
for batch_no, batch in enumerate(plot_batches):
batch_file_name = None if pr_curves_file_name is None else \
pr_curves_file_name + str(batch_no) + ".png"
self._plot_pr_curves(batch, results.user_vep_name,
results.pr_curve_coordinates,
batch_no, num_batches,
num_top_labelled_veps,
ves_color_palette, batch_file_name)
pr_table_file_name = None if dir is None else os.path.join(
dir, "pr_table" + ".html")
self._display_pr_table(results, pr_table_file_name)
[docs]
def plot_roc_results(self, results: VEAnalysisResult,
num_top_labelled_veps: int,
ves_color_palette: dict,
dir: str = None):
num_curves_per_plot = self._roc_pr_config.num_curves_per_plot
roc_metric_batches = []
roc_metrics = results.roc_metrics.sort_values('ROC_AUC',
ascending=False)
for idx in range(0, len(roc_metrics), num_curves_per_plot):
batch = roc_metrics.iloc[idx:idx+num_curves_per_plot]
roc_metric_batches.append(batch)
roc_curves_file_name = None if dir is None else os.path.join(
dir, "roc_curves_")
num_batches = len(roc_metric_batches)
for batch_no, batch in enumerate(roc_metric_batches):
batch_file_name = None if roc_curves_file_name is None else \
roc_curves_file_name + str(batch_no) + ".png"
self._plot_roc_curves(batch, results.user_vep_name,
results.roc_curve_coordinates,
batch_no, num_batches,
num_top_labelled_veps,
ves_color_palette, batch_file_name)
roc_table_file_name = None if dir is None else os.path.join(
dir, "roc_table" + ".html")
self._display_roc_table(results, roc_table_file_name)
[docs]
def plot_mwu_results(self, results: VEAnalysisResult,
ves_color_palette: dict,
dir: str = None):
num_bars_per_plot = self._mwu_config.num_bars_per_plot
batches = []
metrics = results.mwu_metrics.query('EXCEPTION.isna()')
metrics = metrics.sort_values('NEG_LOG10_MWU_PVAL',
ascending=False)
for idx in range(0, len(metrics), num_bars_per_plot):
batch = metrics.iloc[idx:idx+num_bars_per_plot]
batches.append(batch)
mwu_bar_file_name = None if dir is None else os.path.join(
dir, "mwu_bar_")
num_batches = len(batches)
for batch_no, batch in enumerate(batches):
batch_file_name = None if mwu_bar_file_name is None else \
mwu_bar_file_name + str(batch_no) + ".png"
self._plot_mwu_bar(batch, batch_no, num_batches, batch_file_name,
ves_color_palette)
mwu_table_file_name = None if dir is None else os.path.join(
dir, "mwu_table" + ".html")
self._display_mwu_table(results, mwu_table_file_name)
[docs]
def plot_results(self, results: VEAnalysisResult,
metrics: str | list[str] = ["roc", "pr", "mwu"],
num_top_labelled_veps: int = None,
num_top_genes: int = None,
dir: str = None):
"""
Plot the results of an analysis either to the screen or to
files.
Parameters
----------
results : VEAnalysisResult
Analysis result object
metrics : str or list[str]
Specifies which metrics to plot. Can be a string
indicating a single metric or a list of strings for
multiple metrics. The metrics are: roc, pr, mwu.
num_top_labelled_veps : int
If not None, only this many of the top performing veps will be
labelled in the auc plot legends. This is useful when there are
many veps and the legend becomes too cluttered.
num_top_genes : int
If compute_gene_metrics was set to True in call to compute_metrics,
then only include this many top genes in the plot. The top gene
are the ones for which the most variants were observed.
dir : str, optional
Directory to place the plot files. The files will
be placed in a subdirectory off of this directory
whose name begins with ve_analysis_plots and suffixed
by a unique timestamp. If not specified will plot
to screen.
"""
if type(metrics) is str:
metrics = [metrics]
if dir is not None:
dir = unique_file_name(dir, "ve_analysis_plots_")
create_folder(dir)
ves_color_palette = create_feature_palette(
results.general_metrics["SOURCE_NAME"])
if "roc" in metrics and results.roc_metrics is not None:
self.plot_roc_results(results, num_top_labelled_veps,
ves_color_palette, dir)
if "pr" in metrics and results.pr_metrics is not None:
self.plot_pr_results(results, num_top_labelled_veps,
ves_color_palette, dir)
if "mwu" in metrics and results.mwu_metrics is not None:
self.plot_mwu_results(results, ves_color_palette, dir)
if results.gene_general_metrics is not None:
self.plot_gene_results(results, metrics, num_top_genes, dir)
[docs]
def plot_gene_results(self, results: VEAnalysisResult,
metrics: list[str],
num_top_genes: int = None,
dir: str = None):
"""
Plot gene-level results of an analysis.
Parameters
----------
results : VEAnalysisResult
Analysis result object containing gene-level metrics
metrics : list[str]
List of metrics to plot (roc, pr, mwu)
num_top_genes : int
Number of top genes to plot based on the number of variants
in each gene included in the analysis.
dir : str, optional
Directory to place the plot files
"""
ves_color_palette = create_feature_palette(
results.gene_general_metrics["SOURCE_NAME"].unique())
gene_metric_sorter = GeneMetricSorter(
results.gene_unique_variant_counts_df, num_top_genes)
if "roc" in metrics and results.gene_roc_metrics is not None:
self.plot_gene_level_results(results.gene_general_metrics,
results.gene_roc_metrics,
"ROC_AUC",
"ROC AUC",
"Gene-Level ROC",
gene_metric_sorter,
ves_color_palette,
dir,
"gene_roc_heatmap.png",
"gene_roc_table.html")
if "pr" in metrics and results.gene_pr_metrics is not None:
self.plot_gene_level_results(results.gene_general_metrics,
results.gene_pr_metrics,
"PR_AUC",
"PR AUC",
"Gene-Level Precision/Recall",
gene_metric_sorter,
ves_color_palette,
dir,
"gene_pr_heatmap.png",
"gene_pr_table.html")
if "mwu" in metrics and results.gene_mwu_metrics is not None:
self.plot_gene_level_results(results.gene_general_metrics,
results.gene_mwu_metrics,
"NEG_LOG10_MWU_PVAL",
"MWU -log10(pval)",
"Gene-Level Mann-Whitney U",
gene_metric_sorter,
ves_color_palette,
dir,
"gene_mwu_heatmap.png",
"gene_mwu_table.html")
[docs]
def plot_gene_level_results(self, gene_general_metrics: pd.DataFrame,
gene_metrics: pd.DataFrame,
metric_column: str,
metric_display_name: str,
title: str,
gene_metric_sorter: GeneMetricSorter,
ves_color_palette: dict,
dir: str = None,
figure_file_name: str = None,
table_file_name: str = None):
gene_metrics = gene_general_metrics.merge(
gene_metrics, how="inner", suffixes=(None, "_y"),
on=['SCORE_SOURCE', 'GENE_SYMBOL'])[[
'SOURCE_NAME', 'GENE_SYMBOL', 'NUM_VARIANTS',
'NUM_POSITIVE_LABELS',
'NUM_NEGATIVE_LABELS', metric_column]]
gene_metrics = gene_metric_sorter.sort_gene_metrics(
gene_metrics)
heatmap_file_name = None if dir is None else os.path.join(
dir, figure_file_name)
gene_metrics_heatmap = gene_metrics.query(f'{metric_column}.notna()')
heatmap(gene_metrics_heatmap, 'SOURCE_NAME', 'GENE_SYMBOL',
metric_column,
title=title, x_label='Gene', y_label='VEP',
cbar_kws={'label': metric_display_name,
'shrink': 0.8},
filename=heatmap_file_name)
# Create gene ROC table
table_file_name = None if dir is None else os.path.join(
dir, table_file_name)
self._display_gene_metric_table(gene_metrics, metric_column,
metric_display_name, title,
table_file_name)
[docs]
def _display_gene_metric_table(self, metrics_df: pd.DataFrame,
metric_column: str,
metric_display_name: str,
title: str,
file_name: str = None):
style = metrics_df.style.set_properties(
subset=["NUM_VARIANTS",
"NUM_POSITIVE_LABELS", "NUM_NEGATIVE_LABELS",
metric_column], **{"text-align": "right"})
style = style.hide().relabel_index(
['VEP', "Gene", "Variant Total", "Positive Labels",
"Negative Labels",
metric_display_name],
axis=1).set_caption(title)
if file_name:
style.to_html(file_name)
else:
display(style)
[docs]
def _plot_score_vs_pathogenic_fraction(self, axes, results,
x_lower_limit: float = None,
x_upper_limit: float = None,
annotate: bool = False,
dir: str = None):
"""
Plot the fraction of pathogenic variants versus mean score as a line
plot.
Parameters
----------
axes : matplotlib.axes.Axes
Matplotlib axes object to plot on
results : VEAnalysisCalibrationResult
Calibration result object containing score bins
annotate : bool, optional
If True, annotate each point with the score range
dir : str, optional
Directory to save the plot file
"""
config = self._calibration_line_config
sorted_df = results.score_pathogenic_fraction_df.sort_values(
"MEAN_SCORE")
pathogenic_fraction = (
sorted_df["NUM_POSITIVE_LABELS"] / sorted_df["NUM_VARIANTS"]
).round(1)
mean_score = sorted_df["MEAN_SCORE"].round(2)
# Plot pathogenic fraction vs mean score as a line plot
axes.plot(mean_score, pathogenic_fraction,
marker='o', linestyle='-', linewidth=2,
markersize=6, color='blue')
if annotate:
for row in sorted_df.itertuples():
axes.annotate(row.SCORE_RANGE,
(row.MEAN_SCORE,
row.NUM_POSITIVE_LABELS/row.NUM_VARIANTS),
xytext=(5, 5), # Offset from point
textcoords='offset points',
fontsize=8,
ha='left',
va='bottom',
bbox=dict(boxstyle='round,pad=0.3',
facecolor='yellow',
alpha=0.7),
rotation=45) # Rotate text to avoid overlap
# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(
mean_score, pathogenic_fraction)
# Calculate regression line and plot it
regression_line_y = slope * mean_score + intercept
axes.plot(mean_score, regression_line_y, linestyle='--',
color='gray', alpha=0.7)
# Set labels and title
axes.set_ylabel('Pathogenic Fraction', fontsize=config.label_size)
axes.set_title(config.pathogenic_fraction_title,
fontsize=config.title_font_size)
# Remove x tick values
axes.set_xticklabels([])
axes.tick_params(axis='y', labelsize=config.label_size)
# Set axis limits
axes.set_xlim(x_lower_limit, x_upper_limit)
axes.set_ylim(0, 1.05)
# Add grid
axes.grid(True, linestyle='--', alpha=0.7)
[docs]
def _plot_precision_recall_vs_thresholds1(
self, results: VEAnalysisCalibrationResult, dir: str = None):
"""
Obsolete to be removed.
Plot precision, recall, and F1 score versus threshold values.
Parameters
----------
results : VEAnalysisCalibrationResult
Calibration result object containing precision-recall curve data
dir : str, optional
Directory to save the plot file
"""
# Get precision-recall curve data
pr_curve_coords = results.negative_pr_curve_coordinates
# Sort by threshold
pr_curve_coords = pr_curve_coords.sort_values('THRESHOLD')
# Calculate F1 score
precision = pr_curve_coords['PRECISION']
recall = pr_curve_coords['RECALL']
# avoid division by zero
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
# Create the plot
plt.figure(figsize=(10, 6))
# Plot precision, recall, and F1 score vs threshold
plt.plot(pr_curve_coords['THRESHOLD'], precision,
label='Precision', color='blue', linewidth=2)
plt.plot(pr_curve_coords['THRESHOLD'], recall,
label='Recall', color='green', linewidth=2)
plt.plot(pr_curve_coords['THRESHOLD'], f1_scores,
label='F1 Score', color='red', linewidth=2)
# Set labels and title
plt.xlabel('Decision Threshold', fontsize=14)
plt.ylabel('Score', fontsize=14)
plt.title('Performance Metrics vs. Decision Threshold', fontsize=16)
# Add legend
plt.legend(fontsize=12)
# Add grid
plt.grid(True, linestyle='--', alpha=0.7)
# Set y-axis limits
plt.ylim(0, 1.05)
# Save or display the plot
if dir is None:
plt.show()
else:
file_path = os.path.join(dir, "precision_recall_vs_threshold.png")
plt.savefig(file_path, dpi=self._config.file_dpi,
bbox_inches=self._config.bbox_inches)
plt.savefig(file_path.replace(".png", ".svg"),
format='svg',
bbox_inches=self._config.bbox_inches)
plt.close()
[docs]
def _plot_precision_recall_vs_threshold(
self, results: VEAnalysisCalibrationResult,
threshold_boundary: float = 0.5,
precision_cutoff: float = 0.9,
dir: str = None):
"""
Obsolete to be removed.
Plot precision, recall, and F1 score versus threshold values.
Parameters
----------
results : VEAnalysisCalibrationResult
Calibration result object containing precision-recall curve data
dir : str, optional
Directory to save the plot file
"""
# Get precision-recall curve data
pr_curve_coords = results.negative_pr_curve_coordinates.query(
"THRESHOLD <= @threshold_boundary")
# Sort by threshold
pr_curve_coords = pr_curve_coords.sort_values('THRESHOLD')
# Calculate F1 score
precision = pr_curve_coords['PRECISION']
recall = pr_curve_coords['RECALL']
# Create the plot
plt.figure(figsize=(10, 6))
# Plot precision, recall, and F1 score vs threshold
plt.plot(pr_curve_coords['THRESHOLD'], precision,
label='Benign Precision', color='blue', linewidth=2)
plt.plot(pr_curve_coords['THRESHOLD'], recall,
label='Benign Recall', color='blue', linewidth=2,
linestyle='dashdot')
if len(pr_curve_coords) > 1 and precision.max() >= precision_cutoff \
and precision.min() <= precision_cutoff:
plt.axvline(x=pr_curve_coords[precision >= precision_cutoff]
['THRESHOLD'].min(),
color='gray', linestyle='--', linewidth=1, alpha=0.7)
pr_curve_coords = results.positive_pr_curve_coordinates.query(
"THRESHOLD >= @threshold_boundary")
# Sort by threshold
pr_curve_coords = pr_curve_coords.sort_values('THRESHOLD')
# Calculate F1 score
precision = pr_curve_coords['PRECISION']
recall = pr_curve_coords['RECALL']
# Plot precision, recall, and F1 score vs threshold
plt.plot(pr_curve_coords['THRESHOLD'], precision,
label='Pathogenic Precision', color='red', linewidth=2)
plt.plot(pr_curve_coords['THRESHOLD'], recall,
label='Pathogenic Recall', color='red', linewidth=2,
linestyle='dashdot')
if len(pr_curve_coords) > 1 and precision.max() >= precision_cutoff \
and precision.min() <= precision_cutoff:
plt.axvline(x=pr_curve_coords[precision >= precision_cutoff][
'THRESHOLD'].min(),
color='gray', linestyle='--', linewidth=1, alpha=0.7)
plt.axhline(y=precision_cutoff,
color='gray', linestyle='--', linewidth=1, alpha=0.7)
# Set labels and title
plt.xlabel('Decision Threshold', fontsize=14)
plt.ylabel('Precision or Recall', fontsize=14)
plt.title('Performance Metrics vs. Decision Threshold', fontsize=16)
# Add legend
plt.legend(fontsize=12)
# Set y-axis limits
plt.ylim(0, 1.05)
# Save or display the plot
if dir is None:
plt.show()
else:
file_path = os.path.join(dir, "precision_recall_vs_threshold.png")
plt.savefig(file_path, dpi=self._config.file_dpi,
bbox_inches=self._config.bbox_inches)
plt.savefig(file_path.replace(".png", ".svg"),
format='svg',
bbox_inches=self._config.bbox_inches)
plt.close()
[docs]
def _plot_score_vs_variant_counts(self, axes,
results: VEAnalysisCalibrationResult,
bins: int,
x_lower_limit: float = None,
x_upper_limit: float = None,
dir: str = None):
"""
Plot histograms showing the distribution of RANK_SCORE values
for positive (BINARY_LABEL=1) and negative (BINARY_LABEL=0) variants.
Parameters
----------
results : VEAnalysisCalibrationResult
Calibration result object containing variant scores and labels
axes : matplotlib.axes.Axes
Matplotlib axes object to plot on
bins : int
Number of bins to use for the histogram
dir : str, optional
Directory to save the plot file
"""
config = self._calibration_line_config
# Get the raw data with individual variant scores
df = results.scores_and_labels_df
# Separate positive and negative variants
positive_variants = df[df['BINARY_LABEL'] == 1]['RANK_SCORE']
negative_variants = df[df['BINARY_LABEL'] == 0]['RANK_SCORE']
# Define common histogram parameters
alpha = 0.6
# Plot histograms on the same axis with counts (not density)
axes.hist(negative_variants, bins=bins, alpha=alpha, color='blue',
label=f'Benign variants (n={len(negative_variants)})')
axes.hist(positive_variants, bins=bins, alpha=alpha, color='red',
label=f'Pathogenic variants (n={len(positive_variants)})')
# Force y-axis ticks to be integers
axes.yaxis.set_major_locator(MaxNLocator(integer=True))
# Set axis limits
axes.set_xlim(x_lower_limit, x_upper_limit)
# Add formatting
axes.tick_params(axis='both', labelsize=config.label_size)
axes.set_xlabel('Score', fontsize=config.label_size)
axes.set_ylabel('Variant Count', fontsize=config.label_size)
axes.grid(True, linestyle='--', alpha=0.7)
axes.legend(fontsize=config.legend_font_size, loc='upper left')
[docs]
def plot_calibration_curves(self, results: VEAnalysisCalibrationResult,
target_precision: float = None,
target_recall: float = None,
target_f1: float = None,
dir: str = None):
"""
Plot the results of calling VEAnalyzer.compute_calibration_metrics.
Generates 3 plots:
1. Pathogenic fraction by score interval
2. Distribution of variant scores by pathogenicity
3. Precision, recall, and F1 score versus threshold values
The first 2 plots are vertically stacked in a single figure.
Parameters
----------
results : VEAnalysisCalibrationResult
Calibration result object returned by calling
VEAnalyzer.compute_calibration_metrics.
target_precision: float, optional
If specified, will plot a vertical line at the threshold
that achieves the target precision.
target_recall: float, optional
If specified, will plot a vertical line at the threshold
that achieves the target recall.
target_f1: float, optional
If specified, will plot a vertical line at the threshold
that achieves the target f1 score.
dir : str, optional
Directory to place the plot files. The files will
be placed in a subdirectory off of this directory
whose name begins with ve_calibration_plots and suffixed
by a unique timestamp. If not specified will plot
to screen.
"""
if dir is not None:
dir = unique_file_name(dir, "ve_calibration_plots_")
create_folder(dir)
self._plot_binned_data(results, dir=dir)
self._plot_metrics_vs_threshold(results, target_precision,
target_recall, target_f1, dir=dir)
[docs]
def _plot_binned_data(self, results: VEAnalysisCalibrationResult,
dir: str = None):
"""
Plot the binned data as 2 subplots vertically stacked.
"""
# Create vertically stacked subplots
fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(10, 8))
min_score = results.scores_and_labels_df["RANK_SCORE"].min()
max_score = results.scores_and_labels_df["RANK_SCORE"].max()
score_range = max_score - min_score
x_lower_limit = min_score - 0.01 * score_range
x_upper_limit = max_score + 0.01 * score_range
# Subplot 1
self._plot_score_vs_pathogenic_fraction(ax1, results,
x_lower_limit,
x_upper_limit,
dir=dir)
# Subplot 2
self._plot_score_vs_variant_counts(
ax2, results, len(results.score_pathogenic_fraction_df),
x_lower_limit, x_upper_limit, dir
)
# Adjust layout for better spacing
plt.tight_layout()
# Save or display the plot
if dir is None:
plt.show()
else:
file_path = os.path.join(dir, "pathogenic_fraction_by_score.png")
plt.savefig(file_path, dpi=self._config.file_dpi,
bbox_inches=self._config.bbox_inches)
plt.savefig(file_path.replace(".png", ".svg"),
format='svg',
bbox_inches=self._config.bbox_inches)
plt.close()
[docs]
def _plot_metrics_vs_threshold(
self, results: VEAnalysisCalibrationResult,
target_precision: float = None,
target_recall: float = None,
target_f1: float = None,
dir: str = None):
"""
Plot precision, recall, F1 score, versus threshold values.
"""
# Create the plot
plt.figure(figsize=(12, 8))
config = self._calibration_line_config
# Plot Precision vs Threshold
pr_coords = results.pr_curve_coordinates_df.sort_values(
'THRESHOLD')[:-1]
plt.plot(pr_coords['THRESHOLD'], pr_coords['PRECISION'],
label='Precision', color='blue', linewidth=2)
plt.plot(pr_coords['THRESHOLD'], pr_coords['RECALL'],
label='Recall', color='green', linewidth=2)
"""
roc_coords = results.roc_curve_coordinates.sort_values(
'THRESHOLD')[:-1]
plt.plot(roc_coords['THRESHOLD'], roc_coords['TRUE_POSITIVE_RATE'],
label='True Positive Rate (ROC)', color='orange', linewidth=2)
plt.plot(roc_coords['THRESHOLD'], roc_coords['FALSE_POSITIVE_RATE'],
label='False Positive Rate (ROC)', color='red', linewidth=2)
"""
# Plot F1 Score vs Threshold
f1_coords = results.f1_curve_coordinates_df.sort_values(
'THRESHOLD')[:-1]
plt.plot(f1_coords['THRESHOLD'], f1_coords['F1_SCORE'],
label='F1 Score', color='purple', linewidth=2)
if target_recall is not None:
targ_recall = pr_coords[pr_coords["RECALL"] >=
target_recall]['RECALL'].min()
if not np.isnan(targ_recall):
targ_recall_threshold = pr_coords[
pr_coords["RECALL"] ==
targ_recall]["THRESHOLD"].iloc[0]
plt.axvline(x=targ_recall_threshold,
color='green', linestyle='--', linewidth=2,
alpha=0.7)
# Add annotation near the bottom, just to the right of the
# threshold
plt.annotate(
f"{targ_recall_threshold:.2f}",
xy=(targ_recall_threshold,
config.recall_annotation_offset),
xytext=(10, 0), # 10 points to the right
textcoords='offset points',
fontsize=config.label_size,
color='green',
ha='left',
va='bottom',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
alpha=0.8)
)
max_f1 = f1_coords['F1_SCORE'].max()
if not np.isnan(max_f1):
targ_f1_threshold = f1_coords[
f1_coords["F1_SCORE"] ==
max_f1]["THRESHOLD"].iloc[0]
plt.axvline(x=targ_f1_threshold,
color='purple', linestyle='--', linewidth=2,
alpha=0.7)
# Add annotation near the bottom, just to the right of the
# threshold
plt.annotate(
f"{targ_f1_threshold:.2f}",
xy=(targ_f1_threshold, config.f1_annotation_offset),
xytext=(10, 0), # 10 points to the right
textcoords='offset points',
fontsize=config.label_size,
color='purple',
ha='left',
va='bottom',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
alpha=0.8)
)
if target_precision is not None:
targ_precision = pr_coords[pr_coords["PRECISION"] >=
target_precision]['PRECISION'].min()
if not np.isnan(targ_precision):
targ_precision_threshold = pr_coords[
pr_coords["PRECISION"] ==
targ_precision]["THRESHOLD"].iloc[0]
plt.axvline(x=targ_precision_threshold,
color='blue', linestyle='--', linewidth=2,
alpha=0.7)
# Add annotation near the bottom, just to the right of the
# threshold
plt.annotate(
f"{targ_precision_threshold:.2f}",
xy=(targ_precision_threshold,
config.precision_annotation_offset),
xytext=(10, 0), # 10 points to the right
textcoords='offset points',
fontsize=config.label_size,
color='blue',
ha='left',
va='bottom',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white',
alpha=0.8)
)
# Set labels and title
plt.tick_params(axis='both', labelsize=config.label_size)
plt.xlabel('Decision Threshold', fontsize=config.label_size)
plt.ylabel('Metric Value', fontsize=config.label_size)
plt.title(config.metrics_vs_threshold_title,
fontsize=config.title_font_size)
# Add legend
plt.legend(fontsize=config.legend_font_size, loc='best')
# Set y-axis limits
plt.ylim(0, 1.05)
# Save or display the plot
if dir is None:
plt.show()
else:
file_path = os.path.join(dir, "metrics_vs_threshold.png")
plt.savefig(file_path, dpi=self._config.file_dpi,
bbox_inches=self._config.bbox_inches)
plt.savefig(file_path.replace(".png", ".svg"),
format='svg',
bbox_inches=self._config.bbox_inches)
plt.close()