feat(03-01): implement annotation evidence fetch and transform modules

- Create AnnotationRecord model with GO counts, UniProt scores, tier classification
- Implement fetch_go_annotations using mygene.info batch queries
- Implement fetch_uniprot_scores using UniProt REST API
- Add classify_annotation_tier with 3-tier system (well/partial/poor)
- Add normalize_annotation_score with weighted composite (GO 50%, UniProt 30%, Pathway 20%)
- Implement process_annotation_evidence end-to-end pipeline
- Follow NULL preservation pattern from gnomAD (unknown != zero)
- Use lazy polars evaluation where applicable
This commit is contained in:
2026-02-11 18:58:45 +08:00
parent 0d252da348
commit adbb74b965
4 changed files with 563 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
"""Gene annotation completeness evidence layer."""
from usher_pipeline.evidence.annotation.models import AnnotationRecord, ANNOTATION_TABLE_NAME
from usher_pipeline.evidence.annotation.fetch import (
fetch_go_annotations,
fetch_uniprot_scores,
)
from usher_pipeline.evidence.annotation.transform import (
classify_annotation_tier,
normalize_annotation_score,
process_annotation_evidence,
)
__all__ = [
"AnnotationRecord",
"ANNOTATION_TABLE_NAME",
"fetch_go_annotations",
"fetch_uniprot_scores",
"classify_annotation_tier",
"normalize_annotation_score",
"process_annotation_evidence",
]

View File

@@ -0,0 +1,290 @@
"""Fetch gene annotation data from mygene.info and UniProt APIs."""
from typing import Optional
import math
import httpx
import mygene
import polars as pl
import structlog
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
logger = structlog.get_logger()
# Initialize mygene client (singleton pattern - reuse across calls)
_mg_client = None
def _get_mygene_client() -> mygene.MyGeneInfo:
"""Get or create mygene client singleton."""
global _mg_client
if _mg_client is None:
_mg_client = mygene.MyGeneInfo()
return _mg_client
def fetch_go_annotations(gene_ids: list[str], batch_size: int = 1000) -> pl.DataFrame:
"""Fetch GO annotations and pathway memberships from mygene.info.
Uses mygene.querymany to batch query GO terms and pathway data.
Processes in batches to avoid API timeout.
Args:
gene_ids: List of Ensembl gene IDs
batch_size: Number of genes per batch query (default: 1000)
Returns:
DataFrame with columns:
- gene_id: Ensembl gene ID
- gene_symbol: HGNC symbol (NULL if not found)
- go_term_count: Total GO term count across all ontologies (NULL if no GO data)
- go_biological_process_count: GO BP term count (NULL if no GO data)
- go_molecular_function_count: GO MF term count (NULL if no GO data)
- go_cellular_component_count: GO CC term count (NULL if no GO data)
- has_pathway_membership: Boolean indicating presence in KEGG/Reactome (NULL if no pathway data)
Note: Genes with no GO annotations get NULL counts (not zero).
"""
logger.info("fetch_go_annotations_start", gene_count=len(gene_ids))
mg = _get_mygene_client()
all_results = []
# Process in batches to avoid mygene timeout
num_batches = math.ceil(len(gene_ids) / batch_size)
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, len(gene_ids))
batch = gene_ids[start_idx:end_idx]
logger.info(
"fetch_go_batch",
batch_num=i + 1,
total_batches=num_batches,
batch_size=len(batch),
)
# Query mygene for GO terms, pathways, and symbol
try:
results = mg.querymany(
batch,
scopes="ensembl.gene",
fields="go,pathway.kegg,pathway.reactome,symbol",
species="human",
returnall=False,
)
# Process each gene's result
for result in results:
gene_id = result.get("query")
gene_symbol = result.get("symbol", None)
# Extract GO term counts by category
go_data = result.get("go", {})
if isinstance(go_data, dict):
# Count GO terms by ontology
bp_terms = go_data.get("BP", [])
mf_terms = go_data.get("MF", [])
cc_terms = go_data.get("CC", [])
# Convert to list if single dict (mygene sometimes returns dict for single term)
bp_list = bp_terms if isinstance(bp_terms, list) else ([bp_terms] if bp_terms else [])
mf_list = mf_terms if isinstance(mf_terms, list) else ([mf_terms] if mf_terms else [])
cc_list = cc_terms if isinstance(cc_terms, list) else ([cc_terms] if cc_terms else [])
bp_count = len(bp_list) if bp_list else None
mf_count = len(mf_list) if mf_list else None
cc_count = len(cc_list) if cc_list else None
# Total GO count (sum of non-NULL counts, or NULL if all NULL)
counts = [c for c in [bp_count, mf_count, cc_count] if c is not None]
total_count = sum(counts) if counts else None
else:
# No GO data
bp_count = None
mf_count = None
cc_count = None
total_count = None
# Check pathway membership
pathway_data = result.get("pathway", {})
has_kegg = bool(pathway_data.get("kegg"))
has_reactome = bool(pathway_data.get("reactome"))
has_pathway = (has_kegg or has_reactome) if (has_kegg or has_reactome or pathway_data) else None
all_results.append({
"gene_id": gene_id,
"gene_symbol": gene_symbol,
"go_term_count": total_count,
"go_biological_process_count": bp_count,
"go_molecular_function_count": mf_count,
"go_cellular_component_count": cc_count,
"has_pathway_membership": has_pathway,
})
except Exception as e:
logger.warning(
"fetch_go_batch_error",
batch_num=i + 1,
error=str(e),
)
# Add NULL entries for failed batch
for gene_id in batch:
all_results.append({
"gene_id": gene_id,
"gene_symbol": None,
"go_term_count": None,
"go_biological_process_count": None,
"go_molecular_function_count": None,
"go_cellular_component_count": None,
"has_pathway_membership": None,
})
logger.info("fetch_go_annotations_complete", result_count=len(all_results))
return pl.DataFrame(all_results)
@retry(
stop=stop_after_attempt(5),
wait=wait_exponential(multiplier=1, min=2, max=30),
retry=retry_if_exception_type(
(httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)
),
)
def _query_uniprot_batch(accessions: list[str]) -> dict:
"""Query UniProt REST API for annotation scores (with retry).
Args:
accessions: List of UniProt accession IDs (max 100)
Returns:
Dict mapping accession -> annotation_score
"""
if not accessions:
return {}
# Build OR query for batch lookup
query = " OR ".join([f"accession:{acc}" for acc in accessions])
url = "https://rest.uniprot.org/uniprotkb/search"
params = {
"query": query,
"fields": "accession,annotation_score",
"format": "json",
"size": len(accessions),
}
with httpx.Client(timeout=30.0) as client:
response = client.get(url, params=params)
response.raise_for_status()
data = response.json()
# Parse results into mapping
score_map = {}
for entry in data.get("results", []):
accession = entry.get("primaryAccession")
score = entry.get("annotationScore")
if accession and score is not None:
score_map[accession] = int(score)
return score_map
def fetch_uniprot_scores(
gene_ids: list[str],
uniprot_mapping: pl.DataFrame,
batch_size: int = 100,
) -> pl.DataFrame:
"""Fetch UniProt annotation scores for genes.
Uses UniProt REST API to query annotation scores in batches.
Rate-limited to avoid overwhelming the API (built-in via tenacity retry).
Args:
gene_ids: List of Ensembl gene IDs
uniprot_mapping: DataFrame with gene_id and uniprot_accession columns
batch_size: Number of UniProt accessions per batch (default: 100)
Returns:
DataFrame with columns:
- gene_id: Ensembl gene ID
- uniprot_annotation_score: UniProt annotation score 1-5 (NULL if no mapping/score)
Note: Genes without UniProt mapping get NULL (not zero).
"""
logger.info("fetch_uniprot_scores_start", gene_count=len(gene_ids))
# Filter mapping to requested genes
mapping_filtered = uniprot_mapping.filter(pl.col("gene_id").is_in(gene_ids))
if mapping_filtered.height == 0:
logger.warning("fetch_uniprot_no_mappings")
# Return all genes with NULL scores
return pl.DataFrame({
"gene_id": gene_ids,
"uniprot_annotation_score": [None] * len(gene_ids),
})
# Get unique accessions
accessions = mapping_filtered.select("uniprot_accession").unique().to_series().to_list()
logger.info("fetch_uniprot_accessions", accession_count=len(accessions))
# Batch query UniProt API
all_scores = {}
num_batches = math.ceil(len(accessions) / batch_size)
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, len(accessions))
batch = accessions[start_idx:end_idx]
logger.info(
"fetch_uniprot_batch",
batch_num=i + 1,
total_batches=num_batches,
batch_size=len(batch),
)
try:
batch_scores = _query_uniprot_batch(batch)
all_scores.update(batch_scores)
except Exception as e:
logger.warning(
"fetch_uniprot_batch_error",
batch_num=i + 1,
error=str(e),
)
# Continue with other batches - failed batch will have NULL scores
# Create accession -> score mapping
score_df = pl.DataFrame({
"uniprot_accession": list(all_scores.keys()),
"uniprot_annotation_score": list(all_scores.values()),
})
# Join back to gene IDs
result = (
mapping_filtered
.select(["gene_id", "uniprot_accession"])
.join(score_df, on="uniprot_accession", how="left")
.group_by("gene_id")
.agg(
# Take first score if multiple accessions (consistent with gene universe pattern)
pl.col("uniprot_annotation_score").first()
)
)
# Ensure all requested genes are present (add NULL for missing)
all_genes = pl.DataFrame({"gene_id": gene_ids})
result = all_genes.join(result, on="gene_id", how="left")
logger.info("fetch_uniprot_scores_complete", result_count=result.height)
return result

View File

@@ -0,0 +1,38 @@
"""Data models for gene annotation completeness evidence."""
from pydantic import BaseModel
# Table name for DuckDB storage
ANNOTATION_TABLE_NAME = "annotation_completeness"
class AnnotationRecord(BaseModel):
"""Gene annotation completeness metrics for a single gene.
Attributes:
gene_id: Ensembl gene ID (e.g., ENSG00000...)
gene_symbol: HGNC gene symbol
go_term_count: Total number of GO terms (all ontologies) - NULL if no data
go_biological_process_count: Number of GO Biological Process terms - NULL if no data
go_molecular_function_count: Number of GO Molecular Function terms - NULL if no data
go_cellular_component_count: Number of GO Cellular Component terms - NULL if no data
uniprot_annotation_score: UniProt annotation score 1-5 - NULL if no mapping or score
has_pathway_membership: Present in any KEGG/Reactome pathway - NULL if no data
annotation_tier: Classification: "well_annotated", "partially_annotated", "poorly_annotated"
annotation_score_normalized: Composite annotation score 0-1 (higher = better annotated) - NULL if all inputs NULL
CRITICAL: NULL values represent missing data and are preserved as None.
Do NOT convert NULL to 0 - "unknown annotation" is semantically different from "zero annotation".
Conservative approach: NULL GO counts treated as zero for tier classification (assume unannotated).
"""
gene_id: str
gene_symbol: str
go_term_count: int | None = None
go_biological_process_count: int | None = None
go_molecular_function_count: int | None = None
go_cellular_component_count: int | None = None
uniprot_annotation_score: int | None = None
has_pathway_membership: bool | None = None
annotation_tier: str = "poorly_annotated"
annotation_score_normalized: float | None = None

View File

@@ -0,0 +1,213 @@
"""Transform and normalize gene annotation completeness metrics."""
import math
from pathlib import Path
import polars as pl
import structlog
from usher_pipeline.evidence.annotation.fetch import (
fetch_go_annotations,
fetch_uniprot_scores,
)
logger = structlog.get_logger()
def classify_annotation_tier(df: pl.DataFrame) -> pl.DataFrame:
"""Classify genes into annotation tiers based on composite metrics.
Tier definitions:
- "well_annotated": go_term_count >= 20 AND uniprot_annotation_score >= 4
- "partially_annotated": go_term_count >= 5 OR uniprot_annotation_score >= 3
- "poorly_annotated": Everything else (including NULLs)
Conservative approach: NULL GO counts treated as zero for tier classification
(assume unannotated until proven otherwise).
Args:
df: DataFrame with go_term_count and uniprot_annotation_score columns
Returns:
DataFrame with annotation_tier column added
"""
logger.info("classify_annotation_tier_start", row_count=df.height)
# Fill NULL GO counts with 0 for tier classification (conservative)
# But preserve original NULL for downstream NULL handling
df = df.with_columns([
pl.col("go_term_count").fill_null(0).alias("_go_count_filled"),
pl.col("uniprot_annotation_score").fill_null(0).alias("_uniprot_score_filled"),
])
# Apply tier classification logic
df = df.with_columns(
pl.when(
(pl.col("_go_count_filled") >= 20) & (pl.col("_uniprot_score_filled") >= 4)
)
.then(pl.lit("well_annotated"))
.when(
(pl.col("_go_count_filled") >= 5) | (pl.col("_uniprot_score_filled") >= 3)
)
.then(pl.lit("partially_annotated"))
.otherwise(pl.lit("poorly_annotated"))
.alias("annotation_tier")
)
# Drop temporary filled columns
df = df.drop(["_go_count_filled", "_uniprot_score_filled"])
# Log tier distribution
tier_counts = df.group_by("annotation_tier").len().sort("annotation_tier")
logger.info("classify_annotation_tier_complete", tier_distribution=tier_counts.to_dicts())
return df
def normalize_annotation_score(df: pl.DataFrame) -> pl.DataFrame:
"""Compute normalized composite annotation score (0-1 range).
Formula: Weighted average of three components:
- GO component (50%): log2(go_term_count + 1) normalized by max across dataset
- UniProt component (30%): uniprot_annotation_score / 5.0
- Pathway component (20%): has_pathway_membership as 0/1
Result clamped to [0, 1]. NULL if ALL three inputs are NULL.
Args:
df: DataFrame with go_term_count, uniprot_annotation_score, has_pathway_membership
Returns:
DataFrame with annotation_score_normalized column added
"""
logger.info("normalize_annotation_score_start", row_count=df.height)
# Component weights
WEIGHT_GO = 0.5
WEIGHT_UNIPROT = 0.3
WEIGHT_PATHWAY = 0.2
# Compute GO component: log2(count + 1) normalized by max
df = df.with_columns(
pl.when(pl.col("go_term_count").is_not_null())
.then((pl.col("go_term_count") + 1).log(base=2))
.otherwise(None)
.alias("_go_log")
)
# Get max for normalization (from non-NULL values)
go_max = df.filter(pl.col("_go_log").is_not_null()).select(pl.col("_go_log").max()).item()
if go_max is None or go_max == 0:
# No GO data in dataset - all get NULL for GO component
df = df.with_columns(pl.lit(None).cast(pl.Float64).alias("_go_component"))
else:
df = df.with_columns(
pl.when(pl.col("_go_log").is_not_null())
.then((pl.col("_go_log") / go_max) * WEIGHT_GO)
.otherwise(None)
.alias("_go_component")
)
# Compute UniProt component: score / 5.0
df = df.with_columns(
pl.when(pl.col("uniprot_annotation_score").is_not_null())
.then((pl.col("uniprot_annotation_score") / 5.0) * WEIGHT_UNIPROT)
.otherwise(None)
.alias("_uniprot_component")
)
# Compute pathway component: boolean as 0/1
df = df.with_columns(
pl.when(pl.col("has_pathway_membership").is_not_null())
.then(
pl.when(pl.col("has_pathway_membership"))
.then(WEIGHT_PATHWAY)
.otherwise(0.0)
)
.otherwise(None)
.alias("_pathway_component")
)
# Composite score: sum of non-NULL components, NULL if all NULL
# Need to handle NULL properly: only compute if at least one component is non-NULL
df = df.with_columns(
pl.when(
pl.col("_go_component").is_not_null()
| pl.col("_uniprot_component").is_not_null()
| pl.col("_pathway_component").is_not_null()
)
.then(
# Sum components, treating NULL as 0 for the sum
pl.col("_go_component").fill_null(0.0)
+ pl.col("_uniprot_component").fill_null(0.0)
+ pl.col("_pathway_component").fill_null(0.0)
)
.otherwise(None)
.alias("annotation_score_normalized")
)
# Clamp to [0, 1] range (shouldn't exceed but defensive)
df = df.with_columns(
pl.when(pl.col("annotation_score_normalized").is_not_null())
.then(
pl.col("annotation_score_normalized").clip(0.0, 1.0)
)
.otherwise(None)
.alias("annotation_score_normalized")
)
# Drop temporary columns
df = df.drop(["_go_log", "_go_component", "_uniprot_component", "_pathway_component"])
# Log score statistics
stats = df.filter(pl.col("annotation_score_normalized").is_not_null()).select([
pl.col("annotation_score_normalized").mean().alias("mean"),
pl.col("annotation_score_normalized").median().alias("median"),
pl.col("annotation_score_normalized").min().alias("min"),
pl.col("annotation_score_normalized").max().alias("max"),
])
if stats.height > 0:
logger.info("normalize_annotation_score_complete", stats=stats.to_dicts()[0])
else:
logger.warning("normalize_annotation_score_complete", message="No valid scores computed")
return df
def process_annotation_evidence(
gene_ids: list[str],
uniprot_mapping: pl.DataFrame,
) -> pl.DataFrame:
"""End-to-end annotation evidence processing pipeline.
Composes: fetch GO -> fetch UniProt -> join -> classify tier -> normalize -> collect.
Args:
gene_ids: List of Ensembl gene IDs to process
uniprot_mapping: DataFrame with gene_id and uniprot_accession columns
Returns:
Materialized DataFrame with all annotation completeness metrics ready for DuckDB
"""
logger.info("process_annotation_evidence_start", gene_count=len(gene_ids))
# Fetch GO annotations and pathway memberships
go_df = fetch_go_annotations(gene_ids)
# Fetch UniProt annotation scores
uniprot_df = fetch_uniprot_scores(gene_ids, uniprot_mapping)
# Join GO and UniProt data
df = go_df.join(uniprot_df, on="gene_id", how="left")
# Classify annotation tiers
df = classify_annotation_tier(df)
# Normalize composite score
df = normalize_annotation_score(df)
logger.info("process_annotation_evidence_complete", result_count=df.height)
return df