mahammadaftab's picture
clean initial commit
62851e9
"""
Executive Assistant Environment — OpenEnv-compliant RL environment.
Main entry point for the simulation. Orchestrates scenario generation,
action execution, reward computation, and observation delivery.
Supports:
- Action masking (invalid action prevention)
- Partial observability (hidden tasks, delayed inbox)
- Curriculum learning (difficulty auto-scaling)
- Multi-objective reward shaping
"""
from typing import Dict, Tuple, List, Optional
from env.state import State
from env.rewards import RewardEngine
from env.scheduler import Scheduler
from env.scenario_generator import ScenarioGenerator
from env.actions import parse_action, get_valid_actions, get_action_mask
class ExecutiveAssistantEnv:
"""OpenEnv RL environment simulating an executive assistant.
The agent must manage a day's schedule: complete tasks, handle inbox
messages, resolve conflicts, and optimize for multiple objectives.
Attributes:
max_steps: Maximum steps per episode.
difficulty: Current difficulty level.
auto_curriculum: If True, difficulty increases automatically.
"""
def __init__(
self,
difficulty: str = "medium",
max_steps: int = 50,
auto_curriculum: bool = False,
seed: Optional[int] = None,
):
"""Initialize the environment.
Args:
difficulty: Starting difficulty ('easy', 'medium', 'hard').
max_steps: Maximum steps before forced termination.
auto_curriculum: Automatically increase difficulty over episodes.
seed: Random seed for reproducibility.
"""
self.max_steps = max_steps
self.difficulty = difficulty
self.auto_curriculum = auto_curriculum
self.seed = seed
self.generator = ScenarioGenerator(difficulty=difficulty, seed=seed)
self.reward_engine = RewardEngine()
self.scheduler = Scheduler()
# Episode tracking
self.state: Optional[State] = None
self.steps = 0
self.episode_reward = 0.0
self.episode_count = 0
self.episode_history: List[Dict] = []
# Curriculum tracking
self._consecutive_good_episodes = 0
self._curriculum_threshold = 5 # Episodes before difficulty increase
def reset(self) -> Dict:
"""Reset the environment and generate a new scenario.
Returns:
Initial observation dict.
"""
# Auto-curriculum: increase difficulty if performing well
if self.auto_curriculum and self.episode_count > 0:
self._update_curriculum()
self.state = self.generator.generate()
self.steps = 0
self.episode_reward = 0.0
self.episode_count += 1
obs = self.state.to_dict()
obs["valid_actions"] = get_valid_actions(obs)
obs["action_mask"] = get_action_mask(obs)
obs["step"] = self.steps
obs["difficulty"] = self.difficulty
return obs
def step(self, action) -> Tuple[Dict, float, bool, Dict]:
"""Execute one environment step.
Args:
action: Agent action (supports multiple formats via parse_action).
Returns:
Tuple of (observation, reward, done, info).
"""
if self.state is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
self.steps += 1
# Parse the action
action_type, target_id = parse_action(action)
# Apply action via scheduler
result = self.scheduler.apply_action(self.state, action_type, target_id)
# Compute reward
reward = self.reward_engine.compute(
self.state, action_type, target_id, result
)
self.episode_reward += reward
# Check termination
done = self.steps >= self.max_steps or self.state.is_terminal()
# Build observation
obs = self.state.to_dict()
obs["valid_actions"] = get_valid_actions(obs)
obs["action_mask"] = get_action_mask(obs)
obs["step"] = self.steps
obs["difficulty"] = self.difficulty
# Build info dict
info = {
"action_type": action_type,
"target_id": target_id,
"action_success": result.get("success", False),
"action_detail": result.get("detail", ""),
"conflicts_created": result.get("conflicts_created", 0),
"revealed_tasks": result.get("revealed_tasks", 0),
"revealed_messages": result.get("revealed_messages", 0),
"episode_reward": self.episode_reward,
"steps_remaining": self.max_steps - self.steps,
}
# Episode summary on termination
if done:
info["episode_summary"] = self.reward_engine.compute_episode_summary(
self.state
)
from env.utils import compute_metrics
info["metrics"] = compute_metrics(self.state.to_dict())
# Track episode for curriculum learning
self.episode_history.append({
"episode": self.episode_count,
"reward": self.episode_reward,
"difficulty": self.difficulty,
"metrics": info["metrics"],
})
return obs, reward, done, info
def get_valid_actions(self) -> List[Tuple[str, int]]:
"""Get valid actions for the current state."""
if self.state is None:
return [("defer_task", 0)]
return get_valid_actions(self.state.to_dict())
def get_action_mask(self) -> List[int]:
"""Get binary action mask for the current state."""
if self.state is None:
return [0] * 6
return get_action_mask(self.state.to_dict())
def get_state(self) -> Dict:
"""Get current observation (partial observability enforced)."""
if self.state is None:
return {}
return self.state.to_dict()
def get_full_state(self) -> Dict:
"""Get full state including hidden elements (for debugging)."""
if self.state is None:
return {}
return self.state.full_dict()
def get_conflict_graph(self) -> Dict[int, list]:
"""Get the current task conflict graph."""
if self.state is None:
return {}
return self.scheduler.get_conflict_graph(self.state)
def get_schedule_summary(self) -> list:
"""Get sorted schedule summary for visualization."""
if self.state is None:
return []
return self.scheduler.get_schedule_summary(self.state)
def _update_curriculum(self):
"""Auto-scale difficulty based on recent performance."""
if not self.episode_history:
return
latest = self.episode_history[-1]
metrics = latest.get("metrics", {})
efficiency = metrics.get("efficiency_score", 0)
if efficiency >= 60:
self._consecutive_good_episodes += 1
else:
self._consecutive_good_episodes = 0
if self._consecutive_good_episodes >= self._curriculum_threshold:
if self.difficulty == "easy":
self.difficulty = "medium"
self.generator.set_difficulty("medium")
self._consecutive_good_episodes = 0
elif self.difficulty == "medium":
self.difficulty = "hard"
self.generator.set_difficulty("hard")
self._consecutive_good_episodes = 0
def __repr__(self) -> str:
return (
f"ExecutiveAssistantEnv("
f"difficulty={self.difficulty}, "
f"max_steps={self.max_steps}, "
f"episode={self.episode_count})"
)