feat: Add dynamic category system for attribute analysis

Backend:
- Add CategoryMode enum with 4 modes (fixed_only, fixed_plus_custom, custom_only, dynamic_auto)
- Add Step 0 for LLM category analysis before attribute generation
- Implement dynamic prompts for Step 1/2 that work with N categories
- Add execute_step0(), resolve_final_categories(), assemble_dynamic_attribute_tree()
- Update SSE events to include step0_start, step0_complete, categories_resolved

Frontend:
- Add CategorySelector component with mode selection, custom category input, and category count slider
- Update types with CategoryDefinition, Step0Result, DynamicStep1Result, DynamicCausalChain
- Update api.ts with new SSE event handlers
- Update useAttribute hook with category parameters
- Integrate CategorySelector into InputPanel
- Fix mindmap to dynamically extract and display N categories (was hardcoded to 4)
- Add CSS styles for depth 5-8 to support more category levels

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-12-02 23:04:35 +08:00
parent eb6c0c51fa
commit 91f7f41bc1
11 changed files with 727 additions and 53 deletions

View File

@@ -11,10 +11,18 @@ from ..models.schemas import (
Step1Result,
CausalChain,
AttributeNode,
CategoryMode,
CategoryDefinition,
Step0Result,
DynamicStep1Result,
DynamicCausalChain,
)
from ..prompts.attribute_prompt import (
get_step1_attributes_prompt,
get_step2_causal_chain_prompt,
get_step0_category_analysis_prompt,
get_step1_dynamic_attributes_prompt,
get_step2_dynamic_causal_chain_prompt,
)
from ..services.llm_service import ollama_provider, extract_json_from_response
@@ -22,6 +30,117 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["attributes"])
# Fixed categories definition
FIXED_CATEGORIES = [
CategoryDefinition(name="材料", description="物件材料", is_fixed=True, order=0),
CategoryDefinition(name="功能", description="物件功能", is_fixed=True, order=1),
CategoryDefinition(name="用途", description="使用場景", is_fixed=True, order=2),
CategoryDefinition(name="使用族群", description="目標用戶", is_fixed=True, order=3),
]
async def execute_step0(request: StreamAnalyzeRequest) -> Step0Result | None:
"""Execute Step 0 - LLM category analysis"""
if request.category_mode == CategoryMode.FIXED_ONLY:
return None
prompt = get_step0_category_analysis_prompt(
request.query,
request.suggested_category_count
)
temperature = request.temperature if request.temperature is not None else 0.7
response = await ollama_provider.generate(
prompt, model=request.model, temperature=temperature
)
data = extract_json_from_response(response)
step0_result = Step0Result(**data)
# Mark as LLM generated
for cat in step0_result.categories:
cat.is_fixed = False
return step0_result
def resolve_final_categories(
request: StreamAnalyzeRequest,
step0_result: Step0Result | None
) -> List[CategoryDefinition]:
"""Determine final categories based on mode"""
if request.category_mode == CategoryMode.FIXED_ONLY:
return FIXED_CATEGORIES
elif request.category_mode == CategoryMode.FIXED_PLUS_CUSTOM:
categories = FIXED_CATEGORIES.copy()
if request.custom_categories:
for i, name in enumerate(request.custom_categories):
categories.append(
CategoryDefinition(
name=name, is_fixed=False,
order=len(FIXED_CATEGORIES) + i
)
)
return categories
elif request.category_mode == CategoryMode.CUSTOM_ONLY:
return step0_result.categories if step0_result else FIXED_CATEGORIES
elif request.category_mode == CategoryMode.DYNAMIC_AUTO:
return step0_result.categories if step0_result else FIXED_CATEGORIES
return FIXED_CATEGORIES
def assemble_dynamic_attribute_tree(
query: str,
chains: List[DynamicCausalChain],
categories: List[CategoryDefinition]
) -> AttributeNode:
"""Assemble dynamic N-level tree from causal chains"""
sorted_cats = sorted(categories, key=lambda x: x.order)
if not chains:
return AttributeNode(name=query, children=[])
def build_recursive(
level: int,
parent_path: dict,
remaining_chains: List[DynamicCausalChain]
) -> List[AttributeNode]:
if level >= len(sorted_cats):
return []
current_cat = sorted_cats[level]
grouped = {}
for chain in remaining_chains:
# Check if this chain matches the parent path
if all(chain.chain.get(k) == v for k, v in parent_path.items()):
attr_val = chain.chain.get(current_cat.name)
if attr_val:
if attr_val not in grouped:
grouped[attr_val] = []
grouped[attr_val].append(chain)
nodes = []
for attr_val, child_chains in grouped.items():
new_path = {**parent_path, current_cat.name: attr_val}
children = build_recursive(level + 1, new_path, child_chains)
node = AttributeNode(
name=attr_val,
category=current_cat.name,
children=children if children else None
)
nodes.append(node)
return nodes
root_children = build_recursive(0, {}, chains)
return AttributeNode(name=query, children=root_children)
def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode:
"""將因果鏈組裝成樹狀結構"""
# 以材料為第一層分組
@@ -74,14 +193,28 @@ def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeN
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
"""生成 SSE 事件流"""
"""Generate SSE events with dynamic category support"""
try:
temperature = request.temperature if request.temperature is not None else 0.7
# ========== Step 1: 生成屬性列表 ==========
yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n"
# ========== Step 0: Category Analysis (if needed) ==========
step0_result = None
if request.category_mode != CategoryMode.FIXED_ONLY:
yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n"
step1_prompt = get_step1_attributes_prompt(request.query)
step0_result = await execute_step0(request)
if step0_result:
yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n"
# ========== Resolve Final Categories ==========
final_categories = resolve_final_categories(request, step0_result)
yield f"event: categories_resolved\ndata: {json.dumps({'categories': [c.model_dump() for c in final_categories]}, ensure_ascii=False)}\n\n"
# ========== Step 1: Generate Attributes (Dynamic) ==========
yield f"event: step1_start\ndata: {json.dumps({'message': '生成屬性...'}, ensure_ascii=False)}\n\n"
step1_prompt = get_step1_dynamic_attributes_prompt(request.query, final_categories)
logger.info(f"Step 1 prompt: {step1_prompt[:200]}")
step1_response = await ollama_provider.generate(
@@ -90,29 +223,27 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
logger.info(f"Step 1 response: {step1_response[:500]}")
step1_data = extract_json_from_response(step1_response)
step1_result = Step1Result(**step1_data)
step1_result = DynamicStep1Result(attributes=step1_data)
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
# ========== Step 2: 逐條生成因果鏈 ==========
causal_chains: List[CausalChain] = []
# ========== Step 2: Generate Causal Chains (Dynamic) ==========
causal_chains: List[DynamicCausalChain] = []
for i in range(request.chain_count):
chain_index = i + 1
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
step2_prompt = get_step2_causal_chain_prompt(
step2_prompt = get_step2_dynamic_causal_chain_prompt(
query=request.query,
materials=step1_result.materials,
functions=step1_result.functions,
usages=step1_result.usages,
users=step1_result.users,
existing_chains=[c.model_dump() for c in causal_chains],
categories=final_categories,
attributes_by_category=step1_result.attributes,
existing_chains=[c.chain for c in causal_chains],
chain_index=chain_index,
)
# 逐漸提高 temperature 增加多樣性
# Gradually increase temperature for diversity
chain_temperature = min(temperature + 0.05 * i, 1.0)
max_retries = 2
@@ -125,7 +256,7 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
chain_data = extract_json_from_response(step2_response)
chain = CausalChain(**chain_data)
chain = DynamicCausalChain(chain=chain_data)
break
except Exception as e:
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
@@ -136,13 +267,15 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s
causal_chains.append(chain)
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
else:
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'{chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n"
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'生成失敗'}, ensure_ascii=False)}\n\n"
# ========== 組裝最終結構 ==========
final_tree = assemble_attribute_tree(request.query, causal_chains)
# ========== Assemble Final Tree (Dynamic) ==========
final_tree = assemble_dynamic_attribute_tree(request.query, causal_chains, final_categories)
final_result = {
"query": request.query,
"step0_result": step0_result.model_dump() if step0_result else None,
"categories_used": [c.model_dump() for c in final_categories],
"step1_result": step1_result.model_dump(),
"causal_chains": [c.model_dump() for c in causal_chains],
"attributes": final_tree.model_dump(),