Files
novelty-seeking/experiments/visualize.py
gbanyan 43c025e060 feat: Add experiments framework and novelty-driven agent loop
- Add complete experiments directory with pilot study infrastructure
  - 5 experimental conditions (direct, expert-only, attribute-only, full-pipeline, random-perspective)
  - Human assessment tool with React frontend and FastAPI backend
  - AUT flexibility analysis with jump signal detection
  - Result visualization and metrics computation

- Add novelty-driven agent loop module (experiments/novelty_loop/)
  - NoveltyDrivenTaskAgent with expert perspective perturbation
  - Three termination strategies: breakthrough, exhaust, coverage
  - Interactive CLI demo with colored output
  - Embedding-based novelty scoring

- Add DDC knowledge domain classification data (en/zh)
- Add CLAUDE.md project documentation
- Update research report with experiment findings

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 10:16:21 +08:00

522 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Visualization for experiment results.
Generates:
- Box plots of diversity by condition
- 2×2 interaction plots
- Bar charts of survival rates
- t-SNE/UMAP of idea embeddings (optional)
Usage:
python -m experiments.visualize --input results/experiment_xxx_metrics.json
"""
import sys
import json
import argparse
from pathlib import Path
from typing import List, Dict, Any, Optional
import numpy as np
# Add experiments to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from experiments.config import RESULTS_DIR
# Try to import visualization libraries
try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
MATPLOTLIB_AVAILABLE = True
except ImportError:
MATPLOTLIB_AVAILABLE = False
print("Warning: matplotlib not installed. Visualization unavailable.")
print("Install with: pip install matplotlib")
# Condition display names and colors
CONDITION_LABELS = {
"c1_direct": "C1: Direct",
"c2_expert_only": "C2: Expert-Only",
"c3_attribute_only": "C3: Attr-Only",
"c4_full_pipeline": "C4: Full Pipeline",
"c5_random_perspective": "C5: Random"
}
CONDITION_COLORS = {
"c1_direct": "#808080", # Gray (baseline)
"c2_expert_only": "#2196F3", # Blue
"c3_attribute_only": "#FF9800", # Orange
"c4_full_pipeline": "#4CAF50", # Green (main)
"c5_random_perspective": "#9C27B0" # Purple (control)
}
# 2×2 factorial structure
FACTORIAL_2X2 = {
"no_attr_no_expert": "c1_direct",
"no_attr_with_expert": "c2_expert_only",
"with_attr_no_expert": "c3_attribute_only",
"with_attr_with_expert": "c4_full_pipeline"
}
def extract_metric_values(
metrics: Dict[str, Any],
metric_path: str
) -> Dict[str, List[float]]:
"""Extract values for a specific metric across all queries."""
by_condition = {}
for query_metrics in metrics.get("metrics_by_query", []):
for condition, cond_metrics in query_metrics.get("conditions", {}).items():
if condition not in by_condition:
by_condition[condition] = []
value = cond_metrics
for key in metric_path.split("."):
if value is None:
break
if isinstance(value, dict):
value = value.get(key)
else:
value = None
if value is not None and isinstance(value, (int, float)):
by_condition[condition].append(float(value))
return by_condition
def plot_box_comparison(
metrics: Dict[str, Any],
metric_path: str,
title: str,
ylabel: str,
output_path: Path,
figsize: tuple = (10, 6)
):
"""Create box plot comparing conditions."""
if not MATPLOTLIB_AVAILABLE:
return
by_condition = extract_metric_values(metrics, metric_path)
# Order conditions
ordered_conditions = [
"c1_direct", "c2_expert_only", "c3_attribute_only",
"c4_full_pipeline", "c5_random_perspective"
]
conditions = [c for c in ordered_conditions if c in by_condition]
if not conditions:
print(f"No data for {metric_path}")
return
fig, ax = plt.subplots(figsize=figsize)
# Prepare data
data = [by_condition[c] for c in conditions]
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
# Create box plot
bp = ax.boxplot(data, labels=labels, patch_artist=True)
# Color boxes
for patch, color in zip(bp['boxes'], colors):
patch.set_facecolor(color)
patch.set_alpha(0.7)
# Add individual points
for i, (cond, values) in enumerate(zip(conditions, data)):
x = np.random.normal(i + 1, 0.04, size=len(values))
ax.scatter(x, values, alpha=0.6, color=colors[i], edgecolor='black', s=50)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.grid(axis='y', alpha=0.3)
# Rotate labels if needed
plt.xticks(rotation=15, ha='right')
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved: {output_path}")
def plot_interaction_2x2(
metrics: Dict[str, Any],
metric_path: str,
title: str,
ylabel: str,
output_path: Path,
figsize: tuple = (8, 6)
):
"""Create 2×2 factorial interaction plot."""
if not MATPLOTLIB_AVAILABLE:
return
by_condition = extract_metric_values(metrics, metric_path)
# Check if all 2×2 conditions available
required = ["c1_direct", "c2_expert_only", "c3_attribute_only", "c4_full_pipeline"]
if not all(c in by_condition and by_condition[c] for c in required):
print(f"Insufficient data for 2×2 plot of {metric_path}")
return
fig, ax = plt.subplots(figsize=figsize)
# Calculate means
means = {c: np.mean(by_condition[c]) for c in required}
stds = {c: np.std(by_condition[c], ddof=1) if len(by_condition[c]) > 1 else 0 for c in required}
# X positions: No Experts, With Experts
x = [0, 1]
x_labels = ["Without Experts", "With Experts"]
# Line 1: Without Attributes (C1 -> C2)
y_no_attr = [means["c1_direct"], means["c2_expert_only"]]
err_no_attr = [stds["c1_direct"], stds["c2_expert_only"]]
ax.errorbar(x, y_no_attr, yerr=err_no_attr, marker='o', markersize=10,
linewidth=2, capsize=5, label="Without Attributes",
color="#FF9800", linestyle='--')
# Line 2: With Attributes (C3 -> C4)
y_with_attr = [means["c3_attribute_only"], means["c4_full_pipeline"]]
err_with_attr = [stds["c3_attribute_only"], stds["c4_full_pipeline"]]
ax.errorbar(x, y_with_attr, yerr=err_with_attr, marker='s', markersize=10,
linewidth=2, capsize=5, label="With Attributes",
color="#4CAF50", linestyle='-')
# Annotate points
offset = 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) if ax.get_ylim()[1] != ax.get_ylim()[0] else 0.01
ax.annotate("C1", (x[0], y_no_attr[0]), textcoords="offset points",
xytext=(-15, -15), fontsize=9)
ax.annotate("C2", (x[1], y_no_attr[1]), textcoords="offset points",
xytext=(5, -15), fontsize=9)
ax.annotate("C3", (x[0], y_with_attr[0]), textcoords="offset points",
xytext=(-15, 10), fontsize=9)
ax.annotate("C4", (x[1], y_with_attr[1]), textcoords="offset points",
xytext=(5, 10), fontsize=9)
ax.set_xticks(x)
ax.set_xticklabels(x_labels)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.legend(loc='best')
ax.grid(axis='y', alpha=0.3)
# Check for interaction (non-parallel lines)
slope_no_attr = y_no_attr[1] - y_no_attr[0]
slope_with_attr = y_with_attr[1] - y_with_attr[0]
interaction = slope_with_attr - slope_no_attr
interaction_text = f"Interaction: {interaction:+.4f}"
if interaction > 0.01:
interaction_text += " (super-additive)"
elif interaction < -0.01:
interaction_text += " (sub-additive)"
else:
interaction_text += " (additive)"
ax.text(0.02, 0.98, interaction_text, transform=ax.transAxes,
fontsize=10, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved: {output_path}")
def plot_survival_rates(
metrics: Dict[str, Any],
output_path: Path,
figsize: tuple = (10, 6)
):
"""Create bar chart of deduplication survival rates."""
if not MATPLOTLIB_AVAILABLE:
return
by_condition = extract_metric_values(metrics, "survival_rate")
ordered_conditions = [
"c1_direct", "c2_expert_only", "c3_attribute_only",
"c4_full_pipeline", "c5_random_perspective"
]
conditions = [c for c in ordered_conditions if c in by_condition]
if not conditions:
print("No survival rate data")
return
fig, ax = plt.subplots(figsize=figsize)
# Calculate means and stds
means = [np.mean(by_condition[c]) * 100 for c in conditions] # Convert to percentage
stds = [np.std(by_condition[c], ddof=1) * 100 if len(by_condition[c]) > 1 else 0 for c in conditions]
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
x = np.arange(len(conditions))
bars = ax.bar(x, means, yerr=stds, capsize=5, color=colors, alpha=0.8, edgecolor='black')
# Add value labels on bars
for bar, mean in zip(bars, means):
height = bar.get_height()
ax.annotate(f'{mean:.1f}%',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), textcoords="offset points",
ha='center', va='bottom', fontsize=10)
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=15, ha='right')
ax.set_ylabel("Survival Rate (%)")
ax.set_title("Deduplication Survival Rate by Condition\n(Higher = More Diverse Generation)")
ax.set_ylim(0, 110)
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved: {output_path}")
def plot_idea_counts(
metrics: Dict[str, Any],
output_path: Path,
figsize: tuple = (10, 6)
):
"""Create stacked bar chart of raw vs unique idea counts."""
if not MATPLOTLIB_AVAILABLE:
return
raw_counts = extract_metric_values(metrics, "raw_count")
unique_counts = extract_metric_values(metrics, "unique_count")
ordered_conditions = [
"c1_direct", "c2_expert_only", "c3_attribute_only",
"c4_full_pipeline", "c5_random_perspective"
]
conditions = [c for c in ordered_conditions if c in raw_counts and c in unique_counts]
if not conditions:
print("No count data")
return
fig, ax = plt.subplots(figsize=figsize)
# Calculate means
raw_means = [np.mean(raw_counts[c]) for c in conditions]
unique_means = [np.mean(unique_counts[c]) for c in conditions]
removed_means = [r - u for r, u in zip(raw_means, unique_means)]
labels = [CONDITION_LABELS.get(c, c) for c in conditions]
x = np.arange(len(conditions))
width = 0.6
# Stacked bars: unique (bottom) + removed (top)
bars1 = ax.bar(x, unique_means, width, label='Unique Ideas',
color=[CONDITION_COLORS.get(c, "#888888") for c in conditions], alpha=0.9)
bars2 = ax.bar(x, removed_means, width, bottom=unique_means, label='Duplicates Removed',
color='lightgray', alpha=0.7, hatch='//')
# Add value labels
for i, (unique, raw) in enumerate(zip(unique_means, raw_means)):
ax.annotate(f'{unique:.0f}', xy=(x[i], unique / 2),
ha='center', va='center', fontsize=10, fontweight='bold')
ax.annotate(f'({raw:.0f})', xy=(x[i], raw + 1),
ha='center', va='bottom', fontsize=9, color='gray')
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=15, ha='right')
ax.set_ylabel("Number of Ideas")
ax.set_title("Idea Counts by Condition\n(Unique ideas shown, raw total in parentheses)")
ax.legend(loc='upper right')
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved: {output_path}")
def plot_metrics_comparison(
metrics: Dict[str, Any],
output_path: Path,
figsize: tuple = (12, 8)
):
"""Create multi-panel comparison of key metrics."""
if not MATPLOTLIB_AVAILABLE:
return
fig, axes = plt.subplots(2, 2, figsize=figsize)
# Extract metrics
metrics_to_plot = [
("survival_rate", "Survival Rate", axes[0, 0], True),
("post_dedup_diversity.mean_pairwise_distance", "Semantic Diversity", axes[0, 1], False),
("post_dedup_query_distance.mean_distance", "Query Distance (Novelty)", axes[1, 0], False),
("post_dedup_clusters.optimal_clusters", "Number of Clusters", axes[1, 1], False),
]
ordered_conditions = [
"c1_direct", "c2_expert_only", "c3_attribute_only",
"c4_full_pipeline", "c5_random_perspective"
]
for metric_path, title, ax, is_percentage in metrics_to_plot:
by_condition = extract_metric_values(metrics, metric_path)
conditions = [c for c in ordered_conditions if c in by_condition and by_condition[c]]
if not conditions:
ax.text(0.5, 0.5, "No data", ha='center', va='center', transform=ax.transAxes)
ax.set_title(title)
continue
means = [np.mean(by_condition[c]) for c in conditions]
if is_percentage:
means = [m * 100 for m in means]
colors = [CONDITION_COLORS.get(c, "#888888") for c in conditions]
x = np.arange(len(conditions))
bars = ax.bar(x, means, color=colors, alpha=0.8, edgecolor='black')
# Simplified labels
short_labels = ["C1", "C2", "C3", "C4", "C5"][:len(conditions)]
ax.set_xticks(x)
ax.set_xticklabels(short_labels)
ax.set_title(title)
ax.grid(axis='y', alpha=0.3)
if is_percentage:
ax.set_ylim(0, 110)
# Add legend
legend_elements = [
mpatches.Patch(facecolor=CONDITION_COLORS[c], label=CONDITION_LABELS[c])
for c in ordered_conditions if c in CONDITION_COLORS
]
fig.legend(handles=legend_elements, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.02))
plt.tight_layout()
plt.subplots_adjust(bottom=0.15)
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close()
print(f"Saved: {output_path}")
def generate_all_visualizations(
metrics: Dict[str, Any],
output_dir: Path
):
"""Generate all visualization figures."""
if not MATPLOTLIB_AVAILABLE:
print("matplotlib not available. Cannot generate visualizations.")
return
output_dir.mkdir(parents=True, exist_ok=True)
experiment_id = metrics.get("experiment_id", "experiment")
print(f"\nGenerating visualizations for {experiment_id}...")
# 1. Survival rates bar chart
plot_survival_rates(
metrics,
output_dir / f"{experiment_id}_survival_rates.png"
)
# 2. Idea counts stacked bar
plot_idea_counts(
metrics,
output_dir / f"{experiment_id}_idea_counts.png"
)
# 3. Diversity box plot
plot_box_comparison(
metrics,
"post_dedup_diversity.mean_pairwise_distance",
"Semantic Diversity by Condition (Post-Dedup)",
"Mean Pairwise Distance",
output_dir / f"{experiment_id}_diversity_boxplot.png"
)
# 4. Query distance box plot
plot_box_comparison(
metrics,
"post_dedup_query_distance.mean_distance",
"Query Distance by Condition (Novelty)",
"Distance from Original Query",
output_dir / f"{experiment_id}_query_distance_boxplot.png"
)
# 5. 2×2 interaction plot for diversity
plot_interaction_2x2(
metrics,
"post_dedup_diversity.mean_pairwise_distance",
"2×2 Factorial: Semantic Diversity",
"Mean Pairwise Distance",
output_dir / f"{experiment_id}_interaction_diversity.png"
)
# 6. 2×2 interaction plot for query distance
plot_interaction_2x2(
metrics,
"post_dedup_query_distance.mean_distance",
"2×2 Factorial: Query Distance (Novelty)",
"Distance from Original Query",
output_dir / f"{experiment_id}_interaction_novelty.png"
)
# 7. Multi-panel comparison
plot_metrics_comparison(
metrics,
output_dir / f"{experiment_id}_metrics_comparison.png"
)
print(f"\nAll visualizations saved to: {output_dir}")
def main():
parser = argparse.ArgumentParser(
description="Generate visualizations for experiment results"
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Input metrics JSON file"
)
parser.add_argument(
"--output-dir",
type=str,
help="Output directory for figures (default: results/figures/)"
)
args = parser.parse_args()
input_path = Path(args.input)
if not input_path.exists():
input_path = RESULTS_DIR / args.input
if not input_path.exists():
print(f"Error: Input file not found: {args.input}")
sys.exit(1)
# Load metrics
with open(input_path, "r", encoding="utf-8") as f:
metrics = json.load(f)
# Output directory
if args.output_dir:
output_dir = Path(args.output_dir)
else:
output_dir = RESULTS_DIR / "figures"
generate_all_visualizations(metrics, output_dir)
if __name__ == "__main__":
main()