""" Visualization for experiment results. Generates: - Box plots of diversity by condition - 2×2 interaction plots - Bar charts of survival rates - t-SNE/UMAP of idea embeddings (optional) Usage: python -m experiments.visualize --input results/experiment_xxx_metrics.json """ import sys import json import argparse from pathlib import Path from typing import List, Dict, Any, Optional import numpy as np # Add experiments to path sys.path.insert(0, str(Path(__file__).parent.parent)) from experiments.config import RESULTS_DIR # Try to import visualization libraries try: import matplotlib.pyplot as plt import matplotlib.patches as mpatches MATPLOTLIB_AVAILABLE = True except ImportError: MATPLOTLIB_AVAILABLE = False print("Warning: matplotlib not installed. Visualization unavailable.") print("Install with: pip install matplotlib") # Condition display names and colors CONDITION_LABELS = { "c1_direct": "C1: Direct", "c2_expert_only": "C2: Expert-Only", "c3_attribute_only": "C3: Attr-Only", "c4_full_pipeline": "C4: Full Pipeline", "c5_random_perspective": "C5: Random" } CONDITION_COLORS = { "c1_direct": "#808080", # Gray (baseline) "c2_expert_only": "#2196F3", # Blue "c3_attribute_only": "#FF9800", # Orange "c4_full_pipeline": "#4CAF50", # Green (main) "c5_random_perspective": "#9C27B0" # Purple (control) } # 2×2 factorial structure FACTORIAL_2X2 = { "no_attr_no_expert": "c1_direct", "no_attr_with_expert": "c2_expert_only", "with_attr_no_expert": "c3_attribute_only", "with_attr_with_expert": "c4_full_pipeline" } def extract_metric_values( metrics: Dict[str, Any], metric_path: str ) -> Dict[str, List[float]]: """Extract values for a specific metric across all queries.""" by_condition = {} for query_metrics in metrics.get("metrics_by_query", []): for condition, cond_metrics in query_metrics.get("conditions", {}).items(): if condition not in by_condition: by_condition[condition] = [] value = cond_metrics for key in metric_path.split("."): if value is None: break if isinstance(value, dict): value = value.get(key) else: value = None if value is not None and isinstance(value, (int, float)): by_condition[condition].append(float(value)) return by_condition def plot_box_comparison( metrics: Dict[str, Any], metric_path: str, title: str, ylabel: str, output_path: Path, figsize: tuple = (10, 6) ): """Create box plot comparing conditions.""" if not MATPLOTLIB_AVAILABLE: return by_condition = extract_metric_values(metrics, metric_path) # Order conditions ordered_conditions = [ "c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline", "c5_random_perspective" ] conditions = [c for c in ordered_conditions if c in by_condition] if not conditions: print(f"No data for {metric_path}") return fig, ax = plt.subplots(figsize=figsize) # Prepare data data = [by_condition[c] for c in conditions] labels = [CONDITION_LABELS.get(c, c) for c in conditions] colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions] # Create box plot bp = ax.boxplot(data, labels=labels, patch_artist=True) # Color boxes for patch, color in zip(bp['boxes'], colors): patch.set_facecolor(color) patch.set_alpha(0.7) # Add individual points for i, (cond, values) in enumerate(zip(conditions, data)): x = np.random.normal(i + 1, 0.04, size=len(values)) ax.scatter(x, values, alpha=0.6, color=colors[i], edgecolor='black', s=50) ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(axis='y', alpha=0.3) # Rotate labels if needed plt.xticks(rotation=15, ha='right') plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f"Saved: {output_path}") def plot_interaction_2x2( metrics: Dict[str, Any], metric_path: str, title: str, ylabel: str, output_path: Path, figsize: tuple = (8, 6) ): """Create 2×2 factorial interaction plot.""" if not MATPLOTLIB_AVAILABLE: return by_condition = extract_metric_values(metrics, metric_path) # Check if all 2×2 conditions available required = ["c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline"] if not all(c in by_condition and by_condition[c] for c in required): print(f"Insufficient data for 2×2 plot of {metric_path}") return fig, ax = plt.subplots(figsize=figsize) # Calculate means means = {c: np.mean(by_condition[c]) for c in required} stds = {c: np.std(by_condition[c], ddof=1) if len(by_condition[c]) > 1 else 0 for c in required} # X positions: No Experts, With Experts x = [0, 1] x_labels = ["Without Experts", "With Experts"] # Line 1: Without Attributes (C1 -> C2) y_no_attr = [means["c1_direct"], means["c2_expert_only"]] err_no_attr = [stds["c1_direct"], stds["c2_expert_only"]] ax.errorbar(x, y_no_attr, yerr=err_no_attr, marker='o', markersize=10, linewidth=2, capsize=5, label="Without Attributes", color="#FF9800", linestyle='--') # Line 2: With Attributes (C3 -> C4) y_with_attr = [means["c3_attribute_only"], means["c4_full_pipeline"]] err_with_attr = [stds["c3_attribute_only"], stds["c4_full_pipeline"]] ax.errorbar(x, y_with_attr, yerr=err_with_attr, marker='s', markersize=10, linewidth=2, capsize=5, label="With Attributes", color="#4CAF50", linestyle='-') # Annotate points offset = 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) if ax.get_ylim()[1] != ax.get_ylim()[0] else 0.01 ax.annotate("C1", (x[0], y_no_attr[0]), textcoords="offset points", xytext=(-15, -15), fontsize=9) ax.annotate("C2", (x[1], y_no_attr[1]), textcoords="offset points", xytext=(5, -15), fontsize=9) ax.annotate("C3", (x[0], y_with_attr[0]), textcoords="offset points", xytext=(-15, 10), fontsize=9) ax.annotate("C4", (x[1], y_with_attr[1]), textcoords="offset points", xytext=(5, 10), fontsize=9) ax.set_xticks(x) ax.set_xticklabels(x_labels) ax.set_ylabel(ylabel) ax.set_title(title) ax.legend(loc='best') ax.grid(axis='y', alpha=0.3) # Check for interaction (non-parallel lines) slope_no_attr = y_no_attr[1] - y_no_attr[0] slope_with_attr = y_with_attr[1] - y_with_attr[0] interaction = slope_with_attr - slope_no_attr interaction_text = f"Interaction: {interaction:+.4f}" if interaction > 0.01: interaction_text += " (super-additive)" elif interaction < -0.01: interaction_text += " (sub-additive)" else: interaction_text += " (additive)" ax.text(0.02, 0.98, interaction_text, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f"Saved: {output_path}") def plot_survival_rates( metrics: Dict[str, Any], output_path: Path, figsize: tuple = (10, 6) ): """Create bar chart of deduplication survival rates.""" if not MATPLOTLIB_AVAILABLE: return by_condition = extract_metric_values(metrics, "survival_rate") ordered_conditions = [ "c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline", "c5_random_perspective" ] conditions = [c for c in ordered_conditions if c in by_condition] if not conditions: print("No survival rate data") return fig, ax = plt.subplots(figsize=figsize) # Calculate means and stds means = [np.mean(by_condition[c]) * 100 for c in conditions] # Convert to percentage stds = [np.std(by_condition[c], ddof=1) * 100 if len(by_condition[c]) > 1 else 0 for c in conditions] labels = [CONDITION_LABELS.get(c, c) for c in conditions] colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions] x = np.arange(len(conditions)) bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, alpha=0.8, edgecolor='black') # Add value labels on bars for bar, mean in zip(bars, means): height = bar.get_height() ax.annotate(f'{mean:.1f}%', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=10) ax.set_xticks(x) ax.set_xticklabels(labels, rotation=15, ha='right') ax.set_ylabel("Survival Rate (%)") ax.set_title("Deduplication Survival Rate by Condition\n(Higher = More Diverse Generation)") ax.set_ylim(0, 110) ax.grid(axis='y', alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f"Saved: {output_path}") def plot_idea_counts( metrics: Dict[str, Any], output_path: Path, figsize: tuple = (10, 6) ): """Create stacked bar chart of raw vs unique idea counts.""" if not MATPLOTLIB_AVAILABLE: return raw_counts = extract_metric_values(metrics, "raw_count") unique_counts = extract_metric_values(metrics, "unique_count") ordered_conditions = [ "c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline", "c5_random_perspective" ] conditions = [c for c in ordered_conditions if c in raw_counts and c in unique_counts] if not conditions: print("No count data") return fig, ax = plt.subplots(figsize=figsize) # Calculate means raw_means = [np.mean(raw_counts[c]) for c in conditions] unique_means = [np.mean(unique_counts[c]) for c in conditions] removed_means = [r - u for r, u in zip(raw_means, unique_means)] labels = [CONDITION_LABELS.get(c, c) for c in conditions] x = np.arange(len(conditions)) width = 0.6 # Stacked bars: unique (bottom) + removed (top) bars1 = ax.bar(x, unique_means, width, label='Unique Ideas', color=[CONDITION_COLORS.get(c, "#888888") for c in conditions], alpha=0.9) bars2 = ax.bar(x, removed_means, width, bottom=unique_means, label='Duplicates Removed', color='lightgray', alpha=0.7, hatch='//') # Add value labels for i, (unique, raw) in enumerate(zip(unique_means, raw_means)): ax.annotate(f'{unique:.0f}', xy=(x[i], unique / 2), ha='center', va='center', fontsize=10, fontweight='bold') ax.annotate(f'({raw:.0f})', xy=(x[i], raw + 1), ha='center', va='bottom', fontsize=9, color='gray') ax.set_xticks(x) ax.set_xticklabels(labels, rotation=15, ha='right') ax.set_ylabel("Number of Ideas") ax.set_title("Idea Counts by Condition\n(Unique ideas shown, raw total in parentheses)") ax.legend(loc='upper right') ax.grid(axis='y', alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f"Saved: {output_path}") def plot_metrics_comparison( metrics: Dict[str, Any], output_path: Path, figsize: tuple = (12, 8) ): """Create multi-panel comparison of key metrics.""" if not MATPLOTLIB_AVAILABLE: return fig, axes = plt.subplots(2, 2, figsize=figsize) # Extract metrics metrics_to_plot = [ ("survival_rate", "Survival Rate", axes[0, 0], True), ("post_dedup_diversity.mean_pairwise_distance", "Semantic Diversity", axes[0, 1], False), ("post_dedup_query_distance.mean_distance", "Query Distance (Novelty)", axes[1, 0], False), ("post_dedup_clusters.optimal_clusters", "Number of Clusters", axes[1, 1], False), ] ordered_conditions = [ "c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline", "c5_random_perspective" ] for metric_path, title, ax, is_percentage in metrics_to_plot: by_condition = extract_metric_values(metrics, metric_path) conditions = [c for c in ordered_conditions if c in by_condition and by_condition[c]] if not conditions: ax.text(0.5, 0.5, "No data", ha='center', va='center', transform=ax.transAxes) ax.set_title(title) continue means = [np.mean(by_condition[c]) for c in conditions] if is_percentage: means = [m * 100 for m in means] colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions] x = np.arange(len(conditions)) bars = ax.bar(x, means, color=colors, alpha=0.8, edgecolor='black') # Simplified labels short_labels = ["C1", "C2", "C3", "C4", "C5"][:len(conditions)] ax.set_xticks(x) ax.set_xticklabels(short_labels) ax.set_title(title) ax.grid(axis='y', alpha=0.3) if is_percentage: ax.set_ylim(0, 110) # Add legend legend_elements = [ mpatches.Patch(facecolor=CONDITION_COLORS[c], label=CONDITION_LABELS[c]) for c in ordered_conditions if c in CONDITION_COLORS ] fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.02)) plt.tight_layout() plt.subplots_adjust(bottom=0.15) plt.savefig(output_path, dpi=150, bbox_inches='tight') plt.close() print(f"Saved: {output_path}") def generate_all_visualizations( metrics: Dict[str, Any], output_dir: Path ): """Generate all visualization figures.""" if not MATPLOTLIB_AVAILABLE: print("matplotlib not available. Cannot generate visualizations.") return output_dir.mkdir(parents=True, exist_ok=True) experiment_id = metrics.get("experiment_id", "experiment") print(f"\nGenerating visualizations for {experiment_id}...") # 1. Survival rates bar chart plot_survival_rates( metrics, output_dir / f"{experiment_id}_survival_rates.png" ) # 2. Idea counts stacked bar plot_idea_counts( metrics, output_dir / f"{experiment_id}_idea_counts.png" ) # 3. Diversity box plot plot_box_comparison( metrics, "post_dedup_diversity.mean_pairwise_distance", "Semantic Diversity by Condition (Post-Dedup)", "Mean Pairwise Distance", output_dir / f"{experiment_id}_diversity_boxplot.png" ) # 4. Query distance box plot plot_box_comparison( metrics, "post_dedup_query_distance.mean_distance", "Query Distance by Condition (Novelty)", "Distance from Original Query", output_dir / f"{experiment_id}_query_distance_boxplot.png" ) # 5. 2×2 interaction plot for diversity plot_interaction_2x2( metrics, "post_dedup_diversity.mean_pairwise_distance", "2×2 Factorial: Semantic Diversity", "Mean Pairwise Distance", output_dir / f"{experiment_id}_interaction_diversity.png" ) # 6. 2×2 interaction plot for query distance plot_interaction_2x2( metrics, "post_dedup_query_distance.mean_distance", "2×2 Factorial: Query Distance (Novelty)", "Distance from Original Query", output_dir / f"{experiment_id}_interaction_novelty.png" ) # 7. Multi-panel comparison plot_metrics_comparison( metrics, output_dir / f"{experiment_id}_metrics_comparison.png" ) print(f"\nAll visualizations saved to: {output_dir}") def main(): parser = argparse.ArgumentParser( description="Generate visualizations for experiment results" ) parser.add_argument( "--input", type=str, required=True, help="Input metrics JSON file" ) parser.add_argument( "--output-dir", type=str, help="Output directory for figures (default: results/figures/)" ) args = parser.parse_args() input_path = Path(args.input) if not input_path.exists(): input_path = RESULTS_DIR / args.input if not input_path.exists(): print(f"Error: Input file not found: {args.input}") sys.exit(1) # Load metrics with open(input_path, "r", encoding="utf-8") as f: metrics = json.load(f) # Output directory if args.output_dir: output_dir = Path(args.output_dir) else: output_dir = RESULTS_DIR / "figures" generate_all_visualizations(metrics, output_dir) if __name__ == "__main__": main()