Files
novelty-seeking/backend/app/services/llm_service.py
gbanyan 1ed1dab78f feat: Migrate to React Flow and add Fixed + Dynamic category mode
Frontend:
- Migrate MindmapDAG from D3.js to React Flow (@xyflow/react)
- Add custom node components (QueryNode, CategoryHeaderNode, AttributeNode)
- Add useDAGLayout hook for column-based layout
- Add "AI" badge for LLM-suggested categories
- Update CategorySelector with Fixed + Dynamic mode option
- Improve dark/light theme support

Backend:
- Add FIXED_PLUS_DYNAMIC category mode
- Filter duplicate category names in LLM suggestions
- Update prompts to exclude fixed categories when suggesting new ones
- Improve LLM service with better error handling and logging
- Auto-remove /no_think prefix for non-Qwen models
- Add smart JSON format detection for model compatibility
- Improve JSON extraction with multiple parsing strategies

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-03 01:22:57 +08:00

211 lines
7.3 KiB
Python

import json
import logging
import re
from abc import ABC, abstractmethod
from typing import List, Optional
import httpx
from ..config import settings
from ..models.schemas import AttributeNode
logger = logging.getLogger(__name__)
class LLMProvider(ABC):
@abstractmethod
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
"""Generate a response from the LLM."""
pass
@abstractmethod
async def list_models(self) -> List[str]:
"""List available models."""
pass
class OllamaProvider(LLMProvider):
def __init__(self, base_url: str = None):
self.base_url = base_url or settings.ollama_base_url
# Increase timeout for larger models (14B, 32B, etc.)
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
model = model or settings.default_model
url = f"{self.base_url}/api/generate"
# Remove /no_think prefix for non-qwen models (it's qwen-specific)
clean_prompt = prompt
if not model.lower().startswith("qwen") and prompt.startswith("/no_think"):
clean_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "")
logger.info(f"Removed /no_think prefix for model {model}")
# Models known to support JSON format well
json_capable_models = ["qwen", "llama", "mistral", "gemma", "phi"]
model_lower = model.lower()
use_json_format = any(m in model_lower for m in json_capable_models)
payload = {
"model": model,
"prompt": clean_prompt,
"stream": False,
"options": {
"temperature": temperature,
},
}
# Only use format: json for models that support it
if use_json_format:
payload["format"] = "json"
else:
logger.info(f"Model {model} may not support JSON format, requesting without format constraint")
# Retry logic for larger models that may return empty responses
max_retries = 3
for attempt in range(max_retries):
logger.info(f"LLM request attempt {attempt + 1}/{max_retries} to model {model}")
response = await self.client.post(url, json=payload)
response.raise_for_status()
result = response.json()
response_text = result.get("response", "")
logger.info(f"LLM response (first 500 chars): {response_text[:500] if response_text else '(empty)'}")
# Check if response is valid (not empty or just "{}")
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
return response_text
logger.warning(f"Empty or invalid response on attempt {attempt + 1}, retrying...")
# If empty, retry with slightly higher temperature
if attempt < max_retries - 1:
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
# Return whatever we got on last attempt
logger.error(f"All {max_retries} attempts returned empty response from model {model}")
return response_text
async def list_models(self) -> List[str]:
url = f"{self.base_url}/api/tags"
response = await self.client.get(url)
response.raise_for_status()
result = response.json()
models = result.get("models", [])
return [m.get("name", "") for m in models if m.get("name")]
async def close(self):
await self.client.aclose()
class OpenAICompatibleProvider(LLMProvider):
def __init__(self, base_url: str = None, api_key: str = None):
self.base_url = base_url or settings.openai_base_url
self.api_key = api_key or settings.openai_api_key
# Increase timeout for larger models
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
model = model or settings.default_model
url = f"{self.base_url}/v1/chat/completions"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
}
response = await self.client.post(url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
async def list_models(self) -> List[str]:
url = f"{self.base_url}/v1/models"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = await self.client.get(url, headers=headers)
response.raise_for_status()
result = response.json()
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
async def close(self):
await self.client.aclose()
def extract_json_from_response(response: str) -> dict:
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
if not response or not response.strip():
logger.error("LLM returned empty response")
raise ValueError("LLM returned empty response - the model may not support JSON format or the prompt was unclear")
json_str = response
# Try multiple extraction strategies
extraction_attempts = []
# Strategy 1: Look for markdown code blocks
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
if json_match:
extraction_attempts.append(json_match.group(1))
# Strategy 2: Look for JSON object pattern { ... }
json_obj_match = re.search(r'(\{[\s\S]*\})', response)
if json_obj_match:
extraction_attempts.append(json_obj_match.group(1))
# Strategy 3: Original response
extraction_attempts.append(response)
# Try each extraction attempt
for attempt_str in extraction_attempts:
# Clean up: remove extra whitespace, normalize spaces
cleaned = attempt_str.strip()
# Remove trailing whitespace before closing braces/brackets
cleaned = re.sub(r'\s+([}\]])', r'\1', cleaned)
# Normalize newlines but keep structure
cleaned = re.sub(r'[\t\r]+', ' ', cleaned)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
continue
# All attempts failed
logger.error(f"Failed to parse JSON from response")
logger.error(f"Raw response: {response[:1000]}")
raise ValueError(f"Failed to parse LLM response as JSON. The model may not support structured output. Raw response: {response[:300]}...")
def parse_attribute_response(response: str) -> AttributeNode:
"""Parse LLM response into AttributeNode structure."""
data = extract_json_from_response(response)
return AttributeNode.model_validate(data)
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
"""Factory function to get LLM provider."""
if provider_type == "ollama":
return OllamaProvider()
elif provider_type == "openai":
return OpenAICompatibleProvider()
else:
raise ValueError(f"Unknown provider type: {provider_type}")
ollama_provider = OllamaProvider()