feat(01-03): create DuckDB persistence layer with checkpoint-restart
- PipelineStore class for DuckDB-based storage - save_dataframe/load_dataframe for polars and pandas - Checkpoint system with has_checkpoint and metadata tracking - Parquet export capability - Context manager support
This commit is contained in:
23
src/usher_pipeline/gene_mapping/__init__.py
Normal file
23
src/usher_pipeline/gene_mapping/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Gene ID mapping module.
|
||||
|
||||
Provides gene universe definition, batch ID mapping via mygene,
|
||||
and validation gates for quality control.
|
||||
"""
|
||||
|
||||
from usher_pipeline.gene_mapping.mapper import (
|
||||
GeneMapper,
|
||||
MappingResult,
|
||||
MappingReport,
|
||||
)
|
||||
from usher_pipeline.gene_mapping.universe import (
|
||||
fetch_protein_coding_genes,
|
||||
GeneUniverse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GeneMapper",
|
||||
"MappingResult",
|
||||
"MappingReport",
|
||||
"fetch_protein_coding_genes",
|
||||
"GeneUniverse",
|
||||
]
|
||||
189
src/usher_pipeline/gene_mapping/mapper.py
Normal file
189
src/usher_pipeline/gene_mapping/mapper.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Gene ID mapping via mygene batch queries.
|
||||
|
||||
Provides batch mapping from Ensembl gene IDs to HGNC symbols and UniProt accessions.
|
||||
Handles edge cases like missing data, notfound results, and nested data structures.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import mygene
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MappingResult:
|
||||
"""Single gene ID mapping result.
|
||||
|
||||
Attributes:
|
||||
ensembl_id: Ensembl gene ID (ENSG format)
|
||||
hgnc_symbol: HGNC gene symbol (None if not found)
|
||||
uniprot_accession: UniProt Swiss-Prot accession (None if not found)
|
||||
mapping_source: Data source for mapping (default: mygene)
|
||||
"""
|
||||
ensembl_id: str
|
||||
hgnc_symbol: str | None = None
|
||||
uniprot_accession: str | None = None
|
||||
mapping_source: str = "mygene"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MappingReport:
|
||||
"""Summary report for batch mapping operation.
|
||||
|
||||
Attributes:
|
||||
total_genes: Total number of genes queried
|
||||
mapped_hgnc: Number of genes with HGNC symbol
|
||||
mapped_uniprot: Number of genes with UniProt accession
|
||||
unmapped_ids: List of Ensembl IDs with no HGNC symbol
|
||||
success_rate_hgnc: Fraction of genes with HGNC symbol (0-1)
|
||||
success_rate_uniprot: Fraction of genes with UniProt accession (0-1)
|
||||
"""
|
||||
total_genes: int
|
||||
mapped_hgnc: int
|
||||
mapped_uniprot: int
|
||||
unmapped_ids: list[str] = field(default_factory=list)
|
||||
success_rate_hgnc: float = 0.0
|
||||
success_rate_uniprot: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
"""Calculate success rates after initialization."""
|
||||
if self.total_genes > 0:
|
||||
self.success_rate_hgnc = self.mapped_hgnc / self.total_genes
|
||||
self.success_rate_uniprot = self.mapped_uniprot / self.total_genes
|
||||
|
||||
|
||||
class GeneMapper:
|
||||
"""Batch gene ID mapper using mygene API.
|
||||
|
||||
Maps Ensembl gene IDs to HGNC symbols and UniProt Swiss-Prot accessions.
|
||||
Handles batch queries, missing data, and edge cases.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 1000):
|
||||
"""Initialize gene mapper.
|
||||
|
||||
Args:
|
||||
batch_size: Number of genes to query per batch (default: 1000)
|
||||
"""
|
||||
self.batch_size = batch_size
|
||||
self.mg = mygene.MyGeneInfo()
|
||||
logger.info(f"Initialized GeneMapper with batch_size={batch_size}")
|
||||
|
||||
def map_ensembl_ids(
|
||||
self,
|
||||
ensembl_ids: list[str]
|
||||
) -> tuple[list[MappingResult], MappingReport]:
|
||||
"""Map Ensembl gene IDs to HGNC symbols and UniProt accessions.
|
||||
|
||||
Uses mygene.querymany() to batch query for gene symbols and UniProt IDs.
|
||||
Processes queries in chunks of batch_size to avoid API limits.
|
||||
|
||||
Args:
|
||||
ensembl_ids: List of Ensembl gene IDs (ENSG format)
|
||||
|
||||
Returns:
|
||||
Tuple of (mapping_results, mapping_report)
|
||||
- mapping_results: List of MappingResult for each input gene
|
||||
- mapping_report: Summary statistics for the mapping operation
|
||||
|
||||
Notes:
|
||||
- Queries mygene with scopes='ensembl.gene'
|
||||
- Retrieves fields: symbol (HGNC), uniprot.Swiss-Prot
|
||||
- Handles 'notfound' results, missing keys, and nested structures
|
||||
- For duplicate query results, takes first non-null value
|
||||
"""
|
||||
total_genes = len(ensembl_ids)
|
||||
logger.info(f"Mapping {total_genes} Ensembl IDs to HGNC/UniProt")
|
||||
|
||||
results: list[MappingResult] = []
|
||||
unmapped_ids: list[str] = []
|
||||
mapped_hgnc = 0
|
||||
mapped_uniprot = 0
|
||||
|
||||
# Process in batches
|
||||
for i in range(0, total_genes, self.batch_size):
|
||||
batch = ensembl_ids[i:i + self.batch_size]
|
||||
batch_num = i // self.batch_size + 1
|
||||
total_batches = (total_genes + self.batch_size - 1) // self.batch_size
|
||||
|
||||
logger.info(
|
||||
f"Processing batch {batch_num}/{total_batches} "
|
||||
f"({len(batch)} genes)"
|
||||
)
|
||||
|
||||
# Query mygene
|
||||
batch_results = self.mg.querymany(
|
||||
batch,
|
||||
scopes='ensembl.gene',
|
||||
fields='symbol,uniprot.Swiss-Prot',
|
||||
species=9606,
|
||||
returnall=True,
|
||||
)
|
||||
|
||||
# Extract results from returnall=True format
|
||||
# mygene returns {'out': [...], 'missing': [...]} with returnall=True
|
||||
out_results = batch_results.get('out', [])
|
||||
|
||||
# Process each result
|
||||
for hit in out_results:
|
||||
ensembl_id = hit.get('query', '')
|
||||
|
||||
# Check if gene was not found
|
||||
if hit.get('notfound', False):
|
||||
results.append(MappingResult(ensembl_id=ensembl_id))
|
||||
unmapped_ids.append(ensembl_id)
|
||||
continue
|
||||
|
||||
# Extract HGNC symbol
|
||||
hgnc_symbol = hit.get('symbol')
|
||||
|
||||
# Extract UniProt accession (handle nested structure and lists)
|
||||
uniprot_accession = None
|
||||
uniprot_data = hit.get('uniprot')
|
||||
|
||||
if uniprot_data:
|
||||
# uniprot can be a dict with Swiss-Prot key
|
||||
if isinstance(uniprot_data, dict):
|
||||
swiss_prot = uniprot_data.get('Swiss-Prot')
|
||||
# Swiss-Prot can be a string or list
|
||||
if isinstance(swiss_prot, str):
|
||||
uniprot_accession = swiss_prot
|
||||
elif isinstance(swiss_prot, list) and swiss_prot:
|
||||
# Take first accession if list
|
||||
uniprot_accession = swiss_prot[0]
|
||||
|
||||
# Create mapping result
|
||||
results.append(MappingResult(
|
||||
ensembl_id=ensembl_id,
|
||||
hgnc_symbol=hgnc_symbol,
|
||||
uniprot_accession=uniprot_accession,
|
||||
))
|
||||
|
||||
# Track success counts
|
||||
if hgnc_symbol:
|
||||
mapped_hgnc += 1
|
||||
else:
|
||||
unmapped_ids.append(ensembl_id)
|
||||
|
||||
if uniprot_accession:
|
||||
mapped_uniprot += 1
|
||||
|
||||
# Create summary report
|
||||
report = MappingReport(
|
||||
total_genes=total_genes,
|
||||
mapped_hgnc=mapped_hgnc,
|
||||
mapped_uniprot=mapped_uniprot,
|
||||
unmapped_ids=unmapped_ids,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Mapping complete: {mapped_hgnc}/{total_genes} HGNC "
|
||||
f"({report.success_rate_hgnc:.1%}), "
|
||||
f"{mapped_uniprot}/{total_genes} UniProt "
|
||||
f"({report.success_rate_uniprot:.1%})"
|
||||
)
|
||||
|
||||
return results, report
|
||||
107
src/usher_pipeline/gene_mapping/universe.py
Normal file
107
src/usher_pipeline/gene_mapping/universe.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Gene universe definition and retrieval.
|
||||
|
||||
Fetches the complete set of human protein-coding genes from Ensembl via mygene.
|
||||
Validates gene count and filters to ENSG-format Ensembl IDs only.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TypeAlias
|
||||
|
||||
import mygene
|
||||
|
||||
# Type alias for gene universe lists
|
||||
GeneUniverse: TypeAlias = list[str]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Expected range for human protein-coding genes
|
||||
MIN_EXPECTED_GENES = 19000
|
||||
MAX_EXPECTED_GENES = 22000
|
||||
|
||||
|
||||
def fetch_protein_coding_genes(ensembl_release: int = 113) -> GeneUniverse:
|
||||
"""Fetch all human protein-coding genes from Ensembl via mygene.
|
||||
|
||||
Queries mygene for genes with type_of_gene=protein-coding in humans (taxid 9606).
|
||||
Filters to only genes with valid Ensembl gene IDs (ENSG format).
|
||||
Validates gene count is in expected range (19,000-22,000).
|
||||
|
||||
Args:
|
||||
ensembl_release: Ensembl release version (for documentation purposes;
|
||||
mygene returns current data regardless)
|
||||
|
||||
Returns:
|
||||
Sorted, deduplicated list of Ensembl gene IDs (ENSG format)
|
||||
|
||||
Raises:
|
||||
ValueError: If gene count is outside expected range
|
||||
|
||||
Note:
|
||||
While ensembl_release is passed for documentation, mygene API doesn't
|
||||
support querying specific Ensembl versions - it returns current data.
|
||||
For reproducibility, use cached results or versioned data snapshots.
|
||||
"""
|
||||
logger.info(
|
||||
f"Fetching protein-coding genes for Ensembl release {ensembl_release} "
|
||||
"(note: mygene returns current data)"
|
||||
)
|
||||
|
||||
# Initialize mygene client
|
||||
mg = mygene.MyGeneInfo()
|
||||
|
||||
# Query for human protein-coding genes
|
||||
logger.info("Querying mygene for type_of_gene:protein-coding (species=9606)")
|
||||
results = mg.query(
|
||||
'type_of_gene:"protein-coding"',
|
||||
species=9606,
|
||||
fields='ensembl.gene,symbol,name',
|
||||
fetch_all=True,
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved {len(results)} results from mygene")
|
||||
|
||||
# Extract Ensembl gene IDs
|
||||
gene_ids: set[str] = set()
|
||||
|
||||
for hit in results:
|
||||
# Handle both single ensembl.gene and list cases
|
||||
ensembl_data = hit.get('ensembl')
|
||||
if not ensembl_data:
|
||||
continue
|
||||
|
||||
# ensembl can be a single dict or list of dicts
|
||||
if isinstance(ensembl_data, dict):
|
||||
ensembl_list = [ensembl_data]
|
||||
else:
|
||||
ensembl_list = ensembl_data
|
||||
|
||||
# Extract gene IDs from each ensembl entry
|
||||
for ensembl_entry in ensembl_list:
|
||||
gene_id = ensembl_entry.get('gene')
|
||||
if gene_id and isinstance(gene_id, str) and gene_id.startswith('ENSG'):
|
||||
gene_ids.add(gene_id)
|
||||
|
||||
# Sort and deduplicate
|
||||
sorted_genes = sorted(gene_ids)
|
||||
gene_count = len(sorted_genes)
|
||||
|
||||
logger.info(f"Extracted {gene_count} unique Ensembl gene IDs (ENSG format)")
|
||||
|
||||
# Validate gene count
|
||||
if gene_count < MIN_EXPECTED_GENES:
|
||||
logger.warning(
|
||||
f"Gene count {gene_count} is below expected minimum {MIN_EXPECTED_GENES}. "
|
||||
"This may indicate missing data or query issues."
|
||||
)
|
||||
elif gene_count > MAX_EXPECTED_GENES:
|
||||
logger.warning(
|
||||
f"Gene count {gene_count} exceeds expected maximum {MAX_EXPECTED_GENES}. "
|
||||
"This may indicate pseudogene contamination or non-coding genes in results."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Gene count {gene_count} is within expected range "
|
||||
f"({MIN_EXPECTED_GENES}-{MAX_EXPECTED_GENES})"
|
||||
)
|
||||
|
||||
return sorted_genes
|
||||
6
src/usher_pipeline/persistence/__init__.py
Normal file
6
src/usher_pipeline/persistence/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Persistence layer for pipeline checkpoints and provenance tracking."""
|
||||
|
||||
from usher_pipeline.persistence.duckdb_store import PipelineStore
|
||||
|
||||
# ProvenanceTracker will be added in Task 2
|
||||
__all__ = ["PipelineStore"]
|
||||
232
src/usher_pipeline/persistence/duckdb_store.py
Normal file
232
src/usher_pipeline/persistence/duckdb_store.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""DuckDB-based storage for pipeline checkpoints with restart capability."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import duckdb
|
||||
import polars as pl
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
HAS_PANDAS = True
|
||||
except ImportError:
|
||||
HAS_PANDAS = False
|
||||
|
||||
|
||||
class PipelineStore:
|
||||
"""
|
||||
DuckDB-based storage for pipeline intermediate results.
|
||||
|
||||
Enables checkpoint-restart pattern: expensive operations (API downloads,
|
||||
processing) can be saved as DuckDB tables and skipped on subsequent runs.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
"""
|
||||
Initialize PipelineStore with a DuckDB database.
|
||||
|
||||
Args:
|
||||
db_path: Path to DuckDB database file. Parent directories
|
||||
are created automatically.
|
||||
"""
|
||||
self.db_path = db_path
|
||||
# Create parent directories
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Connect to DuckDB
|
||||
self.conn = duckdb.connect(str(db_path))
|
||||
|
||||
# Create metadata table for tracking checkpoints
|
||||
self.conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _checkpoints (
|
||||
table_name VARCHAR PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
row_count INTEGER,
|
||||
description VARCHAR
|
||||
)
|
||||
""")
|
||||
|
||||
def save_dataframe(
|
||||
self,
|
||||
df: Union[pl.DataFrame, "pd.DataFrame"],
|
||||
table_name: str,
|
||||
description: str = "",
|
||||
replace: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Save a DataFrame to DuckDB as a table.
|
||||
|
||||
Args:
|
||||
df: Polars or pandas DataFrame to save
|
||||
table_name: Name for the DuckDB table
|
||||
description: Optional description for checkpoint metadata
|
||||
replace: If True, replace existing table; if False, append
|
||||
"""
|
||||
# Detect DataFrame type
|
||||
is_polars = isinstance(df, pl.DataFrame)
|
||||
if not is_polars and not HAS_PANDAS:
|
||||
raise ValueError("pandas not available")
|
||||
if not is_polars and not isinstance(df, pd.DataFrame):
|
||||
raise ValueError("df must be polars.DataFrame or pandas.DataFrame")
|
||||
|
||||
# Save DataFrame
|
||||
if replace:
|
||||
self.conn.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")
|
||||
else:
|
||||
self.conn.execute(f"INSERT INTO {table_name} SELECT * FROM df")
|
||||
|
||||
# Update checkpoint metadata
|
||||
row_count = len(df)
|
||||
self.conn.execute("""
|
||||
INSERT OR REPLACE INTO _checkpoints (table_name, row_count, description, created_at)
|
||||
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
||||
""", [table_name, row_count, description])
|
||||
|
||||
def load_dataframe(
|
||||
self,
|
||||
table_name: str,
|
||||
as_polars: bool = True
|
||||
) -> Optional[Union[pl.DataFrame, "pd.DataFrame"]]:
|
||||
"""
|
||||
Load a table as a DataFrame.
|
||||
|
||||
Args:
|
||||
table_name: Name of the DuckDB table
|
||||
as_polars: If True, return polars DataFrame; else pandas
|
||||
|
||||
Returns:
|
||||
DataFrame or None if table doesn't exist
|
||||
"""
|
||||
try:
|
||||
result = self.conn.execute(f"SELECT * FROM {table_name}")
|
||||
if as_polars:
|
||||
return result.pl()
|
||||
else:
|
||||
if not HAS_PANDAS:
|
||||
raise ValueError("pandas not available")
|
||||
return result.df()
|
||||
except duckdb.CatalogException:
|
||||
# Table doesn't exist
|
||||
return None
|
||||
|
||||
def has_checkpoint(self, table_name: str) -> bool:
|
||||
"""
|
||||
Check if a checkpoint exists.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to check
|
||||
|
||||
Returns:
|
||||
True if checkpoint exists, False otherwise
|
||||
"""
|
||||
result = self.conn.execute(
|
||||
"SELECT COUNT(*) FROM _checkpoints WHERE table_name = ?",
|
||||
[table_name]
|
||||
).fetchone()
|
||||
return result[0] > 0
|
||||
|
||||
def list_checkpoints(self) -> list[dict]:
|
||||
"""
|
||||
List all checkpoints with metadata.
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata dicts with keys:
|
||||
table_name, created_at, row_count, description
|
||||
"""
|
||||
result = self.conn.execute("""
|
||||
SELECT table_name, created_at, row_count, description
|
||||
FROM _checkpoints
|
||||
ORDER BY created_at DESC
|
||||
""").fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"table_name": row[0],
|
||||
"created_at": row[1],
|
||||
"row_count": row[2],
|
||||
"description": row[3],
|
||||
}
|
||||
for row in result
|
||||
]
|
||||
|
||||
def delete_checkpoint(self, table_name: str) -> None:
|
||||
"""
|
||||
Delete a checkpoint and its metadata.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to delete
|
||||
"""
|
||||
# Drop table if exists
|
||||
self.conn.execute(f"DROP TABLE IF EXISTS {table_name}")
|
||||
|
||||
# Remove from metadata
|
||||
self.conn.execute(
|
||||
"DELETE FROM _checkpoints WHERE table_name = ?",
|
||||
[table_name]
|
||||
)
|
||||
|
||||
def export_parquet(self, table_name: str, output_path: Path) -> None:
|
||||
"""
|
||||
Export a table to Parquet format.
|
||||
|
||||
Args:
|
||||
table_name: Name of the table to export
|
||||
output_path: Path to output Parquet file
|
||||
"""
|
||||
# Create parent directories
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Export using DuckDB's native Parquet writer
|
||||
self.conn.execute(
|
||||
f"COPY {table_name} TO ? (FORMAT PARQUET)",
|
||||
[str(output_path)]
|
||||
)
|
||||
|
||||
def execute_query(
|
||||
self,
|
||||
query: str,
|
||||
params: Optional[list] = None
|
||||
) -> pl.DataFrame:
|
||||
"""
|
||||
Execute arbitrary SQL query and return polars DataFrame.
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
params: Optional query parameters
|
||||
|
||||
Returns:
|
||||
Query results as polars DataFrame
|
||||
"""
|
||||
if params:
|
||||
result = self.conn.execute(query, params)
|
||||
else:
|
||||
result = self.conn.execute(query)
|
||||
return result.pl()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the DuckDB connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit - closes connection."""
|
||||
self.close()
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "PipelineConfig") -> "PipelineStore":
|
||||
"""
|
||||
Create PipelineStore from a PipelineConfig.
|
||||
|
||||
Args:
|
||||
config: PipelineConfig instance
|
||||
|
||||
Returns:
|
||||
PipelineStore instance
|
||||
"""
|
||||
return cls(config.duckdb_path)
|
||||
Reference in New Issue
Block a user