| """ |
| Action definitions and action masking for the Executive Assistant environment. |
| |
| Provides the action vocabulary, parsing, and validity checking to prevent |
| agents from taking illegal actions. |
| """ |
|
|
| from typing import List, Dict, Tuple, Optional |
|
|
|
|
| |
|
|
| ACTIONS = [ |
| "schedule_task", |
| "complete_task", |
| "defer_task", |
| "send_reply", |
| "reject_task", |
| "ask_clarification", |
| ] |
|
|
| ACTION_TO_IDX = {a: i for i, a in enumerate(ACTIONS)} |
| IDX_TO_ACTION = {i: a for i, a in enumerate(ACTIONS)} |
|
|
|
|
| |
|
|
| def parse_action(action_input) -> Tuple[str, int]: |
| """Parse various action input formats into (action_type, target_id). |
| |
| Supports: |
| - Tuple/list: ("complete_task", 3) |
| - Dict: {"action": "complete_task", "target": 3} |
| - String: "complete_task" (target defaults to 0) |
| - Int: 1 (maps to action index, target defaults to 0) |
| |
| Returns: |
| (action_type_str, target_id) |
| """ |
| if isinstance(action_input, (tuple, list)): |
| action_type = str(action_input[0]) |
| target_id = int(action_input[1]) if len(action_input) > 1 else 0 |
| elif isinstance(action_input, dict): |
| action_type = action_input.get("action", "defer_task") |
| target_id = int(action_input.get("target", 0)) |
| elif isinstance(action_input, int): |
| action_type = IDX_TO_ACTION.get(action_input, "defer_task") |
| target_id = 0 |
| elif isinstance(action_input, str): |
| action_type = action_input |
| target_id = 0 |
| else: |
| action_type = "defer_task" |
| target_id = 0 |
|
|
| |
| if action_type not in ACTIONS: |
| action_type = "defer_task" |
|
|
| return action_type, target_id |
|
|
|
|
| |
|
|
| def get_valid_actions(state_dict: Dict) -> List[Tuple[str, int]]: |
| """Return all legal (action_type, target_id) pairs for the current state. |
| |
| Action masking rules: |
| - schedule_task: only for pending tasks |
| - complete_task: only for pending or scheduled tasks |
| - defer_task: only for pending tasks |
| - send_reply: only for unreplied messages |
| - reject_task: only for pending tasks |
| - ask_clarification: for any pending task or unreplied message |
| """ |
| valid = [] |
| tasks = state_dict.get("tasks", []) |
| inbox = state_dict.get("inbox", []) |
|
|
| pending_tasks = [t for t in tasks if t["status"] == "pending"] |
| scheduled_tasks = [t for t in tasks if t["status"] == "scheduled"] |
| unreplied_msgs = [m for m in inbox if not m.get("replied", False)] |
|
|
| |
| for t in pending_tasks: |
| valid.append(("schedule_task", t["id"])) |
|
|
| |
| for t in pending_tasks + scheduled_tasks: |
| valid.append(("complete_task", t["id"])) |
|
|
| |
| for t in pending_tasks: |
| valid.append(("defer_task", t["id"])) |
|
|
| |
| for m in unreplied_msgs: |
| valid.append(("send_reply", m["id"])) |
|
|
| |
| for t in pending_tasks: |
| valid.append(("reject_task", t["id"])) |
|
|
| |
| for t in pending_tasks: |
| valid.append(("ask_clarification", t["id"])) |
| for m in unreplied_msgs: |
| valid.append(("ask_clarification", m["id"])) |
|
|
| |
| if not valid: |
| valid.append(("defer_task", 0)) |
|
|
| return valid |
|
|
|
|
| def is_valid_action( |
| action_type: str, target_id: int, state_dict: Dict |
| ) -> bool: |
| """Check if a specific action is valid in the current state.""" |
| valid = get_valid_actions(state_dict) |
| return (action_type, target_id) in valid |
|
|
|
|
| def get_action_mask(state_dict: Dict) -> List[int]: |
| """Get a binary mask over the action space. |
| |
| Returns a list of 0s and 1s for each action index, |
| where 1 means at least one valid target exists for that action type. |
| """ |
| valid = get_valid_actions(state_dict) |
| valid_types = set(a[0] for a in valid) |
| return [1 if action in valid_types else 0 for action in ACTIONS] |
|
|