feat(01-01): create base API client with retry and caching
- CachedAPIClient with SQLite persistent cache - Exponential backoff retry on 429/5xx/network errors (tenacity) - Rate limiting with skip for cached responses - from_config classmethod for pipeline integration - 5 passing tests for cache creation, rate limiting, and config integration
This commit is contained in:
3
src/usher_pipeline/api_clients/__init__.py
Normal file
3
src/usher_pipeline/api_clients/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import CachedAPIClient
|
||||
|
||||
__all__ = ["CachedAPIClient"]
|
||||
207
src/usher_pipeline/api_clients/base.py
Normal file
207
src/usher_pipeline/api_clients/base.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Base API client with retry logic and persistent caching."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import requests_cache
|
||||
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from usher_pipeline.config.schema import PipelineConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CachedAPIClient:
|
||||
"""
|
||||
HTTP client with rate limiting, retry logic, and persistent SQLite caching.
|
||||
|
||||
Features:
|
||||
- Automatic retry on 429/5xx/network errors with exponential backoff
|
||||
- Persistent SQLite cache with configurable TTL
|
||||
- Rate limiting to avoid overwhelming APIs
|
||||
- Cache statistics tracking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Path,
|
||||
rate_limit: int = 5,
|
||||
max_retries: int = 5,
|
||||
cache_ttl: int = 86400,
|
||||
timeout: int = 30,
|
||||
):
|
||||
"""
|
||||
Initialize API client with caching and retry logic.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for SQLite cache storage
|
||||
rate_limit: Maximum requests per second
|
||||
max_retries: Maximum retry attempts on failure
|
||||
cache_ttl: Cache time-to-live in seconds (0 = infinite)
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.rate_limit = rate_limit
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
|
||||
# Create cache directory
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize requests_cache session
|
||||
cache_path = self.cache_dir / "api_cache"
|
||||
expire_after = cache_ttl if cache_ttl > 0 else None
|
||||
|
||||
self.session = requests_cache.CachedSession(
|
||||
cache_name=str(cache_path),
|
||||
backend="sqlite",
|
||||
expire_after=expire_after,
|
||||
)
|
||||
|
||||
def _should_rate_limit(self, response: requests.Response) -> bool:
|
||||
"""Check if response came from cache (no rate limit needed)."""
|
||||
return not getattr(response, "from_cache", False)
|
||||
|
||||
def _create_retry_decorator(self):
|
||||
"""Create retry decorator with exponential backoff."""
|
||||
return retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=60),
|
||||
retry=retry_if_exception_type((HTTPError, Timeout, ConnectionError)),
|
||||
reraise=True,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
url: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make GET request with retry logic and caching.
|
||||
|
||||
Args:
|
||||
url: Request URL
|
||||
params: Query parameters
|
||||
**kwargs: Additional arguments passed to requests
|
||||
|
||||
Returns:
|
||||
Response object
|
||||
|
||||
Raises:
|
||||
HTTPError: On HTTP error after retries exhausted
|
||||
Timeout: On timeout after retries exhausted
|
||||
ConnectionError: On connection error after retries exhausted
|
||||
"""
|
||||
# Apply retry decorator dynamically
|
||||
@self._create_retry_decorator()
|
||||
def _get_with_retry():
|
||||
response = self.session.get(
|
||||
url,
|
||||
params=params,
|
||||
timeout=self.timeout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Check for HTTP errors
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
# Log warning for rate limiting
|
||||
if response.status_code == 429:
|
||||
logger.warning(
|
||||
f"Rate limited by API (429). "
|
||||
f"URL: {url}. Will retry with backoff."
|
||||
)
|
||||
raise e
|
||||
|
||||
return response
|
||||
|
||||
# Make request with retry
|
||||
response = _get_with_retry()
|
||||
|
||||
# Rate limit only non-cached requests
|
||||
if self._should_rate_limit(response):
|
||||
time.sleep(1 / self.rate_limit)
|
||||
|
||||
return response
|
||||
|
||||
def get_json(
|
||||
self,
|
||||
url: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Make GET request and return JSON response.
|
||||
|
||||
Args:
|
||||
url: Request URL
|
||||
params: Query parameters
|
||||
**kwargs: Additional arguments passed to requests
|
||||
|
||||
Returns:
|
||||
Parsed JSON response as dict
|
||||
|
||||
Raises:
|
||||
HTTPError: On HTTP error
|
||||
JSONDecodeError: If response is not valid JSON
|
||||
"""
|
||||
response = self.get(url, params=params, **kwargs)
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PipelineConfig) -> "CachedAPIClient":
|
||||
"""
|
||||
Create client from pipeline configuration.
|
||||
|
||||
Args:
|
||||
config: PipelineConfig instance
|
||||
|
||||
Returns:
|
||||
Configured CachedAPIClient instance
|
||||
"""
|
||||
return cls(
|
||||
cache_dir=config.cache_dir,
|
||||
rate_limit=config.api.rate_limit_per_second,
|
||||
max_retries=config.api.max_retries,
|
||||
cache_ttl=config.api.cache_ttl_seconds,
|
||||
timeout=config.api.timeout_seconds,
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached responses."""
|
||||
self.session.cache.clear()
|
||||
logger.info("API cache cleared")
|
||||
|
||||
def cache_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get cache hit/miss statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
# requests_cache doesn't provide built-in stats,
|
||||
# so we return basic info about cache state
|
||||
cache_path = self.cache_dir / "api_cache.sqlite"
|
||||
|
||||
stats = {
|
||||
"cache_enabled": True,
|
||||
"cache_path": str(cache_path),
|
||||
"cache_exists": cache_path.exists(),
|
||||
}
|
||||
|
||||
# Get cache size if it exists
|
||||
if cache_path.exists():
|
||||
stats["cache_size_bytes"] = cache_path.stat().st_size
|
||||
|
||||
return stats
|
||||
149
tests/test_api_client.py
Normal file
149
tests/test_api_client.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Tests for API client with caching and retry logic."""
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from usher_pipeline.api_clients.base import CachedAPIClient
|
||||
from usher_pipeline.config import load_config
|
||||
|
||||
|
||||
def test_client_creates_cache_dir(tmp_path):
|
||||
"""Test that client creates cache directory if it doesn't exist."""
|
||||
cache_dir = tmp_path / "nonexistent_cache"
|
||||
|
||||
# Directory should not exist before creating client
|
||||
assert not cache_dir.exists()
|
||||
|
||||
# Create client
|
||||
client = CachedAPIClient(cache_dir=cache_dir)
|
||||
|
||||
# Directory should be created
|
||||
assert cache_dir.exists()
|
||||
assert cache_dir.is_dir()
|
||||
|
||||
|
||||
def test_client_caches_response(tmp_path):
|
||||
"""Test that responses are cached and retrieved from cache."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
client = CachedAPIClient(cache_dir=cache_dir, rate_limit=100)
|
||||
|
||||
test_url = "https://api.example.com/test"
|
||||
mock_response_data = {"data": "test"}
|
||||
|
||||
# Mock the underlying session.get method
|
||||
with patch.object(client.session, "get") as mock_get:
|
||||
# Configure mock to return a response object
|
||||
mock_response_1 = Mock()
|
||||
mock_response_1.status_code = 200
|
||||
mock_response_1.json.return_value = mock_response_data
|
||||
mock_response_1.from_cache = False
|
||||
mock_response_1.raise_for_status = Mock()
|
||||
|
||||
mock_response_2 = Mock()
|
||||
mock_response_2.status_code = 200
|
||||
mock_response_2.json.return_value = mock_response_data
|
||||
mock_response_2.from_cache = True
|
||||
mock_response_2.raise_for_status = Mock()
|
||||
|
||||
# First call: not from cache
|
||||
mock_get.return_value = mock_response_1
|
||||
response_1 = client.get(test_url)
|
||||
assert response_1.status_code == 200
|
||||
|
||||
# Second call: from cache
|
||||
mock_get.return_value = mock_response_2
|
||||
response_2 = client.get(test_url)
|
||||
assert response_2.status_code == 200
|
||||
|
||||
# Verify both calls were made to session.get
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
def test_client_from_config(tmp_path):
|
||||
"""Test creating client from PipelineConfig."""
|
||||
# Create a test config file
|
||||
config_file = tmp_path / "test_config.yaml"
|
||||
config_file.write_text(f"""
|
||||
data_dir: {tmp_path / "data"}
|
||||
cache_dir: {tmp_path / "cache"}
|
||||
duckdb_path: {tmp_path / "test.duckdb"}
|
||||
versions:
|
||||
ensembl_release: 113
|
||||
gnomad_version: v4.1
|
||||
api:
|
||||
rate_limit_per_second: 10
|
||||
max_retries: 3
|
||||
cache_ttl_seconds: 3600
|
||||
timeout_seconds: 60
|
||||
scoring:
|
||||
gnomad: 0.20
|
||||
expression: 0.20
|
||||
annotation: 0.15
|
||||
localization: 0.15
|
||||
animal_model: 0.15
|
||||
literature: 0.15
|
||||
""")
|
||||
|
||||
# Load config and create client
|
||||
config = load_config(config_file)
|
||||
client = CachedAPIClient.from_config(config)
|
||||
|
||||
# Verify settings were applied
|
||||
assert client.rate_limit == 10
|
||||
assert client.max_retries == 3
|
||||
assert client.timeout == 60
|
||||
assert client.cache_dir == tmp_path / "cache"
|
||||
|
||||
|
||||
def test_rate_limit_respected(tmp_path):
|
||||
"""Test that rate limiting sleeps between non-cached requests."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
client = CachedAPIClient(cache_dir=cache_dir, rate_limit=10)
|
||||
|
||||
test_url = "https://api.example.com/test"
|
||||
|
||||
with patch("time.sleep") as mock_sleep, patch.object(
|
||||
client.session, "get"
|
||||
) as mock_get:
|
||||
# Configure mock to return non-cached response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.from_cache = False
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Make request
|
||||
client.get(test_url)
|
||||
|
||||
# Verify sleep was called with correct rate limit
|
||||
mock_sleep.assert_called_once()
|
||||
# Rate limit is 10 req/sec = 1/10 = 0.1 seconds between requests
|
||||
assert mock_sleep.call_args[0][0] == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_rate_limit_skipped_for_cached(tmp_path):
|
||||
"""Test that cached requests don't trigger rate limiting sleep."""
|
||||
cache_dir = tmp_path / "cache"
|
||||
client = CachedAPIClient(cache_dir=cache_dir, rate_limit=10)
|
||||
|
||||
test_url = "https://api.example.com/test"
|
||||
|
||||
with patch("time.sleep") as mock_sleep, patch.object(
|
||||
client.session, "get"
|
||||
) as mock_get:
|
||||
# Configure mock to return cached response
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.from_cache = True
|
||||
mock_response.raise_for_status = Mock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Make request
|
||||
client.get(test_url)
|
||||
|
||||
# Verify sleep was NOT called for cached response
|
||||
mock_sleep.assert_not_called()
|
||||
Reference in New Issue
Block a user