From adbb74b9652213c3de7fc9a8d2a973b314e012b2 Mon Sep 17 00:00:00 2001 From: gbanyan Date: Wed, 11 Feb 2026 18:58:45 +0800 Subject: [PATCH] feat(03-01): implement annotation evidence fetch and transform modules - Create AnnotationRecord model with GO counts, UniProt scores, tier classification - Implement fetch_go_annotations using mygene.info batch queries - Implement fetch_uniprot_scores using UniProt REST API - Add classify_annotation_tier with 3-tier system (well/partial/poor) - Add normalize_annotation_score with weighted composite (GO 50%, UniProt 30%, Pathway 20%) - Implement process_annotation_evidence end-to-end pipeline - Follow NULL preservation pattern from gnomAD (unknown != zero) - Use lazy polars evaluation where applicable --- .../evidence/annotation/__init__.py | 22 ++ .../evidence/annotation/fetch.py | 290 ++++++++++++++++++ .../evidence/annotation/models.py | 38 +++ .../evidence/annotation/transform.py | 213 +++++++++++++ 4 files changed, 563 insertions(+) create mode 100644 src/usher_pipeline/evidence/annotation/__init__.py create mode 100644 src/usher_pipeline/evidence/annotation/fetch.py create mode 100644 src/usher_pipeline/evidence/annotation/models.py create mode 100644 src/usher_pipeline/evidence/annotation/transform.py diff --git a/src/usher_pipeline/evidence/annotation/__init__.py b/src/usher_pipeline/evidence/annotation/__init__.py new file mode 100644 index 0000000..5a66c3d --- /dev/null +++ b/src/usher_pipeline/evidence/annotation/__init__.py @@ -0,0 +1,22 @@ +"""Gene annotation completeness evidence layer.""" + +from usher_pipeline.evidence.annotation.models import AnnotationRecord, ANNOTATION_TABLE_NAME +from usher_pipeline.evidence.annotation.fetch import ( + fetch_go_annotations, + fetch_uniprot_scores, +) +from usher_pipeline.evidence.annotation.transform import ( + classify_annotation_tier, + normalize_annotation_score, + process_annotation_evidence, +) + +__all__ = [ + "AnnotationRecord", + "ANNOTATION_TABLE_NAME", + "fetch_go_annotations", + "fetch_uniprot_scores", + "classify_annotation_tier", + "normalize_annotation_score", + "process_annotation_evidence", +] diff --git a/src/usher_pipeline/evidence/annotation/fetch.py b/src/usher_pipeline/evidence/annotation/fetch.py new file mode 100644 index 0000000..996cf8e --- /dev/null +++ b/src/usher_pipeline/evidence/annotation/fetch.py @@ -0,0 +1,290 @@ +"""Fetch gene annotation data from mygene.info and UniProt APIs.""" + +from typing import Optional +import math + +import httpx +import mygene +import polars as pl +import structlog +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, +) + +logger = structlog.get_logger() + +# Initialize mygene client (singleton pattern - reuse across calls) +_mg_client = None + + +def _get_mygene_client() -> mygene.MyGeneInfo: + """Get or create mygene client singleton.""" + global _mg_client + if _mg_client is None: + _mg_client = mygene.MyGeneInfo() + return _mg_client + + +def fetch_go_annotations(gene_ids: list[str], batch_size: int = 1000) -> pl.DataFrame: + """Fetch GO annotations and pathway memberships from mygene.info. + + Uses mygene.querymany to batch query GO terms and pathway data. + Processes in batches to avoid API timeout. + + Args: + gene_ids: List of Ensembl gene IDs + batch_size: Number of genes per batch query (default: 1000) + + Returns: + DataFrame with columns: + - gene_id: Ensembl gene ID + - gene_symbol: HGNC symbol (NULL if not found) + - go_term_count: Total GO term count across all ontologies (NULL if no GO data) + - go_biological_process_count: GO BP term count (NULL if no GO data) + - go_molecular_function_count: GO MF term count (NULL if no GO data) + - go_cellular_component_count: GO CC term count (NULL if no GO data) + - has_pathway_membership: Boolean indicating presence in KEGG/Reactome (NULL if no pathway data) + + Note: Genes with no GO annotations get NULL counts (not zero). + """ + logger.info("fetch_go_annotations_start", gene_count=len(gene_ids)) + + mg = _get_mygene_client() + all_results = [] + + # Process in batches to avoid mygene timeout + num_batches = math.ceil(len(gene_ids) / batch_size) + + for i in range(num_batches): + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, len(gene_ids)) + batch = gene_ids[start_idx:end_idx] + + logger.info( + "fetch_go_batch", + batch_num=i + 1, + total_batches=num_batches, + batch_size=len(batch), + ) + + # Query mygene for GO terms, pathways, and symbol + try: + results = mg.querymany( + batch, + scopes="ensembl.gene", + fields="go,pathway.kegg,pathway.reactome,symbol", + species="human", + returnall=False, + ) + + # Process each gene's result + for result in results: + gene_id = result.get("query") + gene_symbol = result.get("symbol", None) + + # Extract GO term counts by category + go_data = result.get("go", {}) + if isinstance(go_data, dict): + # Count GO terms by ontology + bp_terms = go_data.get("BP", []) + mf_terms = go_data.get("MF", []) + cc_terms = go_data.get("CC", []) + + # Convert to list if single dict (mygene sometimes returns dict for single term) + bp_list = bp_terms if isinstance(bp_terms, list) else ([bp_terms] if bp_terms else []) + mf_list = mf_terms if isinstance(mf_terms, list) else ([mf_terms] if mf_terms else []) + cc_list = cc_terms if isinstance(cc_terms, list) else ([cc_terms] if cc_terms else []) + + bp_count = len(bp_list) if bp_list else None + mf_count = len(mf_list) if mf_list else None + cc_count = len(cc_list) if cc_list else None + + # Total GO count (sum of non-NULL counts, or NULL if all NULL) + counts = [c for c in [bp_count, mf_count, cc_count] if c is not None] + total_count = sum(counts) if counts else None + else: + # No GO data + bp_count = None + mf_count = None + cc_count = None + total_count = None + + # Check pathway membership + pathway_data = result.get("pathway", {}) + has_kegg = bool(pathway_data.get("kegg")) + has_reactome = bool(pathway_data.get("reactome")) + has_pathway = (has_kegg or has_reactome) if (has_kegg or has_reactome or pathway_data) else None + + all_results.append({ + "gene_id": gene_id, + "gene_symbol": gene_symbol, + "go_term_count": total_count, + "go_biological_process_count": bp_count, + "go_molecular_function_count": mf_count, + "go_cellular_component_count": cc_count, + "has_pathway_membership": has_pathway, + }) + + except Exception as e: + logger.warning( + "fetch_go_batch_error", + batch_num=i + 1, + error=str(e), + ) + # Add NULL entries for failed batch + for gene_id in batch: + all_results.append({ + "gene_id": gene_id, + "gene_symbol": None, + "go_term_count": None, + "go_biological_process_count": None, + "go_molecular_function_count": None, + "go_cellular_component_count": None, + "has_pathway_membership": None, + }) + + logger.info("fetch_go_annotations_complete", result_count=len(all_results)) + + return pl.DataFrame(all_results) + + +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=2, max=30), + retry=retry_if_exception_type( + (httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException) + ), +) +def _query_uniprot_batch(accessions: list[str]) -> dict: + """Query UniProt REST API for annotation scores (with retry). + + Args: + accessions: List of UniProt accession IDs (max 100) + + Returns: + Dict mapping accession -> annotation_score + """ + if not accessions: + return {} + + # Build OR query for batch lookup + query = " OR ".join([f"accession:{acc}" for acc in accessions]) + url = "https://rest.uniprot.org/uniprotkb/search" + + params = { + "query": query, + "fields": "accession,annotation_score", + "format": "json", + "size": len(accessions), + } + + with httpx.Client(timeout=30.0) as client: + response = client.get(url, params=params) + response.raise_for_status() + data = response.json() + + # Parse results into mapping + score_map = {} + for entry in data.get("results", []): + accession = entry.get("primaryAccession") + score = entry.get("annotationScore") + if accession and score is not None: + score_map[accession] = int(score) + + return score_map + + +def fetch_uniprot_scores( + gene_ids: list[str], + uniprot_mapping: pl.DataFrame, + batch_size: int = 100, +) -> pl.DataFrame: + """Fetch UniProt annotation scores for genes. + + Uses UniProt REST API to query annotation scores in batches. + Rate-limited to avoid overwhelming the API (built-in via tenacity retry). + + Args: + gene_ids: List of Ensembl gene IDs + uniprot_mapping: DataFrame with gene_id and uniprot_accession columns + batch_size: Number of UniProt accessions per batch (default: 100) + + Returns: + DataFrame with columns: + - gene_id: Ensembl gene ID + - uniprot_annotation_score: UniProt annotation score 1-5 (NULL if no mapping/score) + + Note: Genes without UniProt mapping get NULL (not zero). + """ + logger.info("fetch_uniprot_scores_start", gene_count=len(gene_ids)) + + # Filter mapping to requested genes + mapping_filtered = uniprot_mapping.filter(pl.col("gene_id").is_in(gene_ids)) + + if mapping_filtered.height == 0: + logger.warning("fetch_uniprot_no_mappings") + # Return all genes with NULL scores + return pl.DataFrame({ + "gene_id": gene_ids, + "uniprot_annotation_score": [None] * len(gene_ids), + }) + + # Get unique accessions + accessions = mapping_filtered.select("uniprot_accession").unique().to_series().to_list() + logger.info("fetch_uniprot_accessions", accession_count=len(accessions)) + + # Batch query UniProt API + all_scores = {} + num_batches = math.ceil(len(accessions) / batch_size) + + for i in range(num_batches): + start_idx = i * batch_size + end_idx = min((i + 1) * batch_size, len(accessions)) + batch = accessions[start_idx:end_idx] + + logger.info( + "fetch_uniprot_batch", + batch_num=i + 1, + total_batches=num_batches, + batch_size=len(batch), + ) + + try: + batch_scores = _query_uniprot_batch(batch) + all_scores.update(batch_scores) + except Exception as e: + logger.warning( + "fetch_uniprot_batch_error", + batch_num=i + 1, + error=str(e), + ) + # Continue with other batches - failed batch will have NULL scores + + # Create accession -> score mapping + score_df = pl.DataFrame({ + "uniprot_accession": list(all_scores.keys()), + "uniprot_annotation_score": list(all_scores.values()), + }) + + # Join back to gene IDs + result = ( + mapping_filtered + .select(["gene_id", "uniprot_accession"]) + .join(score_df, on="uniprot_accession", how="left") + .group_by("gene_id") + .agg( + # Take first score if multiple accessions (consistent with gene universe pattern) + pl.col("uniprot_annotation_score").first() + ) + ) + + # Ensure all requested genes are present (add NULL for missing) + all_genes = pl.DataFrame({"gene_id": gene_ids}) + result = all_genes.join(result, on="gene_id", how="left") + + logger.info("fetch_uniprot_scores_complete", result_count=result.height) + + return result diff --git a/src/usher_pipeline/evidence/annotation/models.py b/src/usher_pipeline/evidence/annotation/models.py new file mode 100644 index 0000000..ef3c67b --- /dev/null +++ b/src/usher_pipeline/evidence/annotation/models.py @@ -0,0 +1,38 @@ +"""Data models for gene annotation completeness evidence.""" + +from pydantic import BaseModel + +# Table name for DuckDB storage +ANNOTATION_TABLE_NAME = "annotation_completeness" + + +class AnnotationRecord(BaseModel): + """Gene annotation completeness metrics for a single gene. + + Attributes: + gene_id: Ensembl gene ID (e.g., ENSG00000...) + gene_symbol: HGNC gene symbol + go_term_count: Total number of GO terms (all ontologies) - NULL if no data + go_biological_process_count: Number of GO Biological Process terms - NULL if no data + go_molecular_function_count: Number of GO Molecular Function terms - NULL if no data + go_cellular_component_count: Number of GO Cellular Component terms - NULL if no data + uniprot_annotation_score: UniProt annotation score 1-5 - NULL if no mapping or score + has_pathway_membership: Present in any KEGG/Reactome pathway - NULL if no data + annotation_tier: Classification: "well_annotated", "partially_annotated", "poorly_annotated" + annotation_score_normalized: Composite annotation score 0-1 (higher = better annotated) - NULL if all inputs NULL + + CRITICAL: NULL values represent missing data and are preserved as None. + Do NOT convert NULL to 0 - "unknown annotation" is semantically different from "zero annotation". + Conservative approach: NULL GO counts treated as zero for tier classification (assume unannotated). + """ + + gene_id: str + gene_symbol: str + go_term_count: int | None = None + go_biological_process_count: int | None = None + go_molecular_function_count: int | None = None + go_cellular_component_count: int | None = None + uniprot_annotation_score: int | None = None + has_pathway_membership: bool | None = None + annotation_tier: str = "poorly_annotated" + annotation_score_normalized: float | None = None diff --git a/src/usher_pipeline/evidence/annotation/transform.py b/src/usher_pipeline/evidence/annotation/transform.py new file mode 100644 index 0000000..0be2f3a --- /dev/null +++ b/src/usher_pipeline/evidence/annotation/transform.py @@ -0,0 +1,213 @@ +"""Transform and normalize gene annotation completeness metrics.""" + +import math +from pathlib import Path + +import polars as pl +import structlog + +from usher_pipeline.evidence.annotation.fetch import ( + fetch_go_annotations, + fetch_uniprot_scores, +) + +logger = structlog.get_logger() + + +def classify_annotation_tier(df: pl.DataFrame) -> pl.DataFrame: + """Classify genes into annotation tiers based on composite metrics. + + Tier definitions: + - "well_annotated": go_term_count >= 20 AND uniprot_annotation_score >= 4 + - "partially_annotated": go_term_count >= 5 OR uniprot_annotation_score >= 3 + - "poorly_annotated": Everything else (including NULLs) + + Conservative approach: NULL GO counts treated as zero for tier classification + (assume unannotated until proven otherwise). + + Args: + df: DataFrame with go_term_count and uniprot_annotation_score columns + + Returns: + DataFrame with annotation_tier column added + """ + logger.info("classify_annotation_tier_start", row_count=df.height) + + # Fill NULL GO counts with 0 for tier classification (conservative) + # But preserve original NULL for downstream NULL handling + df = df.with_columns([ + pl.col("go_term_count").fill_null(0).alias("_go_count_filled"), + pl.col("uniprot_annotation_score").fill_null(0).alias("_uniprot_score_filled"), + ]) + + # Apply tier classification logic + df = df.with_columns( + pl.when( + (pl.col("_go_count_filled") >= 20) & (pl.col("_uniprot_score_filled") >= 4) + ) + .then(pl.lit("well_annotated")) + .when( + (pl.col("_go_count_filled") >= 5) | (pl.col("_uniprot_score_filled") >= 3) + ) + .then(pl.lit("partially_annotated")) + .otherwise(pl.lit("poorly_annotated")) + .alias("annotation_tier") + ) + + # Drop temporary filled columns + df = df.drop(["_go_count_filled", "_uniprot_score_filled"]) + + # Log tier distribution + tier_counts = df.group_by("annotation_tier").len().sort("annotation_tier") + logger.info("classify_annotation_tier_complete", tier_distribution=tier_counts.to_dicts()) + + return df + + +def normalize_annotation_score(df: pl.DataFrame) -> pl.DataFrame: + """Compute normalized composite annotation score (0-1 range). + + Formula: Weighted average of three components: + - GO component (50%): log2(go_term_count + 1) normalized by max across dataset + - UniProt component (30%): uniprot_annotation_score / 5.0 + - Pathway component (20%): has_pathway_membership as 0/1 + + Result clamped to [0, 1]. NULL if ALL three inputs are NULL. + + Args: + df: DataFrame with go_term_count, uniprot_annotation_score, has_pathway_membership + + Returns: + DataFrame with annotation_score_normalized column added + """ + logger.info("normalize_annotation_score_start", row_count=df.height) + + # Component weights + WEIGHT_GO = 0.5 + WEIGHT_UNIPROT = 0.3 + WEIGHT_PATHWAY = 0.2 + + # Compute GO component: log2(count + 1) normalized by max + df = df.with_columns( + pl.when(pl.col("go_term_count").is_not_null()) + .then((pl.col("go_term_count") + 1).log(base=2)) + .otherwise(None) + .alias("_go_log") + ) + + # Get max for normalization (from non-NULL values) + go_max = df.filter(pl.col("_go_log").is_not_null()).select(pl.col("_go_log").max()).item() + + if go_max is None or go_max == 0: + # No GO data in dataset - all get NULL for GO component + df = df.with_columns(pl.lit(None).cast(pl.Float64).alias("_go_component")) + else: + df = df.with_columns( + pl.when(pl.col("_go_log").is_not_null()) + .then((pl.col("_go_log") / go_max) * WEIGHT_GO) + .otherwise(None) + .alias("_go_component") + ) + + # Compute UniProt component: score / 5.0 + df = df.with_columns( + pl.when(pl.col("uniprot_annotation_score").is_not_null()) + .then((pl.col("uniprot_annotation_score") / 5.0) * WEIGHT_UNIPROT) + .otherwise(None) + .alias("_uniprot_component") + ) + + # Compute pathway component: boolean as 0/1 + df = df.with_columns( + pl.when(pl.col("has_pathway_membership").is_not_null()) + .then( + pl.when(pl.col("has_pathway_membership")) + .then(WEIGHT_PATHWAY) + .otherwise(0.0) + ) + .otherwise(None) + .alias("_pathway_component") + ) + + # Composite score: sum of non-NULL components, NULL if all NULL + # Need to handle NULL properly: only compute if at least one component is non-NULL + df = df.with_columns( + pl.when( + pl.col("_go_component").is_not_null() + | pl.col("_uniprot_component").is_not_null() + | pl.col("_pathway_component").is_not_null() + ) + .then( + # Sum components, treating NULL as 0 for the sum + pl.col("_go_component").fill_null(0.0) + + pl.col("_uniprot_component").fill_null(0.0) + + pl.col("_pathway_component").fill_null(0.0) + ) + .otherwise(None) + .alias("annotation_score_normalized") + ) + + # Clamp to [0, 1] range (shouldn't exceed but defensive) + df = df.with_columns( + pl.when(pl.col("annotation_score_normalized").is_not_null()) + .then( + pl.col("annotation_score_normalized").clip(0.0, 1.0) + ) + .otherwise(None) + .alias("annotation_score_normalized") + ) + + # Drop temporary columns + df = df.drop(["_go_log", "_go_component", "_uniprot_component", "_pathway_component"]) + + # Log score statistics + stats = df.filter(pl.col("annotation_score_normalized").is_not_null()).select([ + pl.col("annotation_score_normalized").mean().alias("mean"), + pl.col("annotation_score_normalized").median().alias("median"), + pl.col("annotation_score_normalized").min().alias("min"), + pl.col("annotation_score_normalized").max().alias("max"), + ]) + + if stats.height > 0: + logger.info("normalize_annotation_score_complete", stats=stats.to_dicts()[0]) + else: + logger.warning("normalize_annotation_score_complete", message="No valid scores computed") + + return df + + +def process_annotation_evidence( + gene_ids: list[str], + uniprot_mapping: pl.DataFrame, +) -> pl.DataFrame: + """End-to-end annotation evidence processing pipeline. + + Composes: fetch GO -> fetch UniProt -> join -> classify tier -> normalize -> collect. + + Args: + gene_ids: List of Ensembl gene IDs to process + uniprot_mapping: DataFrame with gene_id and uniprot_accession columns + + Returns: + Materialized DataFrame with all annotation completeness metrics ready for DuckDB + """ + logger.info("process_annotation_evidence_start", gene_count=len(gene_ids)) + + # Fetch GO annotations and pathway memberships + go_df = fetch_go_annotations(gene_ids) + + # Fetch UniProt annotation scores + uniprot_df = fetch_uniprot_scores(gene_ids, uniprot_mapping) + + # Join GO and UniProt data + df = go_df.join(uniprot_df, on="gene_id", how="left") + + # Classify annotation tiers + df = classify_annotation_tier(df) + + # Normalize composite score + df = normalize_annotation_score(df) + + logger.info("process_annotation_evidence_complete", result_count=df.height) + + return df