Files
usher-exploring/tests/test_gnomad.py
gbanyan 174c4af02d feat(02-01): add gnomAD transform pipeline and comprehensive tests
- Implement filter_by_coverage with quality_flag categorization (measured/incomplete_coverage/no_data)
- Add normalize_scores with LOEUF inversion (lower LOEUF = higher score)
- NULL preservation throughout pipeline (unknown != zero constraint)
- process_gnomad_constraint end-to-end pipeline function
- 15 comprehensive unit tests covering edge cases:
  - NULL handling and preservation
  - Coverage filtering without dropping genes
  - Normalization bounds and inversion
  - Mixed type handling for robust parsing
- Fix column mapping to handle gnomAD v4.x loeuf/loeuf_upper duplication
- All existing tests continue to pass
2026-02-11 18:14:41 +08:00

344 lines
13 KiB
Python

"""Unit tests for gnomAD constraint evidence layer."""
from pathlib import Path
from unittest.mock import Mock, patch
import polars as pl
import pytest
from polars.testing import assert_frame_equal
from usher_pipeline.evidence.gnomad.models import ConstraintRecord
from usher_pipeline.evidence.gnomad.fetch import download_constraint_metrics, parse_constraint_tsv
from usher_pipeline.evidence.gnomad.transform import (
filter_by_coverage,
normalize_scores,
process_gnomad_constraint,
)
@pytest.fixture
def sample_constraint_tsv(tmp_path: Path) -> Path:
"""Create a sample gnomAD constraint TSV for testing.
Covers edge cases:
- Normal genes with good coverage (measured)
- Low depth genes (<30x)
- Low CDS coverage genes (<90%)
- NULL LOEUF/pLI (no_data)
- Extreme LOEUF values for normalization bounds
"""
tsv_path = tmp_path / "constraint.tsv"
# Use gnomAD v4.x-style column names
content = """gene\ttranscript\tgene_symbol\tlof.pLI\tlof.oe_ci.upper\tmean_coverage\tmean_proportion_covered
ENSG00000001\tENST00000001\tGENE1\t0.95\t0.15\t45.0\t0.98
ENSG00000002\tENST00000002\tGENE2\t0.80\t0.85\t50.0\t0.95
ENSG00000003\tENST00000003\tGENE3\t0.10\t2.50\t40.0\t0.92
ENSG00000004\tENST00000004\tGENE4\t0.50\t0.0\t55.0\t0.97
ENSG00000005\tENST00000005\tGENE5\t0.20\t1.20\t25.0\t0.85
ENSG00000006\tENST00000006\tGENE6\t0.70\t0.45\t35.0\t0.75
ENSG00000007\tENST00000007\tGENE7\tNA\tNA\t60.0\t0.99
ENSG00000008\tENST00000008\tGENE8\t0.60\t0.30\t50.0\t0.90
ENSG00000009\tENST00000009\tGENE9\t.\t.\t.\t.
ENSG00000010\tENST00000010\tGENE10\t0.90\t0.18\t48.0\t0.94
ENSG00000011\tENST00000011\tGENE11\t0.15\t1.80\t32.0\t0.91
ENSG00000012\tENST00000012\tGENE12\t0.85\t0.22\t10.0\t0.50
ENSG00000013\tENST00000013\tGENE13\t0.40\t0.65\t38.0\t0.88
ENSG00000014\tENST00000014\tGENE14\tNA\t0.75\t42.0\t0.93
ENSG00000015\tENST00000015\tGENE15\t0.75\tNA\t47.0\t0.96
"""
tsv_path.write_text(content)
return tsv_path
def test_parse_constraint_tsv_returns_lazyframe(sample_constraint_tsv: Path):
"""Verify parse returns LazyFrame with expected columns."""
lf = parse_constraint_tsv(sample_constraint_tsv)
assert isinstance(lf, pl.LazyFrame)
# Collect to check columns
df = lf.collect()
expected_columns = {"gene_id", "gene_symbol", "transcript", "pli", "loeuf", "mean_depth", "cds_covered_pct"}
assert expected_columns.issubset(set(df.columns))
def test_parse_constraint_tsv_null_handling(sample_constraint_tsv: Path):
"""NA/empty values become polars null, not zero."""
lf = parse_constraint_tsv(sample_constraint_tsv)
df = lf.collect()
# GENE7 has "NA" for pli and loeuf
gene7 = df.filter(pl.col("gene_symbol") == "GENE7")
assert gene7["pli"][0] is None
assert gene7["loeuf"][0] is None
# GENE9 has "." for all values
gene9 = df.filter(pl.col("gene_symbol") == "GENE9")
assert gene9["pli"][0] is None
assert gene9["loeuf"][0] is None
assert gene9["mean_depth"][0] is None
def test_filter_by_coverage_measured(sample_constraint_tsv: Path):
"""Good coverage genes get quality_flag="measured"."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
df = lf.collect()
# GENE1: depth=45, coverage=0.98, has LOEUF -> measured
gene1 = df.filter(pl.col("gene_symbol") == "GENE1")
assert gene1["quality_flag"][0] == "measured"
# GENE8: depth=50, coverage=0.90 (exactly at threshold), has LOEUF -> measured
gene8 = df.filter(pl.col("gene_symbol") == "GENE8")
assert gene8["quality_flag"][0] == "measured"
def test_filter_by_coverage_incomplete(sample_constraint_tsv: Path):
"""Low depth/CDS genes get quality_flag="incomplete_coverage"."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
df = lf.collect()
# GENE5: depth=25 (< 30) -> incomplete_coverage
gene5 = df.filter(pl.col("gene_symbol") == "GENE5")
assert gene5["quality_flag"][0] == "incomplete_coverage"
# GENE6: coverage=0.75 (< 0.9) -> incomplete_coverage
gene6 = df.filter(pl.col("gene_symbol") == "GENE6")
assert gene6["quality_flag"][0] == "incomplete_coverage"
# GENE12: depth=10 (very low) -> incomplete_coverage
gene12 = df.filter(pl.col("gene_symbol") == "GENE12")
assert gene12["quality_flag"][0] == "incomplete_coverage"
def test_filter_by_coverage_no_data(sample_constraint_tsv: Path):
"""NULL loeuf+pli genes get quality_flag="no_data"."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
df = lf.collect()
# GENE7: both pli and loeuf are NULL -> no_data
gene7 = df.filter(pl.col("gene_symbol") == "GENE7")
assert gene7["quality_flag"][0] == "no_data"
# GENE9: both pli and loeuf are NULL -> no_data
gene9 = df.filter(pl.col("gene_symbol") == "GENE9")
assert gene9["quality_flag"][0] == "no_data"
def test_filter_preserves_all_genes(sample_constraint_tsv: Path):
"""Row count before == row count after (no genes dropped)."""
lf = parse_constraint_tsv(sample_constraint_tsv)
df_before = lf.collect()
count_before = len(df_before)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
df_after = lf.collect()
count_after = len(df_after)
assert count_before == count_after, "Filter should preserve all genes"
def test_normalize_scores_range(sample_constraint_tsv: Path):
"""All non-null normalized scores are in [0, 1]."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
lf = normalize_scores(lf)
df = lf.collect()
# Filter to non-null normalized scores
normalized = df.filter(pl.col("loeuf_normalized").is_not_null())
if len(normalized) > 0:
min_score = normalized["loeuf_normalized"].min()
max_score = normalized["loeuf_normalized"].max()
assert min_score >= 0.0, f"Min normalized score {min_score} < 0"
assert max_score <= 1.0, f"Max normalized score {max_score} > 1"
def test_normalize_scores_inversion(sample_constraint_tsv: Path):
"""Lower LOEUF -> higher normalized score."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
lf = normalize_scores(lf)
df = lf.collect()
# GENE4: LOEUF=0.0 (most constrained) should have highest normalized score
gene4 = df.filter(pl.col("gene_symbol") == "GENE4")
if gene4["quality_flag"][0] == "measured":
assert gene4["loeuf_normalized"][0] is not None
# Should be close to 1.0 (most constrained)
assert gene4["loeuf_normalized"][0] >= 0.95
# GENE3: LOEUF=2.50 (least constrained) should have lowest normalized score
gene3 = df.filter(pl.col("gene_symbol") == "GENE3")
if gene3["quality_flag"][0] == "measured":
assert gene3["loeuf_normalized"][0] is not None
# Should be close to 0.0 (least constrained)
assert gene3["loeuf_normalized"][0] <= 0.05
def test_normalize_scores_null_preserved(sample_constraint_tsv: Path):
"""NULL loeuf stays NULL after normalization."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
lf = normalize_scores(lf)
df = lf.collect()
# GENE7: NULL loeuf -> NULL normalized
gene7 = df.filter(pl.col("gene_symbol") == "GENE7")
assert gene7["loeuf"][0] is None
assert gene7["loeuf_normalized"][0] is None
def test_normalize_scores_incomplete_stays_null(sample_constraint_tsv: Path):
"""incomplete_coverage genes get NULL normalized score."""
lf = parse_constraint_tsv(sample_constraint_tsv)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
lf = normalize_scores(lf)
df = lf.collect()
# GENE5: incomplete_coverage -> NULL normalized
gene5 = df.filter(pl.col("gene_symbol") == "GENE5")
assert gene5["quality_flag"][0] == "incomplete_coverage"
assert gene5["loeuf_normalized"][0] is None
# GENE6: incomplete_coverage -> NULL normalized
gene6 = df.filter(pl.col("gene_symbol") == "GENE6")
assert gene6["quality_flag"][0] == "incomplete_coverage"
assert gene6["loeuf_normalized"][0] is None
def test_process_gnomad_constraint_end_to_end(sample_constraint_tsv: Path):
"""Full pipeline returns DataFrame with all expected columns."""
df = process_gnomad_constraint(sample_constraint_tsv, min_depth=30.0, min_cds_pct=0.9)
# Check it's a materialized DataFrame
assert isinstance(df, pl.DataFrame)
# Check all expected columns exist
expected_columns = {
"gene_id",
"gene_symbol",
"transcript",
"pli",
"loeuf",
"mean_depth",
"cds_covered_pct",
"quality_flag",
"loeuf_normalized",
}
assert expected_columns.issubset(set(df.columns))
# Check we have genes in each category
measured_count = df.filter(pl.col("quality_flag") == "measured").height
incomplete_count = df.filter(pl.col("quality_flag") == "incomplete_coverage").height
no_data_count = df.filter(pl.col("quality_flag") == "no_data").height
assert measured_count > 0, "Should have some measured genes"
assert incomplete_count > 0, "Should have some incomplete_coverage genes"
assert no_data_count > 0, "Should have some no_data genes"
def test_constraint_record_model_validation():
"""ConstraintRecord validates correctly, rejects bad types."""
# Valid record
valid = ConstraintRecord(
gene_id="ENSG00000001",
gene_symbol="GENE1",
transcript="ENST00000001",
pli=0.95,
loeuf=0.15,
loeuf_upper=0.20,
mean_depth=45.0,
cds_covered_pct=0.98,
quality_flag="measured",
loeuf_normalized=0.85,
)
assert valid.gene_id == "ENSG00000001"
assert valid.loeuf_normalized == 0.85
# NULL values are OK
with_nulls = ConstraintRecord(
gene_id="ENSG00000002",
gene_symbol="GENE2",
transcript="ENST00000002",
pli=None,
loeuf=None,
quality_flag="no_data",
loeuf_normalized=None,
)
assert with_nulls.pli is None
assert with_nulls.loeuf is None
assert with_nulls.loeuf_normalized is None
# Invalid type should raise ValidationError
with pytest.raises(Exception): # Pydantic ValidationError
ConstraintRecord(
gene_id=12345, # Should be string
gene_symbol="GENE3",
transcript="ENST00000003",
)
@patch("usher_pipeline.evidence.gnomad.fetch.httpx.stream")
def test_download_skips_if_exists(mock_stream: Mock, tmp_path: Path):
"""download_constraint_metrics returns early if file exists and force=False."""
output_path = tmp_path / "constraint.tsv"
# Create existing file
output_path.write_text("gene\ttranscript\npli\tloeuf\n")
# Call download with force=False
result = download_constraint_metrics(output_path, force=False)
# Should return early without making HTTP request
assert result == output_path
mock_stream.assert_not_called()
@patch("usher_pipeline.evidence.gnomad.fetch.httpx.stream")
def test_download_forces_redownload(mock_stream: Mock, tmp_path: Path):
"""download_constraint_metrics re-downloads when force=True."""
output_path = tmp_path / "constraint.tsv"
# Create existing file
output_path.write_text("old content")
# Mock HTTP response
mock_response = Mock()
mock_response.headers = {"content-length": "100"}
mock_response.iter_bytes = Mock(return_value=[b"gene\ttranscript\n", b"data\n"])
mock_response.raise_for_status = Mock()
mock_stream.return_value.__enter__.return_value = mock_response
# Call download with force=True
result = download_constraint_metrics(output_path, force=True)
# Should make HTTP request
assert result == output_path
mock_stream.assert_called_once()
def test_filter_by_coverage_handles_missing_columns(tmp_path: Path):
"""filter_by_coverage handles genes with missing mean_depth or cds_covered_pct."""
tsv_path = tmp_path / "partial.tsv"
content = """gene\ttranscript\tgene_symbol\tlof.pLI\tlof.oe_ci.upper\tmean_coverage\tmean_proportion_covered
ENSG00000001\tENST00000001\tGENE1\t0.95\t0.15\t.\t.
"""
tsv_path.write_text(content)
lf = parse_constraint_tsv(tsv_path)
lf = filter_by_coverage(lf, min_depth=30.0, min_cds_pct=0.9)
df = lf.collect()
# GENE1 has NULL depth/coverage but has LOEUF -> should be incomplete_coverage
gene1 = df.filter(pl.col("gene_symbol") == "GENE1")
# With NULL depth/coverage, comparisons will be false, so it goes to incomplete_coverage
assert gene1["quality_flag"][0] in ["incomplete_coverage", "no_data"]