File size: 925 Bytes
8a02303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Evaluate heuristic baseline on all tasks and print scores."""

import copy
import json
import sys

from src.tasks import TASKS
from src.grader import RobustnessGrader
from src.baseline import heuristic_policy


def main(n_episodes: int = 10):
    all_results = {}

    for tid, cfg in TASKS.items():
        try:
            grader = RobustnessGrader(copy.deepcopy(cfg))
            result = grader.evaluate_policy(
                heuristic_policy, n_episodes=n_episodes
            )
            all_results[tid] = result

            print(f"{tid}:")
            for k, v in result.items():
                print(f"  {k}: {v}")
            print()

        except Exception as e:
            all_results[tid] = {"error": str(e)}
            print(f"{tid}: FAILED — {e}\n")

    return all_results


if __name__ == "__main__":
    episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 10
    main(n_episodes=episodes)