feat(06-01): add recall@k metrics and extended positive control validation
- compute_recall_at_k() measures recall at absolute (100, 500, 1000, 2000) and percentage (5%, 10%, 20%) thresholds - validate_positive_controls_extended() combines percentile + recall + per-source metrics - Per-source breakdown separates OMIM Usher from SYSCILIA SCGS v2 validation - Export all new functions in __init__.py including negative control imports
This commit is contained in:
@@ -17,6 +17,14 @@ from usher_pipeline.scoring.quality_control import (
|
||||
from usher_pipeline.scoring.validation import (
|
||||
validate_known_gene_ranking,
|
||||
generate_validation_report,
|
||||
compute_recall_at_k,
|
||||
validate_positive_controls_extended,
|
||||
)
|
||||
from usher_pipeline.scoring.negative_controls import (
|
||||
HOUSEKEEPING_GENES_CORE,
|
||||
compile_housekeeping_genes,
|
||||
validate_negative_controls,
|
||||
generate_negative_control_report,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -30,4 +38,10 @@ __all__ = [
|
||||
"run_qc_checks",
|
||||
"validate_known_gene_ranking",
|
||||
"generate_validation_report",
|
||||
"compute_recall_at_k",
|
||||
"validate_positive_controls_extended",
|
||||
"HOUSEKEEPING_GENES_CORE",
|
||||
"compile_housekeeping_genes",
|
||||
"validate_negative_controls",
|
||||
"generate_negative_control_report",
|
||||
]
|
||||
|
||||
@@ -225,3 +225,229 @@ This indicates either:
|
||||
)
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
def compute_recall_at_k(
|
||||
store: PipelineStore,
|
||||
k_values: list[int] | None = None
|
||||
) -> dict:
|
||||
"""
|
||||
Compute recall@k metrics for known genes at various thresholds.
|
||||
|
||||
Measures what fraction of known genes appear in the top-k ranked candidates,
|
||||
providing specific metrics like ">70% recall in top 10%" required by success criteria.
|
||||
|
||||
Args:
|
||||
store: PipelineStore with scored_genes table
|
||||
k_values: Absolute top-k thresholds (default: [100, 500, 1000, 2000])
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- recalls_absolute: dict mapping k -> recall float (e.g., {100: 0.65, 500: 0.85})
|
||||
- recalls_percentage: dict mapping pct_string -> recall float (e.g., {"5%": 0.58, "10%": 0.72})
|
||||
- total_known_unique: int - count of unique known genes (deduplicated)
|
||||
- total_scored: int - count of genes with non-NULL composite scores
|
||||
|
||||
Notes:
|
||||
- Known genes are deduplicated on gene_symbol (genes in both sources count once)
|
||||
- Recall@k = (known genes in top-k) / total_known_unique
|
||||
- Percentage thresholds computed at 5%, 10%, 20% of total_scored
|
||||
- Genes without composite_score (NULL) are excluded
|
||||
- Ordered by composite_score DESC (highest scores first)
|
||||
"""
|
||||
logger.info("compute_recall_at_k_start")
|
||||
|
||||
# Default k values
|
||||
if k_values is None:
|
||||
k_values = [100, 500, 1000, 2000]
|
||||
|
||||
# Compile known genes and deduplicate on gene_symbol
|
||||
known_df = compile_known_genes()
|
||||
known_genes_set = set(known_df["gene_symbol"].unique())
|
||||
total_known_unique = len(known_genes_set)
|
||||
|
||||
# Get total count of scored genes
|
||||
total_scored = store.conn.execute("""
|
||||
SELECT COUNT(*) as total
|
||||
FROM scored_genes
|
||||
WHERE composite_score IS NOT NULL
|
||||
""").fetchone()[0]
|
||||
|
||||
# Compute percentage thresholds
|
||||
percentage_thresholds = [0.05, 0.10, 0.20] # 5%, 10%, 20%
|
||||
percentage_k_values = {
|
||||
f"{int(pct * 100)}%": int(total_scored * pct)
|
||||
for pct in percentage_thresholds
|
||||
}
|
||||
|
||||
# Query top-k genes for each threshold and compute recall
|
||||
recalls_absolute = {}
|
||||
for k in k_values:
|
||||
query = f"""
|
||||
SELECT gene_symbol
|
||||
FROM scored_genes
|
||||
WHERE composite_score IS NOT NULL
|
||||
ORDER BY composite_score DESC
|
||||
LIMIT {k}
|
||||
"""
|
||||
top_k_genes = store.conn.execute(query).pl()
|
||||
top_k_set = set(top_k_genes["gene_symbol"])
|
||||
|
||||
# Count how many known genes are in top-k
|
||||
known_in_top_k = len(known_genes_set & top_k_set)
|
||||
recall = known_in_top_k / total_known_unique if total_known_unique > 0 else 0.0
|
||||
recalls_absolute[k] = recall
|
||||
|
||||
logger.info(
|
||||
"recall_at_k_absolute",
|
||||
k=k,
|
||||
recall=f"{recall:.4f}",
|
||||
known_in_top_k=known_in_top_k,
|
||||
total_known=total_known_unique,
|
||||
)
|
||||
|
||||
# Compute recall at percentage thresholds
|
||||
recalls_percentage = {}
|
||||
for pct_string, k in percentage_k_values.items():
|
||||
query = f"""
|
||||
SELECT gene_symbol
|
||||
FROM scored_genes
|
||||
WHERE composite_score IS NOT NULL
|
||||
ORDER BY composite_score DESC
|
||||
LIMIT {k}
|
||||
"""
|
||||
top_k_genes = store.conn.execute(query).pl()
|
||||
top_k_set = set(top_k_genes["gene_symbol"])
|
||||
|
||||
known_in_top_k = len(known_genes_set & top_k_set)
|
||||
recall = known_in_top_k / total_known_unique if total_known_unique > 0 else 0.0
|
||||
recalls_percentage[pct_string] = recall
|
||||
|
||||
logger.info(
|
||||
"recall_at_k_percentage",
|
||||
threshold=pct_string,
|
||||
k=k,
|
||||
recall=f"{recall:.4f}",
|
||||
known_in_top_k=known_in_top_k,
|
||||
total_known=total_known_unique,
|
||||
)
|
||||
|
||||
return {
|
||||
"recalls_absolute": recalls_absolute,
|
||||
"recalls_percentage": recalls_percentage,
|
||||
"total_known_unique": total_known_unique,
|
||||
"total_scored": total_scored,
|
||||
}
|
||||
|
||||
|
||||
def validate_positive_controls_extended(
|
||||
store: PipelineStore,
|
||||
percentile_threshold: float = 0.75
|
||||
) -> dict:
|
||||
"""
|
||||
Extended positive control validation with recall@k and per-source breakdown.
|
||||
|
||||
Combines base percentile validation with recall@k metrics and per-source analysis
|
||||
to provide comprehensive validation for Phase 6.
|
||||
|
||||
Args:
|
||||
store: PipelineStore with scored_genes table
|
||||
percentile_threshold: Minimum median percentile for validation (default 0.75)
|
||||
|
||||
Returns:
|
||||
Dict with keys:
|
||||
- All keys from validate_known_gene_ranking() (base metrics)
|
||||
- recall_at_k: dict from compute_recall_at_k() (recalls_absolute, recalls_percentage, etc.)
|
||||
- per_source_breakdown: dict mapping source -> {median_percentile, count, top_quartile_count}
|
||||
|
||||
Notes:
|
||||
- Per-source breakdown separates OMIM Usher (10 genes) from SYSCILIA SCGS v2 (28 genes)
|
||||
- Uses same PERCENT_RANK CTE pattern but filters JOIN by source
|
||||
- Allows detecting if one gene set validates better than the other
|
||||
"""
|
||||
logger.info("validate_positive_controls_extended_start", threshold=percentile_threshold)
|
||||
|
||||
# Get base metrics from existing validation function
|
||||
base_metrics = validate_known_gene_ranking(store, percentile_threshold)
|
||||
|
||||
# Compute recall@k metrics
|
||||
recall_metrics = compute_recall_at_k(store)
|
||||
|
||||
# Compute per-source breakdown
|
||||
known_df = compile_known_genes()
|
||||
sources = known_df["source"].unique().to_list()
|
||||
|
||||
per_source_breakdown = {}
|
||||
|
||||
for source in sources:
|
||||
# Filter known genes to current source
|
||||
source_genes = known_df.filter(pl.col("source") == source)
|
||||
|
||||
# Register as temp table
|
||||
store.conn.execute("DROP TABLE IF EXISTS _source_genes")
|
||||
store.conn.execute("CREATE TEMP TABLE _source_genes AS SELECT * FROM source_genes")
|
||||
|
||||
# Query with same PERCENT_RANK pattern
|
||||
query = """
|
||||
WITH ranked_genes AS (
|
||||
SELECT
|
||||
gene_symbol,
|
||||
composite_score,
|
||||
PERCENT_RANK() OVER (ORDER BY composite_score) AS percentile_rank
|
||||
FROM scored_genes
|
||||
WHERE composite_score IS NOT NULL
|
||||
)
|
||||
SELECT
|
||||
rg.gene_symbol,
|
||||
rg.composite_score,
|
||||
rg.percentile_rank
|
||||
FROM ranked_genes rg
|
||||
INNER JOIN _source_genes sg ON rg.gene_symbol = sg.gene_symbol
|
||||
ORDER BY rg.percentile_rank DESC
|
||||
"""
|
||||
|
||||
result = store.conn.execute(query).pl()
|
||||
|
||||
# Clean up temp table
|
||||
store.conn.execute("DROP TABLE IF EXISTS _source_genes")
|
||||
|
||||
if result.height == 0:
|
||||
per_source_breakdown[source] = {
|
||||
"median_percentile": None,
|
||||
"count": 0,
|
||||
"top_quartile_count": 0,
|
||||
}
|
||||
continue
|
||||
|
||||
median_percentile = float(result["percentile_rank"].median())
|
||||
count = result.height
|
||||
top_quartile_count = result.filter(pl.col("percentile_rank") >= 0.75).height
|
||||
|
||||
per_source_breakdown[source] = {
|
||||
"median_percentile": median_percentile,
|
||||
"count": count,
|
||||
"top_quartile_count": top_quartile_count,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"per_source_validation",
|
||||
source=source,
|
||||
median_percentile=f"{median_percentile:.4f}",
|
||||
count=count,
|
||||
top_quartile_count=top_quartile_count,
|
||||
)
|
||||
|
||||
# Combine all metrics
|
||||
extended_metrics = {
|
||||
**base_metrics,
|
||||
"recall_at_k": recall_metrics,
|
||||
"per_source_breakdown": per_source_breakdown,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"validate_positive_controls_extended_complete",
|
||||
validation_passed=base_metrics["validation_passed"],
|
||||
recall_at_10pct=f"{recall_metrics['recalls_percentage'].get('10%', 0.0):.4f}",
|
||||
)
|
||||
|
||||
return extended_metrics
|
||||
|
||||
Reference in New Issue
Block a user