feat(01-03): create DuckDB persistence layer with checkpoint-restart
- PipelineStore class for DuckDB-based storage - save_dataframe/load_dataframe for polars and pandas - Checkpoint system with has_checkpoint and metadata tracking - Parquet export capability - Context manager support
This commit is contained in:
23
src/usher_pipeline/gene_mapping/__init__.py
Normal file
23
src/usher_pipeline/gene_mapping/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Gene ID mapping module.
|
||||||
|
|
||||||
|
Provides gene universe definition, batch ID mapping via mygene,
|
||||||
|
and validation gates for quality control.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from usher_pipeline.gene_mapping.mapper import (
|
||||||
|
GeneMapper,
|
||||||
|
MappingResult,
|
||||||
|
MappingReport,
|
||||||
|
)
|
||||||
|
from usher_pipeline.gene_mapping.universe import (
|
||||||
|
fetch_protein_coding_genes,
|
||||||
|
GeneUniverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GeneMapper",
|
||||||
|
"MappingResult",
|
||||||
|
"MappingReport",
|
||||||
|
"fetch_protein_coding_genes",
|
||||||
|
"GeneUniverse",
|
||||||
|
]
|
||||||
189
src/usher_pipeline/gene_mapping/mapper.py
Normal file
189
src/usher_pipeline/gene_mapping/mapper.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Gene ID mapping via mygene batch queries.
|
||||||
|
|
||||||
|
Provides batch mapping from Ensembl gene IDs to HGNC symbols and UniProt accessions.
|
||||||
|
Handles edge cases like missing data, notfound results, and nested data structures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import mygene
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MappingResult:
|
||||||
|
"""Single gene ID mapping result.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ensembl_id: Ensembl gene ID (ENSG format)
|
||||||
|
hgnc_symbol: HGNC gene symbol (None if not found)
|
||||||
|
uniprot_accession: UniProt Swiss-Prot accession (None if not found)
|
||||||
|
mapping_source: Data source for mapping (default: mygene)
|
||||||
|
"""
|
||||||
|
ensembl_id: str
|
||||||
|
hgnc_symbol: str | None = None
|
||||||
|
uniprot_accession: str | None = None
|
||||||
|
mapping_source: str = "mygene"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MappingReport:
|
||||||
|
"""Summary report for batch mapping operation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
total_genes: Total number of genes queried
|
||||||
|
mapped_hgnc: Number of genes with HGNC symbol
|
||||||
|
mapped_uniprot: Number of genes with UniProt accession
|
||||||
|
unmapped_ids: List of Ensembl IDs with no HGNC symbol
|
||||||
|
success_rate_hgnc: Fraction of genes with HGNC symbol (0-1)
|
||||||
|
success_rate_uniprot: Fraction of genes with UniProt accession (0-1)
|
||||||
|
"""
|
||||||
|
total_genes: int
|
||||||
|
mapped_hgnc: int
|
||||||
|
mapped_uniprot: int
|
||||||
|
unmapped_ids: list[str] = field(default_factory=list)
|
||||||
|
success_rate_hgnc: float = 0.0
|
||||||
|
success_rate_uniprot: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Calculate success rates after initialization."""
|
||||||
|
if self.total_genes > 0:
|
||||||
|
self.success_rate_hgnc = self.mapped_hgnc / self.total_genes
|
||||||
|
self.success_rate_uniprot = self.mapped_uniprot / self.total_genes
|
||||||
|
|
||||||
|
|
||||||
|
class GeneMapper:
|
||||||
|
"""Batch gene ID mapper using mygene API.
|
||||||
|
|
||||||
|
Maps Ensembl gene IDs to HGNC symbols and UniProt Swiss-Prot accessions.
|
||||||
|
Handles batch queries, missing data, and edge cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, batch_size: int = 1000):
|
||||||
|
"""Initialize gene mapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Number of genes to query per batch (default: 1000)
|
||||||
|
"""
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.mg = mygene.MyGeneInfo()
|
||||||
|
logger.info(f"Initialized GeneMapper with batch_size={batch_size}")
|
||||||
|
|
||||||
|
def map_ensembl_ids(
|
||||||
|
self,
|
||||||
|
ensembl_ids: list[str]
|
||||||
|
) -> tuple[list[MappingResult], MappingReport]:
|
||||||
|
"""Map Ensembl gene IDs to HGNC symbols and UniProt accessions.
|
||||||
|
|
||||||
|
Uses mygene.querymany() to batch query for gene symbols and UniProt IDs.
|
||||||
|
Processes queries in chunks of batch_size to avoid API limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ensembl_ids: List of Ensembl gene IDs (ENSG format)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mapping_results, mapping_report)
|
||||||
|
- mapping_results: List of MappingResult for each input gene
|
||||||
|
- mapping_report: Summary statistics for the mapping operation
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Queries mygene with scopes='ensembl.gene'
|
||||||
|
- Retrieves fields: symbol (HGNC), uniprot.Swiss-Prot
|
||||||
|
- Handles 'notfound' results, missing keys, and nested structures
|
||||||
|
- For duplicate query results, takes first non-null value
|
||||||
|
"""
|
||||||
|
total_genes = len(ensembl_ids)
|
||||||
|
logger.info(f"Mapping {total_genes} Ensembl IDs to HGNC/UniProt")
|
||||||
|
|
||||||
|
results: list[MappingResult] = []
|
||||||
|
unmapped_ids: list[str] = []
|
||||||
|
mapped_hgnc = 0
|
||||||
|
mapped_uniprot = 0
|
||||||
|
|
||||||
|
# Process in batches
|
||||||
|
for i in range(0, total_genes, self.batch_size):
|
||||||
|
batch = ensembl_ids[i:i + self.batch_size]
|
||||||
|
batch_num = i // self.batch_size + 1
|
||||||
|
total_batches = (total_genes + self.batch_size - 1) // self.batch_size
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Processing batch {batch_num}/{total_batches} "
|
||||||
|
f"({len(batch)} genes)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query mygene
|
||||||
|
batch_results = self.mg.querymany(
|
||||||
|
batch,
|
||||||
|
scopes='ensembl.gene',
|
||||||
|
fields='symbol,uniprot.Swiss-Prot',
|
||||||
|
species=9606,
|
||||||
|
returnall=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract results from returnall=True format
|
||||||
|
# mygene returns {'out': [...], 'missing': [...]} with returnall=True
|
||||||
|
out_results = batch_results.get('out', [])
|
||||||
|
|
||||||
|
# Process each result
|
||||||
|
for hit in out_results:
|
||||||
|
ensembl_id = hit.get('query', '')
|
||||||
|
|
||||||
|
# Check if gene was not found
|
||||||
|
if hit.get('notfound', False):
|
||||||
|
results.append(MappingResult(ensembl_id=ensembl_id))
|
||||||
|
unmapped_ids.append(ensembl_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract HGNC symbol
|
||||||
|
hgnc_symbol = hit.get('symbol')
|
||||||
|
|
||||||
|
# Extract UniProt accession (handle nested structure and lists)
|
||||||
|
uniprot_accession = None
|
||||||
|
uniprot_data = hit.get('uniprot')
|
||||||
|
|
||||||
|
if uniprot_data:
|
||||||
|
# uniprot can be a dict with Swiss-Prot key
|
||||||
|
if isinstance(uniprot_data, dict):
|
||||||
|
swiss_prot = uniprot_data.get('Swiss-Prot')
|
||||||
|
# Swiss-Prot can be a string or list
|
||||||
|
if isinstance(swiss_prot, str):
|
||||||
|
uniprot_accession = swiss_prot
|
||||||
|
elif isinstance(swiss_prot, list) and swiss_prot:
|
||||||
|
# Take first accession if list
|
||||||
|
uniprot_accession = swiss_prot[0]
|
||||||
|
|
||||||
|
# Create mapping result
|
||||||
|
results.append(MappingResult(
|
||||||
|
ensembl_id=ensembl_id,
|
||||||
|
hgnc_symbol=hgnc_symbol,
|
||||||
|
uniprot_accession=uniprot_accession,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Track success counts
|
||||||
|
if hgnc_symbol:
|
||||||
|
mapped_hgnc += 1
|
||||||
|
else:
|
||||||
|
unmapped_ids.append(ensembl_id)
|
||||||
|
|
||||||
|
if uniprot_accession:
|
||||||
|
mapped_uniprot += 1
|
||||||
|
|
||||||
|
# Create summary report
|
||||||
|
report = MappingReport(
|
||||||
|
total_genes=total_genes,
|
||||||
|
mapped_hgnc=mapped_hgnc,
|
||||||
|
mapped_uniprot=mapped_uniprot,
|
||||||
|
unmapped_ids=unmapped_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Mapping complete: {mapped_hgnc}/{total_genes} HGNC "
|
||||||
|
f"({report.success_rate_hgnc:.1%}), "
|
||||||
|
f"{mapped_uniprot}/{total_genes} UniProt "
|
||||||
|
f"({report.success_rate_uniprot:.1%})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results, report
|
||||||
107
src/usher_pipeline/gene_mapping/universe.py
Normal file
107
src/usher_pipeline/gene_mapping/universe.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""Gene universe definition and retrieval.
|
||||||
|
|
||||||
|
Fetches the complete set of human protein-coding genes from Ensembl via mygene.
|
||||||
|
Validates gene count and filters to ENSG-format Ensembl IDs only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
import mygene
|
||||||
|
|
||||||
|
# Type alias for gene universe lists
|
||||||
|
GeneUniverse: TypeAlias = list[str]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Expected range for human protein-coding genes
|
||||||
|
MIN_EXPECTED_GENES = 19000
|
||||||
|
MAX_EXPECTED_GENES = 22000
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_protein_coding_genes(ensembl_release: int = 113) -> GeneUniverse:
|
||||||
|
"""Fetch all human protein-coding genes from Ensembl via mygene.
|
||||||
|
|
||||||
|
Queries mygene for genes with type_of_gene=protein-coding in humans (taxid 9606).
|
||||||
|
Filters to only genes with valid Ensembl gene IDs (ENSG format).
|
||||||
|
Validates gene count is in expected range (19,000-22,000).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ensembl_release: Ensembl release version (for documentation purposes;
|
||||||
|
mygene returns current data regardless)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted, deduplicated list of Ensembl gene IDs (ENSG format)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If gene count is outside expected range
|
||||||
|
|
||||||
|
Note:
|
||||||
|
While ensembl_release is passed for documentation, mygene API doesn't
|
||||||
|
support querying specific Ensembl versions - it returns current data.
|
||||||
|
For reproducibility, use cached results or versioned data snapshots.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"Fetching protein-coding genes for Ensembl release {ensembl_release} "
|
||||||
|
"(note: mygene returns current data)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize mygene client
|
||||||
|
mg = mygene.MyGeneInfo()
|
||||||
|
|
||||||
|
# Query for human protein-coding genes
|
||||||
|
logger.info("Querying mygene for type_of_gene:protein-coding (species=9606)")
|
||||||
|
results = mg.query(
|
||||||
|
'type_of_gene:"protein-coding"',
|
||||||
|
species=9606,
|
||||||
|
fields='ensembl.gene,symbol,name',
|
||||||
|
fetch_all=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Retrieved {len(results)} results from mygene")
|
||||||
|
|
||||||
|
# Extract Ensembl gene IDs
|
||||||
|
gene_ids: set[str] = set()
|
||||||
|
|
||||||
|
for hit in results:
|
||||||
|
# Handle both single ensembl.gene and list cases
|
||||||
|
ensembl_data = hit.get('ensembl')
|
||||||
|
if not ensembl_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ensembl can be a single dict or list of dicts
|
||||||
|
if isinstance(ensembl_data, dict):
|
||||||
|
ensembl_list = [ensembl_data]
|
||||||
|
else:
|
||||||
|
ensembl_list = ensembl_data
|
||||||
|
|
||||||
|
# Extract gene IDs from each ensembl entry
|
||||||
|
for ensembl_entry in ensembl_list:
|
||||||
|
gene_id = ensembl_entry.get('gene')
|
||||||
|
if gene_id and isinstance(gene_id, str) and gene_id.startswith('ENSG'):
|
||||||
|
gene_ids.add(gene_id)
|
||||||
|
|
||||||
|
# Sort and deduplicate
|
||||||
|
sorted_genes = sorted(gene_ids)
|
||||||
|
gene_count = len(sorted_genes)
|
||||||
|
|
||||||
|
logger.info(f"Extracted {gene_count} unique Ensembl gene IDs (ENSG format)")
|
||||||
|
|
||||||
|
# Validate gene count
|
||||||
|
if gene_count < MIN_EXPECTED_GENES:
|
||||||
|
logger.warning(
|
||||||
|
f"Gene count {gene_count} is below expected minimum {MIN_EXPECTED_GENES}. "
|
||||||
|
"This may indicate missing data or query issues."
|
||||||
|
)
|
||||||
|
elif gene_count > MAX_EXPECTED_GENES:
|
||||||
|
logger.warning(
|
||||||
|
f"Gene count {gene_count} exceeds expected maximum {MAX_EXPECTED_GENES}. "
|
||||||
|
"This may indicate pseudogene contamination or non-coding genes in results."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Gene count {gene_count} is within expected range "
|
||||||
|
f"({MIN_EXPECTED_GENES}-{MAX_EXPECTED_GENES})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted_genes
|
||||||
6
src/usher_pipeline/persistence/__init__.py
Normal file
6
src/usher_pipeline/persistence/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Persistence layer for pipeline checkpoints and provenance tracking."""
|
||||||
|
|
||||||
|
from usher_pipeline.persistence.duckdb_store import PipelineStore
|
||||||
|
|
||||||
|
# ProvenanceTracker will be added in Task 2
|
||||||
|
__all__ = ["PipelineStore"]
|
||||||
232
src/usher_pipeline/persistence/duckdb_store.py
Normal file
232
src/usher_pipeline/persistence/duckdb_store.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""DuckDB-based storage for pipeline checkpoints with restart capability."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import duckdb
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
HAS_PANDAS = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_PANDAS = False
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineStore:
|
||||||
|
"""
|
||||||
|
DuckDB-based storage for pipeline intermediate results.
|
||||||
|
|
||||||
|
Enables checkpoint-restart pattern: expensive operations (API downloads,
|
||||||
|
processing) can be saved as DuckDB tables and skipped on subsequent runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
"""
|
||||||
|
Initialize PipelineStore with a DuckDB database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to DuckDB database file. Parent directories
|
||||||
|
are created automatically.
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
# Create parent directories
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Connect to DuckDB
|
||||||
|
self.conn = duckdb.connect(str(db_path))
|
||||||
|
|
||||||
|
# Create metadata table for tracking checkpoints
|
||||||
|
self.conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS _checkpoints (
|
||||||
|
table_name VARCHAR PRIMARY KEY,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
row_count INTEGER,
|
||||||
|
description VARCHAR
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
def save_dataframe(
|
||||||
|
self,
|
||||||
|
df: Union[pl.DataFrame, "pd.DataFrame"],
|
||||||
|
table_name: str,
|
||||||
|
description: str = "",
|
||||||
|
replace: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save a DataFrame to DuckDB as a table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Polars or pandas DataFrame to save
|
||||||
|
table_name: Name for the DuckDB table
|
||||||
|
description: Optional description for checkpoint metadata
|
||||||
|
replace: If True, replace existing table; if False, append
|
||||||
|
"""
|
||||||
|
# Detect DataFrame type
|
||||||
|
is_polars = isinstance(df, pl.DataFrame)
|
||||||
|
if not is_polars and not HAS_PANDAS:
|
||||||
|
raise ValueError("pandas not available")
|
||||||
|
if not is_polars and not isinstance(df, pd.DataFrame):
|
||||||
|
raise ValueError("df must be polars.DataFrame or pandas.DataFrame")
|
||||||
|
|
||||||
|
# Save DataFrame
|
||||||
|
if replace:
|
||||||
|
self.conn.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM df")
|
||||||
|
else:
|
||||||
|
self.conn.execute(f"INSERT INTO {table_name} SELECT * FROM df")
|
||||||
|
|
||||||
|
# Update checkpoint metadata
|
||||||
|
row_count = len(df)
|
||||||
|
self.conn.execute("""
|
||||||
|
INSERT OR REPLACE INTO _checkpoints (table_name, row_count, description, created_at)
|
||||||
|
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
||||||
|
""", [table_name, row_count, description])
|
||||||
|
|
||||||
|
def load_dataframe(
|
||||||
|
self,
|
||||||
|
table_name: str,
|
||||||
|
as_polars: bool = True
|
||||||
|
) -> Optional[Union[pl.DataFrame, "pd.DataFrame"]]:
|
||||||
|
"""
|
||||||
|
Load a table as a DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the DuckDB table
|
||||||
|
as_polars: If True, return polars DataFrame; else pandas
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame or None if table doesn't exist
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = self.conn.execute(f"SELECT * FROM {table_name}")
|
||||||
|
if as_polars:
|
||||||
|
return result.pl()
|
||||||
|
else:
|
||||||
|
if not HAS_PANDAS:
|
||||||
|
raise ValueError("pandas not available")
|
||||||
|
return result.df()
|
||||||
|
except duckdb.CatalogException:
|
||||||
|
# Table doesn't exist
|
||||||
|
return None
|
||||||
|
|
||||||
|
def has_checkpoint(self, table_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a checkpoint exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the table to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if checkpoint exists, False otherwise
|
||||||
|
"""
|
||||||
|
result = self.conn.execute(
|
||||||
|
"SELECT COUNT(*) FROM _checkpoints WHERE table_name = ?",
|
||||||
|
[table_name]
|
||||||
|
).fetchone()
|
||||||
|
return result[0] > 0
|
||||||
|
|
||||||
|
def list_checkpoints(self) -> list[dict]:
|
||||||
|
"""
|
||||||
|
List all checkpoints with metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of checkpoint metadata dicts with keys:
|
||||||
|
table_name, created_at, row_count, description
|
||||||
|
"""
|
||||||
|
result = self.conn.execute("""
|
||||||
|
SELECT table_name, created_at, row_count, description
|
||||||
|
FROM _checkpoints
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""").fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"table_name": row[0],
|
||||||
|
"created_at": row[1],
|
||||||
|
"row_count": row[2],
|
||||||
|
"description": row[3],
|
||||||
|
}
|
||||||
|
for row in result
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete_checkpoint(self, table_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete a checkpoint and its metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the table to delete
|
||||||
|
"""
|
||||||
|
# Drop table if exists
|
||||||
|
self.conn.execute(f"DROP TABLE IF EXISTS {table_name}")
|
||||||
|
|
||||||
|
# Remove from metadata
|
||||||
|
self.conn.execute(
|
||||||
|
"DELETE FROM _checkpoints WHERE table_name = ?",
|
||||||
|
[table_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
def export_parquet(self, table_name: str, output_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Export a table to Parquet format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the table to export
|
||||||
|
output_path: Path to output Parquet file
|
||||||
|
"""
|
||||||
|
# Create parent directories
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Export using DuckDB's native Parquet writer
|
||||||
|
self.conn.execute(
|
||||||
|
f"COPY {table_name} TO ? (FORMAT PARQUET)",
|
||||||
|
[str(output_path)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
params: Optional[list] = None
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""
|
||||||
|
Execute arbitrary SQL query and return polars DataFrame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL query to execute
|
||||||
|
params: Optional query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Query results as polars DataFrame
|
||||||
|
"""
|
||||||
|
if params:
|
||||||
|
result = self.conn.execute(query, params)
|
||||||
|
else:
|
||||||
|
result = self.conn.execute(query)
|
||||||
|
return result.pl()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the DuckDB connection."""
|
||||||
|
if self.conn:
|
||||||
|
self.conn.close()
|
||||||
|
self.conn = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager entry."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""Context manager exit - closes connection."""
|
||||||
|
self.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: "PipelineConfig") -> "PipelineStore":
|
||||||
|
"""
|
||||||
|
Create PipelineStore from a PipelineConfig.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: PipelineConfig instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PipelineStore instance
|
||||||
|
"""
|
||||||
|
return cls(config.duckdb_path)
|
||||||
Reference in New Issue
Block a user