feat(01-03): create DuckDB persistence layer with checkpoint-restart

- PipelineStore class for DuckDB-based storage
- save_dataframe/load_dataframe for polars and pandas
- Checkpoint system with has_checkpoint and metadata tracking
- Parquet export capability
- Context manager support
This commit is contained in:
2026-02-11 16:30:25 +08:00
parent 9ee3ec2e84
commit d51141f7d5
5 changed files with 557 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
"""Gene ID mapping module.
Provides gene universe definition, batch ID mapping via mygene,
and validation gates for quality control.
"""
from usher_pipeline.gene_mapping.mapper import (
GeneMapper,
MappingResult,
MappingReport,
)
from usher_pipeline.gene_mapping.universe import (
fetch_protein_coding_genes,
GeneUniverse,
)
__all__ = [
"GeneMapper",
"MappingResult",
"MappingReport",
"fetch_protein_coding_genes",
"GeneUniverse",
]

View File

@@ -0,0 +1,189 @@
"""Gene ID mapping via mygene batch queries.
Provides batch mapping from Ensembl gene IDs to HGNC symbols and UniProt accessions.
Handles edge cases like missing data, notfound results, and nested data structures.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
import mygene
logger = logging.getLogger(__name__)
@dataclass
class MappingResult:
"""Single gene ID mapping result.
Attributes:
ensembl_id: Ensembl gene ID (ENSG format)
hgnc_symbol: HGNC gene symbol (None if not found)
uniprot_accession: UniProt Swiss-Prot accession (None if not found)
mapping_source: Data source for mapping (default: mygene)
"""
ensembl_id: str
hgnc_symbol: str | None = None
uniprot_accession: str | None = None
mapping_source: str = "mygene"
@dataclass
class MappingReport:
"""Summary report for batch mapping operation.
Attributes:
total_genes: Total number of genes queried
mapped_hgnc: Number of genes with HGNC symbol
mapped_uniprot: Number of genes with UniProt accession
unmapped_ids: List of Ensembl IDs with no HGNC symbol
success_rate_hgnc: Fraction of genes with HGNC symbol (0-1)
success_rate_uniprot: Fraction of genes with UniProt accession (0-1)
"""
total_genes: int
mapped_hgnc: int
mapped_uniprot: int
unmapped_ids: list[str] = field(default_factory=list)
success_rate_hgnc: float = 0.0
success_rate_uniprot: float = 0.0
def __post_init__(self):
"""Calculate success rates after initialization."""
if self.total_genes > 0:
self.success_rate_hgnc = self.mapped_hgnc / self.total_genes
self.success_rate_uniprot = self.mapped_uniprot / self.total_genes
class GeneMapper:
"""Batch gene ID mapper using mygene API.
Maps Ensembl gene IDs to HGNC symbols and UniProt Swiss-Prot accessions.
Handles batch queries, missing data, and edge cases.
"""
def __init__(self, batch_size: int = 1000):
"""Initialize gene mapper.
Args:
batch_size: Number of genes to query per batch (default: 1000)
"""
self.batch_size = batch_size
self.mg = mygene.MyGeneInfo()
logger.info(f"Initialized GeneMapper with batch_size={batch_size}")
def map_ensembl_ids(
self,
ensembl_ids: list[str]
) -> tuple[list[MappingResult], MappingReport]:
"""Map Ensembl gene IDs to HGNC symbols and UniProt accessions.
Uses mygene.querymany() to batch query for gene symbols and UniProt IDs.
Processes queries in chunks of batch_size to avoid API limits.
Args:
ensembl_ids: List of Ensembl gene IDs (ENSG format)
Returns:
Tuple of (mapping_results, mapping_report)
- mapping_results: List of MappingResult for each input gene
- mapping_report: Summary statistics for the mapping operation
Notes:
- Queries mygene with scopes='ensembl.gene'
- Retrieves fields: symbol (HGNC), uniprot.Swiss-Prot
- Handles 'notfound' results, missing keys, and nested structures
- For duplicate query results, takes first non-null value
"""
total_genes = len(ensembl_ids)
logger.info(f"Mapping {total_genes} Ensembl IDs to HGNC/UniProt")
results: list[MappingResult] = []
unmapped_ids: list[str] = []
mapped_hgnc = 0
mapped_uniprot = 0
# Process in batches
for i in range(0, total_genes, self.batch_size):
batch = ensembl_ids[i:i + self.batch_size]
batch_num = i // self.batch_size + 1
total_batches = (total_genes + self.batch_size - 1) // self.batch_size
logger.info(
f"Processing batch {batch_num}/{total_batches} "
f"({len(batch)} genes)"
)
# Query mygene
batch_results = self.mg.querymany(
batch,
scopes='ensembl.gene',
fields='symbol,uniprot.Swiss-Prot',
species=9606,
returnall=True,
)
# Extract results from returnall=True format
# mygene returns {'out': [...], 'missing': [...]} with returnall=True
out_results = batch_results.get('out', [])
# Process each result
for hit in out_results:
ensembl_id = hit.get('query', '')
# Check if gene was not found
if hit.get('notfound', False):
results.append(MappingResult(ensembl_id=ensembl_id))
unmapped_ids.append(ensembl_id)
continue
# Extract HGNC symbol
hgnc_symbol = hit.get('symbol')
# Extract UniProt accession (handle nested structure and lists)
uniprot_accession = None
uniprot_data = hit.get('uniprot')
if uniprot_data:
# uniprot can be a dict with Swiss-Prot key
if isinstance(uniprot_data, dict):
swiss_prot = uniprot_data.get('Swiss-Prot')
# Swiss-Prot can be a string or list
if isinstance(swiss_prot, str):
uniprot_accession = swiss_prot
elif isinstance(swiss_prot, list) and swiss_prot:
# Take first accession if list
uniprot_accession = swiss_prot[0]
# Create mapping result
results.append(MappingResult(
ensembl_id=ensembl_id,
hgnc_symbol=hgnc_symbol,
uniprot_accession=uniprot_accession,
))
# Track success counts
if hgnc_symbol:
mapped_hgnc += 1
else:
unmapped_ids.append(ensembl_id)
if uniprot_accession:
mapped_uniprot += 1
# Create summary report
report = MappingReport(
total_genes=total_genes,
mapped_hgnc=mapped_hgnc,
mapped_uniprot=mapped_uniprot,
unmapped_ids=unmapped_ids,
)
logger.info(
f"Mapping complete: {mapped_hgnc}/{total_genes} HGNC "
f"({report.success_rate_hgnc:.1%}), "
f"{mapped_uniprot}/{total_genes} UniProt "
f"({report.success_rate_uniprot:.1%})"
)
return results, report

View File

@@ -0,0 +1,107 @@
"""Gene universe definition and retrieval.
Fetches the complete set of human protein-coding genes from Ensembl via mygene.
Validates gene count and filters to ENSG-format Ensembl IDs only.
"""
import logging
from typing import TypeAlias
import mygene
# Type alias for gene universe lists
GeneUniverse: TypeAlias = list[str]
logger = logging.getLogger(__name__)
# Expected range for human protein-coding genes
MIN_EXPECTED_GENES = 19000
MAX_EXPECTED_GENES = 22000
def fetch_protein_coding_genes(ensembl_release: int = 113) -> GeneUniverse:
"""Fetch all human protein-coding genes from Ensembl via mygene.
Queries mygene for genes with type_of_gene=protein-coding in humans (taxid 9606).
Filters to only genes with valid Ensembl gene IDs (ENSG format).
Validates gene count is in expected range (19,000-22,000).
Args:
ensembl_release: Ensembl release version (for documentation purposes;
mygene returns current data regardless)
Returns:
Sorted, deduplicated list of Ensembl gene IDs (ENSG format)
Raises:
ValueError: If gene count is outside expected range
Note:
While ensembl_release is passed for documentation, mygene API doesn't
support querying specific Ensembl versions - it returns current data.
For reproducibility, use cached results or versioned data snapshots.
"""
logger.info(
f"Fetching protein-coding genes for Ensembl release {ensembl_release} "
"(note: mygene returns current data)"
)
# Initialize mygene client
mg = mygene.MyGeneInfo()
# Query for human protein-coding genes
logger.info("Querying mygene for type_of_gene:protein-coding (species=9606)")
results = mg.query(
'type_of_gene:"protein-coding"',
species=9606,
fields='ensembl.gene,symbol,name',
fetch_all=True,
)
logger.info(f"Retrieved {len(results)} results from mygene")
# Extract Ensembl gene IDs
gene_ids: set[str] = set()
for hit in results:
# Handle both single ensembl.gene and list cases
ensembl_data = hit.get('ensembl')
if not ensembl_data:
continue
# ensembl can be a single dict or list of dicts
if isinstance(ensembl_data, dict):
ensembl_list = [ensembl_data]
else:
ensembl_list = ensembl_data
# Extract gene IDs from each ensembl entry
for ensembl_entry in ensembl_list:
gene_id = ensembl_entry.get('gene')
if gene_id and isinstance(gene_id, str) and gene_id.startswith('ENSG'):
gene_ids.add(gene_id)
# Sort and deduplicate
sorted_genes = sorted(gene_ids)
gene_count = len(sorted_genes)
logger.info(f"Extracted {gene_count} unique Ensembl gene IDs (ENSG format)")
# Validate gene count
if gene_count < MIN_EXPECTED_GENES:
logger.warning(
f"Gene count {gene_count} is below expected minimum {MIN_EXPECTED_GENES}. "
"This may indicate missing data or query issues."
)
elif gene_count > MAX_EXPECTED_GENES:
logger.warning(
f"Gene count {gene_count} exceeds expected maximum {MAX_EXPECTED_GENES}. "
"This may indicate pseudogene contamination or non-coding genes in results."
)
else:
logger.info(
f"Gene count {gene_count} is within expected range "
f"({MIN_EXPECTED_GENES}-{MAX_EXPECTED_GENES})"
)
return sorted_genes

View File

@@ -0,0 +1,6 @@
"""Persistence layer for pipeline checkpoints and provenance tracking."""
from usher_pipeline.persistence.duckdb_store import PipelineStore
# ProvenanceTracker will be added in Task 2
__all__ = ["PipelineStore"]

View File

@@ -0,0 +1,232 @@
"""DuckDB-based storage for pipeline checkpoints with restart capability."""
from pathlib import Path
from typing import Optional, Union
import duckdb
import polars as pl
try:
import pandas as pd
HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False
class PipelineStore:
"""
DuckDB-based storage for pipeline intermediate results.
Enables checkpoint-restart pattern: expensive operations (API downloads,
processing) can be saved as DuckDB tables and skipped on subsequent runs.
"""
def __init__(self, db_path: Path):
"""
Initialize PipelineStore with a DuckDB database.
Args:
db_path: Path to DuckDB database file. Parent directories
are created automatically.
"""
self.db_path = db_path
# Create parent directories
db_path.parent.mkdir(parents=True, exist_ok=True)
# Connect to DuckDB
self.conn = duckdb.connect(str(db_path))
# Create metadata table for tracking checkpoints
self.conn.execute("""
CREATE TABLE IF NOT EXISTS _checkpoints (
table_name VARCHAR PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
row_count INTEGER,
description VARCHAR
)
""")
def save_dataframe(
self,
df: Union[pl.DataFrame, "pd.DataFrame"],
table_name: str,
description: str = "",
replace: bool = True
) -> None:
"""
Save a DataFrame to DuckDB as a table.
Args:
df: Polars or pandas DataFrame to save
table_name: Name for the DuckDB table
description: Optional description for checkpoint metadata
replace: If True, replace existing table; if False, append
"""
# Detect DataFrame type
is_polars = isinstance(df, pl.DataFrame)
if not is_polars and not HAS_PANDAS:
raise ValueError("pandas not available")
if not is_polars and not isinstance(df, pd.DataFrame):
raise ValueError("df must be polars.DataFrame or pandas.DataFrame")
# Save DataFrame
if replace:
self.conn.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")
else:
self.conn.execute(f"INSERT INTO {table_name} SELECT * FROM df")
# Update checkpoint metadata
row_count = len(df)
self.conn.execute("""
INSERT OR REPLACE INTO _checkpoints (table_name, row_count, description, created_at)
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
""", [table_name, row_count, description])
def load_dataframe(
self,
table_name: str,
as_polars: bool = True
) -> Optional[Union[pl.DataFrame, "pd.DataFrame"]]:
"""
Load a table as a DataFrame.
Args:
table_name: Name of the DuckDB table
as_polars: If True, return polars DataFrame; else pandas
Returns:
DataFrame or None if table doesn't exist
"""
try:
result = self.conn.execute(f"SELECT * FROM {table_name}")
if as_polars:
return result.pl()
else:
if not HAS_PANDAS:
raise ValueError("pandas not available")
return result.df()
except duckdb.CatalogException:
# Table doesn't exist
return None
def has_checkpoint(self, table_name: str) -> bool:
"""
Check if a checkpoint exists.
Args:
table_name: Name of the table to check
Returns:
True if checkpoint exists, False otherwise
"""
result = self.conn.execute(
"SELECT COUNT(*) FROM _checkpoints WHERE table_name = ?",
[table_name]
).fetchone()
return result[0] > 0
def list_checkpoints(self) -> list[dict]:
"""
List all checkpoints with metadata.
Returns:
List of checkpoint metadata dicts with keys:
table_name, created_at, row_count, description
"""
result = self.conn.execute("""
SELECT table_name, created_at, row_count, description
FROM _checkpoints
ORDER BY created_at DESC
""").fetchall()
return [
{
"table_name": row[0],
"created_at": row[1],
"row_count": row[2],
"description": row[3],
}
for row in result
]
def delete_checkpoint(self, table_name: str) -> None:
"""
Delete a checkpoint and its metadata.
Args:
table_name: Name of the table to delete
"""
# Drop table if exists
self.conn.execute(f"DROP TABLE IF EXISTS {table_name}")
# Remove from metadata
self.conn.execute(
"DELETE FROM _checkpoints WHERE table_name = ?",
[table_name]
)
def export_parquet(self, table_name: str, output_path: Path) -> None:
"""
Export a table to Parquet format.
Args:
table_name: Name of the table to export
output_path: Path to output Parquet file
"""
# Create parent directories
output_path.parent.mkdir(parents=True, exist_ok=True)
# Export using DuckDB's native Parquet writer
self.conn.execute(
f"COPY {table_name} TO ? (FORMAT PARQUET)",
[str(output_path)]
)
def execute_query(
self,
query: str,
params: Optional[list] = None
) -> pl.DataFrame:
"""
Execute arbitrary SQL query and return polars DataFrame.
Args:
query: SQL query to execute
params: Optional query parameters
Returns:
Query results as polars DataFrame
"""
if params:
result = self.conn.execute(query, params)
else:
result = self.conn.execute(query)
return result.pl()
def close(self) -> None:
"""Close the DuckDB connection."""
if self.conn:
self.conn.close()
self.conn = None
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - closes connection."""
self.close()
return False
@classmethod
def from_config(cls, config: "PipelineConfig") -> "PipelineStore":
"""
Create PipelineStore from a PipelineConfig.
Args:
config: PipelineConfig instance
Returns:
PipelineStore instance
"""
return cls(config.duckdb_path)