| """ |
| Tabular Q-Learning RL agent. |
| |
| Uses state hashing and epsilon-greedy exploration with action masking. |
| Designed as a lightweight RL agent that can learn effective strategies |
| without requiring deep learning frameworks. |
| """ |
|
|
| import random |
| import json |
| import hashlib |
| from typing import Dict, Tuple, Optional |
| from collections import defaultdict |
|
|
|
|
| class RLAgent: |
| """Tabular Q-learning agent with epsilon-greedy exploration. |
| |
| Features: |
| - State hashing for tabular lookup. |
| - Action masking (only considers valid actions). |
| - Epsilon decay for explore → exploit transition. |
| - Learning rate and discount factor tuning. |
| """ |
|
|
| def __init__( |
| self, |
| learning_rate: float = 0.1, |
| discount_factor: float = 0.95, |
| epsilon: float = 1.0, |
| epsilon_min: float = 0.05, |
| epsilon_decay: float = 0.995, |
| seed: Optional[int] = None, |
| ): |
| """Initialize the Q-learning agent. |
| |
| Args: |
| learning_rate: Alpha for Q-value updates. |
| discount_factor: Gamma for future reward discounting. |
| epsilon: Initial exploration rate. |
| epsilon_min: Minimum exploration rate. |
| epsilon_decay: Multiplicative decay per episode. |
| seed: Random seed. |
| """ |
| self.lr = learning_rate |
| self.gamma = discount_factor |
| self.epsilon = epsilon |
| self.epsilon_min = epsilon_min |
| self.epsilon_decay = epsilon_decay |
|
|
| |
| self.q_table: Dict[str, Dict[str, float]] = defaultdict( |
| lambda: defaultdict(float) |
| ) |
|
|
| |
| self.last_state_hash: Optional[str] = None |
| self.last_action_key: Optional[str] = None |
|
|
| if seed is not None: |
| random.seed(seed) |
|
|
| def _hash_state(self, state: Dict) -> str: |
| """Create a compact hash of the state for table lookup. |
| |
| We hash key features rather than the full state for generalization: |
| - Current time |
| - Number of pending/completed/missed tasks by priority |
| - Number of unreplied messages by urgency |
| """ |
| features = { |
| "time": state.get("time", ""), |
| "pending_high": sum( |
| 1 for t in state.get("tasks", []) |
| if t["status"] == "pending" and t["priority"] == "high" |
| ), |
| "pending_med": sum( |
| 1 for t in state.get("tasks", []) |
| if t["status"] == "pending" and t["priority"] == "medium" |
| ), |
| "pending_low": sum( |
| 1 for t in state.get("tasks", []) |
| if t["status"] == "pending" and t["priority"] == "low" |
| ), |
| "completed": sum( |
| 1 for t in state.get("tasks", []) |
| if t["status"] == "completed" |
| ), |
| "missed": sum( |
| 1 for t in state.get("tasks", []) |
| if t["status"] == "missed" |
| ), |
| "urgent_unreplied": sum( |
| 1 for m in state.get("inbox", []) |
| if m.get("urgency") == "high" and not m.get("replied", False) |
| ), |
| "normal_unreplied": sum( |
| 1 for m in state.get("inbox", []) |
| if m.get("urgency") != "high" and not m.get("replied", False) |
| ), |
| } |
|
|
| feature_str = json.dumps(features, sort_keys=True) |
| return hashlib.md5(feature_str.encode()).hexdigest()[:12] |
|
|
| def _action_key(self, action: Tuple[str, int]) -> str: |
| """Convert action tuple to a string key for Q-table lookup.""" |
| return f"{action[0]}:{action[1]}" |
|
|
| def _parse_action_key(self, key: str) -> Tuple[str, int]: |
| """Convert action key back to tuple.""" |
| parts = key.split(":") |
| return (parts[0], int(parts[1])) |
|
|
| def act(self, state: Dict) -> Tuple[str, int]: |
| """Choose an action using epsilon-greedy policy with action masking. |
| |
| Args: |
| state: Observation dict from the environment. |
| |
| Returns: |
| (action_type, target_id) tuple. |
| """ |
| valid_actions = state.get("valid_actions", []) |
| if not valid_actions: |
| return ("defer_task", 0) |
|
|
| state_hash = self._hash_state(state) |
|
|
| |
| if random.random() < self.epsilon: |
| action = random.choice(valid_actions) |
| else: |
| |
| q_values = self.q_table[state_hash] |
| best_action = None |
| best_q = float("-inf") |
|
|
| for va in valid_actions: |
| ak = self._action_key(va) |
| q = q_values[ak] |
| if q > best_q: |
| best_q = q |
| best_action = va |
|
|
| action = best_action if best_action else random.choice(valid_actions) |
|
|
| |
| self.last_state_hash = state_hash |
| self.last_action_key = self._action_key(action) |
|
|
| return action |
|
|
| def learn( |
| self, |
| reward: float, |
| next_state: Dict, |
| done: bool, |
| ): |
| """Update Q-values using the Q-learning update rule. |
| |
| Q(s,a) ← Q(s,a) + α[r + γ·max_a' Q(s',a') - Q(s,a)] |
| |
| Args: |
| reward: Reward received from the last action. |
| next_state: New observation after the action. |
| done: Whether the episode ended. |
| """ |
| if self.last_state_hash is None or self.last_action_key is None: |
| return |
|
|
| current_q = self.q_table[self.last_state_hash][self.last_action_key] |
|
|
| if done: |
| target = reward |
| else: |
| |
| next_hash = self._hash_state(next_state) |
| next_valid = next_state.get("valid_actions", []) |
| if next_valid: |
| max_next_q = max( |
| self.q_table[next_hash][self._action_key(a)] |
| for a in next_valid |
| ) |
| else: |
| max_next_q = 0.0 |
|
|
| target = reward + self.gamma * max_next_q |
|
|
| |
| self.q_table[self.last_state_hash][self.last_action_key] = ( |
| current_q + self.lr * (target - current_q) |
| ) |
|
|
| def decay_epsilon(self): |
| """Decay exploration rate after each episode.""" |
| self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay) |
|
|
| def get_q_table_size(self) -> int: |
| """Return the number of unique states seen.""" |
| return len(self.q_table) |
|
|
| def get_stats(self) -> Dict: |
| """Return agent statistics.""" |
| total_entries = sum(len(v) for v in self.q_table.values()) |
| return { |
| "q_table_states": len(self.q_table), |
| "q_table_entries": total_entries, |
| "epsilon": round(self.epsilon, 4), |
| } |
|
|
| def __repr__(self): |
| return ( |
| f"RLAgent(lr={self.lr}, gamma={self.gamma}, " |
| f"epsilon={self.epsilon:.3f}, states={len(self.q_table)})" |
| ) |
|
|