EmailMaestro / scripts /train_rl_agent.py
Flickinshots's picture
Deploy Project Epsilon Space bundle
38c9982 verified
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.executive_assistant.agent import BaselineAgent
from src.executive_assistant.config import TrainingRuntimeConfig, load_env_file
from src.executive_assistant.training import evaluate_q_policy, train_q_learning
def main() -> None:
load_env_file(TrainingRuntimeConfig().env_file)
parser = argparse.ArgumentParser(description="Train a tabular RL policy for seeded tasks.")
parser.add_argument("--episodes", type=int, default=300)
parser.add_argument("--epsilon", type=float, default=0.15)
parser.add_argument("--checkpoint", default="artifacts/checkpoints/q_policy.json")
parser.add_argument("--no-teacher", action="store_true")
args = parser.parse_args()
teacher = None if args.no_teacher else BaselineAgent()
policy, training_scores = train_q_learning(
episodes=args.episodes,
epsilon=args.epsilon,
teacher=teacher,
)
checkpoint_path = policy.save(args.checkpoint)
evaluation = evaluate_q_policy(policy)
print(
json.dumps(
{
"checkpoint": str(checkpoint_path),
"training_scores": training_scores,
"evaluation": evaluation,
},
indent=2,
)
)
if __name__ == "__main__":
main()