""" Action definitions and action masking for the Executive Assistant environment. Provides the action vocabulary, parsing, and validity checking to prevent agents from taking illegal actions. """ from typing import List, Dict, Tuple, Optional # ─── Action Vocabulary ──────────────────────────────────────────────────────── ACTIONS = [ "schedule_task", # 0: Schedule a pending task into a time slot "complete_task", # 1: Mark a scheduled/pending task as completed "defer_task", # 2: Defer a task to a later time "send_reply", # 3: Reply to an inbox message "reject_task", # 4: Reject / cancel a task "ask_clarification", # 5: Ask for clarification on a task or message ] ACTION_TO_IDX = {a: i for i, a in enumerate(ACTIONS)} IDX_TO_ACTION = {i: a for i, a in enumerate(ACTIONS)} # ─── Action Parsing ────────────────────────────────────────────────────────── def parse_action(action_input) -> Tuple[str, int]: """Parse various action input formats into (action_type, target_id). Supports: - Tuple/list: ("complete_task", 3) - Dict: {"action": "complete_task", "target": 3} - String: "complete_task" (target defaults to 0) - Int: 1 (maps to action index, target defaults to 0) Returns: (action_type_str, target_id) """ if isinstance(action_input, (tuple, list)): action_type = str(action_input[0]) target_id = int(action_input[1]) if len(action_input) > 1 else 0 elif isinstance(action_input, dict): action_type = action_input.get("action", "defer_task") target_id = int(action_input.get("target", 0)) elif isinstance(action_input, int): action_type = IDX_TO_ACTION.get(action_input, "defer_task") target_id = 0 elif isinstance(action_input, str): action_type = action_input target_id = 0 else: action_type = "defer_task" target_id = 0 # Validate action type if action_type not in ACTIONS: action_type = "defer_task" return action_type, target_id # ─── Action Masking ────────────────────────────────────────────────────────── def get_valid_actions(state_dict: Dict) -> List[Tuple[str, int]]: """Return all legal (action_type, target_id) pairs for the current state. Action masking rules: - schedule_task: only for pending tasks - complete_task: only for pending or scheduled tasks - defer_task: only for pending tasks - send_reply: only for unreplied messages - reject_task: only for pending tasks - ask_clarification: for any pending task or unreplied message """ valid = [] tasks = state_dict.get("tasks", []) inbox = state_dict.get("inbox", []) pending_tasks = [t for t in tasks if t["status"] == "pending"] scheduled_tasks = [t for t in tasks if t["status"] == "scheduled"] unreplied_msgs = [m for m in inbox if not m.get("replied", False)] # schedule_task — pending tasks only for t in pending_tasks: valid.append(("schedule_task", t["id"])) # complete_task — pending or scheduled tasks for t in pending_tasks + scheduled_tasks: valid.append(("complete_task", t["id"])) # defer_task — pending tasks only for t in pending_tasks: valid.append(("defer_task", t["id"])) # send_reply — unreplied messages only for m in unreplied_msgs: valid.append(("send_reply", m["id"])) # reject_task — pending tasks only for t in pending_tasks: valid.append(("reject_task", t["id"])) # ask_clarification — pending tasks or unreplied messages for t in pending_tasks: valid.append(("ask_clarification", t["id"])) for m in unreplied_msgs: valid.append(("ask_clarification", m["id"])) # If no valid actions exist, allow a no-op defer if not valid: valid.append(("defer_task", 0)) return valid def is_valid_action( action_type: str, target_id: int, state_dict: Dict ) -> bool: """Check if a specific action is valid in the current state.""" valid = get_valid_actions(state_dict) return (action_type, target_id) in valid def get_action_mask(state_dict: Dict) -> List[int]: """Get a binary mask over the action space. Returns a list of 0s and 1s for each action index, where 1 means at least one valid target exists for that action type. """ valid = get_valid_actions(state_dict) valid_types = set(a[0] for a in valid) return [1 if action in valid_types else 0 for action in ACTIONS]