Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """REINFORCE with Learned Baseline -- RL training for PolypharmacyEnv. | |
| Trains a small neural-network policy to perform medication reviews in the | |
| PolypharmacyEnv environment. The policy learns to query drug-drug interactions, | |
| propose clinical interventions, and decide when to finalise the review. | |
| Usage examples: | |
| python train_rl.py --task easy_screening --episodes 200 | |
| python train_rl.py --task budgeted_screening --episodes 500 | |
| python train_rl.py --task complex_tradeoff --episodes 1000 | |
| python train_rl.py --task easy_screening --episodes 300 --lr 5e-4 --batch-size 8 | |
| Architecture: | |
| - Fixed-size state encoding (16-dim global summary features) | |
| - Fixed 166-dim action space with dynamic validity masking | |
| - 3-layer MLP policy (state -> logits over actions) | |
| - 3-layer MLP value baseline (state -> scalar return estimate) | |
| - REINFORCE gradient with advantage = (discounted return) - baseline | |
| - Entropy bonus for sustained exploration | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from itertools import combinations | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Set, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Categorical | |
| # --------------------------------------------------------------------------- | |
| # Environment imports (direct, no HTTP) | |
| # --------------------------------------------------------------------------- | |
| _BACKEND_SRC = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "backend", "src" | |
| ) | |
| sys.path.insert(0, _BACKEND_SRC) | |
| from polypharmacy_env.env_core import PolypharmacyEnv # noqa: E402 | |
| from polypharmacy_env.models import ( # noqa: E402 | |
| PolypharmacyAction, | |
| PolypharmacyObservation, | |
| ) | |
| from polypharmacy_env.config import TASK_CONFIGS, TaskConfig # noqa: E402 | |
| # --------------------------------------------------------------------------- | |
| # Constants -- action-space geometry | |
| # --------------------------------------------------------------------------- | |
| MAX_MEDS = 15 # upper bound across all task difficulties | |
| INTERVENTION_TYPES: List[str] = [ | |
| "stop", | |
| "dose_reduce", | |
| "substitute", | |
| "add_monitoring", | |
| ] | |
| N_INTERVENTION_TYPES = len(INTERVENTION_TYPES) | |
| # Pre-compute the mapping (med_position_i, med_position_j) -> flat action index | |
| # for all possible query_ddi pairs where i < j. | |
| _PAIR_INDEX: Dict[Tuple[int, int], int] = {} | |
| _idx = 0 | |
| for _i in range(MAX_MEDS): | |
| for _j in range(_i + 1, MAX_MEDS): | |
| _PAIR_INDEX[(_i, _j)] = _idx | |
| _idx += 1 | |
| N_PAIRS = _idx # C(15,2) = 105 | |
| _REVERSE_PAIR: Dict[int, Tuple[int, int]] = {v: k for k, v in _PAIR_INDEX.items()} | |
| N_INTERVENTIONS = MAX_MEDS * N_INTERVENTION_TYPES # 60 | |
| FINISH_IDX = N_PAIRS + N_INTERVENTIONS # 165 | |
| N_ACTIONS = FINISH_IDX + 1 # 166 | |
| # State feature vector length (see encode_state) | |
| STATE_DIM = 16 | |
| # --------------------------------------------------------------------------- | |
| # State encoding | |
| # --------------------------------------------------------------------------- | |
| def encode_state(obs: PolypharmacyObservation, task_cfg: TaskConfig) -> torch.Tensor: | |
| """Encode the observation into a compact 16-dim feature vector. | |
| All values are normalised to roughly [0, 1] to help gradient flow. | |
| """ | |
| meds = obs.current_medications | |
| n_meds = len(meds) | |
| n_high_risk = sum(1 for m in meds if m.is_high_risk_elderly) | |
| n_beers_any = sum(1 for m in meds if m.beers_flags) | |
| n_beers_avoid = sum( | |
| 1 for m in meds if any("avoid" in f for f in m.beers_flags) | |
| ) | |
| queries = obs.interaction_queries | |
| n_queries = len(queries) | |
| n_severe = sum(1 for q in queries if q.severity == "severe") | |
| n_moderate = sum(1 for q in queries if q.severity == "moderate") | |
| n_interventions = len(obs.interventions) | |
| max_possible_pairs = max(n_meds * (n_meds - 1) // 2, 1) | |
| # Drugs involved in any discovered severe DDI (among current meds) | |
| current_ids = {m.drug_id for m in meds} | |
| drugs_in_severe: Set[str] = set() | |
| for q in queries: | |
| if q.severity == "severe": | |
| if q.drug_id_1 in current_ids: | |
| drugs_in_severe.add(q.drug_id_1) | |
| if q.drug_id_2 in current_ids: | |
| drugs_in_severe.add(q.drug_id_2) | |
| features = [ | |
| n_meds / MAX_MEDS, | |
| n_high_risk / max(n_meds, 1), | |
| n_beers_any / max(n_meds, 1), | |
| n_beers_avoid / max(n_meds, 1), | |
| obs.remaining_query_budget / max(task_cfg.query_budget, 1), | |
| obs.remaining_intervention_budget / max(task_cfg.intervention_budget, 1), | |
| n_queries / max(task_cfg.query_budget, 1), | |
| n_severe / max(n_queries, 1), | |
| n_moderate / max(n_queries, 1), | |
| n_interventions / max(task_cfg.intervention_budget, 1), | |
| obs.step_index / max(task_cfg.max_steps, 1), | |
| n_queries / max_possible_pairs, # fraction of pairs queried | |
| float(obs.remaining_query_budget > 0), | |
| float(obs.remaining_intervention_budget > 0), | |
| len(drugs_in_severe) / max(n_meds, 1), # how much of the regimen is "hot" | |
| float(n_meds <= 2), # very few meds left -- may be time to finish | |
| ] | |
| return torch.tensor(features, dtype=torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Action-space helpers | |
| # --------------------------------------------------------------------------- | |
| def get_action_mask(obs: PolypharmacyObservation) -> torch.Tensor: | |
| """Return a bool tensor of shape (N_ACTIONS,). True = action is valid.""" | |
| mask = torch.zeros(N_ACTIONS, dtype=torch.bool) | |
| meds = obs.current_medications | |
| n_meds = min(len(meds), MAX_MEDS) | |
| # Already-queried drug-id pairs (order-invariant) | |
| queried: Set[frozenset] = set() | |
| for q in obs.interaction_queries: | |
| queried.add(frozenset((q.drug_id_1, q.drug_id_2))) | |
| # --- query_ddi actions --- | |
| if obs.remaining_query_budget > 0 and n_meds >= 2: | |
| for i in range(n_meds): | |
| for j in range(i + 1, n_meds): | |
| pair_key = frozenset((meds[i].drug_id, meds[j].drug_id)) | |
| if pair_key not in queried: | |
| mask[_PAIR_INDEX[(i, j)]] = True | |
| # --- propose_intervention actions --- | |
| if obs.remaining_intervention_budget > 0: | |
| for i in range(n_meds): | |
| for k in range(N_INTERVENTION_TYPES): | |
| mask[N_PAIRS + i * N_INTERVENTION_TYPES + k] = True | |
| # --- finish_review (always valid) --- | |
| mask[FINISH_IDX] = True | |
| return mask | |
| def action_idx_to_env_action( | |
| idx: int, | |
| meds: list, | |
| ) -> PolypharmacyAction: | |
| """Map a flat action index back to a concrete PolypharmacyAction.""" | |
| if idx == FINISH_IDX: | |
| return PolypharmacyAction(action_type="finish_review") | |
| if idx < N_PAIRS: | |
| i, j = _REVERSE_PAIR[idx] | |
| return PolypharmacyAction( | |
| action_type="query_ddi", | |
| drug_id_1=meds[i].drug_id, | |
| drug_id_2=meds[j].drug_id, | |
| ) | |
| # Otherwise it is an intervention action | |
| rel = idx - N_PAIRS | |
| med_idx = rel // N_INTERVENTION_TYPES | |
| type_idx = rel % N_INTERVENTION_TYPES | |
| return PolypharmacyAction( | |
| action_type="propose_intervention", | |
| target_drug_id=meds[med_idx].drug_id, | |
| intervention_type=INTERVENTION_TYPES[type_idx], | |
| rationale="rl_policy", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Neural-network modules | |
| # --------------------------------------------------------------------------- | |
| class PolicyNetwork(nn.Module): | |
| """3-layer MLP that maps state features to action logits.""" | |
| def __init__( | |
| self, | |
| state_dim: int = STATE_DIM, | |
| action_dim: int = N_ACTIONS, | |
| hidden: int = 128, | |
| ) -> None: | |
| super().__init__() | |
| self.fc1 = nn.Linear(state_dim, hidden) | |
| self.fc2 = nn.Linear(hidden, hidden) | |
| self.fc3 = nn.Linear(hidden, action_dim) | |
| def forward( | |
| self, | |
| state: torch.Tensor, | |
| mask: torch.Tensor, | |
| ) -> Categorical: | |
| x = F.relu(self.fc1(state)) | |
| x = F.relu(self.fc2(x)) | |
| logits = self.fc3(x) | |
| logits = logits.masked_fill(~mask, float("-inf")) | |
| return Categorical(logits=logits) | |
| class ValueNetwork(nn.Module): | |
| """3-layer MLP baseline that estimates the expected return from a state.""" | |
| def __init__(self, state_dim: int = STATE_DIM, hidden: int = 128) -> None: | |
| super().__init__() | |
| self.fc1 = nn.Linear(state_dim, hidden) | |
| self.fc2 = nn.Linear(hidden, hidden // 2) | |
| self.fc3 = nn.Linear(hidden // 2, 1) | |
| def forward(self, state: torch.Tensor) -> torch.Tensor: | |
| x = F.relu(self.fc1(state)) | |
| x = F.relu(self.fc2(x)) | |
| return self.fc3(x).squeeze(-1) | |
| # --------------------------------------------------------------------------- | |
| # Episode rollout | |
| # --------------------------------------------------------------------------- | |
| def run_episode( | |
| env: PolypharmacyEnv, | |
| task_id: str, | |
| policy: PolicyNetwork, | |
| value_net: ValueNetwork, | |
| task_cfg: TaskConfig, | |
| seed: Optional[int] = None, | |
| greedy: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Roll out one full episode, collecting the REINFORCE trajectory. | |
| When *greedy* is True the policy acts deterministically (argmax) and | |
| gradients are not recorded. Used for evaluation. | |
| """ | |
| obs = env.reset(task_id=task_id, seed=seed) | |
| states: List[torch.Tensor] = [] | |
| actions: List[torch.Tensor] = [] | |
| log_probs: List[torch.Tensor] = [] | |
| rewards: List[float] = [] | |
| values: List[torch.Tensor] = [] | |
| entropies: List[torch.Tensor] = [] | |
| grader_score = 0.0 | |
| while not obs.done: | |
| state = encode_state(obs, task_cfg) | |
| mask = get_action_mask(obs) | |
| # Safety: if somehow no action is valid, force finish | |
| if not mask.any(): | |
| mask[FINISH_IDX] = True | |
| if greedy: | |
| with torch.no_grad(): | |
| dist = policy(state, mask) | |
| action_idx = dist.probs.argmax() | |
| value = value_net(state) | |
| else: | |
| with torch.no_grad(): | |
| value = value_net(state) | |
| dist = policy(state, mask) | |
| action_idx = dist.sample() | |
| log_prob = dist.log_prob(action_idx) | |
| entropy = dist.entropy() | |
| states.append(state) | |
| actions.append(action_idx) | |
| log_probs.append(log_prob) | |
| values.append(value) | |
| entropies.append(entropy) | |
| env_action = action_idx_to_env_action( | |
| action_idx.item(), obs.current_medications | |
| ) | |
| obs = env.step(env_action) | |
| reward = float(obs.reward) if obs.reward is not None else 0.0 | |
| rewards.append(reward) | |
| if obs.done: | |
| grader_score = obs.metadata.get("grader_score", 0.0) | |
| return { | |
| "states": states, | |
| "actions": actions, | |
| "log_probs": log_probs, | |
| "rewards": rewards, | |
| "values": values, | |
| "entropies": entropies, | |
| "grader_score": grader_score, | |
| "total_reward": sum(rewards), | |
| "n_steps": len(rewards), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Return computation | |
| # --------------------------------------------------------------------------- | |
| def compute_returns(rewards: List[float], gamma: float = 0.99) -> torch.Tensor: | |
| """Discounted cumulative returns (G_t) for each timestep.""" | |
| returns: List[float] = [] | |
| g = 0.0 | |
| for r in reversed(rewards): | |
| g = r + gamma * g | |
| returns.insert(0, g) | |
| return torch.tensor(returns, dtype=torch.float32) | |
| # --------------------------------------------------------------------------- | |
| # Training | |
| # --------------------------------------------------------------------------- | |
| def train(args: argparse.Namespace) -> None: # noqa: C901 (complex but linear) | |
| task_id: str = args.task | |
| n_episodes: int = args.episodes | |
| lr: float = args.lr | |
| gamma: float = args.gamma | |
| entropy_coeff: float = args.entropy_coeff | |
| batch_size: int = args.batch_size | |
| hidden_dim: int = args.hidden_dim | |
| print_every: int = args.print_every | |
| task_cfg = TASK_CONFIGS[task_id] | |
| # ---- Initialise env & networks ---------------------------------------- | |
| env = PolypharmacyEnv() | |
| policy = PolicyNetwork(STATE_DIM, N_ACTIONS, hidden=hidden_dim) | |
| value_net = ValueNetwork(STATE_DIM, hidden=hidden_dim) | |
| policy_optim = torch.optim.Adam(policy.parameters(), lr=lr) | |
| value_optim = torch.optim.Adam(value_net.parameters(), lr=lr * 3) | |
| # ---- Book-keeping ----------------------------------------------------- | |
| ckpt_dir = Path(args.checkpoint_dir) | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| metrics_path = Path(args.metrics_file) | |
| episode_rewards: List[float] = [] | |
| episode_grader_scores: List[float] = [] | |
| episode_steps: List[int] = [] | |
| episode_policy_losses: List[float] = [] | |
| episode_value_losses: List[float] = [] | |
| best_avg_score: float = -float("inf") | |
| print("=" * 72) | |
| print("REINFORCE Training -- PolypharmacyEnv") | |
| print("=" * 72) | |
| print(f" task : {task_id}") | |
| print(f" episodes : {n_episodes}") | |
| print(f" batch_size : {batch_size}") | |
| print(f" lr : {lr}") | |
| print(f" gamma : {gamma}") | |
| print(f" entropy_coeff : {entropy_coeff}") | |
| print(f" hidden_dim : {hidden_dim}") | |
| print(f" state_dim : {STATE_DIM}") | |
| print(f" action_space : {N_ACTIONS}") | |
| print(f" task budgets : query={task_cfg.query_budget} " | |
| f"intervention={task_cfg.intervention_budget} " | |
| f"max_steps={task_cfg.max_steps}") | |
| print(f" checkpoint_dir : {ckpt_dir}") | |
| print(f" metrics_file : {metrics_path}") | |
| print("=" * 72) | |
| print() | |
| t_start = time.time() | |
| # ---- Main training loop ----------------------------------------------- | |
| # Accumulate a mini-batch of trajectories, then perform one gradient step. | |
| batch_trajs: List[Dict[str, Any]] = [] | |
| for ep in range(1, n_episodes + 1): | |
| traj = run_episode(env, task_id, policy, value_net, task_cfg, seed=ep) | |
| episode_rewards.append(traj["total_reward"]) | |
| episode_grader_scores.append(traj["grader_score"]) | |
| episode_steps.append(traj["n_steps"]) | |
| if traj["n_steps"] == 0: | |
| # Degenerate episode (should not happen); skip update | |
| continue | |
| batch_trajs.append(traj) | |
| # ---- Gradient step every batch_size episodes ---------------------- | |
| if len(batch_trajs) >= batch_size: | |
| # Aggregate losses across the batch | |
| total_policy_loss = torch.tensor(0.0) | |
| total_value_loss = torch.tensor(0.0) | |
| total_entropy = torch.tensor(0.0) | |
| total_steps = 0 | |
| for bt in batch_trajs: | |
| returns = compute_returns(bt["rewards"], gamma) | |
| old_values_t = torch.stack(bt["values"]) # detached, from rollout | |
| log_probs_t = torch.stack(bt["log_probs"]) | |
| entropies_t = torch.stack(bt["entropies"]) | |
| # Advantages use the *detached* rollout values as baseline | |
| advantages = returns - old_values_t.detach() | |
| # Per-trajectory advantage normalisation (reduces variance) | |
| if len(advantages) > 1: | |
| advantages = (advantages - advantages.mean()) / ( | |
| advantages.std() + 1e-8 | |
| ) | |
| # REINFORCE policy gradient (negative because we minimise) | |
| total_policy_loss = total_policy_loss + ( | |
| -(log_probs_t * advantages).sum() | |
| ) | |
| # Recompute value predictions WITH gradients for the value loss | |
| states_t = torch.stack(bt["states"]) | |
| fresh_values = value_net(states_t) | |
| total_value_loss = total_value_loss + F.mse_loss( | |
| fresh_values, returns, reduction="sum" | |
| ) | |
| # Entropy (we want to maximise -> subtract from loss) | |
| total_entropy = total_entropy + entropies_t.sum() | |
| total_steps += len(bt["rewards"]) | |
| # Normalise by total number of timesteps in the batch | |
| total_policy_loss = total_policy_loss / total_steps | |
| total_value_loss = total_value_loss / total_steps | |
| total_entropy = total_entropy / total_steps | |
| # Combined policy loss with entropy bonus | |
| combined_policy_loss = total_policy_loss - entropy_coeff * total_entropy | |
| policy_optim.zero_grad() | |
| combined_policy_loss.backward() | |
| nn.utils.clip_grad_norm_(policy.parameters(), max_norm=1.0) | |
| policy_optim.step() | |
| value_optim.zero_grad() | |
| total_value_loss.backward() | |
| nn.utils.clip_grad_norm_(value_net.parameters(), max_norm=1.0) | |
| value_optim.step() | |
| episode_policy_losses.append(total_policy_loss.item()) | |
| episode_value_losses.append(total_value_loss.item()) | |
| batch_trajs = [] | |
| # ---- Logging ------------------------------------------------------ | |
| if ep % print_every == 0 or ep == 1: | |
| window = min(print_every, ep) | |
| recent_r = episode_rewards[-window:] | |
| recent_s = episode_grader_scores[-window:] | |
| recent_st = episode_steps[-window:] | |
| avg_r = sum(recent_r) / len(recent_r) | |
| avg_s = sum(recent_s) / len(recent_s) | |
| avg_st = sum(recent_st) / len(recent_st) | |
| elapsed = time.time() - t_start | |
| print( | |
| f"[ep {ep:>4d}/{n_episodes}] " | |
| f"avg_reward={avg_r:+.4f} " | |
| f"avg_grader={avg_s:.4f} " | |
| f"avg_steps={avg_st:.1f} " | |
| f"elapsed={elapsed:.1f}s" | |
| ) | |
| # Save best checkpoint based on rolling grader score | |
| eval_window = min(30, ep) | |
| rolling_score = sum(episode_grader_scores[-eval_window:]) / eval_window | |
| if rolling_score > best_avg_score: | |
| best_avg_score = rolling_score | |
| _save_checkpoint( | |
| policy, value_net, policy_optim, value_optim, | |
| ep, best_avg_score, task_id, | |
| ckpt_dir / f"best_{task_id}.pt", | |
| ) | |
| # ---- Final checkpoint ------------------------------------------------- | |
| _save_checkpoint( | |
| policy, value_net, policy_optim, value_optim, | |
| n_episodes, best_avg_score, task_id, | |
| ckpt_dir / f"final_{task_id}.pt", | |
| ) | |
| # ---- Save training metrics to JSON ------------------------------------ | |
| metrics = { | |
| "task_id": task_id, | |
| "n_episodes": n_episodes, | |
| "hyperparameters": { | |
| "lr": lr, | |
| "gamma": gamma, | |
| "entropy_coeff": entropy_coeff, | |
| "batch_size": batch_size, | |
| "hidden_dim": hidden_dim, | |
| "state_dim": STATE_DIM, | |
| "action_dim": N_ACTIONS, | |
| }, | |
| "episode_rewards": episode_rewards, | |
| "episode_grader_scores": episode_grader_scores, | |
| "episode_steps": episode_steps, | |
| "policy_losses": episode_policy_losses, | |
| "value_losses": episode_value_losses, | |
| "best_avg_grader_score": best_avg_score, | |
| "total_training_time_s": time.time() - t_start, | |
| } | |
| metrics_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(metrics_path, "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| print(f"\nTraining metrics saved to {metrics_path}") | |
| # ---- Post-training evaluation ----------------------------------------- | |
| n_eval = 20 | |
| print("\n" + "=" * 72) | |
| print(f"Post-training evaluation ({n_eval} episodes each mode)") | |
| print("=" * 72) | |
| for mode, is_greedy in [("stochastic", False), ("greedy", True)]: | |
| eval_rewards, eval_scores, eval_steps_list = [], [], [] | |
| for i in range(n_eval): | |
| traj = run_episode( | |
| env, task_id, policy, value_net, task_cfg, | |
| seed=10_000 + i, greedy=is_greedy, | |
| ) | |
| eval_rewards.append(traj["total_reward"]) | |
| eval_scores.append(traj["grader_score"]) | |
| eval_steps_list.append(traj["n_steps"]) | |
| avg_r = sum(eval_rewards) / len(eval_rewards) | |
| avg_s = sum(eval_scores) / len(eval_scores) | |
| avg_st = sum(eval_steps_list) / len(eval_steps_list) | |
| print( | |
| f" [{mode:>10s}] avg_reward={avg_r:+.4f} " | |
| f"avg_grader={avg_s:.4f} avg_steps={avg_st:.1f}" | |
| ) | |
| metrics[f"eval_{mode}_avg_reward"] = avg_r | |
| metrics[f"eval_{mode}_avg_grader_score"] = avg_s | |
| metrics[f"eval_{mode}_avg_steps"] = avg_st | |
| metrics[f"eval_{mode}_rewards"] = eval_rewards | |
| metrics[f"eval_{mode}_grader_scores"] = eval_scores | |
| print(f" best training rolling-avg grader: {best_avg_score:.4f}") | |
| print() | |
| with open(metrics_path, "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| print("Done.") | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint I/O | |
| # --------------------------------------------------------------------------- | |
| def _save_checkpoint( | |
| policy: PolicyNetwork, | |
| value_net: ValueNetwork, | |
| policy_optim: torch.optim.Optimizer, | |
| value_optim: torch.optim.Optimizer, | |
| episode: int, | |
| best_score: float, | |
| task_id: str, | |
| path: Path, | |
| ) -> None: | |
| torch.save( | |
| { | |
| "episode": episode, | |
| "best_avg_grader_score": best_score, | |
| "task_id": task_id, | |
| "policy_state_dict": policy.state_dict(), | |
| "value_state_dict": value_net.state_dict(), | |
| "policy_optim_state_dict": policy_optim.state_dict(), | |
| "value_optim_state_dict": value_optim.state_dict(), | |
| "state_dim": STATE_DIM, | |
| "action_dim": N_ACTIONS, | |
| }, | |
| path, | |
| ) | |
| def load_checkpoint( | |
| path: Path, | |
| hidden_dim: int = 128, | |
| ) -> Tuple[PolicyNetwork, ValueNetwork]: | |
| """Load a trained policy + value net from a checkpoint file.""" | |
| ckpt = torch.load(path, map_location="cpu") | |
| policy = PolicyNetwork( | |
| ckpt.get("state_dim", STATE_DIM), | |
| ckpt.get("action_dim", N_ACTIONS), | |
| hidden=hidden_dim, | |
| ) | |
| value_net = ValueNetwork(ckpt.get("state_dim", STATE_DIM), hidden=hidden_dim) | |
| policy.load_state_dict(ckpt["policy_state_dict"]) | |
| value_net.load_state_dict(ckpt["value_state_dict"]) | |
| return policy, value_net | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="REINFORCE training for PolypharmacyEnv", | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| p.add_argument( | |
| "--task", | |
| type=str, | |
| default="easy_screening", | |
| choices=list(TASK_CONFIGS.keys()), | |
| help="Task difficulty to train on", | |
| ) | |
| p.add_argument("--episodes", type=int, default=200, help="Number of training episodes") | |
| p.add_argument("--batch-size", type=int, default=5, help="Episodes per gradient update") | |
| p.add_argument("--lr", type=float, default=3e-4, help="Learning rate for Adam") | |
| p.add_argument("--gamma", type=float, default=0.99, help="Discount factor") | |
| p.add_argument( | |
| "--entropy-coeff", type=float, default=0.02, | |
| help="Entropy bonus coefficient (higher = more exploration)", | |
| ) | |
| p.add_argument("--hidden-dim", type=int, default=128, help="Hidden layer width") | |
| p.add_argument("--print-every", type=int, default=10, help="Print interval (episodes)") | |
| p.add_argument( | |
| "--checkpoint-dir", | |
| type=str, | |
| default=os.path.join(_BACKEND_SRC, "polypharmacy_env", "checkpoints"), | |
| help="Directory to save model checkpoints", | |
| ) | |
| p.add_argument( | |
| "--metrics-file", | |
| type=str, | |
| default="training_metrics.json", | |
| help="Path for JSON training metrics", | |
| ) | |
| return p.parse_args() | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| train(args) | |