diff --git a/backend/app/models/schemas.py b/backend/app/models/schemas.py index 9c59ab1..fbf2f78 100644 --- a/backend/app/models/schemas.py +++ b/backend/app/models/schemas.py @@ -1,5 +1,6 @@ from pydantic import BaseModel -from typing import Optional, List +from typing import Optional, List, Dict +from enum import Enum class AttributeNode(BaseModel): @@ -46,12 +47,17 @@ class CausalChain(BaseModel): class StreamAnalyzeRequest(BaseModel): - """多步驟分析請求""" + """多步驟分析請求(更新為支持動態類別)""" query: str model: Optional[str] = None temperature: Optional[float] = 0.7 chain_count: int = 5 # 用戶可設定要生成多少條因果鏈 + # 新增:動態類別支持 + category_mode: Optional[str] = "dynamic_auto" # CategoryMode enum 值 + custom_categories: Optional[List[str]] = None + suggested_category_count: int = 3 # 建議 LLM 生成的類別數量 + class StreamAnalyzeResponse(BaseModel): """最終完整結果""" @@ -59,3 +65,36 @@ class StreamAnalyzeResponse(BaseModel): step1_result: Step1Result causal_chains: List[CausalChain] attributes: AttributeNode + + +# ===== Dynamic category system schemas ===== + +class CategoryMode(str, Enum): + """類別模式""" + FIXED_ONLY = "fixed_only" + FIXED_PLUS_CUSTOM = "fixed_plus_custom" + CUSTOM_ONLY = "custom_only" + DYNAMIC_AUTO = "dynamic_auto" + + +class CategoryDefinition(BaseModel): + """類別定義""" + name: str + description: Optional[str] = None + is_fixed: bool = True # LLM 生成的為 False + order: int = 0 + + +class Step0Result(BaseModel): + """Step 0: LLM 分析建議類別""" + categories: List[CategoryDefinition] + + +class DynamicStep1Result(BaseModel): + """動態版本的 Step 1 結果""" + attributes: Dict[str, List[str]] # {類別名: [屬性列表]} + + +class DynamicCausalChain(BaseModel): + """動態版本的因果鏈""" + chain: Dict[str, str] # {類別名: 選中屬性} diff --git a/backend/app/prompts/attribute_prompt.py b/backend/app/prompts/attribute_prompt.py index 6cd4e53..07f6255 100644 --- a/backend/app/prompts/attribute_prompt.py +++ b/backend/app/prompts/attribute_prompt.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from typing import List, Optional, Dict +import json DEFAULT_CATEGORIES = ["材料", "功能", "用途", "使用族群", "特性"] @@ -115,3 +116,100 @@ def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None 用戶輸入:{query}""" return prompt + + +# ===== Dynamic category system prompts ===== + +def get_step0_category_analysis_prompt(query: str, suggested_count: int = 3) -> str: + """Step 0: LLM 分析建議類別""" + return f"""/no_think +分析「{query}」,建議 {suggested_count} 個最適合的屬性類別來描述它。 + +【常見類別參考】材料、功能、用途、使用族群、特性、形狀、顏色、尺寸、品牌、價格區間 + +【重要】 +1. 選擇最能描述此物件本質的類別 +2. 類別之間應該有邏輯關係(如:材料→功能→用途) +3. 不要選擇過於抽象或重複的類別 + +只回傳 JSON: +{{ + "categories": [ + {{"name": "類別1", "description": "說明1", "order": 0}}, + {{"name": "類別2", "description": "說明2", "order": 1}} + ] +}} + +物件:{query}""" + + +def get_step1_dynamic_attributes_prompt( + query: str, + categories: List # List[CategoryDefinition] +) -> str: + """動態 Step 1 - 根據類別列表生成屬性""" + # 按 order 排序並構建描述 + sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0)) + + category_desc = "\n".join([ + f"- {cat.name if hasattr(cat, 'name') else cat['name']}: {cat.description if hasattr(cat, 'description') else cat.get('description', '相關屬性')}" + for cat in sorted_cats + ]) + + category_keys = [cat.name if hasattr(cat, 'name') else cat['name'] for cat in sorted_cats] + json_template = {cat: ["屬性1", "屬性2", "屬性3"] for cat in category_keys} + + return f"""/no_think +分析「{query}」,列出以下類別的屬性。每個類別列出 3-5 個常見屬性。 + +【類別列表】 +{category_desc} + +只回傳 JSON: +{json.dumps(json_template, ensure_ascii=False, indent=2)} + +物件:{query}""" + + +def get_step2_dynamic_causal_chain_prompt( + query: str, + categories: List, # List[CategoryDefinition] + attributes_by_category: Dict[str, List[str]], + existing_chains: List[Dict[str, str]], + chain_index: int +) -> str: + """動態 Step 2 - 生成動態類別的因果鏈""" + sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0)) + + # 構建可選屬性 + available_attrs = "\n".join([ + f"【{cat.name if hasattr(cat, 'name') else cat['name']}】{', '.join(attributes_by_category.get(cat.name if hasattr(cat, 'name') else cat['name'], []))}" + for cat in sorted_cats + ]) + + # 已生成的因果鏈 + existing_text = "" + if existing_chains: + chains_list = [ + " → ".join([chain.get(cat.name if hasattr(cat, 'name') else cat['name'], '?') for cat in sorted_cats]) + for chain in existing_chains + ] + existing_text = f"\n【已生成,請勿重複】\n" + "\n".join([f"- {c}" for c in chains_list]) + + # JSON 模板 + json_template = {cat.name if hasattr(cat, 'name') else cat['name']: f"選擇的{cat.name if hasattr(cat, 'name') else cat['name']}" for cat in sorted_cats} + + return f"""/no_think +為「{query}」生成第 {chain_index} 條因果鏈。 + +【可選屬性】 +{available_attrs} +{existing_text} + +【規則】 +1. 從每個類別選擇一個屬性 +2. 因果關係必須合理 +3. 不要重複 + +只回傳 JSON: +{json.dumps(json_template, ensure_ascii=False, indent=2)}""" diff --git a/backend/app/routers/attributes.py b/backend/app/routers/attributes.py index 2759af2..6c351f1 100644 --- a/backend/app/routers/attributes.py +++ b/backend/app/routers/attributes.py @@ -11,10 +11,18 @@ from ..models.schemas import ( Step1Result, CausalChain, AttributeNode, + CategoryMode, + CategoryDefinition, + Step0Result, + DynamicStep1Result, + DynamicCausalChain, ) from ..prompts.attribute_prompt import ( get_step1_attributes_prompt, get_step2_causal_chain_prompt, + get_step0_category_analysis_prompt, + get_step1_dynamic_attributes_prompt, + get_step2_dynamic_causal_chain_prompt, ) from ..services.llm_service import ollama_provider, extract_json_from_response @@ -22,6 +30,117 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["attributes"]) +# Fixed categories definition +FIXED_CATEGORIES = [ + CategoryDefinition(name="材料", description="物件材料", is_fixed=True, order=0), + CategoryDefinition(name="功能", description="物件功能", is_fixed=True, order=1), + CategoryDefinition(name="用途", description="使用場景", is_fixed=True, order=2), + CategoryDefinition(name="使用族群", description="目標用戶", is_fixed=True, order=3), +] + + +async def execute_step0(request: StreamAnalyzeRequest) -> Step0Result | None: + """Execute Step 0 - LLM category analysis""" + if request.category_mode == CategoryMode.FIXED_ONLY: + return None + + prompt = get_step0_category_analysis_prompt( + request.query, + request.suggested_category_count + ) + temperature = request.temperature if request.temperature is not None else 0.7 + response = await ollama_provider.generate( + prompt, model=request.model, temperature=temperature + ) + + data = extract_json_from_response(response) + step0_result = Step0Result(**data) + + # Mark as LLM generated + for cat in step0_result.categories: + cat.is_fixed = False + + return step0_result + + +def resolve_final_categories( + request: StreamAnalyzeRequest, + step0_result: Step0Result | None +) -> List[CategoryDefinition]: + """Determine final categories based on mode""" + if request.category_mode == CategoryMode.FIXED_ONLY: + return FIXED_CATEGORIES + + elif request.category_mode == CategoryMode.FIXED_PLUS_CUSTOM: + categories = FIXED_CATEGORIES.copy() + if request.custom_categories: + for i, name in enumerate(request.custom_categories): + categories.append( + CategoryDefinition( + name=name, is_fixed=False, + order=len(FIXED_CATEGORIES) + i + ) + ) + return categories + + elif request.category_mode == CategoryMode.CUSTOM_ONLY: + return step0_result.categories if step0_result else FIXED_CATEGORIES + + elif request.category_mode == CategoryMode.DYNAMIC_AUTO: + return step0_result.categories if step0_result else FIXED_CATEGORIES + + return FIXED_CATEGORIES + + +def assemble_dynamic_attribute_tree( + query: str, + chains: List[DynamicCausalChain], + categories: List[CategoryDefinition] +) -> AttributeNode: + """Assemble dynamic N-level tree from causal chains""" + sorted_cats = sorted(categories, key=lambda x: x.order) + + if not chains: + return AttributeNode(name=query, children=[]) + + def build_recursive( + level: int, + parent_path: dict, + remaining_chains: List[DynamicCausalChain] + ) -> List[AttributeNode]: + if level >= len(sorted_cats): + return [] + + current_cat = sorted_cats[level] + grouped = {} + + for chain in remaining_chains: + # Check if this chain matches the parent path + if all(chain.chain.get(k) == v for k, v in parent_path.items()): + attr_val = chain.chain.get(current_cat.name) + if attr_val: + if attr_val not in grouped: + grouped[attr_val] = [] + grouped[attr_val].append(chain) + + nodes = [] + for attr_val, child_chains in grouped.items(): + new_path = {**parent_path, current_cat.name: attr_val} + children = build_recursive(level + 1, new_path, child_chains) + + node = AttributeNode( + name=attr_val, + category=current_cat.name, + children=children if children else None + ) + nodes.append(node) + + return nodes + + root_children = build_recursive(0, {}, chains) + return AttributeNode(name=query, children=root_children) + + def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode: """將因果鏈組裝成樹狀結構""" # 以材料為第一層分組 @@ -74,14 +193,28 @@ def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeN async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]: - """生成 SSE 事件流""" + """Generate SSE events with dynamic category support""" try: temperature = request.temperature if request.temperature is not None else 0.7 - # ========== Step 1: 生成屬性列表 ========== - yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n" + # ========== Step 0: Category Analysis (if needed) ========== + step0_result = None + if request.category_mode != CategoryMode.FIXED_ONLY: + yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n" - step1_prompt = get_step1_attributes_prompt(request.query) + step0_result = await execute_step0(request) + + if step0_result: + yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n" + + # ========== Resolve Final Categories ========== + final_categories = resolve_final_categories(request, step0_result) + yield f"event: categories_resolved\ndata: {json.dumps({'categories': [c.model_dump() for c in final_categories]}, ensure_ascii=False)}\n\n" + + # ========== Step 1: Generate Attributes (Dynamic) ========== + yield f"event: step1_start\ndata: {json.dumps({'message': '生成屬性...'}, ensure_ascii=False)}\n\n" + + step1_prompt = get_step1_dynamic_attributes_prompt(request.query, final_categories) logger.info(f"Step 1 prompt: {step1_prompt[:200]}") step1_response = await ollama_provider.generate( @@ -90,29 +223,27 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s logger.info(f"Step 1 response: {step1_response[:500]}") step1_data = extract_json_from_response(step1_response) - step1_result = Step1Result(**step1_data) + step1_result = DynamicStep1Result(attributes=step1_data) yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n" - # ========== Step 2: 逐條生成因果鏈 ========== - causal_chains: List[CausalChain] = [] + # ========== Step 2: Generate Causal Chains (Dynamic) ========== + causal_chains: List[DynamicCausalChain] = [] for i in range(request.chain_count): chain_index = i + 1 yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n" - step2_prompt = get_step2_causal_chain_prompt( + step2_prompt = get_step2_dynamic_causal_chain_prompt( query=request.query, - materials=step1_result.materials, - functions=step1_result.functions, - usages=step1_result.usages, - users=step1_result.users, - existing_chains=[c.model_dump() for c in causal_chains], + categories=final_categories, + attributes_by_category=step1_result.attributes, + existing_chains=[c.chain for c in causal_chains], chain_index=chain_index, ) - # 逐漸提高 temperature 增加多樣性 + # Gradually increase temperature for diversity chain_temperature = min(temperature + 0.05 * i, 1.0) max_retries = 2 @@ -125,7 +256,7 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s logger.info(f"Chain {chain_index} response: {step2_response[:300]}") chain_data = extract_json_from_response(step2_response) - chain = CausalChain(**chain_data) + chain = DynamicCausalChain(chain=chain_data) break except Exception as e: logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}") @@ -136,13 +267,15 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s causal_chains.append(chain) yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n" else: - yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'第 {chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n" + yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'生成失敗'}, ensure_ascii=False)}\n\n" - # ========== 組裝最終結構 ========== - final_tree = assemble_attribute_tree(request.query, causal_chains) + # ========== Assemble Final Tree (Dynamic) ========== + final_tree = assemble_dynamic_attribute_tree(request.query, causal_chains, final_categories) final_result = { "query": request.query, + "step0_result": step0_result.model_dump() if step0_result else None, + "categories_used": [c.model_dump() for c in final_categories], "step1_result": step1_result.model_dump(), "causal_chains": [c.model_dump() for c in causal_chains], "attributes": final_tree.model_dump(), diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index b096717..5e1010a 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -5,6 +5,7 @@ import { InputPanel } from './components/InputPanel'; import { MindmapPanel } from './components/MindmapPanel'; import { useAttribute } from './hooks/useAttribute'; import type { MindmapD3Ref } from './components/MindmapD3'; +import type { CategoryMode } from './types'; const { Header, Sider, Content } = Layout; const { Title } = Typography; @@ -27,9 +28,12 @@ function App() { query: string, model?: string, temperature?: number, - chainCount?: number + chainCount?: number, + categoryMode?: CategoryMode, + customCategories?: string[], + suggestedCategoryCount?: number ) => { - await analyze(query, model, temperature, chainCount); + await analyze(query, model, temperature, chainCount, categoryMode, customCategories, suggestedCategoryCount); }; const handleExpandAll = useCallback(() => { diff --git a/frontend/src/components/CategorySelector.tsx b/frontend/src/components/CategorySelector.tsx new file mode 100644 index 0000000..421ffbd --- /dev/null +++ b/frontend/src/components/CategorySelector.tsx @@ -0,0 +1,157 @@ +import { useState } from 'react'; +import { Radio, Space, Input, Button, Tag, Tooltip, Slider, Typography } from 'antd'; +import { InfoCircleOutlined, PlusOutlined } from '@ant-design/icons'; +import type { CategoryDefinition, CategoryMode, Step0Result } from '../types'; + +const { Text } = Typography; + +interface CategorySelectorProps { + mode: CategoryMode; + onModeChange: (mode: CategoryMode) => void; + customCategories: string[]; + onCustomCategoriesChange: (cats: string[]) => void; + suggestedCount: number; + onSuggestedCountChange: (count: number) => void; + step0Result?: Step0Result; + onStep0Edit?: (cats: CategoryDefinition[]) => void; + disabled?: boolean; +} + +export function CategorySelector({ + mode, + onModeChange, + customCategories, + onCustomCategoriesChange, + suggestedCount, + onSuggestedCountChange, + step0Result, + onStep0Edit, + disabled +}: CategorySelectorProps) { + const [inputValue, setInputValue] = useState(''); + + return ( + + onModeChange(e.target.value as CategoryMode)} + disabled={disabled} + > + + + Fixed (材料、功能、用途、使用族群) + + + Fixed + Custom + + + Custom Only (LLM suggests) + + + Dynamic (LLM suggests, editable) + + + + + {/* 動態模式:類別數量調整 */} + {(mode === 'custom_only' || mode === 'dynamic_auto') && ( +
+ Suggested Category Count: {suggestedCount} + +
+ )} + + {/* 固定+自定義模式 */} + {mode === 'fixed_plus_custom' && ( +
+ Add custom categories: + + setInputValue(e.target.value)} + onPressEnter={() => { + if (inputValue.trim()) { + onCustomCategoriesChange([...customCategories, inputValue.trim()]); + setInputValue(''); + } + }} + disabled={disabled} + /> + + + + {customCategories.length > 0 && ( +
+ {customCategories.map((cat, i) => ( + { + onCustomCategoriesChange(customCategories.filter((_, idx) => idx !== i)); + }} + > + {cat} + + ))} +
+ )} +
+ )} + + {/* Step 0 結果顯示 */} + {step0Result && (mode === 'custom_only' || mode === 'dynamic_auto') && ( +
+ LLM Suggested: +
+ {step0Result.categories.map((cat, i) => ( + { + if (onStep0Edit) { + onStep0Edit(step0Result.categories.filter((_, idx) => idx !== i)); + } + } : undefined} + > + {cat.name} + {cat.description && ( + + + + )} + + ))} +
+ + {mode === 'dynamic_auto' && ( + + You can remove tags or proceed + + )} +
+ )} +
+ ); +} diff --git a/frontend/src/components/InputPanel.tsx b/frontend/src/components/InputPanel.tsx index 92ceb2f..a75b817 100644 --- a/frontend/src/components/InputPanel.tsx +++ b/frontend/src/components/InputPanel.tsx @@ -22,8 +22,9 @@ import { LoadingOutlined, CheckCircleOutlined, } from '@ant-design/icons'; -import type { HistoryItem, AttributeNode, StreamProgress } from '../types'; +import type { HistoryItem, AttributeNode, StreamProgress, CategoryMode, DynamicCausalChain, CausalChain } from '../types'; import { getModels } from '../services/api'; +import { CategorySelector } from './CategorySelector'; const { TextArea } = Input; const { Text } = Typography; @@ -38,7 +39,15 @@ interface InputPanelProps { progress: StreamProgress; history: HistoryItem[]; currentResult: AttributeNode | null; - onAnalyze: (query: string, model?: string, temperature?: number, chainCount?: number) => Promise; + onAnalyze: ( + query: string, + model?: string, + temperature?: number, + chainCount?: number, + categoryMode?: CategoryMode, + customCategories?: string[], + suggestedCategoryCount?: number + ) => Promise; onLoadHistory: (item: HistoryItem) => void; onExpandAll?: () => void; onCollapseAll?: () => void; @@ -64,6 +73,10 @@ export function InputPanel({ const [loadingModels, setLoadingModels] = useState(false); const [temperature, setTemperature] = useState(0.7); const [chainCount, setChainCount] = useState(5); + // Category settings + const [categoryMode, setCategoryMode] = useState('dynamic_auto' as CategoryMode); + const [customCategories, setCustomCategories] = useState([]); + const [suggestedCategoryCount, setSuggestedCategoryCount] = useState(3); useEffect(() => { async function fetchModels() { @@ -92,7 +105,15 @@ export function InputPanel({ } try { - await onAnalyze(query.trim(), selectedModel, temperature, chainCount); + await onAnalyze( + query.trim(), + selectedModel, + temperature, + chainCount, + categoryMode, + customCategories.length > 0 ? customCategories : undefined, + suggestedCategoryCount + ); setQuery(''); } catch { message.error('Analysis failed'); @@ -191,14 +212,27 @@ export function InputPanel({ img.src = 'data:image/svg+xml;base64,' + btoa(unescape(encodeURIComponent(svgData))); }; + // Helper to format chain display (supports both fixed and dynamic chains) + const formatChain = (chain: CausalChain | DynamicCausalChain): string => { + if ('chain' in chain) { + // Dynamic chain + return Object.values(chain.chain).join(' → '); + } else { + // Fixed chain + return `${chain.material} → ${chain.function} → ${chain.usage} → ${chain.user}`; + } + }; + const renderProgressIndicator = () => { if (progress.step === 'idle' || progress.step === 'done') return null; - const percent = progress.step === 'step1' - ? 10 - : progress.step === 'chains' - ? 10 + (progress.currentChainIndex / progress.totalChains) * 90 - : 100; + const percent = progress.step === 'step0' + ? 5 + : progress.step === 'step1' + ? 10 + : progress.step === 'chains' + ? 10 + (progress.currentChainIndex / progress.totalChains) * 90 + : 100; return (
@@ -213,6 +247,20 @@ export function InputPanel({ + {/* Show categories used */} + {progress.categoriesUsed && progress.categoriesUsed.length > 0 && ( +
+ Categories: +
+ {progress.categoriesUsed.map((cat, i) => ( + + {cat.name} + + ))} +
+
+ )} + {progress.completedChains.length > 0 && (
Completed chains: @@ -220,7 +268,7 @@ export function InputPanel({ {progress.completedChains.map((chain, i) => (
- {chain.material} → {chain.function} → {chain.usage} → {chain.user} + {formatChain(chain)}
))}
@@ -232,6 +280,22 @@ export function InputPanel({ }; const collapseItems = [ + { + key: 'categories', + label: 'Category Settings', + children: ( + + ), + }, { key: 'llm', label: 'LLM Parameters', diff --git a/frontend/src/components/MindmapD3.tsx b/frontend/src/components/MindmapD3.tsx index 06838f4..b374b8e 100644 --- a/frontend/src/components/MindmapD3.tsx +++ b/frontend/src/components/MindmapD3.tsx @@ -117,8 +117,19 @@ export const MindmapD3 = forwardRef( d._children = undefined; }); - // Category labels for header - const categoryLabels = ['', '材料', '功能', '用途', '使用族群']; + // Dynamically extract category labels from the tree based on depth + // Each depth level corresponds to a category + const categoryByDepth: Record = {}; + root.descendants().forEach((d: TreeNode) => { + if (d.depth > 0 && d.data.category && !categoryByDepth[d.depth]) { + categoryByDepth[d.depth] = d.data.category; + } + }); + const maxDepthWithCategory = Math.max(...Object.keys(categoryByDepth).map(Number), 0); + const categoryLabels = ['']; + for (let i = 1; i <= maxDepthWithCategory; i++) { + categoryLabels.push(categoryByDepth[i] || ''); + } const headerHeight = 40; function update(source: TreeNode) { @@ -143,12 +154,27 @@ export const MindmapD3 = forwardRef( // Draw category headers with background g.selectAll('.category-header-group').remove(); const maxDepth = Math.max(...descendants.map(d => d.depth)); - const categoryColors: Record = { - '材料': isDark ? '#854eca' : '#722ed1', - '功能': isDark ? '#13a8a8' : '#13c2c2', - '用途': isDark ? '#d87a16' : '#fa8c16', - '使用族群': isDark ? '#49aa19' : '#52c41a', - }; + + // Dynamic color palette for categories + const colorPalette = [ + { dark: '#854eca', light: '#722ed1' }, // purple + { dark: '#13a8a8', light: '#13c2c2' }, // cyan + { dark: '#d87a16', light: '#fa8c16' }, // orange + { dark: '#49aa19', light: '#52c41a' }, // green + { dark: '#1677ff', light: '#1890ff' }, // blue + { dark: '#eb2f96', light: '#f759ab' }, // magenta + { dark: '#faad14', light: '#ffc53d' }, // gold + { dark: '#a0d911', light: '#bae637' }, // lime + ]; + + // Generate colors dynamically based on category position + const categoryColors: Record = {}; + categoryLabels.forEach((label, index) => { + if (label && index > 0) { + const colorIndex = (index - 1) % colorPalette.length; + categoryColors[label] = isDark ? colorPalette[colorIndex].dark : colorPalette[colorIndex].light; + } + }); for (let depth = 1; depth <= Math.min(maxDepth, categoryLabels.length - 1); depth++) { const label = categoryLabels[depth]; diff --git a/frontend/src/hooks/useAttribute.ts b/frontend/src/hooks/useAttribute.ts index 2987852..4a36178 100644 --- a/frontend/src/hooks/useAttribute.ts +++ b/frontend/src/hooks/useAttribute.ts @@ -3,9 +3,9 @@ import type { AttributeNode, HistoryItem, StreamProgress, - StreamAnalyzeResponse, - CausalChain + StreamAnalyzeResponse } from '../types'; +import { CategoryMode } from '../types'; import { analyzeAttributesStream } from '../services/api'; export function useAttribute() { @@ -24,7 +24,10 @@ export function useAttribute() { query: string, model?: string, temperature?: number, - chainCount: number = 5 + chainCount: number = 5, + categoryMode: CategoryMode = CategoryMode.DYNAMIC_AUTO, + customCategories?: string[], + suggestedCategoryCount: number = 3 ) => { // 重置狀態 setProgress({ @@ -39,8 +42,40 @@ export function useAttribute() { try { await analyzeAttributesStream( - { query, chain_count: chainCount, model, temperature }, { + query, + chain_count: chainCount, + model, + temperature, + category_mode: categoryMode, + custom_categories: customCategories, + suggested_category_count: suggestedCategoryCount + }, + { + onStep0Start: () => { + setProgress(prev => ({ + ...prev, + step: 'step0', + message: '正在分析類別...', + })); + }, + + onStep0Complete: (result) => { + setProgress(prev => ({ + ...prev, + step0Result: result, + message: '類別分析完成', + })); + }, + + onCategoriesResolved: (categories) => { + setProgress(prev => ({ + ...prev, + categoriesUsed: categories, + message: `使用 ${categories.length} 個類別`, + })); + }, + onStep1Start: () => { setProgress(prev => ({ ...prev, @@ -148,7 +183,7 @@ export function useAttribute() { }); }, []); - const isLoading = progress.step === 'step1' || progress.step === 'chains'; + const isLoading = progress.step === 'step0' || progress.step === 'step1' || progress.step === 'chains'; return { loading: isLoading, diff --git a/frontend/src/services/api.ts b/frontend/src/services/api.ts index e4cc845..634cf72 100644 --- a/frontend/src/services/api.ts +++ b/frontend/src/services/api.ts @@ -3,17 +3,24 @@ import type { StreamAnalyzeRequest, StreamAnalyzeResponse, Step1Result, - CausalChain + CausalChain, + Step0Result, + CategoryDefinition, + DynamicStep1Result, + DynamicCausalChain } from '../types'; // 自動使用當前瀏覽器的 hostname,支援遠端存取 const API_BASE_URL = `http://${window.location.hostname}:8000/api`; export interface SSECallbacks { + onStep0Start?: () => void; + onStep0Complete?: (result: Step0Result) => void; + onCategoriesResolved?: (categories: CategoryDefinition[]) => void; onStep1Start?: () => void; - onStep1Complete?: (result: Step1Result) => void; + onStep1Complete?: (result: Step1Result | DynamicStep1Result) => void; onChainStart?: (index: number, total: number) => void; - onChainComplete?: (index: number, chain: CausalChain) => void; + onChainComplete?: (index: number, chain: CausalChain | DynamicCausalChain) => void; onChainError?: (index: number, error: string) => void; onDone?: (response: StreamAnalyzeResponse) => void; onError?: (error: string) => void; @@ -65,6 +72,15 @@ export async function analyzeAttributesStream( const eventData = JSON.parse(dataMatch[1]); switch (eventType) { + case 'step0_start': + callbacks.onStep0Start?.(); + break; + case 'step0_complete': + callbacks.onStep0Complete?.(eventData.result); + break; + case 'categories_resolved': + callbacks.onCategoriesResolved?.(eventData.categories); + break; case 'step1_start': callbacks.onStep1Start?.(); break; diff --git a/frontend/src/styles/mindmap.css b/frontend/src/styles/mindmap.css index 7053c6a..da36ab1 100644 --- a/frontend/src/styles/mindmap.css +++ b/frontend/src/styles/mindmap.css @@ -126,6 +126,38 @@ fill: #fff; } +.mindmap-light .node-rect.depth-5 { + fill: #1890ff; + stroke: #096dd9; +} +.mindmap-light .node-text.depth-5 { + fill: #fff; +} + +.mindmap-light .node-rect.depth-6 { + fill: #f759ab; + stroke: #eb2f96; +} +.mindmap-light .node-text.depth-6 { + fill: #fff; +} + +.mindmap-light .node-rect.depth-7 { + fill: #ffc53d; + stroke: #faad14; +} +.mindmap-light .node-text.depth-7 { + fill: #fff; +} + +.mindmap-light .node-rect.depth-8 { + fill: #bae637; + stroke: #a0d911; +} +.mindmap-light .node-text.depth-8 { + fill: #fff; +} + .mindmap-light .link { stroke: #bfbfbf; } @@ -221,6 +253,38 @@ fill: #fff; } +.mindmap-dark .node-rect.depth-5 { + fill: #1677ff; + stroke: #4096ff; +} +.mindmap-dark .node-text.depth-5 { + fill: #fff; +} + +.mindmap-dark .node-rect.depth-6 { + fill: #eb2f96; + stroke: #f759ab; +} +.mindmap-dark .node-text.depth-6 { + fill: #fff; +} + +.mindmap-dark .node-rect.depth-7 { + fill: #faad14; + stroke: #ffc53d; +} +.mindmap-dark .node-text.depth-7 { + fill: #fff; +} + +.mindmap-dark .node-rect.depth-8 { + fill: #a0d911; + stroke: #bae637; +} +.mindmap-dark .node-text.depth-8 { + fill: #fff; +} + .mindmap-dark .link { stroke: #434343; } diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index aec7f5b..3013754 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -52,26 +52,64 @@ export interface CausalChain { user: string; } +// ===== Dynamic category system types ===== + +export interface CategoryDefinition { + name: string; + description?: string; + is_fixed: boolean; + order: number; +} + +export interface Step0Result { + categories: CategoryDefinition[]; +} + +export interface DynamicStep1Result { + attributes: Record; +} + +export interface DynamicCausalChain { + chain: Record; +} + +export const CategoryMode = { + FIXED_ONLY: 'fixed_only', + FIXED_PLUS_CUSTOM: 'fixed_plus_custom', + CUSTOM_ONLY: 'custom_only', + DYNAMIC_AUTO: 'dynamic_auto', +} as const; + +export type CategoryMode = typeof CategoryMode[keyof typeof CategoryMode]; + export interface StreamAnalyzeRequest { query: string; model?: string; temperature?: number; chain_count: number; + // Dynamic category support + category_mode?: CategoryMode; + custom_categories?: string[]; + suggested_category_count?: number; } export interface StreamProgress { - step: 'idle' | 'step1' | 'chains' | 'done' | 'error'; - step1Result?: Step1Result; + step: 'idle' | 'step0' | 'step1' | 'chains' | 'done' | 'error'; + step0Result?: Step0Result; + categoriesUsed?: CategoryDefinition[]; + step1Result?: Step1Result | DynamicStep1Result; currentChainIndex: number; totalChains: number; - completedChains: CausalChain[]; + completedChains: (CausalChain | DynamicCausalChain)[]; message: string; error?: string; } export interface StreamAnalyzeResponse { query: string; - step1_result: Step1Result; - causal_chains: CausalChain[]; + step0_result?: Step0Result; + categories_used: CategoryDefinition[]; + step1_result: Step1Result | DynamicStep1Result; + causal_chains: (CausalChain | DynamicCausalChain)[]; attributes: AttributeNode; }