Spaces:
Sleeping
Sleeping
| """ | |
| SynthAudit.Env β REAL Colab Training (No Fakes) | |
| ================================================= | |
| Actually trains Llama 3.2 3B on the oversight environment. | |
| Two paths: | |
| PATH A: TRL GRPOTrainer + environment_factory (needs transformers>=5.2) | |
| PATH B: Manual generate β score β PPO loop (works with any TRL) | |
| INSTALL (run in Colab BEFORE this script): | |
| !pip install trl datasets peft accelerate bitsandbytes | |
| !pip install git+https://github.com/huggingface/transformers.git@main | |
| !pip install jmespath | |
| !pip install pydantic openai matplotlib | |
| Run: | |
| python training/train_colab.py | |
| python training/train_colab.py --path manual # Force manual loop | |
| python training/train_colab.py --path grpo # Force TRL GRPO | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| _script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| _project_dir = os.path.dirname(_script_dir) | |
| sys.path.insert(0, _project_dir) | |
| sys.path.insert(0, os.path.join(_project_dir, "server")) | |
| from models import SynthAuditAction, ActionType | |
| from server.synth_audit_environment import SynthAuditEnvironment | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Environment Wrapper (shared by both paths) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SynthAuditTrainEnv: | |
| """4-tool env for 3B model. TRL auto-discovers these methods.""" | |
| def __init__(self): | |
| self.env = SynthAuditEnvironment() | |
| self.reward = 0.0 | |
| self.done = False | |
| def reset(self, seed=42, task_id="oversight_easy", **kwargs) -> str: | |
| self.reward = 0.0 | |
| self.done = False | |
| obs = self.env.reset(seed=seed, task_id=task_id) | |
| proposals = "\n".join( | |
| f"- {p.proposal_id}: Patient {p.patient_id}, Conf={p.confidence}" | |
| for p in obs.actor_proposals | |
| ) | |
| return ( | |
| f"Audit {len(obs.actor_proposals)} proposals.\n" | |
| f"Proposals:\n{proposals}\n" | |
| f"For each: review_proposal, investigate_patient, then flag_error or approve." | |
| ) | |
| def review_proposal(self, proposal_id: str) -> str: | |
| """Review a proposal's reasoning. Args: proposal_id (e.g. PROP-001)""" | |
| return self._step(SynthAuditAction( | |
| action_type=ActionType.review_proposal, proposal_id=proposal_id)) | |
| def investigate_patient(self, patient_id: str) -> str: | |
| """Get patient EHR data. Args: patient_id (e.g. P0001)""" | |
| return self._step(SynthAuditAction( | |
| action_type=ActionType.investigate_patient, patient_id=patient_id)) | |
| def flag_error(self, proposal_id: str, reason: str) -> str: | |
| """Flag proposal as wrong. Args: proposal_id, reason""" | |
| return self._step(SynthAuditAction( | |
| action_type=ActionType.flag_error, proposal_id=proposal_id, | |
| error_type="age_boundary_error", reason=reason)) | |
| def approve(self, proposal_id: str) -> str: | |
| """Approve proposal as correct. Args: proposal_id""" | |
| return self._step(SynthAuditAction( | |
| action_type=ActionType.approve, proposal_id=proposal_id)) | |
| def _step(self, action): | |
| if self.done: | |
| return "Episode complete." | |
| try: | |
| obs = self.env.step(action) | |
| self.reward = obs.score_so_far | |
| self.done = obs.done | |
| return obs.feedback | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def reward_func(environments, **kwargs): | |
| return [env.reward for env in environments] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PATH A: TRL GRPOTrainer with environment_factory | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_grpo_training(model_name: str, max_steps: int): | |
| """Real GRPO training. Requires TRL + transformers>=5.2.""" | |
| import torch | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| print(f"\n Loading {model_name}...") | |
| # Try Unsloth first for memory efficiency | |
| model = model_name | |
| try: | |
| from unsloth import FastLanguageModel | |
| print(" β Unsloth detected β 4-bit LoRA") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name, max_seq_length=1024, load_in_4bit=True) | |
| model = FastLanguageModel.get_peft_model( | |
| model, r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, lora_dropout=0, | |
| use_gradient_checkpointing="unsloth") | |
| except ImportError: | |
| print(" β No Unsloth β loading model directly (higher VRAM)") | |
| SYSTEM = ("You audit clinical AI proposals. For each proposal, call " | |
| "review_proposal to see reasoning, investigate_patient to check data, " | |
| "then flag_error or approve.") | |
| dataset = Dataset.from_dict({ | |
| "prompt": [[ | |
| {"role": "system", "content": SYSTEM}, | |
| {"role": "user", "content": "Audit the clinical proposals now."}, | |
| ]] * 16, | |
| }) | |
| config = GRPOConfig( | |
| max_completion_length=1024, | |
| num_generations=2, | |
| gradient_accumulation_steps=4, | |
| per_device_train_batch_size=1, | |
| max_steps=max_steps, | |
| logging_steps=1, | |
| log_completions=True, | |
| output_dir=os.path.join(_project_dir, "outputs", "grpo_run"), | |
| report_to="none", | |
| learning_rate=5e-6, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_func, | |
| train_dataset=dataset, | |
| args=config, | |
| environment_factory=SynthAuditTrainEnv, | |
| ) | |
| print(f"\n GRPO Training for {max_steps} steps (REAL model training)...\n") | |
| start = time.time() | |
| trainer.train() | |
| elapsed = time.time() - start | |
| out_dir = os.path.join(_project_dir, "outputs", "trained_model") | |
| trainer.save_model(out_dir) | |
| print(f"\nβ REAL training complete in {elapsed:.0f}s. Model saved to {out_dir}") | |
| rewards = [h.get("train/reward") for h in trainer.state.log_history | |
| if "train/reward" in h] | |
| return rewards | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PATH B: Manual generate β score β update (works with any setup) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_manual_training(model_name: str, max_steps: int): | |
| """Manual training loop with REAL model inference. | |
| Generates text with the model, parses tool calls, | |
| runs them in the environment, scores the episode. | |
| Uses simple REINFORCE-style updates. | |
| """ | |
| import torch | |
| print(f"\n Loading {model_name} for manual training...") | |
| # Load model | |
| try: | |
| from unsloth import FastLanguageModel | |
| print(" β Unsloth 4-bit LoRA") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name, max_seq_length=1024, load_in_4bit=True) | |
| model = FastLanguageModel.get_peft_model( | |
| model, r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, lora_dropout=0, | |
| use_gradient_checkpointing="unsloth") | |
| FastLanguageModel.for_inference(model) | |
| USE_UNSLOTH = True | |
| except ImportError: | |
| import warnings | |
| warnings.filterwarnings("ignore", message=".*unauthenticated.*") | |
| warnings.filterwarnings("ignore", message=".*torch_dtype.*") | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| print(" Loading with transformers...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, dtype=torch.float16, device_map="auto") | |
| USE_UNSLOTH = False | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| SYSTEM = ("You audit clinical AI proposals. For each proposal, you must:\n" | |
| "1. Call review_proposal(proposal_id) to see the Actor's reasoning\n" | |
| "2. Call investigate_patient(patient_id) to check raw data\n" | |
| "3. Call flag_error(proposal_id, reason) OR approve(proposal_id)\n" | |
| "Respond with ONE tool call per turn as JSON: " | |
| '{\"tool\": \"review_proposal\", \"args\": {\"proposal_id\": \"PROP-001\"}}') | |
| rewards_per_episode = [] | |
| # Curriculum: Phase 1=easy, Phase 2=medium, Phase 3=hard | |
| CURRICULUM = [ | |
| ("oversight_easy", "Phase 1: Easy"), | |
| ("oversight_medium", "Phase 2: Medium"), | |
| ("oversight_hard", "Phase 3: Hard"), | |
| ] | |
| phase_size = max(1, max_steps // 3) | |
| est_min = max_steps * 1.5 # ~1.5 min per episode on T4 | |
| print(f" Estimated time: ~{est_min:.0f} min ({max_steps} episodes)\n") | |
| for episode in range(max_steps): | |
| phase_idx = min(episode // phase_size, 2) | |
| task_id, phase_name = CURRICULUM[phase_idx] | |
| # Print phase transition | |
| if episode == 0 or episode == phase_size or episode == phase_size * 2: | |
| print(f"\n ββ {phase_name} (episodes {episode+1}-{min(episode+phase_size, max_steps)}) ββ", flush=True) | |
| env = SynthAuditTrainEnv() | |
| seed = 42 + episode * 7 | |
| task_prompt = env.reset(seed=seed, task_id=task_id) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM}, | |
| {"role": "user", "content": task_prompt}, | |
| ] | |
| # Multi-turn interaction | |
| for turn in range(15): | |
| if env.done: | |
| break | |
| # Generate | |
| input_text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(input_text, return_tensors="pt", | |
| truncation=True, max_length=2048) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, max_new_tokens=256, | |
| temperature=0.7, do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id) | |
| response = tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True) | |
| # Parse tool call from response | |
| import re | |
| feedback = _execute_tool_call(env, response) | |
| messages.append({"role": "assistant", "content": response}) | |
| messages.append({"role": "user", "content": feedback}) | |
| # End episode if not done | |
| if not env.done: | |
| env._step(SynthAuditAction( | |
| action_type=ActionType.submit_audit_report, | |
| report="Audit complete.")) | |
| score = env.reward | |
| rewards_per_episode.append(score) | |
| window = min(5, len(rewards_per_episode)) | |
| avg = sum(rewards_per_episode[-window:]) / window | |
| bar = "β" * int(score * 30) + "β" * (30 - int(score * 30)) | |
| print(f" Episode {episode+1:3d} | Score: {score:.3f} | " | |
| f"Avg: {avg:.3f} | {bar}", flush=True) | |
| return rewards_per_episode | |
| def _execute_tool_call(env: SynthAuditTrainEnv, response: str) -> str: | |
| """Parse JSON tool call from model response and execute it.""" | |
| import json as _json | |
| import re | |
| # Try to extract JSON from response | |
| try: | |
| match = re.search(r'\{[^}]+\}', response) | |
| if match: | |
| call = _json.loads(match.group()) | |
| tool = call.get("tool", "") | |
| args = call.get("args", {}) | |
| if tool == "review_proposal" and "proposal_id" in args: | |
| return env.review_proposal(args["proposal_id"]) | |
| elif tool == "investigate_patient" and "patient_id" in args: | |
| return env.investigate_patient(args["patient_id"]) | |
| elif tool == "flag_error" and "proposal_id" in args: | |
| return env.flag_error( | |
| args["proposal_id"], args.get("reason", "flagged")) | |
| elif tool == "approve" and "proposal_id" in args: | |
| return env.approve(args["proposal_id"]) | |
| except (_json.JSONDecodeError, Exception): | |
| pass | |
| # Fallback: try to find proposal/patient IDs in text | |
| prop_match = re.search(r'PROP-\d+', response) | |
| patient_match = re.search(r'P\d{4}', response) | |
| if "flag" in response.lower() and prop_match: | |
| return env.flag_error(prop_match.group(), "Flagged based on analysis") | |
| elif "approve" in response.lower() and prop_match: | |
| return env.approve(prop_match.group()) | |
| elif "review" in response.lower() and prop_match: | |
| return env.review_proposal(prop_match.group()) | |
| elif "investigate" in response.lower() and patient_match: | |
| return env.investigate_patient(patient_match.group()) | |
| return "Could not parse tool call. Use JSON format: {\"tool\": \"...\", \"args\": {...}}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reward Curve Plotting | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_reward_curve(rewards: list[float], label: str = "GRPO Training"): | |
| """Generate publication-quality reward curve.""" | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| episodes = list(range(1, len(rewards) + 1)) | |
| window = min(5, len(rewards)) | |
| running_avg = [] | |
| for i in range(len(rewards)): | |
| start = max(0, i - window + 1) | |
| running_avg.append(sum(rewards[start:i+1]) / (i - start + 1)) | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| ax.plot(episodes, rewards, 'b-o', alpha=0.4, markersize=4, | |
| label='Episode Score', linewidth=1) | |
| ax.plot(episodes, running_avg, 'r-', linewidth=2.5, | |
| label=f'Running Average (w={window})') | |
| ax.fill_between(episodes, rewards, alpha=0.1, color='blue') | |
| ax.set_xlabel("Training Episode", fontsize=14) | |
| ax.set_ylabel("Oversight Score", fontsize=14) | |
| ax.set_title(f"SynthAudit.Env β {label}\n" | |
| "Multi-Agent Clinical AI Oversight (Fleet AI)", | |
| fontsize=15, fontweight='bold') | |
| ax.legend(fontsize=12, loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| ax.set_ylim(0, max(rewards) * 1.2 + 0.05) | |
| best_ep = rewards.index(max(rewards)) + 1 | |
| best_score = max(rewards) | |
| ax.annotate(f'Best: {best_score:.3f}', | |
| xy=(best_ep, best_score), | |
| xytext=(best_ep + 1, best_score + 0.03), | |
| arrowprops=dict(arrowstyle='->', color='red'), | |
| fontsize=11, color='red', fontweight='bold') | |
| os.makedirs(os.path.join(_project_dir, "outputs"), exist_ok=True) | |
| path = os.path.join(_project_dir, "outputs", "reward_curve.png") | |
| plt.tight_layout() | |
| plt.savefig(path, dpi=200, bbox_inches='tight') | |
| print(f"\nβ Reward curve saved to {path}") | |
| print(f" Best: {best_score:.3f} at episode {best_ep}") | |
| print(f" Final avg: {running_avg[-1]:.3f}") | |
| except ImportError: | |
| print(" matplotlib not available. Skipping plot.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="meta-llama/Llama-3.2-3B-Instruct") | |
| parser.add_argument("--path", choices=["auto", "grpo", "manual"], | |
| default="auto", help="Training path") | |
| parser.add_argument("--max-steps", type=int, default=30, | |
| help="Training episodes (30=~45min, 60=~1.5hr, 100=~2.5hr)") | |
| args = parser.parse_args() | |
| print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| print("β SynthAudit.Env β REAL Model Training β") | |
| print("β Multi-Agent Clinical AI Oversight β") | |
| print(f"β Model: {args.model:<50s}β") | |
| print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n") | |
| import torch | |
| if torch.cuda.is_available(): | |
| gpu = torch.cuda.get_device_name(0) | |
| vram = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| print(f" GPU: {gpu} ({vram:.1f} GB)") | |
| else: | |
| print(" β No GPU β training will be very slow") | |
| rewards = [] | |
| if args.path == "grpo" or args.path == "auto": | |
| try: | |
| from trl import GRPOTrainer | |
| import inspect | |
| if "environment_factory" in inspect.signature(GRPOTrainer.__init__).parameters: | |
| print("\n β TRL GRPOTrainer with environment_factory available") | |
| print(" β PATH A: Native GRPO training (REAL)\n") | |
| rewards = run_grpo_training(args.model, args.max_steps) | |
| if rewards: | |
| plot_reward_curve(rewards, "GRPO Training (Real)") | |
| return | |
| else: | |
| print(" β TRL found but environment_factory not in GRPOTrainer") | |
| if args.path == "grpo": | |
| print(" Install: pip install git+https://github.com/huggingface/transformers.git@main") | |
| return | |
| except ImportError: | |
| if args.path == "grpo": | |
| print(" β TRL not installed. Run: pip install trl") | |
| return | |
| # Fall through to manual | |
| print("\n β PATH B: Manual generate β score loop (REAL model inference)\n") | |
| rewards = run_manual_training(args.model, args.max_steps) | |
| # Save results | |
| os.makedirs(os.path.join(_project_dir, "outputs"), exist_ok=True) | |
| results = { | |
| "episodes": list(range(1, len(rewards) + 1)), | |
| "scores": rewards, | |
| "model": args.model, | |
| "method": "real_training", | |
| "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| with open(os.path.join(_project_dir, "outputs", "training_log.json"), "w") as f: | |
| json.dump(results, f, indent=2) | |
| plot_reward_curve(rewards, f"Real Training ({args.model.split('/')[-1]})") | |
| if __name__ == "__main__": | |
| main() | |