feat(03-01): implement annotation evidence fetch and transform modules
- Create AnnotationRecord model with GO counts, UniProt scores, tier classification - Implement fetch_go_annotations using mygene.info batch queries - Implement fetch_uniprot_scores using UniProt REST API - Add classify_annotation_tier with 3-tier system (well/partial/poor) - Add normalize_annotation_score with weighted composite (GO 50%, UniProt 30%, Pathway 20%) - Implement process_annotation_evidence end-to-end pipeline - Follow NULL preservation pattern from gnomAD (unknown != zero) - Use lazy polars evaluation where applicable
This commit is contained in:
22
src/usher_pipeline/evidence/annotation/__init__.py
Normal file
22
src/usher_pipeline/evidence/annotation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Gene annotation completeness evidence layer."""
|
||||||
|
|
||||||
|
from usher_pipeline.evidence.annotation.models import AnnotationRecord, ANNOTATION_TABLE_NAME
|
||||||
|
from usher_pipeline.evidence.annotation.fetch import (
|
||||||
|
fetch_go_annotations,
|
||||||
|
fetch_uniprot_scores,
|
||||||
|
)
|
||||||
|
from usher_pipeline.evidence.annotation.transform import (
|
||||||
|
classify_annotation_tier,
|
||||||
|
normalize_annotation_score,
|
||||||
|
process_annotation_evidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnnotationRecord",
|
||||||
|
"ANNOTATION_TABLE_NAME",
|
||||||
|
"fetch_go_annotations",
|
||||||
|
"fetch_uniprot_scores",
|
||||||
|
"classify_annotation_tier",
|
||||||
|
"normalize_annotation_score",
|
||||||
|
"process_annotation_evidence",
|
||||||
|
]
|
||||||
290
src/usher_pipeline/evidence/annotation/fetch.py
Normal file
290
src/usher_pipeline/evidence/annotation/fetch.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""Fetch gene annotation data from mygene.info and UniProt APIs."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import math
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import mygene
|
||||||
|
import polars as pl
|
||||||
|
import structlog
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
retry_if_exception_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
# Initialize mygene client (singleton pattern - reuse across calls)
|
||||||
|
_mg_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mygene_client() -> mygene.MyGeneInfo:
|
||||||
|
"""Get or create mygene client singleton."""
|
||||||
|
global _mg_client
|
||||||
|
if _mg_client is None:
|
||||||
|
_mg_client = mygene.MyGeneInfo()
|
||||||
|
return _mg_client
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_go_annotations(gene_ids: list[str], batch_size: int = 1000) -> pl.DataFrame:
|
||||||
|
"""Fetch GO annotations and pathway memberships from mygene.info.
|
||||||
|
|
||||||
|
Uses mygene.querymany to batch query GO terms and pathway data.
|
||||||
|
Processes in batches to avoid API timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gene_ids: List of Ensembl gene IDs
|
||||||
|
batch_size: Number of genes per batch query (default: 1000)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns:
|
||||||
|
- gene_id: Ensembl gene ID
|
||||||
|
- gene_symbol: HGNC symbol (NULL if not found)
|
||||||
|
- go_term_count: Total GO term count across all ontologies (NULL if no GO data)
|
||||||
|
- go_biological_process_count: GO BP term count (NULL if no GO data)
|
||||||
|
- go_molecular_function_count: GO MF term count (NULL if no GO data)
|
||||||
|
- go_cellular_component_count: GO CC term count (NULL if no GO data)
|
||||||
|
- has_pathway_membership: Boolean indicating presence in KEGG/Reactome (NULL if no pathway data)
|
||||||
|
|
||||||
|
Note: Genes with no GO annotations get NULL counts (not zero).
|
||||||
|
"""
|
||||||
|
logger.info("fetch_go_annotations_start", gene_count=len(gene_ids))
|
||||||
|
|
||||||
|
mg = _get_mygene_client()
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
# Process in batches to avoid mygene timeout
|
||||||
|
num_batches = math.ceil(len(gene_ids) / batch_size)
|
||||||
|
|
||||||
|
for i in range(num_batches):
|
||||||
|
start_idx = i * batch_size
|
||||||
|
end_idx = min((i + 1) * batch_size, len(gene_ids))
|
||||||
|
batch = gene_ids[start_idx:end_idx]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"fetch_go_batch",
|
||||||
|
batch_num=i + 1,
|
||||||
|
total_batches=num_batches,
|
||||||
|
batch_size=len(batch),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query mygene for GO terms, pathways, and symbol
|
||||||
|
try:
|
||||||
|
results = mg.querymany(
|
||||||
|
batch,
|
||||||
|
scopes="ensembl.gene",
|
||||||
|
fields="go,pathway.kegg,pathway.reactome,symbol",
|
||||||
|
species="human",
|
||||||
|
returnall=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each gene's result
|
||||||
|
for result in results:
|
||||||
|
gene_id = result.get("query")
|
||||||
|
gene_symbol = result.get("symbol", None)
|
||||||
|
|
||||||
|
# Extract GO term counts by category
|
||||||
|
go_data = result.get("go", {})
|
||||||
|
if isinstance(go_data, dict):
|
||||||
|
# Count GO terms by ontology
|
||||||
|
bp_terms = go_data.get("BP", [])
|
||||||
|
mf_terms = go_data.get("MF", [])
|
||||||
|
cc_terms = go_data.get("CC", [])
|
||||||
|
|
||||||
|
# Convert to list if single dict (mygene sometimes returns dict for single term)
|
||||||
|
bp_list = bp_terms if isinstance(bp_terms, list) else ([bp_terms] if bp_terms else [])
|
||||||
|
mf_list = mf_terms if isinstance(mf_terms, list) else ([mf_terms] if mf_terms else [])
|
||||||
|
cc_list = cc_terms if isinstance(cc_terms, list) else ([cc_terms] if cc_terms else [])
|
||||||
|
|
||||||
|
bp_count = len(bp_list) if bp_list else None
|
||||||
|
mf_count = len(mf_list) if mf_list else None
|
||||||
|
cc_count = len(cc_list) if cc_list else None
|
||||||
|
|
||||||
|
# Total GO count (sum of non-NULL counts, or NULL if all NULL)
|
||||||
|
counts = [c for c in [bp_count, mf_count, cc_count] if c is not None]
|
||||||
|
total_count = sum(counts) if counts else None
|
||||||
|
else:
|
||||||
|
# No GO data
|
||||||
|
bp_count = None
|
||||||
|
mf_count = None
|
||||||
|
cc_count = None
|
||||||
|
total_count = None
|
||||||
|
|
||||||
|
# Check pathway membership
|
||||||
|
pathway_data = result.get("pathway", {})
|
||||||
|
has_kegg = bool(pathway_data.get("kegg"))
|
||||||
|
has_reactome = bool(pathway_data.get("reactome"))
|
||||||
|
has_pathway = (has_kegg or has_reactome) if (has_kegg or has_reactome or pathway_data) else None
|
||||||
|
|
||||||
|
all_results.append({
|
||||||
|
"gene_id": gene_id,
|
||||||
|
"gene_symbol": gene_symbol,
|
||||||
|
"go_term_count": total_count,
|
||||||
|
"go_biological_process_count": bp_count,
|
||||||
|
"go_molecular_function_count": mf_count,
|
||||||
|
"go_cellular_component_count": cc_count,
|
||||||
|
"has_pathway_membership": has_pathway,
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"fetch_go_batch_error",
|
||||||
|
batch_num=i + 1,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
# Add NULL entries for failed batch
|
||||||
|
for gene_id in batch:
|
||||||
|
all_results.append({
|
||||||
|
"gene_id": gene_id,
|
||||||
|
"gene_symbol": None,
|
||||||
|
"go_term_count": None,
|
||||||
|
"go_biological_process_count": None,
|
||||||
|
"go_molecular_function_count": None,
|
||||||
|
"go_cellular_component_count": None,
|
||||||
|
"has_pathway_membership": None,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info("fetch_go_annotations_complete", result_count=len(all_results))
|
||||||
|
|
||||||
|
return pl.DataFrame(all_results)
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(5),
|
||||||
|
wait=wait_exponential(multiplier=1, min=2, max=30),
|
||||||
|
retry=retry_if_exception_type(
|
||||||
|
(httpx.HTTPStatusError, httpx.ConnectError, httpx.TimeoutException)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def _query_uniprot_batch(accessions: list[str]) -> dict:
|
||||||
|
"""Query UniProt REST API for annotation scores (with retry).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
accessions: List of UniProt accession IDs (max 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping accession -> annotation_score
|
||||||
|
"""
|
||||||
|
if not accessions:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Build OR query for batch lookup
|
||||||
|
query = " OR ".join([f"accession:{acc}" for acc in accessions])
|
||||||
|
url = "https://rest.uniprot.org/uniprotkb/search"
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"fields": "accession,annotation_score",
|
||||||
|
"format": "json",
|
||||||
|
"size": len(accessions),
|
||||||
|
}
|
||||||
|
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
response = client.get(url, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# Parse results into mapping
|
||||||
|
score_map = {}
|
||||||
|
for entry in data.get("results", []):
|
||||||
|
accession = entry.get("primaryAccession")
|
||||||
|
score = entry.get("annotationScore")
|
||||||
|
if accession and score is not None:
|
||||||
|
score_map[accession] = int(score)
|
||||||
|
|
||||||
|
return score_map
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_uniprot_scores(
|
||||||
|
gene_ids: list[str],
|
||||||
|
uniprot_mapping: pl.DataFrame,
|
||||||
|
batch_size: int = 100,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""Fetch UniProt annotation scores for genes.
|
||||||
|
|
||||||
|
Uses UniProt REST API to query annotation scores in batches.
|
||||||
|
Rate-limited to avoid overwhelming the API (built-in via tenacity retry).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gene_ids: List of Ensembl gene IDs
|
||||||
|
uniprot_mapping: DataFrame with gene_id and uniprot_accession columns
|
||||||
|
batch_size: Number of UniProt accessions per batch (default: 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns:
|
||||||
|
- gene_id: Ensembl gene ID
|
||||||
|
- uniprot_annotation_score: UniProt annotation score 1-5 (NULL if no mapping/score)
|
||||||
|
|
||||||
|
Note: Genes without UniProt mapping get NULL (not zero).
|
||||||
|
"""
|
||||||
|
logger.info("fetch_uniprot_scores_start", gene_count=len(gene_ids))
|
||||||
|
|
||||||
|
# Filter mapping to requested genes
|
||||||
|
mapping_filtered = uniprot_mapping.filter(pl.col("gene_id").is_in(gene_ids))
|
||||||
|
|
||||||
|
if mapping_filtered.height == 0:
|
||||||
|
logger.warning("fetch_uniprot_no_mappings")
|
||||||
|
# Return all genes with NULL scores
|
||||||
|
return pl.DataFrame({
|
||||||
|
"gene_id": gene_ids,
|
||||||
|
"uniprot_annotation_score": [None] * len(gene_ids),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get unique accessions
|
||||||
|
accessions = mapping_filtered.select("uniprot_accession").unique().to_series().to_list()
|
||||||
|
logger.info("fetch_uniprot_accessions", accession_count=len(accessions))
|
||||||
|
|
||||||
|
# Batch query UniProt API
|
||||||
|
all_scores = {}
|
||||||
|
num_batches = math.ceil(len(accessions) / batch_size)
|
||||||
|
|
||||||
|
for i in range(num_batches):
|
||||||
|
start_idx = i * batch_size
|
||||||
|
end_idx = min((i + 1) * batch_size, len(accessions))
|
||||||
|
batch = accessions[start_idx:end_idx]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"fetch_uniprot_batch",
|
||||||
|
batch_num=i + 1,
|
||||||
|
total_batches=num_batches,
|
||||||
|
batch_size=len(batch),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
batch_scores = _query_uniprot_batch(batch)
|
||||||
|
all_scores.update(batch_scores)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"fetch_uniprot_batch_error",
|
||||||
|
batch_num=i + 1,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
# Continue with other batches - failed batch will have NULL scores
|
||||||
|
|
||||||
|
# Create accession -> score mapping
|
||||||
|
score_df = pl.DataFrame({
|
||||||
|
"uniprot_accession": list(all_scores.keys()),
|
||||||
|
"uniprot_annotation_score": list(all_scores.values()),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Join back to gene IDs
|
||||||
|
result = (
|
||||||
|
mapping_filtered
|
||||||
|
.select(["gene_id", "uniprot_accession"])
|
||||||
|
.join(score_df, on="uniprot_accession", how="left")
|
||||||
|
.group_by("gene_id")
|
||||||
|
.agg(
|
||||||
|
# Take first score if multiple accessions (consistent with gene universe pattern)
|
||||||
|
pl.col("uniprot_annotation_score").first()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure all requested genes are present (add NULL for missing)
|
||||||
|
all_genes = pl.DataFrame({"gene_id": gene_ids})
|
||||||
|
result = all_genes.join(result, on="gene_id", how="left")
|
||||||
|
|
||||||
|
logger.info("fetch_uniprot_scores_complete", result_count=result.height)
|
||||||
|
|
||||||
|
return result
|
||||||
38
src/usher_pipeline/evidence/annotation/models.py
Normal file
38
src/usher_pipeline/evidence/annotation/models.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Data models for gene annotation completeness evidence."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Table name for DuckDB storage
|
||||||
|
ANNOTATION_TABLE_NAME = "annotation_completeness"
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationRecord(BaseModel):
|
||||||
|
"""Gene annotation completeness metrics for a single gene.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gene_id: Ensembl gene ID (e.g., ENSG00000...)
|
||||||
|
gene_symbol: HGNC gene symbol
|
||||||
|
go_term_count: Total number of GO terms (all ontologies) - NULL if no data
|
||||||
|
go_biological_process_count: Number of GO Biological Process terms - NULL if no data
|
||||||
|
go_molecular_function_count: Number of GO Molecular Function terms - NULL if no data
|
||||||
|
go_cellular_component_count: Number of GO Cellular Component terms - NULL if no data
|
||||||
|
uniprot_annotation_score: UniProt annotation score 1-5 - NULL if no mapping or score
|
||||||
|
has_pathway_membership: Present in any KEGG/Reactome pathway - NULL if no data
|
||||||
|
annotation_tier: Classification: "well_annotated", "partially_annotated", "poorly_annotated"
|
||||||
|
annotation_score_normalized: Composite annotation score 0-1 (higher = better annotated) - NULL if all inputs NULL
|
||||||
|
|
||||||
|
CRITICAL: NULL values represent missing data and are preserved as None.
|
||||||
|
Do NOT convert NULL to 0 - "unknown annotation" is semantically different from "zero annotation".
|
||||||
|
Conservative approach: NULL GO counts treated as zero for tier classification (assume unannotated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
gene_id: str
|
||||||
|
gene_symbol: str
|
||||||
|
go_term_count: int | None = None
|
||||||
|
go_biological_process_count: int | None = None
|
||||||
|
go_molecular_function_count: int | None = None
|
||||||
|
go_cellular_component_count: int | None = None
|
||||||
|
uniprot_annotation_score: int | None = None
|
||||||
|
has_pathway_membership: bool | None = None
|
||||||
|
annotation_tier: str = "poorly_annotated"
|
||||||
|
annotation_score_normalized: float | None = None
|
||||||
213
src/usher_pipeline/evidence/annotation/transform.py
Normal file
213
src/usher_pipeline/evidence/annotation/transform.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Transform and normalize gene annotation completeness metrics."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import polars as pl
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from usher_pipeline.evidence.annotation.fetch import (
|
||||||
|
fetch_go_annotations,
|
||||||
|
fetch_uniprot_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def classify_annotation_tier(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""Classify genes into annotation tiers based on composite metrics.
|
||||||
|
|
||||||
|
Tier definitions:
|
||||||
|
- "well_annotated": go_term_count >= 20 AND uniprot_annotation_score >= 4
|
||||||
|
- "partially_annotated": go_term_count >= 5 OR uniprot_annotation_score >= 3
|
||||||
|
- "poorly_annotated": Everything else (including NULLs)
|
||||||
|
|
||||||
|
Conservative approach: NULL GO counts treated as zero for tier classification
|
||||||
|
(assume unannotated until proven otherwise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with go_term_count and uniprot_annotation_score columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with annotation_tier column added
|
||||||
|
"""
|
||||||
|
logger.info("classify_annotation_tier_start", row_count=df.height)
|
||||||
|
|
||||||
|
# Fill NULL GO counts with 0 for tier classification (conservative)
|
||||||
|
# But preserve original NULL for downstream NULL handling
|
||||||
|
df = df.with_columns([
|
||||||
|
pl.col("go_term_count").fill_null(0).alias("_go_count_filled"),
|
||||||
|
pl.col("uniprot_annotation_score").fill_null(0).alias("_uniprot_score_filled"),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Apply tier classification logic
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(
|
||||||
|
(pl.col("_go_count_filled") >= 20) & (pl.col("_uniprot_score_filled") >= 4)
|
||||||
|
)
|
||||||
|
.then(pl.lit("well_annotated"))
|
||||||
|
.when(
|
||||||
|
(pl.col("_go_count_filled") >= 5) | (pl.col("_uniprot_score_filled") >= 3)
|
||||||
|
)
|
||||||
|
.then(pl.lit("partially_annotated"))
|
||||||
|
.otherwise(pl.lit("poorly_annotated"))
|
||||||
|
.alias("annotation_tier")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop temporary filled columns
|
||||||
|
df = df.drop(["_go_count_filled", "_uniprot_score_filled"])
|
||||||
|
|
||||||
|
# Log tier distribution
|
||||||
|
tier_counts = df.group_by("annotation_tier").len().sort("annotation_tier")
|
||||||
|
logger.info("classify_annotation_tier_complete", tier_distribution=tier_counts.to_dicts())
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_annotation_score(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""Compute normalized composite annotation score (0-1 range).
|
||||||
|
|
||||||
|
Formula: Weighted average of three components:
|
||||||
|
- GO component (50%): log2(go_term_count + 1) normalized by max across dataset
|
||||||
|
- UniProt component (30%): uniprot_annotation_score / 5.0
|
||||||
|
- Pathway component (20%): has_pathway_membership as 0/1
|
||||||
|
|
||||||
|
Result clamped to [0, 1]. NULL if ALL three inputs are NULL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: DataFrame with go_term_count, uniprot_annotation_score, has_pathway_membership
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with annotation_score_normalized column added
|
||||||
|
"""
|
||||||
|
logger.info("normalize_annotation_score_start", row_count=df.height)
|
||||||
|
|
||||||
|
# Component weights
|
||||||
|
WEIGHT_GO = 0.5
|
||||||
|
WEIGHT_UNIPROT = 0.3
|
||||||
|
WEIGHT_PATHWAY = 0.2
|
||||||
|
|
||||||
|
# Compute GO component: log2(count + 1) normalized by max
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("go_term_count").is_not_null())
|
||||||
|
.then((pl.col("go_term_count") + 1).log(base=2))
|
||||||
|
.otherwise(None)
|
||||||
|
.alias("_go_log")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get max for normalization (from non-NULL values)
|
||||||
|
go_max = df.filter(pl.col("_go_log").is_not_null()).select(pl.col("_go_log").max()).item()
|
||||||
|
|
||||||
|
if go_max is None or go_max == 0:
|
||||||
|
# No GO data in dataset - all get NULL for GO component
|
||||||
|
df = df.with_columns(pl.lit(None).cast(pl.Float64).alias("_go_component"))
|
||||||
|
else:
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("_go_log").is_not_null())
|
||||||
|
.then((pl.col("_go_log") / go_max) * WEIGHT_GO)
|
||||||
|
.otherwise(None)
|
||||||
|
.alias("_go_component")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute UniProt component: score / 5.0
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("uniprot_annotation_score").is_not_null())
|
||||||
|
.then((pl.col("uniprot_annotation_score") / 5.0) * WEIGHT_UNIPROT)
|
||||||
|
.otherwise(None)
|
||||||
|
.alias("_uniprot_component")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute pathway component: boolean as 0/1
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("has_pathway_membership").is_not_null())
|
||||||
|
.then(
|
||||||
|
pl.when(pl.col("has_pathway_membership"))
|
||||||
|
.then(WEIGHT_PATHWAY)
|
||||||
|
.otherwise(0.0)
|
||||||
|
)
|
||||||
|
.otherwise(None)
|
||||||
|
.alias("_pathway_component")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Composite score: sum of non-NULL components, NULL if all NULL
|
||||||
|
# Need to handle NULL properly: only compute if at least one component is non-NULL
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(
|
||||||
|
pl.col("_go_component").is_not_null()
|
||||||
|
| pl.col("_uniprot_component").is_not_null()
|
||||||
|
| pl.col("_pathway_component").is_not_null()
|
||||||
|
)
|
||||||
|
.then(
|
||||||
|
# Sum components, treating NULL as 0 for the sum
|
||||||
|
pl.col("_go_component").fill_null(0.0)
|
||||||
|
+ pl.col("_uniprot_component").fill_null(0.0)
|
||||||
|
+ pl.col("_pathway_component").fill_null(0.0)
|
||||||
|
)
|
||||||
|
.otherwise(None)
|
||||||
|
.alias("annotation_score_normalized")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clamp to [0, 1] range (shouldn't exceed but defensive)
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("annotation_score_normalized").is_not_null())
|
||||||
|
.then(
|
||||||
|
pl.col("annotation_score_normalized").clip(0.0, 1.0)
|
||||||
|
)
|
||||||
|
.otherwise(None)
|
||||||
|
.alias("annotation_score_normalized")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop temporary columns
|
||||||
|
df = df.drop(["_go_log", "_go_component", "_uniprot_component", "_pathway_component"])
|
||||||
|
|
||||||
|
# Log score statistics
|
||||||
|
stats = df.filter(pl.col("annotation_score_normalized").is_not_null()).select([
|
||||||
|
pl.col("annotation_score_normalized").mean().alias("mean"),
|
||||||
|
pl.col("annotation_score_normalized").median().alias("median"),
|
||||||
|
pl.col("annotation_score_normalized").min().alias("min"),
|
||||||
|
pl.col("annotation_score_normalized").max().alias("max"),
|
||||||
|
])
|
||||||
|
|
||||||
|
if stats.height > 0:
|
||||||
|
logger.info("normalize_annotation_score_complete", stats=stats.to_dicts()[0])
|
||||||
|
else:
|
||||||
|
logger.warning("normalize_annotation_score_complete", message="No valid scores computed")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def process_annotation_evidence(
|
||||||
|
gene_ids: list[str],
|
||||||
|
uniprot_mapping: pl.DataFrame,
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""End-to-end annotation evidence processing pipeline.
|
||||||
|
|
||||||
|
Composes: fetch GO -> fetch UniProt -> join -> classify tier -> normalize -> collect.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gene_ids: List of Ensembl gene IDs to process
|
||||||
|
uniprot_mapping: DataFrame with gene_id and uniprot_accession columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Materialized DataFrame with all annotation completeness metrics ready for DuckDB
|
||||||
|
"""
|
||||||
|
logger.info("process_annotation_evidence_start", gene_count=len(gene_ids))
|
||||||
|
|
||||||
|
# Fetch GO annotations and pathway memberships
|
||||||
|
go_df = fetch_go_annotations(gene_ids)
|
||||||
|
|
||||||
|
# Fetch UniProt annotation scores
|
||||||
|
uniprot_df = fetch_uniprot_scores(gene_ids, uniprot_mapping)
|
||||||
|
|
||||||
|
# Join GO and UniProt data
|
||||||
|
df = go_df.join(uniprot_df, on="gene_id", how="left")
|
||||||
|
|
||||||
|
# Classify annotation tiers
|
||||||
|
df = classify_annotation_tier(df)
|
||||||
|
|
||||||
|
# Normalize composite score
|
||||||
|
df = normalize_annotation_score(df)
|
||||||
|
|
||||||
|
logger.info("process_annotation_evidence_complete", result_count=df.height)
|
||||||
|
|
||||||
|
return df
|
||||||
Reference in New Issue
Block a user