import json import logging from typing import AsyncGenerator, List from fastapi import APIRouter, HTTPException from fastapi.responses import StreamingResponse from ..models.schemas import ( ModelListResponse, StreamAnalyzeRequest, Step1Result, CausalChain, AttributeNode, ) from ..prompts.attribute_prompt import ( get_step1_attributes_prompt, get_step2_causal_chain_prompt, ) from ..services.llm_service import ollama_provider, extract_json_from_response logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["attributes"]) def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode: """將因果鏈組裝成樹狀結構""" # 以材料為第一層分組 material_map = {} for chain in chains: if chain.material not in material_map: material_map[chain.material] = [] material_map[chain.material].append(chain) # 構建樹狀結構 root = AttributeNode(name=query, children=[]) for material, material_chains in material_map.items(): material_node = AttributeNode(name=material, category="材料", children=[]) # 以功能為第二層分組 function_map = {} for chain in material_chains: if chain.function not in function_map: function_map[chain.function] = [] function_map[chain.function].append(chain) for function, function_chains in function_map.items(): function_node = AttributeNode(name=function, category="功能", children=[]) # 以用途為第三層分組 usage_map = {} for chain in function_chains: if chain.usage not in usage_map: usage_map[chain.usage] = [] usage_map[chain.usage].append(chain) for usage, usage_chains in usage_map.items(): usage_node = AttributeNode( name=usage, category="用途", children=[ AttributeNode(name=c.user, category="使用族群") for c in usage_chains ], ) function_node.children.append(usage_node) material_node.children.append(function_node) root.children.append(material_node) return root async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]: """生成 SSE 事件流""" try: temperature = request.temperature if request.temperature is not None else 0.7 # ========== Step 1: 生成屬性列表 ========== yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n" step1_prompt = get_step1_attributes_prompt(request.query) logger.info(f"Step 1 prompt: {step1_prompt[:200]}") step1_response = await ollama_provider.generate( step1_prompt, model=request.model, temperature=temperature ) logger.info(f"Step 1 response: {step1_response[:500]}") step1_data = extract_json_from_response(step1_response) step1_result = Step1Result(**step1_data) yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n" # ========== Step 2: 逐條生成因果鏈 ========== causal_chains: List[CausalChain] = [] for i in range(request.chain_count): chain_index = i + 1 yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n" step2_prompt = get_step2_causal_chain_prompt( query=request.query, materials=step1_result.materials, functions=step1_result.functions, usages=step1_result.usages, users=step1_result.users, existing_chains=[c.model_dump() for c in causal_chains], chain_index=chain_index, ) # 逐漸提高 temperature 增加多樣性 chain_temperature = min(temperature + 0.05 * i, 1.0) max_retries = 2 chain = None for attempt in range(max_retries): try: step2_response = await ollama_provider.generate( step2_prompt, model=request.model, temperature=chain_temperature ) logger.info(f"Chain {chain_index} response: {step2_response[:300]}") chain_data = extract_json_from_response(step2_response) chain = CausalChain(**chain_data) break except Exception as e: logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}") if attempt < max_retries - 1: chain_temperature = min(chain_temperature + 0.1, 1.0) if chain: causal_chains.append(chain) yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n" else: yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'第 {chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n" # ========== 組裝最終結構 ========== final_tree = assemble_attribute_tree(request.query, causal_chains) final_result = { "query": request.query, "step1_result": step1_result.model_dump(), "causal_chains": [c.model_dump() for c in causal_chains], "attributes": final_tree.model_dump(), } yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n" except Exception as e: logger.error(f"SSE generation error: {e}") yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" @router.post("/analyze") async def analyze_stream(request: StreamAnalyzeRequest): """多步驟分析 with SSE streaming""" return StreamingResponse( generate_sse_events(request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.get("/models", response_model=ModelListResponse) async def list_models(): """List available LLM models.""" try: models = await ollama_provider.list_models() return ModelListResponse(models=models) except Exception as e: raise HTTPException(status_code=500, detail=str(e))