""" Post-generation deduplication for experiment results. Applies embedding-based deduplication uniformly to all conditions to normalize idea counts and measure "dedup survival rate". Usage: python -m experiments.deduplication --input results/experiment_xxx.json """ import sys import json import argparse import asyncio import logging from pathlib import Path from typing import List, Dict, Any, Optional from dataclasses import dataclass # Add backend to path for imports sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) from app.services.embedding_service import embedding_service from app.models.schemas import ExpertTransformationDescription from experiments.config import DEDUP_THRESHOLD, RESULTS_DIR # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) @dataclass class DedupStats: """Deduplication statistics for a single condition.""" condition: str pre_dedup_count: int post_dedup_count: int duplicates_removed: int survival_rate: float groups: List[Dict[str, Any]] def ideas_to_descriptions( ideas: List[str], ideas_with_source: Optional[List[Dict[str, Any]]] = None ) -> List[ExpertTransformationDescription]: """ Convert experiment ideas to ExpertTransformationDescription format for compatibility with the embedding service. """ descriptions = [] if ideas_with_source: # Use source information if available for i, item in enumerate(ideas_with_source): desc = ExpertTransformationDescription( keyword=item.get("keyword", item.get("attribute", item.get("perspective_word", ""))), expert_id=f"source-{i}", expert_name=item.get("expert_name", item.get("perspective_word", "direct")), description=item.get("idea", "") ) descriptions.append(desc) else: # Simple conversion for ideas without source for i, idea in enumerate(ideas): desc = ExpertTransformationDescription( keyword="", expert_id=f"idea-{i}", expert_name="direct", description=idea ) descriptions.append(desc) return descriptions async def deduplicate_condition( ideas: List[str], ideas_with_source: Optional[List[Dict[str, Any]]] = None, threshold: float = DEDUP_THRESHOLD ) -> Dict[str, Any]: """ Apply deduplication to ideas from a single condition. Returns: Dict with deduplicated ideas and statistics """ if not ideas: return { "unique_ideas": [], "unique_ideas_with_source": [], "groups": [], "stats": { "pre_dedup_count": 0, "post_dedup_count": 0, "duplicates_removed": 0, "survival_rate": 1.0 } } # Convert to description format descriptions = ideas_to_descriptions(ideas, ideas_with_source) # Run deduplication result = await embedding_service.deduplicate( descriptions=descriptions, threshold=threshold ) # Extract unique ideas (representatives from each group) unique_ideas = [] unique_ideas_with_source = [] groups_info = [] for group in result.groups: rep = group.representative unique_ideas.append(rep.description) # Reconstruct source info source_info = { "idea": rep.description, "keyword": rep.keyword, "expert_name": rep.expert_name } unique_ideas_with_source.append(source_info) # Group info for analysis group_info = { "representative": rep.description, "duplicates": [d.description for d in group.duplicates], "duplicate_count": len(group.duplicates), "similarity_scores": group.similarity_scores } groups_info.append(group_info) pre_count = len(ideas) post_count = len(unique_ideas) survival_rate = post_count / pre_count if pre_count > 0 else 1.0 return { "unique_ideas": unique_ideas, "unique_ideas_with_source": unique_ideas_with_source, "groups": groups_info, "stats": { "pre_dedup_count": pre_count, "post_dedup_count": post_count, "duplicates_removed": pre_count - post_count, "survival_rate": survival_rate } } async def process_experiment_results( input_file: Path, output_file: Optional[Path] = None, threshold: float = DEDUP_THRESHOLD ) -> Dict[str, Any]: """ Process an experiment results file and apply deduplication. Args: input_file: Path to experiment results JSON output_file: Path for output (default: input_file with _deduped suffix) threshold: Similarity threshold for deduplication Returns: Processed results with deduplication applied """ # Load experiment results with open(input_file, "r", encoding="utf-8") as f: experiment = json.load(f) logger.info(f"Processing experiment: {experiment.get('experiment_id', 'unknown')}") logger.info(f"Deduplication threshold: {threshold}") # Process each query's conditions dedup_summary = { "threshold": threshold, "conditions": {} } for query_result in experiment["results"]: query = query_result["query"] query_id = query_result["query_id"] logger.info(f"\nProcessing query: {query} ({query_id})") for condition, cond_result in query_result["conditions"].items(): if not cond_result.get("success", False): logger.warning(f" Skipping failed condition: {condition}") continue logger.info(f" Deduplicating {condition}...") ideas = cond_result.get("ideas", []) ideas_with_source = cond_result.get("ideas_with_source", []) dedup_result = await deduplicate_condition( ideas=ideas, ideas_with_source=ideas_with_source, threshold=threshold ) # Add dedup results to condition cond_result["dedup"] = dedup_result # Update summary stats if condition not in dedup_summary["conditions"]: dedup_summary["conditions"][condition] = { "total_pre_dedup": 0, "total_post_dedup": 0, "total_removed": 0, "query_stats": [] } stats = dedup_result["stats"] cond_summary = dedup_summary["conditions"][condition] cond_summary["total_pre_dedup"] += stats["pre_dedup_count"] cond_summary["total_post_dedup"] += stats["post_dedup_count"] cond_summary["total_removed"] += stats["duplicates_removed"] cond_summary["query_stats"].append({ "query_id": query_id, "query": query, **stats }) logger.info(f" {stats['pre_dedup_count']} -> {stats['post_dedup_count']} " f"(survival: {stats['survival_rate']:.1%})") # Calculate overall survival rates for condition, cond_stats in dedup_summary["conditions"].items(): if cond_stats["total_pre_dedup"] > 0: cond_stats["overall_survival_rate"] = ( cond_stats["total_post_dedup"] / cond_stats["total_pre_dedup"] ) else: cond_stats["overall_survival_rate"] = 1.0 # Add dedup summary to experiment experiment["dedup_summary"] = dedup_summary # Save results if output_file is None: stem = input_file.stem.replace("_complete", "").replace("_intermediate", "") output_file = input_file.parent / f"{stem}_deduped.json" with open(output_file, "w", encoding="utf-8") as f: json.dump(experiment, f, indent=2, ensure_ascii=False) logger.info(f"\nResults saved to: {output_file}") return experiment def print_dedup_summary(experiment: Dict[str, Any]): """Print formatted deduplication summary.""" dedup = experiment.get("dedup_summary", {}) print("\n" + "=" * 70) print("DEDUPLICATION SUMMARY") print("=" * 70) print(f"Threshold: {dedup.get('threshold', 'N/A')}") print("\nResults by condition:") print("-" * 70) print(f"{'Condition':<30} {'Pre-Dedup':<12} {'Post-Dedup':<12} {'Survival':<10}") print("-" * 70) for condition, stats in dedup.get("conditions", {}).items(): pre = stats.get("total_pre_dedup", 0) post = stats.get("total_post_dedup", 0) survival = stats.get("overall_survival_rate", 1.0) print(f"{condition:<30} {pre:<12} {post:<12} {survival:<10.1%}") print("-" * 70) print("\nInterpretation:") print("- Higher survival rate = more diverse/unique ideas") print("- Lower survival rate = more redundant ideas removed") async def main(): parser = argparse.ArgumentParser( description="Apply deduplication to experiment results" ) parser.add_argument( "--input", type=str, required=True, help="Input experiment results JSON file" ) parser.add_argument( "--output", type=str, help="Output file path (default: input_deduped.json)" ) parser.add_argument( "--threshold", type=float, default=DEDUP_THRESHOLD, help=f"Similarity threshold (default: {DEDUP_THRESHOLD})" ) args = parser.parse_args() input_path = Path(args.input) if not input_path.exists(): # Try relative to results dir input_path = RESULTS_DIR / args.input if not input_path.exists(): print(f"Error: Input file not found: {args.input}") sys.exit(1) output_path = Path(args.output) if args.output else None experiment = await process_experiment_results( input_file=input_path, output_file=output_path, threshold=args.threshold ) print_dedup_summary(experiment) if __name__ == "__main__": asyncio.run(main())