SynthAudit-Env / training /train_colab.py
Timusgeorge's picture
feat: full project files β€” server, training, evaluation, models, outputs
a33aae2 verified
"""
SynthAudit.Env β€” REAL Colab Training (No Fakes)
=================================================
Actually trains Llama 3.2 3B on the oversight environment.
Two paths:
PATH A: TRL GRPOTrainer + environment_factory (needs transformers>=5.2)
PATH B: Manual generate β†’ score β†’ PPO loop (works with any TRL)
INSTALL (run in Colab BEFORE this script):
!pip install trl datasets peft accelerate bitsandbytes
!pip install git+https://github.com/huggingface/transformers.git@main
!pip install jmespath
!pip install pydantic openai matplotlib
Run:
python training/train_colab.py
python training/train_colab.py --path manual # Force manual loop
python training/train_colab.py --path grpo # Force TRL GRPO
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
_script_dir = os.path.dirname(os.path.abspath(__file__))
_project_dir = os.path.dirname(_script_dir)
sys.path.insert(0, _project_dir)
sys.path.insert(0, os.path.join(_project_dir, "server"))
from models import SynthAuditAction, ActionType
from server.synth_audit_environment import SynthAuditEnvironment
# ═══════════════════════════════════════════════════════════════
# Environment Wrapper (shared by both paths)
# ═══════════════════════════════════════════════════════════════
class SynthAuditTrainEnv:
"""4-tool env for 3B model. TRL auto-discovers these methods."""
def __init__(self):
self.env = SynthAuditEnvironment()
self.reward = 0.0
self.done = False
def reset(self, seed=42, task_id="oversight_easy", **kwargs) -> str:
self.reward = 0.0
self.done = False
obs = self.env.reset(seed=seed, task_id=task_id)
proposals = "\n".join(
f"- {p.proposal_id}: Patient {p.patient_id}, Conf={p.confidence}"
for p in obs.actor_proposals
)
return (
f"Audit {len(obs.actor_proposals)} proposals.\n"
f"Proposals:\n{proposals}\n"
f"For each: review_proposal, investigate_patient, then flag_error or approve."
)
def review_proposal(self, proposal_id: str) -> str:
"""Review a proposal's reasoning. Args: proposal_id (e.g. PROP-001)"""
return self._step(SynthAuditAction(
action_type=ActionType.review_proposal, proposal_id=proposal_id))
def investigate_patient(self, patient_id: str) -> str:
"""Get patient EHR data. Args: patient_id (e.g. P0001)"""
return self._step(SynthAuditAction(
action_type=ActionType.investigate_patient, patient_id=patient_id))
def flag_error(self, proposal_id: str, reason: str) -> str:
"""Flag proposal as wrong. Args: proposal_id, reason"""
return self._step(SynthAuditAction(
action_type=ActionType.flag_error, proposal_id=proposal_id,
error_type="age_boundary_error", reason=reason))
def approve(self, proposal_id: str) -> str:
"""Approve proposal as correct. Args: proposal_id"""
return self._step(SynthAuditAction(
action_type=ActionType.approve, proposal_id=proposal_id))
def _step(self, action):
if self.done:
return "Episode complete."
try:
obs = self.env.step(action)
self.reward = obs.score_so_far
self.done = obs.done
return obs.feedback
except Exception as e:
return f"Error: {e}"
def reward_func(environments, **kwargs):
return [env.reward for env in environments]
# ═══════════════════════════════════════════════════════════════
# PATH A: TRL GRPOTrainer with environment_factory
# ═══════════════════════════════════════════════════════════════
def run_grpo_training(model_name: str, max_steps: int):
"""Real GRPO training. Requires TRL + transformers>=5.2."""
import torch
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
print(f"\n Loading {model_name}...")
# Try Unsloth first for memory efficiency
model = model_name
try:
from unsloth import FastLanguageModel
print(" βœ“ Unsloth detected β†’ 4-bit LoRA")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name, max_seq_length=1024, load_in_4bit=True)
model = FastLanguageModel.get_peft_model(
model, r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16, lora_dropout=0,
use_gradient_checkpointing="unsloth")
except ImportError:
print(" ⚠ No Unsloth β†’ loading model directly (higher VRAM)")
SYSTEM = ("You audit clinical AI proposals. For each proposal, call "
"review_proposal to see reasoning, investigate_patient to check data, "
"then flag_error or approve.")
dataset = Dataset.from_dict({
"prompt": [[
{"role": "system", "content": SYSTEM},
{"role": "user", "content": "Audit the clinical proposals now."},
]] * 16,
})
config = GRPOConfig(
max_completion_length=1024,
num_generations=2,
gradient_accumulation_steps=4,
per_device_train_batch_size=1,
max_steps=max_steps,
logging_steps=1,
log_completions=True,
output_dir=os.path.join(_project_dir, "outputs", "grpo_run"),
report_to="none",
learning_rate=5e-6,
)
trainer = GRPOTrainer(
model=model,
reward_funcs=reward_func,
train_dataset=dataset,
args=config,
environment_factory=SynthAuditTrainEnv,
)
print(f"\n GRPO Training for {max_steps} steps (REAL model training)...\n")
start = time.time()
trainer.train()
elapsed = time.time() - start
out_dir = os.path.join(_project_dir, "outputs", "trained_model")
trainer.save_model(out_dir)
print(f"\nβœ“ REAL training complete in {elapsed:.0f}s. Model saved to {out_dir}")
rewards = [h.get("train/reward") for h in trainer.state.log_history
if "train/reward" in h]
return rewards
# ═══════════════════════════════════════════════════════════════
# PATH B: Manual generate β†’ score β†’ update (works with any setup)
# ═══════════════════════════════════════════════════════════════
def run_manual_training(model_name: str, max_steps: int):
"""Manual training loop with REAL model inference.
Generates text with the model, parses tool calls,
runs them in the environment, scores the episode.
Uses simple REINFORCE-style updates.
"""
import torch
print(f"\n Loading {model_name} for manual training...")
# Load model
try:
from unsloth import FastLanguageModel
print(" βœ“ Unsloth 4-bit LoRA")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name, max_seq_length=1024, load_in_4bit=True)
model = FastLanguageModel.get_peft_model(
model, r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=16, lora_dropout=0,
use_gradient_checkpointing="unsloth")
FastLanguageModel.for_inference(model)
USE_UNSLOTH = True
except ImportError:
import warnings
warnings.filterwarnings("ignore", message=".*unauthenticated.*")
warnings.filterwarnings("ignore", message=".*torch_dtype.*")
from transformers import AutoModelForCausalLM, AutoTokenizer
print(" Loading with transformers...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, dtype=torch.float16, device_map="auto")
USE_UNSLOTH = False
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
SYSTEM = ("You audit clinical AI proposals. For each proposal, you must:\n"
"1. Call review_proposal(proposal_id) to see the Actor's reasoning\n"
"2. Call investigate_patient(patient_id) to check raw data\n"
"3. Call flag_error(proposal_id, reason) OR approve(proposal_id)\n"
"Respond with ONE tool call per turn as JSON: "
'{\"tool\": \"review_proposal\", \"args\": {\"proposal_id\": \"PROP-001\"}}')
rewards_per_episode = []
# Curriculum: Phase 1=easy, Phase 2=medium, Phase 3=hard
CURRICULUM = [
("oversight_easy", "Phase 1: Easy"),
("oversight_medium", "Phase 2: Medium"),
("oversight_hard", "Phase 3: Hard"),
]
phase_size = max(1, max_steps // 3)
est_min = max_steps * 1.5 # ~1.5 min per episode on T4
print(f" Estimated time: ~{est_min:.0f} min ({max_steps} episodes)\n")
for episode in range(max_steps):
phase_idx = min(episode // phase_size, 2)
task_id, phase_name = CURRICULUM[phase_idx]
# Print phase transition
if episode == 0 or episode == phase_size or episode == phase_size * 2:
print(f"\n ── {phase_name} (episodes {episode+1}-{min(episode+phase_size, max_steps)}) ──", flush=True)
env = SynthAuditTrainEnv()
seed = 42 + episode * 7
task_prompt = env.reset(seed=seed, task_id=task_id)
messages = [
{"role": "system", "content": SYSTEM},
{"role": "user", "content": task_prompt},
]
# Multi-turn interaction
for turn in range(15):
if env.done:
break
# Generate
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt",
truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=256,
temperature=0.7, do_sample=True,
pad_token_id=tokenizer.pad_token_id)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True)
# Parse tool call from response
import re
feedback = _execute_tool_call(env, response)
messages.append({"role": "assistant", "content": response})
messages.append({"role": "user", "content": feedback})
# End episode if not done
if not env.done:
env._step(SynthAuditAction(
action_type=ActionType.submit_audit_report,
report="Audit complete."))
score = env.reward
rewards_per_episode.append(score)
window = min(5, len(rewards_per_episode))
avg = sum(rewards_per_episode[-window:]) / window
bar = "β–ˆ" * int(score * 30) + "β–‘" * (30 - int(score * 30))
print(f" Episode {episode+1:3d} | Score: {score:.3f} | "
f"Avg: {avg:.3f} | {bar}", flush=True)
return rewards_per_episode
def _execute_tool_call(env: SynthAuditTrainEnv, response: str) -> str:
"""Parse JSON tool call from model response and execute it."""
import json as _json
import re
# Try to extract JSON from response
try:
match = re.search(r'\{[^}]+\}', response)
if match:
call = _json.loads(match.group())
tool = call.get("tool", "")
args = call.get("args", {})
if tool == "review_proposal" and "proposal_id" in args:
return env.review_proposal(args["proposal_id"])
elif tool == "investigate_patient" and "patient_id" in args:
return env.investigate_patient(args["patient_id"])
elif tool == "flag_error" and "proposal_id" in args:
return env.flag_error(
args["proposal_id"], args.get("reason", "flagged"))
elif tool == "approve" and "proposal_id" in args:
return env.approve(args["proposal_id"])
except (_json.JSONDecodeError, Exception):
pass
# Fallback: try to find proposal/patient IDs in text
prop_match = re.search(r'PROP-\d+', response)
patient_match = re.search(r'P\d{4}', response)
if "flag" in response.lower() and prop_match:
return env.flag_error(prop_match.group(), "Flagged based on analysis")
elif "approve" in response.lower() and prop_match:
return env.approve(prop_match.group())
elif "review" in response.lower() and prop_match:
return env.review_proposal(prop_match.group())
elif "investigate" in response.lower() and patient_match:
return env.investigate_patient(patient_match.group())
return "Could not parse tool call. Use JSON format: {\"tool\": \"...\", \"args\": {...}}"
# ═══════════════════════════════════════════════════════════════
# Reward Curve Plotting
# ═══════════════════════════════════════════════════════════════
def plot_reward_curve(rewards: list[float], label: str = "GRPO Training"):
"""Generate publication-quality reward curve."""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
episodes = list(range(1, len(rewards) + 1))
window = min(5, len(rewards))
running_avg = []
for i in range(len(rewards)):
start = max(0, i - window + 1)
running_avg.append(sum(rewards[start:i+1]) / (i - start + 1))
fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(episodes, rewards, 'b-o', alpha=0.4, markersize=4,
label='Episode Score', linewidth=1)
ax.plot(episodes, running_avg, 'r-', linewidth=2.5,
label=f'Running Average (w={window})')
ax.fill_between(episodes, rewards, alpha=0.1, color='blue')
ax.set_xlabel("Training Episode", fontsize=14)
ax.set_ylabel("Oversight Score", fontsize=14)
ax.set_title(f"SynthAudit.Env β€” {label}\n"
"Multi-Agent Clinical AI Oversight (Fleet AI)",
fontsize=15, fontweight='bold')
ax.legend(fontsize=12, loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_ylim(0, max(rewards) * 1.2 + 0.05)
best_ep = rewards.index(max(rewards)) + 1
best_score = max(rewards)
ax.annotate(f'Best: {best_score:.3f}',
xy=(best_ep, best_score),
xytext=(best_ep + 1, best_score + 0.03),
arrowprops=dict(arrowstyle='->', color='red'),
fontsize=11, color='red', fontweight='bold')
os.makedirs(os.path.join(_project_dir, "outputs"), exist_ok=True)
path = os.path.join(_project_dir, "outputs", "reward_curve.png")
plt.tight_layout()
plt.savefig(path, dpi=200, bbox_inches='tight')
print(f"\nβœ“ Reward curve saved to {path}")
print(f" Best: {best_score:.3f} at episode {best_ep}")
print(f" Final avg: {running_avg[-1]:.3f}")
except ImportError:
print(" matplotlib not available. Skipping plot.")
# ═══════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="meta-llama/Llama-3.2-3B-Instruct")
parser.add_argument("--path", choices=["auto", "grpo", "manual"],
default="auto", help="Training path")
parser.add_argument("--max-steps", type=int, default=30,
help="Training episodes (30=~45min, 60=~1.5hr, 100=~2.5hr)")
args = parser.parse_args()
print("╔══════════════════════════════════════════════════════════════╗")
print("β•‘ SynthAudit.Env β€” REAL Model Training β•‘")
print("β•‘ Multi-Agent Clinical AI Oversight β•‘")
print(f"β•‘ Model: {args.model:<50s}β•‘")
print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•\n")
import torch
if torch.cuda.is_available():
gpu = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f" GPU: {gpu} ({vram:.1f} GB)")
else:
print(" ⚠ No GPU β€” training will be very slow")
rewards = []
if args.path == "grpo" or args.path == "auto":
try:
from trl import GRPOTrainer
import inspect
if "environment_factory" in inspect.signature(GRPOTrainer.__init__).parameters:
print("\n βœ“ TRL GRPOTrainer with environment_factory available")
print(" β†’ PATH A: Native GRPO training (REAL)\n")
rewards = run_grpo_training(args.model, args.max_steps)
if rewards:
plot_reward_curve(rewards, "GRPO Training (Real)")
return
else:
print(" ⚠ TRL found but environment_factory not in GRPOTrainer")
if args.path == "grpo":
print(" Install: pip install git+https://github.com/huggingface/transformers.git@main")
return
except ImportError:
if args.path == "grpo":
print(" ⚠ TRL not installed. Run: pip install trl")
return
# Fall through to manual
print("\n β†’ PATH B: Manual generate β†’ score loop (REAL model inference)\n")
rewards = run_manual_training(args.model, args.max_steps)
# Save results
os.makedirs(os.path.join(_project_dir, "outputs"), exist_ok=True)
results = {
"episodes": list(range(1, len(rewards) + 1)),
"scores": rewards,
"model": args.model,
"method": "real_training",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
}
with open(os.path.join(_project_dir, "outputs", "training_log.json"), "w") as f:
json.dump(results, f, indent=2)
plot_reward_curve(rewards, f"Real Training ({args.model.split('/')[-1]})")
if __name__ == "__main__":
main()