diff --git a/backend/app/models/schemas.py b/backend/app/models/schemas.py index fbf2f78..90a851f 100644 --- a/backend/app/models/schemas.py +++ b/backend/app/models/schemas.py @@ -73,6 +73,7 @@ class CategoryMode(str, Enum): """類別模式""" FIXED_ONLY = "fixed_only" FIXED_PLUS_CUSTOM = "fixed_plus_custom" + FIXED_PLUS_DYNAMIC = "fixed_plus_dynamic" # Fixed + LLM suggested CUSTOM_ONLY = "custom_only" DYNAMIC_AUTO = "dynamic_auto" @@ -98,3 +99,35 @@ class DynamicStep1Result(BaseModel): class DynamicCausalChain(BaseModel): """動態版本的因果鏈""" chain: Dict[str, str] # {類別名: 選中屬性} + + +# ===== DAG (Directed Acyclic Graph) schemas ===== + +class DAGNode(BaseModel): + """DAG 節點 - 每個屬性只出現一次""" + id: str # 唯一 ID: "{category}_{index}" + name: str # 顯示名稱 + category: str # 所屬類別 + order: int # 欄位內位置 + + +class DAGEdge(BaseModel): + """DAG 邊 - 節點之間的連接""" + source_id: str + target_id: str + + +class AttributeDAG(BaseModel): + """完整 DAG 結構""" + query: str + categories: List[CategoryDefinition] + nodes: List[DAGNode] + edges: List[DAGEdge] + + +class DAGRelationship(BaseModel): + """Step 2 輸出 - 單一關係""" + source_category: str + source: str # source attribute name + target_category: str + target: str # target attribute name diff --git a/backend/app/prompts/attribute_prompt.py b/backend/app/prompts/attribute_prompt.py index 07f6255..a7c0e5a 100644 --- a/backend/app/prompts/attribute_prompt.py +++ b/backend/app/prompts/attribute_prompt.py @@ -120,17 +120,26 @@ def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None # ===== Dynamic category system prompts ===== -def get_step0_category_analysis_prompt(query: str, suggested_count: int = 3) -> str: +def get_step0_category_analysis_prompt( + query: str, + suggested_count: int = 3, + exclude_categories: List[str] | None = None +) -> str: """Step 0: LLM 分析建議類別""" + exclude_text = "" + if exclude_categories: + exclude_text = f"\n【禁止使用的類別】{', '.join(exclude_categories)}(這些已經是固定類別,不要重複建議)\n" + return f"""/no_think 分析「{query}」,建議 {suggested_count} 個最適合的屬性類別來描述它。 -【常見類別參考】材料、功能、用途、使用族群、特性、形狀、顏色、尺寸、品牌、價格區間 - +【常見類別參考】特性、形狀、顏色、尺寸、品牌、價格區間、重量、風格、場合、季節、技術規格 +{exclude_text} 【重要】 1. 選擇最能描述此物件本質的類別 -2. 類別之間應該有邏輯關係(如:材料→功能→用途) +2. 類別之間應該有邏輯關係 3. 不要選擇過於抽象或重複的類別 +4. 必須建議與參考列表不同的、有創意的類別 只回傳 JSON: {{ @@ -213,3 +222,47 @@ def get_step2_dynamic_causal_chain_prompt( 只回傳 JSON: {json.dumps(json_template, ensure_ascii=False, indent=2)}""" + + +# ===== DAG relationship prompt ===== + +def get_step2_dag_relationships_prompt( + query: str, + categories: List, # List[CategoryDefinition] + attributes_by_category: Dict[str, List[str]], +) -> str: + """生成相鄰類別之間的自然關係""" + sorted_cats = sorted(categories, key=lambda x: x.order if hasattr(x, 'order') else x.get('order', 0)) + + # Build attribute listing + attr_listing = "\n".join([ + f"【{cat.name if hasattr(cat, 'name') else cat['name']}】{', '.join(attributes_by_category.get(cat.name if hasattr(cat, 'name') else cat['name'], []))}" + for cat in sorted_cats + ]) + + # Build direction hints + direction_hints = " → ".join([cat.name if hasattr(cat, 'name') else cat['name'] for cat in sorted_cats]) + + return f"""/no_think +分析「{query}」的屬性關係。 + +{attr_listing} + +【關係方向】{direction_hints} + +【規則】 +1. 只建立相鄰類別之間的關係(例如:材料→功能,功能→用途) +2. 只輸出真正有因果或關聯關係的配對 +3. 一個屬性可連接多個下游屬性,也可以不連接任何屬性 +4. 不需要每個屬性都有連接 +5. 關係應該合理且有意義 + +回傳 JSON: +{{ + "relationships": [ + {{"source_category": "類別A", "source": "屬性名", "target_category": "類別B", "target": "屬性名"}}, + ... + ] +}} + +只回傳 JSON。""" diff --git a/backend/app/routers/attributes.py b/backend/app/routers/attributes.py index 6c351f1..a4dfa32 100644 --- a/backend/app/routers/attributes.py +++ b/backend/app/routers/attributes.py @@ -16,6 +16,10 @@ from ..models.schemas import ( Step0Result, DynamicStep1Result, DynamicCausalChain, + DAGNode, + DAGEdge, + AttributeDAG, + DAGRelationship, ) from ..prompts.attribute_prompt import ( get_step1_attributes_prompt, @@ -23,6 +27,7 @@ from ..prompts.attribute_prompt import ( get_step0_category_analysis_prompt, get_step1_dynamic_attributes_prompt, get_step2_dynamic_causal_chain_prompt, + get_step2_dag_relationships_prompt, ) from ..services.llm_service import ollama_provider, extract_json_from_response @@ -39,14 +44,21 @@ FIXED_CATEGORIES = [ ] -async def execute_step0(request: StreamAnalyzeRequest) -> Step0Result | None: - """Execute Step 0 - LLM category analysis""" - if request.category_mode == CategoryMode.FIXED_ONLY: - return None +async def execute_step0( + request: StreamAnalyzeRequest, + exclude_categories: List[str] | None = None +) -> Step0Result | None: + """Execute Step 0 - LLM category analysis + Called only for modes that need LLM to suggest categories: + - FIXED_PLUS_DYNAMIC + - CUSTOM_ONLY + - DYNAMIC_AUTO + """ prompt = get_step0_category_analysis_prompt( request.query, - request.suggested_category_count + request.suggested_category_count, + exclude_categories=exclude_categories ) temperature = request.temperature if request.temperature is not None else 0.7 response = await ollama_provider.generate( @@ -83,6 +95,34 @@ def resolve_final_categories( ) return categories + elif request.category_mode == CategoryMode.FIXED_PLUS_DYNAMIC: + # Fixed categories + LLM suggested categories + categories = [ + CategoryDefinition( + name=cat.name, + description=cat.description, + is_fixed=True, + order=i + ) + for i, cat in enumerate(FIXED_CATEGORIES) + ] + if step0_result: + # Filter out LLM categories that duplicate fixed category names + fixed_names = {cat.name for cat in FIXED_CATEGORIES} + added_count = 0 + for cat in step0_result.categories: + if cat.name not in fixed_names: + categories.append( + CategoryDefinition( + name=cat.name, + description=cat.description, + is_fixed=False, + order=len(FIXED_CATEGORIES) + added_count + ) + ) + added_count += 1 + return categories + elif request.category_mode == CategoryMode.CUSTOM_ONLY: return step0_result.categories if step0_result else FIXED_CATEGORIES @@ -192,17 +232,73 @@ def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeN return root +def build_dag_from_relationships( + query: str, + categories: List[CategoryDefinition], + attributes_by_category: dict, + relationships: List[DAGRelationship], +) -> AttributeDAG: + """從屬性和關係建構 DAG""" + sorted_cats = sorted(categories, key=lambda x: x.order) + + # 建立節點 - 每個屬性只出現一次 + nodes: List[DAGNode] = [] + node_id_map: dict = {} # (category, name) -> id + + for cat in sorted_cats: + cat_name = cat.name + for idx, attr_name in enumerate(attributes_by_category.get(cat_name, [])): + node_id = f"{cat_name}_{idx}" + nodes.append(DAGNode( + id=node_id, + name=attr_name, + category=cat_name, + order=idx + )) + node_id_map[(cat_name, attr_name)] = node_id + + # 建立邊 + edges: List[DAGEdge] = [] + for rel in relationships: + source_key = (rel.source_category, rel.source) + target_key = (rel.target_category, rel.target) + if source_key in node_id_map and target_key in node_id_map: + edges.append(DAGEdge( + source_id=node_id_map[source_key], + target_id=node_id_map[target_key] + )) + + return AttributeDAG( + query=query, + categories=sorted_cats, + nodes=nodes, + edges=edges + ) + + async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]: """Generate SSE events with dynamic category support""" try: temperature = request.temperature if request.temperature is not None else 0.7 # ========== Step 0: Category Analysis (if needed) ========== + # Only these modes need LLM category analysis + needs_step0 = request.category_mode in [ + CategoryMode.FIXED_PLUS_DYNAMIC, + CategoryMode.CUSTOM_ONLY, + CategoryMode.DYNAMIC_AUTO, + ] + step0_result = None - if request.category_mode != CategoryMode.FIXED_ONLY: + if needs_step0: yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n" - step0_result = await execute_step0(request) + # For FIXED_PLUS_DYNAMIC, exclude the fixed category names + exclude_cats = None + if request.category_mode == CategoryMode.FIXED_PLUS_DYNAMIC: + exclude_cats = [cat.name for cat in FIXED_CATEGORIES] + + step0_result = await execute_step0(request, exclude_categories=exclude_cats) if step0_result: yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n" @@ -227,58 +323,58 @@ async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[s yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n" - # ========== Step 2: Generate Causal Chains (Dynamic) ========== - causal_chains: List[DynamicCausalChain] = [] + # ========== Step 2: Generate Relationships (DAG) ========== + yield f"event: relationships_start\ndata: {json.dumps({'message': '生成關係...'}, ensure_ascii=False)}\n\n" - for i in range(request.chain_count): - chain_index = i + 1 + step2_prompt = get_step2_dag_relationships_prompt( + query=request.query, + categories=final_categories, + attributes_by_category=step1_result.attributes, + ) + logger.info(f"Step 2 (relationships) prompt: {step2_prompt[:300]}") - yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n" + relationships: List[DAGRelationship] = [] + max_retries = 2 + for attempt in range(max_retries): + try: + step2_response = await ollama_provider.generate( + step2_prompt, model=request.model, temperature=temperature + ) + logger.info(f"Relationships response: {step2_response[:500]}") - step2_prompt = get_step2_dynamic_causal_chain_prompt( - query=request.query, - categories=final_categories, - attributes_by_category=step1_result.attributes, - existing_chains=[c.chain for c in causal_chains], - chain_index=chain_index, - ) + rel_data = extract_json_from_response(step2_response) + raw_relationships = rel_data.get("relationships", []) - # Gradually increase temperature for diversity - chain_temperature = min(temperature + 0.05 * i, 1.0) + for rel in raw_relationships: + relationships.append(DAGRelationship( + source_category=rel.get("source_category", ""), + source=rel.get("source", ""), + target_category=rel.get("target_category", ""), + target=rel.get("target", ""), + )) + break + except Exception as e: + logger.warning(f"Relationships attempt {attempt + 1} failed: {e}") + if attempt < max_retries - 1: + temperature = min(temperature + 0.1, 1.0) - max_retries = 2 - chain = None - for attempt in range(max_retries): - try: - step2_response = await ollama_provider.generate( - step2_prompt, model=request.model, temperature=chain_temperature - ) - logger.info(f"Chain {chain_index} response: {step2_response[:300]}") + yield f"event: relationships_complete\ndata: {json.dumps({'count': len(relationships)}, ensure_ascii=False)}\n\n" - chain_data = extract_json_from_response(step2_response) - chain = DynamicCausalChain(chain=chain_data) - break - except Exception as e: - logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}") - if attempt < max_retries - 1: - chain_temperature = min(chain_temperature + 0.1, 1.0) - - if chain: - causal_chains.append(chain) - yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n" - else: - yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'生成失敗'}, ensure_ascii=False)}\n\n" - - # ========== Assemble Final Tree (Dynamic) ========== - final_tree = assemble_dynamic_attribute_tree(request.query, causal_chains, final_categories) + # ========== Build DAG ========== + dag = build_dag_from_relationships( + query=request.query, + categories=final_categories, + attributes_by_category=step1_result.attributes, + relationships=relationships, + ) final_result = { "query": request.query, "step0_result": step0_result.model_dump() if step0_result else None, "categories_used": [c.model_dump() for c in final_categories], "step1_result": step1_result.model_dump(), - "causal_chains": [c.model_dump() for c in causal_chains], - "attributes": final_tree.model_dump(), + "relationships": [r.model_dump() for r in relationships], + "dag": dag.model_dump(), } yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n" diff --git a/backend/app/services/llm_service.py b/backend/app/services/llm_service.py index 8a42ad7..20b779b 100644 --- a/backend/app/services/llm_service.py +++ b/backend/app/services/llm_service.py @@ -1,4 +1,5 @@ import json +import logging import re from abc import ABC, abstractmethod from typing import List, Optional @@ -8,6 +9,8 @@ import httpx from ..config import settings from ..models.schemas import AttributeNode +logger = logging.getLogger(__name__) + class LLMProvider(ABC): @abstractmethod @@ -35,34 +38,56 @@ class OllamaProvider(LLMProvider): model = model or settings.default_model url = f"{self.base_url}/api/generate" + # Remove /no_think prefix for non-qwen models (it's qwen-specific) + clean_prompt = prompt + if not model.lower().startswith("qwen") and prompt.startswith("/no_think"): + clean_prompt = prompt.replace("/no_think\n", "").replace("/no_think", "") + logger.info(f"Removed /no_think prefix for model {model}") + + # Models known to support JSON format well + json_capable_models = ["qwen", "llama", "mistral", "gemma", "phi"] + model_lower = model.lower() + use_json_format = any(m in model_lower for m in json_capable_models) + payload = { "model": model, - "prompt": prompt, + "prompt": clean_prompt, "stream": False, - "format": "json", "options": { "temperature": temperature, }, } + # Only use format: json for models that support it + if use_json_format: + payload["format"] = "json" + else: + logger.info(f"Model {model} may not support JSON format, requesting without format constraint") + # Retry logic for larger models that may return empty responses max_retries = 3 for attempt in range(max_retries): + logger.info(f"LLM request attempt {attempt + 1}/{max_retries} to model {model}") response = await self.client.post(url, json=payload) response.raise_for_status() result = response.json() response_text = result.get("response", "") + logger.info(f"LLM response (first 500 chars): {response_text[:500] if response_text else '(empty)'}") + # Check if response is valid (not empty or just "{}") if response_text and response_text.strip() not in ["", "{}", "{ }"]: return response_text + logger.warning(f"Empty or invalid response on attempt {attempt + 1}, retrying...") + # If empty, retry with slightly higher temperature if attempt < max_retries - 1: payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0) # Return whatever we got on last attempt + logger.error(f"All {max_retries} attempts returned empty response from model {model}") return response_text async def list_models(self) -> List[str]: @@ -124,21 +149,46 @@ class OpenAICompatibleProvider(LLMProvider): def extract_json_from_response(response: str) -> dict: """Extract JSON from LLM response, handling markdown code blocks and extra whitespace.""" - # Remove markdown code blocks if present + if not response or not response.strip(): + logger.error("LLM returned empty response") + raise ValueError("LLM returned empty response - the model may not support JSON format or the prompt was unclear") + + json_str = response + + # Try multiple extraction strategies + extraction_attempts = [] + + # Strategy 1: Look for markdown code blocks json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response) if json_match: - json_str = json_match.group(1) - else: - json_str = response + extraction_attempts.append(json_match.group(1)) - # Clean up: remove extra whitespace, normalize spaces - json_str = json_str.strip() - # Remove trailing whitespace before closing braces/brackets - json_str = re.sub(r'\s+([}\]])', r'\1', json_str) - # Remove multiple spaces/tabs/newlines - json_str = re.sub(r'[\t\n\r]+', ' ', json_str) + # Strategy 2: Look for JSON object pattern { ... } + json_obj_match = re.search(r'(\{[\s\S]*\})', response) + if json_obj_match: + extraction_attempts.append(json_obj_match.group(1)) - return json.loads(json_str) + # Strategy 3: Original response + extraction_attempts.append(response) + + # Try each extraction attempt + for attempt_str in extraction_attempts: + # Clean up: remove extra whitespace, normalize spaces + cleaned = attempt_str.strip() + # Remove trailing whitespace before closing braces/brackets + cleaned = re.sub(r'\s+([}\]])', r'\1', cleaned) + # Normalize newlines but keep structure + cleaned = re.sub(r'[\t\r]+', ' ', cleaned) + + try: + return json.loads(cleaned) + except json.JSONDecodeError: + continue + + # All attempts failed + logger.error(f"Failed to parse JSON from response") + logger.error(f"Raw response: {response[:1000]}") + raise ValueError(f"Failed to parse LLM response as JSON. The model may not support structured output. Raw response: {response[:300]}...") def parse_attribute_response(response: str) -> AttributeNode: diff --git a/frontend/package-lock.json b/frontend/package-lock.json index ff3a7d3..95fb22c 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -10,6 +10,7 @@ "dependencies": { "@ant-design/icons": "^6.1.0", "@types/d3": "^7.4.3", + "@xyflow/react": "^12.9.3", "antd": "^6.0.0", "d3": "^7.9.0", "react": "^19.2.0", @@ -2454,7 +2455,7 @@ "version": "19.2.7", "resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.7.tgz", "integrity": "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==", - "dev": true, + "devOptional": true, "license": "MIT", "peer": true, "dependencies": { @@ -2763,6 +2764,38 @@ "vite": "^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0" } }, + "node_modules/@xyflow/react": { + "version": "12.9.3", + "resolved": "https://registry.npmjs.org/@xyflow/react/-/react-12.9.3.tgz", + "integrity": "sha512-PSWoJ8vHiEqSIkLIkge+0eiHWiw4C6dyFDA03VKWJkqbU4A13VlDIVwKqf/Znuysn2GQw/zA61zpHE4rGgax7Q==", + "license": "MIT", + "dependencies": { + "@xyflow/system": "0.0.73", + "classcat": "^5.0.3", + "zustand": "^4.4.0" + }, + "peerDependencies": { + "react": ">=17", + "react-dom": ">=17" + } + }, + "node_modules/@xyflow/system": { + "version": "0.0.73", + "resolved": "https://registry.npmjs.org/@xyflow/system/-/system-0.0.73.tgz", + "integrity": "sha512-C2ymH2V4mYDkdVSiRx0D7R0s3dvfXiupVBcko6tXP5K4tVdSBMo22/e3V9yRNdn+2HQFv44RFKzwOyCcUUDAVQ==", + "license": "MIT", + "dependencies": { + "@types/d3-drag": "^3.0.7", + "@types/d3-interpolate": "^3.0.4", + "@types/d3-selection": "^3.0.10", + "@types/d3-transition": "^3.0.8", + "@types/d3-zoom": "^3.0.8", + "d3-drag": "^3.0.0", + "d3-interpolate": "^3.0.1", + "d3-selection": "^3.0.0", + "d3-zoom": "^3.0.0" + } + }, "node_modules/acorn": { "version": "8.15.0", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", @@ -3001,6 +3034,12 @@ "url": "https://github.com/chalk/chalk?sponsor=1" } }, + "node_modules/classcat": { + "version": "5.0.5", + "resolved": "https://registry.npmjs.org/classcat/-/classcat-5.0.5.tgz", + "integrity": "sha512-JhZUT7JFcQy/EzW605k/ktHtncoo9vnyW/2GspNYwFlN1C/WmjuV/xtS04e9SOkL2sTdw0VAZ2UGCcQ9lR6p6w==", + "license": "MIT" + }, "node_modules/classnames": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/classnames/-/classnames-2.5.1.tgz", @@ -4809,6 +4848,15 @@ "punycode": "^2.1.0" } }, + "node_modules/use-sync-external-store": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", + "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", + "license": "MIT", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/vite": { "version": "7.2.6", "resolved": "https://registry.npmjs.org/vite/-/vite-7.2.6.tgz", @@ -4954,6 +5002,34 @@ "peerDependencies": { "zod": "^3.25.0 || ^4.0.0" } + }, + "node_modules/zustand": { + "version": "4.5.7", + "resolved": "https://registry.npmjs.org/zustand/-/zustand-4.5.7.tgz", + "integrity": "sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==", + "license": "MIT", + "dependencies": { + "use-sync-external-store": "^1.2.2" + }, + "engines": { + "node": ">=12.7.0" + }, + "peerDependencies": { + "@types/react": ">=16.8", + "immer": ">=9.0.6", + "react": ">=16.8" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "immer": { + "optional": true + }, + "react": { + "optional": true + } + } } } } diff --git a/frontend/package.json b/frontend/package.json index eea21f2..646e405 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -12,6 +12,7 @@ "dependencies": { "@ant-design/icons": "^6.1.0", "@types/d3": "^7.4.3", + "@xyflow/react": "^12.9.3", "antd": "^6.0.0", "d3": "^7.9.0", "react": "^19.2.0", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 5e1010a..dba3547 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,10 +1,11 @@ import { useState, useRef, useCallback } from 'react'; -import { ConfigProvider, Layout, theme, Typography } from 'antd'; +import { ConfigProvider, Layout, theme, Typography, Space } from 'antd'; +import { ApartmentOutlined } from '@ant-design/icons'; import { ThemeToggle } from './components/ThemeToggle'; import { InputPanel } from './components/InputPanel'; import { MindmapPanel } from './components/MindmapPanel'; import { useAttribute } from './hooks/useAttribute'; -import type { MindmapD3Ref } from './components/MindmapD3'; +import type { MindmapDAGRef } from './components/MindmapDAG'; import type { CategoryMode } from './types'; const { Header, Sider, Content } = Layout; @@ -22,7 +23,7 @@ function App() { nodeSpacing: 32, fontSize: 14, }); - const mindmapRef = useRef(null); + const mindmapRef = useRef(null); const handleAnalyze = async ( query: string, @@ -36,12 +37,8 @@ function App() { await analyze(query, model, temperature, chainCount, categoryMode, customCategories, suggestedCategoryCount); }; - const handleExpandAll = useCallback(() => { - mindmapRef.current?.expandAll(); - }, []); - - const handleCollapseAll = useCallback(() => { - mindmapRef.current?.collapseAll(); + const handleResetView = useCallback(() => { + mindmapRef.current?.resetView(); }, []); return ( @@ -57,11 +54,37 @@ function App() { alignItems: 'center', justifyContent: 'space-between', padding: '0 24px', + background: isDark + ? 'linear-gradient(90deg, #141414 0%, #1f1f1f 50%, #141414 100%)' + : 'linear-gradient(90deg, #fff 0%, #fafafa 50%, #fff 100%)', + borderBottom: isDark ? '1px solid #303030' : '1px solid #f0f0f0', + boxShadow: isDark + ? '0 2px 8px rgba(0, 0, 0, 0.3)' + : '0 2px 8px rgba(0, 0, 0, 0.06)', }} > - - Attribute Agent - + + + + Attribute Agent + + @@ -96,8 +119,7 @@ function App() { currentResult={currentResult} onAnalyze={handleAnalyze} onLoadHistory={loadFromHistory} - onExpandAll={handleExpandAll} - onCollapseAll={handleCollapseAll} + onResetView={handleResetView} visualSettings={visualSettings} onVisualSettingsChange={setVisualSettings} /> diff --git a/frontend/src/components/CategorySelector.tsx b/frontend/src/components/CategorySelector.tsx index 421ffbd..80728f2 100644 --- a/frontend/src/components/CategorySelector.tsx +++ b/frontend/src/components/CategorySelector.tsx @@ -42,28 +42,37 @@ export function CategorySelector({ Fixed (材料、功能、用途、使用族群) - Fixed + Custom + Fixed + Custom (手動新增) + + + Fixed + Dynamic (LLM 建議額外類別) - Custom Only (LLM suggests) + Custom Only (LLM 建議) - Dynamic (LLM suggests, editable) + Dynamic (LLM 建議, 可編輯) {/* 動態模式:類別數量調整 */} - {(mode === 'custom_only' || mode === 'dynamic_auto') && ( + {(mode === 'custom_only' || mode === 'dynamic_auto' || mode === 'fixed_plus_dynamic') && (
- Suggested Category Count: {suggestedCount} + + {mode === 'fixed_plus_dynamic' + ? `額外建議類別數: ${suggestedCount}` + : `Suggested Category Count: ${suggestedCount}`} +
@@ -120,7 +129,7 @@ export function CategorySelector({ )} {/* Step 0 結果顯示 */} - {step0Result && (mode === 'custom_only' || mode === 'dynamic_auto') && ( + {step0Result && (mode === 'custom_only' || mode === 'dynamic_auto' || mode === 'fixed_plus_dynamic') && (
LLM Suggested:
diff --git a/frontend/src/components/InputPanel.tsx b/frontend/src/components/InputPanel.tsx index a75b817..0096a65 100644 --- a/frontend/src/components/InputPanel.tsx +++ b/frontend/src/components/InputPanel.tsx @@ -11,23 +11,36 @@ import { Divider, Collapse, Progress, - Tag, + Card, + Alert, } from 'antd'; import { SearchOutlined, HistoryOutlined, - DownloadOutlined, - ExpandAltOutlined, - ShrinkOutlined, + ReloadOutlined, LoadingOutlined, - CheckCircleOutlined, + FileImageOutlined, + FileTextOutlined, + CodeOutlined, } from '@ant-design/icons'; -import type { HistoryItem, AttributeNode, StreamProgress, CategoryMode, DynamicCausalChain, CausalChain } from '../types'; +import type { AttributeDAG, CategoryMode } from '../types'; import { getModels } from '../services/api'; import { CategorySelector } from './CategorySelector'; +interface DAGProgress { + step: 'idle' | 'step0' | 'step1' | 'relationships' | 'done' | 'error'; + message: string; + error?: string; +} + +interface DAGHistoryItem { + query: string; + result: AttributeDAG; + timestamp: Date; +} + const { TextArea } = Input; -const { Text } = Typography; +const { Text, Title } = Typography; interface VisualSettings { nodeSpacing: number; @@ -36,9 +49,9 @@ interface VisualSettings { interface InputPanelProps { loading: boolean; - progress: StreamProgress; - history: HistoryItem[]; - currentResult: AttributeNode | null; + progress: DAGProgress; + history: DAGHistoryItem[]; + currentResult: AttributeDAG | null; onAnalyze: ( query: string, model?: string, @@ -48,9 +61,8 @@ interface InputPanelProps { customCategories?: string[], suggestedCategoryCount?: number ) => Promise; - onLoadHistory: (item: HistoryItem) => void; - onExpandAll?: () => void; - onCollapseAll?: () => void; + onLoadHistory: (item: DAGHistoryItem) => void; + onResetView?: () => void; visualSettings: VisualSettings; onVisualSettingsChange: (settings: VisualSettings) => void; } @@ -62,8 +74,7 @@ export function InputPanel({ currentResult, onAnalyze, onLoadHistory, - onExpandAll, - onCollapseAll, + onResetView, visualSettings, onVisualSettingsChange, }: InputPanelProps) { @@ -137,7 +148,7 @@ export function InputPanel({ const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; - a.download = `${currentResult.name || 'mindmap'}.json`; + a.download = `${currentResult.query || 'dag'}.json`; a.click(); URL.revokeObjectURL(url); }; @@ -148,134 +159,131 @@ export function InputPanel({ return; } - const nodeToMarkdown = (node: AttributeNode, level: number = 0): string => { - const indent = ' '.repeat(level); - let md = `${indent}- ${node.name}\n`; - if (node.children) { - node.children.forEach((child) => { - md += nodeToMarkdown(child, level + 1); - }); + // Group nodes by category + const nodesByCategory: Record = {}; + for (const node of currentResult.nodes) { + if (!nodesByCategory[node.category]) { + nodesByCategory[node.category] = []; } - return md; - }; + nodesByCategory[node.category].push(node.name); + } + + let markdown = `# ${currentResult.query}\n\n`; + for (const cat of currentResult.categories) { + const nodes = nodesByCategory[cat.name] || []; + markdown += `## ${cat.name}\n`; + for (const name of nodes) { + markdown += `- ${name}\n`; + } + markdown += '\n'; + } - const markdown = `# ${currentResult.name}\n\n${currentResult.children?.map((c) => nodeToMarkdown(c, 0)).join('') || ''}`; const blob = new Blob([markdown], { type: 'text/markdown' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; - a.download = `${currentResult.name || 'mindmap'}.md`; + a.download = `${currentResult.query || 'dag'}.md`; a.click(); URL.revokeObjectURL(url); }; const handleExportSVG = () => { - const svg = document.querySelector('.mindmap-svg'); - if (!svg) { + const reactFlowWrapper = document.querySelector('.react-flow'); + if (!reactFlowWrapper) { message.warning('No mindmap to export'); return; } + + const viewport = reactFlowWrapper.querySelector('.react-flow__viewport'); + if (!viewport) { + message.warning('No mindmap to export'); + return; + } + + const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); + svg.setAttribute('width', String(reactFlowWrapper.clientWidth)); + svg.setAttribute('height', String(reactFlowWrapper.clientHeight)); + svg.setAttribute('xmlns', 'http://www.w3.org/2000/svg'); + + const viewportClone = viewport.cloneNode(true) as SVGGElement; + svg.appendChild(viewportClone); + const svgData = new XMLSerializer().serializeToString(svg); const blob = new Blob([svgData], { type: 'image/svg+xml' }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; - a.download = `${currentResult?.name || 'mindmap'}.svg`; + a.download = `${currentResult?.query || 'dag'}.svg`; a.click(); URL.revokeObjectURL(url); }; const handleExportPNG = () => { - const svg = document.querySelector('.mindmap-svg') as SVGSVGElement; - if (!svg) { + const reactFlowWrapper = document.querySelector('.react-flow') as HTMLElement; + if (!reactFlowWrapper) { message.warning('No mindmap to export'); return; } + const viewport = reactFlowWrapper.querySelector('.react-flow__viewport'); + if (!viewport) { + message.warning('No mindmap to export'); + return; + } + + const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); + svg.setAttribute('width', String(reactFlowWrapper.clientWidth)); + svg.setAttribute('height', String(reactFlowWrapper.clientHeight)); + svg.setAttribute('xmlns', 'http://www.w3.org/2000/svg'); + svg.appendChild(viewport.cloneNode(true)); + const svgData = new XMLSerializer().serializeToString(svg); const canvas = document.createElement('canvas'); const ctx = canvas.getContext('2d'); const img = new Image(); img.onload = () => { - canvas.width = svg.clientWidth * 2; - canvas.height = svg.clientHeight * 2; + canvas.width = reactFlowWrapper.clientWidth * 2; + canvas.height = reactFlowWrapper.clientHeight * 2; ctx?.scale(2, 2); ctx?.drawImage(img, 0, 0); const pngUrl = canvas.toDataURL('image/png'); const a = document.createElement('a'); a.href = pngUrl; - a.download = `${currentResult?.name || 'mindmap'}.png`; + a.download = `${currentResult?.query || 'dag'}.png`; a.click(); }; img.src = 'data:image/svg+xml;base64,' + btoa(unescape(encodeURIComponent(svgData))); }; - // Helper to format chain display (supports both fixed and dynamic chains) - const formatChain = (chain: CausalChain | DynamicCausalChain): string => { - if ('chain' in chain) { - // Dynamic chain - return Object.values(chain.chain).join(' → '); - } else { - // Fixed chain - return `${chain.material} → ${chain.function} → ${chain.usage} → ${chain.user}`; - } - }; - const renderProgressIndicator = () => { if (progress.step === 'idle' || progress.step === 'done') return null; const percent = progress.step === 'step0' - ? 5 + ? 15 : progress.step === 'step1' - ? 10 - : progress.step === 'chains' - ? 10 + (progress.currentChainIndex / progress.totalChains) * 90 + ? 50 + : progress.step === 'relationships' + ? 85 : 100; return ( -
- - - {progress.step === 'error' ? ( - Error - ) : ( - - )} - {progress.message} - - - - {/* Show categories used */} - {progress.categoriesUsed && progress.categoriesUsed.length > 0 && ( -
- Categories: -
- {progress.categoriesUsed.map((cat, i) => ( - - {cat.name} - - ))} -
-
- )} - - {progress.completedChains.length > 0 && ( -
- Completed chains: -
- {progress.completedChains.map((chain, i) => ( -
- - {formatChain(chain)} -
- ))} -
-
- )} -
-
+ } + message={progress.message} + description={ + + } + style={{ marginBottom: 16 }} + showIcon + /> ); }; @@ -291,7 +299,6 @@ export function InputPanel({ onCustomCategoriesChange={setCustomCategories} suggestedCount={suggestedCategoryCount} onSuggestedCountChange={setSuggestedCategoryCount} - step0Result={progress.step0Result} disabled={loading} /> ), @@ -300,9 +307,9 @@ export function InputPanel({ key: 'llm', label: 'LLM Parameters', children: ( - +
- Temperature: {temperature} + Temperature: {temperature}
- Chain Count: {chainCount} + Chain Count: {chainCount} +
- Node Spacing: {visualSettings.nodeSpacing} + Node Spacing: {visualSettings.nodeSpacing}px
- Font Size: {visualSettings.fontSize}px + Font Size: {visualSettings.fontSize}px onVisualSettingsChange({ ...visualSettings, fontSize: v })} />
- - - - +
), }, @@ -364,83 +371,135 @@ export function InputPanel({ key: 'export', label: 'Export', children: ( - - - - - + + + + + + + + + ), }, ]; return ( -
- - Model -