File size: 2,286 Bytes
acb327b
 
 
 
 
 
 
 
 
e8ad153
acb327b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Evaluate all 4 baseline agents and generate comparison plots."""
import sys, os, argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--quick", action="store_true", help="Fewer episodes for CI")
    args, _ = parser.parse_known_args()

    print("🎯  Running baseline evaluation…")
    from config import cfg
    from env.task_bank import TaskBank
    from core.baseline import run_baseline_evaluation, ALL_BASELINES
    from training.evaluate import (
        evaluate_agent, make_synthetic_pair, compare_and_plot,
        make_synthetic_training_log, EvalResults,
    )
    from pathlib import Path

    Path(cfg.PLOTS_DIR).mkdir(parents=True, exist_ok=True)
    bank = TaskBank(); bank.ensure_loaded()
    n = 50 if args.quick else cfg.FULL_EVAL_EPISODES

    print(f"  📊  Evaluating {len(ALL_BASELINES)} baselines ({n} episodes each)…")
    baseline_reports = run_baseline_evaluation(bank, n_episodes=n)

    print("  📈  Building comparison EvalResults…")
    from training.evaluate import EvalResults
    from core.metrics import CalibrationReport

    def _wrap(name, rep):
        r = EvalResults(report=rep, label=name)
        return r

    baseline_eval = {name: _wrap(name.replace("_"," ").title(), rep)
                     for name, rep in baseline_reports.items()}

    print("  📊  Generating synthetic trained model (for plot demo)…")
    _, trained_synth = make_synthetic_pair(ece_before=0.34, ece_after=0.08)
    trained_synth.label = "ECHO Trained"

    make_synthetic_training_log(cfg.TRAINING_LOG)
    paths = compare_and_plot(trained_synth, {"Untrained": list(baseline_eval.values())[1]})

    print("\n" + "─"*60)
    print("  BASELINE RESULTS")
    print("─"*60)
    for name, rep in baseline_reports.items():
        print(f"  {name:<20}  ECE={rep.ece:.3f}  Acc={rep.accuracy:.1%}  "
              f"OverConf={rep.overconfidence_rate:.1%}")
    print("─"*60)
    print("\n✅  All plots saved to results/plots/")
    for k, p in paths.items():
        print(f"    • {k}: {p}")

if __name__ == "__main__":
    main()