feat: Migrate to React Flow and add Fixed + Dynamic category mode
Frontend: - Migrate MindmapDAG from D3.js to React Flow (@xyflow/react) - Add custom node components (QueryNode, CategoryHeaderNode, AttributeNode) - Add useDAGLayout hook for column-based layout - Add "AI" badge for LLM-suggested categories - Update CategorySelector with Fixed + Dynamic mode option - Improve dark/light theme support Backend: - Add FIXED_PLUS_DYNAMIC category mode - Filter duplicate category names in LLM suggestions - Update prompts to exclude fixed categories when suggesting new ones - Improve LLM service with better error handling and logging - Auto-remove /no_think prefix for non-Qwen models - Add smart JSON format detection for model compatibility - Improve JSON extraction with multiple parsing strategies 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -73,6 +73,7 @@ class CategoryMode(str, Enum):
|
||||
"""類別模式"""
|
||||
FIXED_ONLY = "fixed_only"
|
||||
FIXED_PLUS_CUSTOM = "fixed_plus_custom"
|
||||
FIXED_PLUS_DYNAMIC = "fixed_plus_dynamic" # Fixed + LLM suggested
|
||||
CUSTOM_ONLY = "custom_only"
|
||||
DYNAMIC_AUTO = "dynamic_auto"
|
||||
|
||||
@@ -98,3 +99,35 @@ class DynamicStep1Result(BaseModel):
|
||||
class DynamicCausalChain(BaseModel):
|
||||
"""動態版本的因果鏈"""
|
||||
chain: Dict[str, str] # {類別名: 選中屬性}
|
||||
|
||||
|
||||
# ===== DAG (Directed Acyclic Graph) schemas =====
|
||||
|
||||
class DAGNode(BaseModel):
|
||||
"""DAG 節點 - 每個屬性只出現一次"""
|
||||
id: str # 唯一 ID: "{category}_{index}"
|
||||
name: str # 顯示名稱
|
||||
category: str # 所屬類別
|
||||
order: int # 欄位內位置
|
||||
|
||||
|
||||
class DAGEdge(BaseModel):
|
||||
"""DAG 邊 - 節點之間的連接"""
|
||||
source_id: str
|
||||
target_id: str
|
||||
|
||||
|
||||
class AttributeDAG(BaseModel):
|
||||
"""完整 DAG 結構"""
|
||||
query: str
|
||||
categories: List[CategoryDefinition]
|
||||
nodes: List[DAGNode]
|
||||
edges: List[DAGEdge]
|
||||
|
||||
|
||||
class DAGRelationship(BaseModel):
|
||||
"""Step 2 輸出 - 單一關係"""
|
||||
source_category: str
|
||||
source: str # source attribute name
|
||||
target_category: str
|
||||
target: str # target attribute name
|
||||
|
||||
@@ -120,17 +120,26 @@ def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None
|
||||
|
||||
# ===== Dynamic category system prompts =====
|
||||
|
||||
def get_step0_category_analysis_prompt(query: str, suggested_count: int = 3) -> str:
|
||||
def get_step0_category_analysis_prompt(
|
||||
query: str,
|
||||
suggested_count: int = 3,
|
||||
exclude_categories: List[str] | None = None
|
||||
) -> str:
|
||||
"""Step 0: LLM 分析建議類別"""
|
||||
exclude_text = ""
|
||||
if exclude_categories:
|
||||
exclude_text = f"\n【禁止使用的類別】{', '.join(exclude_categories)}(這些已經是固定類別,不要重複建議)\n"
|
||||
|
||||
return f"""/no_think
|
||||
分析「{query}」,建議 {suggested_count} 個最適合的屬性類別來描述它。
|
||||
|
||||
【常見類別參考】材料、功能、用途、使用族群、特性、形狀、顏色、尺寸、品牌、價格區間
|
||||
|
||||
【常見類別參考】特性、形狀、顏色、尺寸、品牌、價格區間、重量、風格、場合、季節、技術規格
|
||||
{exclude_text}
|
||||
【重要】
|
||||
1. 選擇最能描述此物件本質的類別
|
||||
2. 類別之間應該有邏輯關係(如:材料→功能→用途)
|
||||
2. 類別之間應該有邏輯關係
|
||||
3. 不要選擇過於抽象或重複的類別
|
||||
4. 必須建議與參考列表不同的、有創意的類別
|
||||
|
||||
只回傳 JSON:
|
||||
{{
|
||||
@@ -213,3 +222,47 @@ def get_step2_dynamic_causal_chain_prompt(
|
||||
|
||||
只回傳 JSON:
|
||||
{json.dumps(json_template, ensure_ascii=False, indent=2)}"""
|
||||
|
||||
|
||||
# ===== DAG relationship prompt =====
|
||||
|
||||
def get_step2_dag_relationships_prompt(
|
||||
query: str,
|
||||
categories: List, # List[CategoryDefinition]
|
||||
attributes_by_category: Dict[str, List[str]],
|
||||
) -> str:
|
||||
"""生成相鄰類別之間的自然關係"""
|
||||
sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0))
|
||||
|
||||
# Build attribute listing
|
||||
attr_listing = "\n".join([
|
||||
f"【{cat.name if hasattr(cat, 'name') else cat['name']}】{', '.join(attributes_by_category.get(cat.name if hasattr(cat, 'name') else cat['name'], []))}"
|
||||
for cat in sorted_cats
|
||||
])
|
||||
|
||||
# Build direction hints
|
||||
direction_hints = " → ".join([cat.name if hasattr(cat, 'name') else cat['name'] for cat in sorted_cats])
|
||||
|
||||
return f"""/no_think
|
||||
分析「{query}」的屬性關係。
|
||||
|
||||
{attr_listing}
|
||||
|
||||
【關係方向】{direction_hints}
|
||||
|
||||
【規則】
|
||||
1. 只建立相鄰類別之間的關係(例如:材料→功能,功能→用途)
|
||||
2. 只輸出真正有因果或關聯關係的配對
|
||||
3. 一個屬性可連接多個下游屬性,也可以不連接任何屬性
|
||||
4. 不需要每個屬性都有連接
|
||||
5. 關係應該合理且有意義
|
||||
|
||||
回傳 JSON:
|
||||
{{
|
||||
"relationships": [
|
||||
{{"source_category": "類別A", "source": "屬性名", "target_category": "類別B", "target": "屬性名"}},
|
||||
...
|
||||
]
|
||||
}}
|
||||
|
||||
只回傳 JSON。"""
|
||||
|
||||
@@ -16,6 +16,10 @@ from ..models.schemas import (
|
||||
Step0Result,
|
||||
DynamicStep1Result,
|
||||
DynamicCausalChain,
|
||||
DAGNode,
|
||||
DAGEdge,
|
||||
AttributeDAG,
|
||||
DAGRelationship,
|
||||
)
|
||||
from ..prompts.attribute_prompt import (
|
||||
get_step1_attributes_prompt,
|
||||
@@ -23,6 +27,7 @@ from ..prompts.attribute_prompt import (
|
||||
get_step0_category_analysis_prompt,
|
||||
get_step1_dynamic_attributes_prompt,
|
||||
get_step2_dynamic_causal_chain_prompt,
|
||||
get_step2_dag_relationships_prompt,
|
||||
)
|
||||
from ..services.llm_service import ollama_provider, extract_json_from_response
|
||||
|
||||
@@ -39,14 +44,21 @@ FIXED_CATEGORIES = [
|
||||
]
|
||||
|
||||
|
||||
async def execute_step0(request: StreamAnalyzeRequest) -> Step0Result | None:
|
||||
"""Execute Step 0 - LLM category analysis"""
|
||||
if request.category_mode == CategoryMode.FIXED_ONLY:
|
||||
return None
|
||||
async def execute_step0(
|
||||
request: StreamAnalyzeRequest,
|
||||
exclude_categories: List[str] | None = None
|
||||
) -> Step0Result | None:
|
||||
"""Execute Step 0 - LLM category analysis
|
||||
|
||||
Called only for modes that need LLM to suggest categories:
|
||||
- FIXED_PLUS_DYNAMIC
|
||||
- CUSTOM_ONLY
|
||||
- DYNAMIC_AUTO
|
||||
"""
|
||||
prompt = get_step0_category_analysis_prompt(
|
||||
request.query,
|
||||
request.suggested_category_count
|
||||
request.suggested_category_count,
|
||||
exclude_categories=exclude_categories
|
||||
)
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
response = await ollama_provider.generate(
|
||||
@@ -83,6 +95,34 @@ def resolve_final_categories(
|
||||
)
|
||||
return categories
|
||||
|
||||
elif request.category_mode == CategoryMode.FIXED_PLUS_DYNAMIC:
|
||||
# Fixed categories + LLM suggested categories
|
||||
categories = [
|
||||
CategoryDefinition(
|
||||
name=cat.name,
|
||||
description=cat.description,
|
||||
is_fixed=True,
|
||||
order=i
|
||||
)
|
||||
for i, cat in enumerate(FIXED_CATEGORIES)
|
||||
]
|
||||
if step0_result:
|
||||
# Filter out LLM categories that duplicate fixed category names
|
||||
fixed_names = {cat.name for cat in FIXED_CATEGORIES}
|
||||
added_count = 0
|
||||
for cat in step0_result.categories:
|
||||
if cat.name not in fixed_names:
|
||||
categories.append(
|
||||
CategoryDefinition(
|
||||
name=cat.name,
|
||||
description=cat.description,
|
||||
is_fixed=False,
|
||||
order=len(FIXED_CATEGORIES) + added_count
|
||||
)
|
||||
)
|
||||
added_count += 1
|
||||
return categories
|
||||
|
||||
elif request.category_mode == CategoryMode.CUSTOM_ONLY:
|
||||
return step0_result.categories if step0_result else FIXED_CATEGORIES
|
||||
|
||||
@@ -192,17 +232,73 @@ def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeN
|
||||
return root
|
||||
|
||||
|
||||
def build_dag_from_relationships(
|
||||
query: str,
|
||||
categories: List[CategoryDefinition],
|
||||
attributes_by_category: dict,
|
||||
relationships: List[DAGRelationship],
|
||||
) -> AttributeDAG:
|
||||
"""從屬性和關係建構 DAG"""
|
||||
sorted_cats = sorted(categories, key=lambda x: x.order)
|
||||
|
||||
# 建立節點 - 每個屬性只出現一次
|
||||
nodes: List[DAGNode] = []
|
||||
node_id_map: dict = {} # (category, name) -> id
|
||||
|
||||
for cat in sorted_cats:
|
||||
cat_name = cat.name
|
||||
for idx, attr_name in enumerate(attributes_by_category.get(cat_name, [])):
|
||||
node_id = f"{cat_name}_{idx}"
|
||||
nodes.append(DAGNode(
|
||||
id=node_id,
|
||||
name=attr_name,
|
||||
category=cat_name,
|
||||
order=idx
|
||||
))
|
||||
node_id_map[(cat_name, attr_name)] = node_id
|
||||
|
||||
# 建立邊
|
||||
edges: List[DAGEdge] = []
|
||||
for rel in relationships:
|
||||
source_key = (rel.source_category, rel.source)
|
||||
target_key = (rel.target_category, rel.target)
|
||||
if source_key in node_id_map and target_key in node_id_map:
|
||||
edges.append(DAGEdge(
|
||||
source_id=node_id_map[source_key],
|
||||
target_id=node_id_map[target_key]
|
||||
))
|
||||
|
||||
return AttributeDAG(
|
||||
query=query,
|
||||
categories=sorted_cats,
|
||||
nodes=nodes,
|
||||
edges=edges
|
||||
)
|
||||
|
||||
|
||||
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE events with dynamic category support"""
|
||||
try:
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
|
||||
# ========== Step 0: Category Analysis (if needed) ==========
|
||||
# Only these modes need LLM category analysis
|
||||
needs_step0 = request.category_mode in [
|
||||
CategoryMode.FIXED_PLUS_DYNAMIC,
|
||||
CategoryMode.CUSTOM_ONLY,
|
||||
CategoryMode.DYNAMIC_AUTO,
|
||||
]
|
||||
|
||||
step0_result = None
|
||||
if request.category_mode != CategoryMode.FIXED_ONLY:
|
||||
if needs_step0:
|
||||
yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step0_result = await execute_step0(request)
|
||||
# For FIXED_PLUS_DYNAMIC, exclude the fixed category names
|
||||
exclude_cats = None
|
||||
if request.category_mode == CategoryMode.FIXED_PLUS_DYNAMIC:
|
||||
exclude_cats = [cat.name for cat in FIXED_CATEGORIES]
|
||||
|
||||
step0_result = await execute_step0(request, exclude_categories=exclude_cats)
|
||||
|
||||
if step0_result:
|
||||
yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
@@ -227,58 +323,58 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
|
||||
|
||||
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Step 2: Generate Causal Chains (Dynamic) ==========
|
||||
causal_chains: List[DynamicCausalChain] = []
|
||||
# ========== Step 2: Generate Relationships (DAG) ==========
|
||||
yield f"event: relationships_start\ndata: {json.dumps({'message': '生成關係...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
for i in range(request.chain_count):
|
||||
chain_index = i + 1
|
||||
step2_prompt = get_step2_dag_relationships_prompt(
|
||||
query=request.query,
|
||||
categories=final_categories,
|
||||
attributes_by_category=step1_result.attributes,
|
||||
)
|
||||
logger.info(f"Step 2 (relationships) prompt: {step2_prompt[:300]}")
|
||||
|
||||
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
|
||||
relationships: List[DAGRelationship] = []
|
||||
max_retries = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step2_response = await ollama_provider.generate(
|
||||
step2_prompt, model=request.model, temperature=temperature
|
||||
)
|
||||
logger.info(f"Relationships response: {step2_response[:500]}")
|
||||
|
||||
step2_prompt = get_step2_dynamic_causal_chain_prompt(
|
||||
query=request.query,
|
||||
categories=final_categories,
|
||||
attributes_by_category=step1_result.attributes,
|
||||
existing_chains=[c.chain for c in causal_chains],
|
||||
chain_index=chain_index,
|
||||
)
|
||||
rel_data = extract_json_from_response(step2_response)
|
||||
raw_relationships = rel_data.get("relationships", [])
|
||||
|
||||
# Gradually increase temperature for diversity
|
||||
chain_temperature = min(temperature + 0.05 * i, 1.0)
|
||||
for rel in raw_relationships:
|
||||
relationships.append(DAGRelationship(
|
||||
source_category=rel.get("source_category", ""),
|
||||
source=rel.get("source", ""),
|
||||
target_category=rel.get("target_category", ""),
|
||||
target=rel.get("target", ""),
|
||||
))
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Relationships attempt {attempt + 1} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
temperature = min(temperature + 0.1, 1.0)
|
||||
|
||||
max_retries = 2
|
||||
chain = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step2_response = await ollama_provider.generate(
|
||||
step2_prompt, model=request.model, temperature=chain_temperature
|
||||
)
|
||||
logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
|
||||
yield f"event: relationships_complete\ndata: {json.dumps({'count': len(relationships)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
chain_data = extract_json_from_response(step2_response)
|
||||
chain = DynamicCausalChain(chain=chain_data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
chain_temperature = min(chain_temperature + 0.1, 1.0)
|
||||
|
||||
if chain:
|
||||
causal_chains.append(chain)
|
||||
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'生成失敗'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Assemble Final Tree (Dynamic) ==========
|
||||
final_tree = assemble_dynamic_attribute_tree(request.query, causal_chains, final_categories)
|
||||
# ========== Build DAG ==========
|
||||
dag = build_dag_from_relationships(
|
||||
query=request.query,
|
||||
categories=final_categories,
|
||||
attributes_by_category=step1_result.attributes,
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
final_result = {
|
||||
"query": request.query,
|
||||
"step0_result": step0_result.model_dump() if step0_result else None,
|
||||
"categories_used": [c.model_dump() for c in final_categories],
|
||||
"step1_result": step1_result.model_dump(),
|
||||
"causal_chains": [c.model_dump() for c in causal_chains],
|
||||
"attributes": final_tree.model_dump(),
|
||||
"relationships": [r.model_dump() for r in relationships],
|
||||
"dag": dag.model_dump(),
|
||||
}
|
||||
yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
@@ -8,6 +9,8 @@ import httpx
|
||||
from ..config import settings
|
||||
from ..models.schemas import AttributeNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
@abstractmethod
|
||||
@@ -35,34 +38,56 @@ class OllamaProvider(LLMProvider):
|
||||
model = model or settings.default_model
|
||||
url = f"{self.base_url}/api/generate"
|
||||
|
||||
# Remove /no_think prefix for non-qwen models (it's qwen-specific)
|
||||
clean_prompt = prompt
|
||||
if not model.lower().startswith("qwen") and prompt.startswith("/no_think"):
|
||||
clean_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "")
|
||||
logger.info(f"Removed /no_think prefix for model {model}")
|
||||
|
||||
# Models known to support JSON format well
|
||||
json_capable_models = ["qwen", "llama", "mistral", "gemma", "phi"]
|
||||
model_lower = model.lower()
|
||||
use_json_format = any(m in model_lower for m in json_capable_models)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"prompt": clean_prompt,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
# Only use format: json for models that support it
|
||||
if use_json_format:
|
||||
payload["format"] = "json"
|
||||
else:
|
||||
logger.info(f"Model {model} may not support JSON format, requesting without format constraint")
|
||||
|
||||
# Retry logic for larger models that may return empty responses
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
logger.info(f"LLM request attempt {attempt + 1}/{max_retries} to model {model}")
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
response_text = result.get("response", "")
|
||||
|
||||
logger.info(f"LLM response (first 500 chars): {response_text[:500] if response_text else '(empty)'}")
|
||||
|
||||
# Check if response is valid (not empty or just "{}")
|
||||
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
|
||||
return response_text
|
||||
|
||||
logger.warning(f"Empty or invalid response on attempt {attempt + 1}, retrying...")
|
||||
|
||||
# If empty, retry with slightly higher temperature
|
||||
if attempt < max_retries - 1:
|
||||
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
|
||||
|
||||
# Return whatever we got on last attempt
|
||||
logger.error(f"All {max_retries} attempts returned empty response from model {model}")
|
||||
return response_text
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
@@ -124,21 +149,46 @@ class OpenAICompatibleProvider(LLMProvider):
|
||||
|
||||
def extract_json_from_response(response: str) -> dict:
|
||||
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
|
||||
# Remove markdown code blocks if present
|
||||
if not response or not response.strip():
|
||||
logger.error("LLM returned empty response")
|
||||
raise ValueError("LLM returned empty response - the model may not support JSON format or the prompt was unclear")
|
||||
|
||||
json_str = response
|
||||
|
||||
# Try multiple extraction strategies
|
||||
extraction_attempts = []
|
||||
|
||||
# Strategy 1: Look for markdown code blocks
|
||||
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
json_str = response
|
||||
extraction_attempts.append(json_match.group(1))
|
||||
|
||||
# Clean up: remove extra whitespace, normalize spaces
|
||||
json_str = json_str.strip()
|
||||
# Remove trailing whitespace before closing braces/brackets
|
||||
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
|
||||
# Remove multiple spaces/tabs/newlines
|
||||
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
|
||||
# Strategy 2: Look for JSON object pattern { ... }
|
||||
json_obj_match = re.search(r'(\{[\s\S]*\})', response)
|
||||
if json_obj_match:
|
||||
extraction_attempts.append(json_obj_match.group(1))
|
||||
|
||||
return json.loads(json_str)
|
||||
# Strategy 3: Original response
|
||||
extraction_attempts.append(response)
|
||||
|
||||
# Try each extraction attempt
|
||||
for attempt_str in extraction_attempts:
|
||||
# Clean up: remove extra whitespace, normalize spaces
|
||||
cleaned = attempt_str.strip()
|
||||
# Remove trailing whitespace before closing braces/brackets
|
||||
cleaned = re.sub(r'\s+([}\]])', r'\1', cleaned)
|
||||
# Normalize newlines but keep structure
|
||||
cleaned = re.sub(r'[\t\r]+', ' ', cleaned)
|
||||
|
||||
try:
|
||||
return json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# All attempts failed
|
||||
logger.error(f"Failed to parse JSON from response")
|
||||
logger.error(f"Raw response: {response[:1000]}")
|
||||
raise ValueError(f"Failed to parse LLM response as JSON. The model may not support structured output. Raw response: {response[:300]}...")
|
||||
|
||||
|
||||
def parse_attribute_response(response: str) -> AttributeNode:
|
||||
|
||||
Reference in New Issue
Block a user