File size: 1,314 Bytes
1b64cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from tasks.easy import create_easy_task
from tasks.medium import create_medium_task
from tasks.hard import create_hard_task

from graders.easy_grader import EasyGrader
from graders.medium_grader import MediumGrader
from graders.hard_grader import HardGrader

from app.env import WorkflowEnv
from baseline.policy import BaselinePolicy


def run_task(task_name, create_task_fn, grader_cls):
    state, ground_truth = create_task_fn()
    env = WorkflowEnv(state)
    policy = BaselinePolicy()

    obs = env.reset()

    done = False
    steps = 0

    while not done and steps < 10:
        action = policy.act(obs)

        if action is None:
            break

        obs, reward, done, _ = env.step(action)
        steps += 1

    trajectory = env.state().history
    print(f"{task_name} trajectory:", trajectory)

    grader = grader_cls()
    score = grader.grade(trajectory, ground_truth)

    return score


def main():
    results = {}

    results["easy"] = run_task("easy", create_easy_task, EasyGrader)
    results["medium"] = run_task("medium", create_medium_task, MediumGrader)
    results["hard"] = run_task("hard", create_hard_task, HardGrader)

    print("\n===== BASELINE RESULTS =====")
    for k, v in results.items():
        print(f"{k}: {round(v, 3)}")


if __name__ == "__main__":
    main()