feat(06-01): add recall@k metrics and extended positive control validation

- compute_recall_at_k() measures recall at absolute (100, 500, 1000, 2000) and percentage (5%, 10%, 20%) thresholds
- validate_positive_controls_extended() combines percentile + recall + per-source metrics
- Per-source breakdown separates OMIM Usher from SYSCILIA SCGS v2 validation
- Export all new functions in __init__.py including negative control imports
This commit is contained in:
2026-02-12 04:41:00 +08:00
parent a7589d9bf1
commit 0f615c0d53
2 changed files with 240 additions and 0 deletions

View File

@@ -17,6 +17,14 @@ from usher_pipeline.scoring.quality_control import (
from usher_pipeline.scoring.validation import (
validate_known_gene_ranking,
generate_validation_report,
compute_recall_at_k,
validate_positive_controls_extended,
)
from usher_pipeline.scoring.negative_controls import (
HOUSEKEEPING_GENES_CORE,
compile_housekeeping_genes,
validate_negative_controls,
generate_negative_control_report,
)
__all__ = [
@@ -30,4 +38,10 @@ __all__ = [
"run_qc_checks",
"validate_known_gene_ranking",
"generate_validation_report",
"compute_recall_at_k",
"validate_positive_controls_extended",
"HOUSEKEEPING_GENES_CORE",
"compile_housekeeping_genes",
"validate_negative_controls",
"generate_negative_control_report",
]

View File

@@ -225,3 +225,229 @@ This indicates either:
)
return "\n".join(report)
def compute_recall_at_k(
store: PipelineStore,
k_values: list[int] | None = None
) -> dict:
"""
Compute recall@k metrics for known genes at various thresholds.
Measures what fraction of known genes appear in the top-k ranked candidates,
providing specific metrics like ">70% recall in top 10%" required by success criteria.
Args:
store: PipelineStore with scored_genes table
k_values: Absolute top-k thresholds (default: [100, 500, 1000, 2000])
Returns:
Dict with keys:
- recalls_absolute: dict mapping k -> recall float (e.g., {100: 0.65, 500: 0.85})
- recalls_percentage: dict mapping pct_string -> recall float (e.g., {"5%": 0.58, "10%": 0.72})
- total_known_unique: int - count of unique known genes (deduplicated)
- total_scored: int - count of genes with non-NULL composite scores
Notes:
- Known genes are deduplicated on gene_symbol (genes in both sources count once)
- Recall@k = (known genes in top-k) / total_known_unique
- Percentage thresholds computed at 5%, 10%, 20% of total_scored
- Genes without composite_score (NULL) are excluded
- Ordered by composite_score DESC (highest scores first)
"""
logger.info("compute_recall_at_k_start")
# Default k values
if k_values is None:
k_values = [100, 500, 1000, 2000]
# Compile known genes and deduplicate on gene_symbol
known_df = compile_known_genes()
known_genes_set = set(known_df["gene_symbol"].unique())
total_known_unique = len(known_genes_set)
# Get total count of scored genes
total_scored = store.conn.execute("""
SELECT COUNT(*) as total
FROM scored_genes
WHERE composite_score IS NOT NULL
""").fetchone()[0]
# Compute percentage thresholds
percentage_thresholds = [0.05, 0.10, 0.20] # 5%, 10%, 20%
percentage_k_values = {
f"{int(pct * 100)}%": int(total_scored * pct)
for pct in percentage_thresholds
}
# Query top-k genes for each threshold and compute recall
recalls_absolute = {}
for k in k_values:
query = f"""
SELECT gene_symbol
FROM scored_genes
WHERE composite_score IS NOT NULL
ORDER BY composite_score DESC
LIMIT {k}
"""
top_k_genes = store.conn.execute(query).pl()
top_k_set = set(top_k_genes["gene_symbol"])
# Count how many known genes are in top-k
known_in_top_k = len(known_genes_set & top_k_set)
recall = known_in_top_k / total_known_unique if total_known_unique > 0 else 0.0
recalls_absolute[k] = recall
logger.info(
"recall_at_k_absolute",
k=k,
recall=f"{recall:.4f}",
known_in_top_k=known_in_top_k,
total_known=total_known_unique,
)
# Compute recall at percentage thresholds
recalls_percentage = {}
for pct_string, k in percentage_k_values.items():
query = f"""
SELECT gene_symbol
FROM scored_genes
WHERE composite_score IS NOT NULL
ORDER BY composite_score DESC
LIMIT {k}
"""
top_k_genes = store.conn.execute(query).pl()
top_k_set = set(top_k_genes["gene_symbol"])
known_in_top_k = len(known_genes_set & top_k_set)
recall = known_in_top_k / total_known_unique if total_known_unique > 0 else 0.0
recalls_percentage[pct_string] = recall
logger.info(
"recall_at_k_percentage",
threshold=pct_string,
k=k,
recall=f"{recall:.4f}",
known_in_top_k=known_in_top_k,
total_known=total_known_unique,
)
return {
"recalls_absolute": recalls_absolute,
"recalls_percentage": recalls_percentage,
"total_known_unique": total_known_unique,
"total_scored": total_scored,
}
def validate_positive_controls_extended(
store: PipelineStore,
percentile_threshold: float = 0.75
) -> dict:
"""
Extended positive control validation with recall@k and per-source breakdown.
Combines base percentile validation with recall@k metrics and per-source analysis
to provide comprehensive validation for Phase 6.
Args:
store: PipelineStore with scored_genes table
percentile_threshold: Minimum median percentile for validation (default 0.75)
Returns:
Dict with keys:
- All keys from validate_known_gene_ranking() (base metrics)
- recall_at_k: dict from compute_recall_at_k() (recalls_absolute, recalls_percentage, etc.)
- per_source_breakdown: dict mapping source -> {median_percentile, count, top_quartile_count}
Notes:
- Per-source breakdown separates OMIM Usher (10 genes) from SYSCILIA SCGS v2 (28 genes)
- Uses same PERCENT_RANK CTE pattern but filters JOIN by source
- Allows detecting if one gene set validates better than the other
"""
logger.info("validate_positive_controls_extended_start", threshold=percentile_threshold)
# Get base metrics from existing validation function
base_metrics = validate_known_gene_ranking(store, percentile_threshold)
# Compute recall@k metrics
recall_metrics = compute_recall_at_k(store)
# Compute per-source breakdown
known_df = compile_known_genes()
sources = known_df["source"].unique().to_list()
per_source_breakdown = {}
for source in sources:
# Filter known genes to current source
source_genes = known_df.filter(pl.col("source") == source)
# Register as temp table
store.conn.execute("DROP TABLE IF EXISTS _source_genes")
store.conn.execute("CREATE TEMP TABLE _source_genes AS SELECT * FROM source_genes")
# Query with same PERCENT_RANK pattern
query = """
WITH ranked_genes AS (
SELECT
gene_symbol,
composite_score,
PERCENT_RANK() OVER (ORDER BY composite_score) AS percentile_rank
FROM scored_genes
WHERE composite_score IS NOT NULL
)
SELECT
rg.gene_symbol,
rg.composite_score,
rg.percentile_rank
FROM ranked_genes rg
INNER JOIN _source_genes sg ON rg.gene_symbol = sg.gene_symbol
ORDER BY rg.percentile_rank DESC
"""
result = store.conn.execute(query).pl()
# Clean up temp table
store.conn.execute("DROP TABLE IF EXISTS _source_genes")
if result.height == 0:
per_source_breakdown[source] = {
"median_percentile": None,
"count": 0,
"top_quartile_count": 0,
}
continue
median_percentile = float(result["percentile_rank"].median())
count = result.height
top_quartile_count = result.filter(pl.col("percentile_rank") >= 0.75).height
per_source_breakdown[source] = {
"median_percentile": median_percentile,
"count": count,
"top_quartile_count": top_quartile_count,
}
logger.info(
"per_source_validation",
source=source,
median_percentile=f"{median_percentile:.4f}",
count=count,
top_quartile_count=top_quartile_count,
)
# Combine all metrics
extended_metrics = {
**base_metrics,
"recall_at_k": recall_metrics,
"per_source_breakdown": per_source_breakdown,
}
logger.info(
"validate_positive_controls_extended_complete",
validation_passed=base_metrics["validation_passed"],
recall_at_10pct=f"{recall_metrics['recalls_percentage'].get('10%', 0.0):.4f}",
)
return extended_metrics