File size: 925 Bytes
8a02303 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | """Evaluate heuristic baseline on all tasks and print scores."""
import copy
import json
import sys
from src.tasks import TASKS
from src.grader import RobustnessGrader
from src.baseline import heuristic_policy
def main(n_episodes: int = 10):
all_results = {}
for tid, cfg in TASKS.items():
try:
grader = RobustnessGrader(copy.deepcopy(cfg))
result = grader.evaluate_policy(
heuristic_policy, n_episodes=n_episodes
)
all_results[tid] = result
print(f"{tid}:")
for k, v in result.items():
print(f" {k}: {v}")
print()
except Exception as e:
all_results[tid] = {"error": str(e)}
print(f"{tid}: FAILED — {e}\n")
return all_results
if __name__ == "__main__":
episodes = int(sys.argv[1]) if len(sys.argv) > 1 else 10
main(n_episodes=episodes)
|