feat: Add experiments framework and novelty-driven agent loop
- Add complete experiments directory with pilot study infrastructure - 5 experimental conditions (direct, expert-only, attribute-only, full-pipeline, random-perspective) - Human assessment tool with React frontend and FastAPI backend - AUT flexibility analysis with jump signal detection - Result visualization and metrics computation - Add novelty-driven agent loop module (experiments/novelty_loop/) - NoveltyDrivenTaskAgent with expert perspective perturbation - Three termination strategies: breakthrough, exhaust, coverage - Interactive CLI demo with colored output - Embedding-based novelty scoring - Add DDC knowledge domain classification data (en/zh) - Add CLAUDE.md project documentation - Update research report with experiment findings Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
521
experiments/visualize.py
Normal file
521
experiments/visualize.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""
|
||||
Visualization for experiment results.
|
||||
|
||||
Generates:
|
||||
- Box plots of diversity by condition
|
||||
- 2×2 interaction plots
|
||||
- Bar charts of survival rates
|
||||
- t-SNE/UMAP of idea embeddings (optional)
|
||||
|
||||
Usage:
|
||||
python -m experiments.visualize --input results/experiment_xxx_metrics.json
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Add experiments to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from experiments.config import RESULTS_DIR
|
||||
|
||||
# Try to import visualization libraries
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
MATPLOTLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MATPLOTLIB_AVAILABLE = False
|
||||
print("Warning: matplotlib not installed. Visualization unavailable.")
|
||||
print("Install with: pip install matplotlib")
|
||||
|
||||
# Condition display names and colors
|
||||
CONDITION_LABELS = {
|
||||
"c1_direct": "C1: Direct",
|
||||
"c2_expert_only": "C2: Expert-Only",
|
||||
"c3_attribute_only": "C3: Attr-Only",
|
||||
"c4_full_pipeline": "C4: Full Pipeline",
|
||||
"c5_random_perspective": "C5: Random"
|
||||
}
|
||||
|
||||
CONDITION_COLORS = {
|
||||
"c1_direct": "#808080", # Gray (baseline)
|
||||
"c2_expert_only": "#2196F3", # Blue
|
||||
"c3_attribute_only": "#FF9800", # Orange
|
||||
"c4_full_pipeline": "#4CAF50", # Green (main)
|
||||
"c5_random_perspective": "#9C27B0" # Purple (control)
|
||||
}
|
||||
|
||||
# 2×2 factorial structure
|
||||
FACTORIAL_2X2 = {
|
||||
"no_attr_no_expert": "c1_direct",
|
||||
"no_attr_with_expert": "c2_expert_only",
|
||||
"with_attr_no_expert": "c3_attribute_only",
|
||||
"with_attr_with_expert": "c4_full_pipeline"
|
||||
}
|
||||
|
||||
|
||||
def extract_metric_values(
|
||||
metrics: Dict[str, Any],
|
||||
metric_path: str
|
||||
) -> Dict[str, List[float]]:
|
||||
"""Extract values for a specific metric across all queries."""
|
||||
by_condition = {}
|
||||
|
||||
for query_metrics in metrics.get("metrics_by_query", []):
|
||||
for condition, cond_metrics in query_metrics.get("conditions", {}).items():
|
||||
if condition not in by_condition:
|
||||
by_condition[condition] = []
|
||||
|
||||
value = cond_metrics
|
||||
for key in metric_path.split("."):
|
||||
if value is None:
|
||||
break
|
||||
if isinstance(value, dict):
|
||||
value = value.get(key)
|
||||
else:
|
||||
value = None
|
||||
|
||||
if value is not None and isinstance(value, (int, float)):
|
||||
by_condition[condition].append(float(value))
|
||||
|
||||
return by_condition
|
||||
|
||||
|
||||
def plot_box_comparison(
|
||||
metrics: Dict[str, Any],
|
||||
metric_path: str,
|
||||
title: str,
|
||||
ylabel: str,
|
||||
output_path: Path,
|
||||
figsize: tuple = (10, 6)
|
||||
):
|
||||
"""Create box plot comparing conditions."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
return
|
||||
|
||||
by_condition = extract_metric_values(metrics, metric_path)
|
||||
|
||||
# Order conditions
|
||||
ordered_conditions = [
|
||||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||||
"c4_full_pipeline", "c5_random_perspective"
|
||||
]
|
||||
conditions = [c for c in ordered_conditions if c in by_condition]
|
||||
|
||||
if not conditions:
|
||||
print(f"No data for {metric_path}")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Prepare data
|
||||
data = [by_condition[c] for c in conditions]
|
||||
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
|
||||
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
|
||||
|
||||
# Create box plot
|
||||
bp = ax.boxplot(data, labels=labels, patch_artist=True)
|
||||
|
||||
# Color boxes
|
||||
for patch, color in zip(bp['boxes'], colors):
|
||||
patch.set_facecolor(color)
|
||||
patch.set_alpha(0.7)
|
||||
|
||||
# Add individual points
|
||||
for i, (cond, values) in enumerate(zip(conditions, data)):
|
||||
x = np.random.normal(i + 1, 0.04, size=len(values))
|
||||
ax.scatter(x, values, alpha=0.6, color=colors[i], edgecolor='black', s=50)
|
||||
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_title(title)
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
|
||||
# Rotate labels if needed
|
||||
plt.xticks(rotation=15, ha='right')
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def plot_interaction_2x2(
|
||||
metrics: Dict[str, Any],
|
||||
metric_path: str,
|
||||
title: str,
|
||||
ylabel: str,
|
||||
output_path: Path,
|
||||
figsize: tuple = (8, 6)
|
||||
):
|
||||
"""Create 2×2 factorial interaction plot."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
return
|
||||
|
||||
by_condition = extract_metric_values(metrics, metric_path)
|
||||
|
||||
# Check if all 2×2 conditions available
|
||||
required = ["c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline"]
|
||||
if not all(c in by_condition and by_condition[c] for c in required):
|
||||
print(f"Insufficient data for 2×2 plot of {metric_path}")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Calculate means
|
||||
means = {c: np.mean(by_condition[c]) for c in required}
|
||||
stds = {c: np.std(by_condition[c], ddof=1) if len(by_condition[c]) > 1 else 0 for c in required}
|
||||
|
||||
# X positions: No Experts, With Experts
|
||||
x = [0, 1]
|
||||
x_labels = ["Without Experts", "With Experts"]
|
||||
|
||||
# Line 1: Without Attributes (C1 -> C2)
|
||||
y_no_attr = [means["c1_direct"], means["c2_expert_only"]]
|
||||
err_no_attr = [stds["c1_direct"], stds["c2_expert_only"]]
|
||||
ax.errorbar(x, y_no_attr, yerr=err_no_attr, marker='o', markersize=10,
|
||||
linewidth=2, capsize=5, label="Without Attributes",
|
||||
color="#FF9800", linestyle='--')
|
||||
|
||||
# Line 2: With Attributes (C3 -> C4)
|
||||
y_with_attr = [means["c3_attribute_only"], means["c4_full_pipeline"]]
|
||||
err_with_attr = [stds["c3_attribute_only"], stds["c4_full_pipeline"]]
|
||||
ax.errorbar(x, y_with_attr, yerr=err_with_attr, marker='s', markersize=10,
|
||||
linewidth=2, capsize=5, label="With Attributes",
|
||||
color="#4CAF50", linestyle='-')
|
||||
|
||||
# Annotate points
|
||||
offset = 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) if ax.get_ylim()[1] != ax.get_ylim()[0] else 0.01
|
||||
ax.annotate("C1", (x[0], y_no_attr[0]), textcoords="offset points",
|
||||
xytext=(-15, -15), fontsize=9)
|
||||
ax.annotate("C2", (x[1], y_no_attr[1]), textcoords="offset points",
|
||||
xytext=(5, -15), fontsize=9)
|
||||
ax.annotate("C3", (x[0], y_with_attr[0]), textcoords="offset points",
|
||||
xytext=(-15, 10), fontsize=9)
|
||||
ax.annotate("C4", (x[1], y_with_attr[1]), textcoords="offset points",
|
||||
xytext=(5, 10), fontsize=9)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(x_labels)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_title(title)
|
||||
ax.legend(loc='best')
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
|
||||
# Check for interaction (non-parallel lines)
|
||||
slope_no_attr = y_no_attr[1] - y_no_attr[0]
|
||||
slope_with_attr = y_with_attr[1] - y_with_attr[0]
|
||||
interaction = slope_with_attr - slope_no_attr
|
||||
|
||||
interaction_text = f"Interaction: {interaction:+.4f}"
|
||||
if interaction > 0.01:
|
||||
interaction_text += " (super-additive)"
|
||||
elif interaction < -0.01:
|
||||
interaction_text += " (sub-additive)"
|
||||
else:
|
||||
interaction_text += " (additive)"
|
||||
|
||||
ax.text(0.02, 0.98, interaction_text, transform=ax.transAxes,
|
||||
fontsize=10, verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def plot_survival_rates(
|
||||
metrics: Dict[str, Any],
|
||||
output_path: Path,
|
||||
figsize: tuple = (10, 6)
|
||||
):
|
||||
"""Create bar chart of deduplication survival rates."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
return
|
||||
|
||||
by_condition = extract_metric_values(metrics, "survival_rate")
|
||||
|
||||
ordered_conditions = [
|
||||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||||
"c4_full_pipeline", "c5_random_perspective"
|
||||
]
|
||||
conditions = [c for c in ordered_conditions if c in by_condition]
|
||||
|
||||
if not conditions:
|
||||
print("No survival rate data")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Calculate means and stds
|
||||
means = [np.mean(by_condition[c]) * 100 for c in conditions] # Convert to percentage
|
||||
stds = [np.std(by_condition[c], ddof=1) * 100 if len(by_condition[c]) > 1 else 0 for c in conditions]
|
||||
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
|
||||
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
|
||||
|
||||
x = np.arange(len(conditions))
|
||||
bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, alpha=0.8, edgecolor='black')
|
||||
|
||||
# Add value labels on bars
|
||||
for bar, mean in zip(bars, means):
|
||||
height = bar.get_height()
|
||||
ax.annotate(f'{mean:.1f}%',
|
||||
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||
xytext=(0, 3), textcoords="offset points",
|
||||
ha='center', va='bottom', fontsize=10)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, rotation=15, ha='right')
|
||||
ax.set_ylabel("Survival Rate (%)")
|
||||
ax.set_title("Deduplication Survival Rate by Condition\n(Higher = More Diverse Generation)")
|
||||
ax.set_ylim(0, 110)
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def plot_idea_counts(
|
||||
metrics: Dict[str, Any],
|
||||
output_path: Path,
|
||||
figsize: tuple = (10, 6)
|
||||
):
|
||||
"""Create stacked bar chart of raw vs unique idea counts."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
return
|
||||
|
||||
raw_counts = extract_metric_values(metrics, "raw_count")
|
||||
unique_counts = extract_metric_values(metrics, "unique_count")
|
||||
|
||||
ordered_conditions = [
|
||||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||||
"c4_full_pipeline", "c5_random_perspective"
|
||||
]
|
||||
conditions = [c for c in ordered_conditions if c in raw_counts and c in unique_counts]
|
||||
|
||||
if not conditions:
|
||||
print("No count data")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
# Calculate means
|
||||
raw_means = [np.mean(raw_counts[c]) for c in conditions]
|
||||
unique_means = [np.mean(unique_counts[c]) for c in conditions]
|
||||
removed_means = [r - u for r, u in zip(raw_means, unique_means)]
|
||||
|
||||
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
|
||||
x = np.arange(len(conditions))
|
||||
width = 0.6
|
||||
|
||||
# Stacked bars: unique (bottom) + removed (top)
|
||||
bars1 = ax.bar(x, unique_means, width, label='Unique Ideas',
|
||||
color=[CONDITION_COLORS.get(c, "#888888") for c in conditions], alpha=0.9)
|
||||
bars2 = ax.bar(x, removed_means, width, bottom=unique_means, label='Duplicates Removed',
|
||||
color='lightgray', alpha=0.7, hatch='//')
|
||||
|
||||
# Add value labels
|
||||
for i, (unique, raw) in enumerate(zip(unique_means, raw_means)):
|
||||
ax.annotate(f'{unique:.0f}', xy=(x[i], unique / 2),
|
||||
ha='center', va='center', fontsize=10, fontweight='bold')
|
||||
ax.annotate(f'({raw:.0f})', xy=(x[i], raw + 1),
|
||||
ha='center', va='bottom', fontsize=9, color='gray')
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, rotation=15, ha='right')
|
||||
ax.set_ylabel("Number of Ideas")
|
||||
ax.set_title("Idea Counts by Condition\n(Unique ideas shown, raw total in parentheses)")
|
||||
ax.legend(loc='upper right')
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def plot_metrics_comparison(
|
||||
metrics: Dict[str, Any],
|
||||
output_path: Path,
|
||||
figsize: tuple = (12, 8)
|
||||
):
|
||||
"""Create multi-panel comparison of key metrics."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
return
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
||||
|
||||
# Extract metrics
|
||||
metrics_to_plot = [
|
||||
("survival_rate", "Survival Rate", axes[0, 0], True),
|
||||
("post_dedup_diversity.mean_pairwise_distance", "Semantic Diversity", axes[0, 1], False),
|
||||
("post_dedup_query_distance.mean_distance", "Query Distance (Novelty)", axes[1, 0], False),
|
||||
("post_dedup_clusters.optimal_clusters", "Number of Clusters", axes[1, 1], False),
|
||||
]
|
||||
|
||||
ordered_conditions = [
|
||||
"c1_direct", "c2_expert_only", "c3_attribute_only",
|
||||
"c4_full_pipeline", "c5_random_perspective"
|
||||
]
|
||||
|
||||
for metric_path, title, ax, is_percentage in metrics_to_plot:
|
||||
by_condition = extract_metric_values(metrics, metric_path)
|
||||
conditions = [c for c in ordered_conditions if c in by_condition and by_condition[c]]
|
||||
|
||||
if not conditions:
|
||||
ax.text(0.5, 0.5, "No data", ha='center', va='center', transform=ax.transAxes)
|
||||
ax.set_title(title)
|
||||
continue
|
||||
|
||||
means = [np.mean(by_condition[c]) for c in conditions]
|
||||
if is_percentage:
|
||||
means = [m * 100 for m in means]
|
||||
|
||||
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
|
||||
x = np.arange(len(conditions))
|
||||
|
||||
bars = ax.bar(x, means, color=colors, alpha=0.8, edgecolor='black')
|
||||
|
||||
# Simplified labels
|
||||
short_labels = ["C1", "C2", "C3", "C4", "C5"][:len(conditions)]
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(short_labels)
|
||||
ax.set_title(title)
|
||||
ax.grid(axis='y', alpha=0.3)
|
||||
|
||||
if is_percentage:
|
||||
ax.set_ylim(0, 110)
|
||||
|
||||
# Add legend
|
||||
legend_elements = [
|
||||
mpatches.Patch(facecolor=CONDITION_COLORS[c], label=CONDITION_LABELS[c])
|
||||
for c in ordered_conditions if c in CONDITION_COLORS
|
||||
]
|
||||
fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.02))
|
||||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(bottom=0.15)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f"Saved: {output_path}")
|
||||
|
||||
|
||||
def generate_all_visualizations(
|
||||
metrics: Dict[str, Any],
|
||||
output_dir: Path
|
||||
):
|
||||
"""Generate all visualization figures."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
print("matplotlib not available. Cannot generate visualizations.")
|
||||
return
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
experiment_id = metrics.get("experiment_id", "experiment")
|
||||
|
||||
print(f"\nGenerating visualizations for {experiment_id}...")
|
||||
|
||||
# 1. Survival rates bar chart
|
||||
plot_survival_rates(
|
||||
metrics,
|
||||
output_dir / f"{experiment_id}_survival_rates.png"
|
||||
)
|
||||
|
||||
# 2. Idea counts stacked bar
|
||||
plot_idea_counts(
|
||||
metrics,
|
||||
output_dir / f"{experiment_id}_idea_counts.png"
|
||||
)
|
||||
|
||||
# 3. Diversity box plot
|
||||
plot_box_comparison(
|
||||
metrics,
|
||||
"post_dedup_diversity.mean_pairwise_distance",
|
||||
"Semantic Diversity by Condition (Post-Dedup)",
|
||||
"Mean Pairwise Distance",
|
||||
output_dir / f"{experiment_id}_diversity_boxplot.png"
|
||||
)
|
||||
|
||||
# 4. Query distance box plot
|
||||
plot_box_comparison(
|
||||
metrics,
|
||||
"post_dedup_query_distance.mean_distance",
|
||||
"Query Distance by Condition (Novelty)",
|
||||
"Distance from Original Query",
|
||||
output_dir / f"{experiment_id}_query_distance_boxplot.png"
|
||||
)
|
||||
|
||||
# 5. 2×2 interaction plot for diversity
|
||||
plot_interaction_2x2(
|
||||
metrics,
|
||||
"post_dedup_diversity.mean_pairwise_distance",
|
||||
"2×2 Factorial: Semantic Diversity",
|
||||
"Mean Pairwise Distance",
|
||||
output_dir / f"{experiment_id}_interaction_diversity.png"
|
||||
)
|
||||
|
||||
# 6. 2×2 interaction plot for query distance
|
||||
plot_interaction_2x2(
|
||||
metrics,
|
||||
"post_dedup_query_distance.mean_distance",
|
||||
"2×2 Factorial: Query Distance (Novelty)",
|
||||
"Distance from Original Query",
|
||||
output_dir / f"{experiment_id}_interaction_novelty.png"
|
||||
)
|
||||
|
||||
# 7. Multi-panel comparison
|
||||
plot_metrics_comparison(
|
||||
metrics,
|
||||
output_dir / f"{experiment_id}_metrics_comparison.png"
|
||||
)
|
||||
|
||||
print(f"\nAll visualizations saved to: {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate visualizations for experiment results"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input metrics JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Output directory for figures (default: results/figures/)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
input_path = Path(args.input)
|
||||
if not input_path.exists():
|
||||
input_path = RESULTS_DIR / args.input
|
||||
if not input_path.exists():
|
||||
print(f"Error: Input file not found: {args.input}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load metrics
|
||||
with open(input_path, "r", encoding="utf-8") as f:
|
||||
metrics = json.load(f)
|
||||
|
||||
# Output directory
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
else:
|
||||
output_dir = RESULTS_DIR / "figures"
|
||||
|
||||
generate_all_visualizations(metrics, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user