Initial commit
This commit is contained in:
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
160
backend/app/services/llm_service.py
Normal file
160
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..config import settings
|
||||
from ..models.schemas import AttributeNode
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
||||
) -> str:
|
||||
"""Generate a response from the LLM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self) -> List[str]:
|
||||
"""List available models."""
|
||||
pass
|
||||
|
||||
|
||||
class OllamaProvider(LLMProvider):
|
||||
def __init__(self, base_url: str = None):
|
||||
self.base_url = base_url or settings.ollama_base_url
|
||||
# Increase timeout for larger models (14B, 32B, etc.)
|
||||
self.client = httpx.AsyncClient(timeout=300.0)
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
||||
) -> str:
|
||||
model = model or settings.default_model
|
||||
url = f"{self.base_url}/api/generate"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
# Retry logic for larger models that may return empty responses
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
response_text = result.get("response", "")
|
||||
|
||||
# Check if response is valid (not empty or just "{}")
|
||||
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
|
||||
return response_text
|
||||
|
||||
# If empty, retry with slightly higher temperature
|
||||
if attempt < max_retries - 1:
|
||||
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
|
||||
|
||||
# Return whatever we got on last attempt
|
||||
return response_text
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
url = f"{self.base_url}/api/tags"
|
||||
|
||||
response = await self.client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
models = result.get("models", [])
|
||||
return [m.get("name", "") for m in models if m.get("name")]
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(LLMProvider):
|
||||
def __init__(self, base_url: str = None, api_key: str = None):
|
||||
self.base_url = base_url or settings.openai_base_url
|
||||
self.api_key = api_key or settings.openai_api_key
|
||||
# Increase timeout for larger models
|
||||
self.client = httpx.AsyncClient(timeout=300.0)
|
||||
|
||||
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
|
||||
model = model or settings.default_model
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
|
||||
response = await self.client.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
url = f"{self.base_url}/v1/models"
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = await self.client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
def extract_json_from_response(response: str) -> dict:
|
||||
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
|
||||
# Remove markdown code blocks if present
|
||||
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
json_str = response
|
||||
|
||||
# Clean up: remove extra whitespace, normalize spaces
|
||||
json_str = json_str.strip()
|
||||
# Remove trailing whitespace before closing braces/brackets
|
||||
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
|
||||
# Remove multiple spaces/tabs/newlines
|
||||
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
|
||||
def parse_attribute_response(response: str) -> AttributeNode:
|
||||
"""Parse LLM response into AttributeNode structure."""
|
||||
data = extract_json_from_response(response)
|
||||
return AttributeNode.model_validate(data)
|
||||
|
||||
|
||||
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
|
||||
"""Factory function to get LLM provider."""
|
||||
if provider_type == "ollama":
|
||||
return OllamaProvider()
|
||||
elif provider_type == "openai":
|
||||
return OpenAICompatibleProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
|
||||
ollama_provider = OllamaProvider()
|
||||
Reference in New Issue
Block a user