Files
novelty-seeking/backend/app/services/llm_service.py
2025-12-02 02:06:51 +08:00

161 lines
5.0 KiB
Python

import json
import re
from abc import ABC, abstractmethod
from typing import List, Optional
import httpx
from ..config import settings
from ..models.schemas import AttributeNode
class LLMProvider(ABC):
@abstractmethod
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
"""Generate a response from the LLM."""
pass
@abstractmethod
async def list_models(self) -> List[str]:
"""List available models."""
pass
class OllamaProvider(LLMProvider):
def __init__(self, base_url: str = None):
self.base_url = base_url or settings.ollama_base_url
# Increase timeout for larger models (14B, 32B, etc.)
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
model = model or settings.default_model
url = f"{self.base_url}/api/generate"
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"format": "json",
"options": {
"temperature": temperature,
},
}
# Retry logic for larger models that may return empty responses
max_retries = 3
for attempt in range(max_retries):
response = await self.client.post(url, json=payload)
response.raise_for_status()
result = response.json()
response_text = result.get("response", "")
# Check if response is valid (not empty or just "{}")
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
return response_text
# If empty, retry with slightly higher temperature
if attempt < max_retries - 1:
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
# Return whatever we got on last attempt
return response_text
async def list_models(self) -> List[str]:
url = f"{self.base_url}/api/tags"
response = await self.client.get(url)
response.raise_for_status()
result = response.json()
models = result.get("models", [])
return [m.get("name", "") for m in models if m.get("name")]
async def close(self):
await self.client.aclose()
class OpenAICompatibleProvider(LLMProvider):
def __init__(self, base_url: str = None, api_key: str = None):
self.base_url = base_url or settings.openai_base_url
self.api_key = api_key or settings.openai_api_key
# Increase timeout for larger models
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
model = model or settings.default_model
url = f"{self.base_url}/v1/chat/completions"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
}
response = await self.client.post(url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
async def list_models(self) -> List[str]:
url = f"{self.base_url}/v1/models"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = await self.client.get(url, headers=headers)
response.raise_for_status()
result = response.json()
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
async def close(self):
await self.client.aclose()
def extract_json_from_response(response: str) -> dict:
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
# Remove markdown code blocks if present
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
if json_match:
json_str = json_match.group(1)
else:
json_str = response
# Clean up: remove extra whitespace, normalize spaces
json_str = json_str.strip()
# Remove trailing whitespace before closing braces/brackets
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
# Remove multiple spaces/tabs/newlines
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
return json.loads(json_str)
def parse_attribute_response(response: str) -> AttributeNode:
"""Parse LLM response into AttributeNode structure."""
data = extract_json_from_response(response)
return AttributeNode.model_validate(data)
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
"""Factory function to get LLM provider."""
if provider_type == "ollama":
return OllamaProvider()
elif provider_type == "openai":
return OpenAICompatibleProvider()
else:
raise ValueError(f"Unknown provider type: {provider_type}")
ollama_provider = OllamaProvider()