Frontend: - Migrate MindmapDAG from D3.js to React Flow (@xyflow/react) - Add custom node components (QueryNode, CategoryHeaderNode, AttributeNode) - Add useDAGLayout hook for column-based layout - Add "AI" badge for LLM-suggested categories - Update CategorySelector with Fixed + Dynamic mode option - Improve dark/light theme support Backend: - Add FIXED_PLUS_DYNAMIC category mode - Filter duplicate category names in LLM suggestions - Update prompts to exclude fixed categories when suggesting new ones - Improve LLM service with better error handling and logging - Auto-remove /no_think prefix for non-Qwen models - Add smart JSON format detection for model compatibility - Improve JSON extraction with multiple parsing strategies 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
211 lines
7.3 KiB
Python
211 lines
7.3 KiB
Python
import json
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
|
|
from ..config import settings
|
|
from ..models.schemas import AttributeNode
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
@abstractmethod
|
|
async def generate(
|
|
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
|
) -> str:
|
|
"""Generate a response from the LLM."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def list_models(self) -> List[str]:
|
|
"""List available models."""
|
|
pass
|
|
|
|
|
|
class OllamaProvider(LLMProvider):
|
|
def __init__(self, base_url: str = None):
|
|
self.base_url = base_url or settings.ollama_base_url
|
|
# Increase timeout for larger models (14B, 32B, etc.)
|
|
self.client = httpx.AsyncClient(timeout=300.0)
|
|
|
|
async def generate(
|
|
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
|
) -> str:
|
|
model = model or settings.default_model
|
|
url = f"{self.base_url}/api/generate"
|
|
|
|
# Remove /no_think prefix for non-qwen models (it's qwen-specific)
|
|
clean_prompt = prompt
|
|
if not model.lower().startswith("qwen") and prompt.startswith("/no_think"):
|
|
clean_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "")
|
|
logger.info(f"Removed /no_think prefix for model {model}")
|
|
|
|
# Models known to support JSON format well
|
|
json_capable_models = ["qwen", "llama", "mistral", "gemma", "phi"]
|
|
model_lower = model.lower()
|
|
use_json_format = any(m in model_lower for m in json_capable_models)
|
|
|
|
payload = {
|
|
"model": model,
|
|
"prompt": clean_prompt,
|
|
"stream": False,
|
|
"options": {
|
|
"temperature": temperature,
|
|
},
|
|
}
|
|
|
|
# Only use format: json for models that support it
|
|
if use_json_format:
|
|
payload["format"] = "json"
|
|
else:
|
|
logger.info(f"Model {model} may not support JSON format, requesting without format constraint")
|
|
|
|
# Retry logic for larger models that may return empty responses
|
|
max_retries = 3
|
|
for attempt in range(max_retries):
|
|
logger.info(f"LLM request attempt {attempt + 1}/{max_retries} to model {model}")
|
|
response = await self.client.post(url, json=payload)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
response_text = result.get("response", "")
|
|
|
|
logger.info(f"LLM response (first 500 chars): {response_text[:500] if response_text else '(empty)'}")
|
|
|
|
# Check if response is valid (not empty or just "{}")
|
|
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
|
|
return response_text
|
|
|
|
logger.warning(f"Empty or invalid response on attempt {attempt + 1}, retrying...")
|
|
|
|
# If empty, retry with slightly higher temperature
|
|
if attempt < max_retries - 1:
|
|
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
|
|
|
|
# Return whatever we got on last attempt
|
|
logger.error(f"All {max_retries} attempts returned empty response from model {model}")
|
|
return response_text
|
|
|
|
async def list_models(self) -> List[str]:
|
|
url = f"{self.base_url}/api/tags"
|
|
|
|
response = await self.client.get(url)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
models = result.get("models", [])
|
|
return [m.get("name", "") for m in models if m.get("name")]
|
|
|
|
async def close(self):
|
|
await self.client.aclose()
|
|
|
|
|
|
class OpenAICompatibleProvider(LLMProvider):
|
|
def __init__(self, base_url: str = None, api_key: str = None):
|
|
self.base_url = base_url or settings.openai_base_url
|
|
self.api_key = api_key or settings.openai_api_key
|
|
# Increase timeout for larger models
|
|
self.client = httpx.AsyncClient(timeout=300.0)
|
|
|
|
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
|
|
model = model or settings.default_model
|
|
url = f"{self.base_url}/v1/chat/completions"
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
}
|
|
|
|
response = await self.client.post(url, json=payload, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return result["choices"][0]["message"]["content"]
|
|
|
|
async def list_models(self) -> List[str]:
|
|
url = f"{self.base_url}/v1/models"
|
|
|
|
headers = {}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
response = await self.client.get(url, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
|
|
|
|
async def close(self):
|
|
await self.client.aclose()
|
|
|
|
|
|
def extract_json_from_response(response: str) -> dict:
|
|
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
|
|
if not response or not response.strip():
|
|
logger.error("LLM returned empty response")
|
|
raise ValueError("LLM returned empty response - the model may not support JSON format or the prompt was unclear")
|
|
|
|
json_str = response
|
|
|
|
# Try multiple extraction strategies
|
|
extraction_attempts = []
|
|
|
|
# Strategy 1: Look for markdown code blocks
|
|
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
|
|
if json_match:
|
|
extraction_attempts.append(json_match.group(1))
|
|
|
|
# Strategy 2: Look for JSON object pattern { ... }
|
|
json_obj_match = re.search(r'(\{[\s\S]*\})', response)
|
|
if json_obj_match:
|
|
extraction_attempts.append(json_obj_match.group(1))
|
|
|
|
# Strategy 3: Original response
|
|
extraction_attempts.append(response)
|
|
|
|
# Try each extraction attempt
|
|
for attempt_str in extraction_attempts:
|
|
# Clean up: remove extra whitespace, normalize spaces
|
|
cleaned = attempt_str.strip()
|
|
# Remove trailing whitespace before closing braces/brackets
|
|
cleaned = re.sub(r'\s+([}\]])', r'\1', cleaned)
|
|
# Normalize newlines but keep structure
|
|
cleaned = re.sub(r'[\t\r]+', ' ', cleaned)
|
|
|
|
try:
|
|
return json.loads(cleaned)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# All attempts failed
|
|
logger.error(f"Failed to parse JSON from response")
|
|
logger.error(f"Raw response: {response[:1000]}")
|
|
raise ValueError(f"Failed to parse LLM response as JSON. The model may not support structured output. Raw response: {response[:300]}...")
|
|
|
|
|
|
def parse_attribute_response(response: str) -> AttributeNode:
|
|
"""Parse LLM response into AttributeNode structure."""
|
|
data = extract_json_from_response(response)
|
|
return AttributeNode.model_validate(data)
|
|
|
|
|
|
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
|
|
"""Factory function to get LLM provider."""
|
|
if provider_type == "ollama":
|
|
return OllamaProvider()
|
|
elif provider_type == "openai":
|
|
return OpenAICompatibleProvider()
|
|
else:
|
|
raise ValueError(f"Unknown provider type: {provider_type}")
|
|
|
|
|
|
ollama_provider = OllamaProvider()
|