mahammadaftab's picture
clean initial commit
62851e9
"""
Action definitions and action masking for the Executive Assistant environment.
Provides the action vocabulary, parsing, and validity checking to prevent
agents from taking illegal actions.
"""
from typing import List, Dict, Tuple, Optional
# ─── Action Vocabulary ────────────────────────────────────────────────────────
ACTIONS = [
"schedule_task", # 0: Schedule a pending task into a time slot
"complete_task", # 1: Mark a scheduled/pending task as completed
"defer_task", # 2: Defer a task to a later time
"send_reply", # 3: Reply to an inbox message
"reject_task", # 4: Reject / cancel a task
"ask_clarification", # 5: Ask for clarification on a task or message
]
ACTION_TO_IDX = {a: i for i, a in enumerate(ACTIONS)}
IDX_TO_ACTION = {i: a for i, a in enumerate(ACTIONS)}
# ─── Action Parsing ──────────────────────────────────────────────────────────
def parse_action(action_input) -> Tuple[str, int]:
"""Parse various action input formats into (action_type, target_id).
Supports:
- Tuple/list: ("complete_task", 3)
- Dict: {"action": "complete_task", "target": 3}
- String: "complete_task" (target defaults to 0)
- Int: 1 (maps to action index, target defaults to 0)
Returns:
(action_type_str, target_id)
"""
if isinstance(action_input, (tuple, list)):
action_type = str(action_input[0])
target_id = int(action_input[1]) if len(action_input) > 1 else 0
elif isinstance(action_input, dict):
action_type = action_input.get("action", "defer_task")
target_id = int(action_input.get("target", 0))
elif isinstance(action_input, int):
action_type = IDX_TO_ACTION.get(action_input, "defer_task")
target_id = 0
elif isinstance(action_input, str):
action_type = action_input
target_id = 0
else:
action_type = "defer_task"
target_id = 0
# Validate action type
if action_type not in ACTIONS:
action_type = "defer_task"
return action_type, target_id
# ─── Action Masking ──────────────────────────────────────────────────────────
def get_valid_actions(state_dict: Dict) -> List[Tuple[str, int]]:
"""Return all legal (action_type, target_id) pairs for the current state.
Action masking rules:
- schedule_task: only for pending tasks
- complete_task: only for pending or scheduled tasks
- defer_task: only for pending tasks
- send_reply: only for unreplied messages
- reject_task: only for pending tasks
- ask_clarification: for any pending task or unreplied message
"""
valid = []
tasks = state_dict.get("tasks", [])
inbox = state_dict.get("inbox", [])
pending_tasks = [t for t in tasks if t["status"] == "pending"]
scheduled_tasks = [t for t in tasks if t["status"] == "scheduled"]
unreplied_msgs = [m for m in inbox if not m.get("replied", False)]
# schedule_task β€” pending tasks only
for t in pending_tasks:
valid.append(("schedule_task", t["id"]))
# complete_task β€” pending or scheduled tasks
for t in pending_tasks + scheduled_tasks:
valid.append(("complete_task", t["id"]))
# defer_task β€” pending tasks only
for t in pending_tasks:
valid.append(("defer_task", t["id"]))
# send_reply β€” unreplied messages only
for m in unreplied_msgs:
valid.append(("send_reply", m["id"]))
# reject_task β€” pending tasks only
for t in pending_tasks:
valid.append(("reject_task", t["id"]))
# ask_clarification β€” pending tasks or unreplied messages
for t in pending_tasks:
valid.append(("ask_clarification", t["id"]))
for m in unreplied_msgs:
valid.append(("ask_clarification", m["id"]))
# If no valid actions exist, allow a no-op defer
if not valid:
valid.append(("defer_task", 0))
return valid
def is_valid_action(
action_type: str, target_id: int, state_dict: Dict
) -> bool:
"""Check if a specific action is valid in the current state."""
valid = get_valid_actions(state_dict)
return (action_type, target_id) in valid
def get_action_mask(state_dict: Dict) -> List[int]:
"""Get a binary mask over the action space.
Returns a list of 0s and 1s for each action index,
where 1 means at least one valid target exists for that action type.
"""
valid = get_valid_actions(state_dict)
valid_types = set(a[0] for a in valid)
return [1 if action in valid_types else 0 for action in ACTIONS]