import json import re from abc import ABC, abstractmethod from typing import List, Optional import httpx from ..config import settings from ..models.schemas import AttributeNode class LLMProvider(ABC): @abstractmethod async def generate( self, prompt: str, model: Optional[str] = None, temperature: float = 0.7 ) -> str: """Generate a response from the LLM.""" pass @abstractmethod async def list_models(self) -> List[str]: """List available models.""" pass class OllamaProvider(LLMProvider): def __init__(self, base_url: str = None): self.base_url = base_url or settings.ollama_base_url # Increase timeout for larger models (14B, 32B, etc.) self.client = httpx.AsyncClient(timeout=300.0) async def generate( self, prompt: str, model: Optional[str] = None, temperature: float = 0.7 ) -> str: model = model or settings.default_model url = f"{self.base_url}/api/generate" payload = { "model": model, "prompt": prompt, "stream": False, "format": "json", "options": { "temperature": temperature, }, } # Retry logic for larger models that may return empty responses max_retries = 3 for attempt in range(max_retries): response = await self.client.post(url, json=payload) response.raise_for_status() result = response.json() response_text = result.get("response", "") # Check if response is valid (not empty or just "{}") if response_text and response_text.strip() not in ["", "{}", "{ }"]: return response_text # If empty, retry with slightly higher temperature if attempt < max_retries - 1: payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0) # Return whatever we got on last attempt return response_text async def list_models(self) -> List[str]: url = f"{self.base_url}/api/tags" response = await self.client.get(url) response.raise_for_status() result = response.json() models = result.get("models", []) return [m.get("name", "") for m in models if m.get("name")] async def close(self): await self.client.aclose() class OpenAICompatibleProvider(LLMProvider): def __init__(self, base_url: str = None, api_key: str = None): self.base_url = base_url or settings.openai_base_url self.api_key = api_key or settings.openai_api_key # Increase timeout for larger models self.client = httpx.AsyncClient(timeout=300.0) async def generate(self, prompt: str, model: Optional[str] = None) -> str: model = model or settings.default_model url = f"{self.base_url}/v1/chat/completions" headers = {} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" payload = { "model": model, "messages": [{"role": "user", "content": prompt}], } response = await self.client.post(url, json=payload, headers=headers) response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"] async def list_models(self) -> List[str]: url = f"{self.base_url}/v1/models" headers = {} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" response = await self.client.get(url, headers=headers) response.raise_for_status() result = response.json() return [m.get("id", "") for m in result.get("data", []) if m.get("id")] async def close(self): await self.client.aclose() def extract_json_from_response(response: str) -> dict: """Extract JSON from LLM response, handling markdown code blocks and extra whitespace.""" # Remove markdown code blocks if present json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response) if json_match: json_str = json_match.group(1) else: json_str = response # Clean up: remove extra whitespace, normalize spaces json_str = json_str.strip() # Remove trailing whitespace before closing braces/brackets json_str = re.sub(r'\s+([}\]])', r'\1', json_str) # Remove multiple spaces/tabs/newlines json_str = re.sub(r'[\t\n\r]+', ' ', json_str) return json.loads(json_str) def parse_attribute_response(response: str) -> AttributeNode: """Parse LLM response into AttributeNode structure.""" data = extract_json_from_response(response) return AttributeNode.model_validate(data) def get_llm_provider(provider_type: str = "ollama") -> LLMProvider: """Factory function to get LLM provider.""" if provider_type == "ollama": return OllamaProvider() elif provider_type == "openai": return OpenAICompatibleProvider() else: raise ValueError(f"Unknown provider type: {provider_type}") ollama_provider = OllamaProvider()