diff --git a/src/usher_pipeline/scoring/__init__.py b/src/usher_pipeline/scoring/__init__.py index fb1b0ae..43d19a0 100644 --- a/src/usher_pipeline/scoring/__init__.py +++ b/src/usher_pipeline/scoring/__init__.py @@ -17,6 +17,14 @@ from usher_pipeline.scoring.quality_control import ( from usher_pipeline.scoring.validation import ( validate_known_gene_ranking, generate_validation_report, + compute_recall_at_k, + validate_positive_controls_extended, +) +from usher_pipeline.scoring.negative_controls import ( + HOUSEKEEPING_GENES_CORE, + compile_housekeeping_genes, + validate_negative_controls, + generate_negative_control_report, ) __all__ = [ @@ -30,4 +38,10 @@ __all__ = [ "run_qc_checks", "validate_known_gene_ranking", "generate_validation_report", + "compute_recall_at_k", + "validate_positive_controls_extended", + "HOUSEKEEPING_GENES_CORE", + "compile_housekeeping_genes", + "validate_negative_controls", + "generate_negative_control_report", ] diff --git a/src/usher_pipeline/scoring/validation.py b/src/usher_pipeline/scoring/validation.py index 81eddf5..a104563 100644 --- a/src/usher_pipeline/scoring/validation.py +++ b/src/usher_pipeline/scoring/validation.py @@ -225,3 +225,229 @@ This indicates either: ) return "\n".join(report) + + +def compute_recall_at_k( + store: PipelineStore, + k_values: list[int] | None = None +) -> dict: + """ + Compute recall@k metrics for known genes at various thresholds. + + Measures what fraction of known genes appear in the top-k ranked candidates, + providing specific metrics like ">70% recall in top 10%" required by success criteria. + + Args: + store: PipelineStore with scored_genes table + k_values: Absolute top-k thresholds (default: [100, 500, 1000, 2000]) + + Returns: + Dict with keys: + - recalls_absolute: dict mapping k -> recall float (e.g., {100: 0.65, 500: 0.85}) + - recalls_percentage: dict mapping pct_string -> recall float (e.g., {"5%": 0.58, "10%": 0.72}) + - total_known_unique: int - count of unique known genes (deduplicated) + - total_scored: int - count of genes with non-NULL composite scores + + Notes: + - Known genes are deduplicated on gene_symbol (genes in both sources count once) + - Recall@k = (known genes in top-k) / total_known_unique + - Percentage thresholds computed at 5%, 10%, 20% of total_scored + - Genes without composite_score (NULL) are excluded + - Ordered by composite_score DESC (highest scores first) + """ + logger.info("compute_recall_at_k_start") + + # Default k values + if k_values is None: + k_values = [100, 500, 1000, 2000] + + # Compile known genes and deduplicate on gene_symbol + known_df = compile_known_genes() + known_genes_set = set(known_df["gene_symbol"].unique()) + total_known_unique = len(known_genes_set) + + # Get total count of scored genes + total_scored = store.conn.execute(""" + SELECT COUNT(*) as total + FROM scored_genes + WHERE composite_score IS NOT NULL + """).fetchone()[0] + + # Compute percentage thresholds + percentage_thresholds = [0.05, 0.10, 0.20] # 5%, 10%, 20% + percentage_k_values = { + f"{int(pct * 100)}%": int(total_scored * pct) + for pct in percentage_thresholds + } + + # Query top-k genes for each threshold and compute recall + recalls_absolute = {} + for k in k_values: + query = f""" + SELECT gene_symbol + FROM scored_genes + WHERE composite_score IS NOT NULL + ORDER BY composite_score DESC + LIMIT {k} + """ + top_k_genes = store.conn.execute(query).pl() + top_k_set = set(top_k_genes["gene_symbol"]) + + # Count how many known genes are in top-k + known_in_top_k = len(known_genes_set & top_k_set) + recall = known_in_top_k / total_known_unique if total_known_unique > 0 else 0.0 + recalls_absolute[k] = recall + + logger.info( + "recall_at_k_absolute", + k=k, + recall=f"{recall:.4f}", + known_in_top_k=known_in_top_k, + total_known=total_known_unique, + ) + + # Compute recall at percentage thresholds + recalls_percentage = {} + for pct_string, k in percentage_k_values.items(): + query = f""" + SELECT gene_symbol + FROM scored_genes + WHERE composite_score IS NOT NULL + ORDER BY composite_score DESC + LIMIT {k} + """ + top_k_genes = store.conn.execute(query).pl() + top_k_set = set(top_k_genes["gene_symbol"]) + + known_in_top_k = len(known_genes_set & top_k_set) + recall = known_in_top_k / total_known_unique if total_known_unique > 0 else 0.0 + recalls_percentage[pct_string] = recall + + logger.info( + "recall_at_k_percentage", + threshold=pct_string, + k=k, + recall=f"{recall:.4f}", + known_in_top_k=known_in_top_k, + total_known=total_known_unique, + ) + + return { + "recalls_absolute": recalls_absolute, + "recalls_percentage": recalls_percentage, + "total_known_unique": total_known_unique, + "total_scored": total_scored, + } + + +def validate_positive_controls_extended( + store: PipelineStore, + percentile_threshold: float = 0.75 +) -> dict: + """ + Extended positive control validation with recall@k and per-source breakdown. + + Combines base percentile validation with recall@k metrics and per-source analysis + to provide comprehensive validation for Phase 6. + + Args: + store: PipelineStore with scored_genes table + percentile_threshold: Minimum median percentile for validation (default 0.75) + + Returns: + Dict with keys: + - All keys from validate_known_gene_ranking() (base metrics) + - recall_at_k: dict from compute_recall_at_k() (recalls_absolute, recalls_percentage, etc.) + - per_source_breakdown: dict mapping source -> {median_percentile, count, top_quartile_count} + + Notes: + - Per-source breakdown separates OMIM Usher (10 genes) from SYSCILIA SCGS v2 (28 genes) + - Uses same PERCENT_RANK CTE pattern but filters JOIN by source + - Allows detecting if one gene set validates better than the other + """ + logger.info("validate_positive_controls_extended_start", threshold=percentile_threshold) + + # Get base metrics from existing validation function + base_metrics = validate_known_gene_ranking(store, percentile_threshold) + + # Compute recall@k metrics + recall_metrics = compute_recall_at_k(store) + + # Compute per-source breakdown + known_df = compile_known_genes() + sources = known_df["source"].unique().to_list() + + per_source_breakdown = {} + + for source in sources: + # Filter known genes to current source + source_genes = known_df.filter(pl.col("source") == source) + + # Register as temp table + store.conn.execute("DROP TABLE IF EXISTS _source_genes") + store.conn.execute("CREATE TEMP TABLE _source_genes AS SELECT * FROM source_genes") + + # Query with same PERCENT_RANK pattern + query = """ + WITH ranked_genes AS ( + SELECT + gene_symbol, + composite_score, + PERCENT_RANK() OVER (ORDER BY composite_score) AS percentile_rank + FROM scored_genes + WHERE composite_score IS NOT NULL + ) + SELECT + rg.gene_symbol, + rg.composite_score, + rg.percentile_rank + FROM ranked_genes rg + INNER JOIN _source_genes sg ON rg.gene_symbol = sg.gene_symbol + ORDER BY rg.percentile_rank DESC + """ + + result = store.conn.execute(query).pl() + + # Clean up temp table + store.conn.execute("DROP TABLE IF EXISTS _source_genes") + + if result.height == 0: + per_source_breakdown[source] = { + "median_percentile": None, + "count": 0, + "top_quartile_count": 0, + } + continue + + median_percentile = float(result["percentile_rank"].median()) + count = result.height + top_quartile_count = result.filter(pl.col("percentile_rank") >= 0.75).height + + per_source_breakdown[source] = { + "median_percentile": median_percentile, + "count": count, + "top_quartile_count": top_quartile_count, + } + + logger.info( + "per_source_validation", + source=source, + median_percentile=f"{median_percentile:.4f}", + count=count, + top_quartile_count=top_quartile_count, + ) + + # Combine all metrics + extended_metrics = { + **base_metrics, + "recall_at_k": recall_metrics, + "per_source_breakdown": per_source_breakdown, + } + + logger.info( + "validate_positive_controls_extended_complete", + validation_passed=base_metrics["validation_passed"], + recall_at_10pct=f"{recall_metrics['recalls_percentage'].get('10%', 0.0):.4f}", + ) + + return extended_metrics