Files
usher-exploring/tests/test_persistence.py
gbanyan 98a1a750dd feat(01-03): create provenance tracker with comprehensive tests
- ProvenanceTracker class for metadata tracking
- Records pipeline version, data source versions, config hash, timestamps
- Sidecar JSON export alongside outputs
- DuckDB _provenance table support
- 13 comprehensive tests (8 DuckDB + 5 provenance)
- All tests pass (12 passed, 1 skipped - pandas)
2026-02-11 16:31:51 +08:00

319 lines
9.1 KiB
Python

"""Tests for persistence layer (DuckDB store and provenance tracking)."""
import json
from pathlib import Path
import polars as pl
import pytest
from usher_pipeline.config.loader import load_config
from usher_pipeline.persistence import PipelineStore, ProvenanceTracker
@pytest.fixture
def test_config(tmp_path):
"""Create a minimal test config."""
# Create a minimal config YAML
config_path = tmp_path / "test_config.yaml"
config_path.write_text("""
data_dir: {data_dir}
cache_dir: {cache_dir}
duckdb_path: {duckdb_path}
versions:
ensembl_release: 113
gnomad_version: "v4.1"
gtex_version: "v8"
hpa_version: "23.0"
api:
rate_limit_per_second: 5
max_retries: 5
cache_ttl_seconds: 86400
timeout_seconds: 30
scoring:
gnomad: 0.20
expression: 0.20
annotation: 0.15
localization: 0.15
animal_model: 0.15
literature: 0.15
""".format(
data_dir=str(tmp_path / "data"),
cache_dir=str(tmp_path / "cache"),
duckdb_path=str(tmp_path / "test.duckdb"),
))
return load_config(config_path)
# ============================================================================
# DuckDB Store Tests
# ============================================================================
def test_store_creates_database(tmp_path):
"""Test that PipelineStore creates .duckdb file at specified path."""
db_path = tmp_path / "test.duckdb"
assert not db_path.exists()
store = PipelineStore(db_path)
store.close()
assert db_path.exists()
def test_save_and_load_polars(tmp_path):
"""Test saving and loading polars DataFrame."""
store = PipelineStore(tmp_path / "test.duckdb")
# Create test DataFrame
df = pl.DataFrame({
"gene": ["BRCA1", "TP53", "MYO7A"],
"score": [0.95, 0.88, 0.92],
"chr": ["17", "17", "11"],
})
# Save
store.save_dataframe(df, "genes", "test genes")
# Load
loaded = store.load_dataframe("genes", as_polars=True)
# Verify
assert loaded.shape == df.shape
assert loaded.columns == df.columns
assert loaded["gene"].to_list() == df["gene"].to_list()
assert loaded["score"].to_list() == df["score"].to_list()
store.close()
def test_save_and_load_pandas(tmp_path):
"""Test saving and loading pandas DataFrame."""
pd = pytest.importorskip("pandas")
store = PipelineStore(tmp_path / "test.duckdb")
# Create test DataFrame
df = pd.DataFrame({
"gene": ["BRCA1", "TP53"],
"score": [0.95, 0.88],
})
# Save
store.save_dataframe(df, "genes_pandas", "pandas test")
# Load as pandas
loaded = store.load_dataframe("genes_pandas", as_polars=False)
# Verify
assert loaded.shape == df.shape
assert list(loaded.columns) == list(df.columns)
assert loaded["gene"].tolist() == df["gene"].tolist()
store.close()
def test_checkpoint_lifecycle(tmp_path):
"""Test checkpoint lifecycle: save -> has -> delete -> not has."""
store = PipelineStore(tmp_path / "test.duckdb")
df = pl.DataFrame({"col": [1, 2, 3]})
# Initially no checkpoint
assert not store.has_checkpoint("test_table")
# Save creates checkpoint
store.save_dataframe(df, "test_table", "test")
assert store.has_checkpoint("test_table")
# Delete removes checkpoint
store.delete_checkpoint("test_table")
assert not store.has_checkpoint("test_table")
# Load returns None after deletion
assert store.load_dataframe("test_table") is None
store.close()
def test_list_checkpoints(tmp_path):
"""Test listing checkpoints returns metadata."""
store = PipelineStore(tmp_path / "test.duckdb")
# Create 3 tables
for i in range(3):
df = pl.DataFrame({"val": list(range(i + 1))})
store.save_dataframe(df, f"table_{i}", f"description {i}")
# List checkpoints
checkpoints = store.list_checkpoints()
assert len(checkpoints) == 3
# Verify metadata structure
for ckpt in checkpoints:
assert "table_name" in ckpt
assert "created_at" in ckpt
assert "row_count" in ckpt
assert "description" in ckpt
# Verify specific metadata
table_0 = [c for c in checkpoints if c["table_name"] == "table_0"][0]
assert table_0["row_count"] == 1
assert table_0["description"] == "description 0"
store.close()
def test_export_parquet(tmp_path):
"""Test exporting table to Parquet."""
store = PipelineStore(tmp_path / "test.duckdb")
# Create and save DataFrame
df = pl.DataFrame({
"gene": ["BRCA1", "TP53", "MYO7A"],
"score": [0.95, 0.88, 0.92],
})
store.save_dataframe(df, "genes", "test genes")
# Export to Parquet
parquet_path = tmp_path / "output" / "genes.parquet"
store.export_parquet("genes", parquet_path)
# Verify Parquet file exists and is readable
assert parquet_path.exists()
loaded_from_parquet = pl.read_parquet(parquet_path)
assert loaded_from_parquet.shape == df.shape
assert loaded_from_parquet["gene"].to_list() == df["gene"].to_list()
store.close()
def test_load_nonexistent_returns_none(tmp_path):
"""Test that loading non-existent table returns None."""
store = PipelineStore(tmp_path / "test.duckdb")
result = store.load_dataframe("nonexistent_table")
assert result is None
store.close()
def test_context_manager(tmp_path):
"""Test context manager support."""
db_path = tmp_path / "test.duckdb"
df = pl.DataFrame({"col": [1, 2, 3]})
# Use context manager
with PipelineStore(db_path) as store:
store.save_dataframe(df, "test_table", "test")
assert store.has_checkpoint("test_table")
# Connection should be closed after context exit
# Open a new connection to verify data persists
with PipelineStore(db_path) as store:
loaded = store.load_dataframe("test_table")
assert loaded is not None
assert loaded.shape == df.shape
# ============================================================================
# Provenance Tests
# ============================================================================
def test_provenance_metadata_structure(test_config):
"""Test that provenance metadata has all required keys."""
tracker = ProvenanceTracker("0.1.0", test_config)
metadata = tracker.create_metadata()
assert "pipeline_version" in metadata
assert "data_source_versions" in metadata
assert "config_hash" in metadata
assert "created_at" in metadata
assert "processing_steps" in metadata
assert metadata["pipeline_version"] == "0.1.0"
assert isinstance(metadata["data_source_versions"], dict)
assert isinstance(metadata["config_hash"], str)
assert len(metadata["processing_steps"]) == 0
def test_provenance_records_steps(test_config):
"""Test that processing steps are recorded with timestamps."""
tracker = ProvenanceTracker("0.1.0", test_config)
# Record steps
tracker.record_step("download_genes")
tracker.record_step("filter_protein_coding", {"count": 19000})
metadata = tracker.create_metadata()
steps = metadata["processing_steps"]
assert len(steps) == 2
# Check first step
assert steps[0]["step_name"] == "download_genes"
assert "timestamp" in steps[0]
# Check second step
assert steps[1]["step_name"] == "filter_protein_coding"
assert steps[1]["details"]["count"] == 19000
assert "timestamp" in steps[1]
def test_provenance_sidecar_roundtrip(test_config, tmp_path):
"""Test saving and loading provenance sidecar."""
tracker = ProvenanceTracker("0.1.0", test_config)
tracker.record_step("test_step", {"key": "value"})
# Save sidecar
output_path = tmp_path / "output.parquet"
tracker.save_sidecar(output_path)
# Verify sidecar file exists
sidecar_path = tmp_path / "output.provenance.json"
assert sidecar_path.exists()
# Load and verify content
loaded = ProvenanceTracker.load_sidecar(sidecar_path)
assert loaded["pipeline_version"] == "0.1.0"
assert loaded["config_hash"] == test_config.config_hash()
assert len(loaded["processing_steps"]) == 1
assert loaded["processing_steps"][0]["step_name"] == "test_step"
def test_provenance_config_hash_included(test_config):
"""Test that config hash is included in metadata."""
tracker = ProvenanceTracker("0.1.0", test_config)
metadata = tracker.create_metadata()
assert metadata["config_hash"] == test_config.config_hash()
def test_provenance_save_to_store(test_config, tmp_path):
"""Test saving provenance to DuckDB store."""
store = PipelineStore(tmp_path / "test.duckdb")
tracker = ProvenanceTracker("0.1.0", test_config)
tracker.record_step("test_step")
# Save to store
tracker.save_to_store(store)
# Verify _provenance table exists and has data
result = store.conn.execute("SELECT * FROM _provenance").fetchall()
assert len(result) > 0
# Verify content
row = result[0]
assert row[0] == "0.1.0" # version
assert row[1] == test_config.config_hash() # config_hash
# Verify steps_json is valid JSON
steps = json.loads(row[3])
assert len(steps) == 1
assert steps[0]["step_name"] == "test_step"
store.close()