feat(06-02): implement sensitivity analysis module with weight perturbation and Spearman correlation
- Add perturb_weight() function with renormalization to maintain sum=1.0 - Add run_sensitivity_analysis() for parameter sweep across all layers and deltas - Add summarize_sensitivity() for stability classification - Add generate_sensitivity_report() for human-readable output - Default perturbations: ±5% and ±10% with stability threshold 0.85
This commit is contained in:
378
src/usher_pipeline/scoring/sensitivity.py
Normal file
378
src/usher_pipeline/scoring/sensitivity.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""Parameter sweep sensitivity analysis for scoring weight validation."""
|
||||
|
||||
import polars as pl
|
||||
import structlog
|
||||
from scipy.stats import spearmanr
|
||||
|
||||
from usher_pipeline.config.schema import ScoringWeights
|
||||
from usher_pipeline.persistence.duckdb_store import PipelineStore
|
||||
from usher_pipeline.scoring.integration import compute_composite_scores
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Evidence layer names (must match ScoringWeights fields)
|
||||
EVIDENCE_LAYERS = [
|
||||
"gnomad",
|
||||
"expression",
|
||||
"annotation",
|
||||
"localization",
|
||||
"animal_model",
|
||||
"literature",
|
||||
]
|
||||
|
||||
# Default perturbation deltas (±5% and ±10%)
|
||||
DEFAULT_DELTAS = [-0.10, -0.05, 0.05, 0.10]
|
||||
|
||||
# Spearman correlation threshold for stability classification
|
||||
STABILITY_THRESHOLD = 0.85
|
||||
|
||||
|
||||
def perturb_weight(baseline: ScoringWeights, layer: str, delta: float) -> ScoringWeights:
|
||||
"""
|
||||
Perturb a single weight and renormalize to maintain sum=1.0 constraint.
|
||||
|
||||
Args:
|
||||
baseline: Baseline ScoringWeights instance
|
||||
layer: Evidence layer name to perturb (must be in EVIDENCE_LAYERS)
|
||||
delta: Perturbation amount (can be negative)
|
||||
|
||||
Returns:
|
||||
New ScoringWeights instance with perturbed and renormalized weights
|
||||
|
||||
Raises:
|
||||
ValueError: If layer not in EVIDENCE_LAYERS
|
||||
|
||||
Notes:
|
||||
- Clamps perturbed weight to [0.0, 1.0] before renormalization
|
||||
- Renormalizes ALL weights so they sum to 1.0
|
||||
- Maintains weights.validate_sum() guarantee
|
||||
"""
|
||||
if layer not in EVIDENCE_LAYERS:
|
||||
raise ValueError(
|
||||
f"Invalid layer '{layer}'. Must be one of {EVIDENCE_LAYERS}"
|
||||
)
|
||||
|
||||
# Get baseline weights as dict
|
||||
w_dict = baseline.model_dump()
|
||||
|
||||
# Apply perturbation with clamping
|
||||
w_dict[layer] = max(0.0, min(1.0, w_dict[layer] + delta))
|
||||
|
||||
# Renormalize to sum=1.0
|
||||
total = sum(w_dict[k] for k in EVIDENCE_LAYERS)
|
||||
if total > 0:
|
||||
for k in EVIDENCE_LAYERS:
|
||||
w_dict[k] = w_dict[k] / total
|
||||
else:
|
||||
# Edge case: all weights became zero (should not happen in practice)
|
||||
# Revert to uniform distribution
|
||||
uniform = 1.0 / len(EVIDENCE_LAYERS)
|
||||
for k in EVIDENCE_LAYERS:
|
||||
w_dict[k] = uniform
|
||||
|
||||
# Return new ScoringWeights instance
|
||||
return ScoringWeights(**w_dict)
|
||||
|
||||
|
||||
def run_sensitivity_analysis(
|
||||
store: PipelineStore,
|
||||
baseline_weights: ScoringWeights,
|
||||
deltas: list[float] | None = None,
|
||||
top_n: int = 100,
|
||||
) -> dict:
|
||||
"""
|
||||
Run sensitivity analysis by perturbing each weight and measuring rank stability.
|
||||
|
||||
For each layer and each delta, perturbs the weight, recomputes composite scores,
|
||||
and measures Spearman rank correlation on the top-N genes compared to baseline.
|
||||
|
||||
Args:
|
||||
store: PipelineStore with evidence layer tables
|
||||
baseline_weights: Baseline ScoringWeights to perturb
|
||||
deltas: List of perturbation amounts (default: DEFAULT_DELTAS)
|
||||
top_n: Number of top-ranked genes to compare (default: 100)
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- baseline_weights: dict - baseline weights as dict
|
||||
- results: list[dict] - per-perturbation results with:
|
||||
- layer: str
|
||||
- delta: float
|
||||
- perturbed_weights: dict
|
||||
- spearman_rho: float or None
|
||||
- spearman_pval: float or None
|
||||
- overlap_count: int - genes in both top-N lists
|
||||
- top_n: int
|
||||
- top_n: int
|
||||
- total_perturbations: int
|
||||
|
||||
Notes:
|
||||
- compute_composite_scores re-queries DB each time (by design)
|
||||
- Spearman correlation computed on composite_score of overlapping genes
|
||||
- If overlap < 10 genes, records rho=None and logs warning
|
||||
"""
|
||||
if deltas is None:
|
||||
deltas = DEFAULT_DELTAS
|
||||
|
||||
logger.info(
|
||||
"run_sensitivity_analysis_start",
|
||||
baseline_weights=baseline_weights.model_dump(),
|
||||
deltas=deltas,
|
||||
top_n=top_n,
|
||||
total_perturbations=len(EVIDENCE_LAYERS) * len(deltas),
|
||||
)
|
||||
|
||||
# Compute baseline scores and get top-N genes
|
||||
baseline_scores = compute_composite_scores(store, baseline_weights)
|
||||
baseline_top_n = (
|
||||
baseline_scores
|
||||
.filter(pl.col("composite_score").is_not_null())
|
||||
.sort("composite_score", descending=True)
|
||||
.head(top_n)
|
||||
.select(["gene_symbol", "composite_score"])
|
||||
.rename({"composite_score": "baseline_score"})
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
# For each layer, for each delta, compute perturbation
|
||||
for layer in EVIDENCE_LAYERS:
|
||||
for delta in deltas:
|
||||
# Create perturbed weights
|
||||
perturbed_weights = perturb_weight(baseline_weights, layer, delta)
|
||||
|
||||
# Compute perturbed scores
|
||||
perturbed_scores = compute_composite_scores(store, perturbed_weights)
|
||||
perturbed_top_n = (
|
||||
perturbed_scores
|
||||
.filter(pl.col("composite_score").is_not_null())
|
||||
.sort("composite_score", descending=True)
|
||||
.head(top_n)
|
||||
.select(["gene_symbol", "composite_score"])
|
||||
.rename({"composite_score": "perturbed_score"})
|
||||
)
|
||||
|
||||
# Inner join to get overlapping genes
|
||||
joined = baseline_top_n.join(perturbed_top_n, on="gene_symbol", how="inner")
|
||||
overlap_count = joined.height
|
||||
|
||||
# Compute Spearman correlation if sufficient overlap
|
||||
if overlap_count < 10:
|
||||
logger.warning(
|
||||
"run_sensitivity_analysis_low_overlap",
|
||||
layer=layer,
|
||||
delta=delta,
|
||||
overlap_count=overlap_count,
|
||||
message="Insufficient overlap for Spearman correlation (need >= 10)",
|
||||
)
|
||||
spearman_rho = None
|
||||
spearman_pval = None
|
||||
else:
|
||||
# Extract paired scores
|
||||
baseline_vals = joined["baseline_score"].to_numpy()
|
||||
perturbed_vals = joined["perturbed_score"].to_numpy()
|
||||
|
||||
# Compute Spearman correlation
|
||||
rho, pval = spearmanr(baseline_vals, perturbed_vals)
|
||||
spearman_rho = float(rho)
|
||||
spearman_pval = float(pval)
|
||||
|
||||
# Record result
|
||||
result = {
|
||||
"layer": layer,
|
||||
"delta": delta,
|
||||
"perturbed_weights": perturbed_weights.model_dump(),
|
||||
"spearman_rho": spearman_rho,
|
||||
"spearman_pval": spearman_pval,
|
||||
"overlap_count": overlap_count,
|
||||
"top_n": top_n,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Log each perturbation result
|
||||
logger.info(
|
||||
"run_sensitivity_analysis_perturbation",
|
||||
layer=layer,
|
||||
delta=f"{delta:+.2f}",
|
||||
spearman_rho=f"{spearman_rho:.4f}" if spearman_rho is not None else "N/A",
|
||||
spearman_pval=f"{spearman_pval:.4e}" if spearman_pval is not None else "N/A",
|
||||
overlap_count=overlap_count,
|
||||
stable=spearman_rho >= STABILITY_THRESHOLD if spearman_rho is not None else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"run_sensitivity_analysis_complete",
|
||||
total_perturbations=len(results),
|
||||
layers=len(EVIDENCE_LAYERS),
|
||||
deltas=len(deltas),
|
||||
)
|
||||
|
||||
return {
|
||||
"baseline_weights": baseline_weights.model_dump(),
|
||||
"results": results,
|
||||
"top_n": top_n,
|
||||
"total_perturbations": len(results),
|
||||
}
|
||||
|
||||
|
||||
def summarize_sensitivity(analysis_result: dict) -> dict:
|
||||
"""
|
||||
Summarize sensitivity analysis results with stability classification.
|
||||
|
||||
Args:
|
||||
analysis_result: Dict returned from run_sensitivity_analysis()
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- min_rho: float - minimum Spearman rho (excluding None)
|
||||
- max_rho: float - maximum Spearman rho (excluding None)
|
||||
- mean_rho: float - mean Spearman rho (excluding None)
|
||||
- stable_count: int - count of perturbations with rho >= STABILITY_THRESHOLD
|
||||
- unstable_count: int - count of perturbations with rho < STABILITY_THRESHOLD
|
||||
- total_perturbations: int
|
||||
- overall_stable: bool - True if all non-None rhos >= STABILITY_THRESHOLD
|
||||
- most_sensitive_layer: str - layer with lowest mean rho
|
||||
- most_robust_layer: str - layer with highest mean rho
|
||||
|
||||
Notes:
|
||||
- Excludes None rho values from all statistics
|
||||
- most_sensitive/robust computed from per-layer mean rho
|
||||
"""
|
||||
results = analysis_result["results"]
|
||||
|
||||
# Filter out None rho values
|
||||
valid_results = [r for r in results if r["spearman_rho"] is not None]
|
||||
|
||||
if not valid_results:
|
||||
# Edge case: all perturbations had insufficient overlap
|
||||
return {
|
||||
"min_rho": None,
|
||||
"max_rho": None,
|
||||
"mean_rho": None,
|
||||
"stable_count": 0,
|
||||
"unstable_count": 0,
|
||||
"total_perturbations": analysis_result["total_perturbations"],
|
||||
"overall_stable": False,
|
||||
"most_sensitive_layer": None,
|
||||
"most_robust_layer": None,
|
||||
}
|
||||
|
||||
# Compute global statistics
|
||||
rho_values = [r["spearman_rho"] for r in valid_results]
|
||||
min_rho = min(rho_values)
|
||||
max_rho = max(rho_values)
|
||||
mean_rho = sum(rho_values) / len(rho_values)
|
||||
|
||||
# Count stable/unstable
|
||||
stable_count = sum(1 for rho in rho_values if rho >= STABILITY_THRESHOLD)
|
||||
unstable_count = len(rho_values) - stable_count
|
||||
|
||||
# Overall stability: all non-None rhos must be >= threshold
|
||||
overall_stable = all(rho >= STABILITY_THRESHOLD for rho in rho_values)
|
||||
|
||||
# Compute per-layer mean rho
|
||||
layer_rho_map = {}
|
||||
for layer in EVIDENCE_LAYERS:
|
||||
layer_results = [
|
||||
r["spearman_rho"]
|
||||
for r in valid_results
|
||||
if r["layer"] == layer and r["spearman_rho"] is not None
|
||||
]
|
||||
if layer_results:
|
||||
layer_rho_map[layer] = sum(layer_results) / len(layer_results)
|
||||
|
||||
# Find most sensitive (lowest mean rho) and most robust (highest mean rho)
|
||||
if layer_rho_map:
|
||||
most_sensitive_layer = min(layer_rho_map, key=layer_rho_map.get)
|
||||
most_robust_layer = max(layer_rho_map, key=layer_rho_map.get)
|
||||
else:
|
||||
most_sensitive_layer = None
|
||||
most_robust_layer = None
|
||||
|
||||
return {
|
||||
"min_rho": min_rho,
|
||||
"max_rho": max_rho,
|
||||
"mean_rho": mean_rho,
|
||||
"stable_count": stable_count,
|
||||
"unstable_count": unstable_count,
|
||||
"total_perturbations": analysis_result["total_perturbations"],
|
||||
"overall_stable": overall_stable,
|
||||
"most_sensitive_layer": most_sensitive_layer,
|
||||
"most_robust_layer": most_robust_layer,
|
||||
}
|
||||
|
||||
|
||||
def generate_sensitivity_report(analysis_result: dict, summary: dict) -> str:
|
||||
"""
|
||||
Generate human-readable sensitivity analysis report.
|
||||
|
||||
Args:
|
||||
analysis_result: Dict returned from run_sensitivity_analysis()
|
||||
summary: Dict returned from summarize_sensitivity()
|
||||
|
||||
Returns:
|
||||
Multi-line text report with perturbation table and summary
|
||||
|
||||
Notes:
|
||||
- Follows formatting pattern from generate_validation_report()
|
||||
- Shows table with Layer | Delta | Spearman rho | p-value | Stable?
|
||||
- Includes interpretation text
|
||||
"""
|
||||
status = "STABLE ✓" if summary["overall_stable"] else "UNSTABLE ✗"
|
||||
|
||||
report = [
|
||||
f"Sensitivity Analysis: {status}",
|
||||
"",
|
||||
"Summary:",
|
||||
f" Total perturbations: {summary['total_perturbations']}",
|
||||
f" Stable perturbations: {summary['stable_count']} (rho >= {STABILITY_THRESHOLD})",
|
||||
f" Unstable perturbations: {summary['unstable_count']}",
|
||||
f" Mean Spearman rho: {summary['mean_rho']:.4f}" if summary['mean_rho'] is not None else " Mean Spearman rho: N/A",
|
||||
f" Range: [{summary['min_rho']:.4f}, {summary['max_rho']:.4f}]" if summary['min_rho'] is not None else " Range: N/A",
|
||||
"",
|
||||
]
|
||||
|
||||
# Add interpretation
|
||||
if summary["overall_stable"]:
|
||||
report.append(
|
||||
f"All weight perturbations (±5-10%) produce stable rankings (rho >= {STABILITY_THRESHOLD}), "
|
||||
"validating result robustness."
|
||||
)
|
||||
else:
|
||||
report.append(
|
||||
f"Warning: Some perturbations produce unstable rankings (rho < {STABILITY_THRESHOLD}). "
|
||||
"Results may be sensitive to weight choices."
|
||||
)
|
||||
|
||||
if summary["most_sensitive_layer"] and summary["most_robust_layer"]:
|
||||
report.append("")
|
||||
report.append(f" Most sensitive layer: {summary['most_sensitive_layer']}")
|
||||
report.append(f" Most robust layer: {summary['most_robust_layer']}")
|
||||
|
||||
report.append("")
|
||||
report.append("Perturbation Results:")
|
||||
report.append("-" * 100)
|
||||
report.append(f"{'Layer':<15} {'Delta':>8} {'Spearman rho':>14} {'p-value':>12} {'Overlap':>10} {'Stable?':>10}")
|
||||
report.append("-" * 100)
|
||||
|
||||
for result in analysis_result["results"]:
|
||||
layer = result["layer"]
|
||||
delta = result["delta"]
|
||||
rho = result["spearman_rho"]
|
||||
pval = result["spearman_pval"]
|
||||
overlap = result["overlap_count"]
|
||||
|
||||
if rho is not None:
|
||||
stable_mark = "✓" if rho >= STABILITY_THRESHOLD else "✗"
|
||||
rho_str = f"{rho:.4f}"
|
||||
pval_str = f"{pval:.2e}"
|
||||
else:
|
||||
stable_mark = "N/A"
|
||||
rho_str = "N/A"
|
||||
pval_str = "N/A"
|
||||
|
||||
report.append(
|
||||
f"{layer:<15} {delta:>+8.2f} {rho_str:>14} {pval_str:>12} {overlap:>10} {stable_mark:>10}"
|
||||
)
|
||||
|
||||
return "\n".join(report)
|
||||
Reference in New Issue
Block a user