feat: Add dynamic category system for attribute analysis

Backend:
- Add CategoryMode enum with 4 modes (fixed_only, fixed_plus_custom, custom_only, dynamic_auto)
- Add Step 0 for LLM category analysis before attribute generation
- Implement dynamic prompts for Step 1/2 that work with N categories
- Add execute_step0(), resolve_final_categories(), assemble_dynamic_attribute_tree()
- Update SSE events to include step0_start, step0_complete, categories_resolved

Frontend:
- Add CategorySelector component with mode selection, custom category input, and category count slider
- Update types with CategoryDefinition, Step0Result, DynamicStep1Result, DynamicCausalChain
- Update api.ts with new SSE event handlers
- Update useAttribute hook with category parameters
- Integrate CategorySelector into InputPanel
- Fix mindmap to dynamically extract and display N categories (was hardcoded to 4)
- Add CSS styles for depth 5-8 to support more category levels

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-12-02 23:04:35 +08:00
parent eb6c0c51fa
commit 91f7f41bc1
11 changed files with 727 additions and 53 deletions

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List, Dict
from enum import Enum
class AttributeNode(BaseModel): class AttributeNode(BaseModel):
@@ -46,12 +47,17 @@ class CausalChain(BaseModel):
class StreamAnalyzeRequest(BaseModel): class StreamAnalyzeRequest(BaseModel):
"""多步驟分析請求""" """多步驟分析請求(更新為支持動態類別)"""
query: str query: str
model: Optional[str] = None model: Optional[str] = None
temperature: Optional[float] = 0.7 temperature: Optional[float] = 0.7
chain_count: int = 5 # 用戶可設定要生成多少條因果鏈 chain_count: int = 5 # 用戶可設定要生成多少條因果鏈
# 新增:動態類別支持
category_mode: Optional[str] = "dynamic_auto" # CategoryMode enum 值
custom_categories: Optional[List[str]] = None
suggested_category_count: int = 3 # 建議 LLM 生成的類別數量
class StreamAnalyzeResponse(BaseModel): class StreamAnalyzeResponse(BaseModel):
"""最終完整結果""" """最終完整結果"""
@@ -59,3 +65,36 @@ class StreamAnalyzeResponse(BaseModel):
step1_result: Step1Result step1_result: Step1Result
causal_chains: List[CausalChain] causal_chains: List[CausalChain]
attributes: AttributeNode attributes: AttributeNode
# ===== Dynamic category system schemas =====
class CategoryMode(str, Enum):
"""類別模式"""
FIXED_ONLY = "fixed_only"
FIXED_PLUS_CUSTOM = "fixed_plus_custom"
CUSTOM_ONLY = "custom_only"
DYNAMIC_AUTO = "dynamic_auto"
class CategoryDefinition(BaseModel):
"""類別定義"""
name: str
description: Optional[str] = None
is_fixed: bool = True # LLM 生成的為 False
order: int = 0
class Step0Result(BaseModel):
"""Step 0: LLM 分析建議類別"""
categories: List[CategoryDefinition]
class DynamicStep1Result(BaseModel):
"""動態版本的 Step 1 結果"""
attributes: Dict[str, List[str]] # {類別名: [屬性列表]}
class DynamicCausalChain(BaseModel):
"""動態版本的因果鏈"""
chain: Dict[str, str] # {類別名: 選中屬性}

View File

@@ -1,4 +1,5 @@
from typing import List, Optional from typing import List, Optional, Dict
import json
DEFAULT_CATEGORIES = ["材料", "功能", "用途", "使用族群", "特性"] DEFAULT_CATEGORIES = ["材料", "功能", "用途", "使用族群", "特性"]
@@ -115,3 +116,100 @@ def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None
用戶輸入:{query}""" 用戶輸入:{query}"""
return prompt return prompt
# ===== Dynamic category system prompts =====
def get_step0_category_analysis_prompt(query: str, suggested_count: int = 3) -> str:
"""Step 0: LLM 分析建議類別"""
return f"""/no_think
分析「{query}」,建議 {suggested_count} 個最適合的屬性類別來描述它。
【常見類別參考】材料、功能、用途、使用族群、特性、形狀、顏色、尺寸、品牌、價格區間
【重要】
1. 選擇最能描述此物件本質的類別
2. 類別之間應該有邏輯關係(如:材料→功能→用途)
3. 不要選擇過於抽象或重複的類別
只回傳 JSON
{{
"categories": [
{{"name": "類別1", "description": "說明1", "order": 0}},
{{"name": "類別2", "description": "說明2", "order": 1}}
]
}}
物件:{query}"""
def get_step1_dynamic_attributes_prompt(
query: str,
categories: List # List[CategoryDefinition]
) -> str:
"""動態 Step 1 - 根據類別列表生成屬性"""
# 按 order 排序並構建描述
sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0))
category_desc = "\n".join([
f"- {cat.name if hasattr(cat, 'name') else cat['name']}: {cat.description if hasattr(cat, 'description') else cat.get('description', '相關屬性')}"
for cat in sorted_cats
])
category_keys = [cat.name if hasattr(cat, 'name') else cat['name'] for cat in sorted_cats]
json_template = {cat: ["屬性1", "屬性2", "屬性3"] for cat in category_keys}
return f"""/no_think
分析「{query}」,列出以下類別的屬性。每個類別列出 3-5 個常見屬性。
【類別列表】
{category_desc}
只回傳 JSON
{json.dumps(json_template, ensure_ascii=False, indent=2)}
物件:{query}"""
def get_step2_dynamic_causal_chain_prompt(
query: str,
categories: List, # List[CategoryDefinition]
attributes_by_category: Dict[str, List[str]],
existing_chains: List[Dict[str, str]],
chain_index: int
) -> str:
"""動態 Step 2 - 生成動態類別的因果鏈"""
sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0))
# 構建可選屬性
available_attrs = "\n".join([
f"{cat.name if hasattr(cat, 'name') else cat['name']}{', '.join(attributes_by_category.get(cat.name if hasattr(cat, 'name') else cat['name'], []))}"
for cat in sorted_cats
])
# 已生成的因果鏈
existing_text = ""
if existing_chains:
chains_list = [
"".join([chain.get(cat.name if hasattr(cat, 'name') else cat['name'], '?') for cat in sorted_cats])
for chain in existing_chains
]
existing_text = f"\n【已生成,請勿重複】\n" + "\n".join([f"- {c}" for c in chains_list])
# JSON 模板
json_template = {cat.name if hasattr(cat, 'name') else cat['name']: f"選擇的{cat.name if hasattr(cat, 'name') else cat['name']}" for cat in sorted_cats}
return f"""/no_think
為「{query}」生成第 {chain_index} 條因果鏈。
【可選屬性】
{available_attrs}
{existing_text}
【規則】
1. 從每個類別選擇一個屬性
2. 因果關係必須合理
3. 不要重複
只回傳 JSON
{json.dumps(json_template, ensure_ascii=False, indent=2)}"""

View File

@@ -11,10 +11,18 @@ from ..models.schemas import (
Step1Result, Step1Result,
CausalChain, CausalChain,
AttributeNode, AttributeNode,
CategoryMode,
CategoryDefinition,
Step0Result,
DynamicStep1Result,
DynamicCausalChain,
) )
from ..prompts.attribute_prompt import ( from ..prompts.attribute_prompt import (
get_step1_attributes_prompt, get_step1_attributes_prompt,
get_step2_causal_chain_prompt, get_step2_causal_chain_prompt,
get_step0_category_analysis_prompt,
get_step1_dynamic_attributes_prompt,
get_step2_dynamic_causal_chain_prompt,
) )
from ..services.llm_service import ollama_provider, extract_json_from_response from ..services.llm_service import ollama_provider, extract_json_from_response
@@ -22,6 +30,117 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["attributes"]) router = APIRouter(prefix="/api", tags=["attributes"])
# Fixed categories definition
FIXED_CATEGORIES = [
CategoryDefinition(name="材料", description="物件材料", is_fixed=True, order=0),
CategoryDefinition(name="功能", description="物件功能", is_fixed=True, order=1),
CategoryDefinition(name="用途", description="使用場景", is_fixed=True, order=2),
CategoryDefinition(name="使用族群", description="目標用戶", is_fixed=True, order=3),
]
async def execute_step0(request: StreamAnalyzeRequest) -> Step0Result | None:
"""Execute Step 0 - LLM category analysis"""
if request.category_mode == CategoryMode.FIXED_ONLY:
return None
prompt = get_step0_category_analysis_prompt(
request.query,
request.suggested_category_count
)
temperature = request.temperature if request.temperature is not None else 0.7
response = await ollama_provider.generate(
prompt, model=request.model, temperature=temperature
)
data = extract_json_from_response(response)
step0_result = Step0Result(**data)
# Mark as LLM generated
for cat in step0_result.categories:
cat.is_fixed = False
return step0_result
def resolve_final_categories(
request: StreamAnalyzeRequest,
step0_result: Step0Result | None
) -> List[CategoryDefinition]:
"""Determine final categories based on mode"""
if request.category_mode == CategoryMode.FIXED_ONLY:
return FIXED_CATEGORIES
elif request.category_mode == CategoryMode.FIXED_PLUS_CUSTOM:
categories = FIXED_CATEGORIES.copy()
if request.custom_categories:
for i, name in enumerate(request.custom_categories):
categories.append(
CategoryDefinition(
name=name, is_fixed=False,
order=len(FIXED_CATEGORIES) + i
)
)
return categories
elif request.category_mode == CategoryMode.CUSTOM_ONLY:
return step0_result.categories if step0_result else FIXED_CATEGORIES
elif request.category_mode == CategoryMode.DYNAMIC_AUTO:
return step0_result.categories if step0_result else FIXED_CATEGORIES
return FIXED_CATEGORIES
def assemble_dynamic_attribute_tree(
query: str,
chains: List[DynamicCausalChain],
categories: List[CategoryDefinition]
) -> AttributeNode:
"""Assemble dynamic N-level tree from causal chains"""
sorted_cats = sorted(categories, key=lambda x: x.order)
if not chains:
return AttributeNode(name=query, children=[])
def build_recursive(
level: int,
parent_path: dict,
remaining_chains: List[DynamicCausalChain]
) -> List[AttributeNode]:
if level >= len(sorted_cats):
return []
current_cat = sorted_cats[level]
grouped = {}
for chain in remaining_chains:
# Check if this chain matches the parent path
if all(chain.chain.get(k) == v for k, v in parent_path.items()):
attr_val = chain.chain.get(current_cat.name)
if attr_val:
if attr_val not in grouped:
grouped[attr_val] = []
grouped[attr_val].append(chain)
nodes = []
for attr_val, child_chains in grouped.items():
new_path = {**parent_path, current_cat.name: attr_val}
children = build_recursive(level + 1, new_path, child_chains)
node = AttributeNode(
name=attr_val,
category=current_cat.name,
children=children if children else None
)
nodes.append(node)
return nodes
root_children = build_recursive(0, {}, chains)
return AttributeNode(name=query, children=root_children)
def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode: def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode:
"""將因果鏈組裝成樹狀結構""" """將因果鏈組裝成樹狀結構"""
# 以材料為第一層分組 # 以材料為第一層分組
@@ -74,14 +193,28 @@ def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeN
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]: async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
"""生成 SSE 事件流""" """Generate SSE events with dynamic category support"""
try: try:
temperature = request.temperature if request.temperature is not None else 0.7 temperature = request.temperature if request.temperature is not None else 0.7
# ========== Step 1: 生成屬性列表 ========== # ========== Step 0: Category Analysis (if needed) ==========
yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n" step0_result = None
if request.category_mode != CategoryMode.FIXED_ONLY:
yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n"
step1_prompt = get_step1_attributes_prompt(request.query) step0_result = await execute_step0(request)
if step0_result:
yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n"
# ========== Resolve Final Categories ==========
final_categories = resolve_final_categories(request, step0_result)
yield f"event: categories_resolved\ndata: {json.dumps({'categories': [c.model_dump() for c in final_categories]}, ensure_ascii=False)}\n\n"
# ========== Step 1: Generate Attributes (Dynamic) ==========
yield f"event: step1_start\ndata: {json.dumps({'message': '生成屬性...'}, ensure_ascii=False)}\n\n"
step1_prompt = get_step1_dynamic_attributes_prompt(request.query, final_categories)
logger.info(f"Step 1 prompt: {step1_prompt[:200]}") logger.info(f"Step 1 prompt: {step1_prompt[:200]}")
step1_response = await ollama_provider.generate( step1_response = await ollama_provider.generate(
@@ -90,29 +223,27 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
logger.info(f"Step 1 response: {step1_response[:500]}") logger.info(f"Step 1 response: {step1_response[:500]}")
step1_data = extract_json_from_response(step1_response) step1_data = extract_json_from_response(step1_response)
step1_result = Step1Result(**step1_data) step1_result = DynamicStep1Result(attributes=step1_data)
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n" yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
# ========== Step 2: 逐條生成因果鏈 ========== # ========== Step 2: Generate Causal Chains (Dynamic) ==========
causal_chains: List[CausalChain] = [] causal_chains: List[DynamicCausalChain] = []
for i in range(request.chain_count): for i in range(request.chain_count):
chain_index = i + 1 chain_index = i + 1
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n" yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
step2_prompt = get_step2_causal_chain_prompt( step2_prompt = get_step2_dynamic_causal_chain_prompt(
query=request.query, query=request.query,
materials=step1_result.materials, categories=final_categories,
functions=step1_result.functions, attributes_by_category=step1_result.attributes,
usages=step1_result.usages, existing_chains=[c.chain for c in causal_chains],
users=step1_result.users,
existing_chains=[c.model_dump() for c in causal_chains],
chain_index=chain_index, chain_index=chain_index,
) )
# 逐漸提高 temperature 增加多樣性 # Gradually increase temperature for diversity
chain_temperature = min(temperature + 0.05 * i, 1.0) chain_temperature = min(temperature + 0.05 * i, 1.0)
max_retries = 2 max_retries = 2
@@ -125,7 +256,7 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
logger.info(f"Chain {chain_index} response: {step2_response[:300]}") logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
chain_data = extract_json_from_response(step2_response) chain_data = extract_json_from_response(step2_response)
chain = CausalChain(**chain_data) chain = DynamicCausalChain(chain=chain_data)
break break
except Exception as e: except Exception as e:
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}") logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
@@ -136,13 +267,15 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
causal_chains.append(chain) causal_chains.append(chain)
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n" yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
else: else:
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'{chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n" yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'生成失敗'}, ensure_ascii=False)}\n\n"
# ========== 組裝最終結構 ========== # ========== Assemble Final Tree (Dynamic) ==========
final_tree = assemble_attribute_tree(request.query, causal_chains) final_tree = assemble_dynamic_attribute_tree(request.query, causal_chains, final_categories)
final_result = { final_result = {
"query": request.query, "query": request.query,
"step0_result": step0_result.model_dump() if step0_result else None,
"categories_used": [c.model_dump() for c in final_categories],
"step1_result": step1_result.model_dump(), "step1_result": step1_result.model_dump(),
"causal_chains": [c.model_dump() for c in causal_chains], "causal_chains": [c.model_dump() for c in causal_chains],
"attributes": final_tree.model_dump(), "attributes": final_tree.model_dump(),

View File

@@ -5,6 +5,7 @@ import { InputPanel } from './components/InputPanel';
import { MindmapPanel } from './components/MindmapPanel'; import { MindmapPanel } from './components/MindmapPanel';
import { useAttribute } from './hooks/useAttribute'; import { useAttribute } from './hooks/useAttribute';
import type { MindmapD3Ref } from './components/MindmapD3'; import type { MindmapD3Ref } from './components/MindmapD3';
import type { CategoryMode } from './types';
const { Header, Sider, Content } = Layout; const { Header, Sider, Content } = Layout;
const { Title } = Typography; const { Title } = Typography;
@@ -27,9 +28,12 @@ function App() {
query: string, query: string,
model?: string, model?: string,
temperature?: number, temperature?: number,
chainCount?: number chainCount?: number,
categoryMode?: CategoryMode,
customCategories?: string[],
suggestedCategoryCount?: number
) => { ) => {
await analyze(query, model, temperature, chainCount); await analyze(query, model, temperature, chainCount, categoryMode, customCategories, suggestedCategoryCount);
}; };
const handleExpandAll = useCallback(() => { const handleExpandAll = useCallback(() => {

View File

@@ -0,0 +1,157 @@
import { useState } from 'react';
import { Radio, Space, Input, Button, Tag, Tooltip, Slider, Typography } from 'antd';
import { InfoCircleOutlined, PlusOutlined } from '@ant-design/icons';
import type { CategoryDefinition, CategoryMode, Step0Result } from '../types';
const { Text } = Typography;
interface CategorySelectorProps {
mode: CategoryMode;
onModeChange: (mode: CategoryMode) => void;
customCategories: string[];
onCustomCategoriesChange: (cats: string[]) => void;
suggestedCount: number;
onSuggestedCountChange: (count: number) => void;
step0Result?: Step0Result;
onStep0Edit?: (cats: CategoryDefinition[]) => void;
disabled?: boolean;
}
export function CategorySelector({
mode,
onModeChange,
customCategories,
onCustomCategoriesChange,
suggestedCount,
onSuggestedCountChange,
step0Result,
onStep0Edit,
disabled
}: CategorySelectorProps) {
const [inputValue, setInputValue] = useState('');
return (
<Space direction="vertical" style={{ width: '100%' }}>
<Radio.Group
value={mode}
onChange={(e) => onModeChange(e.target.value as CategoryMode)}
disabled={disabled}
>
<Space direction="vertical">
<Radio value="fixed_only">
Fixed (使)
</Radio>
<Radio value="fixed_plus_custom">
Fixed + Custom
</Radio>
<Radio value="custom_only">
Custom Only (LLM suggests)
</Radio>
<Radio value="dynamic_auto">
Dynamic (LLM suggests, editable)
</Radio>
</Space>
</Radio.Group>
{/* 動態模式:類別數量調整 */}
{(mode === 'custom_only' || mode === 'dynamic_auto') && (
<div>
<Text>Suggested Category Count: {suggestedCount}</Text>
<Slider
min={2}
max={8}
step={1}
value={suggestedCount}
onChange={onSuggestedCountChange}
marks={{ 2: '2', 3: '3', 5: '5', 8: '8' }}
disabled={disabled}
/>
</div>
)}
{/* 固定+自定義模式 */}
{mode === 'fixed_plus_custom' && (
<div>
<Text type="secondary">Add custom categories:</Text>
<Space.Compact style={{ width: '100%', marginTop: 8 }}>
<Input
placeholder="Category name"
value={inputValue}
onChange={(e) => setInputValue(e.target.value)}
onPressEnter={() => {
if (inputValue.trim()) {
onCustomCategoriesChange([...customCategories, inputValue.trim()]);
setInputValue('');
}
}}
disabled={disabled}
/>
<Button
type="primary"
icon={<PlusOutlined />}
onClick={() => {
if (inputValue.trim()) {
onCustomCategoriesChange([...customCategories, inputValue.trim()]);
setInputValue('');
}
}}
disabled={disabled}
>
Add
</Button>
</Space.Compact>
{customCategories.length > 0 && (
<div style={{ marginTop: 8 }}>
{customCategories.map((cat, i) => (
<Tag
key={i}
closable
onClose={() => {
onCustomCategoriesChange(customCategories.filter((_, idx) => idx !== i));
}}
>
{cat}
</Tag>
))}
</div>
)}
</div>
)}
{/* Step 0 結果顯示 */}
{step0Result && (mode === 'custom_only' || mode === 'dynamic_auto') && (
<div style={{ marginTop: 8, padding: 12, background: 'rgba(0,0,0,0.04)', borderRadius: 4 }}>
<Text strong>LLM Suggested:</Text>
<div style={{ marginTop: 8 }}>
{step0Result.categories.map((cat, i) => (
<Tag
key={i}
color="blue"
closable={mode === 'dynamic_auto'}
onClose={mode === 'dynamic_auto' ? () => {
if (onStep0Edit) {
onStep0Edit(step0Result.categories.filter((_, idx) => idx !== i));
}
} : undefined}
>
{cat.name}
{cat.description && (
<Tooltip title={cat.description}>
<InfoCircleOutlined style={{ marginLeft: 4 }} />
</Tooltip>
)}
</Tag>
))}
</div>
{mode === 'dynamic_auto' && (
<Text type="secondary" style={{ fontSize: 12, marginTop: 8, display: 'block' }}>
You can remove tags or proceed
</Text>
)}
</div>
)}
</Space>
);
}

View File

@@ -22,8 +22,9 @@ import {
LoadingOutlined, LoadingOutlined,
CheckCircleOutlined, CheckCircleOutlined,
} from '@ant-design/icons'; } from '@ant-design/icons';
import type { HistoryItem, AttributeNode, StreamProgress } from '../types'; import type { HistoryItem, AttributeNode, StreamProgress, CategoryMode, DynamicCausalChain, CausalChain } from '../types';
import { getModels } from '../services/api'; import { getModels } from '../services/api';
import { CategorySelector } from './CategorySelector';
const { TextArea } = Input; const { TextArea } = Input;
const { Text } = Typography; const { Text } = Typography;
@@ -38,7 +39,15 @@ interface InputPanelProps {
progress: StreamProgress; progress: StreamProgress;
history: HistoryItem[]; history: HistoryItem[];
currentResult: AttributeNode | null; currentResult: AttributeNode | null;
onAnalyze: (query: string, model?: string, temperature?: number, chainCount?: number) => Promise<void>; onAnalyze: (
query: string,
model?: string,
temperature?: number,
chainCount?: number,
categoryMode?: CategoryMode,
customCategories?: string[],
suggestedCategoryCount?: number
) => Promise<void>;
onLoadHistory: (item: HistoryItem) => void; onLoadHistory: (item: HistoryItem) => void;
onExpandAll?: () => void; onExpandAll?: () => void;
onCollapseAll?: () => void; onCollapseAll?: () => void;
@@ -64,6 +73,10 @@ export function InputPanel({
const [loadingModels, setLoadingModels] = useState(false); const [loadingModels, setLoadingModels] = useState(false);
const [temperature, setTemperature] = useState(0.7); const [temperature, setTemperature] = useState(0.7);
const [chainCount, setChainCount] = useState(5); const [chainCount, setChainCount] = useState(5);
// Category settings
const [categoryMode, setCategoryMode] = useState<CategoryMode>('dynamic_auto' as CategoryMode);
const [customCategories, setCustomCategories] = useState<string[]>([]);
const [suggestedCategoryCount, setSuggestedCategoryCount] = useState(3);
useEffect(() => { useEffect(() => {
async function fetchModels() { async function fetchModels() {
@@ -92,7 +105,15 @@ export function InputPanel({
} }
try { try {
await onAnalyze(query.trim(), selectedModel, temperature, chainCount); await onAnalyze(
query.trim(),
selectedModel,
temperature,
chainCount,
categoryMode,
customCategories.length > 0 ? customCategories : undefined,
suggestedCategoryCount
);
setQuery(''); setQuery('');
} catch { } catch {
message.error('Analysis failed'); message.error('Analysis failed');
@@ -191,14 +212,27 @@ export function InputPanel({
img.src = 'data:image/svg+xml;base64,' + btoa(unescape(encodeURIComponent(svgData))); img.src = 'data:image/svg+xml;base64,' + btoa(unescape(encodeURIComponent(svgData)));
}; };
// Helper to format chain display (supports both fixed and dynamic chains)
const formatChain = (chain: CausalChain | DynamicCausalChain): string => {
if ('chain' in chain) {
// Dynamic chain
return Object.values(chain.chain).join(' → ');
} else {
// Fixed chain
return `${chain.material}${chain.function}${chain.usage}${chain.user}`;
}
};
const renderProgressIndicator = () => { const renderProgressIndicator = () => {
if (progress.step === 'idle' || progress.step === 'done') return null; if (progress.step === 'idle' || progress.step === 'done') return null;
const percent = progress.step === 'step1' const percent = progress.step === 'step0'
? 10 ? 5
: progress.step === 'chains' : progress.step === 'step1'
? 10 + (progress.currentChainIndex / progress.totalChains) * 90 ? 10
: 100; : progress.step === 'chains'
? 10 + (progress.currentChainIndex / progress.totalChains) * 90
: 100;
return ( return (
<div style={{ marginBottom: 16, padding: 12, background: 'rgba(0,0,0,0.04)', borderRadius: 8 }}> <div style={{ marginBottom: 16, padding: 12, background: 'rgba(0,0,0,0.04)', borderRadius: 8 }}>
@@ -213,6 +247,20 @@ export function InputPanel({
</Space> </Space>
<Progress percent={Math.round(percent)} size="small" status={progress.step === 'error' ? 'exception' : 'active'} /> <Progress percent={Math.round(percent)} size="small" status={progress.step === 'error' ? 'exception' : 'active'} />
{/* Show categories used */}
{progress.categoriesUsed && progress.categoriesUsed.length > 0 && (
<div>
<Text type="secondary" style={{ fontSize: 12 }}>Categories:</Text>
<div style={{ marginTop: 4 }}>
{progress.categoriesUsed.map((cat, i) => (
<Tag key={i} color={cat.is_fixed ? 'default' : 'blue'}>
{cat.name}
</Tag>
))}
</div>
</div>
)}
{progress.completedChains.length > 0 && ( {progress.completedChains.length > 0 && (
<div style={{ marginTop: 8 }}> <div style={{ marginTop: 8 }}>
<Text type="secondary" style={{ fontSize: 12 }}>Completed chains:</Text> <Text type="secondary" style={{ fontSize: 12 }}>Completed chains:</Text>
@@ -220,7 +268,7 @@ export function InputPanel({
{progress.completedChains.map((chain, i) => ( {progress.completedChains.map((chain, i) => (
<div key={i} style={{ fontSize: 11, padding: '2px 0' }}> <div key={i} style={{ fontSize: 11, padding: '2px 0' }}>
<CheckCircleOutlined style={{ color: '#52c41a', marginRight: 4 }} /> <CheckCircleOutlined style={{ color: '#52c41a', marginRight: 4 }} />
{chain.material} {chain.function} {chain.usage} {chain.user} {formatChain(chain)}
</div> </div>
))} ))}
</div> </div>
@@ -232,6 +280,22 @@ export function InputPanel({
}; };
const collapseItems = [ const collapseItems = [
{
key: 'categories',
label: 'Category Settings',
children: (
<CategorySelector
mode={categoryMode}
onModeChange={setCategoryMode}
customCategories={customCategories}
onCustomCategoriesChange={setCustomCategories}
suggestedCount={suggestedCategoryCount}
onSuggestedCountChange={setSuggestedCategoryCount}
step0Result={progress.step0Result}
disabled={loading}
/>
),
},
{ {
key: 'llm', key: 'llm',
label: 'LLM Parameters', label: 'LLM Parameters',

View File

@@ -117,8 +117,19 @@ export const MindmapD3 = forwardRef<MindmapD3Ref, MindmapD3Props>(
d._children = undefined; d._children = undefined;
}); });
// Category labels for header // Dynamically extract category labels from the tree based on depth
const categoryLabels = ['', '材料', '功能', '用途', '使用族群']; // Each depth level corresponds to a category
const categoryByDepth: Record<number, string> = {};
root.descendants().forEach((d: TreeNode) => {
if (d.depth > 0 && d.data.category && !categoryByDepth[d.depth]) {
categoryByDepth[d.depth] = d.data.category;
}
});
const maxDepthWithCategory = Math.max(...Object.keys(categoryByDepth).map(Number), 0);
const categoryLabels = [''];
for (let i = 1; i <= maxDepthWithCategory; i++) {
categoryLabels.push(categoryByDepth[i] || '');
}
const headerHeight = 40; const headerHeight = 40;
function update(source: TreeNode) { function update(source: TreeNode) {
@@ -143,12 +154,27 @@ export const MindmapD3 = forwardRef<MindmapD3Ref, MindmapD3Props>(
// Draw category headers with background // Draw category headers with background
g.selectAll('.category-header-group').remove(); g.selectAll('.category-header-group').remove();
const maxDepth = Math.max(...descendants.map(d => d.depth)); const maxDepth = Math.max(...descendants.map(d => d.depth));
const categoryColors: Record<string, string> = {
'材料': isDark ? '#854eca' : '#722ed1', // Dynamic color palette for categories
'功能': isDark ? '#13a8a8' : '#13c2c2', const colorPalette = [
'用途': isDark ? '#d87a16' : '#fa8c16', { dark: '#854eca', light: '#722ed1' }, // purple
'使用族群': isDark ? '#49aa19' : '#52c41a', { dark: '#13a8a8', light: '#13c2c2' }, // cyan
}; { dark: '#d87a16', light: '#fa8c16' }, // orange
{ dark: '#49aa19', light: '#52c41a' }, // green
{ dark: '#1677ff', light: '#1890ff' }, // blue
{ dark: '#eb2f96', light: '#f759ab' }, // magenta
{ dark: '#faad14', light: '#ffc53d' }, // gold
{ dark: '#a0d911', light: '#bae637' }, // lime
];
// Generate colors dynamically based on category position
const categoryColors: Record<string, string> = {};
categoryLabels.forEach((label, index) => {
if (label && index > 0) {
const colorIndex = (index - 1) % colorPalette.length;
categoryColors[label] = isDark ? colorPalette[colorIndex].dark : colorPalette[colorIndex].light;
}
});
for (let depth = 1; depth <= Math.min(maxDepth, categoryLabels.length - 1); depth++) { for (let depth = 1; depth <= Math.min(maxDepth, categoryLabels.length - 1); depth++) {
const label = categoryLabels[depth]; const label = categoryLabels[depth];

View File

@@ -3,9 +3,9 @@ import type {
AttributeNode, AttributeNode,
HistoryItem, HistoryItem,
StreamProgress, StreamProgress,
StreamAnalyzeResponse, StreamAnalyzeResponse
CausalChain
} from '../types'; } from '../types';
import { CategoryMode } from '../types';
import { analyzeAttributesStream } from '../services/api'; import { analyzeAttributesStream } from '../services/api';
export function useAttribute() { export function useAttribute() {
@@ -24,7 +24,10 @@ export function useAttribute() {
query: string, query: string,
model?: string, model?: string,
temperature?: number, temperature?: number,
chainCount: number = 5 chainCount: number = 5,
categoryMode: CategoryMode = CategoryMode.DYNAMIC_AUTO,
customCategories?: string[],
suggestedCategoryCount: number = 3
) => { ) => {
// 重置狀態 // 重置狀態
setProgress({ setProgress({
@@ -39,8 +42,40 @@ export function useAttribute() {
try { try {
await analyzeAttributesStream( await analyzeAttributesStream(
{ query, chain_count: chainCount, model, temperature },
{ {
query,
chain_count: chainCount,
model,
temperature,
category_mode: categoryMode,
custom_categories: customCategories,
suggested_category_count: suggestedCategoryCount
},
{
onStep0Start: () => {
setProgress(prev => ({
...prev,
step: 'step0',
message: '正在分析類別...',
}));
},
onStep0Complete: (result) => {
setProgress(prev => ({
...prev,
step0Result: result,
message: '類別分析完成',
}));
},
onCategoriesResolved: (categories) => {
setProgress(prev => ({
...prev,
categoriesUsed: categories,
message: `使用 ${categories.length} 個類別`,
}));
},
onStep1Start: () => { onStep1Start: () => {
setProgress(prev => ({ setProgress(prev => ({
...prev, ...prev,
@@ -148,7 +183,7 @@ export function useAttribute() {
}); });
}, []); }, []);
const isLoading = progress.step === 'step1' || progress.step === 'chains'; const isLoading = progress.step === 'step0' || progress.step === 'step1' || progress.step === 'chains';
return { return {
loading: isLoading, loading: isLoading,

View File

@@ -3,17 +3,24 @@ import type {
StreamAnalyzeRequest, StreamAnalyzeRequest,
StreamAnalyzeResponse, StreamAnalyzeResponse,
Step1Result, Step1Result,
CausalChain CausalChain,
Step0Result,
CategoryDefinition,
DynamicStep1Result,
DynamicCausalChain
} from '../types'; } from '../types';
// 自動使用當前瀏覽器的 hostname支援遠端存取 // 自動使用當前瀏覽器的 hostname支援遠端存取
const API_BASE_URL = `http://${window.location.hostname}:8000/api`; const API_BASE_URL = `http://${window.location.hostname}:8000/api`;
export interface SSECallbacks { export interface SSECallbacks {
onStep0Start?: () => void;
onStep0Complete?: (result: Step0Result) => void;
onCategoriesResolved?: (categories: CategoryDefinition[]) => void;
onStep1Start?: () => void; onStep1Start?: () => void;
onStep1Complete?: (result: Step1Result) => void; onStep1Complete?: (result: Step1Result | DynamicStep1Result) => void;
onChainStart?: (index: number, total: number) => void; onChainStart?: (index: number, total: number) => void;
onChainComplete?: (index: number, chain: CausalChain) => void; onChainComplete?: (index: number, chain: CausalChain | DynamicCausalChain) => void;
onChainError?: (index: number, error: string) => void; onChainError?: (index: number, error: string) => void;
onDone?: (response: StreamAnalyzeResponse) => void; onDone?: (response: StreamAnalyzeResponse) => void;
onError?: (error: string) => void; onError?: (error: string) => void;
@@ -65,6 +72,15 @@ export async function analyzeAttributesStream(
const eventData = JSON.parse(dataMatch[1]); const eventData = JSON.parse(dataMatch[1]);
switch (eventType) { switch (eventType) {
case 'step0_start':
callbacks.onStep0Start?.();
break;
case 'step0_complete':
callbacks.onStep0Complete?.(eventData.result);
break;
case 'categories_resolved':
callbacks.onCategoriesResolved?.(eventData.categories);
break;
case 'step1_start': case 'step1_start':
callbacks.onStep1Start?.(); callbacks.onStep1Start?.();
break; break;

View File

@@ -126,6 +126,38 @@
fill: #fff; fill: #fff;
} }
.mindmap-light .node-rect.depth-5 {
fill: #1890ff;
stroke: #096dd9;
}
.mindmap-light .node-text.depth-5 {
fill: #fff;
}
.mindmap-light .node-rect.depth-6 {
fill: #f759ab;
stroke: #eb2f96;
}
.mindmap-light .node-text.depth-6 {
fill: #fff;
}
.mindmap-light .node-rect.depth-7 {
fill: #ffc53d;
stroke: #faad14;
}
.mindmap-light .node-text.depth-7 {
fill: #fff;
}
.mindmap-light .node-rect.depth-8 {
fill: #bae637;
stroke: #a0d911;
}
.mindmap-light .node-text.depth-8 {
fill: #fff;
}
.mindmap-light .link { .mindmap-light .link {
stroke: #bfbfbf; stroke: #bfbfbf;
} }
@@ -221,6 +253,38 @@
fill: #fff; fill: #fff;
} }
.mindmap-dark .node-rect.depth-5 {
fill: #1677ff;
stroke: #4096ff;
}
.mindmap-dark .node-text.depth-5 {
fill: #fff;
}
.mindmap-dark .node-rect.depth-6 {
fill: #eb2f96;
stroke: #f759ab;
}
.mindmap-dark .node-text.depth-6 {
fill: #fff;
}
.mindmap-dark .node-rect.depth-7 {
fill: #faad14;
stroke: #ffc53d;
}
.mindmap-dark .node-text.depth-7 {
fill: #fff;
}
.mindmap-dark .node-rect.depth-8 {
fill: #a0d911;
stroke: #bae637;
}
.mindmap-dark .node-text.depth-8 {
fill: #fff;
}
.mindmap-dark .link { .mindmap-dark .link {
stroke: #434343; stroke: #434343;
} }

View File

@@ -52,26 +52,64 @@ export interface CausalChain {
user: string; user: string;
} }
// ===== Dynamic category system types =====
export interface CategoryDefinition {
name: string;
description?: string;
is_fixed: boolean;
order: number;
}
export interface Step0Result {
categories: CategoryDefinition[];
}
export interface DynamicStep1Result {
attributes: Record<string, string[]>;
}
export interface DynamicCausalChain {
chain: Record<string, string>;
}
export const CategoryMode = {
FIXED_ONLY: 'fixed_only',
FIXED_PLUS_CUSTOM: 'fixed_plus_custom',
CUSTOM_ONLY: 'custom_only',
DYNAMIC_AUTO: 'dynamic_auto',
} as const;
export type CategoryMode = typeof CategoryMode[keyof typeof CategoryMode];
export interface StreamAnalyzeRequest { export interface StreamAnalyzeRequest {
query: string; query: string;
model?: string; model?: string;
temperature?: number; temperature?: number;
chain_count: number; chain_count: number;
// Dynamic category support
category_mode?: CategoryMode;
custom_categories?: string[];
suggested_category_count?: number;
} }
export interface StreamProgress { export interface StreamProgress {
step: 'idle' | 'step1' | 'chains' | 'done' | 'error'; step: 'idle' | 'step0' | 'step1' | 'chains' | 'done' | 'error';
step1Result?: Step1Result; step0Result?: Step0Result;
categoriesUsed?: CategoryDefinition[];
step1Result?: Step1Result | DynamicStep1Result;
currentChainIndex: number; currentChainIndex: number;
totalChains: number; totalChains: number;
completedChains: CausalChain[]; completedChains: (CausalChain | DynamicCausalChain)[];
message: string; message: string;
error?: string; error?: string;
} }
export interface StreamAnalyzeResponse { export interface StreamAnalyzeResponse {
query: string; query: string;
step1_result: Step1Result; step0_result?: Step0Result;
causal_chains: CausalChain[]; categories_used: CategoryDefinition[];
step1_result: Step1Result | DynamicStep1Result;
causal_chains: (CausalChain | DynamicCausalChain)[];
attributes: AttributeNode; attributes: AttributeNode;
} }