From d51141f7d591e6d6372413f9d22840d204485eb8 Mon Sep 17 00:00:00 2001 From: gbanyan Date: Wed, 11 Feb 2026 16:30:25 +0800 Subject: [PATCH] feat(01-03): create DuckDB persistence layer with checkpoint-restart - PipelineStore class for DuckDB-based storage - save_dataframe/load_dataframe for polars and pandas - Checkpoint system with has_checkpoint and metadata tracking - Parquet export capability - Context manager support --- src/usher_pipeline/gene_mapping/__init__.py | 23 ++ src/usher_pipeline/gene_mapping/mapper.py | 189 ++++++++++++++ src/usher_pipeline/gene_mapping/universe.py | 107 ++++++++ src/usher_pipeline/persistence/__init__.py | 6 + .../persistence/duckdb_store.py | 232 ++++++++++++++++++ 5 files changed, 557 insertions(+) create mode 100644 src/usher_pipeline/gene_mapping/__init__.py create mode 100644 src/usher_pipeline/gene_mapping/mapper.py create mode 100644 src/usher_pipeline/gene_mapping/universe.py create mode 100644 src/usher_pipeline/persistence/__init__.py create mode 100644 src/usher_pipeline/persistence/duckdb_store.py diff --git a/src/usher_pipeline/gene_mapping/__init__.py b/src/usher_pipeline/gene_mapping/__init__.py new file mode 100644 index 0000000..3fa5a13 --- /dev/null +++ b/src/usher_pipeline/gene_mapping/__init__.py @@ -0,0 +1,23 @@ +"""Gene ID mapping module. + +Provides gene universe definition, batch ID mapping via mygene, +and validation gates for quality control. +""" + +from usher_pipeline.gene_mapping.mapper import ( + GeneMapper, + MappingResult, + MappingReport, +) +from usher_pipeline.gene_mapping.universe import ( + fetch_protein_coding_genes, + GeneUniverse, +) + +__all__ = [ + "GeneMapper", + "MappingResult", + "MappingReport", + "fetch_protein_coding_genes", + "GeneUniverse", +] diff --git a/src/usher_pipeline/gene_mapping/mapper.py b/src/usher_pipeline/gene_mapping/mapper.py new file mode 100644 index 0000000..47698a3 --- /dev/null +++ b/src/usher_pipeline/gene_mapping/mapper.py @@ -0,0 +1,189 @@ +"""Gene ID mapping via mygene batch queries. + +Provides batch mapping from Ensembl gene IDs to HGNC symbols and UniProt accessions. +Handles edge cases like missing data, notfound results, and nested data structures. +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +import mygene + +logger = logging.getLogger(__name__) + + +@dataclass +class MappingResult: + """Single gene ID mapping result. + + Attributes: + ensembl_id: Ensembl gene ID (ENSG format) + hgnc_symbol: HGNC gene symbol (None if not found) + uniprot_accession: UniProt Swiss-Prot accession (None if not found) + mapping_source: Data source for mapping (default: mygene) + """ + ensembl_id: str + hgnc_symbol: str | None = None + uniprot_accession: str | None = None + mapping_source: str = "mygene" + + +@dataclass +class MappingReport: + """Summary report for batch mapping operation. + + Attributes: + total_genes: Total number of genes queried + mapped_hgnc: Number of genes with HGNC symbol + mapped_uniprot: Number of genes with UniProt accession + unmapped_ids: List of Ensembl IDs with no HGNC symbol + success_rate_hgnc: Fraction of genes with HGNC symbol (0-1) + success_rate_uniprot: Fraction of genes with UniProt accession (0-1) + """ + total_genes: int + mapped_hgnc: int + mapped_uniprot: int + unmapped_ids: list[str] = field(default_factory=list) + success_rate_hgnc: float = 0.0 + success_rate_uniprot: float = 0.0 + + def __post_init__(self): + """Calculate success rates after initialization.""" + if self.total_genes > 0: + self.success_rate_hgnc = self.mapped_hgnc / self.total_genes + self.success_rate_uniprot = self.mapped_uniprot / self.total_genes + + +class GeneMapper: + """Batch gene ID mapper using mygene API. + + Maps Ensembl gene IDs to HGNC symbols and UniProt Swiss-Prot accessions. + Handles batch queries, missing data, and edge cases. + """ + + def __init__(self, batch_size: int = 1000): + """Initialize gene mapper. + + Args: + batch_size: Number of genes to query per batch (default: 1000) + """ + self.batch_size = batch_size + self.mg = mygene.MyGeneInfo() + logger.info(f"Initialized GeneMapper with batch_size={batch_size}") + + def map_ensembl_ids( + self, + ensembl_ids: list[str] + ) -> tuple[list[MappingResult], MappingReport]: + """Map Ensembl gene IDs to HGNC symbols and UniProt accessions. + + Uses mygene.querymany() to batch query for gene symbols and UniProt IDs. + Processes queries in chunks of batch_size to avoid API limits. + + Args: + ensembl_ids: List of Ensembl gene IDs (ENSG format) + + Returns: + Tuple of (mapping_results, mapping_report) + - mapping_results: List of MappingResult for each input gene + - mapping_report: Summary statistics for the mapping operation + + Notes: + - Queries mygene with scopes='ensembl.gene' + - Retrieves fields: symbol (HGNC), uniprot.Swiss-Prot + - Handles 'notfound' results, missing keys, and nested structures + - For duplicate query results, takes first non-null value + """ + total_genes = len(ensembl_ids) + logger.info(f"Mapping {total_genes} Ensembl IDs to HGNC/UniProt") + + results: list[MappingResult] = [] + unmapped_ids: list[str] = [] + mapped_hgnc = 0 + mapped_uniprot = 0 + + # Process in batches + for i in range(0, total_genes, self.batch_size): + batch = ensembl_ids[i:i + self.batch_size] + batch_num = i // self.batch_size + 1 + total_batches = (total_genes + self.batch_size - 1) // self.batch_size + + logger.info( + f"Processing batch {batch_num}/{total_batches} " + f"({len(batch)} genes)" + ) + + # Query mygene + batch_results = self.mg.querymany( + batch, + scopes='ensembl.gene', + fields='symbol,uniprot.Swiss-Prot', + species=9606, + returnall=True, + ) + + # Extract results from returnall=True format + # mygene returns {'out': [...], 'missing': [...]} with returnall=True + out_results = batch_results.get('out', []) + + # Process each result + for hit in out_results: + ensembl_id = hit.get('query', '') + + # Check if gene was not found + if hit.get('notfound', False): + results.append(MappingResult(ensembl_id=ensembl_id)) + unmapped_ids.append(ensembl_id) + continue + + # Extract HGNC symbol + hgnc_symbol = hit.get('symbol') + + # Extract UniProt accession (handle nested structure and lists) + uniprot_accession = None + uniprot_data = hit.get('uniprot') + + if uniprot_data: + # uniprot can be a dict with Swiss-Prot key + if isinstance(uniprot_data, dict): + swiss_prot = uniprot_data.get('Swiss-Prot') + # Swiss-Prot can be a string or list + if isinstance(swiss_prot, str): + uniprot_accession = swiss_prot + elif isinstance(swiss_prot, list) and swiss_prot: + # Take first accession if list + uniprot_accession = swiss_prot[0] + + # Create mapping result + results.append(MappingResult( + ensembl_id=ensembl_id, + hgnc_symbol=hgnc_symbol, + uniprot_accession=uniprot_accession, + )) + + # Track success counts + if hgnc_symbol: + mapped_hgnc += 1 + else: + unmapped_ids.append(ensembl_id) + + if uniprot_accession: + mapped_uniprot += 1 + + # Create summary report + report = MappingReport( + total_genes=total_genes, + mapped_hgnc=mapped_hgnc, + mapped_uniprot=mapped_uniprot, + unmapped_ids=unmapped_ids, + ) + + logger.info( + f"Mapping complete: {mapped_hgnc}/{total_genes} HGNC " + f"({report.success_rate_hgnc:.1%}), " + f"{mapped_uniprot}/{total_genes} UniProt " + f"({report.success_rate_uniprot:.1%})" + ) + + return results, report diff --git a/src/usher_pipeline/gene_mapping/universe.py b/src/usher_pipeline/gene_mapping/universe.py new file mode 100644 index 0000000..1d21a9c --- /dev/null +++ b/src/usher_pipeline/gene_mapping/universe.py @@ -0,0 +1,107 @@ +"""Gene universe definition and retrieval. + +Fetches the complete set of human protein-coding genes from Ensembl via mygene. +Validates gene count and filters to ENSG-format Ensembl IDs only. +""" + +import logging +from typing import TypeAlias + +import mygene + +# Type alias for gene universe lists +GeneUniverse: TypeAlias = list[str] + +logger = logging.getLogger(__name__) + +# Expected range for human protein-coding genes +MIN_EXPECTED_GENES = 19000 +MAX_EXPECTED_GENES = 22000 + + +def fetch_protein_coding_genes(ensembl_release: int = 113) -> GeneUniverse: + """Fetch all human protein-coding genes from Ensembl via mygene. + + Queries mygene for genes with type_of_gene=protein-coding in humans (taxid 9606). + Filters to only genes with valid Ensembl gene IDs (ENSG format). + Validates gene count is in expected range (19,000-22,000). + + Args: + ensembl_release: Ensembl release version (for documentation purposes; + mygene returns current data regardless) + + Returns: + Sorted, deduplicated list of Ensembl gene IDs (ENSG format) + + Raises: + ValueError: If gene count is outside expected range + + Note: + While ensembl_release is passed for documentation, mygene API doesn't + support querying specific Ensembl versions - it returns current data. + For reproducibility, use cached results or versioned data snapshots. + """ + logger.info( + f"Fetching protein-coding genes for Ensembl release {ensembl_release} " + "(note: mygene returns current data)" + ) + + # Initialize mygene client + mg = mygene.MyGeneInfo() + + # Query for human protein-coding genes + logger.info("Querying mygene for type_of_gene:protein-coding (species=9606)") + results = mg.query( + 'type_of_gene:"protein-coding"', + species=9606, + fields='ensembl.gene,symbol,name', + fetch_all=True, + ) + + logger.info(f"Retrieved {len(results)} results from mygene") + + # Extract Ensembl gene IDs + gene_ids: set[str] = set() + + for hit in results: + # Handle both single ensembl.gene and list cases + ensembl_data = hit.get('ensembl') + if not ensembl_data: + continue + + # ensembl can be a single dict or list of dicts + if isinstance(ensembl_data, dict): + ensembl_list = [ensembl_data] + else: + ensembl_list = ensembl_data + + # Extract gene IDs from each ensembl entry + for ensembl_entry in ensembl_list: + gene_id = ensembl_entry.get('gene') + if gene_id and isinstance(gene_id, str) and gene_id.startswith('ENSG'): + gene_ids.add(gene_id) + + # Sort and deduplicate + sorted_genes = sorted(gene_ids) + gene_count = len(sorted_genes) + + logger.info(f"Extracted {gene_count} unique Ensembl gene IDs (ENSG format)") + + # Validate gene count + if gene_count < MIN_EXPECTED_GENES: + logger.warning( + f"Gene count {gene_count} is below expected minimum {MIN_EXPECTED_GENES}. " + "This may indicate missing data or query issues." + ) + elif gene_count > MAX_EXPECTED_GENES: + logger.warning( + f"Gene count {gene_count} exceeds expected maximum {MAX_EXPECTED_GENES}. " + "This may indicate pseudogene contamination or non-coding genes in results." + ) + else: + logger.info( + f"Gene count {gene_count} is within expected range " + f"({MIN_EXPECTED_GENES}-{MAX_EXPECTED_GENES})" + ) + + return sorted_genes diff --git a/src/usher_pipeline/persistence/__init__.py b/src/usher_pipeline/persistence/__init__.py new file mode 100644 index 0000000..e422ab1 --- /dev/null +++ b/src/usher_pipeline/persistence/__init__.py @@ -0,0 +1,6 @@ +"""Persistence layer for pipeline checkpoints and provenance tracking.""" + +from usher_pipeline.persistence.duckdb_store import PipelineStore + +# ProvenanceTracker will be added in Task 2 +__all__ = ["PipelineStore"] diff --git a/src/usher_pipeline/persistence/duckdb_store.py b/src/usher_pipeline/persistence/duckdb_store.py new file mode 100644 index 0000000..6dffc50 --- /dev/null +++ b/src/usher_pipeline/persistence/duckdb_store.py @@ -0,0 +1,232 @@ +"""DuckDB-based storage for pipeline checkpoints with restart capability.""" + +from pathlib import Path +from typing import Optional, Union + +import duckdb +import polars as pl + +try: + import pandas as pd + HAS_PANDAS = True +except ImportError: + HAS_PANDAS = False + + +class PipelineStore: + """ + DuckDB-based storage for pipeline intermediate results. + + Enables checkpoint-restart pattern: expensive operations (API downloads, + processing) can be saved as DuckDB tables and skipped on subsequent runs. + """ + + def __init__(self, db_path: Path): + """ + Initialize PipelineStore with a DuckDB database. + + Args: + db_path: Path to DuckDB database file. Parent directories + are created automatically. + """ + self.db_path = db_path + # Create parent directories + db_path.parent.mkdir(parents=True, exist_ok=True) + + # Connect to DuckDB + self.conn = duckdb.connect(str(db_path)) + + # Create metadata table for tracking checkpoints + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS _checkpoints ( + table_name VARCHAR PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + row_count INTEGER, + description VARCHAR + ) + """) + + def save_dataframe( + self, + df: Union[pl.DataFrame, "pd.DataFrame"], + table_name: str, + description: str = "", + replace: bool = True + ) -> None: + """ + Save a DataFrame to DuckDB as a table. + + Args: + df: Polars or pandas DataFrame to save + table_name: Name for the DuckDB table + description: Optional description for checkpoint metadata + replace: If True, replace existing table; if False, append + """ + # Detect DataFrame type + is_polars = isinstance(df, pl.DataFrame) + if not is_polars and not HAS_PANDAS: + raise ValueError("pandas not available") + if not is_polars and not isinstance(df, pd.DataFrame): + raise ValueError("df must be polars.DataFrame or pandas.DataFrame") + + # Save DataFrame + if replace: + self.conn.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df") + else: + self.conn.execute(f"INSERT INTO {table_name} SELECT * FROM df") + + # Update checkpoint metadata + row_count = len(df) + self.conn.execute(""" + INSERT OR REPLACE INTO _checkpoints (table_name, row_count, description, created_at) + VALUES (?, ?, ?, CURRENT_TIMESTAMP) + """, [table_name, row_count, description]) + + def load_dataframe( + self, + table_name: str, + as_polars: bool = True + ) -> Optional[Union[pl.DataFrame, "pd.DataFrame"]]: + """ + Load a table as a DataFrame. + + Args: + table_name: Name of the DuckDB table + as_polars: If True, return polars DataFrame; else pandas + + Returns: + DataFrame or None if table doesn't exist + """ + try: + result = self.conn.execute(f"SELECT * FROM {table_name}") + if as_polars: + return result.pl() + else: + if not HAS_PANDAS: + raise ValueError("pandas not available") + return result.df() + except duckdb.CatalogException: + # Table doesn't exist + return None + + def has_checkpoint(self, table_name: str) -> bool: + """ + Check if a checkpoint exists. + + Args: + table_name: Name of the table to check + + Returns: + True if checkpoint exists, False otherwise + """ + result = self.conn.execute( + "SELECT COUNT(*) FROM _checkpoints WHERE table_name = ?", + [table_name] + ).fetchone() + return result[0] > 0 + + def list_checkpoints(self) -> list[dict]: + """ + List all checkpoints with metadata. + + Returns: + List of checkpoint metadata dicts with keys: + table_name, created_at, row_count, description + """ + result = self.conn.execute(""" + SELECT table_name, created_at, row_count, description + FROM _checkpoints + ORDER BY created_at DESC + """).fetchall() + + return [ + { + "table_name": row[0], + "created_at": row[1], + "row_count": row[2], + "description": row[3], + } + for row in result + ] + + def delete_checkpoint(self, table_name: str) -> None: + """ + Delete a checkpoint and its metadata. + + Args: + table_name: Name of the table to delete + """ + # Drop table if exists + self.conn.execute(f"DROP TABLE IF EXISTS {table_name}") + + # Remove from metadata + self.conn.execute( + "DELETE FROM _checkpoints WHERE table_name = ?", + [table_name] + ) + + def export_parquet(self, table_name: str, output_path: Path) -> None: + """ + Export a table to Parquet format. + + Args: + table_name: Name of the table to export + output_path: Path to output Parquet file + """ + # Create parent directories + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Export using DuckDB's native Parquet writer + self.conn.execute( + f"COPY {table_name} TO ? (FORMAT PARQUET)", + [str(output_path)] + ) + + def execute_query( + self, + query: str, + params: Optional[list] = None + ) -> pl.DataFrame: + """ + Execute arbitrary SQL query and return polars DataFrame. + + Args: + query: SQL query to execute + params: Optional query parameters + + Returns: + Query results as polars DataFrame + """ + if params: + result = self.conn.execute(query, params) + else: + result = self.conn.execute(query) + return result.pl() + + def close(self) -> None: + """Close the DuckDB connection.""" + if self.conn: + self.conn.close() + self.conn = None + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - closes connection.""" + self.close() + return False + + @classmethod + def from_config(cls, config: "PipelineConfig") -> "PipelineStore": + """ + Create PipelineStore from a PipelineConfig. + + Args: + config: PipelineConfig instance + + Returns: + PipelineStore instance + """ + return cls(config.duckdb_path)