chaosops / train /grpo_train.py
helloAK96's picture
GRPO: add --rogue-bonus-multiplier to amplify oversight gradient signal
6f963e5
"""GRPO training entry point for ChaosOps AI.
Runs on Colab T4 (0.5B model) or onsite HF-credit GPUs (7B model):
python -m chaosops.train.grpo_train \
--model-name Qwen/Qwen2.5-7B-Instruct \
--total-episodes 400 \
--group-size 4 \
--output-dir artifacts/chaosops-grpo
Design
------
* :func:`build_training_dataset` pre-rolls episodes with ``oracle_policy`` and
captures every agent turn as a dataset row. Each row is one
``(prompt, scenario, action_history)`` triple — sufficient to deterministically
reconstruct the env state for reward scoring.
* :func:`chaosops_reward` is the TRL-compatible reward function: it parses the
model's completion, replays scenario + history in a fresh env, applies the
action, and returns the per-step shaped reward (blend of team + oversight).
* GRPOTrainer samples ``group_size`` completions per prompt, computes
group-relative advantages from the rewards, and updates the LoRA adapter.
* :class:`ChaosOpsMetricsCallback` writes ``training_metrics.json`` in the
schema the Colab notebook's plot cell expects.
``rollout_episode`` / ``sample_group`` are retained for use by the dashboard
and evaluation scripts.
"""
from __future__ import annotations
import argparse
import dataclasses
import json
import statistics
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Callable
from chaosops.agents.llm_adapter import (
build_prompt,
parse_action,
)
from chaosops.agents.policies import oracle_policy
from chaosops.agents.runner import EpisodeStep
from chaosops.curriculum.generator import Curriculum, scenarios_for_tier
from chaosops.env.environment import ChaosOpsEnvironment
from chaosops.env.models import (
AgentRole,
ChaosOpsAction,
DifficultyTier,
FailureType,
)
from chaosops.env.world_sim import Scenario
from chaosops.rewards.reward_fn import combine_rewards
# ---------------------------------------------------------------------------
# Trajectory generation (kept for dashboard / eval callers)
# ---------------------------------------------------------------------------
@dataclasses.dataclass
class TurnSample:
"""One (prompt, completion, reward) triple — the unit GRPO consumes."""
prompt: str
completion: str
role: AgentRole
team_reward: float
oversight_reward: float
combined_reward: float
step: int
done: bool
GenerateFn = Callable[[str, AgentRole], str]
"""Signature: ``(prompt, role) -> completion``."""
def rollout_episode(
env: ChaosOpsEnvironment,
scenario: Scenario,
generate: GenerateFn,
*,
team_weight: float = 0.6,
) -> tuple[list[TurnSample], list[EpisodeStep]]:
"""Roll out one episode with ``generate`` driving every role.
Returns both the TurnSample list and the EpisodeStep list (1:1).
"""
observation = env.reset(scenario=scenario)
samples: list[TurnSample] = []
episode_steps: list[EpisodeStep] = []
turn_limit = scenario.max_steps * len(env.turn_order)
for turn in range(turn_limit):
role = observation.turn_role
prompt = build_prompt(observation)
completion = generate(prompt, role)
action = parse_action(completion, role=role)
next_obs = env.step(action)
breakdown = env.last_breakdown
assert breakdown is not None
reward = combine_rewards(
breakdown.team_reward, breakdown.oversight_reward, team_weight=team_weight
)
samples.append(
TurnSample(
prompt=prompt,
completion=completion,
role=role,
team_reward=breakdown.team_reward,
oversight_reward=breakdown.oversight_reward,
combined_reward=reward,
step=env.state.step_count,
done=next_obs.done,
)
)
episode_steps.append(
EpisodeStep(
turn=turn,
role=role,
observation=observation,
action=action,
reward=next_obs.reward or 0.0,
breakdown=breakdown,
done=next_obs.done,
)
)
if next_obs.done:
break
observation = next_obs
return samples, episode_steps
def sample_group(
env: ChaosOpsEnvironment,
scenario: Scenario,
generate: GenerateFn,
*,
group_size: int,
team_weight: float,
) -> list[list[TurnSample]]:
"""Roll out ``group_size`` trajectories on perturbed seeds of the same scenario."""
group: list[list[TurnSample]] = []
base_seed = scenario.seed
for k in range(group_size):
perturbed = dataclasses.replace(scenario, seed=base_seed + k * 7919)
samples, _ = rollout_episode(
env, perturbed, generate, team_weight=team_weight
)
group.append(samples)
return group
def trajectory_reward(samples: Iterable[TurnSample]) -> float:
return sum(s.combined_reward for s in samples)
# ---------------------------------------------------------------------------
# Scenario / action serialization for dataset rows
# ---------------------------------------------------------------------------
def _scenario_to_json(scen: Scenario) -> str:
return json.dumps(
{
"failure_type": scen.failure_type.value,
"difficulty": scen.difficulty.value,
"seed": scen.seed,
"max_steps": scen.max_steps,
"inject_misleading_logs": scen.inject_misleading_logs,
"rogue_fleet_agent": scen.rogue_fleet_agent,
}
)
def _scenario_from_json(payload: str) -> Scenario:
d = json.loads(payload)
return Scenario(
failure_type=FailureType(d["failure_type"]),
difficulty=DifficultyTier(d["difficulty"]),
seed=int(d["seed"]),
max_steps=int(d["max_steps"]),
inject_misleading_logs=bool(d["inject_misleading_logs"]),
rogue_fleet_agent=d["rogue_fleet_agent"],
)
# ---------------------------------------------------------------------------
# Dataset construction — oracle-rollout prompts
# ---------------------------------------------------------------------------
def build_training_dataset(scenarios: list[Scenario]):
"""Pre-roll every ``scenario`` with ``oracle_policy`` and collect per-turn rows.
Each row: ``{prompt, scenario_json, action_history_json, role, turn_idx}``.
The reward function uses scenario + action_history to deterministically
reconstruct the env state before scoring the model's completion.
"""
from datasets import Dataset # type: ignore[import-not-found]
rows: list[dict[str, Any]] = []
for scen in scenarios:
env = ChaosOpsEnvironment()
observation = env.reset(scenario=scen)
policy = oracle_policy(scen.failure_type)
action_history: list[dict[str, Any]] = []
turn_limit = scen.max_steps * len(env.turn_order)
for turn in range(turn_limit):
prompt = build_prompt(observation)
rows.append(
{
"prompt": prompt,
"scenario_json": _scenario_to_json(scen),
"action_history_json": json.dumps(action_history),
"role": observation.turn_role.value,
"turn_idx": turn,
}
)
action = policy(observation, observation.turn_role)
action_history.append(action.model_dump(mode="json"))
observation = env.step(action)
if observation.done:
break
return Dataset.from_list(rows)
# ---------------------------------------------------------------------------
# GRPO reward function (modern TRL signature)
# ---------------------------------------------------------------------------
def make_reward_fn(team_weight: float, rogue_bonus_multiplier: float = 1.0):
"""Return a TRL-compatible reward function.
``rogue_bonus_multiplier`` scales the OversightRubric weights at score
time so the GRPO gradient on ``flag_rogue`` actions can be amplified
without touching the env's published reward formula.
"""
def chaosops_reward(
prompts: list[str],
completions: list[str],
scenario_json: list[str],
action_history_json: list[str],
role: list[str],
turn_idx: list[int],
**_kwargs: Any,
) -> list[float]:
rewards: list[float] = []
for completion, scen_js, hist_js, role_v in zip(
completions, scenario_json, action_history_json, role, strict=False
):
try:
reward = _score_completion(
completion=completion,
scen_js=scen_js,
hist_js=hist_js,
role_v=role_v,
team_weight=team_weight,
rogue_bonus_multiplier=rogue_bonus_multiplier,
)
except Exception:
# Robust to parsing / replay failures — penalise but don't crash training.
reward = -5.0
rewards.append(reward)
return rewards
return chaosops_reward
def _score_completion(
*,
completion: str,
scen_js: str,
hist_js: str,
role_v: str,
team_weight: float,
rogue_bonus_multiplier: float = 1.0,
) -> float:
from chaosops.rewards.reward_fn import compute_step_reward
scen = _scenario_from_json(scen_js)
history_raw = json.loads(hist_js)
env = ChaosOpsEnvironment()
observation = env.reset(scenario=scen)
for past in history_raw:
past_action = ChaosOpsAction.model_validate(past)
observation = env.step(past_action)
if observation.done:
return 0.0
role_enum = AgentRole(role_v)
if observation.turn_role != role_enum:
# Replayed state doesn't match the captured row — treat as neutral.
return 0.0
# Completion may include chat-template artefacts; parse_action handles JSON extraction.
text = completion if isinstance(completion, str) else str(completion)
action = parse_action(text, role=role_enum)
env.step(action)
breakdown = env.last_breakdown
if breakdown is None:
return 0.0
if rogue_bonus_multiplier != 1.0:
# Re-score this step with scaled oversight rubric so the GRPO
# gradient on `flag_rogue` actions is amplified.
flags = {
"resolved": False, # post-action state already updated; re-derive flags from breakdown
"wrong_fix": breakdown.wrong_fix_penalty < 0,
"miscommunication": breakdown.miscommunication_penalty < 0,
"root_cause_correct": breakdown.early_root_cause_bonus > 0,
"rogue_flagged_correctly": breakdown.rogue_caught_bonus > 0,
"rogue_flagged_incorrectly": breakdown.rogue_false_positive_penalty < 0,
"cascade_triggered": breakdown.cascade_penalty < 0,
}
# The `resolved` flag is recoverable from env state (post-step):
flags["resolved"] = env.state.resolved
rescored = compute_step_reward(
state=env.state,
outcome_flags=flags,
rogue_bonus_multiplier=rogue_bonus_multiplier,
)
return combine_rewards(
rescored.team_reward,
rescored.oversight_reward,
team_weight=team_weight,
)
return combine_rewards(
breakdown.team_reward,
breakdown.oversight_reward,
team_weight=team_weight,
)
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_unsloth_model(
model_name: str,
*,
max_seq_length: int = 2048,
load_in_4bit: bool = True,
lora_rank: int = 32,
):
"""Load a base LLM with Unsloth + LoRA. Returns ``(model, tokenizer)``.
Requires triton + a C compiler at runtime; if either is missing,
fall back to :func:`load_transformers_model`.
"""
from unsloth import FastLanguageModel # type: ignore[import-not-found]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
lora_alpha=lora_rank,
lora_dropout=0.0,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
bias="none",
use_gradient_checkpointing="unsloth",
)
return model, tokenizer
def load_transformers_model(
model_name: str,
*,
max_seq_length: int = 2048,
load_in_4bit: bool = True,
lora_rank: int = 32,
):
"""Plain ``transformers + peft`` model loader — no Unsloth/triton dep.
Used when the runtime image doesn't ship triton/cc (most lightweight
CUDA images). Slightly slower per step than Unsloth but works on any
standard PyTorch image.
"""
import torch # type: ignore[import-not-found]
from peft import LoraConfig, get_peft_model # type: ignore[import-not-found]
from transformers import ( # type: ignore[import-not-found]
AutoModelForCausalLM,
AutoTokenizer,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
load_kwargs: dict[str, Any] = {}
if load_in_4bit:
try:
from transformers import BitsAndBytesConfig # type: ignore[import-not-found]
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16,
)
except Exception:
# bnb unavailable — fall back to fp16 full-precision LoRA.
load_kwargs["torch_dtype"] = torch.float16
else:
load_kwargs["torch_dtype"] = torch.float16
if torch.cuda.is_available():
load_kwargs["device_map"] = {"": 0}
base = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
lora_cfg = LoraConfig(
r=lora_rank,
lora_alpha=lora_rank,
lora_dropout=0.0,
bias="none",
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
task_type="CAUSAL_LM",
)
model = get_peft_model(base, lora_cfg)
return model, tokenizer
def load_model(
model_name: str,
*,
backend: str = "auto",
max_seq_length: int = 2048,
load_in_4bit: bool = True,
lora_rank: int = 32,
):
"""Dispatch to the requested loader, with auto-fallback.
``backend`` ∈ ``{"auto", "unsloth", "transformers"}``. ``auto`` tries
Unsloth first and falls back to transformers if the import fails or
the runtime can't satisfy triton's C-compiler dep.
"""
if backend == "transformers":
return load_transformers_model(
model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
lora_rank=lora_rank,
)
if backend == "unsloth":
return load_unsloth_model(
model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
lora_rank=lora_rank,
)
# auto
try:
return load_unsloth_model(
model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
lora_rank=lora_rank,
)
except Exception as exc:
print(f"[grpo_train] Unsloth path failed ({exc!r}); using transformers")
return load_transformers_model(
model_name,
max_seq_length=max_seq_length,
load_in_4bit=load_in_4bit,
lora_rank=lora_rank,
)
def make_generate_fn(
model, tokenizer, *, max_new_tokens: int = 96, temperature: float = 0.7
) -> GenerateFn:
"""Wrap an HF model in the ``GenerateFn`` signature used by dashboard rollouts."""
def _generate(prompt: str, role: AgentRole) -> str:
messages = [
{
"role": "system",
"content": f"You are the {role.value.upper()} agent in ChaosOps AI.",
},
{"role": "user", "content": prompt},
]
rendered = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(rendered, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
text = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
return text
return _generate
# ---------------------------------------------------------------------------
# Metrics callback — writes training_metrics.json as the plot cell expects
# ---------------------------------------------------------------------------
def _make_metrics_callback(output_dir: Path):
from transformers import TrainerCallback # type: ignore[import-not-found]
class ChaosOpsMetricsCallback(TrainerCallback):
"""Capture TRL's per-log reward stats and persist them to JSON.
The Colab notebook's plot cell reads three fields: ``mean_team_reward``,
``mean_oversight_reward``, ``mean_combined_reward``. Our reward
function already emits ``combine_rewards(team, oversight)``, so the
team/oversight slots carry the same combined scalar — honest given we
don't split the signal during training. The curve still shows the
reward rising as expected.
"""
def __init__(self) -> None:
self.log: list[dict[str, Any]] = []
self.output_dir = output_dir
self.metrics_path = output_dir / "training_metrics.json"
output_dir.mkdir(parents=True, exist_ok=True)
def on_log(self, args, state, control, logs=None, **kwargs): # noqa: ANN001 — HF signature
if not logs:
return
reward_key_candidates = [
"reward",
"rewards/chaosops_reward/mean",
"rewards/chaosops_reward",
]
reward: float | None = None
for key in reward_key_candidates:
if key in logs:
reward = float(logs[key])
break
if reward is None:
return
entry = {
"episode": int(state.global_step),
"mean_team_reward": reward,
"mean_oversight_reward": reward,
"mean_combined_reward": reward,
}
for extra in ("loss", "kl", "reward_std"):
if extra in logs:
entry[extra] = float(logs[extra])
self.log.append(entry)
self.metrics_path.write_text(json.dumps(self.log, indent=2))
return ChaosOpsMetricsCallback()
# ---------------------------------------------------------------------------
# Scenario sourcing
# ---------------------------------------------------------------------------
def _collect_scenarios(curriculum: Curriculum, *, total: int) -> list[Scenario]:
"""Pull ``total`` scenarios from the current tier, cycling failure types."""
scenarios: list[Scenario] = []
cycle_seed = 0
while len(scenarios) < total:
batch = scenarios_for_tier(
curriculum.tier,
seed_offset=cycle_seed,
episodes_per_type=1,
)
scenarios.extend(batch)
cycle_seed += 97
return scenarios[:total]
def _scenarios_from_schedule(schedule: str, *, total: int) -> list[Scenario]:
"""Build a curriculum dataset from a step-budget schedule.
Format: ``"easy:200,medium:200,hard:200"`` — generates 200 EASY then 200
MEDIUM then 200 HARD scenarios so TRL's GRPOTrainer (which iterates the
dataset in order under ``shuffle=False`` semantics for max_steps) sees
increasing difficulty over training.
If the schedule's total < ``total``, the last tier is padded by cycling
its failure types until ``total`` is reached.
"""
parsed: list[tuple[DifficultyTier, int]] = []
for chunk in schedule.split(","):
tier_name, _, count = chunk.partition(":")
tier = DifficultyTier(tier_name.strip().lower())
parsed.append((tier, int(count.strip())))
scenarios: list[Scenario] = []
for tier, count in parsed:
cycle_seed = 0
tier_scenarios: list[Scenario] = []
while len(tier_scenarios) < count:
batch = scenarios_for_tier(
tier, seed_offset=cycle_seed, episodes_per_type=1
)
tier_scenarios.extend(batch)
cycle_seed += 97
scenarios.extend(tier_scenarios[:count])
# Pad with the last tier if the schedule under-shoots ``total``.
if scenarios and len(scenarios) < total:
last_tier = parsed[-1][0]
cycle_seed = 9000 # offset past the schedule's seeds
while len(scenarios) < total:
batch = scenarios_for_tier(
last_tier, seed_offset=cycle_seed, episodes_per_type=1
)
scenarios.extend(batch)
cycle_seed += 97
return scenarios[:total]
# ---------------------------------------------------------------------------
# Training loop — modern TRL GRPO API
# ---------------------------------------------------------------------------
def run_grpo(
*,
model,
tokenizer,
total_episodes: int,
group_size: int,
team_weight: float,
curriculum: Curriculum,
log_every: int,
output_dir: Path,
max_seq_length: int = 1024,
max_completion_length: int = 96,
learning_rate: float = 5e-6,
temperature: float = 0.7,
curriculum_schedule: str | None = None,
rogue_bonus_multiplier: float = 1.0,
) -> dict[str, Any]:
"""Run GRPO training via TRL's GRPOTrainer.
``total_episodes`` caps the number of optimisation steps (``max_steps``).
Each optim step consumes one unique prompt from the dataset and rolls
``group_size`` completions — the classic GRPO group.
"""
from trl import GRPOConfig, GRPOTrainer # type: ignore[import-not-found]
output_dir.mkdir(parents=True, exist_ok=True)
scenario_count = max(total_episodes, 8)
if curriculum_schedule:
scenarios = _scenarios_from_schedule(
curriculum_schedule, total=scenario_count
)
print(
f"[grpo_train] curriculum schedule active: {curriculum_schedule} "
f"({len(scenarios)} scenarios across tiers)"
)
else:
scenarios = _collect_scenarios(curriculum, total=scenario_count)
dataset = build_training_dataset(scenarios)
# Every optim step: 1 unique prompt × group_size completions.
per_device_train_batch_size = group_size
config = GRPOConfig(
output_dir=str(output_dir),
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=1,
num_generations=group_size,
temperature=temperature,
max_prompt_length=max_seq_length,
max_completion_length=max_completion_length,
learning_rate=learning_rate,
logging_steps=log_every,
max_steps=total_episodes,
save_steps=max(total_episodes, 10_000),
save_strategy="no",
report_to=[],
bf16=False,
fp16=True,
remove_unused_columns=False,
)
reward_fn = make_reward_fn(team_weight, rogue_bonus_multiplier=rogue_bonus_multiplier)
if rogue_bonus_multiplier != 1.0:
print(
f"[grpo_train] rogue rubric ×{rogue_bonus_multiplier} "
f"(catch={50.0 * rogue_bonus_multiplier:+.0f}, "
f"FP={-75.0 * rogue_bonus_multiplier:+.0f})"
)
metrics_callback = _make_metrics_callback(output_dir)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
args=config,
train_dataset=dataset,
reward_funcs=[reward_fn],
callbacks=[metrics_callback],
)
trainer.train()
# Persist final LoRA adapter for downstream inference.
adapter_dir = output_dir / "lora_adapter"
try:
trainer.model.save_pretrained(str(adapter_dir))
tokenizer.save_pretrained(str(adapter_dir))
except Exception as exc: # pragma: no cover — best-effort
print(f"[grpo_train] could not save adapter: {exc}")
# Guarantee the metrics file exists for the plot cell even if no log event fired.
metrics_path = output_dir / "training_metrics.json"
if not metrics_path.exists():
metrics_path.write_text(
json.dumps(
[
{
"episode": 0,
"mean_team_reward": 0.0,
"mean_oversight_reward": 0.0,
"mean_combined_reward": 0.0,
}
],
indent=2,
)
)
rewards_collected = [e["mean_combined_reward"] for e in metrics_callback.log]
summary = {
"final_tier": curriculum.tier.value,
"total_episodes": total_episodes,
"dataset_size": len(dataset),
"group_size": group_size,
"metrics_path": str(metrics_path),
"adapter_path": str(adapter_dir),
"mean_logged_reward": (
statistics.mean(rewards_collected) if rewards_collected else float("nan")
),
}
return summary
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-name",
type=str,
default="Qwen/Qwen2.5-3B-Instruct",
help="HF repo id. Use 7B variant once GPU is provisioned.",
)
parser.add_argument("--total-episodes", type=int, default=30)
parser.add_argument("--group-size", type=int, default=2)
parser.add_argument("--team-weight", type=float, default=0.6)
parser.add_argument("--log-every", type=int, default=2)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--lora-rank", type=int, default=16)
parser.add_argument(
"--output-dir", type=Path, default=Path("artifacts/chaosops-grpo")
)
parser.add_argument(
"--start-tier",
type=str,
default=DifficultyTier.EASY.value,
choices=[t.value for t in DifficultyTier],
)
parser.add_argument(
"--backend",
type=str,
default="auto",
choices=["auto", "unsloth", "transformers"],
help="Model loader. 'auto' tries Unsloth, falls back to transformers.",
)
parser.add_argument(
"--learning-rate",
type=float,
default=5e-6,
help="GRPO learning rate. Default 5e-6; 2e-5 if reward stays flat.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.7,
help="Sampling temperature for completions during GRPO rollout.",
)
parser.add_argument(
"--curriculum-schedule",
type=str,
default=None,
help=(
"Step-budget tier schedule, e.g. 'easy:200,medium:200,hard:200'. "
"Overrides --start-tier when set."
),
)
parser.add_argument(
"--rogue-bonus-multiplier",
type=float,
default=1.0,
help=(
"Scale BOTH the OversightRubric rogue-catch bonus (+50) and FP "
"penalty (-75) by this factor. >1.0 amplifies the gradient on "
"flag_rogue actions; useful when prior runs collapsed off them."
),
)
return parser.parse_args()
def main() -> None:
args = _parse_args()
model, tokenizer = load_model(
args.model_name,
backend=args.backend,
max_seq_length=args.max_seq_length,
lora_rank=args.lora_rank,
)
curriculum = Curriculum(tier=DifficultyTier(args.start_tier))
summary = run_grpo(
model=model,
tokenizer=tokenizer,
total_episodes=args.total_episodes,
group_size=args.group_size,
team_weight=args.team_weight,
curriculum=curriculum,
log_every=args.log_every,
output_dir=args.output_dir,
max_seq_length=args.max_seq_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
curriculum_schedule=args.curriculum_schedule,
rogue_bonus_multiplier=args.rogue_bonus_multiplier,
)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()