from __future__ import annotations import re, random from typing import Any import torch from datasets import Dataset from transformers import AutoTokenizer from trl import GRPOConfig, GRPOTrainer from trl.trainer.grpo_trainer import generate_rollout_completions from client import TrustSafetyEnv from models import TrustAction, ContentSignals # ── Config ──────────────────────────────────────────────────────────────────── MODEL_NAME = "Qwen/Qwen3-1.7B" ENV_BASE_URL = "https://YOUR-USERNAME-trust-safety-env.hf.space" # ← your HF Space URL MAX_STEPS = 7 TASKS = ["T-001","T-002","T-003","T-004","T-005","T-006"] SYSTEM_PROMPT = """You are a Trust & Safety content moderation agent. MANDATORY 3-STEP PROCESS — follow this EXACT order every time: STEP 1 — USE TOOLS (always before anything else) Call investigation tools to gather evidence before deciding: - read_comments → check comment context - check_user_history → check prior violations - check_entity_status → verify public figure / org status - view_policy → retrieve relevant platform policy Respond with JSON: {"action_type": "use_tool", "tool_name": ""} STEP 2 — EXTRACT SIGNALS After gathering evidence, respond with JSON: {"action_type": "extract_signals", "signals": { "target": "", "is_protected_class": , "toxicity_level": <0.0-1.0>, "is_direct_attack": , "context_type": "", "intent": "", "confidence": <0.0-1.0>, "abusive_language_present": , "content_flags": [""] }} STEP 3 — FINAL DECISION Respond with JSON: {"action_type": "final_decision", "final_decision": ""} ⚠️ Skipping tools costs -0.25 reward per missed required tool. ⚠️ Always output valid JSON. Nothing else.""" # ── Dataset — one row per task, shuffled every epoch ───────────────────────── def make_dataset() -> Dataset: rows = [{"task_id": tid, "prompt": f"Process moderation ticket {tid}"} for tid in TASKS] random.shuffle(rows) return Dataset.from_list(rows * 50) # 300 rows = ~50 passes over all 6 tasks # ── Action parser — extract JSON action from model output ──────────────────── def parse_action(text: str) -> TrustAction | None: try: match = re.search(r'\{.*\}', text, re.DOTALL) if not match: return None import json d = json.loads(match.group()) atype = d.get("action_type", "") if atype == "use_tool": return TrustAction(action_type="use_tool", tool_name=d.get("tool_name")) if atype == "extract_signals": raw = d.get("signals", {}) flags = raw.get("content_flags", []) if not isinstance(flags, list): flags = [] signals = ContentSignals( target = raw.get("target", "none"), is_protected_class = bool(raw.get("is_protected_class", False)), toxicity_level = float(raw.get("toxicity_level", 0.5)), is_direct_attack = bool(raw.get("is_direct_attack", False)), context_type = raw.get("context_type", "statement"), intent = raw.get("intent", "ambiguous"), confidence = float(raw.get("confidence", 0.5)), abusive_language_present= bool(raw.get("abusive_language_present", False)), content_flags = flags, ) return TrustAction(action_type="extract_signals", signals=signals) if atype == "final_decision": return TrustAction(action_type="final_decision", final_decision=d.get("final_decision", "ESCALATE")) except Exception: pass return None # ── Rollout — one full episode per prompt ──────────────────────────────────── def rollout_once(trainer, env, tokenizer, task_id: str): result = env.reset(episode_id=task_id) obs = result.observation ep_reward = 0.0 all_prompt_ids = [] all_completion_ids = [] all_logprobs = [] reward_components = {"accuracy": 0.0, "tools": 0.0, "signals": 0.0} for step in range(MAX_STEPS): if result.done: break # Build context from current observation context_parts = [f"TICKET: {obs.ticket_id}\n{obs.post_text}"] if obs.comments_found: context_parts.append(f"Comments: {obs.comments_found}") if obs.user_history_found: context_parts.append(f"User History: {obs.user_history_found}") if obs.entity_status_found: context_parts.append(f"Entity: {obs.entity_status_found}") if obs.policy_found: context_parts.append(f"Policy: {obs.policy_found}") if obs.extracted_signals: context_parts.append(f"Signals: {obs.extracted_signals}") messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": "\n\n".join(context_parts)}, ] rollout = generate_rollout_completions(trainer, [messages]) text = rollout["text"][0] if rollout["text"] else "" action = parse_action(text) if action is None: action = TrustAction(action_type="final_decision", final_decision="ESCALATE") result = env.step(action) obs = result.observation if result.reward is not None: ep_reward = result.reward all_prompt_ids.extend(rollout.get("prompt_ids", [])) all_completion_ids.extend(rollout.get("completion_ids", [])) all_logprobs.extend(rollout.get("logprobs", [])) # Reward components for GRPO shaping signal info = obs.info or {} breakdown = info.get("reward_breakdown", {}) reward_components["accuracy"] = breakdown.get("base_score", ep_reward) reward_components["tools"] = -breakdown.get("tool_miss_penalty", 0.0) reward_components["signals"] = breakdown.get("signal_accuracy_bonus", 0.0) return { "prompt_ids": all_prompt_ids, "completion_ids": all_completion_ids, "logprobs": all_logprobs, "accuracy_reward": reward_components["accuracy"], "tool_reward": reward_components["tools"], "signal_reward": reward_components["signals"], } # ── Rollout wrapper — called by GRPOTrainer ─────────────────────────────────── def rollout_func(trainer, prompts, **kwargs): env = TrustSafetyEnv(base_url=ENV_BASE_URL).sync() results = [] for p in prompts: task_id = p.get("task_id", random.choice(TASKS)) results.append(rollout_once(trainer, env, None, task_id)) env.close() return results # ── Reward functions (called by GRPOTrainer separately) ────────────────────── def reward_accuracy(completions, **kwargs): """Base decision correctness from environment.""" return [r.get("accuracy_reward", 0.0) for r in completions] def reward_tools(completions, **kwargs): """Penalise skipping required tools.""" return [r.get("tool_reward", 0.0) for r in completions] def reward_signals(completions, **kwargs): """Bonus for accurate signal extraction.""" return [r.get("signal_reward", 0.0) for r in completions] # ── GRPOConfig ──────────────────────────────────────────────────────────────── grpo_config = GRPOConfig( output_dir = "./trust-safety-grpo", num_train_epochs = 1, learning_rate = 5e-6, gradient_accumulation_steps = 64, per_device_train_batch_size = 1, num_generations = 2, max_completion_length = 256, # JSON actions are longer than Wordle guesses max_prompt_length = 1024, use_vllm = True, vllm_mode = "colocate", vllm_gpu_memory_utilization = 0.15, gradient_checkpointing = True, report_to = "trackio", logging_steps = 5, save_steps = 50, ) # ── Main ────────────────────────────────────────────────────────────────────── if __name__ == "__main__": dataset = make_dataset() trainer = GRPOTrainer( model = MODEL_NAME, reward_funcs = [reward_accuracy, reward_tools, reward_signals], rollout_func = rollout_func, train_dataset = dataset, args = grpo_config, ) print("=" * 60) print(f"Training {MODEL_NAME} on Trust & Safety environment") print(f"Tasks : {TASKS}") print(f"Epochs: {grpo_config.num_train_epochs}") print("=" * 60) trainer.train() trainer.save_model("./trust-safety-grpo-final") print("✅ Model saved to ./trust-safety-grpo-final")