Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from src.executive_assistant.agent import BaselineAgent | |
| from src.executive_assistant.config import TrainingRuntimeConfig, load_env_file | |
| from src.executive_assistant.training import evaluate_q_policy, train_q_learning | |
| def main() -> None: | |
| load_env_file(TrainingRuntimeConfig().env_file) | |
| parser = argparse.ArgumentParser(description="Train a tabular RL policy for seeded tasks.") | |
| parser.add_argument("--episodes", type=int, default=300) | |
| parser.add_argument("--epsilon", type=float, default=0.15) | |
| parser.add_argument("--checkpoint", default="artifacts/checkpoints/q_policy.json") | |
| parser.add_argument("--no-teacher", action="store_true") | |
| args = parser.parse_args() | |
| teacher = None if args.no_teacher else BaselineAgent() | |
| policy, training_scores = train_q_learning( | |
| episodes=args.episodes, | |
| epsilon=args.epsilon, | |
| teacher=teacher, | |
| ) | |
| checkpoint_path = policy.save(args.checkpoint) | |
| evaluation = evaluate_q_policy(policy) | |
| print( | |
| json.dumps( | |
| { | |
| "checkpoint": str(checkpoint_path), | |
| "training_scores": training_scores, | |
| "evaluation": evaluation, | |
| }, | |
| indent=2, | |
| ) | |
| ) | |
| if __name__ == "__main__": | |
| main() | |