feat(06-02): implement sensitivity analysis module with weight perturbation and Spearman correlation
- Add perturb_weight() function with renormalization to maintain sum=1.0 - Add run_sensitivity_analysis() for parameter sweep across all layers and deltas - Add summarize_sensitivity() for stability classification - Add generate_sensitivity_report() for human-readable output - Default perturbations: ±5% and ±10% with stability threshold 0.85
This commit is contained in:
378
src/usher_pipeline/scoring/sensitivity.py
Normal file
378
src/usher_pipeline/scoring/sensitivity.py
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
"""Parameter sweep sensitivity analysis for scoring weight validation."""
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
import structlog
|
||||||
|
from scipy.stats import spearmanr
|
||||||
|
|
||||||
|
from usher_pipeline.config.schema import ScoringWeights
|
||||||
|
from usher_pipeline.persistence.duckdb_store import PipelineStore
|
||||||
|
from usher_pipeline.scoring.integration import compute_composite_scores
|
||||||
|
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
# Evidence layer names (must match ScoringWeights fields)
|
||||||
|
EVIDENCE_LAYERS = [
|
||||||
|
"gnomad",
|
||||||
|
"expression",
|
||||||
|
"annotation",
|
||||||
|
"localization",
|
||||||
|
"animal_model",
|
||||||
|
"literature",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default perturbation deltas (±5% and ±10%)
|
||||||
|
DEFAULT_DELTAS = [-0.10, -0.05, 0.05, 0.10]
|
||||||
|
|
||||||
|
# Spearman correlation threshold for stability classification
|
||||||
|
STABILITY_THRESHOLD = 0.85
|
||||||
|
|
||||||
|
|
||||||
|
def perturb_weight(baseline: ScoringWeights, layer: str, delta: float) -> ScoringWeights:
|
||||||
|
"""
|
||||||
|
Perturb a single weight and renormalize to maintain sum=1.0 constraint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
baseline: Baseline ScoringWeights instance
|
||||||
|
layer: Evidence layer name to perturb (must be in EVIDENCE_LAYERS)
|
||||||
|
delta: Perturbation amount (can be negative)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New ScoringWeights instance with perturbed and renormalized weights
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If layer not in EVIDENCE_LAYERS
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Clamps perturbed weight to [0.0, 1.0] before renormalization
|
||||||
|
- Renormalizes ALL weights so they sum to 1.0
|
||||||
|
- Maintains weights.validate_sum() guarantee
|
||||||
|
"""
|
||||||
|
if layer not in EVIDENCE_LAYERS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid layer '{layer}'. Must be one of {EVIDENCE_LAYERS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get baseline weights as dict
|
||||||
|
w_dict = baseline.model_dump()
|
||||||
|
|
||||||
|
# Apply perturbation with clamping
|
||||||
|
w_dict[layer] = max(0.0, min(1.0, w_dict[layer] + delta))
|
||||||
|
|
||||||
|
# Renormalize to sum=1.0
|
||||||
|
total = sum(w_dict[k] for k in EVIDENCE_LAYERS)
|
||||||
|
if total > 0:
|
||||||
|
for k in EVIDENCE_LAYERS:
|
||||||
|
w_dict[k] = w_dict[k] / total
|
||||||
|
else:
|
||||||
|
# Edge case: all weights became zero (should not happen in practice)
|
||||||
|
# Revert to uniform distribution
|
||||||
|
uniform = 1.0 / len(EVIDENCE_LAYERS)
|
||||||
|
for k in EVIDENCE_LAYERS:
|
||||||
|
w_dict[k] = uniform
|
||||||
|
|
||||||
|
# Return new ScoringWeights instance
|
||||||
|
return ScoringWeights(**w_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def run_sensitivity_analysis(
|
||||||
|
store: PipelineStore,
|
||||||
|
baseline_weights: ScoringWeights,
|
||||||
|
deltas: list[float] | None = None,
|
||||||
|
top_n: int = 100,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Run sensitivity analysis by perturbing each weight and measuring rank stability.
|
||||||
|
|
||||||
|
For each layer and each delta, perturbs the weight, recomputes composite scores,
|
||||||
|
and measures Spearman rank correlation on the top-N genes compared to baseline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store: PipelineStore with evidence layer tables
|
||||||
|
baseline_weights: Baseline ScoringWeights to perturb
|
||||||
|
deltas: List of perturbation amounts (default: DEFAULT_DELTAS)
|
||||||
|
top_n: Number of top-ranked genes to compare (default: 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys:
|
||||||
|
- baseline_weights: dict - baseline weights as dict
|
||||||
|
- results: list[dict] - per-perturbation results with:
|
||||||
|
- layer: str
|
||||||
|
- delta: float
|
||||||
|
- perturbed_weights: dict
|
||||||
|
- spearman_rho: float or None
|
||||||
|
- spearman_pval: float or None
|
||||||
|
- overlap_count: int - genes in both top-N lists
|
||||||
|
- top_n: int
|
||||||
|
- top_n: int
|
||||||
|
- total_perturbations: int
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- compute_composite_scores re-queries DB each time (by design)
|
||||||
|
- Spearman correlation computed on composite_score of overlapping genes
|
||||||
|
- If overlap < 10 genes, records rho=None and logs warning
|
||||||
|
"""
|
||||||
|
if deltas is None:
|
||||||
|
deltas = DEFAULT_DELTAS
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"run_sensitivity_analysis_start",
|
||||||
|
baseline_weights=baseline_weights.model_dump(),
|
||||||
|
deltas=deltas,
|
||||||
|
top_n=top_n,
|
||||||
|
total_perturbations=len(EVIDENCE_LAYERS) * len(deltas),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute baseline scores and get top-N genes
|
||||||
|
baseline_scores = compute_composite_scores(store, baseline_weights)
|
||||||
|
baseline_top_n = (
|
||||||
|
baseline_scores
|
||||||
|
.filter(pl.col("composite_score").is_not_null())
|
||||||
|
.sort("composite_score", descending=True)
|
||||||
|
.head(top_n)
|
||||||
|
.select(["gene_symbol", "composite_score"])
|
||||||
|
.rename({"composite_score": "baseline_score"})
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# For each layer, for each delta, compute perturbation
|
||||||
|
for layer in EVIDENCE_LAYERS:
|
||||||
|
for delta in deltas:
|
||||||
|
# Create perturbed weights
|
||||||
|
perturbed_weights = perturb_weight(baseline_weights, layer, delta)
|
||||||
|
|
||||||
|
# Compute perturbed scores
|
||||||
|
perturbed_scores = compute_composite_scores(store, perturbed_weights)
|
||||||
|
perturbed_top_n = (
|
||||||
|
perturbed_scores
|
||||||
|
.filter(pl.col("composite_score").is_not_null())
|
||||||
|
.sort("composite_score", descending=True)
|
||||||
|
.head(top_n)
|
||||||
|
.select(["gene_symbol", "composite_score"])
|
||||||
|
.rename({"composite_score": "perturbed_score"})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inner join to get overlapping genes
|
||||||
|
joined = baseline_top_n.join(perturbed_top_n, on="gene_symbol", how="inner")
|
||||||
|
overlap_count = joined.height
|
||||||
|
|
||||||
|
# Compute Spearman correlation if sufficient overlap
|
||||||
|
if overlap_count < 10:
|
||||||
|
logger.warning(
|
||||||
|
"run_sensitivity_analysis_low_overlap",
|
||||||
|
layer=layer,
|
||||||
|
delta=delta,
|
||||||
|
overlap_count=overlap_count,
|
||||||
|
message="Insufficient overlap for Spearman correlation (need >= 10)",
|
||||||
|
)
|
||||||
|
spearman_rho = None
|
||||||
|
spearman_pval = None
|
||||||
|
else:
|
||||||
|
# Extract paired scores
|
||||||
|
baseline_vals = joined["baseline_score"].to_numpy()
|
||||||
|
perturbed_vals = joined["perturbed_score"].to_numpy()
|
||||||
|
|
||||||
|
# Compute Spearman correlation
|
||||||
|
rho, pval = spearmanr(baseline_vals, perturbed_vals)
|
||||||
|
spearman_rho = float(rho)
|
||||||
|
spearman_pval = float(pval)
|
||||||
|
|
||||||
|
# Record result
|
||||||
|
result = {
|
||||||
|
"layer": layer,
|
||||||
|
"delta": delta,
|
||||||
|
"perturbed_weights": perturbed_weights.model_dump(),
|
||||||
|
"spearman_rho": spearman_rho,
|
||||||
|
"spearman_pval": spearman_pval,
|
||||||
|
"overlap_count": overlap_count,
|
||||||
|
"top_n": top_n,
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Log each perturbation result
|
||||||
|
logger.info(
|
||||||
|
"run_sensitivity_analysis_perturbation",
|
||||||
|
layer=layer,
|
||||||
|
delta=f"{delta:+.2f}",
|
||||||
|
spearman_rho=f"{spearman_rho:.4f}" if spearman_rho is not None else "N/A",
|
||||||
|
spearman_pval=f"{spearman_pval:.4e}" if spearman_pval is not None else "N/A",
|
||||||
|
overlap_count=overlap_count,
|
||||||
|
stable=spearman_rho >= STABILITY_THRESHOLD if spearman_rho is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"run_sensitivity_analysis_complete",
|
||||||
|
total_perturbations=len(results),
|
||||||
|
layers=len(EVIDENCE_LAYERS),
|
||||||
|
deltas=len(deltas),
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"baseline_weights": baseline_weights.model_dump(),
|
||||||
|
"results": results,
|
||||||
|
"top_n": top_n,
|
||||||
|
"total_perturbations": len(results),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_sensitivity(analysis_result: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Summarize sensitivity analysis results with stability classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
analysis_result: Dict returned from run_sensitivity_analysis()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys:
|
||||||
|
- min_rho: float - minimum Spearman rho (excluding None)
|
||||||
|
- max_rho: float - maximum Spearman rho (excluding None)
|
||||||
|
- mean_rho: float - mean Spearman rho (excluding None)
|
||||||
|
- stable_count: int - count of perturbations with rho >= STABILITY_THRESHOLD
|
||||||
|
- unstable_count: int - count of perturbations with rho < STABILITY_THRESHOLD
|
||||||
|
- total_perturbations: int
|
||||||
|
- overall_stable: bool - True if all non-None rhos >= STABILITY_THRESHOLD
|
||||||
|
- most_sensitive_layer: str - layer with lowest mean rho
|
||||||
|
- most_robust_layer: str - layer with highest mean rho
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Excludes None rho values from all statistics
|
||||||
|
- most_sensitive/robust computed from per-layer mean rho
|
||||||
|
"""
|
||||||
|
results = analysis_result["results"]
|
||||||
|
|
||||||
|
# Filter out None rho values
|
||||||
|
valid_results = [r for r in results if r["spearman_rho"] is not None]
|
||||||
|
|
||||||
|
if not valid_results:
|
||||||
|
# Edge case: all perturbations had insufficient overlap
|
||||||
|
return {
|
||||||
|
"min_rho": None,
|
||||||
|
"max_rho": None,
|
||||||
|
"mean_rho": None,
|
||||||
|
"stable_count": 0,
|
||||||
|
"unstable_count": 0,
|
||||||
|
"total_perturbations": analysis_result["total_perturbations"],
|
||||||
|
"overall_stable": False,
|
||||||
|
"most_sensitive_layer": None,
|
||||||
|
"most_robust_layer": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute global statistics
|
||||||
|
rho_values = [r["spearman_rho"] for r in valid_results]
|
||||||
|
min_rho = min(rho_values)
|
||||||
|
max_rho = max(rho_values)
|
||||||
|
mean_rho = sum(rho_values) / len(rho_values)
|
||||||
|
|
||||||
|
# Count stable/unstable
|
||||||
|
stable_count = sum(1 for rho in rho_values if rho >= STABILITY_THRESHOLD)
|
||||||
|
unstable_count = len(rho_values) - stable_count
|
||||||
|
|
||||||
|
# Overall stability: all non-None rhos must be >= threshold
|
||||||
|
overall_stable = all(rho >= STABILITY_THRESHOLD for rho in rho_values)
|
||||||
|
|
||||||
|
# Compute per-layer mean rho
|
||||||
|
layer_rho_map = {}
|
||||||
|
for layer in EVIDENCE_LAYERS:
|
||||||
|
layer_results = [
|
||||||
|
r["spearman_rho"]
|
||||||
|
for r in valid_results
|
||||||
|
if r["layer"] == layer and r["spearman_rho"] is not None
|
||||||
|
]
|
||||||
|
if layer_results:
|
||||||
|
layer_rho_map[layer] = sum(layer_results) / len(layer_results)
|
||||||
|
|
||||||
|
# Find most sensitive (lowest mean rho) and most robust (highest mean rho)
|
||||||
|
if layer_rho_map:
|
||||||
|
most_sensitive_layer = min(layer_rho_map, key=layer_rho_map.get)
|
||||||
|
most_robust_layer = max(layer_rho_map, key=layer_rho_map.get)
|
||||||
|
else:
|
||||||
|
most_sensitive_layer = None
|
||||||
|
most_robust_layer = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"min_rho": min_rho,
|
||||||
|
"max_rho": max_rho,
|
||||||
|
"mean_rho": mean_rho,
|
||||||
|
"stable_count": stable_count,
|
||||||
|
"unstable_count": unstable_count,
|
||||||
|
"total_perturbations": analysis_result["total_perturbations"],
|
||||||
|
"overall_stable": overall_stable,
|
||||||
|
"most_sensitive_layer": most_sensitive_layer,
|
||||||
|
"most_robust_layer": most_robust_layer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_sensitivity_report(analysis_result: dict, summary: dict) -> str:
|
||||||
|
"""
|
||||||
|
Generate human-readable sensitivity analysis report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
analysis_result: Dict returned from run_sensitivity_analysis()
|
||||||
|
summary: Dict returned from summarize_sensitivity()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Multi-line text report with perturbation table and summary
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Follows formatting pattern from generate_validation_report()
|
||||||
|
- Shows table with Layer | Delta | Spearman rho | p-value | Stable?
|
||||||
|
- Includes interpretation text
|
||||||
|
"""
|
||||||
|
status = "STABLE ✓" if summary["overall_stable"] else "UNSTABLE ✗"
|
||||||
|
|
||||||
|
report = [
|
||||||
|
f"Sensitivity Analysis: {status}",
|
||||||
|
"",
|
||||||
|
"Summary:",
|
||||||
|
f" Total perturbations: {summary['total_perturbations']}",
|
||||||
|
f" Stable perturbations: {summary['stable_count']} (rho >= {STABILITY_THRESHOLD})",
|
||||||
|
f" Unstable perturbations: {summary['unstable_count']}",
|
||||||
|
f" Mean Spearman rho: {summary['mean_rho']:.4f}" if summary['mean_rho'] is not None else " Mean Spearman rho: N/A",
|
||||||
|
f" Range: [{summary['min_rho']:.4f}, {summary['max_rho']:.4f}]" if summary['min_rho'] is not None else " Range: N/A",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add interpretation
|
||||||
|
if summary["overall_stable"]:
|
||||||
|
report.append(
|
||||||
|
f"All weight perturbations (±5-10%) produce stable rankings (rho >= {STABILITY_THRESHOLD}), "
|
||||||
|
"validating result robustness."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
report.append(
|
||||||
|
f"Warning: Some perturbations produce unstable rankings (rho < {STABILITY_THRESHOLD}). "
|
||||||
|
"Results may be sensitive to weight choices."
|
||||||
|
)
|
||||||
|
|
||||||
|
if summary["most_sensitive_layer"] and summary["most_robust_layer"]:
|
||||||
|
report.append("")
|
||||||
|
report.append(f" Most sensitive layer: {summary['most_sensitive_layer']}")
|
||||||
|
report.append(f" Most robust layer: {summary['most_robust_layer']}")
|
||||||
|
|
||||||
|
report.append("")
|
||||||
|
report.append("Perturbation Results:")
|
||||||
|
report.append("-" * 100)
|
||||||
|
report.append(f"{'Layer':<15} {'Delta':>8} {'Spearman rho':>14} {'p-value':>12} {'Overlap':>10} {'Stable?':>10}")
|
||||||
|
report.append("-" * 100)
|
||||||
|
|
||||||
|
for result in analysis_result["results"]:
|
||||||
|
layer = result["layer"]
|
||||||
|
delta = result["delta"]
|
||||||
|
rho = result["spearman_rho"]
|
||||||
|
pval = result["spearman_pval"]
|
||||||
|
overlap = result["overlap_count"]
|
||||||
|
|
||||||
|
if rho is not None:
|
||||||
|
stable_mark = "✓" if rho >= STABILITY_THRESHOLD else "✗"
|
||||||
|
rho_str = f"{rho:.4f}"
|
||||||
|
pval_str = f"{pval:.2e}"
|
||||||
|
else:
|
||||||
|
stable_mark = "N/A"
|
||||||
|
rho_str = "N/A"
|
||||||
|
pval_str = "N/A"
|
||||||
|
|
||||||
|
report.append(
|
||||||
|
f"{layer:<15} {delta:>+8.2f} {rho_str:>14} {pval_str:>12} {overlap:>10} {stable_mark:>10}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(report)
|
||||||
Reference in New Issue
Block a user