diff --git a/src/usher_pipeline/persistence/__init__.py b/src/usher_pipeline/persistence/__init__.py index e422ab1..6281574 100644 --- a/src/usher_pipeline/persistence/__init__.py +++ b/src/usher_pipeline/persistence/__init__.py @@ -1,6 +1,6 @@ """Persistence layer for pipeline checkpoints and provenance tracking.""" from usher_pipeline.persistence.duckdb_store import PipelineStore +from usher_pipeline.persistence.provenance import ProvenanceTracker -# ProvenanceTracker will be added in Task 2 -__all__ = ["PipelineStore"] +__all__ = ["PipelineStore", "ProvenanceTracker"] diff --git a/src/usher_pipeline/persistence/provenance.py b/src/usher_pipeline/persistence/provenance.py new file mode 100644 index 0000000..655024e --- /dev/null +++ b/src/usher_pipeline/persistence/provenance.py @@ -0,0 +1,141 @@ +"""Provenance tracking for pipeline reproducibility.""" + +import json +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + + +class ProvenanceTracker: + """ + Tracks provenance metadata for pipeline runs. + + Records pipeline version, data source versions, config hash, + and processing steps for full reproducibility tracking. + """ + + def __init__(self, pipeline_version: str, config: "PipelineConfig"): + """ + Initialize provenance tracker. + + Args: + pipeline_version: Pipeline version string (e.g., "0.1.0") + config: PipelineConfig instance + """ + self.pipeline_version = pipeline_version + self.config_hash = config.config_hash() + self.data_source_versions = config.versions.model_dump() + self.processing_steps = [] + self.created_at = datetime.now(timezone.utc) + + def record_step(self, step_name: str, details: Optional[dict] = None) -> None: + """ + Record a processing step. + + Args: + step_name: Name of the processing step + details: Optional dictionary of additional details + """ + step = { + "step_name": step_name, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + if details: + step["details"] = details + self.processing_steps.append(step) + + def create_metadata(self) -> dict: + """ + Create full provenance metadata dictionary. + + Returns: + Dictionary with all provenance information + """ + return { + "pipeline_version": self.pipeline_version, + "data_source_versions": self.data_source_versions, + "config_hash": self.config_hash, + "created_at": self.created_at.isoformat(), + "processing_steps": self.processing_steps, + } + + def save_sidecar(self, output_path: Path) -> None: + """ + Save provenance metadata as a JSON sidecar file. + + Args: + output_path: Path to the main output file. + Sidecar will be saved as {path}.provenance.json + """ + sidecar_path = output_path.with_suffix(".provenance.json") + sidecar_path.parent.mkdir(parents=True, exist_ok=True) + + metadata = self.create_metadata() + with open(sidecar_path, "w") as f: + json.dump(metadata, f, indent=2, default=str) + + def save_to_store(self, store: "PipelineStore") -> None: + """ + Save provenance metadata to DuckDB store. + + Args: + store: PipelineStore instance + """ + metadata = self.create_metadata() + + # Create or replace _provenance table + store.conn.execute(""" + CREATE TABLE IF NOT EXISTS _provenance ( + version VARCHAR, + config_hash VARCHAR, + created_at TIMESTAMP, + steps_json VARCHAR + ) + """) + + # Insert provenance record + store.conn.execute(""" + INSERT INTO _provenance (version, config_hash, created_at, steps_json) + VALUES (?, ?, ?, ?) + """, [ + metadata["pipeline_version"], + metadata["config_hash"], + metadata["created_at"], + json.dumps(metadata["processing_steps"]), + ]) + + @staticmethod + def load_sidecar(sidecar_path: Path) -> dict: + """ + Load provenance metadata from a sidecar file. + + Args: + sidecar_path: Path to the .provenance.json file + + Returns: + Provenance metadata dictionary + """ + with open(sidecar_path) as f: + return json.load(f) + + @classmethod + def from_config( + cls, + config: "PipelineConfig", + version: Optional[str] = None + ) -> "ProvenanceTracker": + """ + Create ProvenanceTracker from a PipelineConfig. + + Args: + config: PipelineConfig instance + version: Pipeline version string. If None, uses usher_pipeline.__version__ + + Returns: + ProvenanceTracker instance + """ + if version is None: + from usher_pipeline import __version__ + version = __version__ + + return cls(version, config) diff --git a/tests/test_persistence.py b/tests/test_persistence.py new file mode 100644 index 0000000..20a2d0a --- /dev/null +++ b/tests/test_persistence.py @@ -0,0 +1,318 @@ +"""Tests for persistence layer (DuckDB store and provenance tracking).""" + +import json +from pathlib import Path + +import polars as pl +import pytest + +from usher_pipeline.config.loader import load_config +from usher_pipeline.persistence import PipelineStore, ProvenanceTracker + + +@pytest.fixture +def test_config(tmp_path): + """Create a minimal test config.""" + # Create a minimal config YAML + config_path = tmp_path / "test_config.yaml" + config_path.write_text(""" +data_dir: {data_dir} +cache_dir: {cache_dir} +duckdb_path: {duckdb_path} +versions: + ensembl_release: 113 + gnomad_version: "v4.1" + gtex_version: "v8" + hpa_version: "23.0" +api: + rate_limit_per_second: 5 + max_retries: 5 + cache_ttl_seconds: 86400 + timeout_seconds: 30 +scoring: + gnomad: 0.20 + expression: 0.20 + annotation: 0.15 + localization: 0.15 + animal_model: 0.15 + literature: 0.15 +""".format( + data_dir=str(tmp_path / "data"), + cache_dir=str(tmp_path / "cache"), + duckdb_path=str(tmp_path / "test.duckdb"), + )) + return load_config(config_path) + + +# ============================================================================ +# DuckDB Store Tests +# ============================================================================ + +def test_store_creates_database(tmp_path): + """Test that PipelineStore creates .duckdb file at specified path.""" + db_path = tmp_path / "test.duckdb" + assert not db_path.exists() + + store = PipelineStore(db_path) + store.close() + + assert db_path.exists() + + +def test_save_and_load_polars(tmp_path): + """Test saving and loading polars DataFrame.""" + store = PipelineStore(tmp_path / "test.duckdb") + + # Create test DataFrame + df = pl.DataFrame({ + "gene": ["BRCA1", "TP53", "MYO7A"], + "score": [0.95, 0.88, 0.92], + "chr": ["17", "17", "11"], + }) + + # Save + store.save_dataframe(df, "genes", "test genes") + + # Load + loaded = store.load_dataframe("genes", as_polars=True) + + # Verify + assert loaded.shape == df.shape + assert loaded.columns == df.columns + assert loaded["gene"].to_list() == df["gene"].to_list() + assert loaded["score"].to_list() == df["score"].to_list() + + store.close() + + +def test_save_and_load_pandas(tmp_path): + """Test saving and loading pandas DataFrame.""" + pd = pytest.importorskip("pandas") + + store = PipelineStore(tmp_path / "test.duckdb") + + # Create test DataFrame + df = pd.DataFrame({ + "gene": ["BRCA1", "TP53"], + "score": [0.95, 0.88], + }) + + # Save + store.save_dataframe(df, "genes_pandas", "pandas test") + + # Load as pandas + loaded = store.load_dataframe("genes_pandas", as_polars=False) + + # Verify + assert loaded.shape == df.shape + assert list(loaded.columns) == list(df.columns) + assert loaded["gene"].tolist() == df["gene"].tolist() + + store.close() + + +def test_checkpoint_lifecycle(tmp_path): + """Test checkpoint lifecycle: save -> has -> delete -> not has.""" + store = PipelineStore(tmp_path / "test.duckdb") + + df = pl.DataFrame({"col": [1, 2, 3]}) + + # Initially no checkpoint + assert not store.has_checkpoint("test_table") + + # Save creates checkpoint + store.save_dataframe(df, "test_table", "test") + assert store.has_checkpoint("test_table") + + # Delete removes checkpoint + store.delete_checkpoint("test_table") + assert not store.has_checkpoint("test_table") + + # Load returns None after deletion + assert store.load_dataframe("test_table") is None + + store.close() + + +def test_list_checkpoints(tmp_path): + """Test listing checkpoints returns metadata.""" + store = PipelineStore(tmp_path / "test.duckdb") + + # Create 3 tables + for i in range(3): + df = pl.DataFrame({"val": list(range(i + 1))}) + store.save_dataframe(df, f"table_{i}", f"description {i}") + + # List checkpoints + checkpoints = store.list_checkpoints() + + assert len(checkpoints) == 3 + + # Verify metadata structure + for ckpt in checkpoints: + assert "table_name" in ckpt + assert "created_at" in ckpt + assert "row_count" in ckpt + assert "description" in ckpt + + # Verify specific metadata + table_0 = [c for c in checkpoints if c["table_name"] == "table_0"][0] + assert table_0["row_count"] == 1 + assert table_0["description"] == "description 0" + + store.close() + + +def test_export_parquet(tmp_path): + """Test exporting table to Parquet.""" + store = PipelineStore(tmp_path / "test.duckdb") + + # Create and save DataFrame + df = pl.DataFrame({ + "gene": ["BRCA1", "TP53", "MYO7A"], + "score": [0.95, 0.88, 0.92], + }) + store.save_dataframe(df, "genes", "test genes") + + # Export to Parquet + parquet_path = tmp_path / "output" / "genes.parquet" + store.export_parquet("genes", parquet_path) + + # Verify Parquet file exists and is readable + assert parquet_path.exists() + loaded_from_parquet = pl.read_parquet(parquet_path) + assert loaded_from_parquet.shape == df.shape + assert loaded_from_parquet["gene"].to_list() == df["gene"].to_list() + + store.close() + + +def test_load_nonexistent_returns_none(tmp_path): + """Test that loading non-existent table returns None.""" + store = PipelineStore(tmp_path / "test.duckdb") + + result = store.load_dataframe("nonexistent_table") + assert result is None + + store.close() + + +def test_context_manager(tmp_path): + """Test context manager support.""" + db_path = tmp_path / "test.duckdb" + + df = pl.DataFrame({"col": [1, 2, 3]}) + + # Use context manager + with PipelineStore(db_path) as store: + store.save_dataframe(df, "test_table", "test") + assert store.has_checkpoint("test_table") + + # Connection should be closed after context exit + # Open a new connection to verify data persists + with PipelineStore(db_path) as store: + loaded = store.load_dataframe("test_table") + assert loaded is not None + assert loaded.shape == df.shape + + +# ============================================================================ +# Provenance Tests +# ============================================================================ + +def test_provenance_metadata_structure(test_config): + """Test that provenance metadata has all required keys.""" + tracker = ProvenanceTracker("0.1.0", test_config) + + metadata = tracker.create_metadata() + + assert "pipeline_version" in metadata + assert "data_source_versions" in metadata + assert "config_hash" in metadata + assert "created_at" in metadata + assert "processing_steps" in metadata + + assert metadata["pipeline_version"] == "0.1.0" + assert isinstance(metadata["data_source_versions"], dict) + assert isinstance(metadata["config_hash"], str) + assert len(metadata["processing_steps"]) == 0 + + +def test_provenance_records_steps(test_config): + """Test that processing steps are recorded with timestamps.""" + tracker = ProvenanceTracker("0.1.0", test_config) + + # Record steps + tracker.record_step("download_genes") + tracker.record_step("filter_protein_coding", {"count": 19000}) + + metadata = tracker.create_metadata() + steps = metadata["processing_steps"] + + assert len(steps) == 2 + + # Check first step + assert steps[0]["step_name"] == "download_genes" + assert "timestamp" in steps[0] + + # Check second step + assert steps[1]["step_name"] == "filter_protein_coding" + assert steps[1]["details"]["count"] == 19000 + assert "timestamp" in steps[1] + + +def test_provenance_sidecar_roundtrip(test_config, tmp_path): + """Test saving and loading provenance sidecar.""" + tracker = ProvenanceTracker("0.1.0", test_config) + tracker.record_step("test_step", {"key": "value"}) + + # Save sidecar + output_path = tmp_path / "output.parquet" + tracker.save_sidecar(output_path) + + # Verify sidecar file exists + sidecar_path = tmp_path / "output.provenance.json" + assert sidecar_path.exists() + + # Load and verify content + loaded = ProvenanceTracker.load_sidecar(sidecar_path) + + assert loaded["pipeline_version"] == "0.1.0" + assert loaded["config_hash"] == test_config.config_hash() + assert len(loaded["processing_steps"]) == 1 + assert loaded["processing_steps"][0]["step_name"] == "test_step" + + +def test_provenance_config_hash_included(test_config): + """Test that config hash is included in metadata.""" + tracker = ProvenanceTracker("0.1.0", test_config) + + metadata = tracker.create_metadata() + + assert metadata["config_hash"] == test_config.config_hash() + + +def test_provenance_save_to_store(test_config, tmp_path): + """Test saving provenance to DuckDB store.""" + store = PipelineStore(tmp_path / "test.duckdb") + tracker = ProvenanceTracker("0.1.0", test_config) + tracker.record_step("test_step") + + # Save to store + tracker.save_to_store(store) + + # Verify _provenance table exists and has data + result = store.conn.execute("SELECT * FROM _provenance").fetchall() + assert len(result) > 0 + + # Verify content + row = result[0] + assert row[0] == "0.1.0" # version + assert row[1] == test_config.config_hash() # config_hash + + # Verify steps_json is valid JSON + steps = json.loads(row[3]) + assert len(steps) == 1 + assert steps[0]["step_name"] == "test_step" + + store.close()