File size: 2,030 Bytes
1b64cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Tuple, Dict, Any
from copy import deepcopy

from .state import EnvironmentState
from .observation import Observation
from .actions import Action
from .transition import apply_action
from .reward import compute_reward


class WorkflowEnv:
    def __init__(self, initial_state: EnvironmentState):
        self.initial_state = deepcopy(initial_state)
        self._state = deepcopy(initial_state)

    # -----------------------------
    # RESET
    # -----------------------------
    def reset(self) -> Observation:
        self._state = deepcopy(self.initial_state)
        return self._get_observation()

    # -----------------------------
    # STEP
    # -----------------------------
    def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
        if self._state.done:
            raise Exception("Episode already finished. Call reset().")

        # Log action
        self._state.history.append({
            "timestep": self._state.timestep,
            "action": action.model_dump()
        })

        # ✅ APPLY TRANSITION (NEW)
        self._state, info = apply_action(self._state, action)

        # ✅ COMPUTE REWARD (NEW)
        reward = compute_reward(self._state, action.type, info)

        # Increment timestep
        self._state.timestep += 1

        # Episode termination
        if self._state.timestep >= 10:
            self._state.done = True

        return self._get_observation(), reward, self._state.done, {}

    # -----------------------------
    # STATE ACCESS
    # -----------------------------
    def state(self) -> EnvironmentState:
        return self._state

    # -----------------------------
    # OBSERVATION
    # -----------------------------
    def _get_observation(self) -> Observation:
        return Observation(
            emails=self._state.emails,
            tasks=self._state.tasks,
            calendar=self._state.calendar,
            history=self._state.history,
            timestep=self._state.timestep
        )