echo-ultimate / scripts /run_baseline.py
Vikaspandey582003's picture
Upload folder using huggingface_hub
e8ad153 verified
"""Evaluate all 4 baseline agents and generate comparison plots."""
import sys, os, argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--quick", action="store_true", help="Fewer episodes for CI")
args, _ = parser.parse_known_args()
print("🎯 Running baseline evaluation…")
from config import cfg
from env.task_bank import TaskBank
from core.baseline import run_baseline_evaluation, ALL_BASELINES
from training.evaluate import (
evaluate_agent, make_synthetic_pair, compare_and_plot,
make_synthetic_training_log, EvalResults,
)
from pathlib import Path
Path(cfg.PLOTS_DIR).mkdir(parents=True, exist_ok=True)
bank = TaskBank(); bank.ensure_loaded()
n = 50 if args.quick else cfg.FULL_EVAL_EPISODES
print(f" 📊 Evaluating {len(ALL_BASELINES)} baselines ({n} episodes each)…")
baseline_reports = run_baseline_evaluation(bank, n_episodes=n)
print(" 📈 Building comparison EvalResults…")
from training.evaluate import EvalResults
from core.metrics import CalibrationReport
def _wrap(name, rep):
r = EvalResults(report=rep, label=name)
return r
baseline_eval = {name: _wrap(name.replace("_"," ").title(), rep)
for name, rep in baseline_reports.items()}
print(" 📊 Generating synthetic trained model (for plot demo)…")
_, trained_synth = make_synthetic_pair(ece_before=0.34, ece_after=0.08)
trained_synth.label = "ECHO Trained"
make_synthetic_training_log(cfg.TRAINING_LOG)
paths = compare_and_plot(trained_synth, {"Untrained": list(baseline_eval.values())[1]})
print("\n" + "─"*60)
print(" BASELINE RESULTS")
print("─"*60)
for name, rep in baseline_reports.items():
print(f" {name:<20} ECE={rep.ece:.3f} Acc={rep.accuracy:.1%} "
f"OverConf={rep.overconfidence_rate:.1%}")
print("─"*60)
print("\n✅ All plots saved to results/plots/")
for k, p in paths.items():
print(f" • {k}: {p}")
if __name__ == "__main__":
main()