feat(05-02): implement visualization module with matplotlib/seaborn plots
- Add matplotlib>=3.8.0 and seaborn>=0.13.0 to dependencies - Create visualizations.py with 3 plot functions and orchestrator - plot_score_distribution: histogram colored by confidence tier - plot_layer_contributions: bar chart of evidence layer coverage - plot_tier_breakdown: pie chart of tier distribution - Use Agg backend for headless/CLI safety - All plots saved at 300 DPI with proper figure cleanup - 6 tests covering file creation, edge cases, and return values
This commit is contained in:
@@ -37,6 +37,8 @@ dependencies = [
|
||||
"structlog>=25.0",
|
||||
"biopython>=1.84",
|
||||
"scipy>=1.14",
|
||||
"matplotlib>=3.8.0",
|
||||
"seaborn>=0.13.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
244
src/usher_pipeline/output/visualizations.py
Normal file
244
src/usher_pipeline/output/visualizations.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""Visualization generation for pipeline outputs."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import polars as pl
|
||||
|
||||
# Use Agg backend (non-interactive, safe for headless/CLI use)
|
||||
matplotlib.use("Agg")
|
||||
|
||||
import matplotlib.pyplot as plt # noqa: E402
|
||||
import seaborn as sns # noqa: E402
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def plot_score_distribution(df: pl.DataFrame, output_path: Path) -> Path:
|
||||
"""
|
||||
Create histogram of composite scores colored by confidence tier.
|
||||
|
||||
Args:
|
||||
df: DataFrame with composite_score and confidence_tier columns
|
||||
output_path: Path where PNG will be saved
|
||||
|
||||
Returns:
|
||||
Path to the saved PNG file
|
||||
|
||||
Notes:
|
||||
- Converts to pandas for seaborn compatibility
|
||||
- Uses tier-specific color coding (HIGH=green, MEDIUM=orange, LOW=red)
|
||||
- Saves at 300 DPI for publication quality
|
||||
"""
|
||||
# Convert to pandas for seaborn
|
||||
pdf = df.to_pandas()
|
||||
|
||||
# Set seaborn theme
|
||||
sns.set_theme(style="whitegrid", context="paper")
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Create stacked histogram
|
||||
sns.histplot(
|
||||
data=pdf,
|
||||
x="composite_score",
|
||||
hue="confidence_tier",
|
||||
hue_order=["HIGH", "MEDIUM", "LOW"],
|
||||
palette={
|
||||
"HIGH": "#2ecc71",
|
||||
"MEDIUM": "#f39c12",
|
||||
"LOW": "#e74c3c",
|
||||
},
|
||||
bins=30,
|
||||
multiple="stack",
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
# Add labels
|
||||
ax.set_xlabel("Composite Score")
|
||||
ax.set_ylabel("Candidate Count")
|
||||
ax.set_title("Score Distribution by Confidence Tier")
|
||||
|
||||
# Save figure
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
|
||||
# CRITICAL: Close figure to prevent memory leak
|
||||
plt.close(fig)
|
||||
|
||||
logger.info(f"Saved score distribution plot to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_layer_contributions(df: pl.DataFrame, output_path: Path) -> Path:
|
||||
"""
|
||||
Create bar chart showing evidence layer coverage.
|
||||
|
||||
Args:
|
||||
df: DataFrame with layer score columns
|
||||
output_path: Path where PNG will be saved
|
||||
|
||||
Returns:
|
||||
Path to the saved PNG file
|
||||
|
||||
Notes:
|
||||
- Counts non-null values per layer
|
||||
- Shows which layers have the most/least coverage
|
||||
"""
|
||||
# Define layer columns
|
||||
layer_columns = [
|
||||
"gnomad_score",
|
||||
"expression_score",
|
||||
"annotation_score",
|
||||
"localization_score",
|
||||
"animal_model_score",
|
||||
"literature_score",
|
||||
]
|
||||
|
||||
# Count non-null values per layer
|
||||
layer_counts = {}
|
||||
for col in layer_columns:
|
||||
if col in df.columns:
|
||||
count = df.filter(pl.col(col).is_not_null()).height
|
||||
# Clean label (remove "_score" suffix)
|
||||
label = col.replace("_score", "").replace("_", " ").title()
|
||||
layer_counts[label] = count
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Create bar chart
|
||||
labels = list(layer_counts.keys())
|
||||
values = list(layer_counts.values())
|
||||
|
||||
sns.barplot(x=labels, y=values, hue=labels, palette="viridis", ax=ax, legend=False)
|
||||
|
||||
# Rotate labels
|
||||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
|
||||
|
||||
# Add labels
|
||||
ax.set_xlabel("Evidence Layer")
|
||||
ax.set_ylabel("Candidates with Evidence")
|
||||
ax.set_title("Evidence Layer Coverage")
|
||||
|
||||
# Save figure
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
|
||||
# Close figure
|
||||
plt.close(fig)
|
||||
|
||||
logger.info(f"Saved layer contributions plot to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def plot_tier_breakdown(df: pl.DataFrame, output_path: Path) -> Path:
|
||||
"""
|
||||
Create pie chart showing tier distribution.
|
||||
|
||||
Args:
|
||||
df: DataFrame with confidence_tier column
|
||||
output_path: Path where PNG will be saved
|
||||
|
||||
Returns:
|
||||
Path to the saved PNG file
|
||||
|
||||
Notes:
|
||||
- Shows percentage breakdown of HIGH/MEDIUM/LOW tiers
|
||||
- Uses same color scheme as score distribution plot
|
||||
"""
|
||||
# Count genes per tier
|
||||
if "confidence_tier" in df.columns:
|
||||
tier_counts = df.group_by("confidence_tier").agg(
|
||||
pl.len().alias("count")
|
||||
)
|
||||
tier_dict = {
|
||||
row["confidence_tier"]: row["count"]
|
||||
for row in tier_counts.to_dicts()
|
||||
}
|
||||
else:
|
||||
tier_dict = {}
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(8, 8))
|
||||
|
||||
# Define tier order and colors
|
||||
tiers = ["HIGH", "MEDIUM", "LOW"]
|
||||
colors = ["#2ecc71", "#f39c12", "#e74c3c"]
|
||||
|
||||
# Get counts in order (0 if tier not present)
|
||||
counts = [tier_dict.get(tier, 0) for tier in tiers]
|
||||
|
||||
# Create pie chart
|
||||
ax.pie(
|
||||
counts,
|
||||
labels=tiers,
|
||||
colors=colors,
|
||||
autopct="%1.1f%%",
|
||||
startangle=90,
|
||||
)
|
||||
|
||||
ax.set_title("Candidate Tier Breakdown")
|
||||
|
||||
# Save figure
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
|
||||
# Close figure
|
||||
plt.close(fig)
|
||||
|
||||
logger.info(f"Saved tier breakdown plot to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def generate_all_plots(df: pl.DataFrame, output_dir: Path) -> dict[str, Path]:
|
||||
"""
|
||||
Generate all visualization plots.
|
||||
|
||||
Args:
|
||||
df: DataFrame with scoring results
|
||||
output_dir: Directory where plots will be saved
|
||||
|
||||
Returns:
|
||||
Dictionary mapping plot name to file path
|
||||
|
||||
Notes:
|
||||
- Creates output directory if needed
|
||||
- Wraps each plot in try/except to continue on individual failures
|
||||
- Uses standard filenames for each plot type
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plots = {}
|
||||
|
||||
# Plot 1: Score distribution
|
||||
try:
|
||||
plots["score_distribution"] = plot_score_distribution(
|
||||
df,
|
||||
output_dir / "score_distribution.png",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create score distribution plot: {e}")
|
||||
|
||||
# Plot 2: Layer contributions
|
||||
try:
|
||||
plots["layer_contributions"] = plot_layer_contributions(
|
||||
df,
|
||||
output_dir / "layer_contributions.png",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create layer contributions plot: {e}")
|
||||
|
||||
# Plot 3: Tier breakdown
|
||||
try:
|
||||
plots["tier_breakdown"] = plot_tier_breakdown(
|
||||
df,
|
||||
output_dir / "tier_breakdown.png",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create tier breakdown plot: {e}")
|
||||
|
||||
logger.info(f"Generated {len(plots)} plots in {output_dir}")
|
||||
return plots
|
||||
112
tests/test_visualizations.py
Normal file
112
tests/test_visualizations.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Tests for visualization generation."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from usher_pipeline.output.visualizations import (
|
||||
generate_all_plots,
|
||||
plot_layer_contributions,
|
||||
plot_score_distribution,
|
||||
plot_tier_breakdown,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def synthetic_results_df():
|
||||
"""Create synthetic scored results DataFrame."""
|
||||
return pl.DataFrame({
|
||||
"gene_symbol": [f"GENE{i}" for i in range(30)],
|
||||
"composite_score": [0.1 + i * 0.03 for i in range(30)],
|
||||
"confidence_tier": (
|
||||
["HIGH"] * 10 + ["MEDIUM"] * 10 + ["LOW"] * 10
|
||||
),
|
||||
"gnomad_score": [0.5 if i % 2 == 0 else None for i in range(30)],
|
||||
"expression_score": [0.6 if i % 3 == 0 else None for i in range(30)],
|
||||
"annotation_score": [0.7 if i % 4 == 0 else None for i in range(30)],
|
||||
"localization_score": [0.8 if i % 5 == 0 else None for i in range(30)],
|
||||
"animal_model_score": [0.9 if i % 6 == 0 else None for i in range(30)],
|
||||
"literature_score": [0.85 if i % 7 == 0 else None for i in range(30)],
|
||||
})
|
||||
|
||||
|
||||
def test_plot_score_distribution_creates_file(synthetic_results_df, tmp_path):
|
||||
"""Test that score distribution plot creates a PNG file."""
|
||||
output_path = tmp_path / "score_dist.png"
|
||||
|
||||
result = plot_score_distribution(synthetic_results_df, output_path)
|
||||
|
||||
assert result == output_path
|
||||
assert output_path.exists()
|
||||
assert output_path.stat().st_size > 0
|
||||
|
||||
|
||||
def test_plot_layer_contributions_creates_file(synthetic_results_df, tmp_path):
|
||||
"""Test that layer contributions plot creates a PNG file."""
|
||||
output_path = tmp_path / "layer_contrib.png"
|
||||
|
||||
result = plot_layer_contributions(synthetic_results_df, output_path)
|
||||
|
||||
assert result == output_path
|
||||
assert output_path.exists()
|
||||
assert output_path.stat().st_size > 0
|
||||
|
||||
|
||||
def test_plot_tier_breakdown_creates_file(synthetic_results_df, tmp_path):
|
||||
"""Test that tier breakdown plot creates a PNG file."""
|
||||
output_path = tmp_path / "tier_breakdown.png"
|
||||
|
||||
result = plot_tier_breakdown(synthetic_results_df, output_path)
|
||||
|
||||
assert result == output_path
|
||||
assert output_path.exists()
|
||||
assert output_path.stat().st_size > 0
|
||||
|
||||
|
||||
def test_generate_all_plots_creates_all_files(synthetic_results_df, tmp_path):
|
||||
"""Test that generate_all_plots creates all 3 PNG files."""
|
||||
output_dir = tmp_path / "plots"
|
||||
|
||||
plots = generate_all_plots(synthetic_results_df, output_dir)
|
||||
|
||||
# Check all files exist
|
||||
assert (output_dir / "score_distribution.png").exists()
|
||||
assert (output_dir / "layer_contributions.png").exists()
|
||||
assert (output_dir / "tier_breakdown.png").exists()
|
||||
|
||||
|
||||
def test_generate_all_plots_returns_paths(synthetic_results_df, tmp_path):
|
||||
"""Test that generate_all_plots returns dict with 3 entries."""
|
||||
output_dir = tmp_path / "plots"
|
||||
|
||||
plots = generate_all_plots(synthetic_results_df, output_dir)
|
||||
|
||||
assert len(plots) == 3
|
||||
assert "score_distribution" in plots
|
||||
assert "layer_contributions" in plots
|
||||
assert "tier_breakdown" in plots
|
||||
|
||||
|
||||
def test_plots_handle_empty_dataframe(tmp_path):
|
||||
"""Test that plots handle empty DataFrames without crashing."""
|
||||
empty_df = pl.DataFrame({
|
||||
"gene_symbol": [],
|
||||
"composite_score": [],
|
||||
"confidence_tier": [],
|
||||
"gnomad_score": [],
|
||||
"expression_score": [],
|
||||
"annotation_score": [],
|
||||
"localization_score": [],
|
||||
"animal_model_score": [],
|
||||
"literature_score": [],
|
||||
})
|
||||
|
||||
output_dir = tmp_path / "empty_plots"
|
||||
|
||||
# Should not crash
|
||||
plots = generate_all_plots(empty_df, output_dir)
|
||||
|
||||
# At minimum, the function should return without error
|
||||
# Some plots may succeed (empty plot) or fail gracefully
|
||||
assert isinstance(plots, dict)
|
||||
Reference in New Issue
Block a user