161 lines
5.0 KiB
Python
161 lines
5.0 KiB
Python
import json
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
|
|
from ..config import settings
|
|
from ..models.schemas import AttributeNode
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
@abstractmethod
|
|
async def generate(
|
|
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
|
) -> str:
|
|
"""Generate a response from the LLM."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def list_models(self) -> List[str]:
|
|
"""List available models."""
|
|
pass
|
|
|
|
|
|
class OllamaProvider(LLMProvider):
|
|
def __init__(self, base_url: str = None):
|
|
self.base_url = base_url or settings.ollama_base_url
|
|
# Increase timeout for larger models (14B, 32B, etc.)
|
|
self.client = httpx.AsyncClient(timeout=300.0)
|
|
|
|
async def generate(
|
|
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
|
) -> str:
|
|
model = model or settings.default_model
|
|
url = f"{self.base_url}/api/generate"
|
|
|
|
payload = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"format": "json",
|
|
"options": {
|
|
"temperature": temperature,
|
|
},
|
|
}
|
|
|
|
# Retry logic for larger models that may return empty responses
|
|
max_retries = 3
|
|
for attempt in range(max_retries):
|
|
response = await self.client.post(url, json=payload)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
response_text = result.get("response", "")
|
|
|
|
# Check if response is valid (not empty or just "{}")
|
|
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
|
|
return response_text
|
|
|
|
# If empty, retry with slightly higher temperature
|
|
if attempt < max_retries - 1:
|
|
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
|
|
|
|
# Return whatever we got on last attempt
|
|
return response_text
|
|
|
|
async def list_models(self) -> List[str]:
|
|
url = f"{self.base_url}/api/tags"
|
|
|
|
response = await self.client.get(url)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
models = result.get("models", [])
|
|
return [m.get("name", "") for m in models if m.get("name")]
|
|
|
|
async def close(self):
|
|
await self.client.aclose()
|
|
|
|
|
|
class OpenAICompatibleProvider(LLMProvider):
|
|
def __init__(self, base_url: str = None, api_key: str = None):
|
|
self.base_url = base_url or settings.openai_base_url
|
|
self.api_key = api_key or settings.openai_api_key
|
|
# Increase timeout for larger models
|
|
self.client = httpx.AsyncClient(timeout=300.0)
|
|
|
|
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
|
|
model = model or settings.default_model
|
|
url = f"{self.base_url}/v1/chat/completions"
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
}
|
|
|
|
response = await self.client.post(url, json=payload, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return result["choices"][0]["message"]["content"]
|
|
|
|
async def list_models(self) -> List[str]:
|
|
url = f"{self.base_url}/v1/models"
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
response = await self.client.get(url, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
|
|
|
|
async def close(self):
|
|
await self.client.aclose()
|
|
|
|
|
|
def extract_json_from_response(response: str) -> dict:
|
|
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
|
|
# Remove markdown code blocks if present
|
|
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
|
|
if json_match:
|
|
json_str = json_match.group(1)
|
|
else:
|
|
json_str = response
|
|
|
|
# Clean up: remove extra whitespace, normalize spaces
|
|
json_str = json_str.strip()
|
|
# Remove trailing whitespace before closing braces/brackets
|
|
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
|
|
# Remove multiple spaces/tabs/newlines
|
|
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
|
|
|
|
return json.loads(json_str)
|
|
|
|
|
|
def parse_attribute_response(response: str) -> AttributeNode:
|
|
"""Parse LLM response into AttributeNode structure."""
|
|
data = extract_json_from_response(response)
|
|
return AttributeNode.model_validate(data)
|
|
|
|
|
|
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
|
|
"""Factory function to get LLM provider."""
|
|
if provider_type == "ollama":
|
|
return OllamaProvider()
|
|
elif provider_type == "openai":
|
|
return OpenAICompatibleProvider()
|
|
else:
|
|
raise ValueError(f"Unknown provider type: {provider_type}")
|
|
|
|
|
|
ollama_provider = OllamaProvider()
|