Initial commit
This commit is contained in:
0
backend/app/routers/__init__.py
Normal file
0
backend/app/routers/__init__.py
Normal file
178
backend/app/routers/attributes.py
Normal file
178
backend/app/routers/attributes.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..models.schemas import (
|
||||
ModelListResponse,
|
||||
StreamAnalyzeRequest,
|
||||
Step1Result,
|
||||
CausalChain,
|
||||
AttributeNode,
|
||||
)
|
||||
from ..prompts.attribute_prompt import (
|
||||
get_step1_attributes_prompt,
|
||||
get_step2_causal_chain_prompt,
|
||||
)
|
||||
from ..services.llm_service import ollama_provider, extract_json_from_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["attributes"])
|
||||
|
||||
|
||||
def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode:
|
||||
"""將因果鏈組裝成樹狀結構"""
|
||||
# 以材料為第一層分組
|
||||
material_map = {}
|
||||
|
||||
for chain in chains:
|
||||
if chain.material not in material_map:
|
||||
material_map[chain.material] = []
|
||||
material_map[chain.material].append(chain)
|
||||
|
||||
# 構建樹狀結構
|
||||
root = AttributeNode(name=query, children=[])
|
||||
|
||||
for material, material_chains in material_map.items():
|
||||
material_node = AttributeNode(name=material, category="材料", children=[])
|
||||
|
||||
# 以功能為第二層分組
|
||||
function_map = {}
|
||||
for chain in material_chains:
|
||||
if chain.function not in function_map:
|
||||
function_map[chain.function] = []
|
||||
function_map[chain.function].append(chain)
|
||||
|
||||
for function, function_chains in function_map.items():
|
||||
function_node = AttributeNode(name=function, category="功能", children=[])
|
||||
|
||||
# 以用途為第三層分組
|
||||
usage_map = {}
|
||||
for chain in function_chains:
|
||||
if chain.usage not in usage_map:
|
||||
usage_map[chain.usage] = []
|
||||
usage_map[chain.usage].append(chain)
|
||||
|
||||
for usage, usage_chains in usage_map.items():
|
||||
usage_node = AttributeNode(
|
||||
name=usage,
|
||||
category="用途",
|
||||
children=[
|
||||
AttributeNode(name=c.user, category="使用族群")
|
||||
for c in usage_chains
|
||||
],
|
||||
)
|
||||
function_node.children.append(usage_node)
|
||||
|
||||
material_node.children.append(function_node)
|
||||
|
||||
root.children.append(material_node)
|
||||
|
||||
return root
|
||||
|
||||
|
||||
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
|
||||
"""生成 SSE 事件流"""
|
||||
try:
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
|
||||
# ========== Step 1: 生成屬性列表 ==========
|
||||
yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step1_prompt = get_step1_attributes_prompt(request.query)
|
||||
logger.info(f"Step 1 prompt: {step1_prompt[:200]}")
|
||||
|
||||
step1_response = await ollama_provider.generate(
|
||||
step1_prompt, model=request.model, temperature=temperature
|
||||
)
|
||||
logger.info(f"Step 1 response: {step1_response[:500]}")
|
||||
|
||||
step1_data = extract_json_from_response(step1_response)
|
||||
step1_result = Step1Result(**step1_data)
|
||||
|
||||
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Step 2: 逐條生成因果鏈 ==========
|
||||
causal_chains: List[CausalChain] = []
|
||||
|
||||
for i in range(request.chain_count):
|
||||
chain_index = i + 1
|
||||
|
||||
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step2_prompt = get_step2_causal_chain_prompt(
|
||||
query=request.query,
|
||||
materials=step1_result.materials,
|
||||
functions=step1_result.functions,
|
||||
usages=step1_result.usages,
|
||||
users=step1_result.users,
|
||||
existing_chains=[c.model_dump() for c in causal_chains],
|
||||
chain_index=chain_index,
|
||||
)
|
||||
|
||||
# 逐漸提高 temperature 增加多樣性
|
||||
chain_temperature = min(temperature + 0.05 * i, 1.0)
|
||||
|
||||
max_retries = 2
|
||||
chain = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step2_response = await ollama_provider.generate(
|
||||
step2_prompt, model=request.model, temperature=chain_temperature
|
||||
)
|
||||
logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
|
||||
|
||||
chain_data = extract_json_from_response(step2_response)
|
||||
chain = CausalChain(**chain_data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
chain_temperature = min(chain_temperature + 0.1, 1.0)
|
||||
|
||||
if chain:
|
||||
causal_chains.append(chain)
|
||||
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'第 {chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== 組裝最終結構 ==========
|
||||
final_tree = assemble_attribute_tree(request.query, causal_chains)
|
||||
|
||||
final_result = {
|
||||
"query": request.query,
|
||||
"step1_result": step1_result.model_dump(),
|
||||
"causal_chains": [c.model_dump() for c in causal_chains],
|
||||
"attributes": final_tree.model_dump(),
|
||||
}
|
||||
yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSE generation error: {e}")
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze_stream(request: StreamAnalyzeRequest):
|
||||
"""多步驟分析 with SSE streaming"""
|
||||
return StreamingResponse(
|
||||
generate_sse_events(request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelListResponse)
|
||||
async def list_models():
|
||||
"""List available LLM models."""
|
||||
try:
|
||||
models = await ollama_provider.list_models()
|
||||
return ModelListResponse(models=models)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
Reference in New Issue
Block a user