{ "cells": [ { "cell_type": "markdown", "id": "eab24a17", "metadata": {}, "source": [ "# SmartPayEnv Theme-4 Judge Repro (Colab, Self-Contained, Unsloth + TRL@git)\n", "\n", "Self-contained notebook. Does NOT import anything from the repo.\n", "\n", "Pipeline:\n", "1. Install deps (Unsloth + TRL from GitHub)\n", "2. HF login (uses your HF credits)\n", "3. Connect to deployed SmartPayEnv Space\n", "4. Collect group-relative preference pairs (inline)\n", "5. Baseline eval (random + heuristic) on frozen seed\n", "6. Train policy with Unsloth FastLanguageModel + TRL DPO\n", "7. Trained-policy eval on the same frozen seed\n", "8. Plots:\n", " - rollout reward curve\n", " - DPO training loss\n", " - before/after mean reward (random vs heuristic vs trained)\n", " - mean reward per risk bucket (low / medium / high)\n", "9. Save artifacts to ./artifacts\n", "\n", "Hackathon: OpenEnv (India 2026), Theme #4 — Self-Improvement.\n", "Space: https://huggingface.co/spaces/Pratap-K/SmartPayEnv" ] }, { "cell_type": "markdown", "id": "57c1f412", "metadata": {}, "source": [ "## 1. Install dependencies (Unsloth + TRL from GitHub)" ] }, { "cell_type": "code", "execution_count": null, "id": "c0142bbc", "metadata": {}, "outputs": [], "source": [ "!pip -q install --upgrade pip\n", "!pip -q install \"unsloth @ git+https://github.com/unslothai/unsloth.git\"\n", "!pip -q install \"trl @ git+https://github.com/huggingface/trl.git\"\n", "!pip -q install --upgrade transformers accelerate peft bitsandbytes datasets huggingface_hub matplotlib pandas requests numpy" ] }, { "cell_type": "markdown", "id": "e4f39274", "metadata": {}, "source": [ "## 2. Authenticate Hugging Face" ] }, { "cell_type": "code", "execution_count": null, "id": "6a201e39", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()" ] }, { "cell_type": "markdown", "id": "d5373ffe", "metadata": {}, "source": [ "## 3. Configuration" ] }, { "cell_type": "code", "execution_count": null, "id": "73b92d43", "metadata": {}, "outputs": [], "source": [ "import os, json, random\n", "import numpy as np\n", "\n", "QUICK_MODE = True\n", "ENV_URL = 'https://pratap-k-smartpayenv.hf.space'\n", "DIFFICULTY = 2\n", "SEED = 42\n", "\n", "ROLLOUT_STEPS = 60 if QUICK_MODE else 240\n", "GROUP_SIZE = 6 if QUICK_MODE else 10\n", "EVAL_EPISODES = 3 if QUICK_MODE else 5\n", "EVAL_STEPS_PER_EPISODE = 30 if QUICK_MODE else 60\n", "\n", "MODEL_ID = 'unsloth/Qwen2.5-0.5B-Instruct'\n", "MAX_SEQ_LEN = 2048\n", "LOAD_IN_4BIT = True\n", "\n", "os.makedirs('artifacts', exist_ok=True)\n", "random.seed(SEED)\n", "np.random.seed(SEED)\n", "print('Config ready. QUICK_MODE =', QUICK_MODE, '| MODEL_ID =', MODEL_ID)" ] }, { "cell_type": "markdown", "id": "7060469d", "metadata": {}, "source": [ "## 4. Health check" ] }, { "cell_type": "code", "execution_count": null, "id": "b0198da2", "metadata": {}, "outputs": [], "source": [ "import requests\n", "r = requests.get(f'{ENV_URL}/health', timeout=30)\n", "print('Health:', r.status_code, r.text[:120])" ] }, { "cell_type": "markdown", "id": "b2d9dcfc", "metadata": {}, "source": [ "## 5. Inline env helpers" ] }, { "cell_type": "code", "execution_count": null, "id": "1e10b8d4", "metadata": {}, "outputs": [], "source": [ "def env_reset(difficulty=DIFFICULTY):\n", " res = requests.post(f'{ENV_URL}/reset', json={'difficulty': int(difficulty)}, timeout=30)\n", " res.raise_for_status()\n", " payload = res.json()\n", " return payload.get('observation', payload)\n", "\n", "def env_step(action):\n", " res = requests.post(f'{ENV_URL}/step', json={'action': action}, timeout=30)\n", " res.raise_for_status()\n", " return res.json()\n", "\n", "def env_simulate(action):\n", " res = requests.post(f'{ENV_URL}/simulate', json={'action': action}, timeout=30)\n", " res.raise_for_status()\n", " return res.json()\n", "\n", "def all_actions():\n", " out = []\n", " for g in (0,1,2):\n", " for f in (0,1,2,3):\n", " for r in (0,1):\n", " out.append({'gateway': g, 'fraud_decision': f, 'retry_strategy': r})\n", " return out\n", "\n", "ACTIONS = all_actions()\n", "print('Total candidate actions:', len(ACTIONS))" ] }, { "cell_type": "markdown", "id": "d0d33873", "metadata": {}, "source": [ "## 6. Collect group-relative preference pairs" ] }, { "cell_type": "code", "execution_count": null, "id": "db6c57b4", "metadata": {}, "outputs": [], "source": [ "def collect_pairs(steps=ROLLOUT_STEPS, group=GROUP_SIZE, difficulty=DIFFICULTY):\n", " obs = env_reset(difficulty)\n", " pairs, reward_trace = [], []\n", " for _ in range(steps):\n", " sampled = random.sample(ACTIONS, k=min(group, len(ACTIONS)))\n", " scored = []\n", " for a in sampled:\n", " try:\n", " sim = env_simulate(a)\n", " scored.append((a, float(sim.get('reward', 0.0))))\n", " except requests.RequestException:\n", " continue\n", " if len(scored) < 2:\n", " break\n", " scored.sort(key=lambda x: x[1], reverse=True)\n", " best, best_r = scored[0]\n", " worst, worst_r = scored[-1]\n", "\n", " prompt = (\n", " 'SmartPayEnv observation:\\n'\n", " f'{json.dumps(obs, sort_keys=True)}\\n'\n", " 'Return one action JSON with fields: gateway, fraud_decision, retry_strategy.'\n", " )\n", " pairs.append({\n", " 'prompt': prompt,\n", " 'chosen': json.dumps(best, sort_keys=True),\n", " 'rejected': json.dumps(worst, sort_keys=True),\n", " 'chosen_reward': best_r,\n", " 'rejected_reward': worst_r,\n", " })\n", " reward_trace.append(best_r)\n", "\n", " step_payload = env_step(best)\n", " obs = step_payload.get('observation', step_payload)\n", " if bool(obs.get('done', False)):\n", " obs = env_reset(difficulty)\n", " return pairs, reward_trace\n", "\n", "pairs, rollout_rewards = collect_pairs()\n", "print('Collected pairs:', len(pairs))" ] }, { "cell_type": "markdown", "id": "9d0f2b46", "metadata": {}, "source": [ "## 7. Baseline evaluation (random + heuristic) with risk-bucket breakdown" ] }, { "cell_type": "code", "execution_count": null, "id": "fc0a1f5b", "metadata": {}, "outputs": [], "source": [ "def risk_bucket(obs):\n", " r = float(obs.get('observed_fraud_risk', 0.0))\n", " if r < 0.3:\n", " return 'low'\n", " if r < 0.65:\n", " return 'medium'\n", " return 'high'\n", "\n", "def eval_policy(policy_fn, episodes=EVAL_EPISODES, steps=EVAL_STEPS_PER_EPISODE, difficulty=DIFFICULTY):\n", " all_rewards = []\n", " per_episode_means = []\n", " bucket_rewards = {'low': [], 'medium': [], 'high': []}\n", " for _ in range(episodes):\n", " obs = env_reset(difficulty)\n", " ep_rewards = []\n", " for _ in range(steps):\n", " bucket = risk_bucket(obs)\n", " action = policy_fn(obs)\n", " payload = env_step(action)\n", " obs = payload.get('observation', payload)\n", " r = float(obs.get('reward', payload.get('reward', 0.0)))\n", " ep_rewards.append(r)\n", " bucket_rewards[bucket].append(r)\n", " if bool(obs.get('done', False)):\n", " obs = env_reset(difficulty)\n", " all_rewards.extend(ep_rewards)\n", " per_episode_means.append(float(np.mean(ep_rewards)))\n", " bucket_means = {k: (float(np.mean(v)) if v else 0.0) for k, v in bucket_rewards.items()}\n", " return {\n", " 'mean_reward': float(np.mean(all_rewards)) if all_rewards else 0.0,\n", " 'per_episode_mean': per_episode_means,\n", " 'bucket_means': bucket_means,\n", " 'all_rewards': all_rewards,\n", " }\n", "\n", "def random_policy(_obs):\n", " return random.choice(ACTIONS)\n", "\n", "def heuristic_policy(obs):\n", " risk = float(obs.get('observed_fraud_risk', 0.0))\n", " rates = obs.get('gateway_success_rates', [0.9, 0.9, 0.9]) or [0.9, 0.9, 0.9]\n", " gateway = int(np.argmax(rates))\n", " if risk > 0.65:\n", " fd = 1\n", " elif risk > 0.4:\n", " fd = 2\n", " else:\n", " fd = 0\n", " return {'gateway': gateway, 'fraud_decision': fd, 'retry_strategy': 1}\n", "\n", "baseline_random = eval_policy(random_policy)\n", "baseline_heuristic = eval_policy(heuristic_policy)\n", "print('Random baseline:', baseline_random['mean_reward'], baseline_random['bucket_means'])\n", "print('Heuristic baseline:', baseline_heuristic['mean_reward'], baseline_heuristic['bucket_means'])" ] }, { "cell_type": "markdown", "id": "7c6c10e3", "metadata": {}, "source": [ "## 8. Train with Unsloth FastLanguageModel + TRL DPO" ] }, { "cell_type": "code", "execution_count": null, "id": "bf9a3739", "metadata": {}, "outputs": [], "source": [ "from unsloth import FastLanguageModel\n", "from datasets import Dataset\n", "from trl import DPOConfig, DPOTrainer\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=MODEL_ID,\n", " max_seq_length=MAX_SEQ_LEN,\n", " dtype=None,\n", " load_in_4bit=LOAD_IN_4BIT,\n", ")\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=16,\n", " target_modules=['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],\n", " lora_alpha=16,\n", " lora_dropout=0.0,\n", " bias='none',\n", " use_gradient_checkpointing='unsloth',\n", " random_state=SEED,\n", ")\n", "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", "ds = Dataset.from_list(pairs)\n", "print(ds)\n", "\n", "cfg = DPOConfig(\n", " output_dir='outputs/theme4_dpo_unsloth',\n", " per_device_train_batch_size=1,\n", " gradient_accumulation_steps=4,\n", " num_train_epochs=1 if QUICK_MODE else 2,\n", " logging_steps=2,\n", " learning_rate=5e-6,\n", " max_prompt_length=1024,\n", " max_length=1280,\n", " save_strategy='no',\n", " report_to=[],\n", " bf16=True,\n", ")\n", "\n", "trainer = DPOTrainer(\n", " model=model,\n", " ref_model=None,\n", " args=cfg,\n", " train_dataset=ds,\n", " processing_class=tokenizer,\n", ")\n", "trainer.train()\n", "\n", "loss_history = [h.get('loss') for h in trainer.state.log_history if 'loss' in h]\n", "print('Training loss points:', len(loss_history))" ] }, { "cell_type": "markdown", "id": "12cfc52f", "metadata": {}, "source": [ "## 9. Trained-policy evaluation" ] }, { "cell_type": "code", "execution_count": null, "id": "814937a9", "metadata": {}, "outputs": [], "source": [ "import re\n", "import torch\n", "\n", "FastLanguageModel.for_inference(model)\n", "device = next(model.parameters()).device\n", "ACTION_RE = re.compile(r'\\{[^{}]*\\}')\n", "\n", "def parse_action(text):\n", " m = ACTION_RE.search(text)\n", " if not m:\n", " return {'gateway': 1, 'fraud_decision': 0, 'retry_strategy': 1}\n", " try:\n", " a = json.loads(m.group(0))\n", " return {\n", " 'gateway': int(a.get('gateway', 1)) % 3,\n", " 'fraud_decision': int(a.get('fraud_decision', 0)) % 4,\n", " 'retry_strategy': int(a.get('retry_strategy', 1)) % 2,\n", " }\n", " except Exception:\n", " return {'gateway': 1, 'fraud_decision': 0, 'retry_strategy': 1}\n", "\n", "def trained_policy(obs):\n", " prompt = (\n", " 'SmartPayEnv observation:\\n'\n", " f'{json.dumps(obs, sort_keys=True)}\\n'\n", " 'Return one action JSON with fields: gateway, fraud_decision, retry_strategy.'\n", " )\n", " inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=1024).to(device)\n", " with torch.no_grad():\n", " out = model.generate(\n", " **inputs,\n", " max_new_tokens=64,\n", " do_sample=False,\n", " pad_token_id=tokenizer.pad_token_id,\n", " )\n", " text = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n", " return parse_action(text)\n", "\n", "trained_eval = eval_policy(trained_policy)\n", "print('Trained policy mean reward:', trained_eval['mean_reward'])\n", "print('Trained per-bucket:', trained_eval['bucket_means'])" ] }, { "cell_type": "markdown", "id": "cf9d641c", "metadata": {}, "source": [ "## 10. Plots and saved artifacts" ] }, { "cell_type": "code", "execution_count": null, "id": "e228c3ac", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "plt.figure(figsize=(8,4))\n", "plt.plot(rollout_rewards, label='Best-action reward per rollout step')\n", "plt.xlabel('Rollout step')\n", "plt.ylabel('Reward')\n", "plt.title('Group-relative rollout reward (data-collection phase)')\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.savefig('artifacts/rollout_reward_curve.png', dpi=140)\n", "plt.show()\n", "\n", "if loss_history:\n", " plt.figure(figsize=(8,4))\n", " plt.plot(loss_history, label='DPO training loss')\n", " plt.xlabel('Logging step')\n", " plt.ylabel('Loss')\n", " plt.title('TRL DPO training loss (Unsloth)')\n", " plt.legend()\n", " plt.tight_layout()\n", " plt.savefig('artifacts/training_loss.png', dpi=140)\n", " plt.show()\n", "\n", "labels = ['Random', 'Heuristic', 'Trained LLM']\n", "values = [baseline_random['mean_reward'], baseline_heuristic['mean_reward'], trained_eval['mean_reward']]\n", "plt.figure(figsize=(7,4))\n", "bars = plt.bar(labels, values, color=['#bbb','#88c','#4a8'])\n", "for b, v in zip(bars, values):\n", " plt.text(b.get_x()+b.get_width()/2, v+0.01, f'{v:.3f}', ha='center')\n", "plt.ylabel('Mean reward (frozen holdout)')\n", "plt.title('Before vs After Training')\n", "plt.tight_layout()\n", "plt.savefig('artifacts/before_after_rewards.png', dpi=140)\n", "plt.show()\n", "\n", "buckets = ['low', 'medium', 'high']\n", "rand_b = [baseline_random['bucket_means'][b] for b in buckets]\n", "heur_b = [baseline_heuristic['bucket_means'][b] for b in buckets]\n", "trnd_b = [trained_eval['bucket_means'][b] for b in buckets]\n", "x = np.arange(len(buckets))\n", "w = 0.27\n", "plt.figure(figsize=(8,4))\n", "plt.bar(x - w, rand_b, width=w, label='Random', color='#bbb')\n", "plt.bar(x, heur_b, width=w, label='Heuristic', color='#88c')\n", "plt.bar(x + w, trnd_b, width=w, label='Trained LLM', color='#4a8')\n", "plt.xticks(x, [b.title()+' Risk' for b in buckets])\n", "plt.ylabel('Mean reward')\n", "plt.title('Per Risk-Bucket Reward (frozen holdout)')\n", "plt.legend()\n", "plt.tight_layout()\n", "plt.savefig('artifacts/per_bucket_rewards.png', dpi=140)\n", "plt.show()\n", "\n", "summary = {\n", " 'env_url': ENV_URL,\n", " 'model_id': MODEL_ID,\n", " 'quick_mode': QUICK_MODE,\n", " 'pairs_collected': len(pairs),\n", " 'baseline_random_mean_reward': baseline_random['mean_reward'],\n", " 'baseline_heuristic_mean_reward': baseline_heuristic['mean_reward'],\n", " 'trained_mean_reward': trained_eval['mean_reward'],\n", " 'reward_gain_vs_random': trained_eval['mean_reward'] - baseline_random['mean_reward'],\n", " 'reward_gain_vs_heuristic': trained_eval['mean_reward'] - baseline_heuristic['mean_reward'],\n", " 'per_bucket': {\n", " 'random': baseline_random['bucket_means'],\n", " 'heuristic': baseline_heuristic['bucket_means'],\n", " 'trained': trained_eval['bucket_means'],\n", " },\n", " 'rollout_reward_trace': rollout_rewards,\n", " 'training_loss_history': loss_history,\n", " 'eval_per_episode': {\n", " 'random': baseline_random['per_episode_mean'],\n", " 'heuristic': baseline_heuristic['per_episode_mean'],\n", " 'trained': trained_eval['per_episode_mean'],\n", " },\n", "}\n", "with open('artifacts/run_summary.json', 'w', encoding='utf-8') as f:\n", " json.dump(summary, f, indent=2)\n", "print(json.dumps({k:v for k,v in summary.items() if k not in ('rollout_reward_trace','training_loss_history')}, indent=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 11. (Optional) Upload artifacts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !huggingface-cli upload artifacts artifacts --repo-type dataset" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }