Initial commit
This commit is contained in:
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
15
backend/app/config.py
Normal file
15
backend/app/config.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
ollama_base_url: str = "http://192.168.30.36:11434"
|
||||
default_model: str = "qwen3:8b"
|
||||
openai_api_key: Optional[str] = None
|
||||
openai_base_url: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
41
backend/app/main.py
Normal file
41
backend/app/main.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .routers import attributes
|
||||
from .services.llm_service import ollama_provider
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
await ollama_provider.close()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Attribute Agent API",
|
||||
description="API for analyzing objects and extracting their attributes",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(attributes.router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Attribute Agent API is running"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy"}
|
||||
0
backend/app/models/__init__.py
Normal file
0
backend/app/models/__init__.py
Normal file
61
backend/app/models/schemas.py
Normal file
61
backend/app/models/schemas.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class AttributeNode(BaseModel):
|
||||
name: str
|
||||
category: Optional[str] = None # 材料, 功能, 用途, 使用族群
|
||||
children: Optional[List["AttributeNode"]] = None
|
||||
|
||||
|
||||
AttributeNode.model_rebuild()
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
query: str
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = 0.7
|
||||
categories: Optional[List[str]] = None # 如果為 None,使用預設類別
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
query: str
|
||||
attributes: AttributeNode
|
||||
|
||||
|
||||
class ModelListResponse(BaseModel):
|
||||
models: List[str]
|
||||
|
||||
|
||||
# ===== Multi-step streaming schemas =====
|
||||
|
||||
class Step1Result(BaseModel):
|
||||
"""Step 1 的結果:各類別屬性列表"""
|
||||
materials: List[str]
|
||||
functions: List[str]
|
||||
usages: List[str]
|
||||
users: List[str]
|
||||
|
||||
|
||||
class CausalChain(BaseModel):
|
||||
"""單條因果鏈"""
|
||||
material: str
|
||||
function: str
|
||||
usage: str
|
||||
user: str
|
||||
|
||||
|
||||
class StreamAnalyzeRequest(BaseModel):
|
||||
"""多步驟分析請求"""
|
||||
query: str
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = 0.7
|
||||
chain_count: int = 5 # 用戶可設定要生成多少條因果鏈
|
||||
|
||||
|
||||
class StreamAnalyzeResponse(BaseModel):
|
||||
"""最終完整結果"""
|
||||
query: str
|
||||
step1_result: Step1Result
|
||||
causal_chains: List[CausalChain]
|
||||
attributes: AttributeNode
|
||||
0
backend/app/prompts/__init__.py
Normal file
0
backend/app/prompts/__init__.py
Normal file
117
backend/app/prompts/attribute_prompt.py
Normal file
117
backend/app/prompts/attribute_prompt.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import List, Optional
|
||||
|
||||
DEFAULT_CATEGORIES = ["材料", "功能", "用途", "使用族群", "特性"]
|
||||
|
||||
CATEGORY_DESCRIPTIONS = {
|
||||
"材料": "物件由什麼材料組成",
|
||||
"功能": "物件能做什麼",
|
||||
"用途": "物件在什麼場景使用",
|
||||
"使用族群": "誰會使用這個物件",
|
||||
"特性": "物件有什麼特徵",
|
||||
}
|
||||
|
||||
|
||||
def get_attribute_prompt(query: str, categories: Optional[List[str]] = None) -> str:
|
||||
"""Generate prompt with causal chain structure."""
|
||||
|
||||
prompt = f"""分析「{query}」的屬性,以因果鏈方式呈現:材料→功能→用途→使用族群。
|
||||
|
||||
請列出 3-5 種材料,每種材料延伸出完整因果鏈。
|
||||
|
||||
JSON 格式:
|
||||
{{"name": "{query}", "children": [{{"name": "材料名", "category": "材料", "children": [{{"name": "功能名", "category": "功能", "children": [{{"name": "用途名", "category": "用途", "children": [{{"name": "族群名", "category": "使用族群"}}]}}]}}]}}]}}
|
||||
|
||||
只回傳 JSON。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def get_step1_attributes_prompt(query: str) -> str:
|
||||
"""Step 1: 生成各類別的屬性列表(平行結構)"""
|
||||
return f"""/no_think
|
||||
分析「{query}」,列出以下四個類別的屬性。每個類別列出 3-5 個常見屬性。
|
||||
|
||||
只回傳 JSON,格式如下:
|
||||
{{"materials": ["材料1", "材料2", "材料3"], "functions": ["功能1", "功能2", "功能3"], "usages": ["用途1", "用途2", "用途3"], "users": ["族群1", "族群2", "族群3"]}}
|
||||
|
||||
物件:{query}"""
|
||||
|
||||
|
||||
def get_step2_causal_chain_prompt(
|
||||
query: str,
|
||||
materials: List[str],
|
||||
functions: List[str],
|
||||
usages: List[str],
|
||||
users: List[str],
|
||||
existing_chains: List[dict],
|
||||
chain_index: int
|
||||
) -> str:
|
||||
"""Step 2: 生成單條因果鏈"""
|
||||
existing_chains_text = ""
|
||||
if existing_chains:
|
||||
chains_list = [
|
||||
f"- {c['material']} → {c['function']} → {c['usage']} → {c['user']}"
|
||||
for c in existing_chains
|
||||
]
|
||||
existing_chains_text = f"""
|
||||
【已生成的因果鏈,請勿重複】
|
||||
{chr(10).join(chains_list)}
|
||||
"""
|
||||
|
||||
return f"""/no_think
|
||||
為「{query}」生成第 {chain_index} 條因果鏈。
|
||||
|
||||
【可選材料】{', '.join(materials)}
|
||||
【可選功能】{', '.join(functions)}
|
||||
【可選用途】{', '.join(usages)}
|
||||
【可選族群】{', '.join(users)}
|
||||
{existing_chains_text}
|
||||
【規則】
|
||||
1. 從每個類別選擇一個屬性,組成合理的因果鏈
|
||||
2. 因果關係必須合邏輯(材料決定功能,功能決定用途,用途決定族群)
|
||||
3. 不要與已生成的因果鏈重複
|
||||
|
||||
只回傳 JSON:
|
||||
{{"material": "選擇的材料", "function": "選擇的功能", "usage": "選擇的用途", "user": "選擇的族群"}}"""
|
||||
|
||||
|
||||
def get_flat_attribute_prompt(query: str, categories: Optional[List[str]] = None) -> str:
|
||||
"""Generate prompt with flat/parallel categories (original design)."""
|
||||
cats = categories if categories else DEFAULT_CATEGORIES
|
||||
|
||||
# Build category list
|
||||
category_lines = []
|
||||
for cat in cats:
|
||||
desc = CATEGORY_DESCRIPTIONS.get(cat, f"{cat}的相關屬性")
|
||||
category_lines.append(f"- {cat}:{desc}")
|
||||
|
||||
categories_text = "\n".join(category_lines)
|
||||
|
||||
prompt = f"""/no_think
|
||||
你是一個物件屬性分析專家。請將用戶輸入的物件拆解成以下屬性類別。
|
||||
|
||||
【必須包含的類別】
|
||||
{categories_text}
|
||||
|
||||
【重要】回傳格式必須是合法的 JSON,每個節點都必須有 "name" 欄位:
|
||||
|
||||
```json
|
||||
{{
|
||||
"name": "物件名稱",
|
||||
"children": [
|
||||
{{
|
||||
"name": "類別名稱",
|
||||
"children": [
|
||||
{{"name": "屬性1"}},
|
||||
{{"name": "屬性2"}}
|
||||
]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
```
|
||||
|
||||
只回傳 JSON,不要有任何其他文字。
|
||||
|
||||
用戶輸入:{query}"""
|
||||
|
||||
return prompt
|
||||
0
backend/app/routers/__init__.py
Normal file
0
backend/app/routers/__init__.py
Normal file
178
backend/app/routers/attributes.py
Normal file
178
backend/app/routers/attributes.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from ..models.schemas import (
|
||||
ModelListResponse,
|
||||
StreamAnalyzeRequest,
|
||||
Step1Result,
|
||||
CausalChain,
|
||||
AttributeNode,
|
||||
)
|
||||
from ..prompts.attribute_prompt import (
|
||||
get_step1_attributes_prompt,
|
||||
get_step2_causal_chain_prompt,
|
||||
)
|
||||
from ..services.llm_service import ollama_provider, extract_json_from_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api", tags=["attributes"])
|
||||
|
||||
|
||||
def assemble_attribute_tree(query: str, chains: List[CausalChain]) -> AttributeNode:
|
||||
"""將因果鏈組裝成樹狀結構"""
|
||||
# 以材料為第一層分組
|
||||
material_map = {}
|
||||
|
||||
for chain in chains:
|
||||
if chain.material not in material_map:
|
||||
material_map[chain.material] = []
|
||||
material_map[chain.material].append(chain)
|
||||
|
||||
# 構建樹狀結構
|
||||
root = AttributeNode(name=query, children=[])
|
||||
|
||||
for material, material_chains in material_map.items():
|
||||
material_node = AttributeNode(name=material, category="材料", children=[])
|
||||
|
||||
# 以功能為第二層分組
|
||||
function_map = {}
|
||||
for chain in material_chains:
|
||||
if chain.function not in function_map:
|
||||
function_map[chain.function] = []
|
||||
function_map[chain.function].append(chain)
|
||||
|
||||
for function, function_chains in function_map.items():
|
||||
function_node = AttributeNode(name=function, category="功能", children=[])
|
||||
|
||||
# 以用途為第三層分組
|
||||
usage_map = {}
|
||||
for chain in function_chains:
|
||||
if chain.usage not in usage_map:
|
||||
usage_map[chain.usage] = []
|
||||
usage_map[chain.usage].append(chain)
|
||||
|
||||
for usage, usage_chains in usage_map.items():
|
||||
usage_node = AttributeNode(
|
||||
name=usage,
|
||||
category="用途",
|
||||
children=[
|
||||
AttributeNode(name=c.user, category="使用族群")
|
||||
for c in usage_chains
|
||||
],
|
||||
)
|
||||
function_node.children.append(usage_node)
|
||||
|
||||
material_node.children.append(function_node)
|
||||
|
||||
root.children.append(material_node)
|
||||
|
||||
return root
|
||||
|
||||
|
||||
async def generate_sse_events(request: StreamAnalyzeRequest) -> AsyncGenerator[str, None]:
|
||||
"""生成 SSE 事件流"""
|
||||
try:
|
||||
temperature = request.temperature if request.temperature is not None else 0.7
|
||||
|
||||
# ========== Step 1: 生成屬性列表 ==========
|
||||
yield f"event: step1_start\ndata: {json.dumps({'message': '正在分析屬性列表...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step1_prompt = get_step1_attributes_prompt(request.query)
|
||||
logger.info(f"Step 1 prompt: {step1_prompt[:200]}")
|
||||
|
||||
step1_response = await ollama_provider.generate(
|
||||
step1_prompt, model=request.model, temperature=temperature
|
||||
)
|
||||
logger.info(f"Step 1 response: {step1_response[:500]}")
|
||||
|
||||
step1_data = extract_json_from_response(step1_response)
|
||||
step1_result = Step1Result(**step1_data)
|
||||
|
||||
yield f"event: step1_complete\ndata: {json.dumps({'result': step1_result.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== Step 2: 逐條生成因果鏈 ==========
|
||||
causal_chains: List[CausalChain] = []
|
||||
|
||||
for i in range(request.chain_count):
|
||||
chain_index = i + 1
|
||||
|
||||
yield f"event: chain_start\ndata: {json.dumps({'index': chain_index, 'total': request.chain_count, 'message': f'正在生成第 {chain_index}/{request.chain_count} 條因果鏈...'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
step2_prompt = get_step2_causal_chain_prompt(
|
||||
query=request.query,
|
||||
materials=step1_result.materials,
|
||||
functions=step1_result.functions,
|
||||
usages=step1_result.usages,
|
||||
users=step1_result.users,
|
||||
existing_chains=[c.model_dump() for c in causal_chains],
|
||||
chain_index=chain_index,
|
||||
)
|
||||
|
||||
# 逐漸提高 temperature 增加多樣性
|
||||
chain_temperature = min(temperature + 0.05 * i, 1.0)
|
||||
|
||||
max_retries = 2
|
||||
chain = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
step2_response = await ollama_provider.generate(
|
||||
step2_prompt, model=request.model, temperature=chain_temperature
|
||||
)
|
||||
logger.info(f"Chain {chain_index} response: {step2_response[:300]}")
|
||||
|
||||
chain_data = extract_json_from_response(step2_response)
|
||||
chain = CausalChain(**chain_data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"Chain {chain_index} attempt {attempt + 1} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
chain_temperature = min(chain_temperature + 0.1, 1.0)
|
||||
|
||||
if chain:
|
||||
causal_chains.append(chain)
|
||||
yield f"event: chain_complete\ndata: {json.dumps({'index': chain_index, 'chain': chain.model_dump()}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"event: chain_error\ndata: {json.dumps({'index': chain_index, 'error': f'第 {chain_index} 條因果鏈生成失敗'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# ========== 組裝最終結構 ==========
|
||||
final_tree = assemble_attribute_tree(request.query, causal_chains)
|
||||
|
||||
final_result = {
|
||||
"query": request.query,
|
||||
"step1_result": step1_result.model_dump(),
|
||||
"causal_chains": [c.model_dump() for c in causal_chains],
|
||||
"attributes": final_tree.model_dump(),
|
||||
}
|
||||
yield f"event: done\ndata: {json.dumps(final_result, ensure_ascii=False)}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"SSE generation error: {e}")
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@router.post("/analyze")
|
||||
async def analyze_stream(request: StreamAnalyzeRequest):
|
||||
"""多步驟分析 with SSE streaming"""
|
||||
return StreamingResponse(
|
||||
generate_sse_events(request),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/models", response_model=ModelListResponse)
|
||||
async def list_models():
|
||||
"""List available LLM models."""
|
||||
try:
|
||||
models = await ollama_provider.list_models()
|
||||
return ModelListResponse(models=models)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
160
backend/app/services/llm_service.py
Normal file
160
backend/app/services/llm_service.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..config import settings
|
||||
from ..models.schemas import AttributeNode
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
||||
) -> str:
|
||||
"""Generate a response from the LLM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self) -> List[str]:
|
||||
"""List available models."""
|
||||
pass
|
||||
|
||||
|
||||
class OllamaProvider(LLMProvider):
|
||||
def __init__(self, base_url: str = None):
|
||||
self.base_url = base_url or settings.ollama_base_url
|
||||
# Increase timeout for larger models (14B, 32B, etc.)
|
||||
self.client = httpx.AsyncClient(timeout=300.0)
|
||||
|
||||
async def generate(
|
||||
self, prompt: str, model: Optional[str] = None, temperature: float = 0.7
|
||||
) -> str:
|
||||
model = model or settings.default_model
|
||||
url = f"{self.base_url}/api/generate"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
# Retry logic for larger models that may return empty responses
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
response_text = result.get("response", "")
|
||||
|
||||
# Check if response is valid (not empty or just "{}")
|
||||
if response_text and response_text.strip() not in ["", "{}", "{ }"]:
|
||||
return response_text
|
||||
|
||||
# If empty, retry with slightly higher temperature
|
||||
if attempt < max_retries - 1:
|
||||
payload["options"]["temperature"] = min(temperature + 0.1 * (attempt + 1), 1.0)
|
||||
|
||||
# Return whatever we got on last attempt
|
||||
return response_text
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
url = f"{self.base_url}/api/tags"
|
||||
|
||||
response = await self.client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
models = result.get("models", [])
|
||||
return [m.get("name", "") for m in models if m.get("name")]
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
class OpenAICompatibleProvider(LLMProvider):
|
||||
def __init__(self, base_url: str = None, api_key: str = None):
|
||||
self.base_url = base_url or settings.openai_base_url
|
||||
self.api_key = api_key or settings.openai_api_key
|
||||
# Increase timeout for larger models
|
||||
self.client = httpx.AsyncClient(timeout=300.0)
|
||||
|
||||
async def generate(self, prompt: str, model: Optional[str] = None) -> str:
|
||||
model = model or settings.default_model
|
||||
url = f"{self.base_url}/v1/chat/completions"
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
|
||||
response = await self.client.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
async def list_models(self) -> List[str]:
|
||||
url = f"{self.base_url}/v1/models"
|
||||
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
response = await self.client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
return [m.get("id", "") for m in result.get("data", []) if m.get("id")]
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
|
||||
def extract_json_from_response(response: str) -> dict:
|
||||
"""Extract JSON from LLM response, handling markdown code blocks and extra whitespace."""
|
||||
# Remove markdown code blocks if present
|
||||
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)```", response)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
json_str = response
|
||||
|
||||
# Clean up: remove extra whitespace, normalize spaces
|
||||
json_str = json_str.strip()
|
||||
# Remove trailing whitespace before closing braces/brackets
|
||||
json_str = re.sub(r'\s+([}\]])', r'\1', json_str)
|
||||
# Remove multiple spaces/tabs/newlines
|
||||
json_str = re.sub(r'[\t\n\r]+', ' ', json_str)
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
|
||||
def parse_attribute_response(response: str) -> AttributeNode:
|
||||
"""Parse LLM response into AttributeNode structure."""
|
||||
data = extract_json_from_response(response)
|
||||
return AttributeNode.model_validate(data)
|
||||
|
||||
|
||||
def get_llm_provider(provider_type: str = "ollama") -> LLMProvider:
|
||||
"""Factory function to get LLM provider."""
|
||||
if provider_type == "ollama":
|
||||
return OllamaProvider()
|
||||
elif provider_type == "openai":
|
||||
return OpenAICompatibleProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
|
||||
ollama_provider = OllamaProvider()
|
||||
Reference in New Issue
Block a user