Spaces:
Sleeping
Sleeping
| """ | |
| SynthAudit.Env β REAL GRPO Training (Unsloth + TRL) | |
| ===================================================== | |
| ACTUALLY trains the model. Weights update. Rewards improve. | |
| Run on Colab T4: | |
| !pip install unsloth | |
| !pip install trl datasets | |
| !python3 training/train_real.py | |
| """ | |
| from __future__ import annotations | |
| import json, os, re, sys, time, warnings | |
| warnings.filterwarnings("ignore") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| _script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| _project_dir = os.path.dirname(_script_dir) | |
| sys.path.insert(0, _project_dir) | |
| sys.path.insert(0, os.path.join(_project_dir, "server")) | |
| from models import SynthAuditAction, ActionType | |
| from server.synth_audit_environment import SynthAuditEnvironment | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Reward function: runs a FULL episode from model's completion | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def score_completion(text: str, seed: int = 42, task_id: str = "oversight_easy") -> float: | |
| """Parse model output as JSON tool calls, execute in env, return score.""" | |
| env = SynthAuditEnvironment() | |
| obs = env.reset(seed=seed, task_id=task_id) | |
| # Try to parse JSON array of actions | |
| actions = [] | |
| try: | |
| match = re.search(r'\[.*\]', text, re.DOTALL) | |
| if match: | |
| actions = json.loads(match.group()) | |
| except Exception: | |
| pass | |
| # Fallback: parse individual JSON objects | |
| if not actions: | |
| for m in re.finditer(r'\{[^{}]+\}', text): | |
| try: | |
| actions.append(json.loads(m.group())) | |
| except Exception: | |
| continue | |
| # Execute parsed actions | |
| for act in actions: | |
| if obs.done: | |
| break | |
| try: | |
| action = SynthAuditAction(**act) | |
| obs = env.step(action) | |
| except Exception: | |
| continue | |
| return obs.score_so_far | |
| def make_reward_func(seeds, task_ids): | |
| """Create reward function for GRPOTrainer.""" | |
| def reward_func(completions, **kwargs): | |
| scores = [] | |
| for i, completion_list in enumerate(completions): | |
| text = completion_list[0]["content"] if isinstance(completion_list, list) else str(completion_list) | |
| seed = seeds[i % len(seeds)] | |
| task = task_ids[i % len(task_ids)] | |
| score = score_completion(text, seed=seed, task_id=task) | |
| scores.append(float(score)) | |
| return scores | |
| return reward_func | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main Training | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| import torch | |
| MODEL = os.getenv("MODEL", "Qwen/Qwen2.5-3B-Instruct") | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "50")) | |
| NUM_GEN = int(os.getenv("NUM_GEN", "4")) | |
| print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ") | |
| print("β SynthAudit.Env β REAL GRPO Training (Unsloth + TRL) β") | |
| print("β Multi-Agent Clinical AI Oversight β") | |
| print(f"β Model: {MODEL:<47s}β") | |
| print(f"β Steps: {MAX_STEPS:<47d}β") | |
| print(f"β Gen/step: {NUM_GEN:<47d}β") | |
| print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ\n") | |
| if torch.cuda.is_available(): | |
| gpu = torch.cuda.get_device_name(0) | |
| vram = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| print(f" GPU: {gpu} ({vram:.1f} GB)") | |
| # ββ Load model with Unsloth βββββββββββββββββββββββββββ | |
| try: | |
| from unsloth import FastLanguageModel | |
| print(f"\n Loading {MODEL} with Unsloth (4-bit LoRA)...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| MODEL, max_seq_length=1024, load_in_4bit=True) | |
| model = FastLanguageModel.get_peft_model( | |
| model, r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, lora_dropout=0, | |
| use_gradient_checkpointing="unsloth") | |
| print(" β Unsloth 4-bit LoRA ready") | |
| USE_UNSLOTH = True | |
| except ImportError: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| print(f"\n Loading {MODEL} with transformers...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL, dtype=torch.float16, device_map="auto") | |
| USE_UNSLOTH = False | |
| print(" β No Unsloth β using raw transformers (higher VRAM)") | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ββ Build dataset βββββββββββββββββββββββββββββββββββββ | |
| from datasets import Dataset | |
| SYSTEM = ( | |
| "You are an AI Oversight Auditor. A Medical AI reviewed clinical trial " | |
| "patients and proposed diagnoses. Some contain errors.\n\n" | |
| "Return a JSON array of actions to audit the proposals:\n" | |
| '- {"action_type": "review_proposal", "proposal_id": "PROP-001"}\n' | |
| '- {"action_type": "investigate_patient", "patient_id": "P0001"}\n' | |
| '- {"action_type": "flag_error", "proposal_id": "PROP-001", ' | |
| '"error_type": "age_boundary_error", "reason": "Patient age 150 exceeds protocol max"}\n' | |
| '- {"action_type": "approve", "proposal_id": "PROP-001"}\n\n' | |
| "First review each proposal, then investigate the patient, then flag or approve." | |
| ) | |
| # Generate varied prompts by running env resets | |
| prompts = [] | |
| seeds = [] | |
| task_ids = [] | |
| dataset_size = max(MAX_STEPS * 2, 64) | |
| TASKS = ["oversight_easy"] * (dataset_size // 2) + \ | |
| ["oversight_medium"] * (dataset_size // 4) + \ | |
| ["oversight_hard"] * (dataset_size - dataset_size // 2 - dataset_size // 4) | |
| for i in range(dataset_size): | |
| seed = 42 + i * 7 | |
| task = TASKS[i] | |
| env = SynthAuditEnvironment() | |
| obs = env.reset(seed=seed, task_id=task) | |
| proposal_text = "\n".join( | |
| f" {p.proposal_id}: Patient {p.patient_id}, " | |
| f"Dx={p.diagnosis}, Confidence={p.confidence}" | |
| for p in obs.actor_proposals | |
| ) | |
| user_msg = ( | |
| f"PROTOCOL:\n{obs.protocol_excerpt[:200]}\n\n" | |
| f"PROPOSALS ({len(obs.actor_proposals)}):\n{proposal_text}\n\n" | |
| f"Audit these proposals. Return a JSON array of actions." | |
| ) | |
| prompts.append([ | |
| {"role": "system", "content": SYSTEM}, | |
| {"role": "user", "content": user_msg}, | |
| ]) | |
| seeds.append(seed) | |
| task_ids.append(task) | |
| dataset = Dataset.from_dict({"prompt": prompts}) | |
| print(f" Dataset: {dataset_size} prompts (50% easy, 25% medium, 25% hard)") | |
| # ββ Try GRPO Training βββββββββββββββββββββββββββββββββ | |
| from trl import GRPOTrainer, GRPOConfig | |
| config = GRPOConfig( | |
| max_completion_length=512, | |
| num_generations=NUM_GEN, | |
| gradient_accumulation_steps=1, | |
| per_device_train_batch_size=1, | |
| max_steps=MAX_STEPS, | |
| logging_steps=1, | |
| output_dir=os.path.join(_project_dir, "outputs", "grpo_run"), | |
| report_to="none", | |
| learning_rate=5e-6, | |
| save_steps=25, | |
| save_total_limit=2, | |
| log_completions=True, | |
| ) | |
| reward_fn = make_reward_func(seeds, task_ids) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=reward_fn, | |
| train_dataset=dataset, | |
| args=config, | |
| ) | |
| print(f"\n βΈ GRPO Training for {MAX_STEPS} steps...") | |
| print(f" βΈ This is REAL training β weights are being updated!\n") | |
| start = time.time() | |
| trainer.train() | |
| elapsed = time.time() - start | |
| # ββ Save model ββββββββββββββββββββββββββββββββββββββββ | |
| out_dir = os.path.join(_project_dir, "outputs", "trained_model") | |
| trainer.save_model(out_dir) | |
| # ββ Extract metrics βββββββββββββββββββββββββββββββββββ | |
| rewards = [h["train/reward"] for h in trainer.state.log_history | |
| if "train/reward" in h] | |
| losses = [h["train/loss"] for h in trainer.state.log_history | |
| if "train/loss" in h] | |
| results = { | |
| "model": MODEL, | |
| "method": "GRPO", | |
| "max_steps": MAX_STEPS, | |
| "num_generations": NUM_GEN, | |
| "elapsed_seconds": round(elapsed), | |
| "rewards": rewards, | |
| "losses": losses, | |
| "final_reward": rewards[-1] if rewards else None, | |
| "best_reward": max(rewards) if rewards else None, | |
| } | |
| os.makedirs(os.path.join(_project_dir, "outputs"), exist_ok=True) | |
| with open(os.path.join(_project_dir, "outputs", "training_log.json"), "w") as f: | |
| json.dump(results, f, indent=2) | |
| # ββ Plot ββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(1, 2, figsize=(16, 6)) | |
| if rewards: | |
| steps = list(range(1, len(rewards) + 1)) | |
| w = min(5, len(rewards)) | |
| avg = [] | |
| for i in range(len(rewards)): | |
| s = max(0, i - w + 1) | |
| avg.append(sum(rewards[s:i+1]) / (i - s + 1)) | |
| axes[0].plot(steps, rewards, 'b-', alpha=0.3, linewidth=1) | |
| axes[0].plot(steps, avg, 'r-', linewidth=2.5, label=f'Running Avg (w={w})') | |
| axes[0].fill_between(steps, rewards, alpha=0.1, color='blue') | |
| axes[0].set_xlabel("Training Step") | |
| axes[0].set_ylabel("Reward (Episode Score)") | |
| axes[0].set_title("GRPO Reward Curve", fontweight='bold') | |
| axes[0].legend() | |
| axes[0].grid(True, alpha=0.3) | |
| if losses: | |
| axes[1].plot(range(1, len(losses)+1), losses, 'g-', linewidth=1.5) | |
| axes[1].set_xlabel("Training Step") | |
| axes[1].set_ylabel("Loss") | |
| axes[1].set_title("Training Loss", fontweight='bold') | |
| axes[1].grid(True, alpha=0.3) | |
| fig.suptitle(f"SynthAudit.Env β GRPO Training ({MODEL.split('/')[-1]})\n" | |
| f"{MAX_STEPS} steps, {elapsed/60:.0f} min, REAL weight updates", | |
| fontsize=14, fontweight='bold') | |
| plt.tight_layout() | |
| path = os.path.join(_project_dir, "outputs", "reward_curve.png") | |
| plt.savefig(path, dpi=200, bbox_inches='tight') | |
| print(f"\nβ Reward curve: {path}") | |
| except ImportError: | |
| pass | |
| print(f"\n{'='*60}") | |
| print(f" REAL GRPO Training Complete") | |
| print(f" Time: {elapsed/60:.1f} min") | |
| print(f" Steps: {MAX_STEPS}") | |
| print(f" Best reward: {max(rewards) if rewards else 'N/A'}") | |
| print(f" Final reward: {rewards[-1] if rewards else 'N/A'}") | |
| print(f" Model saved: {out_dir}") | |
| print(f"{'='*60}") | |
| if __name__ == "__main__": | |
| main() | |