diff --git a/src/usher_pipeline/evidence/gnomad/__init__.py b/src/usher_pipeline/evidence/gnomad/__init__.py index 911054d..3795578 100644 --- a/src/usher_pipeline/evidence/gnomad/__init__.py +++ b/src/usher_pipeline/evidence/gnomad/__init__.py @@ -2,10 +2,18 @@ from usher_pipeline.evidence.gnomad.models import ConstraintRecord, GNOMAD_CONSTRAINT_URL from usher_pipeline.evidence.gnomad.fetch import download_constraint_metrics, parse_constraint_tsv +from usher_pipeline.evidence.gnomad.transform import ( + filter_by_coverage, + normalize_scores, + process_gnomad_constraint, +) __all__ = [ "ConstraintRecord", "GNOMAD_CONSTRAINT_URL", "download_constraint_metrics", "parse_constraint_tsv", + "filter_by_coverage", + "normalize_scores", + "process_gnomad_constraint", ] diff --git a/src/usher_pipeline/evidence/gnomad/fetch.py b/src/usher_pipeline/evidence/gnomad/fetch.py index 63d81fa..b3b9f58 100644 --- a/src/usher_pipeline/evidence/gnomad/fetch.py +++ b/src/usher_pipeline/evidence/gnomad/fetch.py @@ -144,11 +144,15 @@ def parse_constraint_tsv(tsv_path: Path) -> pl.LazyFrame: ) # Map actual columns to our standardized names + # Track which source columns we've already used to avoid duplicates column_mapping = {} + used_source_cols = set() + for our_name, variants in COLUMN_VARIANTS.items(): for variant in variants: - if variant in actual_columns: + if variant in actual_columns and variant not in used_source_cols: column_mapping[variant] = our_name + used_source_cols.add(variant) break if not column_mapping: @@ -164,4 +168,13 @@ def parse_constraint_tsv(tsv_path: Path) -> pl.LazyFrame: # Select and rename mapped columns lf = lf.select([pl.col(old).alias(new) for old, new in column_mapping.items()]) + # Special case: if loeuf is mapped but loeuf_upper is not, duplicate loeuf to loeuf_upper + # (In gnomAD, the "upper" value IS the LOEUF we use) + mapped_names = set(column_mapping.values()) + if "loeuf" in mapped_names and "loeuf_upper" not in mapped_names: + lf = lf.with_columns(pl.col("loeuf").alias("loeuf_upper")) + elif "loeuf_upper" in mapped_names and "loeuf" not in mapped_names: + # If we only got loeuf_upper, copy it to loeuf + lf = lf.with_columns(pl.col("loeuf_upper").alias("loeuf")) + return lf diff --git a/src/usher_pipeline/evidence/gnomad/models.py b/src/usher_pipeline/evidence/gnomad/models.py index d4162d9..824f8e2 100644 --- a/src/usher_pipeline/evidence/gnomad/models.py +++ b/src/usher_pipeline/evidence/gnomad/models.py @@ -9,15 +9,18 @@ GNOMAD_CONSTRAINT_URL = ( ) # Column name mapping for different gnomAD versions -# v2.1.1 uses: gene, transcript, pLI, oe_lof_upper (LOEUF), mean_proportion_covered_bases -# v4.x uses: gene, transcript, mane_select, lof.pLI, lof.oe_ci.upper, mean_proportion_covered +# v2.1.1 uses: gene, transcript, pLI, oe_lof_upper (LOEUF upper CI), mean_proportion_covered_bases +# v4.x uses: gene, transcript, mane_select, lof.pLI, lof.oe_ci.upper (LOEUF), mean_proportion_covered +# NOTE: In gnomAD data, what's called "upper" is actually the LOEUF value we want (observed/expected upper bound) COLUMN_VARIANTS = { "gene_id": ["gene", "gene_id"], "gene_symbol": ["gene_symbol", "gene"], "transcript": ["transcript", "canonical_transcript", "mane_select"], "pli": ["pLI", "lof.pLI", "pli"], - "loeuf": ["oe_lof_upper", "lof.oe_ci.upper", "oe_lof", "loeuf"], - "loeuf_upper": ["oe_lof_upper_ci", "lof.oe_ci.upper", "oe_lof_upper"], + # LOEUF is the "upper" column in gnomAD (oe_lof_upper = observed/expected upper bound) + "loeuf": ["lof.oe_ci.upper", "oe_lof_upper", "oe_lof", "loeuf"], + # loeuf_upper is typically the same as loeuf in gnomAD data (they report the upper CI) + "loeuf_upper": ["lof.oe_ci.upper", "oe_lof_upper_ci", "oe_lof_upper"], "mean_depth": ["mean_coverage", "mean_depth", "mean_cov"], "cds_covered_pct": [ "mean_proportion_covered_bases", diff --git a/src/usher_pipeline/evidence/gnomad/transform.py b/src/usher_pipeline/evidence/gnomad/transform.py new file mode 100644 index 0000000..19e7fd2 --- /dev/null +++ b/src/usher_pipeline/evidence/gnomad/transform.py @@ -0,0 +1,150 @@ +"""Transform and normalize gnomAD constraint metrics.""" + +from pathlib import Path + +import polars as pl +import structlog + +from usher_pipeline.evidence.gnomad.fetch import parse_constraint_tsv + +logger = structlog.get_logger() + + +def filter_by_coverage( + lf: pl.LazyFrame, + min_depth: float = 30.0, + min_cds_pct: float = 0.9, +) -> pl.LazyFrame: + """Add quality_flag column based on coverage thresholds. + + Does NOT drop any genes - preserves all rows with quality categorization. + "Unknown" constraint is semantically different from "zero" constraint. + + Args: + lf: LazyFrame with gnomAD constraint data + min_depth: Minimum mean sequencing depth (default: 30x) + min_cds_pct: Minimum CDS coverage fraction (default: 0.9 = 90%) + + Returns: + LazyFrame with quality_flag column added: + - "measured": Good coverage AND has LOEUF estimate + - "incomplete_coverage": Coverage below thresholds + - "no_data": Both LOEUF and pLI are NULL + """ + # Ensure numeric columns are properly cast to float (handles edge cases with mixed types) + lf = lf.with_columns([ + pl.col("mean_depth").cast(pl.Float64, strict=False), + pl.col("cds_covered_pct").cast(pl.Float64, strict=False), + pl.col("loeuf").cast(pl.Float64, strict=False), + pl.col("pli").cast(pl.Float64, strict=False), + ]) + + return lf.with_columns( + pl.when( + pl.col("mean_depth").is_not_null() + & pl.col("cds_covered_pct").is_not_null() + & (pl.col("mean_depth") >= min_depth) + & (pl.col("cds_covered_pct") >= min_cds_pct) + & pl.col("loeuf").is_not_null() + ) + .then(pl.lit("measured")) + .when(pl.col("loeuf").is_null() & pl.col("pli").is_null()) + .then(pl.lit("no_data")) + .when( + pl.col("mean_depth").is_not_null() + & pl.col("cds_covered_pct").is_not_null() + & ((pl.col("mean_depth") < min_depth) | (pl.col("cds_covered_pct") < min_cds_pct)) + ) + .then(pl.lit("incomplete_coverage")) + .otherwise(pl.lit("incomplete_coverage")) + .alias("quality_flag") + ) + + +def normalize_scores(lf: pl.LazyFrame) -> pl.LazyFrame: + """Normalize LOEUF scores to 0-1 range with inversion. + + Lower LOEUF = more constrained = HIGHER normalized score. + Only genes with quality_flag="measured" get normalized scores. + Others get NULL (not 0.0 - "unknown" != "zero constraint"). + + Args: + lf: LazyFrame with gnomAD constraint data and quality_flag column + + Returns: + LazyFrame with loeuf_normalized column added + """ + # Compute min/max from measured genes only + measured = lf.filter(pl.col("quality_flag") == "measured") + + # Aggregate min/max in a single pass + stats = measured.select( + pl.col("loeuf").min().alias("loeuf_min"), + pl.col("loeuf").max().alias("loeuf_max"), + ).collect() + + if len(stats) == 0: + # No measured genes - all get NULL + return lf.with_columns(pl.lit(None).cast(pl.Float64).alias("loeuf_normalized")) + + loeuf_min = stats["loeuf_min"][0] + loeuf_max = stats["loeuf_max"][0] + + if loeuf_min is None or loeuf_max is None or loeuf_min == loeuf_max: + # Handle edge case: all measured genes have same LOEUF + return lf.with_columns(pl.lit(None).cast(pl.Float64).alias("loeuf_normalized")) + + # Invert: lower LOEUF -> higher score + # Formula: (max - value) / (max - min) + return lf.with_columns( + pl.when(pl.col("quality_flag") == "measured") + .then((loeuf_max - pl.col("loeuf")) / (loeuf_max - loeuf_min)) + .otherwise(pl.lit(None)) + .alias("loeuf_normalized") + ) + + +def process_gnomad_constraint( + tsv_path: Path, + min_depth: float = 30.0, + min_cds_pct: float = 0.9, +) -> pl.DataFrame: + """Full gnomAD constraint processing pipeline. + + Composes: parse -> filter_by_coverage -> normalize_scores -> collect + + Args: + tsv_path: Path to gnomAD constraint TSV file + min_depth: Minimum mean sequencing depth (default: 30x) + min_cds_pct: Minimum CDS coverage fraction (default: 0.9) + + Returns: + Materialized DataFrame ready for DuckDB storage + """ + logger.info("gnomad_process_start", tsv_path=str(tsv_path)) + + # Parse with lazy evaluation + lf = parse_constraint_tsv(tsv_path) + + # Filter and normalize + lf = filter_by_coverage(lf, min_depth=min_depth, min_cds_pct=min_cds_pct) + lf = normalize_scores(lf) + + # Materialize + df = lf.collect() + + # Log summary statistics + stats = df.group_by("quality_flag").len().sort("quality_flag") + total = len(df) + + logger.info( + "gnomad_process_complete", + total_genes=total, + measured=df.filter(pl.col("quality_flag") == "measured").height, + incomplete_coverage=df.filter( + pl.col("quality_flag") == "incomplete_coverage" + ).height, + no_data=df.filter(pl.col("quality_flag") == "no_data").height, + ) + + return df diff --git a/tests/test_gnomad.py b/tests/test_gnomad.py new file mode 100644 index 0000000..279b744 --- /dev/null +++ b/tests/test_gnomad.py @@ -0,0 +1,343 @@ +"""Unit tests for gnomAD constraint evidence layer.""" + +from pathlib import Path +from unittest.mock import Mock, patch + +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +from usher_pipeline.evidence.gnomad.models import ConstraintRecord +from usher_pipeline.evidence.gnomad.fetch import download_constraint_metrics, parse_constraint_tsv +from usher_pipeline.evidence.gnomad.transform import ( + filter_by_coverage, + normalize_scores, + process_gnomad_constraint, +) + + +@pytest.fixture +def sample_constraint_tsv(tmp_path: Path) -> Path: + """Create a sample gnomAD constraint TSV for testing. + + Covers edge cases: + - Normal genes with good coverage (measured) + - Low depth genes (<30x) + - Low CDS coverage genes (<90%) + - NULL LOEUF/pLI (no_data) + - Extreme LOEUF values for normalization bounds + """ + tsv_path = tmp_path / "constraint.tsv" + + # Use gnomAD v4.x-style column names + content = """gene\ttranscript\tgene_symbol\tlof.pLI\tlof.oe_ci.upper\tmean_coverage\tmean_proportion_covered +ENSG00000001\tENST00000001\tGENE1\t0.95\t0.15\t45.0\t0.98 +ENSG00000002\tENST00000002\tGENE2\t0.80\t0.85\t50.0\t0.95 +ENSG00000003\tENST00000003\tGENE3\t0.10\t2.50\t40.0\t0.92 +ENSG00000004\tENST00000004\tGENE4\t0.50\t0.0\t55.0\t0.97 +ENSG00000005\tENST00000005\tGENE5\t0.20\t1.20\t25.0\t0.85 +ENSG00000006\tENST00000006\tGENE6\t0.70\t0.45\t35.0\t0.75 +ENSG00000007\tENST00000007\tGENE7\tNA\tNA\t60.0\t0.99 +ENSG00000008\tENST00000008\tGENE8\t0.60\t0.30\t50.0\t0.90 +ENSG00000009\tENST00000009\tGENE9\t.\t.\t.\t. +ENSG00000010\tENST00000010\tGENE10\t0.90\t0.18\t48.0\t0.94 +ENSG00000011\tENST00000011\tGENE11\t0.15\t1.80\t32.0\t0.91 +ENSG00000012\tENST00000012\tGENE12\t0.85\t0.22\t10.0\t0.50 +ENSG00000013\tENST00000013\tGENE13\t0.40\t0.65\t38.0\t0.88 +ENSG00000014\tENST00000014\tGENE14\tNA\t0.75\t42.0\t0.93 +ENSG00000015\tENST00000015\tGENE15\t0.75\tNA\t47.0\t0.96 +""" + + tsv_path.write_text(content) + return tsv_path + + +def test_parse_constraint_tsv_returns_lazyframe(sample_constraint_tsv: Path): + """Verify parse returns LazyFrame with expected columns.""" + lf = parse_constraint_tsv(sample_constraint_tsv) + + assert isinstance(lf, pl.LazyFrame) + + # Collect to check columns + df = lf.collect() + expected_columns = {"gene_id", "gene_symbol", "transcript", "pli", "loeuf", "mean_depth", "cds_covered_pct"} + assert expected_columns.issubset(set(df.columns)) + + +def test_parse_constraint_tsv_null_handling(sample_constraint_tsv: Path): + """NA/empty values become polars null, not zero.""" + lf = parse_constraint_tsv(sample_constraint_tsv) + df = lf.collect() + + # GENE7 has "NA" for pli and loeuf + gene7 = df.filter(pl.col("gene_symbol") == "GENE7") + assert gene7["pli"][0] is None + assert gene7["loeuf"][0] is None + + # GENE9 has "." for all values + gene9 = df.filter(pl.col("gene_symbol") == "GENE9") + assert gene9["pli"][0] is None + assert gene9["loeuf"][0] is None + assert gene9["mean_depth"][0] is None + + +def test_filter_by_coverage_measured(sample_constraint_tsv: Path): + """Good coverage genes get quality_flag="measured".""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + df = lf.collect() + + # GENE1: depth=45, coverage=0.98, has LOEUF -> measured + gene1 = df.filter(pl.col("gene_symbol") == "GENE1") + assert gene1["quality_flag"][0] == "measured" + + # GENE8: depth=50, coverage=0.90 (exactly at threshold), has LOEUF -> measured + gene8 = df.filter(pl.col("gene_symbol") == "GENE8") + assert gene8["quality_flag"][0] == "measured" + + +def test_filter_by_coverage_incomplete(sample_constraint_tsv: Path): + """Low depth/CDS genes get quality_flag="incomplete_coverage".""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + df = lf.collect() + + # GENE5: depth=25 (< 30) -> incomplete_coverage + gene5 = df.filter(pl.col("gene_symbol") == "GENE5") + assert gene5["quality_flag"][0] == "incomplete_coverage" + + # GENE6: coverage=0.75 (< 0.9) -> incomplete_coverage + gene6 = df.filter(pl.col("gene_symbol") == "GENE6") + assert gene6["quality_flag"][0] == "incomplete_coverage" + + # GENE12: depth=10 (very low) -> incomplete_coverage + gene12 = df.filter(pl.col("gene_symbol") == "GENE12") + assert gene12["quality_flag"][0] == "incomplete_coverage" + + +def test_filter_by_coverage_no_data(sample_constraint_tsv: Path): + """NULL loeuf+pli genes get quality_flag="no_data".""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + df = lf.collect() + + # GENE7: both pli and loeuf are NULL -> no_data + gene7 = df.filter(pl.col("gene_symbol") == "GENE7") + assert gene7["quality_flag"][0] == "no_data" + + # GENE9: both pli and loeuf are NULL -> no_data + gene9 = df.filter(pl.col("gene_symbol") == "GENE9") + assert gene9["quality_flag"][0] == "no_data" + + +def test_filter_preserves_all_genes(sample_constraint_tsv: Path): + """Row count before == row count after (no genes dropped).""" + lf = parse_constraint_tsv(sample_constraint_tsv) + df_before = lf.collect() + count_before = len(df_before) + + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + df_after = lf.collect() + count_after = len(df_after) + + assert count_before == count_after, "Filter should preserve all genes" + + +def test_normalize_scores_range(sample_constraint_tsv: Path): + """All non-null normalized scores are in [0, 1].""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + lf = normalize_scores(lf) + df = lf.collect() + + # Filter to non-null normalized scores + normalized = df.filter(pl.col("loeuf_normalized").is_not_null()) + + if len(normalized) > 0: + min_score = normalized["loeuf_normalized"].min() + max_score = normalized["loeuf_normalized"].max() + + assert min_score >= 0.0, f"Min normalized score {min_score} < 0" + assert max_score <= 1.0, f"Max normalized score {max_score} > 1" + + +def test_normalize_scores_inversion(sample_constraint_tsv: Path): + """Lower LOEUF -> higher normalized score.""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + lf = normalize_scores(lf) + df = lf.collect() + + # GENE4: LOEUF=0.0 (most constrained) should have highest normalized score + gene4 = df.filter(pl.col("gene_symbol") == "GENE4") + if gene4["quality_flag"][0] == "measured": + assert gene4["loeuf_normalized"][0] is not None + # Should be close to 1.0 (most constrained) + assert gene4["loeuf_normalized"][0] >= 0.95 + + # GENE3: LOEUF=2.50 (least constrained) should have lowest normalized score + gene3 = df.filter(pl.col("gene_symbol") == "GENE3") + if gene3["quality_flag"][0] == "measured": + assert gene3["loeuf_normalized"][0] is not None + # Should be close to 0.0 (least constrained) + assert gene3["loeuf_normalized"][0] <= 0.05 + + +def test_normalize_scores_null_preserved(sample_constraint_tsv: Path): + """NULL loeuf stays NULL after normalization.""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + lf = normalize_scores(lf) + df = lf.collect() + + # GENE7: NULL loeuf -> NULL normalized + gene7 = df.filter(pl.col("gene_symbol") == "GENE7") + assert gene7["loeuf"][0] is None + assert gene7["loeuf_normalized"][0] is None + + +def test_normalize_scores_incomplete_stays_null(sample_constraint_tsv: Path): + """incomplete_coverage genes get NULL normalized score.""" + lf = parse_constraint_tsv(sample_constraint_tsv) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + lf = normalize_scores(lf) + df = lf.collect() + + # GENE5: incomplete_coverage -> NULL normalized + gene5 = df.filter(pl.col("gene_symbol") == "GENE5") + assert gene5["quality_flag"][0] == "incomplete_coverage" + assert gene5["loeuf_normalized"][0] is None + + # GENE6: incomplete_coverage -> NULL normalized + gene6 = df.filter(pl.col("gene_symbol") == "GENE6") + assert gene6["quality_flag"][0] == "incomplete_coverage" + assert gene6["loeuf_normalized"][0] is None + + +def test_process_gnomad_constraint_end_to_end(sample_constraint_tsv: Path): + """Full pipeline returns DataFrame with all expected columns.""" + df = process_gnomad_constraint(sample_constraint_tsv, min_depth=30.0, min_cds_pct=0.9) + + # Check it's a materialized DataFrame + assert isinstance(df, pl.DataFrame) + + # Check all expected columns exist + expected_columns = { + "gene_id", + "gene_symbol", + "transcript", + "pli", + "loeuf", + "mean_depth", + "cds_covered_pct", + "quality_flag", + "loeuf_normalized", + } + assert expected_columns.issubset(set(df.columns)) + + # Check we have genes in each category + measured_count = df.filter(pl.col("quality_flag") == "measured").height + incomplete_count = df.filter(pl.col("quality_flag") == "incomplete_coverage").height + no_data_count = df.filter(pl.col("quality_flag") == "no_data").height + + assert measured_count > 0, "Should have some measured genes" + assert incomplete_count > 0, "Should have some incomplete_coverage genes" + assert no_data_count > 0, "Should have some no_data genes" + + +def test_constraint_record_model_validation(): + """ConstraintRecord validates correctly, rejects bad types.""" + # Valid record + valid = ConstraintRecord( + gene_id="ENSG00000001", + gene_symbol="GENE1", + transcript="ENST00000001", + pli=0.95, + loeuf=0.15, + loeuf_upper=0.20, + mean_depth=45.0, + cds_covered_pct=0.98, + quality_flag="measured", + loeuf_normalized=0.85, + ) + assert valid.gene_id == "ENSG00000001" + assert valid.loeuf_normalized == 0.85 + + # NULL values are OK + with_nulls = ConstraintRecord( + gene_id="ENSG00000002", + gene_symbol="GENE2", + transcript="ENST00000002", + pli=None, + loeuf=None, + quality_flag="no_data", + loeuf_normalized=None, + ) + assert with_nulls.pli is None + assert with_nulls.loeuf is None + assert with_nulls.loeuf_normalized is None + + # Invalid type should raise ValidationError + with pytest.raises(Exception): # Pydantic ValidationError + ConstraintRecord( + gene_id=12345, # Should be string + gene_symbol="GENE3", + transcript="ENST00000003", + ) + + +@patch("usher_pipeline.evidence.gnomad.fetch.httpx.stream") +def test_download_skips_if_exists(mock_stream: Mock, tmp_path: Path): + """download_constraint_metrics returns early if file exists and force=False.""" + output_path = tmp_path / "constraint.tsv" + + # Create existing file + output_path.write_text("gene\ttranscript\npli\tloeuf\n") + + # Call download with force=False + result = download_constraint_metrics(output_path, force=False) + + # Should return early without making HTTP request + assert result == output_path + mock_stream.assert_not_called() + + +@patch("usher_pipeline.evidence.gnomad.fetch.httpx.stream") +def test_download_forces_redownload(mock_stream: Mock, tmp_path: Path): + """download_constraint_metrics re-downloads when force=True.""" + output_path = tmp_path / "constraint.tsv" + + # Create existing file + output_path.write_text("old content") + + # Mock HTTP response + mock_response = Mock() + mock_response.headers = {"content-length": "100"} + mock_response.iter_bytes = Mock(return_value=[b"gene\ttranscript\n", b"data\n"]) + mock_response.raise_for_status = Mock() + mock_stream.return_value.__enter__.return_value = mock_response + + # Call download with force=True + result = download_constraint_metrics(output_path, force=True) + + # Should make HTTP request + assert result == output_path + mock_stream.assert_called_once() + + +def test_filter_by_coverage_handles_missing_columns(tmp_path: Path): + """filter_by_coverage handles genes with missing mean_depth or cds_covered_pct.""" + tsv_path = tmp_path / "partial.tsv" + content = """gene\ttranscript\tgene_symbol\tlof.pLI\tlof.oe_ci.upper\tmean_coverage\tmean_proportion_covered +ENSG00000001\tENST00000001\tGENE1\t0.95\t0.15\t.\t. +""" + tsv_path.write_text(content) + + lf = parse_constraint_tsv(tsv_path) + lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9) + df = lf.collect() + + # GENE1 has NULL depth/coverage but has LOEUF -> should be incomplete_coverage + gene1 = df.filter(pl.col("gene_symbol") == "GENE1") + # With NULL depth/coverage, comparisons will be false, so it goes to incomplete_coverage + assert gene1["quality_flag"][0] in ["incomplete_coverage", "no_data"]