""" Compute metrics for experiment results. Computes metrics BOTH before and after deduplication: - Pre-dedup: Measures raw generation capability - Post-dedup: Measures quality of unique ideas Also normalizes idea counts for fair cross-condition comparison. Usage: python -m experiments.compute_metrics --input results/experiment_xxx_deduped.json """ import sys import json import argparse import asyncio import logging import random from pathlib import Path from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass, asdict import numpy as np # Add backend to path for imports sys.path.insert(0, str(Path(__file__).parent.parent / "backend")) from app.services.embedding_service import embedding_service from app.services.llm_service import ollama_provider, extract_json_from_response from experiments.config import RESULTS_DIR, MODEL, RANDOM_SEED # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) @dataclass class DiversityMetrics: """Semantic diversity metrics for a set of ideas.""" mean_pairwise_distance: float std_pairwise_distance: float min_pairwise_distance: float max_pairwise_distance: float idea_count: int @dataclass class ClusterMetrics: """Cluster analysis metrics.""" optimal_clusters: int silhouette_score: float cluster_sizes: List[int] @dataclass class QueryDistanceMetrics: """Distance from original query metrics.""" mean_distance: float std_distance: float min_distance: float max_distance: float distances: List[float] @dataclass class RelevanceMetrics: """LLM-as-judge relevance metrics (for hallucination detection).""" relevance_rate: float # Score >= 2 nonsense_rate: float # Score == 1 mean_score: float score_distribution: Dict[int, int] # {1: count, 2: count, 3: count} @dataclass class ConditionMetrics: """All metrics for a single condition.""" condition: str query: str # Idea counts raw_count: int unique_count: int survival_rate: float # Pre-dedup metrics (on raw ideas) pre_dedup_diversity: Optional[DiversityMetrics] # Post-dedup metrics (on unique ideas) post_dedup_diversity: Optional[DiversityMetrics] post_dedup_clusters: Optional[ClusterMetrics] post_dedup_query_distance: Optional[QueryDistanceMetrics] # Normalized metrics (on equal-sized samples) normalized_diversity: Optional[DiversityMetrics] normalized_sample_size: int # Relevance/hallucination (post-dedup only) relevance: Optional[RelevanceMetrics] # ============================================================ # Embedding-based metrics # ============================================================ async def get_embeddings(texts: List[str]) -> List[List[float]]: """Get embeddings for a list of texts.""" if not texts: return [] return await embedding_service.get_embeddings_batch(texts) def compute_pairwise_distances(embeddings: List[List[float]]) -> List[float]: """Compute all pairwise cosine distances.""" n = len(embeddings) if n < 2: return [] distances = [] for i in range(n): for j in range(i + 1, n): sim = embedding_service.cosine_similarity(embeddings[i], embeddings[j]) dist = 1 - sim # Convert similarity to distance distances.append(dist) return distances async def compute_diversity_metrics(ideas: List[str]) -> Optional[DiversityMetrics]: """Compute semantic diversity metrics for a set of ideas.""" if len(ideas) < 2: return None embeddings = await get_embeddings(ideas) distances = compute_pairwise_distances(embeddings) if not distances: return None return DiversityMetrics( mean_pairwise_distance=float(np.mean(distances)), std_pairwise_distance=float(np.std(distances)), min_pairwise_distance=float(np.min(distances)), max_pairwise_distance=float(np.max(distances)), idea_count=len(ideas) ) async def compute_query_distance_metrics( query: str, ideas: List[str] ) -> Optional[QueryDistanceMetrics]: """Compute distance of ideas from the original query.""" if not ideas: return None # Get query embedding query_emb = await embedding_service.get_embedding(query) idea_embs = await get_embeddings(ideas) distances = [] for emb in idea_embs: sim = embedding_service.cosine_similarity(query_emb, emb) dist = 1 - sim distances.append(dist) return QueryDistanceMetrics( mean_distance=float(np.mean(distances)), std_distance=float(np.std(distances)), min_distance=float(np.min(distances)), max_distance=float(np.max(distances)), distances=distances ) async def compute_cluster_metrics(ideas: List[str]) -> Optional[ClusterMetrics]: """Compute cluster analysis metrics.""" if len(ideas) < 3: return None try: from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score except ImportError: logger.warning("sklearn not installed, skipping cluster metrics") return None embeddings = await get_embeddings(ideas) embeddings_np = np.array(embeddings) # Find optimal k using silhouette score max_k = min(len(ideas) - 1, 10) if max_k < 2: return None best_k = 2 best_score = -1 for k in range(2, max_k + 1): try: kmeans = KMeans(n_clusters=k, random_state=RANDOM_SEED, n_init=10) labels = kmeans.fit_predict(embeddings_np) score = silhouette_score(embeddings_np, labels) if score > best_score: best_score = score best_k = k except Exception as e: logger.warning(f"Clustering failed for k={k}: {e}") continue # Get cluster sizes for optimal k kmeans = KMeans(n_clusters=best_k, random_state=RANDOM_SEED, n_init=10) labels = kmeans.fit_predict(embeddings_np) cluster_sizes = [int(np.sum(labels == i)) for i in range(best_k)] return ClusterMetrics( optimal_clusters=best_k, silhouette_score=float(best_score), cluster_sizes=sorted(cluster_sizes, reverse=True) ) # ============================================================ # LLM-as-Judge relevance metrics # ============================================================ async def judge_relevance(query: str, idea: str, model: str = None) -> Dict[str, Any]: """Use LLM to judge if an idea is relevant to the query.""" model = model or MODEL prompt = f"""/no_think You are evaluating whether a generated idea is relevant and applicable to an original query. Original query: {query} Generated idea: {idea} Rate the relevance on a scale of 1-3: 1 = Nonsense/completely irrelevant (no logical connection to the query) 2 = Weak but valid connection (requires stretch but has some relevance) 3 = Clearly relevant and applicable (directly relates to the query) Return JSON only: {{"score": N, "reason": "brief explanation (10-20 words)"}} """ try: response = await ollama_provider.generate( prompt=prompt, model=model, temperature=0.3 # Lower temperature for more consistent judgments ) result = extract_json_from_response(response) return { "score": result.get("score", 2), "reason": result.get("reason", "") } except Exception as e: logger.warning(f"Relevance judgment failed: {e}") return {"score": 2, "reason": "judgment failed"} async def compute_relevance_metrics( query: str, ideas: List[str], model: str = None, sample_size: int = None ) -> Optional[RelevanceMetrics]: """Compute LLM-as-judge relevance metrics for ideas.""" if not ideas: return None # Optionally sample to reduce API calls if sample_size and len(ideas) > sample_size: rng = random.Random(RANDOM_SEED) ideas_to_judge = rng.sample(ideas, sample_size) else: ideas_to_judge = ideas scores = [] for idea in ideas_to_judge: result = await judge_relevance(query, idea, model) scores.append(result["score"]) # Compute distribution distribution = {1: 0, 2: 0, 3: 0} for s in scores: if s in distribution: distribution[s] += 1 nonsense_count = distribution[1] relevant_count = distribution[2] + distribution[3] return RelevanceMetrics( relevance_rate=relevant_count / len(scores) if scores else 0, nonsense_rate=nonsense_count / len(scores) if scores else 0, mean_score=float(np.mean(scores)) if scores else 0, score_distribution=distribution ) # ============================================================ # Main metrics computation # ============================================================ async def compute_condition_metrics( query: str, condition: str, raw_ideas: List[str], unique_ideas: List[str], normalized_sample_size: int, compute_relevance: bool = False ) -> ConditionMetrics: """Compute all metrics for a single condition.""" raw_count = len(raw_ideas) unique_count = len(unique_ideas) survival_rate = unique_count / raw_count if raw_count > 0 else 1.0 logger.info(f" Computing metrics for {condition}...") logger.info(f" Raw: {raw_count}, Unique: {unique_count}, Survival: {survival_rate:.1%}") # Pre-dedup diversity (on raw ideas) logger.info(f" Computing pre-dedup diversity...") pre_dedup_diversity = await compute_diversity_metrics(raw_ideas) # Post-dedup diversity (on unique ideas) logger.info(f" Computing post-dedup diversity...") post_dedup_diversity = await compute_diversity_metrics(unique_ideas) # Cluster analysis (post-dedup) logger.info(f" Computing cluster metrics...") post_dedup_clusters = await compute_cluster_metrics(unique_ideas) # Query distance (post-dedup) logger.info(f" Computing query distance...") post_dedup_query_distance = await compute_query_distance_metrics(query, unique_ideas) # Normalized diversity (equal-sized sample for fair comparison) normalized_diversity = None if len(unique_ideas) >= normalized_sample_size and normalized_sample_size > 1: logger.info(f" Computing normalized diversity (n={normalized_sample_size})...") rng = random.Random(RANDOM_SEED) sampled_ideas = rng.sample(unique_ideas, normalized_sample_size) normalized_diversity = await compute_diversity_metrics(sampled_ideas) # Relevance metrics (optional, expensive) relevance = None if compute_relevance and unique_ideas: logger.info(f" Computing relevance metrics (LLM-as-judge)...") # Sample up to 10 ideas to reduce cost relevance = await compute_relevance_metrics( query, unique_ideas, sample_size=min(10, len(unique_ideas)) ) return ConditionMetrics( condition=condition, query=query, raw_count=raw_count, unique_count=unique_count, survival_rate=survival_rate, pre_dedup_diversity=pre_dedup_diversity, post_dedup_diversity=post_dedup_diversity, post_dedup_clusters=post_dedup_clusters, post_dedup_query_distance=post_dedup_query_distance, normalized_diversity=normalized_diversity, normalized_sample_size=normalized_sample_size, relevance=relevance ) async def process_experiment_results( input_file: Path, output_file: Optional[Path] = None, compute_relevance: bool = False ) -> Dict[str, Any]: """ Process experiment results and compute all metrics. Args: input_file: Path to deduped experiment results JSON output_file: Path for output (default: input with _metrics suffix) compute_relevance: Whether to compute LLM-as-judge relevance Returns: Results with computed metrics """ # Load experiment results with open(input_file, "r", encoding="utf-8") as f: experiment = json.load(f) logger.info(f"Processing experiment: {experiment.get('experiment_id', 'unknown')}") # Determine normalized sample size (minimum unique count across all conditions) min_unique_count = float('inf') for query_result in experiment["results"]: for condition, cond_result in query_result["conditions"].items(): if cond_result.get("success", False): dedup = cond_result.get("dedup", {}) unique_count = len(dedup.get("unique_ideas", cond_result.get("ideas", []))) if unique_count > 0: min_unique_count = min(min_unique_count, unique_count) normalized_sample_size = min(int(min_unique_count), 10) if min_unique_count != float('inf') else 5 logger.info(f"Normalized sample size: {normalized_sample_size}") # Process each query all_metrics = [] for query_result in experiment["results"]: query = query_result["query"] query_id = query_result["query_id"] logger.info(f"\nProcessing query: {query} ({query_id})") query_metrics = { "query_id": query_id, "query": query, "conditions": {} } for condition, cond_result in query_result["conditions"].items(): if not cond_result.get("success", False): logger.warning(f" Skipping failed condition: {condition}") continue # Get raw and unique ideas raw_ideas = cond_result.get("ideas", []) dedup = cond_result.get("dedup", {}) unique_ideas = dedup.get("unique_ideas", raw_ideas) # Compute metrics metrics = await compute_condition_metrics( query=query, condition=condition, raw_ideas=raw_ideas, unique_ideas=unique_ideas, normalized_sample_size=normalized_sample_size, compute_relevance=compute_relevance ) # Convert to dict for JSON serialization query_metrics["conditions"][condition] = asdict(metrics) all_metrics.append(query_metrics) # Calculate aggregate statistics aggregate = calculate_aggregate_metrics(all_metrics) # Build output output = { "experiment_id": experiment.get("experiment_id"), "config": experiment.get("config"), "normalized_sample_size": normalized_sample_size, "metrics_by_query": all_metrics, "aggregate": aggregate } # Save results if output_file is None: stem = input_file.stem.replace("_deduped", "").replace("_complete", "") output_file = input_file.parent / f"{stem}_metrics.json" with open(output_file, "w", encoding="utf-8") as f: json.dump(output, f, indent=2, ensure_ascii=False) logger.info(f"\nMetrics saved to: {output_file}") return output def calculate_aggregate_metrics(all_metrics: List[Dict]) -> Dict[str, Any]: """Calculate aggregate statistics across all queries.""" aggregate = {} # Collect metrics by condition by_condition = {} for query_metrics in all_metrics: for condition, metrics in query_metrics["conditions"].items(): if condition not in by_condition: by_condition[condition] = { "raw_counts": [], "unique_counts": [], "survival_rates": [], "pre_dedup_diversity": [], "post_dedup_diversity": [], "normalized_diversity": [], "query_distances": [], "cluster_counts": [], "silhouette_scores": [], "relevance_rates": [], "nonsense_rates": [] } bc = by_condition[condition] bc["raw_counts"].append(metrics["raw_count"]) bc["unique_counts"].append(metrics["unique_count"]) bc["survival_rates"].append(metrics["survival_rate"]) if metrics.get("pre_dedup_diversity"): bc["pre_dedup_diversity"].append( metrics["pre_dedup_diversity"]["mean_pairwise_distance"] ) if metrics.get("post_dedup_diversity"): bc["post_dedup_diversity"].append( metrics["post_dedup_diversity"]["mean_pairwise_distance"] ) if metrics.get("normalized_diversity"): bc["normalized_diversity"].append( metrics["normalized_diversity"]["mean_pairwise_distance"] ) if metrics.get("post_dedup_query_distance"): bc["query_distances"].append( metrics["post_dedup_query_distance"]["mean_distance"] ) if metrics.get("post_dedup_clusters"): bc["cluster_counts"].append( metrics["post_dedup_clusters"]["optimal_clusters"] ) bc["silhouette_scores"].append( metrics["post_dedup_clusters"]["silhouette_score"] ) if metrics.get("relevance"): bc["relevance_rates"].append(metrics["relevance"]["relevance_rate"]) bc["nonsense_rates"].append(metrics["relevance"]["nonsense_rate"]) # Calculate means and stds for condition, data in by_condition.items(): aggregate[condition] = {} for metric_name, values in data.items(): if values: aggregate[condition][metric_name] = { "mean": float(np.mean(values)), "std": float(np.std(values)), "min": float(np.min(values)), "max": float(np.max(values)), "n": len(values) } return aggregate def print_metrics_summary(metrics: Dict[str, Any]): """Print a formatted summary of computed metrics.""" print("\n" + "=" * 80) print("METRICS SUMMARY") print("=" * 80) print(f"\nNormalized sample size: {metrics.get('normalized_sample_size', 'N/A')}") aggregate = metrics.get("aggregate", {}) # Idea counts print("\n--- Idea Counts ---") print(f"{'Condition':<25} {'Raw':<10} {'Unique':<10} {'Survival':<10}") print("-" * 55) for cond, data in aggregate.items(): raw = data.get("raw_counts", {}).get("mean", 0) unique = data.get("unique_counts", {}).get("mean", 0) survival = data.get("survival_rates", {}).get("mean", 0) print(f"{cond:<25} {raw:<10.1f} {unique:<10.1f} {survival:<10.1%}") # Diversity metrics print("\n--- Semantic Diversity (Mean Pairwise Distance) ---") print(f"{'Condition':<25} {'Pre-Dedup':<12} {'Post-Dedup':<12} {'Normalized':<12}") print("-" * 61) for cond, data in aggregate.items(): pre = data.get("pre_dedup_diversity", {}).get("mean", 0) post = data.get("post_dedup_diversity", {}).get("mean", 0) norm = data.get("normalized_diversity", {}).get("mean", 0) print(f"{cond:<25} {pre:<12.4f} {post:<12.4f} {norm:<12.4f}") # Query distance print("\n--- Query Distance (Novelty) ---") print(f"{'Condition':<25} {'Mean Distance':<15} {'Std':<10}") print("-" * 50) for cond, data in aggregate.items(): dist = data.get("query_distances", {}) mean = dist.get("mean", 0) std = dist.get("std", 0) print(f"{cond:<25} {mean:<15.4f} {std:<10.4f}") # Cluster metrics print("\n--- Cluster Analysis ---") print(f"{'Condition':<25} {'Clusters':<12} {'Silhouette':<12}") print("-" * 49) for cond, data in aggregate.items(): clusters = data.get("cluster_counts", {}).get("mean", 0) silhouette = data.get("silhouette_scores", {}).get("mean", 0) print(f"{cond:<25} {clusters:<12.1f} {silhouette:<12.4f}") # Relevance (if computed) has_relevance = any( "relevance_rates" in data and data["relevance_rates"].get("n", 0) > 0 for data in aggregate.values() ) if has_relevance: print("\n--- Relevance (LLM-as-Judge) ---") print(f"{'Condition':<25} {'Relevance':<12} {'Nonsense':<12}") print("-" * 49) for cond, data in aggregate.items(): rel = data.get("relevance_rates", {}).get("mean", 0) non = data.get("nonsense_rates", {}).get("mean", 0) print(f"{cond:<25} {rel:<12.1%} {non:<12.1%}") print("\n" + "=" * 80) print("Interpretation:") print("- Higher pairwise distance = more diverse ideas") print("- Higher query distance = more novel (farther from original)") print("- More clusters = more distinct themes") print("- Higher silhouette = cleaner cluster separation") print("=" * 80) async def main(): parser = argparse.ArgumentParser( description="Compute metrics for experiment results" ) parser.add_argument( "--input", type=str, required=True, help="Input deduped experiment results JSON file" ) parser.add_argument( "--output", type=str, help="Output file path (default: input_metrics.json)" ) parser.add_argument( "--relevance", action="store_true", help="Compute LLM-as-judge relevance metrics (expensive)" ) args = parser.parse_args() input_path = Path(args.input) if not input_path.exists(): input_path = RESULTS_DIR / args.input if not input_path.exists(): print(f"Error: Input file not found: {args.input}") sys.exit(1) output_path = Path(args.output) if args.output else None metrics = await process_experiment_results( input_file=input_path, output_file=output_path, compute_relevance=args.relevance ) print_metrics_summary(metrics) if __name__ == "__main__": asyncio.run(main())