Spaces:
Sleeping
Sleeping
| """Evaluate all 4 baseline agents and generate comparison plots.""" | |
| import sys, os, argparse | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--quick", action="store_true", help="Fewer episodes for CI") | |
| args, _ = parser.parse_known_args() | |
| print("🎯 Running baseline evaluation…") | |
| from config import cfg | |
| from env.task_bank import TaskBank | |
| from core.baseline import run_baseline_evaluation, ALL_BASELINES | |
| from training.evaluate import ( | |
| evaluate_agent, make_synthetic_pair, compare_and_plot, | |
| make_synthetic_training_log, EvalResults, | |
| ) | |
| from pathlib import Path | |
| Path(cfg.PLOTS_DIR).mkdir(parents=True, exist_ok=True) | |
| bank = TaskBank(); bank.ensure_loaded() | |
| n = 50 if args.quick else cfg.FULL_EVAL_EPISODES | |
| print(f" 📊 Evaluating {len(ALL_BASELINES)} baselines ({n} episodes each)…") | |
| baseline_reports = run_baseline_evaluation(bank, n_episodes=n) | |
| print(" 📈 Building comparison EvalResults…") | |
| from training.evaluate import EvalResults | |
| from core.metrics import CalibrationReport | |
| def _wrap(name, rep): | |
| r = EvalResults(report=rep, label=name) | |
| return r | |
| baseline_eval = {name: _wrap(name.replace("_"," ").title(), rep) | |
| for name, rep in baseline_reports.items()} | |
| print(" 📊 Generating synthetic trained model (for plot demo)…") | |
| _, trained_synth = make_synthetic_pair(ece_before=0.34, ece_after=0.08) | |
| trained_synth.label = "ECHO Trained" | |
| make_synthetic_training_log(cfg.TRAINING_LOG) | |
| paths = compare_and_plot(trained_synth, {"Untrained": list(baseline_eval.values())[1]}) | |
| print("\n" + "─"*60) | |
| print(" BASELINE RESULTS") | |
| print("─"*60) | |
| for name, rep in baseline_reports.items(): | |
| print(f" {name:<20} ECE={rep.ece:.3f} Acc={rep.accuracy:.1%} " | |
| f"OverConf={rep.overconfidence_rate:.1%}") | |
| print("─"*60) | |
| print("\n✅ All plots saved to results/plots/") | |
| for k, p in paths.items(): | |
| print(f" • {k}: {p}") | |
| if __name__ == "__main__": | |
| main() | |