Files
novelty-seeking/backend/app/services/llm_deduplication_service.py
2026-01-05 22:32:08 +08:00

275 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
LLM Deduplication Service - Using LLM pairwise comparison for deduplication
Let LLM determine whether two descriptions are semantically duplicate, accelerated by parallel processing.
"""
import asyncio
import logging
from typing import List, Tuple, Optional, Literal
import httpx
import numpy as np
from ..config import settings
from ..models.schemas import (
ExpertTransformationDescription,
DeduplicationResult,
DeduplicationMethod,
DescriptionGroup,
)
from ..prompts.language_config import LanguageType
logger = logging.getLogger(__name__)
class LLMDeduplicationService:
"""LLM 去重服務:使用 LLM 成對比較判斷語意相似度"""
def __init__(self):
self.base_url = settings.ollama_base_url
self.default_model = "qwen3:4b" # 快速模型,適合簡單判斷
self.client = httpx.AsyncClient(timeout=60.0)
self.max_concurrent = 5 # 最大並行數,避免 Ollama 過載
def _get_comparison_prompt(self, desc1: str, desc2: str, lang: LanguageType = "zh") -> str:
"""Get comparison prompt in the specified language"""
if lang == "en":
return f"""Determine whether the following two innovative descriptions express the same or very similar concepts:
Description 1: {desc1}
Description 2: {desc2}
If both descriptions essentially express the same or very similar innovative concept, answer "YES"
If the two descriptions express different innovative concepts, answer "NO"
Only answer YES or NO, no other text"""
else:
return f"""判斷以下兩個創新描述是否表達相同或非常相似的概念:
描述1: {desc1}
描述2: {desc2}
如果兩者描述的創新概念本質相同或非常相似,回答 "YES"
如果兩者描述不同的創新概念,回答 "NO"
只回答 YES 或 NO不要其他文字"""
async def compare_pair(
self,
desc1: str,
desc2: str,
model: str,
semaphore: asyncio.Semaphore,
lang: LanguageType = "zh"
) -> bool:
"""
Let LLM determine whether two descriptions are semantically duplicate
Args:
desc1: First description
desc2: Second description
model: LLM model name
semaphore: Concurrency control semaphore
lang: Language for the prompt
Returns:
bool: Whether the descriptions are duplicates
"""
async with semaphore: # Control concurrency
prompt = self._get_comparison_prompt(desc1, desc2, lang)
try:
response = await self.client.post(
f"{self.base_url}/api/generate",
json={
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1, # 低溫度以獲得一致的判斷
"num_predict": 10, # 只需要短回答
}
}
)
response.raise_for_status()
result = response.json()["response"].strip().upper()
is_similar = result.startswith("YES")
logger.debug(f"LLM comparison: '{desc1[:30]}...' vs '{desc2[:30]}...' -> {result} ({is_similar})")
return is_similar
except Exception as e:
logger.error(f"LLM comparison failed: {e}")
return False # 失敗時假設不相似
async def compare_batch(
self,
pairs: List[Tuple[int, int, str, str]],
model: str,
lang: LanguageType = "zh"
) -> List[Tuple[int, int, bool]]:
"""
Parallel batch comparison of multiple description pairs
Args:
pairs: List of pairs to compare [(i, j, desc1, desc2), ...]
model: LLM model name
lang: Language for the prompt
Returns:
List of comparison results [(i, j, is_similar), ...]
"""
semaphore = asyncio.Semaphore(self.max_concurrent)
async def compare_one(pair: Tuple[int, int, str, str]) -> Tuple[int, int, bool]:
i, j, desc1, desc2 = pair
is_similar = await self.compare_pair(desc1, desc2, model, semaphore, lang)
return (i, j, is_similar)
# Use asyncio.gather to execute all comparisons in parallel
results = await asyncio.gather(*[compare_one(p) for p in pairs])
return results
def cluster_by_similarity(
self,
similarity_matrix: np.ndarray,
threshold: float
) -> List[List[int]]:
"""
貪婪聚類:將相似度 >= threshold 的項目分組
與 embedding_service 使用相同的演算法
"""
n = len(similarity_matrix)
assigned = [False] * n
groups = []
for i in range(n):
if assigned[i]:
continue
# 開始新的分組,以 item i 為代表
group = [i]
assigned[i] = True
# 找出所有與 i 相似的項目
for j in range(i + 1, n):
if not assigned[j] and similarity_matrix[i][j] >= threshold:
group.append(j)
assigned[j] = True
groups.append(group)
return groups
async def deduplicate(
self,
descriptions: List[ExpertTransformationDescription],
model: Optional[str] = None,
lang: LanguageType = "zh"
) -> DeduplicationResult:
"""
Use LLM pairwise comparison for deduplication
Args:
descriptions: List of descriptions to deduplicate
model: LLM model name
lang: Language for the prompt
Returns:
DeduplicationResult: Deduplication result
"""
model = model or self.default_model
# 空輸入處理
if not descriptions:
return DeduplicationResult(
total_input=0,
total_groups=0,
total_duplicates=0,
groups=[],
threshold_used=0.5, # LLM 方法固定使用 0.5 閾值
method_used=DeduplicationMethod.LLM,
model_used=model
)
n = len(descriptions)
similarity_matrix = np.zeros((n, n))
# 對角線為 1自己與自己相似
for i in range(n):
similarity_matrix[i][i] = 1.0
# 建立所有需要比較的配對
pairs = []
for i in range(n):
for j in range(i + 1, n):
pairs.append((
i, j,
descriptions[i].description,
descriptions[j].description
))
total_pairs = len(pairs)
logger.info(f"LLM deduplication: {total_pairs} pairs to compare (parallel={self.max_concurrent}, model={model}, lang={lang})")
# Parallel batch comparison
results = await self.compare_batch(pairs, model, lang)
# 填入相似度矩陣
for i, j, is_similar in results:
similarity_value = 1.0 if is_similar else 0.0
similarity_matrix[i][j] = similarity_value
similarity_matrix[j][i] = similarity_value
# 使用閾值 0.5 聚類(因為 LLM 輸出只有 0/1
logger.info("Clustering results...")
clusters = self.cluster_by_similarity(similarity_matrix, 0.5)
# 建立結果分組
result_groups = []
total_duplicates = 0
for group_idx, indices in enumerate(clusters):
if len(indices) == 1:
# 獨立項目 - 無重複
result_groups.append(DescriptionGroup(
group_id=f"group-{group_idx}",
representative=descriptions[indices[0]],
duplicates=[],
similarity_scores=[]
))
else:
# 有重複的分組 - 第一個為代表
rep_idx = indices[0]
dup_indices = indices[1:]
# LLM 方法的相似度分數都是 1.0(因為是 YES/NO 判斷)
dup_scores = [1.0 for _ in dup_indices]
result_groups.append(DescriptionGroup(
group_id=f"group-{group_idx}",
representative=descriptions[rep_idx],
duplicates=[descriptions[idx] for idx in dup_indices],
similarity_scores=dup_scores
))
total_duplicates += len(dup_indices)
logger.info(f"LLM deduplication complete: {n} -> {len(result_groups)} groups, {total_duplicates} duplicates found")
return DeduplicationResult(
total_input=n,
total_groups=len(result_groups),
total_duplicates=total_duplicates,
groups=result_groups,
threshold_used=0.5, # LLM 方法固定使用 0.5 閾值
method_used=DeduplicationMethod.LLM,
model_used=model
)
async def close(self):
"""關閉 HTTP 客戶端"""
await self.client.aclose()
# 全域實例
llm_deduplication_service = LLMDeduplicationService()