""" ImmunoOrg 2.0: Trajectory Generator for GRPO Training ======================================================= Executes scenarios and records observation → action → reward trajectories suitable for GRPO training with TRL (Transformers Reinforcement Learning). """ import json import gzip import logging from typing import List, Dict, Any, Optional, Tuple from pathlib import Path from dataclasses import dataclass, field, asdict import time logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) # ============================================================ # DATA CLASSES # ============================================================ @dataclass class TrajectoryFrame: """Single step in a trajectory.""" step: int observation: Dict[str, Any] action: Dict[str, Any] reward: float terminated: bool @dataclass class Trajectory: """Complete episode trajectory.""" scenario_id: str dataset_type: str difficulty: int seed: int frames: List[TrajectoryFrame] = field(default_factory=list) cumulative_reward: float = 0.0 num_frames: int = 0 avg_reward: float = 0.0 max_steps: int = 0 time_to_containment: Optional[int] = None status: str = "pending" # pending | completed | failed error_message: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { "scenario_id": self.scenario_id, "dataset_type": self.dataset_type, "difficulty": self.difficulty, "seed": self.seed, "frames": [ { "step": f.step, "observation": f.observation, "action": f.action, "reward": f.reward, "terminated": f.terminated } for f in self.frames ], "cumulative_reward": self.cumulative_reward, "num_frames": self.num_frames, "avg_reward": self.avg_reward, "max_steps": self.max_steps, "time_to_containment": self.time_to_containment, "status": self.status, "error_message": self.error_message } # ============================================================ # TRAJECTORY GENERATOR # ============================================================ class TrajectoryGenerator: """ Generates training trajectories by executing scenarios in the environment. This class: - Loads scenario definitions - Initializes the environment with scenario configs - Runs agent to completion - Records observation → action → reward frames - Computes trajectory statistics """ def __init__(self, env=None, agent=None, output_dir: str = "training/trajectories"): """ Initialize trajectory generator. Args: env: ImmunoOrgEnvironment instance agent: Agent with act(obs) -> action method output_dir: Directory for saving trajectories """ self.env = env self.agent = agent self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) self.trajectories = [] self.stats = { "total_trajectories": 0, "successful": 0, "failed": 0, "total_frames": 0, "avg_reward_overall": 0.0 } def generate_trajectory( self, scenario: Dict[str, Any], max_steps: Optional[int] = None, verbose: bool = False ) -> Trajectory: """ Execute one scenario and record trajectory. Args: scenario: Scenario configuration dictionary max_steps: Override max steps (if None, use scenario config) verbose: Print progress Returns: Trajectory object with all frames and stats """ if not self.env or not self.agent: raise ValueError("Environment and agent must be initialized before generating trajectories") scenario_id = scenario.get("scenario_id", "unknown") difficulty = scenario.get("difficulty", 1) seed = scenario.get("seed", 42) dataset_type = scenario.get("dataset_type", "unknown") config = scenario.get("config", {}) # Get max steps from config max_steps = max_steps or config.get("max_steps", 100) # Create trajectory object trajectory = Trajectory( scenario_id=scenario_id, dataset_type=dataset_type, difficulty=difficulty, seed=seed, max_steps=max_steps ) try: # Reset environment with scenario seed obs = self.env.reset(seed=seed) if verbose: logger.info(f"Executing {scenario_id} (difficulty={difficulty}, seed={seed})") step = 0 containment_time = None # Run episode to completion or max steps while step < max_steps: # Agent observes and acts action = self.agent.act(obs) # Environment steps next_obs, reward, terminated = self.env.step(action) # Serialize observation and action obs_dict = self._serialize_observation(obs) action_dict = self._serialize_action(action) # Record frame frame = TrajectoryFrame( step=step, observation=obs_dict, action=action_dict, reward=float(reward), terminated=terminated ) trajectory.frames.append(frame) trajectory.cumulative_reward += reward # Track containment time (first step where threats are contained) if containment_time is None and self._is_contained(next_obs): containment_time = step if terminated: trajectory.status = "completed" break obs = next_obs step += 1 # Compute trajectory stats trajectory.num_frames = len(trajectory.frames) trajectory.avg_reward = ( trajectory.cumulative_reward / max(1, len(trajectory.frames)) ) trajectory.time_to_containment = containment_time if trajectory.status == "pending": trajectory.status = "truncated" # Max steps reached self.stats["total_trajectories"] += 1 self.stats["successful"] += 1 self.stats["total_frames"] += trajectory.num_frames self.trajectories.append(trajectory) if verbose: logger.info( f" → {scenario_id}: {trajectory.num_frames} frames, " f"reward={trajectory.cumulative_reward:.3f}, " f"avg={trajectory.avg_reward:.3f}" ) return trajectory except Exception as e: logger.error(f"Error executing {scenario_id}: {str(e)}") trajectory.status = "failed" trajectory.error_message = str(e) self.stats["failed"] += 1 return trajectory def generate_trajectories_batch( self, scenarios: List[Dict[str, Any]], max_parallel: int = 1, verbose: bool = True ) -> List[Trajectory]: """ Generate trajectories for multiple scenarios. Args: scenarios: List of scenario configurations max_parallel: Number of parallel workers (currently 1, TODO: implement parallelism) verbose: Print progress Returns: List of Trajectory objects """ trajectories = [] total = len(scenarios) logger.info(f"Generating {total} trajectories...") start_time = time.time() for i, scenario in enumerate(scenarios): trajectory = self.generate_trajectory(scenario, verbose=verbose) trajectories.append(trajectory) if (i + 1) % max(1, total // 10) == 0: elapsed = time.time() - start_time rate = (i + 1) / elapsed remaining = (total - i - 1) / rate if rate > 0 else 0 logger.info( f"Progress: {i+1}/{total} ({100*(i+1)/total:.1f}%) | " f"Elapsed: {elapsed:.1f}s | ETA: {remaining:.1f}s" ) elapsed = time.time() - start_time logger.info(f"Batch complete: {len(trajectories)} trajectories in {elapsed:.1f}s") return trajectories def save_trajectories( self, trajectories: List[Trajectory], filename: str, compress: bool = True ) -> str: """ Save trajectories to file. Args: trajectories: List of Trajectory objects filename: Output filename compress: Compress with gzip Returns: Path to saved file """ output_path = self.output_dir / filename # Convert to dictionaries trajectory_dicts = [t.to_dict() for t in trajectories] if compress and filename.endswith('.json'): output_path = output_path.with_suffix('.json.gz') with gzip.open(str(output_path), 'wt', encoding='utf-8') as f: json.dump(trajectory_dicts, f, indent=2) logger.info(f"Saved {len(trajectories)} trajectories to {output_path} (compressed)") else: with open(output_path, 'w', encoding='utf-8') as f: json.dump(trajectory_dicts, f, indent=2) logger.info(f"Saved {len(trajectories)} trajectories to {output_path}") return str(output_path) def _serialize_observation(self, obs) -> Dict[str, Any]: """Serialize observation for storage.""" if hasattr(obs, 'dict'): return obs.dict() elif isinstance(obs, dict): return obs else: return {"raw": str(obs)} def _serialize_action(self, action) -> Dict[str, Any]: """Serialize action for storage.""" if hasattr(action, 'dict'): return action.dict() elif isinstance(action, dict): return action else: return {"raw": str(action)} def _is_contained(self, obs) -> bool: """Check if all threats are contained based on observation.""" if isinstance(obs, dict): # Check for threat indicators detected_attacks = obs.get("detected_attacks", []) threat_level = obs.get("threat_level", 0.0) return len(detected_attacks) == 0 and threat_level < 0.1 elif hasattr(obs, 'detected_attacks') and hasattr(obs, 'threat_level'): return len(obs.detected_attacks) == 0 and obs.threat_level < 0.1 return False def print_stats(self): """Print trajectory generation statistics.""" total = self.stats["total_trajectories"] successful = self.stats["successful"] failed = self.stats["failed"] if total > 0: self.stats["avg_reward_overall"] = sum( t.cumulative_reward for t in self.trajectories ) / total stats_str = f""" TRAJECTORY GENERATION STATISTICS ================================= Total Trajectories: {total} - Successful: {successful} ({100*successful/max(1,total):.1f}%) - Failed: {failed} ({100*failed/max(1,total):.1f}%) Total Frames: {self.stats['total_frames']} Average Frames per Trajectory: {self.stats['total_frames']/max(1,total):.1f} Average Reward (Overall): {self.stats['avg_reward_overall']:.4f} Difficulty Distribution: """ for d in range(1, 5): count = sum(1 for t in self.trajectories if t.difficulty == d) if count > 0: avg_reward = sum( t.cumulative_reward for t in self.trajectories if t.difficulty == d ) / count stats_str += f" - Difficulty {d}: {count} trajectories (avg_reward={avg_reward:.4f})\n" logger.info(stats_str) # ============================================================ # GRPO DATASET CONVERTER # ============================================================ class GRPODatasetConverter: """ Converts trajectories to GRPO training format for TRL. GRPO Training Inputs: - prompt: LLM input (formatted observation) - completion: LLM output (formatted action) - reward: scalar feedback (0-1) """ def __init__(self, tokenizer=None): """ Initialize converter. Args: tokenizer: HuggingFace tokenizer for prompt/completion """ self.tokenizer = tokenizer def convert_trajectory_to_grpo(self, trajectory: Trajectory) -> List[Dict[str, Any]]: """ Convert single trajectory to GRPO training samples. Args: trajectory: Trajectory object Returns: List of {prompt, completion, reward} dictionaries """ grpo_samples = [] for frame in trajectory.frames: obs = frame.observation action = frame.action reward = frame.reward # Build prompt from observation prompt = self._build_llm_prompt(obs, trajectory.difficulty) # Format action as completion completion = self._format_action_as_completion(action) grpo_samples.append({ "prompt": prompt, "completion": completion, "reward": reward, "scenario_id": trajectory.scenario_id, "step": frame.step, "difficulty": trajectory.difficulty }) return grpo_samples def convert_trajectories_to_grpo(self, trajectories: List[Trajectory]) -> List[Dict[str, Any]]: """ Convert multiple trajectories to GRPO training data. Args: trajectories: List of Trajectory objects Returns: List of GRPO training samples """ grpo_data = [] for trajectory in trajectories: samples = self.convert_trajectory_to_grpo(trajectory) grpo_data.extend(samples) logger.info(f"Converted {len(trajectories)} trajectories to {len(grpo_data)} GRPO samples") return grpo_data def _build_llm_prompt(self, obs: Dict[str, Any], difficulty: int) -> str: """ Build LLM prompt from observation. Args: obs: Observation dictionary difficulty: Current difficulty level Returns: Formatted prompt string """ prompt = f"""You are the Patronus AI, an autonomous self-healing enterprise agent. PHASE: {obs.get('current_phase', 'unknown')} DIFFICULTY: {difficulty} THREAT_LEVEL: {obs.get('threat_level', 0.0):.2f} STEP: {obs.get('step_count', 0)} | TIME: {obs.get('sim_time', 0.0):.0f} === BOARD DIRECTIVES === {chr(10).join(obs.get('directives', [])) if obs.get('directives') else 'None'} === RAG CVE INTELLIGENCE === {chr(10).join(obs.get('alerts', [])) if obs.get('alerts') else 'No alerts'} === NETWORK STATE === Visible Nodes: {len(obs.get('visible_nodes', []))} | Health: {obs.get('network_health_summary', {}).get('average_health', 0.0):.1f} Detected Threats: {len(obs.get('detected_attacks', []))} === ORG STATE === Departments: {len(obs.get('org_nodes', []))} | Pending Approvals: {len(obs.get('pending_approvals', []))} TASK: Analyze the situation. Return your reasoning and chosen action. FORMAT: REASONING: | ACTION: | DETAIL: | TARGET: """ return prompt def _format_action_as_completion(self, action: Dict[str, Any]) -> str: """ Format action as LLM completion. Args: action: Action dictionary Returns: Formatted action string """ reasoning = action.get("reasoning", "Executing standard procedure") action_type = action.get("action_type", "DIAGNOSTIC") detail = action.get("tactical_action") or action.get("strategic_action") or action.get("diagnostic_action") or "QUERY_BELIEF_MAP" target = action.get("target", "system") completion = f"REASONING: {reasoning} | ACTION: {action_type} | DETAIL: {detail} | TARGET: {target}" return completion def save_grpo_data( self, grpo_data: List[Dict[str, Any]], filename: str, output_dir: str = "training/grpo_data", compress: bool = True ) -> str: """ Save GRPO training data to file. Args: grpo_data: List of GRPO training samples filename: Output filename output_dir: Output directory compress: Compress with gzip Returns: Path to saved file """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) output_file = output_path / filename if compress and filename.endswith('.json'): output_file = output_file.with_suffix('.json.gz') with gzip.open(str(output_file), 'wt', encoding='utf-8') as f: json.dump(grpo_data, f, indent=2) logger.info(f"Saved {len(grpo_data)} GRPO samples to {output_file} (compressed)") else: with open(output_file, 'w', encoding='utf-8') as f: json.dump(grpo_data, f, indent=2) logger.info(f"Saved {len(grpo_data)} GRPO samples to {output_file}") return str(output_file) # ============================================================ # UTILITY FUNCTIONS # ============================================================ def load_scenarios(filepath: str) -> List[Dict[str, Any]]: """ Load scenarios from JSON file (handles gzip). Args: filepath: Path to scenarios file Returns: List of scenario dictionaries """ if filepath.endswith('.gz'): with gzip.open(filepath, 'rt', encoding='utf-8') as f: return json.load(f) else: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) def load_trajectories(filepath: str) -> List[Trajectory]: """ Load trajectories from JSON file (handles gzip). Args: filepath: Path to trajectories file Returns: List of Trajectory objects """ if filepath.endswith('.gz'): with gzip.open(filepath, 'rt', encoding='utf-8') as f: data = json.load(f) else: with open(filepath, 'r', encoding='utf-8') as f: data = json.load(f) trajectories = [] for traj_dict in data: traj = Trajectory( scenario_id=traj_dict["scenario_id"], dataset_type=traj_dict["dataset_type"], difficulty=traj_dict["difficulty"], seed=traj_dict["seed"], cumulative_reward=traj_dict["cumulative_reward"], num_frames=traj_dict["num_frames"], avg_reward=traj_dict["avg_reward"], max_steps=traj_dict["max_steps"], time_to_containment=traj_dict.get("time_to_containment"), status=traj_dict.get("status", "unknown") ) # Note: frames not loaded for memory efficiency trajectories.append(traj) return trajectories