Initial commit

This commit is contained in:
2025-12-02 02:06:51 +08:00
commit eb6c0c51fa
37 changed files with 7454 additions and 0 deletions

View File

View File

@@ -0,0 +1,160 @@
import json
import re
from abc import ABC, abstractmethod
from typing import List, Optional
import httpx
from ..config import settings
from ..models.schemas import AttributeNode
class LLMProvider(ABC):
@abstractmethod
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
"""Generate a response from the LLM."""
pass
@abstractmethod
async def list_models(self) -> List[str]:
"""List available models."""
pass
class OllamaProvider(LLMProvider):
def __init__(self, base_url: str = None):
self.base_url = base_url or settings.ollama_base_url
# Increase timeout for larger models (14B, 32B, etc.)
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
model = model or settings.default_model
url = f"{self.base_url}/api/generate"
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"format": "json",
"options": {
"temperature": temperature,
},
}
# Retry logic for larger models that may return empty responses
max_retries = 3
for attempt in range(max_retries):
response = await self.client.post(url, json=payload)
response.raise_for_status()
result = response.json()
response_text = result.get("response", "")
# Check if response is valid (not empty or just "{}")
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
return response_text
# If empty, retry with slightly higher temperature
if attempt < max_retries - 1:
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
# Return whatever we got on last attempt
return response_text
async def list_models(self) -> List[str]:
url = f"{self.base_url}/api/tags"
response = await self.client.get(url)
response.raise_for_status()
result = response.json()
models = result.get("models", [])
return [m.get("name", "") for m in models if m.get("name")]
async def close(self):
await self.client.aclose()
class OpenAICompatibleProvider(LLMProvider):
def __init__(self, base_url: str = None, api_key: str = None):
self.base_url = base_url or settings.openai_base_url
self.api_key = api_key or settings.openai_api_key
# Increase timeout for larger models
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
model = model or settings.default_model
url = f"{self.base_url}/v1/chat/completions"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
}
response = await self.client.post(url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
async def list_models(self) -> List[str]:
url = f"{self.base_url}/v1/models"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = await self.client.get(url, headers=headers)
response.raise_for_status()
result = response.json()
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
async def close(self):
await self.client.aclose()
def extract_json_from_response(response: str) -> dict:
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
# Remove markdown code blocks if present
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
if json_match:
json_str = json_match.group(1)
else:
json_str = response
# Clean up: remove extra whitespace, normalize spaces
json_str = json_str.strip()
# Remove trailing whitespace before closing braces/brackets
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
# Remove multiple spaces/tabs/newlines
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
return json.loads(json_str)
def parse_attribute_response(response: str) -> AttributeNode:
"""Parse LLM response into AttributeNode structure."""
data = extract_json_from_response(response)
return AttributeNode.model_validate(data)
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
"""Factory function to get LLM provider."""
if provider_type == "ollama":
return OllamaProvider()
elif provider_type == "openai":
return OpenAICompatibleProvider()
else:
raise ValueError(f"Unknown provider type: {provider_type}")
ollama_provider = OllamaProvider()