import json import logging from typing import AsyncGenerator, List from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from ..models.schemas import ( ModelListResponse, StreamAnalyzeRequest, Step1Result, CausalChain, AttributeNode, CategoryMode, CategoryDefinition, Step0Result, DynamicStep1Result, DynamicCausalChain, DAGNode, DAGEdge, AttributeDAG, DAGRelationship, ) from ..prompts.attribute_prompt import ( get_step1_attributes_prompt, get_step2_causal_chain_prompt, get_step0_category_analysis_prompt, get_step1_dynamic_attributes_prompt, get_step2_dynamic_causal_chain_prompt, get_step2_dag_relationships_prompt, ) from ..services.llm_service import ollama_provider, extract_json_from_response logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["attributes"]) # Fixed categories definition FIXED_CATEGORIES = [ CategoryDefinition(name="材料", description="物件材料", is_fixed=True, order=0), CategoryDefinition(name="功能", description="物件功能", is_fixed=True, order=1), CategoryDefinition(name="用途", description="使用場景", is_fixed=True, order=2), CategoryDefinition(name="使用族群", description="目標用戶", is_fixed=True, order=3), ] async def execute_step0( request: StreamAnalyzeRequest, exclude_categories: List[str] | None = None ) -> Step0Result | None: """Execute Step 0 - LLM category analysis Called only for modes that need LLM to suggest categories: - FIXED_PLUS_DYNAMIC - CUSTOM_ONLY - DYNAMIC_AUTO """ prompt = get_step0_category_analysis_prompt( request.query, request.suggested_category_count, exclude_categories=exclude_categories ) temperature = request.temperature if request.temperature is not None else 0.7 response = await ollama_provider.generate( prompt, model=request.model, temperature=temperature ) data = extract_json_from_response(response) step0_result = Step0Result(**data) # Mark as LLM generated for cat in step0_result.categories: cat.is_fixed = False return step0_result def resolve_final_categories( request: StreamAnalyzeRequest, step0_result: Step0Result | None ) -> List[CategoryDefinition]: """Determine final categories based on mode""" if request.category_mode == CategoryMode.FIXED_ONLY: return FIXED_CATEGORIES elif request.category_mode == CategoryMode.FIXED_PLUS_CUSTOM: categories = FIXED_CATEGORIES.copy() if request.custom_categories: for i, name in enumerate(request.custom_categories): categories.append( CategoryDefinition( name=name, is_fixed=False, order=len(FIXED_CATEGORIES) + i ) ) return categories elif request.category_mode == CategoryMode.FIXED_PLUS_DYNAMIC: # Fixed categories + LLM suggested categories categories = [ CategoryDefinition( name=cat.name, description=cat.description, is_fixed=True, order=i ) for i, cat in enumerate(FIXED_CATEGORIES) ] if step0_result: # Filter out LLM categories that duplicate fixed category names fixed_names = {cat.name for cat in FIXED_CATEGORIES} added_count = 0 for cat in step0_result.categories: if cat.name not in fixed_names: categories.append( CategoryDefinition( name=cat.name, description=cat.description, is_fixed=False, order=len(FIXED_CATEGORIES) + added_count ) ) added_count += 1 return categories elif request.category_mode == CategoryMode.CUSTOM_ONLY: return step0_result.categories if step0_result else FIXED_CATEGORIES elif request.category_mode == CategoryMode.DYNAMIC_AUTO: return step0_result.categories if step0_result else FIXED_CATEGORIES return FIXED_CATEGORIES def assemble_dynamic_attribute_tree( query: str, chains: List[DynamicCausalChain], categories: List[CategoryDefinition] ) -> AttributeNode: """Assemble dynamic N-level tree from causal chains""" sorted_cats = sorted(categories, key=lambda x: x.order) if not chains: return AttributeNode(name=query, children=[]) def build_recursive( level: int, parent_path: dict, remaining_chains: List[DynamicCausalChain] ) -> List[AttributeNode]: if level >= len(sorted_cats): return [] current_cat = sorted_cats[level] grouped = {} for chain in remaining_chains: # Check if this chain matches the parent path if all(chain.chain.get(k) == v for k, v in parent_path.items()): attr_val = chain.chain.get(current_cat.name) if attr_val: if attr_val not in grouped: grouped[attr_val] = [] grouped[attr_val].append(chain) nodes = [] for attr_val, child_chains in grouped.items(): new_path = {**parent_path, current_cat.name: attr_val} children = build_recursive(level + 1, new_path, child_chains) node = AttributeNode( name=attr_val, category=current_cat.name, children=children if children else None ) nodes.append(node) return nodes root_children = build_recursive(0, {}, chains) return AttributeNode(name=query, children=root_children) def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode: """將因果鏈組裝成樹狀結構""" # 以材料為第一層分組 material_map = {} for chain in chains: if chain.material not in material_map: material_map[chain.material] = [] material_map[chain.material].append(chain) # 構建樹狀結構 root = AttributeNode(name=query, children=[]) for material, material_chains in material_map.items(): material_node = AttributeNode(name=material, category="材料", children=[]) # 以功能為第二層分組 function_map = {} for chain in material_chains: if chain.function not in function_map: function_map[chain.function] = [] function_map[chain.function].append(chain) for function, function_chains in function_map.items(): function_node = AttributeNode(name=function, category="功能", children=[]) # 以用途為第三層分組 usage_map = {} for chain in function_chains: if chain.usage not in usage_map: usage_map[chain.usage] = [] usage_map[chain.usage].append(chain) for usage, usage_chains in usage_map.items(): usage_node = AttributeNode( name=usage, category="用途", children=[ AttributeNode(name=c.user, category="使用族群") for c in usage_chains ], ) function_node.children.append(usage_node) material_node.children.append(function_node) root.children.append(material_node) return root def build_dag_from_relationships( query: str, categories: List[CategoryDefinition], attributes_by_category: dict, relationships: List[DAGRelationship], ) -> AttributeDAG: """從屬性和關係建構 DAG""" sorted_cats = sorted(categories, key=lambda x: x.order) # 建立節點 - 每個屬性只出現一次 nodes: List[DAGNode] = [] node_id_map: dict = {} # (category, name) -> id for cat in sorted_cats: cat_name = cat.name for idx, attr_name in enumerate(attributes_by_category.get(cat_name, [])): node_id = f"{cat_name}_{idx}" nodes.append(DAGNode( id=node_id, name=attr_name, category=cat_name, order=idx )) node_id_map[(cat_name, attr_name)] = node_id # 建立邊 edges: List[DAGEdge] = [] for rel in relationships: source_key = (rel.source_category, rel.source) target_key = (rel.target_category, rel.target) if source_key in node_id_map and target_key in node_id_map: edges.append(DAGEdge( source_id=node_id_map[source_key], target_id=node_id_map[target_key] )) return AttributeDAG( query=query, categories=sorted_cats, nodes=nodes, edges=edges ) async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]: """Generate SSE events with dynamic category support""" try: temperature = request.temperature if request.temperature is not None else 0.7 # ========== Step 0: Category Analysis (if needed) ========== # Only these modes need LLM category analysis needs_step0 = request.category_mode in [ CategoryMode.FIXED_PLUS_DYNAMIC, CategoryMode.CUSTOM_ONLY, CategoryMode.DYNAMIC_AUTO, ] step0_result = None if needs_step0: yield f"event: step0_start\ndata: {json.dumps({'message': '分析類別...'}, ensure_ascii=False)}\n\n" # For FIXED_PLUS_DYNAMIC, exclude the fixed category names exclude_cats = None if request.category_mode == CategoryMode.FIXED_PLUS_DYNAMIC: exclude_cats = [cat.name for cat in FIXED_CATEGORIES] step0_result = await execute_step0(request, exclude_categories=exclude_cats) if step0_result: yield f"event: step0_complete\ndata: {json.dumps({'result': step0_result.model_dump()}, ensure_ascii=False)}\n\n" # ========== Resolve Final Categories ========== final_categories = resolve_final_categories(request, step0_result) yield f"event: categories_resolved\ndata: {json.dumps({'categories': [c.model_dump() for c in final_categories]}, ensure_ascii=False)}\n\n" # ========== Step 1: Generate Attributes (Dynamic) ========== yield f"event: step1_start\ndata: {json.dumps({'message': '生成屬性...'}, ensure_ascii=False)}\n\n" step1_prompt = get_step1_dynamic_attributes_prompt(request.query, final_categories) logger.info(f"Step 1 prompt: {step1_prompt[:200]}") step1_response = await ollama_provider.generate( step1_prompt, model=request.model, temperature=temperature ) logger.info(f"Step 1 response: {step1_response[:500]}") step1_data = extract_json_from_response(step1_response) step1_result = DynamicStep1Result(attributes=step1_data) yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n" # ========== Step 2: Generate Relationships (DAG) ========== yield f"event: relationships_start\ndata: {json.dumps({'message': '生成關係...'}, ensure_ascii=False)}\n\n" step2_prompt = get_step2_dag_relationships_prompt( query=request.query, categories=final_categories, attributes_by_category=step1_result.attributes, ) logger.info(f"Step 2 (relationships) prompt: {step2_prompt[:300]}") relationships: List[DAGRelationship] = [] max_retries = 2 for attempt in range(max_retries): try: step2_response = await ollama_provider.generate( step2_prompt, model=request.model, temperature=temperature ) logger.info(f"Relationships response: {step2_response[:500]}") rel_data = extract_json_from_response(step2_response) raw_relationships = rel_data.get("relationships", []) for rel in raw_relationships: relationships.append(DAGRelationship( source_category=rel.get("source_category", ""), source=rel.get("source", ""), target_category=rel.get("target_category", ""), target=rel.get("target", ""), )) break except Exception as e: logger.warning(f"Relationships attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: temperature = min(temperature + 0.1, 1.0) yield f"event: relationships_complete\ndata: {json.dumps({'count': len(relationships)}, ensure_ascii=False)}\n\n" # ========== Build DAG ========== dag = build_dag_from_relationships( query=request.query, categories=final_categories, attributes_by_category=step1_result.attributes, relationships=relationships, ) final_result = { "query": request.query, "step0_result": step0_result.model_dump() if step0_result else None, "categories_used": [c.model_dump() for c in final_categories], "step1_result": step1_result.model_dump(), "relationships": [r.model_dump() for r in relationships], "dag": dag.model_dump(), } yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n" except Exception as e: logger.error(f"SSE generation error: {e}") yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" @router.post("/analyze") async def analyze_stream(request: StreamAnalyzeRequest): """多步驟分析 with SSE streaming""" return StreamingResponse( generate_sse_events(request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.get("/models", response_model=ModelListResponse) async def list_models(): """List available LLM models.""" try: models = await ollama_provider.list_models() return ModelListResponse(models=models) except Exception as e: raise HTTPException(status_code=500, detail=str(e))