import json import logging import re from abc import ABC, abstractmethod from typing import List, Optional import httpx from ..config import settings from ..models.schemas import AttributeNode logger = logging.getLogger(__name__) class LLMProvider(ABC): @abstractmethod async def generate( self, prompt: str, model: Optional[str] = None, temperature: float = 0.7 ) -> str: """Generate a response from the LLM.""" pass @abstractmethod async def list_models(self) -> List[str]: """List available models.""" pass class OllamaProvider(LLMProvider): def __init__(self, base_url: str = None): self.base_url = base_url or settings.ollama_base_url # Increase timeout for larger models (14B, 32B, etc.) self.client = httpx.AsyncClient(timeout=300.0) async def generate( self, prompt: str, model: Optional[str] = None, temperature: float = 0.7 ) -> str: model = model or settings.default_model url = f"{self.base_url}/api/generate" # Remove /no_think prefix for non-qwen models (it's qwen-specific) clean_prompt = prompt if not model.lower().startswith("qwen") and prompt.startswith("/no_think"): clean_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "") logger.info(f"Removed /no_think prefix for model {model}") # Models known to support JSON format well json_capable_models = ["qwen", "llama", "mistral", "gemma", "phi"] model_lower = model.lower() use_json_format = any(m in model_lower for m in json_capable_models) payload = { "model": model, "prompt": clean_prompt, "stream": False, "options": { "temperature": temperature, }, } # Only use format: json for models that support it if use_json_format: payload["format"] = "json" else: logger.info(f"Model {model} may not support JSON format, requesting without format constraint") # Retry logic for larger models that may return empty responses max_retries = 3 for attempt in range(max_retries): logger.info(f"LLM request attempt {attempt + 1}/{max_retries} to model {model}") response = await self.client.post(url, json=payload) response.raise_for_status() result = response.json() response_text = result.get("response", "") logger.info(f"LLM response (first 500 chars): {response_text[:500] if response_text else '(empty)'}") # Check if response is valid (not empty or just "{}") if response_text and response_text.strip() not in ["", "{}", "{ }"]: return response_text logger.warning(f"Empty or invalid response on attempt {attempt + 1}, retrying...") # If empty, retry with slightly higher temperature if attempt < max_retries - 1: payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0) # Return whatever we got on last attempt logger.error(f"All {max_retries} attempts returned empty response from model {model}") return response_text async def list_models(self) -> List[str]: url = f"{self.base_url}/api/tags" response = await self.client.get(url) response.raise_for_status() result = response.json() models = result.get("models", []) return [m.get("name", "") for m in models if m.get("name")] async def close(self): await self.client.aclose() class OpenAICompatibleProvider(LLMProvider): def __init__(self, base_url: str = None, api_key: str = None): self.base_url = base_url or settings.openai_base_url self.api_key = api_key or settings.openai_api_key # Increase timeout for larger models self.client = httpx.AsyncClient(timeout=300.0) async def generate(self, prompt: str, model: Optional[str] = None) -> str: model = model or settings.default_model url = f"{self.base_url}/v1/chat/completions" headers = {} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" payload = { "model": model, "messages": [{"role": "user", "content": prompt}], } response = await self.client.post(url, json=payload, headers=headers) response.raise_for_status() result = response.json() return result["choices"][0]["message"]["content"] async def list_models(self) -> List[str]: url = f"{self.base_url}/v1/models" headers = {} if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" response = await self.client.get(url, headers=headers) response.raise_for_status() result = response.json() return [m.get("id", "") for m in result.get("data", []) if m.get("id")] async def close(self): await self.client.aclose() def extract_json_from_response(response: str) -> dict: """Extract JSON from LLM response, handling markdown code blocks and extra whitespace.""" if not response or not response.strip(): logger.error("LLM returned empty response") raise ValueError("LLM returned empty response - the model may not support JSON format or the prompt was unclear") json_str = response # Try multiple extraction strategies extraction_attempts = [] # Strategy 1: Look for markdown code blocks json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response) if json_match: extraction_attempts.append(json_match.group(1)) # Strategy 2: Look for JSON object pattern { ... } json_obj_match = re.search(r'(\{[\s\S]*\})', response) if json_obj_match: extraction_attempts.append(json_obj_match.group(1)) # Strategy 3: Original response extraction_attempts.append(response) # Try each extraction attempt for attempt_str in extraction_attempts: # Clean up: remove extra whitespace, normalize spaces cleaned = attempt_str.strip() # Remove trailing whitespace before closing braces/brackets cleaned = re.sub(r'\s+([}\]])', r'\1', cleaned) # Normalize newlines but keep structure cleaned = re.sub(r'[\t\r]+', ' ', cleaned) try: return json.loads(cleaned) except json.JSONDecodeError: continue # All attempts failed logger.error(f"Failed to parse JSON from response") logger.error(f"Raw response: {response[:1000]}") raise ValueError(f"Failed to parse LLM response as JSON. The model may not support structured output. Raw response: {response[:300]}...") def parse_attribute_response(response: str) -> AttributeNode: """Parse LLM response into AttributeNode structure.""" data = extract_json_from_response(response) return AttributeNode.model_validate(data) def get_llm_provider(provider_type: str = "ollama") -> LLMProvider: """Factory function to get LLM provider.""" if provider_type == "ollama": return OllamaProvider() elif provider_type == "openai": return OpenAICompatibleProvider() else: raise ValueError(f"Unknown provider type: {provider_type}") ollama_provider = OllamaProvider()