mahammadaftab's picture
clean initial commit
62851e9
"""
Training script for the Executive Assistant environment.
Trains all 3 agents (Random, Rule-Based, Q-Learning) and compares performance.
Supports curriculum learning with auto-difficulty scaling.
"""
import sys
import os
import io
# Fix Windows console encoding for Unicode output
if sys.stdout.encoding != "utf-8":
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env.assistant_env import ExecutiveAssistantEnv
from agents.random_agent import RandomAgent
from agents.rule_based_agent import RuleBasedAgent
from agents.rl_agent import RLAgent
from training.plots import plot_rewards, plot_comparison, plot_metrics
def train_agent(env, agent, num_episodes=100, is_rl=False, label="Agent"):
"""Train/evaluate an agent for N episodes.
Args:
env: ExecutiveAssistantEnv instance.
agent: Agent with an act(state) method.
num_episodes: Number of episodes to run.
is_rl: If True, call agent.learn() after each step.
label: Name for logging.
Returns:
Dict with rewards, metrics per episode.
"""
rewards = []
all_metrics = []
for ep in range(num_episodes):
state = env.reset()
total_reward = 0.0
done = False
while not done:
action = agent.act(state)
next_state, reward, done, info = env.step(action)
if is_rl:
agent.learn(reward, next_state, done)
total_reward += reward
state = next_state
rewards.append(total_reward)
# Collect metrics from final info
metrics = info.get("metrics", {})
all_metrics.append(metrics)
if is_rl:
agent.decay_epsilon()
# Progress logging
if (ep + 1) % 25 == 0:
avg_reward = sum(rewards[-25:]) / min(25, len(rewards))
print(
f" [{label}] Episode {ep + 1}/{num_episodes} | "
f"Avg Reward (last 25): {avg_reward:.1f} | "
f"Difficulty: {env.difficulty}"
)
return {"rewards": rewards, "metrics": all_metrics}
def main():
"""Run full training comparison."""
NUM_EPISODES = 200
os.makedirs("logs", exist_ok=True)
print("=" * 60)
print("πŸš€ AI Executive Assistant Simulator β€” Training")
print("=" * 60)
# ─── Random Agent ─────────────────────────────────────────
print("\nπŸ“Š Training Random Agent...")
env_random = ExecutiveAssistantEnv(difficulty="medium", seed=42)
random_agent = RandomAgent(seed=42)
random_results = train_agent(
env_random, random_agent, NUM_EPISODES, is_rl=False, label="Random"
)
# ─── Rule-Based Agent ─────────────────────────────────────
print("\nπŸ“Š Evaluating Rule-Based Agent...")
env_rule = ExecutiveAssistantEnv(difficulty="medium", seed=42)
rule_agent = RuleBasedAgent()
rule_results = train_agent(
env_rule, rule_agent, NUM_EPISODES, is_rl=False, label="RuleBased"
)
# ─── Q-Learning Agent ─────────────────────────────────────
print("\nπŸ“Š Training Q-Learning Agent...")
env_rl = ExecutiveAssistantEnv(
difficulty="easy", auto_curriculum=True, seed=42
)
rl_agent = RLAgent(
learning_rate=0.1,
discount_factor=0.95,
epsilon=1.0,
epsilon_decay=0.99,
seed=42,
)
rl_results = train_agent(
env_rl, rl_agent, NUM_EPISODES, is_rl=True, label="QLearning"
)
# ─── Results Summary ──────────────────────────────────────
print("\n" + "=" * 60)
print("πŸ“ˆ RESULTS SUMMARY")
print("=" * 60)
for name, results in [
("Random", random_results),
("Rule-Based", rule_results),
("Q-Learning", rl_results),
]:
rews = results["rewards"]
avg = sum(rews) / len(rews)
last_50_avg = sum(rews[-50:]) / min(50, len(rews))
best = max(rews)
worst = min(rews)
final_metrics = results["metrics"][-1] if results["metrics"] else {}
print(f"\n πŸ€– {name}")
print(f" Avg Reward: {avg:.1f}")
print(f" Last 50 Avg: {last_50_avg:.1f}")
print(f" Best Episode: {best:.1f}")
print(f" Worst Episode: {worst:.1f}")
if final_metrics:
print(f" Completion: {final_metrics.get('task_completion_rate', 0):.1%}")
print(f" Hi-Pri Compl: {final_metrics.get('high_priority_completion', 0):.1%}")
print(f" Msg Response: {final_metrics.get('message_response_rate', 0):.1%}")
print(f" Efficiency: {final_metrics.get('efficiency_score', 0):.1f}/100")
if rl_agent:
stats = rl_agent.get_stats()
print(f"\n 🧠 Q-Learning Stats")
print(f" Q-Table States: {stats['q_table_states']}")
print(f" Q-Table Entries: {stats['q_table_entries']}")
print(f" Final Epsilon: {stats['epsilon']}")
# ─── Generate Plots ───────────────────────────────────────
print("\nπŸ“Š Generating plots...")
plot_rewards(
{
"Random": random_results["rewards"],
"Rule-Based": rule_results["rewards"],
"Q-Learning": rl_results["rewards"],
},
save_path="logs/reward_curves.png",
)
# Compute averages for comparison
comparison_data = {}
for name, results in [
("Random", random_results),
("Rule-Based", rule_results),
("Q-Learning", rl_results),
]:
rews = results["rewards"]
mets = results["metrics"]
comparison_data[name] = {
"avg_reward": sum(rews) / len(rews),
"task_completion": (
sum(m.get("task_completion_rate", 0) for m in mets) / len(mets)
if mets else 0
),
"message_response": (
sum(m.get("message_response_rate", 0) for m in mets) / len(mets)
if mets else 0
),
"efficiency": (
sum(m.get("efficiency_score", 0) for m in mets) / len(mets)
if mets else 0
),
}
plot_comparison(comparison_data, save_path="logs/agent_comparison.png")
# Metrics for last agent
if rl_results["metrics"]:
plot_metrics(rl_results["metrics"][-1], save_path="logs/rl_metrics.png")
print("\nβœ… Training complete! Plots saved to logs/")
print("=" * 60)
if __name__ == "__main__":
main()