feat: Add dynamic category system for attribute analysis
Backend: - Add CategoryMode enum with 4 modes (fixed_only, fixed_plus_custom, custom_only, dynamic_auto) - Add Step 0 for LLM category analysis before attribute generation - Implement dynamic prompts for Step 1/2 that work with N categories - Add execute_step0(), resolve_final_categories(), assemble_dynamic_attribute_tree() - Update SSE events to include step0_start, step0_complete, categories_resolved Frontend: - Add CategorySelector component with mode selection, custom category input, and category count slider - Update types with CategoryDefinition, Step0Result, DynamicStep1Result, DynamicCausalChain - Update api.ts with new SSE event handlers - Update useAttribute hook with category parameters - Integrate CategorySelector into InputPanel - Fix mindmap to dynamically extract and display N categories (was hardcoded to 4) - Add CSS styles for depth 5-8 to support more category levels 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AttributeNode(BaseModel):
|
||||
@@ -46,12 +47,17 @@ class CausalChain(BaseModel):
|
||||
|
||||
|
||||
class StreamAnalyzeRequest(BaseModel):
|
||||
"""多步驟分析請求"""
|
||||
"""多步驟分析請求(更新為支持動態類別)"""
|
||||
query: str
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = 0.7
|
||||
chain_count: int = 5 # 用戶可設定要生成多少條因果鏈
|
||||
|
||||
# 新增:動態類別支持
|
||||
category_mode: Optional[str] = "dynamic_auto" # CategoryMode enum 值
|
||||
custom_categories: Optional[List[str]] = None
|
||||
suggested_category_count: int = 3 # 建議 LLM 生成的類別數量
|
||||
|
||||
|
||||
class StreamAnalyzeResponse(BaseModel):
|
||||
"""最終完整結果"""
|
||||
@@ -59,3 +65,36 @@ class StreamAnalyzeResponse(BaseModel):
|
||||
step1_result: Step1Result
|
||||
causal_chains: List[CausalChain]
|
||||
attributes: AttributeNode
|
||||
|
||||
|
||||
# ===== Dynamic category system schemas =====
|
||||
|
||||
class CategoryMode(str, Enum):
|
||||
"""類別模式"""
|
||||
FIXED_ONLY = "fixed_only"
|
||||
FIXED_PLUS_CUSTOM = "fixed_plus_custom"
|
||||
CUSTOM_ONLY = "custom_only"
|
||||
DYNAMIC_AUTO = "dynamic_auto"
|
||||
|
||||
|
||||
class CategoryDefinition(BaseModel):
|
||||
"""類別定義"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
is_fixed: bool = True # LLM 生成的為 False
|
||||
order: int = 0
|
||||
|
||||
|
||||
class Step0Result(BaseModel):
|
||||
"""Step 0: LLM 分析建議類別"""
|
||||
categories: List[CategoryDefinition]
|
||||
|
||||
|
||||
class DynamicStep1Result(BaseModel):
|
||||
"""動態版本的 Step 1 結果"""
|
||||
attributes: Dict[str, List[str]] # {類別名: [屬性列表]}
|
||||
|
||||
|
||||
class DynamicCausalChain(BaseModel):
|
||||
"""動態版本的因果鏈"""
|
||||
chain: Dict[str, str] # {類別名: 選中屬性}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
import json
|
||||
|
||||
DEFAULT_CATEGORIES = ["材料", "功能", "用途", "使用族群", "特性"]
|
||||
|
||||
@@ -115,3 +116,100 @@ def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None
|
||||
用戶輸入:{query}"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
# ===== Dynamic category system prompts =====
|
||||
|
||||
def get_step0_category_analysis_prompt(query: str, suggested_count: int = 3) -> str:
|
||||
"""Step 0: LLM 分析建議類別"""
|
||||
return f"""/no_think
|
||||
分析「{query}」,建議 {suggested_count} 個最適合的屬性類別來描述它。
|
||||
|
||||
【常見類別參考】材料、功能、用途、使用族群、特性、形狀、顏色、尺寸、品牌、價格區間
|
||||
|
||||
【重要】
|
||||
1. 選擇最能描述此物件本質的類別
|
||||
2. 類別之間應該有邏輯關係(如:材料→功能→用途)
|
||||
3. 不要選擇過於抽象或重複的類別
|
||||
|
||||
只回傳 JSON:
|
||||
{{
|
||||
"categories": [
|
||||
{{"name": "類別1", "description": "說明1", "order": 0}},
|
||||
{{"name": "類別2", "description": "說明2", "order": 1}}
|
||||
]
|
||||
}}
|
||||
|
||||
物件:{query}"""
|
||||
|
||||
|
||||
def get_step1_dynamic_attributes_prompt(
|
||||
query: str,
|
||||
categories: List # List[CategoryDefinition]
|
||||
) -> str:
|
||||
"""動態 Step 1 - 根據類別列表生成屬性"""
|
||||
# 按 order 排序並構建描述
|
||||
sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0))
|
||||
|
||||
category_desc = "\n".join([
|
||||
f"- {cat.name if hasattr(cat, 'name') else cat['name']}: {cat.description if hasattr(cat, 'description') else cat.get('description', '相關屬性')}"
|
||||
for cat in sorted_cats
|
||||
])
|
||||
|
||||
category_keys = [cat.name if hasattr(cat, 'name') else cat['name'] for cat in sorted_cats]
|
||||
json_template = {cat: ["屬性1", "屬性2", "屬性3"] for cat in category_keys}
|
||||
|
||||
return f"""/no_think
|
||||
分析「{query}」,列出以下類別的屬性。每個類別列出 3-5 個常見屬性。
|
||||
|
||||
【類別列表】
|
||||
{category_desc}
|
||||
|
||||
只回傳 JSON:
|
||||
{json.dumps(json_template, ensure_ascii=False, indent=2)}
|
||||
|
||||
物件:{query}"""
|
||||
|
||||
|
||||
def get_step2_dynamic_causal_chain_prompt(
|
||||
query: str,
|
||||
categories: List, # List[CategoryDefinition]
|
||||
attributes_by_category: Dict[str, List[str]],
|
||||
existing_chains: List[Dict[str, str]],
|
||||
chain_index: int
|
||||
) -> str:
|
||||
"""動態 Step 2 - 生成動態類別的因果鏈"""
|
||||
sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0))
|
||||
|
||||
# 構建可選屬性
|
||||
available_attrs = "\n".join([
|
||||
f"【{cat.name if hasattr(cat, 'name') else cat['name']}】{', '.join(attributes_by_category.get(cat.name if hasattr(cat, 'name') else cat['name'], []))}"
|
||||
for cat in sorted_cats
|
||||
])
|
||||
|
||||
# 已生成的因果鏈
|
||||
existing_text = ""
|
||||
if existing_chains:
|
||||
chains_list = [
|
||||
" → ".join([chain.get(cat.name if hasattr(cat, 'name') else cat['name'], '?') for cat in sorted_cats])
|
||||
for chain in existing_chains
|
||||
]
|
||||
existing_text = f"\n【已生成,請勿重複】\n" + "\n".join([f"- {c}" for c in chains_list])
|
||||
|
||||
# JSON 模板
|
||||
json_template = {cat.name if hasattr(cat, 'name') else cat['name']: f"選擇的{cat.name if hasattr(cat, 'name') else cat['name']}" for cat in sorted_cats}
|
||||
|
||||
return f"""/no_think
|
||||
為「{query}」生成第 {chain_index} 條因果鏈。
|
||||
|
||||
【可選屬性】
|
||||
{available_attrs}
|
||||
{existing_text}
|
||||
|
||||
【規則】
|
||||
1. 從每個類別選擇一個屬性
|
||||
2. 因果關係必須合理
|
||||
3. 不要重複
|
||||
|
||||
只回傳 JSON:
|
||||
{json.dumps(json_template, ensure_ascii=False, indent=2)}"""
|
||||
|
||||
@@ -11,10 +11,18 @@ from ..models.schemas import (
|
||||
Step1Result,
|
||||
CausalChain,
|
||||
AttributeNode,
|
||||
CategoryMode,
|
||||
CategoryDefinition,
|
||||
Step0Result,
|
||||
DynamicStep1Result,
|
||||
DynamicCausalChain,
|
||||
)
|
||||
from ..prompts.attribute_prompt import (
|
||||
get_step1_attributes_prompt,
|
||||
get_step2_causal_chain_prompt,
|
||||
get_step0_category_analysis_prompt,
|
||||
get_step1_dynamic_attributes_prompt,
|
||||
get_step2_dynamic_causal_chain_prompt,
|
||||
)
|
||||
from ..services.llm_service import ollama_provider, extract_json_from_response
|
||||
|
||||
@@ -22,6 +30,117 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["attributes"])
|
||||
|
||||
|
||||
# Fixed categories definition
|
||||
FIXED_CATEGORIES = [
|
||||
CategoryDefinition(name="材料", description="物件材料", is_fixed=True, order=0),
|
||||
CategoryDefinition(name="功能", description="物件功能", is_fixed=True, order=1),
|
||||
CategoryDefinition(name="用途", description="使用場景", is_fixed=True, order=2),
|
||||
CategoryDefinition(name="使用族群", description="目標用戶", is_fixed=True, order=3),
|
||||
]
|
||||
|
||||
|
||||
async def execute_step0(request: StreamAnalyzeRequest) -> Step0Result | None:
|
||||
"""Execute Step 0 - LLM category analysis"""
|
||||
if request.category_mode == CategoryMode.FIXED_ONLY:
|
||||
return None
|
||||
|
||||
prompt = get_step0_category_analysis_prompt(
|
||||
request.query,
|
||||
request.suggested_category_count
|
||||
)
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
response = await ollama_provider.generate(
|
||||
prompt, model=request.model, temperature=temperature
|
||||
)
|
||||
|
||||
data = extract_json_from_response(response)
|
||||
step0_result = Step0Result(**data)
|
||||
|
||||
# Mark as LLM generated
|
||||
for cat in step0_result.categories:
|
||||
cat.is_fixed = False
|
||||
|
||||
return step0_result
|
||||
|
||||
|
||||
def resolve_final_categories(
|
||||
request: StreamAnalyzeRequest,
|
||||
step0_result: Step0Result | None
|
||||
) -> List[CategoryDefinition]:
|
||||
"""Determine final categories based on mode"""
|
||||
if request.category_mode == CategoryMode.FIXED_ONLY:
|
||||
return FIXED_CATEGORIES
|
||||
|
||||
elif request.category_mode == CategoryMode.FIXED_PLUS_CUSTOM:
|
||||
categories = FIXED_CATEGORIES.copy()
|
||||
if request.custom_categories:
|
||||
for i, name in enumerate(request.custom_categories):
|
||||
categories.append(
|
||||
CategoryDefinition(
|
||||
name=name, is_fixed=False,
|
||||
order=len(FIXED_CATEGORIES) + i
|
||||
)
|
||||
)
|
||||
return categories
|
||||
|
||||
elif request.category_mode == CategoryMode.CUSTOM_ONLY:
|
||||
return step0_result.categories if step0_result else FIXED_CATEGORIES
|
||||
|
||||
elif request.category_mode == CategoryMode.DYNAMIC_AUTO:
|
||||
return step0_result.categories if step0_result else FIXED_CATEGORIES
|
||||
|
||||
return FIXED_CATEGORIES
|
||||
|
||||
|
||||
def assemble_dynamic_attribute_tree(
|
||||
query: str,
|
||||
chains: List[DynamicCausalChain],
|
||||
categories: List[CategoryDefinition]
|
||||
) -> AttributeNode:
|
||||
"""Assemble dynamic N-level tree from causal chains"""
|
||||
sorted_cats = sorted(categories, key=lambda x: x.order)
|
||||
|
||||
if not chains:
|
||||
return AttributeNode(name=query, children=[])
|
||||
|
||||
def build_recursive(
|
||||
level: int,
|
||||
parent_path: dict,
|
||||
remaining_chains: List[DynamicCausalChain]
|
||||
) -> List[AttributeNode]:
|
||||
if level >= len(sorted_cats):
|
||||
return []
|
||||
|
||||
current_cat = sorted_cats[level]
|
||||
grouped = {}
|
||||
|
||||
for chain in remaining_chains:
|
||||
# Check if this chain matches the parent path
|
||||
if all(chain.chain.get(k) == v for k, v in parent_path.items()):
|
||||
attr_val = chain.chain.get(current_cat.name)
|
||||
if attr_val:
|
||||
if attr_val not in grouped:
|
||||
grouped[attr_val] = []
|
||||
grouped[attr_val].append(chain)
|
||||
|
||||
nodes = []
|
||||
for attr_val, child_chains in grouped.items():
|
||||
new_path = {**parent_path, current_cat.name: attr_val}
|
||||
children = build_recursive(level + 1, new_path, child_chains)
|
||||
|
||||
node = AttributeNode(
|
||||
name=attr_val,
|
||||
category=current_cat.name,
|
||||
children=children if children else None
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
root_children = build_recursive(0, {}, chains)
|
||||
return AttributeNode(name=query, children=root_children)
|
||||
|
||||
|
||||
def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode:
|
||||
"""將因果鏈組裝成樹狀結構"""
|
||||
# 以材料為第一層分組
|
||||
@@ -74,14 +193,28 @@ def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeN
|
||||
|
||||
|
||||
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
|
||||
"""生成 SSE 事件流"""
|
||||
"""Generate SSE events with dynamic category support"""
|
||||
try:
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
|
||||
# ========== Step 1: 生成屬性列表 ==========
|
||||
yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n"
|
||||
# ========== Step 0: Category Analysis (if needed) ==========
|
||||
step0_result = None
|
||||
if request.category_mode != CategoryMode.FIXED_ONLY:
|
||||
yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step1_prompt = get_step1_attributes_prompt(request.query)
|
||||
step0_result = await execute_step0(request)
|
||||
|
||||
if step0_result:
|
||||
yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Resolve Final Categories ==========
|
||||
final_categories = resolve_final_categories(request, step0_result)
|
||||
yield f"event: categories_resolved\ndata: {json.dumps({'categories': [c.model_dump() for c in final_categories]}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Step 1: Generate Attributes (Dynamic) ==========
|
||||
yield f"event: step1_start\ndata: {json.dumps({'message': '生成屬性...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step1_prompt = get_step1_dynamic_attributes_prompt(request.query, final_categories)
|
||||
logger.info(f"Step 1 prompt: {step1_prompt[:200]}")
|
||||
|
||||
step1_response = await ollama_provider.generate(
|
||||
@@ -90,29 +223,27 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
|
||||
logger.info(f"Step 1 response: {step1_response[:500]}")
|
||||
|
||||
step1_data = extract_json_from_response(step1_response)
|
||||
step1_result = Step1Result(**step1_data)
|
||||
step1_result = DynamicStep1Result(attributes=step1_data)
|
||||
|
||||
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Step 2: 逐條生成因果鏈 ==========
|
||||
causal_chains: List[CausalChain] = []
|
||||
# ========== Step 2: Generate Causal Chains (Dynamic) ==========
|
||||
causal_chains: List[DynamicCausalChain] = []
|
||||
|
||||
for i in range(request.chain_count):
|
||||
chain_index = i + 1
|
||||
|
||||
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step2_prompt = get_step2_causal_chain_prompt(
|
||||
step2_prompt = get_step2_dynamic_causal_chain_prompt(
|
||||
query=request.query,
|
||||
materials=step1_result.materials,
|
||||
functions=step1_result.functions,
|
||||
usages=step1_result.usages,
|
||||
users=step1_result.users,
|
||||
existing_chains=[c.model_dump() for c in causal_chains],
|
||||
categories=final_categories,
|
||||
attributes_by_category=step1_result.attributes,
|
||||
existing_chains=[c.chain for c in causal_chains],
|
||||
chain_index=chain_index,
|
||||
)
|
||||
|
||||
# 逐漸提高 temperature 增加多樣性
|
||||
# Gradually increase temperature for diversity
|
||||
chain_temperature = min(temperature + 0.05 * i, 1.0)
|
||||
|
||||
max_retries = 2
|
||||
@@ -125,7 +256,7 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
|
||||
logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
|
||||
|
||||
chain_data = extract_json_from_response(step2_response)
|
||||
chain = CausalChain(**chain_data)
|
||||
chain = DynamicCausalChain(chain=chain_data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
|
||||
@@ -136,13 +267,15 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
|
||||
causal_chains.append(chain)
|
||||
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'第 {chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'生成失敗'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== 組裝最終結構 ==========
|
||||
final_tree = assemble_attribute_tree(request.query, causal_chains)
|
||||
# ========== Assemble Final Tree (Dynamic) ==========
|
||||
final_tree = assemble_dynamic_attribute_tree(request.query, causal_chains, final_categories)
|
||||
|
||||
final_result = {
|
||||
"query": request.query,
|
||||
"step0_result": step0_result.model_dump() if step0_result else None,
|
||||
"categories_used": [c.model_dump() for c in final_categories],
|
||||
"step1_result": step1_result.model_dump(),
|
||||
"causal_chains": [c.model_dump() for c in causal_chains],
|
||||
"attributes": final_tree.model_dump(),
|
||||
|
||||
Reference in New Issue
Block a user