Initial commit

This commit is contained in:
2025-12-02 02:06:51 +08:00
commit eb6c0c51fa
37 changed files with 7454 additions and 0 deletions

0
backend/app/__init__.py Normal file
View File

15
backend/app/config.py Normal file
View File

@@ -0,0 +1,15 @@
from pydantic_settings import BaseSettings
from typing import Optional
class Settings(BaseSettings):
ollama_base_url: str = "http://192.168.30.36:11434"
default_model: str = "qwen3:8b"
openai_api_key: Optional[str] = None
openai_base_url: Optional[str] = None
class Config:
env_file = ".env"
settings = Settings()

41
backend/app/main.py Normal file
View File

@@ -0,0 +1,41 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .routers import attributes
from .services.llm_service import ollama_provider
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
await ollama_provider.close()
app = FastAPI(
title="Attribute Agent API",
description="API for analyzing objects and extracting their attributes",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(attributes.router)
@app.get("/")
async def root():
return {"message": "Attribute Agent API is running"}
@app.get("/health")
async def health_check():
return {"status": "healthy"}

View File

View File

@@ -0,0 +1,61 @@
from pydantic import BaseModel
from typing import Optional, List
class AttributeNode(BaseModel):
name: str
category: Optional[str] = None # 材料, 功能, 用途, 使用族群
children: Optional[List["AttributeNode"]] = None
AttributeNode.model_rebuild()
class AnalyzeRequest(BaseModel):
query: str
model: Optional[str] = None
temperature: Optional[float] = 0.7
categories: Optional[List[str]] = None # 如果為 None使用預設類別
class AnalyzeResponse(BaseModel):
query: str
attributes: AttributeNode
class ModelListResponse(BaseModel):
models: List[str]
# ===== Multi-step streaming schemas =====
class Step1Result(BaseModel):
"""Step 1 的結果:各類別屬性列表"""
materials: List[str]
functions: List[str]
usages: List[str]
users: List[str]
class CausalChain(BaseModel):
"""單條因果鏈"""
material: str
function: str
usage: str
user: str
class StreamAnalyzeRequest(BaseModel):
"""多步驟分析請求"""
query: str
model: Optional[str] = None
temperature: Optional[float] = 0.7
chain_count: int = 5 # 用戶可設定要生成多少條因果鏈
class StreamAnalyzeResponse(BaseModel):
"""最終完整結果"""
query: str
step1_result: Step1Result
causal_chains: List[CausalChain]
attributes: AttributeNode

View File

View File

@@ -0,0 +1,117 @@
from typing import List, Optional
DEFAULT_CATEGORIES = ["材料", "功能", "用途", "使用族群", "特性"]
CATEGORY_DESCRIPTIONS = {
"材料": "物件由什麼材料組成",
"功能": "物件能做什麼",
"用途": "物件在什麼場景使用",
"使用族群": "誰會使用這個物件",
"特性": "物件有什麼特徵",
}
def get_attribute_prompt(query: str, categories: Optional[List[str]] = None) -> str:
"""Generate prompt with causal chain structure."""
prompt = f"""分析「{query}」的屬性,以因果鏈方式呈現:材料→功能→用途→使用族群。
請列出 3-5 種材料,每種材料延伸出完整因果鏈。
JSON 格式:
{{"name": "{query}", "children": [{{"name": "材料名", "category": "材料", "children": [{{"name": "功能名", "category": "功能", "children": [{{"name": "用途名", "category": "用途", "children": [{{"name": "族群名", "category": "使用族群"}}]}}]}}]}}]}}
只回傳 JSON。"""
return prompt
def get_step1_attributes_prompt(query: str) -> str:
"""Step 1: 生成各類別的屬性列表(平行結構)"""
return f"""/no_think
分析「{query}」,列出以下四個類別的屬性。每個類別列出 3-5 個常見屬性。
只回傳 JSON格式如下
{{"materials": ["材料1", "材料2", "材料3"], "functions": ["功能1", "功能2", "功能3"], "usages": ["用途1", "用途2", "用途3"], "users": ["族群1", "族群2", "族群3"]}}
物件:{query}"""
def get_step2_causal_chain_prompt(
query: str,
materials: List[str],
functions: List[str],
usages: List[str],
users: List[str],
existing_chains: List[dict],
chain_index: int
) -> str:
"""Step 2: 生成單條因果鏈"""
existing_chains_text = ""
if existing_chains:
chains_list = [
f"- {c['material']}{c['function']}{c['usage']}{c['user']}"
for c in existing_chains
]
existing_chains_text = f"""
【已生成的因果鏈,請勿重複】
{chr(10).join(chains_list)}
"""
return f"""/no_think
為「{query}」生成第 {chain_index} 條因果鏈。
【可選材料】{', '.join(materials)}
【可選功能】{', '.join(functions)}
【可選用途】{', '.join(usages)}
【可選族群】{', '.join(users)}
{existing_chains_text}
【規則】
1. 從每個類別選擇一個屬性,組成合理的因果鏈
2. 因果關係必須合邏輯(材料決定功能,功能決定用途,用途決定族群)
3. 不要與已生成的因果鏈重複
只回傳 JSON
{{"material": "選擇的材料", "function": "選擇的功能", "usage": "選擇的用途", "user": "選擇的族群"}}"""
def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None) -> str:
"""Generate prompt with flat/parallel categories (original design)."""
cats = categories if categories else DEFAULT_CATEGORIES
# Build category list
category_lines = []
for cat in cats:
desc = CATEGORY_DESCRIPTIONS.get(cat, f"{cat}的相關屬性")
category_lines.append(f"- {cat}{desc}")
categories_text = "\n".join(category_lines)
prompt = f"""/no_think
你是一個物件屬性分析專家。請將用戶輸入的物件拆解成以下屬性類別。
【必須包含的類別】
{categories_text}
【重要】回傳格式必須是合法的 JSON每個節點都必須有 "name" 欄位:
```json
{{
"name": "物件名稱",
"children": [
{{
"name": "類別名稱",
"children": [
{{"name": "屬性1"}},
{{"name": "屬性2"}}
]
}}
]
}}
```
只回傳 JSON不要有任何其他文字。
用戶輸入:{query}"""
return prompt

View File

View File

@@ -0,0 +1,178 @@
import json
import logging
from typing import AsyncGenerator, List
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from ..models.schemas import (
ModelListResponse,
StreamAnalyzeRequest,
Step1Result,
CausalChain,
AttributeNode,
)
from ..prompts.attribute_prompt import (
get_step1_attributes_prompt,
get_step2_causal_chain_prompt,
)
from ..services.llm_service import ollama_provider, extract_json_from_response
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["attributes"])
def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode:
"""將因果鏈組裝成樹狀結構"""
# 以材料為第一層分組
material_map = {}
for chain in chains:
if chain.material not in material_map:
material_map[chain.material] = []
material_map[chain.material].append(chain)
# 構建樹狀結構
root = AttributeNode(name=query, children=[])
for material, material_chains in material_map.items():
material_node = AttributeNode(name=material, category="材料", children=[])
# 以功能為第二層分組
function_map = {}
for chain in material_chains:
if chain.function not in function_map:
function_map[chain.function] = []
function_map[chain.function].append(chain)
for function, function_chains in function_map.items():
function_node = AttributeNode(name=function, category="功能", children=[])
# 以用途為第三層分組
usage_map = {}
for chain in function_chains:
if chain.usage not in usage_map:
usage_map[chain.usage] = []
usage_map[chain.usage].append(chain)
for usage, usage_chains in usage_map.items():
usage_node = AttributeNode(
name=usage,
category="用途",
children=[
AttributeNode(name=c.user, category="使用族群")
for c in usage_chains
],
)
function_node.children.append(usage_node)
material_node.children.append(function_node)
root.children.append(material_node)
return root
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
"""生成 SSE 事件流"""
try:
temperature = request.temperature if request.temperature is not None else 0.7
# ========== Step 1: 生成屬性列表 ==========
yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n"
step1_prompt = get_step1_attributes_prompt(request.query)
logger.info(f"Step 1 prompt: {step1_prompt[:200]}")
step1_response = await ollama_provider.generate(
step1_prompt, model=request.model, temperature=temperature
)
logger.info(f"Step 1 response: {step1_response[:500]}")
step1_data = extract_json_from_response(step1_response)
step1_result = Step1Result(**step1_data)
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
# ========== Step 2: 逐條生成因果鏈 ==========
causal_chains: List[CausalChain] = []
for i in range(request.chain_count):
chain_index = i + 1
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
step2_prompt = get_step2_causal_chain_prompt(
query=request.query,
materials=step1_result.materials,
functions=step1_result.functions,
usages=step1_result.usages,
users=step1_result.users,
existing_chains=[c.model_dump() for c in causal_chains],
chain_index=chain_index,
)
# 逐漸提高 temperature 增加多樣性
chain_temperature = min(temperature + 0.05 * i, 1.0)
max_retries = 2
chain = None
for attempt in range(max_retries):
try:
step2_response = await ollama_provider.generate(
step2_prompt, model=request.model, temperature=chain_temperature
)
logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
chain_data = extract_json_from_response(step2_response)
chain = CausalChain(**chain_data)
break
except Exception as e:
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
if attempt < max_retries - 1:
chain_temperature = min(chain_temperature + 0.1, 1.0)
if chain:
causal_chains.append(chain)
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
else:
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'{chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n"
# ========== 組裝最終結構 ==========
final_tree = assemble_attribute_tree(request.query, causal_chains)
final_result = {
"query": request.query,
"step1_result": step1_result.model_dump(),
"causal_chains": [c.model_dump() for c in causal_chains],
"attributes": final_tree.model_dump(),
}
yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n"
except Exception as e:
logger.error(f"SSE generation error: {e}")
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
@router.post("/analyze")
async def analyze_stream(request: StreamAnalyzeRequest):
"""多步驟分析 with SSE streaming"""
return StreamingResponse(
generate_sse_events(request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/models", response_model=ModelListResponse)
async def list_models():
"""List available LLM models."""
try:
models = await ollama_provider.list_models()
return ModelListResponse(models=models)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

View File

@@ -0,0 +1,160 @@
import json
import re
from abc import ABC, abstractmethod
from typing import List, Optional
import httpx
from ..config import settings
from ..models.schemas import AttributeNode
class LLMProvider(ABC):
@abstractmethod
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
"""Generate a response from the LLM."""
pass
@abstractmethod
async def list_models(self) -> List[str]:
"""List available models."""
pass
class OllamaProvider(LLMProvider):
def __init__(self, base_url: str = None):
self.base_url = base_url or settings.ollama_base_url
# Increase timeout for larger models (14B, 32B, etc.)
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
) -> str:
model = model or settings.default_model
url = f"{self.base_url}/api/generate"
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"format": "json",
"options": {
"temperature": temperature,
},
}
# Retry logic for larger models that may return empty responses
max_retries = 3
for attempt in range(max_retries):
response = await self.client.post(url, json=payload)
response.raise_for_status()
result = response.json()
response_text = result.get("response", "")
# Check if response is valid (not empty or just "{}")
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
return response_text
# If empty, retry with slightly higher temperature
if attempt < max_retries - 1:
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
# Return whatever we got on last attempt
return response_text
async def list_models(self) -> List[str]:
url = f"{self.base_url}/api/tags"
response = await self.client.get(url)
response.raise_for_status()
result = response.json()
models = result.get("models", [])
return [m.get("name", "") for m in models if m.get("name")]
async def close(self):
await self.client.aclose()
class OpenAICompatibleProvider(LLMProvider):
def __init__(self, base_url: str = None, api_key: str = None):
self.base_url = base_url or settings.openai_base_url
self.api_key = api_key or settings.openai_api_key
# Increase timeout for larger models
self.client = httpx.AsyncClient(timeout=300.0)
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
model = model or settings.default_model
url = f"{self.base_url}/v1/chat/completions"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": model,
"messages": [{"role": "user", "content": prompt}],
}
response = await self.client.post(url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
async def list_models(self) -> List[str]:
url = f"{self.base_url}/v1/models"
headers = {}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
response = await self.client.get(url, headers=headers)
response.raise_for_status()
result = response.json()
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
async def close(self):
await self.client.aclose()
def extract_json_from_response(response: str) -> dict:
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
# Remove markdown code blocks if present
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
if json_match:
json_str = json_match.group(1)
else:
json_str = response
# Clean up: remove extra whitespace, normalize spaces
json_str = json_str.strip()
# Remove trailing whitespace before closing braces/brackets
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
# Remove multiple spaces/tabs/newlines
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
return json.loads(json_str)
def parse_attribute_response(response: str) -> AttributeNode:
"""Parse LLM response into AttributeNode structure."""
data = extract_json_from_response(response)
return AttributeNode.model_validate(data)
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
"""Factory function to get LLM provider."""
if provider_type == "ollama":
return OllamaProvider()
elif provider_type == "openai":
return OpenAICompatibleProvider()
else:
raise ValueError(f"Unknown provider type: {provider_type}")
ollama_provider = OllamaProvider()

6
backend/requirements.txt Normal file
View File

@@ -0,0 +1,6 @@
fastapi>=0.109.0
uvicorn[standard]>=0.27.0
httpx>=0.26.0
pydantic>=2.5.0
pydantic-settings>=2.1.0
python-dotenv>=1.0.0