- Add complete experiments directory with pilot study infrastructure - 5 experimental conditions (direct, expert-only, attribute-only, full-pipeline, random-perspective) - Human assessment tool with React frontend and FastAPI backend - AUT flexibility analysis with jump signal detection - Result visualization and metrics computation - Add novelty-driven agent loop module (experiments/novelty_loop/) - NoveltyDrivenTaskAgent with expert perspective perturbation - Three termination strategies: breakthrough, exhaust, coverage - Interactive CLI demo with colored output - Embedding-based novelty scoring - Add DDC knowledge domain classification data (en/zh) - Add CLAUDE.md project documentation - Update research report with experiment findings Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
522 lines
16 KiB
Python
522 lines
16 KiB
Python
"""
|
||
Visualization for experiment results.
|
||
|
||
Generates:
|
||
- Box plots of diversity by condition
|
||
- 2×2 interaction plots
|
||
- Bar charts of survival rates
|
||
- t-SNE/UMAP of idea embeddings (optional)
|
||
|
||
Usage:
|
||
python -m experiments.visualize --input results/experiment_xxx_metrics.json
|
||
"""
|
||
|
||
import sys
|
||
import json
|
||
import argparse
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
import numpy as np
|
||
|
||
# Add experiments to path
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from experiments.config import RESULTS_DIR
|
||
|
||
# Try to import visualization libraries
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
MATPLOTLIB_AVAILABLE = True
|
||
except ImportError:
|
||
MATPLOTLIB_AVAILABLE = False
|
||
print("Warning: matplotlib not installed. Visualization unavailable.")
|
||
print("Install with: pip install matplotlib")
|
||
|
||
# Condition display names and colors
|
||
CONDITION_LABELS = {
|
||
"c1_direct": "C1: Direct",
|
||
"c2_expert_only": "C2: Expert-Only",
|
||
"c3_attribute_only": "C3: Attr-Only",
|
||
"c4_full_pipeline": "C4: Full Pipeline",
|
||
"c5_random_perspective": "C5: Random"
|
||
}
|
||
|
||
CONDITION_COLORS = {
|
||
"c1_direct": "#808080", # Gray (baseline)
|
||
"c2_expert_only": "#2196F3", # Blue
|
||
"c3_attribute_only": "#FF9800", # Orange
|
||
"c4_full_pipeline": "#4CAF50", # Green (main)
|
||
"c5_random_perspective": "#9C27B0" # Purple (control)
|
||
}
|
||
|
||
# 2×2 factorial structure
|
||
FACTORIAL_2X2 = {
|
||
"no_attr_no_expert": "c1_direct",
|
||
"no_attr_with_expert": "c2_expert_only",
|
||
"with_attr_no_expert": "c3_attribute_only",
|
||
"with_attr_with_expert": "c4_full_pipeline"
|
||
}
|
||
|
||
|
||
def extract_metric_values(
|
||
metrics: Dict[str, Any],
|
||
metric_path: str
|
||
) -> Dict[str, List[float]]:
|
||
"""Extract values for a specific metric across all queries."""
|
||
by_condition = {}
|
||
|
||
for query_metrics in metrics.get("metrics_by_query", []):
|
||
for condition, cond_metrics in query_metrics.get("conditions", {}).items():
|
||
if condition not in by_condition:
|
||
by_condition[condition] = []
|
||
|
||
value = cond_metrics
|
||
for key in metric_path.split("."):
|
||
if value is None:
|
||
break
|
||
if isinstance(value, dict):
|
||
value = value.get(key)
|
||
else:
|
||
value = None
|
||
|
||
if value is not None and isinstance(value, (int, float)):
|
||
by_condition[condition].append(float(value))
|
||
|
||
return by_condition
|
||
|
||
|
||
def plot_box_comparison(
|
||
metrics: Dict[str, Any],
|
||
metric_path: str,
|
||
title: str,
|
||
ylabel: str,
|
||
output_path: Path,
|
||
figsize: tuple = (10, 6)
|
||
):
|
||
"""Create box plot comparing conditions."""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
return
|
||
|
||
by_condition = extract_metric_values(metrics, metric_path)
|
||
|
||
# Order conditions
|
||
ordered_conditions = [
|
||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||
"c4_full_pipeline", "c5_random_perspective"
|
||
]
|
||
conditions = [c for c in ordered_conditions if c in by_condition]
|
||
|
||
if not conditions:
|
||
print(f"No data for {metric_path}")
|
||
return
|
||
|
||
fig, ax = plt.subplots(figsize=figsize)
|
||
|
||
# Prepare data
|
||
data = [by_condition[c] for c in conditions]
|
||
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
|
||
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
|
||
|
||
# Create box plot
|
||
bp = ax.boxplot(data, labels=labels, patch_artist=True)
|
||
|
||
# Color boxes
|
||
for patch, color in zip(bp['boxes'], colors):
|
||
patch.set_facecolor(color)
|
||
patch.set_alpha(0.7)
|
||
|
||
# Add individual points
|
||
for i, (cond, values) in enumerate(zip(conditions, data)):
|
||
x = np.random.normal(i + 1, 0.04, size=len(values))
|
||
ax.scatter(x, values, alpha=0.6, color=colors[i], edgecolor='black', s=50)
|
||
|
||
ax.set_ylabel(ylabel)
|
||
ax.set_title(title)
|
||
ax.grid(axis='y', alpha=0.3)
|
||
|
||
# Rotate labels if needed
|
||
plt.xticks(rotation=15, ha='right')
|
||
plt.tight_layout()
|
||
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"Saved: {output_path}")
|
||
|
||
|
||
def plot_interaction_2x2(
|
||
metrics: Dict[str, Any],
|
||
metric_path: str,
|
||
title: str,
|
||
ylabel: str,
|
||
output_path: Path,
|
||
figsize: tuple = (8, 6)
|
||
):
|
||
"""Create 2×2 factorial interaction plot."""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
return
|
||
|
||
by_condition = extract_metric_values(metrics, metric_path)
|
||
|
||
# Check if all 2×2 conditions available
|
||
required = ["c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline"]
|
||
if not all(c in by_condition and by_condition[c] for c in required):
|
||
print(f"Insufficient data for 2×2 plot of {metric_path}")
|
||
return
|
||
|
||
fig, ax = plt.subplots(figsize=figsize)
|
||
|
||
# Calculate means
|
||
means = {c: np.mean(by_condition[c]) for c in required}
|
||
stds = {c: np.std(by_condition[c], ddof=1) if len(by_condition[c]) > 1 else 0 for c in required}
|
||
|
||
# X positions: No Experts, With Experts
|
||
x = [0, 1]
|
||
x_labels = ["Without Experts", "With Experts"]
|
||
|
||
# Line 1: Without Attributes (C1 -> C2)
|
||
y_no_attr = [means["c1_direct"], means["c2_expert_only"]]
|
||
err_no_attr = [stds["c1_direct"], stds["c2_expert_only"]]
|
||
ax.errorbar(x, y_no_attr, yerr=err_no_attr, marker='o', markersize=10,
|
||
linewidth=2, capsize=5, label="Without Attributes",
|
||
color="#FF9800", linestyle='--')
|
||
|
||
# Line 2: With Attributes (C3 -> C4)
|
||
y_with_attr = [means["c3_attribute_only"], means["c4_full_pipeline"]]
|
||
err_with_attr = [stds["c3_attribute_only"], stds["c4_full_pipeline"]]
|
||
ax.errorbar(x, y_with_attr, yerr=err_with_attr, marker='s', markersize=10,
|
||
linewidth=2, capsize=5, label="With Attributes",
|
||
color="#4CAF50", linestyle='-')
|
||
|
||
# Annotate points
|
||
offset = 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) if ax.get_ylim()[1] != ax.get_ylim()[0] else 0.01
|
||
ax.annotate("C1", (x[0], y_no_attr[0]), textcoords="offset points",
|
||
xytext=(-15, -15), fontsize=9)
|
||
ax.annotate("C2", (x[1], y_no_attr[1]), textcoords="offset points",
|
||
xytext=(5, -15), fontsize=9)
|
||
ax.annotate("C3", (x[0], y_with_attr[0]), textcoords="offset points",
|
||
xytext=(-15, 10), fontsize=9)
|
||
ax.annotate("C4", (x[1], y_with_attr[1]), textcoords="offset points",
|
||
xytext=(5, 10), fontsize=9)
|
||
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(x_labels)
|
||
ax.set_ylabel(ylabel)
|
||
ax.set_title(title)
|
||
ax.legend(loc='best')
|
||
ax.grid(axis='y', alpha=0.3)
|
||
|
||
# Check for interaction (non-parallel lines)
|
||
slope_no_attr = y_no_attr[1] - y_no_attr[0]
|
||
slope_with_attr = y_with_attr[1] - y_with_attr[0]
|
||
interaction = slope_with_attr - slope_no_attr
|
||
|
||
interaction_text = f"Interaction: {interaction:+.4f}"
|
||
if interaction > 0.01:
|
||
interaction_text += " (super-additive)"
|
||
elif interaction < -0.01:
|
||
interaction_text += " (sub-additive)"
|
||
else:
|
||
interaction_text += " (additive)"
|
||
|
||
ax.text(0.02, 0.98, interaction_text, transform=ax.transAxes,
|
||
fontsize=10, verticalalignment='top',
|
||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"Saved: {output_path}")
|
||
|
||
|
||
def plot_survival_rates(
|
||
metrics: Dict[str, Any],
|
||
output_path: Path,
|
||
figsize: tuple = (10, 6)
|
||
):
|
||
"""Create bar chart of deduplication survival rates."""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
return
|
||
|
||
by_condition = extract_metric_values(metrics, "survival_rate")
|
||
|
||
ordered_conditions = [
|
||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||
"c4_full_pipeline", "c5_random_perspective"
|
||
]
|
||
conditions = [c for c in ordered_conditions if c in by_condition]
|
||
|
||
if not conditions:
|
||
print("No survival rate data")
|
||
return
|
||
|
||
fig, ax = plt.subplots(figsize=figsize)
|
||
|
||
# Calculate means and stds
|
||
means = [np.mean(by_condition[c]) * 100 for c in conditions] # Convert to percentage
|
||
stds = [np.std(by_condition[c], ddof=1) * 100 if len(by_condition[c]) > 1 else 0 for c in conditions]
|
||
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
|
||
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
|
||
|
||
x = np.arange(len(conditions))
|
||
bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, alpha=0.8, edgecolor='black')
|
||
|
||
# Add value labels on bars
|
||
for bar, mean in zip(bars, means):
|
||
height = bar.get_height()
|
||
ax.annotate(f'{mean:.1f}%',
|
||
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||
xytext=(0, 3), textcoords="offset points",
|
||
ha='center', va='bottom', fontsize=10)
|
||
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(labels, rotation=15, ha='right')
|
||
ax.set_ylabel("Survival Rate (%)")
|
||
ax.set_title("Deduplication Survival Rate by Condition\n(Higher = More Diverse Generation)")
|
||
ax.set_ylim(0, 110)
|
||
ax.grid(axis='y', alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"Saved: {output_path}")
|
||
|
||
|
||
def plot_idea_counts(
|
||
metrics: Dict[str, Any],
|
||
output_path: Path,
|
||
figsize: tuple = (10, 6)
|
||
):
|
||
"""Create stacked bar chart of raw vs unique idea counts."""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
return
|
||
|
||
raw_counts = extract_metric_values(metrics, "raw_count")
|
||
unique_counts = extract_metric_values(metrics, "unique_count")
|
||
|
||
ordered_conditions = [
|
||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||
"c4_full_pipeline", "c5_random_perspective"
|
||
]
|
||
conditions = [c for c in ordered_conditions if c in raw_counts and c in unique_counts]
|
||
|
||
if not conditions:
|
||
print("No count data")
|
||
return
|
||
|
||
fig, ax = plt.subplots(figsize=figsize)
|
||
|
||
# Calculate means
|
||
raw_means = [np.mean(raw_counts[c]) for c in conditions]
|
||
unique_means = [np.mean(unique_counts[c]) for c in conditions]
|
||
removed_means = [r - u for r, u in zip(raw_means, unique_means)]
|
||
|
||
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
|
||
x = np.arange(len(conditions))
|
||
width = 0.6
|
||
|
||
# Stacked bars: unique (bottom) + removed (top)
|
||
bars1 = ax.bar(x, unique_means, width, label='Unique Ideas',
|
||
color=[CONDITION_COLORS.get(c, "#888888") for c in conditions], alpha=0.9)
|
||
bars2 = ax.bar(x, removed_means, width, bottom=unique_means, label='Duplicates Removed',
|
||
color='lightgray', alpha=0.7, hatch='//')
|
||
|
||
# Add value labels
|
||
for i, (unique, raw) in enumerate(zip(unique_means, raw_means)):
|
||
ax.annotate(f'{unique:.0f}', xy=(x[i], unique / 2),
|
||
ha='center', va='center', fontsize=10, fontweight='bold')
|
||
ax.annotate(f'({raw:.0f})', xy=(x[i], raw + 1),
|
||
ha='center', va='bottom', fontsize=9, color='gray')
|
||
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(labels, rotation=15, ha='right')
|
||
ax.set_ylabel("Number of Ideas")
|
||
ax.set_title("Idea Counts by Condition\n(Unique ideas shown, raw total in parentheses)")
|
||
ax.legend(loc='upper right')
|
||
ax.grid(axis='y', alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"Saved: {output_path}")
|
||
|
||
|
||
def plot_metrics_comparison(
|
||
metrics: Dict[str, Any],
|
||
output_path: Path,
|
||
figsize: tuple = (12, 8)
|
||
):
|
||
"""Create multi-panel comparison of key metrics."""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
return
|
||
|
||
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
||
|
||
# Extract metrics
|
||
metrics_to_plot = [
|
||
("survival_rate", "Survival Rate", axes[0, 0], True),
|
||
("post_dedup_diversity.mean_pairwise_distance", "Semantic Diversity", axes[0, 1], False),
|
||
("post_dedup_query_distance.mean_distance", "Query Distance (Novelty)", axes[1, 0], False),
|
||
("post_dedup_clusters.optimal_clusters", "Number of Clusters", axes[1, 1], False),
|
||
]
|
||
|
||
ordered_conditions = [
|
||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||
"c4_full_pipeline", "c5_random_perspective"
|
||
]
|
||
|
||
for metric_path, title, ax, is_percentage in metrics_to_plot:
|
||
by_condition = extract_metric_values(metrics, metric_path)
|
||
conditions = [c for c in ordered_conditions if c in by_condition and by_condition[c]]
|
||
|
||
if not conditions:
|
||
ax.text(0.5, 0.5, "No data", ha='center', va='center', transform=ax.transAxes)
|
||
ax.set_title(title)
|
||
continue
|
||
|
||
means = [np.mean(by_condition[c]) for c in conditions]
|
||
if is_percentage:
|
||
means = [m * 100 for m in means]
|
||
|
||
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
|
||
x = np.arange(len(conditions))
|
||
|
||
bars = ax.bar(x, means, color=colors, alpha=0.8, edgecolor='black')
|
||
|
||
# Simplified labels
|
||
short_labels = ["C1", "C2", "C3", "C4", "C5"][:len(conditions)]
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(short_labels)
|
||
ax.set_title(title)
|
||
ax.grid(axis='y', alpha=0.3)
|
||
|
||
if is_percentage:
|
||
ax.set_ylim(0, 110)
|
||
|
||
# Add legend
|
||
legend_elements = [
|
||
mpatches.Patch(facecolor=CONDITION_COLORS[c], label=CONDITION_LABELS[c])
|
||
for c in ordered_conditions if c in CONDITION_COLORS
|
||
]
|
||
fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.02))
|
||
|
||
plt.tight_layout()
|
||
plt.subplots_adjust(bottom=0.15)
|
||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f"Saved: {output_path}")
|
||
|
||
|
||
def generate_all_visualizations(
|
||
metrics: Dict[str, Any],
|
||
output_dir: Path
|
||
):
|
||
"""Generate all visualization figures."""
|
||
if not MATPLOTLIB_AVAILABLE:
|
||
print("matplotlib not available. Cannot generate visualizations.")
|
||
return
|
||
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
experiment_id = metrics.get("experiment_id", "experiment")
|
||
|
||
print(f"\nGenerating visualizations for {experiment_id}...")
|
||
|
||
# 1. Survival rates bar chart
|
||
plot_survival_rates(
|
||
metrics,
|
||
output_dir / f"{experiment_id}_survival_rates.png"
|
||
)
|
||
|
||
# 2. Idea counts stacked bar
|
||
plot_idea_counts(
|
||
metrics,
|
||
output_dir / f"{experiment_id}_idea_counts.png"
|
||
)
|
||
|
||
# 3. Diversity box plot
|
||
plot_box_comparison(
|
||
metrics,
|
||
"post_dedup_diversity.mean_pairwise_distance",
|
||
"Semantic Diversity by Condition (Post-Dedup)",
|
||
"Mean Pairwise Distance",
|
||
output_dir / f"{experiment_id}_diversity_boxplot.png"
|
||
)
|
||
|
||
# 4. Query distance box plot
|
||
plot_box_comparison(
|
||
metrics,
|
||
"post_dedup_query_distance.mean_distance",
|
||
"Query Distance by Condition (Novelty)",
|
||
"Distance from Original Query",
|
||
output_dir / f"{experiment_id}_query_distance_boxplot.png"
|
||
)
|
||
|
||
# 5. 2×2 interaction plot for diversity
|
||
plot_interaction_2x2(
|
||
metrics,
|
||
"post_dedup_diversity.mean_pairwise_distance",
|
||
"2×2 Factorial: Semantic Diversity",
|
||
"Mean Pairwise Distance",
|
||
output_dir / f"{experiment_id}_interaction_diversity.png"
|
||
)
|
||
|
||
# 6. 2×2 interaction plot for query distance
|
||
plot_interaction_2x2(
|
||
metrics,
|
||
"post_dedup_query_distance.mean_distance",
|
||
"2×2 Factorial: Query Distance (Novelty)",
|
||
"Distance from Original Query",
|
||
output_dir / f"{experiment_id}_interaction_novelty.png"
|
||
)
|
||
|
||
# 7. Multi-panel comparison
|
||
plot_metrics_comparison(
|
||
metrics,
|
||
output_dir / f"{experiment_id}_metrics_comparison.png"
|
||
)
|
||
|
||
print(f"\nAll visualizations saved to: {output_dir}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Generate visualizations for experiment results"
|
||
)
|
||
parser.add_argument(
|
||
"--input",
|
||
type=str,
|
||
required=True,
|
||
help="Input metrics JSON file"
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir",
|
||
type=str,
|
||
help="Output directory for figures (default: results/figures/)"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
input_path = Path(args.input)
|
||
if not input_path.exists():
|
||
input_path = RESULTS_DIR / args.input
|
||
if not input_path.exists():
|
||
print(f"Error: Input file not found: {args.input}")
|
||
sys.exit(1)
|
||
|
||
# Load metrics
|
||
with open(input_path, "r", encoding="utf-8") as f:
|
||
metrics = json.load(f)
|
||
|
||
# Output directory
|
||
if args.output_dir:
|
||
output_dir = Path(args.output_dir)
|
||
else:
|
||
output_dir = RESULTS_DIR / "figures"
|
||
|
||
generate_all_visualizations(metrics, output_dir)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|