Spaces:
Sleeping
Sleeping
File size: 2,030 Bytes
1b64cba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | from typing import Tuple, Dict, Any
from copy import deepcopy
from .state import EnvironmentState
from .observation import Observation
from .actions import Action
from .transition import apply_action
from .reward import compute_reward
class WorkflowEnv:
def __init__(self, initial_state: EnvironmentState):
self.initial_state = deepcopy(initial_state)
self._state = deepcopy(initial_state)
# -----------------------------
# RESET
# -----------------------------
def reset(self) -> Observation:
self._state = deepcopy(self.initial_state)
return self._get_observation()
# -----------------------------
# STEP
# -----------------------------
def step(self, action: Action) -> Tuple[Observation, float, bool, Dict[str, Any]]:
if self._state.done:
raise Exception("Episode already finished. Call reset().")
# Log action
self._state.history.append({
"timestep": self._state.timestep,
"action": action.model_dump()
})
# ✅ APPLY TRANSITION (NEW)
self._state, info = apply_action(self._state, action)
# ✅ COMPUTE REWARD (NEW)
reward = compute_reward(self._state, action.type, info)
# Increment timestep
self._state.timestep += 1
# Episode termination
if self._state.timestep >= 10:
self._state.done = True
return self._get_observation(), reward, self._state.done, {}
# -----------------------------
# STATE ACCESS
# -----------------------------
def state(self) -> EnvironmentState:
return self._state
# -----------------------------
# OBSERVATION
# -----------------------------
def _get_observation(self) -> Observation:
return Observation(
emails=self._state.emails,
tasks=self._state.tasks,
calendar=self._state.calendar,
history=self._state.history,
timestep=self._state.timestep
) |