import math
import statistics
from pathlib import Path
from typing import Dict, List
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.transforms import blended_transform_factory
from BLPlot._heatmap import (
_draw_section,
_setup_heatmap_axes,
)
from BLPlot.plotter import Plotter, iter_datasets_with_runs, random_classifier_baseline
plt.rcParams["font.size"] = 12
def _load_auprc_ratios(
dataset_ids: List[str],
dataset_paths: List[Path],
gt_paths: List[Path],
) -> Dict[str, Dict[str, float]]:
"""
Compute median AUPRC ratio per algorithm per dataset.
For each dataset, reads AUPRC.csv and divides each per-run value by the
random classifier baseline for that dataset, then takes the median across
runs. Datasets with missing CSV or missing ground truth are skipped.
Parameters
----------
dataset_ids : list[str]
Ordered dataset identifiers.
dataset_paths : list[Path]
Output directory for each dataset (same order as dataset_ids).
gt_paths : list[Path]
Ground truth CSV path for each dataset (same order as dataset_ids).
Returns
-------
dict[str, dict[str, float]]
algo -> {dataset_id -> median AUPRC ratio}, nan where unavailable.
"""
result: Dict[str, Dict[str, float]] = {}
for dataset_id, dataset_path, gt_path in zip(dataset_ids, dataset_paths, gt_paths):
csv_path = dataset_path / 'AUPRC.csv'
if not csv_path.exists():
print(f"Warning: {csv_path} not found, skipping.")
continue
if not gt_path.exists():
print(f"Warning: {gt_path} not found, skipping.")
continue
baseline = random_classifier_baseline(gt_path)
if math.isnan(baseline) or baseline == 0.0:
print(f"Warning: random baseline undefined for {gt_path}, skipping.")
continue
df = pd.read_csv(csv_path, index_col=0)
for algo in df.index:
vals = [v for v in df.loc[algo].tolist() if not math.isnan(v)]
ratio = statistics.median([v / baseline for v in vals]) if vals else float('nan')
result.setdefault(str(algo), {})[dataset_id] = ratio
return result
def _load_spearman(
dataset_ids: List[str],
dataset_paths: List[Path],
) -> Dict[str, Dict[str, float]]:
"""
Load median Spearman stability per algorithm per dataset.
Reads Spearman.csv from each dataset's output directory. Datasets with
missing CSV are skipped with a warning.
Parameters
----------
dataset_ids : list[str]
Ordered dataset identifiers.
dataset_paths : list[Path]
Output directory for each dataset (same order as dataset_ids).
Returns
-------
dict[str, dict[str, float]]
algo -> {dataset_id -> MedianSpearman}, nan where unavailable.
"""
result: Dict[str, Dict[str, float]] = {}
for dataset_id, dataset_path in zip(dataset_ids, dataset_paths):
csv_path = dataset_path / 'Spearman.csv'
if not csv_path.exists():
print(f"Warning: {csv_path} not found, skipping.")
continue
df = pd.read_csv(csv_path, index_col=0)
for algo in df.index:
val = df.loc[algo, 'MedianSpearman']
result.setdefault(str(algo), {})[dataset_id] = float(val)
return result
[docs]class PlotSummaryHeatmap(Plotter):
"""
Replicates Figure 2 of Pratapa et al. 2020 (BEELINE).
Produces a two-section heatmap: median AUPRC ratios (left) and median
Spearman stability scores (right). Algorithms (rows) are sorted by
decreasing median AUPRC ratio; datasets are columns. Each cell contains a
rounded square sized and colored by its value. Values above the random
predictor baseline (ratio >= 1) are drawn full-size with their raw value
as white text. Alternating row backgrounds aid readability. Writes
Summary.pdf to the output directory.
"""
def __call__(self, config: dict, output_dir: Path, root: Path) -> None:
"""
Generate the summary heatmap and write it to output_dir/Summary.pdf.
Parameters
----------
config : dict
Parsed YAML configuration.
output_dir : Path
Directory where Summary.pdf is written.
root : Path
Working directory from which config paths are resolved.
Returns
-------
None
"""
if not isinstance(config, dict):
raise TypeError(f"config must be dict, got {type(config)}")
if not isinstance(output_dir, Path):
raise TypeError(f"output_dir must be Path, got {type(output_dir)}")
if not isinstance(root, Path):
raise TypeError(f"root must be Path, got {type(root)}")
rows = list(iter_datasets_with_runs(config, root))
if not rows:
print("No datasets found for summary heatmap.")
return
dataset_ids = [r[0] for r in rows]
# dataset_labels : list[str] — per-dataset display labels for column
# headers; uses 'nickname' from config when set, else falls back to
# dataset_id.
dataset_labels = [r[1] for r in rows]
dataset_paths = [r[2] for r in rows]
gt_paths = [r[3] for r in rows]
auprc_ratios = _load_auprc_ratios(dataset_ids, dataset_paths, gt_paths)
spearman = _load_spearman(dataset_ids, dataset_paths)
all_algos = sorted(set(auprc_ratios) | set(spearman))
if not all_algos:
print("No data found for summary heatmap.")
return
# Sort algorithms by decreasing median-of-medians AUPRC ratio.
def _median_auprc_ratio(algo: str) -> float:
vals = [v for v in auprc_ratios.get(algo, {}).values() if not math.isnan(v)]
return statistics.median(vals) if vals else 0.0
sorted_algos = sorted(all_algos, key=_median_auprc_ratio, reverse=True)
n_algos = len(sorted_algos)
n_datasets = len(dataset_ids)
# Build raw value arrays (n_algos × n_datasets), sorted_algos[0] = best.
auprc_arr = np.array([
[auprc_ratios.get(a, {}).get(d, float('nan')) for d in dataset_ids]
for a in sorted_algos
])
spear_arr = np.array([
[spearman.get(a, {}).get(d, float('nan')) for d in dataset_ids]
for a in sorted_algos
])
# --- Figure layout (mirrors old_heatmap.py) ---
# Section 1 cols: x = 1 .. n_datasets
# Gap: x = n_datasets + 1
# Section 2 cols: x = n_datasets + 2 .. n_datasets * 2 + 1
total_cols = n_datasets * 2 + 1
pad = 2
height = 7
asp_ratio = (total_cols + pad) / (n_algos + pad)
fig_size = (height * asp_ratio + 0.5, height)
fig = plt.figure(figsize=(fig_size[0], fig_size[1] + 0.5))
ax = fig.add_subplot(1, 1, 1)
_setup_heatmap_axes(ax, n_algos, sorted_algos, total_cols, pad)
# Measure the maximum y-tick label width in axes-fraction coordinates so
# the row backgrounds extend precisely to cover the algorithm labels.
# draw() forces text layout before get_window_extent() is called.
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
max_label_px = max(
(t.get_window_extent(renderer).width for t in ax.get_yticklabels()),
default=0.0,
)
# Small padding (0.01) leaves a sliver of space between label and edge.
label_frac = max_label_px / ax.get_window_extent(renderer).width + 0.01
row_trans = blended_transform_factory(ax.transAxes, ax.transData)
for row_idx in range(n_algos):
bg = (0.9, 0.9, 0.9) if row_idx % 2 == 0 else (1.0, 1.0, 1.0)
ax.add_artist(patches.Rectangle(
(-label_frac, n_algos - row_idx - 0.5),
width=1.0 + label_frac, height=1,
transform=row_trans,
clip_on=False,
edgecolor=(1, 1, 1), facecolor=bg,
))
# Color palettes: viridis for AUPRC, cool (blue) for stability.
auprc_palette = sns.color_palette("viridis", 11)
spear_palette = sns.cubehelix_palette(11, reverse=True)
_draw_section(
ax, auprc_arr, n_algos, n_datasets,
col_x_start=1,
palette=auprc_palette,
rand_cutoff=0.0,
dataset_ids=dataset_labels,
section_label='Median AUPRC ratios',
)
_draw_section(
ax, spear_arr, n_algos, n_datasets,
col_x_start=n_datasets + 2,
palette=spear_palette,
rand_cutoff=0.0,
dataset_ids=dataset_labels,
section_label='Median stability scores',
switch_text=False,
)
# Legend geometry — same fixed-width approach as PlotEPRHeatmap.
lw = 5.0 / (total_cols + 1) # legend width as axes fraction
lh = 0.5 / (n_algos + pad) # legend height as axes fraction
ly = -lh - 0.05 # just below the heatmap
# Two legends placed at the centre of each section's half of the axes.
for cx, palette in [(1/4, auprc_palette), (3/4, spear_palette)]:
legend_ax = ax.inset_axes([cx - lw / 2, ly, lw, lh])
legend_ax.imshow(
np.arange(len(palette)).reshape(1, len(palette)),
cmap=mpl.colors.ListedColormap(list(palette)),
interpolation='nearest', aspect='auto',
)
legend_ax.yaxis.set_ticks_position('none')
legend_ax.xaxis.set_ticks_position('none')
legend_ax.set_yticklabels([])
legend_ax.set_xticks([0.5, len(palette) - 2])
legend_ax.set_xticklabels(['Low/Poor', 'High/Good'], fontsize=12)
stem = output_dir / 'Summary'
plt.savefig(stem.with_suffix('.pdf'), bbox_inches='tight')
plt.savefig(stem.with_suffix('.png'), bbox_inches='tight')
plt.close(fig)
print(f"Saved summary heatmap to {stem}.pdf and .png")