275 lines
9.1 KiB
Python
275 lines
9.1 KiB
Python
"""
|
||
LLM Deduplication Service - Using LLM pairwise comparison for deduplication
|
||
|
||
Let LLM determine whether two descriptions are semantically duplicate, accelerated by parallel processing.
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import List, Tuple, Optional, Literal
|
||
|
||
import httpx
|
||
import numpy as np
|
||
|
||
from ..config import settings
|
||
from ..models.schemas import (
|
||
ExpertTransformationDescription,
|
||
DeduplicationResult,
|
||
DeduplicationMethod,
|
||
DescriptionGroup,
|
||
)
|
||
from ..prompts.language_config import LanguageType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMDeduplicationService:
|
||
"""LLM 去重服務:使用 LLM 成對比較判斷語意相似度"""
|
||
|
||
def __init__(self):
|
||
self.base_url = settings.ollama_base_url
|
||
self.default_model = "qwen3:4b" # 快速模型,適合簡單判斷
|
||
self.client = httpx.AsyncClient(timeout=60.0)
|
||
self.max_concurrent = 5 # 最大並行數,避免 Ollama 過載
|
||
|
||
def _get_comparison_prompt(self, desc1: str, desc2: str, lang: LanguageType = "zh") -> str:
|
||
"""Get comparison prompt in the specified language"""
|
||
if lang == "en":
|
||
return f"""Determine whether the following two innovative descriptions express the same or very similar concepts:
|
||
|
||
Description 1: {desc1}
|
||
|
||
Description 2: {desc2}
|
||
|
||
If both descriptions essentially express the same or very similar innovative concept, answer "YES"
|
||
If the two descriptions express different innovative concepts, answer "NO"
|
||
Only answer YES or NO, no other text"""
|
||
else:
|
||
return f"""判斷以下兩個創新描述是否表達相同或非常相似的概念:
|
||
|
||
描述1: {desc1}
|
||
|
||
描述2: {desc2}
|
||
|
||
如果兩者描述的創新概念本質相同或非常相似,回答 "YES"
|
||
如果兩者描述不同的創新概念,回答 "NO"
|
||
只回答 YES 或 NO,不要其他文字"""
|
||
|
||
async def compare_pair(
|
||
self,
|
||
desc1: str,
|
||
desc2: str,
|
||
model: str,
|
||
semaphore: asyncio.Semaphore,
|
||
lang: LanguageType = "zh"
|
||
) -> bool:
|
||
"""
|
||
Let LLM determine whether two descriptions are semantically duplicate
|
||
|
||
Args:
|
||
desc1: First description
|
||
desc2: Second description
|
||
model: LLM model name
|
||
semaphore: Concurrency control semaphore
|
||
lang: Language for the prompt
|
||
|
||
Returns:
|
||
bool: Whether the descriptions are duplicates
|
||
"""
|
||
async with semaphore: # Control concurrency
|
||
prompt = self._get_comparison_prompt(desc1, desc2, lang)
|
||
|
||
try:
|
||
response = await self.client.post(
|
||
f"{self.base_url}/api/generate",
|
||
json={
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
"options": {
|
||
"temperature": 0.1, # 低溫度以獲得一致的判斷
|
||
"num_predict": 10, # 只需要短回答
|
||
}
|
||
}
|
||
)
|
||
response.raise_for_status()
|
||
result = response.json()["response"].strip().upper()
|
||
is_similar = result.startswith("YES")
|
||
logger.debug(f"LLM comparison: '{desc1[:30]}...' vs '{desc2[:30]}...' -> {result} ({is_similar})")
|
||
return is_similar
|
||
except Exception as e:
|
||
logger.error(f"LLM comparison failed: {e}")
|
||
return False # 失敗時假設不相似
|
||
|
||
async def compare_batch(
|
||
self,
|
||
pairs: List[Tuple[int, int, str, str]],
|
||
model: str,
|
||
lang: LanguageType = "zh"
|
||
) -> List[Tuple[int, int, bool]]:
|
||
"""
|
||
Parallel batch comparison of multiple description pairs
|
||
|
||
Args:
|
||
pairs: List of pairs to compare [(i, j, desc1, desc2), ...]
|
||
model: LLM model name
|
||
lang: Language for the prompt
|
||
|
||
Returns:
|
||
List of comparison results [(i, j, is_similar), ...]
|
||
"""
|
||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||
|
||
async def compare_one(pair: Tuple[int, int, str, str]) -> Tuple[int, int, bool]:
|
||
i, j, desc1, desc2 = pair
|
||
is_similar = await self.compare_pair(desc1, desc2, model, semaphore, lang)
|
||
return (i, j, is_similar)
|
||
|
||
# Use asyncio.gather to execute all comparisons in parallel
|
||
results = await asyncio.gather(*[compare_one(p) for p in pairs])
|
||
return results
|
||
|
||
def cluster_by_similarity(
|
||
self,
|
||
similarity_matrix: np.ndarray,
|
||
threshold: float
|
||
) -> List[List[int]]:
|
||
"""
|
||
貪婪聚類:將相似度 >= threshold 的項目分組
|
||
|
||
與 embedding_service 使用相同的演算法
|
||
"""
|
||
n = len(similarity_matrix)
|
||
assigned = [False] * n
|
||
groups = []
|
||
|
||
for i in range(n):
|
||
if assigned[i]:
|
||
continue
|
||
|
||
# 開始新的分組,以 item i 為代表
|
||
group = [i]
|
||
assigned[i] = True
|
||
|
||
# 找出所有與 i 相似的項目
|
||
for j in range(i + 1, n):
|
||
if not assigned[j] and similarity_matrix[i][j] >= threshold:
|
||
group.append(j)
|
||
assigned[j] = True
|
||
|
||
groups.append(group)
|
||
|
||
return groups
|
||
|
||
async def deduplicate(
|
||
self,
|
||
descriptions: List[ExpertTransformationDescription],
|
||
model: Optional[str] = None,
|
||
lang: LanguageType = "zh"
|
||
) -> DeduplicationResult:
|
||
"""
|
||
Use LLM pairwise comparison for deduplication
|
||
|
||
Args:
|
||
descriptions: List of descriptions to deduplicate
|
||
model: LLM model name
|
||
lang: Language for the prompt
|
||
|
||
Returns:
|
||
DeduplicationResult: Deduplication result
|
||
"""
|
||
model = model or self.default_model
|
||
|
||
# 空輸入處理
|
||
if not descriptions:
|
||
return DeduplicationResult(
|
||
total_input=0,
|
||
total_groups=0,
|
||
total_duplicates=0,
|
||
groups=[],
|
||
threshold_used=0.5, # LLM 方法固定使用 0.5 閾值
|
||
method_used=DeduplicationMethod.LLM,
|
||
model_used=model
|
||
)
|
||
|
||
n = len(descriptions)
|
||
similarity_matrix = np.zeros((n, n))
|
||
|
||
# 對角線為 1(自己與自己相似)
|
||
for i in range(n):
|
||
similarity_matrix[i][i] = 1.0
|
||
|
||
# 建立所有需要比較的配對
|
||
pairs = []
|
||
for i in range(n):
|
||
for j in range(i + 1, n):
|
||
pairs.append((
|
||
i, j,
|
||
descriptions[i].description,
|
||
descriptions[j].description
|
||
))
|
||
|
||
total_pairs = len(pairs)
|
||
logger.info(f"LLM deduplication: {total_pairs} pairs to compare (parallel={self.max_concurrent}, model={model}, lang={lang})")
|
||
|
||
# Parallel batch comparison
|
||
results = await self.compare_batch(pairs, model, lang)
|
||
|
||
# 填入相似度矩陣
|
||
for i, j, is_similar in results:
|
||
similarity_value = 1.0 if is_similar else 0.0
|
||
similarity_matrix[i][j] = similarity_value
|
||
similarity_matrix[j][i] = similarity_value
|
||
|
||
# 使用閾值 0.5 聚類(因為 LLM 輸出只有 0/1)
|
||
logger.info("Clustering results...")
|
||
clusters = self.cluster_by_similarity(similarity_matrix, 0.5)
|
||
|
||
# 建立結果分組
|
||
result_groups = []
|
||
total_duplicates = 0
|
||
|
||
for group_idx, indices in enumerate(clusters):
|
||
if len(indices) == 1:
|
||
# 獨立項目 - 無重複
|
||
result_groups.append(DescriptionGroup(
|
||
group_id=f"group-{group_idx}",
|
||
representative=descriptions[indices[0]],
|
||
duplicates=[],
|
||
similarity_scores=[]
|
||
))
|
||
else:
|
||
# 有重複的分組 - 第一個為代表
|
||
rep_idx = indices[0]
|
||
dup_indices = indices[1:]
|
||
# LLM 方法的相似度分數都是 1.0(因為是 YES/NO 判斷)
|
||
dup_scores = [1.0 for _ in dup_indices]
|
||
|
||
result_groups.append(DescriptionGroup(
|
||
group_id=f"group-{group_idx}",
|
||
representative=descriptions[rep_idx],
|
||
duplicates=[descriptions[idx] for idx in dup_indices],
|
||
similarity_scores=dup_scores
|
||
))
|
||
total_duplicates += len(dup_indices)
|
||
|
||
logger.info(f"LLM deduplication complete: {n} -> {len(result_groups)} groups, {total_duplicates} duplicates found")
|
||
|
||
return DeduplicationResult(
|
||
total_input=n,
|
||
total_groups=len(result_groups),
|
||
total_duplicates=total_duplicates,
|
||
groups=result_groups,
|
||
threshold_used=0.5, # LLM 方法固定使用 0.5 閾值
|
||
method_used=DeduplicationMethod.LLM,
|
||
model_used=model
|
||
)
|
||
|
||
async def close(self):
|
||
"""關閉 HTTP 客戶端"""
|
||
await self.client.aclose()
|
||
|
||
|
||
# 全域實例
|
||
llm_deduplication_service = LLMDeduplicationService()
|