aegis_training / aegis_env /curriculum.py
YashashMathur's picture
Upload aegis_env
ab65ac6 verified
"""
Curriculum scheduling and scenario loading for AEGIS-Env.
"""
import os
import json
import random
from typing import Dict, Any, Optional
from scripts.adversarial_generator import AdversarialGenerator
MAX_TOKENS = 300 # Layer-3: token cap (~300 words)
def truncate_to_tokens(text: str, max_tokens: int = MAX_TOKENS) -> str:
"""Layer-3 fix: hardcap tokens to prevent context blowup during training."""
words = text.split()
if len(words) <= max_tokens:
return text
return " ".join(words[:max_tokens]) + " [TRUNCATED]"
def wrap_with_delimiters(text: str) -> str:
"""Layer-3 fix: fence worker output to prevent prompt injection."""
return f"[WORKER_THOUGHTS_START]\n{text}\n[WORKER_THOUGHTS_END]"
class CurriculumScheduler:
"""CUR-01, CUR-02: Support for Level 3 (Adversarial) escalation."""
@staticmethod
def get_level(training_step: int) -> int:
if training_step < 150:
return 1
if training_step < 300:
return 2
return 3
class ScenarioLoader:
"""SCN-10: Scenario loader with support for Level 3 Adversarial generation."""
def __init__(self, scenario_dir: Optional[str] = None, seed: int = 42):
self._rng = random.Random(seed)
self.seed = seed
self.scenarios_by_level: Dict[int, list] = {}
self.train_scenarios: list = []
self.eval_scenarios: list = []
self._step_counter = 0 # Layer-2: circular buffer counter
if scenario_dir and os.path.exists(scenario_dir):
for level in [1, 2]:
level_path = os.path.join(scenario_dir, f"level_{level}")
if os.path.exists(level_path):
scenarios = []
for fname in os.listdir(level_path):
if fname.endswith(".json"):
fpath = os.path.join(level_path, fname)
with open(fpath, "r", encoding="utf-8") as f:
scenarios.append(json.load(f))
self.scenarios_by_level[level] = scenarios
# 80/20 train/eval partition
all_scenarios = (
self.scenarios_by_level.get(1, [])
+ self.scenarios_by_level.get(2, [])
)
self._rng.shuffle(all_scenarios)
split = int(len(all_scenarios) * 0.8)
self.train_scenarios = all_scenarios[:split]
self.eval_scenarios = all_scenarios[split:]
def sample(self, level: int, mode: str = "train") -> Dict[str, Any]:
"""Layer-2 fix: circular buffer modulo prevents IndexError on exhaustion.
For level 3, use AdversarialGenerator on Level 2 scenarios.
"""
pool = self.train_scenarios if mode == "train" else self.eval_scenarios
if not pool:
raise ValueError(
f"No scenarios found for mode='{mode}'. Check scenario_dir."
)
# Circular modulo — never raises IndexError
idx = self._step_counter % len(pool)
self._step_counter += 1
scenario = pool[idx]
if level >= 3:
return AdversarialGenerator.generate(scenario, seed=self.seed + self._step_counter)
return scenario
def reset_counter(self) -> None:
"""Reset circular counter for a new training run."""
self._step_counter = 0