{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Viraltest v2 — Real LLM Training with LoRA + Environment Rewards\n", "\n", "This notebook **actually trains** an LLM (Qwen2.5-1.5B-Instruct) to play our Instagram creator simulation.\n", "\n", "**Pipeline:**\n", "1. Clone repo & install deps\n", "2. Run 5 heuristic baselines × 3 tasks (15 runs) → leaderboard\n", "3. Run **untrained** LLM on all 3 tasks → \"before\" scores\n", "4. **LoRA fine-tune** with reward-weighted SFT (4 rounds × 6 episodes = real weight updates)\n", "5. Run **trained** LLM on all 3 tasks → \"after\" scores\n", "6. Generate real plots from real numbers\n", "\n", "**Requirements:** Colab T4 GPU (free tier), ~45 min total.\n", "\n", "**What makes this real training:** LoRA adapter weights are actually updated via gradient descent. The model's behavior changes because its weights change, not because we edit the prompt.\n", "\n", "**Before this notebook:** run `training/syntax_only.ipynb` (kernel + syntax only) and `training/train_grpo_smoke.ipynb` (repo + env). Pip lines use quoted package specs so Colab/zsh does not break on `>=`." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n", "!pip install -q torch torchvision torchaudio\n", "!pip install -q \"transformers>=4.45.0\" \"accelerate\" \"peft>=0.10.0\" \"trl>=0.20.0\" \"datasets\"\n", "!pip install -q matplotlib pandas\n", "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n", "!pip install -q \"openenv-core[core]>=0.2.2\"\n", "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\"" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n", "import os\n", "import sys\n", "import shutil\n", "import subprocess\n", "from pathlib import Path\n", "\n", "REPO_BRANCH = \"main\"\n", "REPO_URL = \"https://github.com/VaibhavKhandare/viral-posts-env.git\"\n", "COLAB_REPO = Path(\"/content/viral-posts-env\")\n", "\n", "\n", "def _is_repo_root(p: Path) -> bool:\n", " return (p / \"server\" / \"viraltest_environment.py\").is_file() and (p / \"models.py\").is_file()\n", "\n", "\n", "def _find_local_root() -> Path:\n", " here = Path.cwd().resolve()\n", " for cand in (here, here.parent, here.parent.parent):\n", " if _is_repo_root(cand):\n", " return cand\n", " raise FileNotFoundError(\n", " \"Could not find project root. cd into viral-posts-env or run this notebook in Google Colab.\"\n", " )\n", "\n", "\n", "# --- Colab: always clone a clean copy (avoids stale 7-day code) ---\n", "if Path(\"/content\").is_dir():\n", " if COLAB_REPO.exists():\n", " shutil.rmtree(COLAB_REPO, ignore_errors=True)\n", " p = subprocess.run(\n", " [\n", " \"git\", \"clone\", \"--branch\", REPO_BRANCH, \"--depth\", \"1\",\n", " REPO_URL, str(COLAB_REPO),\n", " ],\n", " capture_output=True,\n", " text=True,\n", " )\n", " if p.returncode != 0:\n", " raise RuntimeError(\n", " \"git clone failed. Check network and branch name.\\n\"\n", " f\"stdout:\\n{p.stdout}\\nstderr:\\n{p.stderr}\"\n", " )\n", " if not COLAB_REPO.is_dir():\n", " raise FileNotFoundError(f\"Clone did not create {COLAB_REPO}\")\n", " os.chdir(COLAB_REPO)\n", " print(\"Mode: Colab (fresh clone)\")\n", "else:\n", " # --- Local machine: do not use /content ---\n", " root = _find_local_root()\n", " os.chdir(root)\n", " print(\"Mode: local\")\n", " print(f\"Repo root: {root}\")\n", "\n", "REPO_DIR = str(Path.cwd().resolve())\n", "if REPO_DIR not in sys.path:\n", " sys.path.insert(0, REPO_DIR)\n", "\n", "PLOTS_DIR = os.path.join(REPO_DIR, \"plots\")\n", "os.makedirs(PLOTS_DIR, exist_ok=True)\n", "\n", "try:\n", " commit = subprocess.check_output(\n", " [\"git\", \"rev-parse\", \"--short\", \"HEAD\"],\n", " stderr=subprocess.DEVNULL,\n", " text=True,\n", " ).strip()\n", "except Exception:\n", " commit = \"n/a\"\n", "\n", "print(f\"Working dir: {os.getcwd()}\")\n", "print(f\"Branch: {REPO_BRANCH}\")\n", "print(f\"Commit: {commit}\")\n", "print(f\"Plots dir: {PLOTS_DIR}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 3: Imports (with runtime validation)\n", "import json, random, time, textwrap, copy, os, sys\n", "from pathlib import Path\n", "from typing import Any, Dict, List, Optional, Tuple\n", "from collections import defaultdict\n", "\n", "# Find repo root if notebook was opened from training/ and Cell 2 was skipped\n", "if not Path(\"server/viraltest_environment.py\").is_file():\n", " for cand in (Path.cwd(), Path.cwd().parent, Path.cwd().parent.parent):\n", " if (cand / \"server\" / \"viraltest_environment.py\").is_file():\n", " os.chdir(cand)\n", " s = str(cand.resolve())\n", " if s not in sys.path:\n", " sys.path.insert(0, s)\n", " print(\"Auto chdir to repo root:\", s)\n", " break\n", " else:\n", " raise RuntimeError(\n", " \"Project files not found. Run **Cell 2** first (Colab), or run from repo root.\\n\"\n", " f\" cwd = {os.getcwd()!r}\\n\"\n", " )\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "from models import ScheduledAction, ToolCall, ViraltestAction\n", "from server.viraltest_environment import (\n", " ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n", " TOPIC_CATEGORIES,\n", ")\n", "\n", "ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n", "NICHES = list(TOPIC_CATEGORIES.keys())\n", "CONTENT_TYPES = [\"reel\", \"carousel\", \"story\", \"text_post\"]\n", "INTENTS = [\"send_bait\", \"save_bait\", \"watch_bait\", \"like_bait\"]\n", "TASKS = [\"monthly_engage\", \"monthly_strategic\", \"monthly_competitive\"]\n", "\n", "print(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}\")\n", "print(f\"Tags: {len(TAG_POOL)}, Topics: {len(ALL_TOPICS)}, Horizon: {TASK_HORIZON} days\")\n", "\n", "# Hard stop if stale repo/code is loaded\n", "assert TASK_HORIZON == 15, (\n", " f\"Expected TASK_HORIZON=15, got {TASK_HORIZON}. \"\n", " \"Restart runtime and run from Cell 1 again (clean clone on main).\"\n", ")\n", "\n", "# Same sanity as syntax_only.ipynb (kernel parses modern Python)\n", "import ast\n", "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n", "print(\"OK: ast.parse (syntax check)\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Heuristic Baselines\n", "\n", "5 scripted agents prove the environment differentiates skill levels.\n", "\n", "Benchmark policy:\n", "- Keep heuristic baselines stable across runs so comparisons stay honest.\n", "- Do not lower heuristic strength just to improve model charts.\n", "- If baseline behavior is recalibrated, document the rationale and keep changes human-plausible." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 4: Define heuristic agents + episode runner\n", "_rng = random.Random(42)\n", "\n", "def plan_always_rest(obs_dict, day):\n", " return ViraltestAction(scheduled_actions=[])\n", "\n", "def plan_spam(obs_dict, day):\n", " return ViraltestAction(scheduled_actions=[\n", " ScheduledAction(hour=h, action_type=\"post\", content_type=\"reel\",\n", " topic=\"AI tools\", tags=[\"ai\"], intent=\"watch_bait\")\n", " for h in range(24)])\n", "\n", "def plan_random(obs_dict, day):\n", " actions = []\n", " for h in range(24):\n", " if _rng.random() < 0.1:\n", " actions.append(ScheduledAction(\n", " hour=h, action_type=\"post\",\n", " content_type=_rng.choice(CONTENT_TYPES),\n", " topic=_rng.choice(ALL_TOPICS),\n", " tags=_rng.sample(TAG_POOL[:30], 3),\n", " intent=_rng.choice(INTENTS)))\n", " return ViraltestAction(scheduled_actions=actions)\n", "\n", "def plan_minimal(obs_dict, day):\n", " return ViraltestAction(scheduled_actions=[\n", " ScheduledAction(hour=12, action_type=\"post\", content_type=\"carousel\",\n", " topic=ALL_TOPICS[day % len(ALL_TOPICS)],\n", " tags=[TAG_POOL[i % len(TAG_POOL)] for i in range(day, day+3)],\n", " intent=\"save_bait\")])\n", "\n", "def plan_smart(obs_dict, day):\n", " return ViraltestAction(\n", " tool_calls=[ToolCall(name=\"query_trends\",\n", " arguments={\"niche\": NICHES[day % len(NICHES)]})] if day <= 3 else [],\n", " scheduled_actions=[\n", " ScheduledAction(hour=8, action_type=\"create_content\"),\n", " ScheduledAction(hour=12, action_type=\"post\",\n", " content_type=CONTENT_TYPES[(day*2)%4],\n", " topic=ALL_TOPICS[(day*2)%len(ALL_TOPICS)],\n", " tags=[TAG_POOL[(day*6+i)%len(TAG_POOL)] for i in range(3)],\n", " intent=INTENTS[(day*2)%4]),\n", " ScheduledAction(hour=19, action_type=\"post\",\n", " content_type=CONTENT_TYPES[(day*2+1)%4],\n", " topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n", " tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n", " intent=INTENTS[(day*2+1)%4]),\n", " ])\n", "\n", "BASELINE_AGENTS = {\n", " \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n", " \"random\": plan_random, \"minimal\": plan_minimal, \"smart\": plan_smart,\n", "}\n", "\n", "def run_episode(task, plan_fn, seed=42):\n", " env = ViraltestEnvironment()\n", " obs = env.reset(task=task, seed=seed)\n", " obs_dict = obs.model_dump()\n", " rewards, energies = [], [obs.creator_energy]\n", " for day in range(1, TASK_HORIZON + 1):\n", " action = plan_fn(obs_dict, day)\n", " obs = env.step(action)\n", " obs_dict = obs.model_dump()\n", " rewards.append(obs.reward or 0.0)\n", " energies.append(obs.creator_energy)\n", " if obs.done: break\n", " grader = (obs.metadata or {}).get(\"grader_score\", 0.0)\n", " return {\"grader_score\": grader, \"total_reward\": sum(rewards),\n", " \"steps\": len(rewards), \"final_energy\": obs.creator_energy,\n", " \"follower_delta\": obs.follower_count - 10000,\n", " \"burned_out\": obs.creator_energy <= 0,\n", " \"rewards\": rewards, \"energies\": energies}\n", "\n", "print(\"Agents and episode runner defined.\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 5: Run baselines (safe)\n", "print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n", "print(\"=\" * 70)\n", "\n", "required = [\"BASELINE_AGENTS\", \"run_episode\", \"TASKS\", \"random\"]\n", "missing = [k for k in required if k not in globals()]\n", "if missing:\n", " raise RuntimeError(\n", " f\"Missing prerequisites: {missing}. Run notebook from top (Cell 1 -> Cell 5).\"\n", " )\n", "\n", "baseline_results = {}\n", "for name, fn in BASELINE_AGENTS.items():\n", " baseline_results[name] = {}\n", " for task in TASKS:\n", " _rng = random.Random(42)\n", " try:\n", " result = run_episode(task, fn, seed=42)\n", " except Exception as e:\n", " raise RuntimeError(\n", " f\"Baseline failed for agent={name}, task={task}: {type(e).__name__}: {e}\"\n", " ) from e\n", " baseline_results[name][task] = result\n", " print(f\" {name:>12s} | {task:>22s} | score={result['grader_score']:.4f} \"\n", " f\"| energy={result['final_energy']:.2f}\")\n", " print()\n", "\n", "print(\"\\nLEADERBOARD\")\n", "print(f\"{'Agent':<14s} {'Engage':>10s} {'Strategic':>12s} {'Competitive':>14s} {'Avg':>8s}\")\n", "print(\"-\" * 60)\n", "for name in BASELINE_AGENTS:\n", " scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n", " print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 6: Baseline plots\n", "fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n", "agent_names = list(BASELINE_AGENTS.keys())\n", "colors = ['#E53935', '#FF9800', '#9E9E9E', '#42A5F5', '#4CAF50']\n", "for i, task in enumerate(TASKS):\n", " scores = [baseline_results[a][task][\"grader_score\"] for a in agent_names]\n", " bars = axes[i].barh(agent_names, scores, color=colors)\n", " axes[i].set_title(task.replace(\"monthly_\", \"\").title(), fontsize=13, fontweight='bold')\n", " for bar, score in zip(bars, scores):\n", " axes[i].text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,\n", " f\"{score:.4f}\", va='center', fontsize=9)\n", "axes[0].set_ylabel(\"Agent\")\n", "fig.suptitle(\"Viraltest v2 — Heuristic Baseline Leaderboard\", fontsize=14, fontweight='bold')\n", "fig.tight_layout()\n", "fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2: Load LLM (Qwen2.5-1.5B-Instruct)\n", "\n", "We load the base model with 4-bit quantization to fit in free Colab's T4 GPU (16GB VRAM)." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "\n", "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "if tokenizer.pad_token is None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"left\"\n", "\n", "\n", "def _has_flash_attn():\n", " try:\n", " import flash_attn # noqa: F401\n", " return torch.cuda.is_available()\n", " except Exception:\n", " return False\n", "\n", "\n", "if torch.cuda.is_available():\n", " dtype = torch.bfloat16\n", " attn_impl = \"flash_attention_2\" if _has_flash_attn() else \"sdpa\"\n", "elif getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available():\n", " dtype, attn_impl = torch.float16, \"sdpa\"\n", "else:\n", " dtype, attn_impl = torch.float32, \"eager\"\n", "\n", "print(f\"Loading {MODEL_NAME} (dtype={dtype}, attn={attn_impl})...\")\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_NAME,\n", " trust_remote_code=True,\n", " dtype=dtype,\n", " attn_implementation=attn_impl,\n", " device_map=\"cuda:0\" if torch.cuda.is_available() else None,\n", ")\n", "if not torch.cuda.is_available():\n", " model = model.to(\"mps\") if (getattr(torch.backends, \"mps\", None) and torch.backends.mps.is_available()) else model.to(\"cpu\")\n", "\n", "model.eval()\n", "print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n", "if torch.cuda.is_available():\n", " print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 8: LLM agent functions\n", "_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n", "You are an Instagram content strategy agent. Each step is one day.\n", "You manage a creator account over a 15-day cycle.\n", "\n", "RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n", "{\n", " \"tool_calls\": [{\"name\": \"\", \"arguments\": {...}}],\n", " \"scheduled_actions\": [\n", " {\"hour\": 0-23, \"action_type\": \"post|create_content\",\n", " \"content_type\": \"reel|story|carousel|text_post\",\n", " \"topic\": \"\", \"tags\": [\"...\"],\n", " \"intent\": \"send_bait|save_bait|watch_bait|like_bait\"}\n", " ],\n", " \"notes\": \"strategy notes\"\n", "}\n", "\n", "TOOLS:\n", "- query_trends(niche) trending topics+tags for niche\n", "- query_audience(segment_id) segment topic affinities + active hours\n", "- query_competitor(competitor_id, window_days) competitor recent posts\n", "- query_tag_history(tag) your past signals (watch/sends/saves/likes) for a tag\n", "- predict_engagement(scheduled_actions) simulate a plan WITHOUT committing\n", "- draft_review(scheduled_actions) AI review of a draft plan\n", "- query_creator_pool() list collab partners with audience overlap\n", "- propose_collab(partner_id, content_type, hour) co-author the post at that hour (max 2/month)\n", "\n", "ACTION SCHEMA:\n", "- hour: 0..23 (unlisted hours = rest)\n", "- action_type: post (publish) | create_content (build queue, no publish)\n", "- content_type: reel | story | carousel | text_post\n", "- intent: which Mosseri signal the post optimises for\n", " send_bait -> DM shares (strongest discovery signal)\n", " save_bait -> bookmarks (content quality)\n", " watch_bait -> reels watch time\n", " like_bait -> likes from existing followers\n", "- tags: up to 5 hashtags\n", "- topic: free-form string\n", "- empty scheduled_actions = full day rest\"\"\")\n", "\n", "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n", "\n", "TWO-PHASE FLOW (each day has two turns — same observation, two responses):\n", "PHASE A — DISCOVERY: respond with {\"tool_calls\": [...]} only. Tools cost nothing,\n", " call as many query_* / predict_engagement / draft_review as useful. Their results\n", " are dispatched immediately and shown to you in PHASE B of the SAME day.\n", "PHASE B — PLANNING: respond with {\"scheduled_actions\": [...], \"notes\": \"...\"}\n", " using the freshly returned Tool results.\n", "Audience peak hours, segment affinities, trends, competitor schedules are NOT in\n", "the observation — discover them in PHASE A. Useful PHASE-A starter set:\n", " query_trends(niche), query_audience(segment_id), query_creator_pool(),\n", " query_competitor(competitor_id, window_days), and on later days also\n", " predict_engagement(scheduled_actions=[...candidate plan...]).\"\"\")\n", "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n", "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n", "\n", "\n", "def format_obs(obs):\n", " days = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n", " day_name = days[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n", " signals_str = \"\"\n", " signals = getattr(obs, \"engagement_signals\", None)\n", " if signals:\n", " signals_str = (f\"Signals: watch={signals.watch_time:.3f} \"\n", " f\"sends={signals.sends_per_reach:.3f} \"\n", " f\"saves={signals.saves:.3f}\\n\")\n", " tool_str = \"\"\n", " for tr in getattr(obs, \"tool_results\", []):\n", " if tr.success:\n", " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n", " if not tool_str:\n", " tool_str = \" (none — call query_* tools to discover)\\n\"\n", " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n", " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n", " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n", " f\"{signals_str}\"\n", " f\"Tool results:\\n{tool_str}\"\n", " f\"Plan today's actions (JSON only):\")\n", "\n", "\n", "def is_well_formed_response(text):\n", " try:\n", " t = text.strip()\n", " if \"```\" in t:\n", " t = \"\\n\".join(l for l in t.split(\"\\n\") if not l.strip().startswith(\"```\")).strip()\n", " s, e = t.find(\"{\"), t.rfind(\"}\") + 1\n", " d = json.loads(t[s:e])\n", " for tc in d.get(\"tool_calls\", []):\n", " if not isinstance(tc, dict) or not isinstance(tc.get(\"arguments\", {}), dict):\n", " return False\n", " return True\n", " except Exception:\n", " return False\n", "\n", "\n", "def parse_model_output(text):\n", " text = text.strip()\n", " if \"```\" in text:\n", " lines = [l for l in text.split(\"\\n\") if not l.strip().startswith(\"```\")]\n", " text = \"\\n\".join(lines).strip()\n", " start, end = text.find(\"{\"), text.rfind(\"}\") + 1\n", " if start >= 0 and end > start:\n", " text = text[start:end]\n", " try:\n", " data = json.loads(text)\n", " except Exception:\n", " return ViraltestAction(scheduled_actions=[])\n", " tool_calls = []\n", " for tc in data.get(\"tool_calls\", []):\n", " if not isinstance(tc, dict) or \"name\" not in tc:\n", " continue\n", " args = tc.get(\"arguments\", {})\n", " if isinstance(args, list) and args and isinstance(args[0], dict):\n", " args = args[0]\n", " if not isinstance(args, dict):\n", " continue\n", " try:\n", " tool_calls.append(ToolCall(name=tc[\"name\"], arguments=args))\n", " except Exception:\n", " pass\n", " scheduled = []\n", " for a in data.get(\"scheduled_actions\", []):\n", " try:\n", " scheduled.append(ScheduledAction(**a))\n", " except Exception:\n", " pass\n", " return ViraltestAction(\n", " tool_calls=tool_calls,\n", " scheduled_actions=scheduled,\n", " notes=data.get(\"notes\"),\n", " )\n", "\n", "\n", "def _infer_model_device(m):\n", " \"\"\"Works for single/multi-device models (Peft, 4-bit) where m.device may be missing.\"\"\"\n", " p = next(m.parameters(), None)\n", " if p is not None:\n", " return p.device\n", " d = getattr(m, \"device\", None)\n", " if d is not None:\n", " return d\n", " return torch.device(\"cpu\")\n", "\n", "\n", "def _build_chat(system, prompt):\n", " return [\n", " {\"role\": \"system\", \"content\": system},\n", " {\"role\": \"user\", \"content\": prompt},\n", " ]\n", "\n", "\n", "def _batched_generate(mdl, tok, prompts, eval=False, max_new_tokens=512):\n", " enc = tok(prompts, return_tensors=\"pt\", padding=True, truncation=False).to(_infer_model_device(mdl))\n", " gen_kwargs = dict(\n", " max_new_tokens=max_new_tokens,\n", " pad_token_id=tok.pad_token_id,\n", " do_sample=True, temperature=1.0, top_p=0.95,\n", " )\n", " with torch.no_grad():\n", " out = mdl.generate(**enc, **gen_kwargs)\n", " resps = tok.batch_decode(out[:, enc[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n", " return resps, enc[\"input_ids\"].shape[1]\n", "\n", "\n", "IO_LOG_PATH = os.path.join(PLOTS_DIR, \"io_log.jsonl\")\n", "open(IO_LOG_PATH, \"w\").close() # truncate\n", "\n", "\n", "def _log_io(tag, ep_idx, day, task, seed, prompt, response):\n", " rec = {\"tag\": tag, \"ep\": ep_idx, \"day\": day, \"task\": task, \"seed\": seed,\n", " \"prompt\": prompt, \"response\": response}\n", " with open(IO_LOG_PATH, \"a\") as f:\n", " f.write(json.dumps(rec) + \"\\n\")\n", "\n", "\n", "DISCOVERY_SUFFIX = \"\\n\\nPHASE A (DISCOVERY): respond with JSON {\\\"tool_calls\\\": [...]} only.\"\n", "PLANNING_SUFFIX = \"\\n\\nPHASE B (PLANNING): respond with JSON {\\\"scheduled_actions\\\": [...], \\\"notes\\\": \\\"...\\\"} using the fresh Tool results above.\"\n", "\n", "\n", "def _parse_tool_calls_only(text):\n", " return parse_model_output(text).tool_calls\n", "\n", "\n", "def _parse_actions_only(text):\n", " a = parse_model_output(text)\n", " return ViraltestAction(tool_calls=[], scheduled_actions=a.scheduled_actions, notes=a.notes)\n", "\n", "\n", "def _format_fresh_results(fresh):\n", " if not fresh:\n", " return \"\"\n", " out = \"Fresh tool results (PHASE A):\\n\"\n", " for tr in fresh:\n", " if tr.success:\n", " out += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n", " else:\n", " out += f\" {tr.name}: ERROR {tr.error}\\n\"\n", " return out\n", "\n", "\n", "def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None, log_tag=None):\n", " \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n", " sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n", " n = len(tasks_seeds)\n", " envs = [ViraltestEnvironment() for _ in range(n)]\n", " obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n", " rewards = [[] for _ in range(n)]\n", " energies = [[obs.creator_energy] for obs in obss]\n", " pairs = [[] for _ in range(n)]\n", " done_mask = [obs.done for obs in obss]\n", " rest_action = ViraltestAction(scheduled_actions=[])\n", "\n", " def _gen(prompts):\n", " chats = [_build_chat(sys_prompt, p) for p in prompts]\n", " texts = [tok.apply_chat_template(c, tokenize=False, add_generation_prompt=True) for c in chats]\n", " return _batched_generate(mdl, tok, texts, eval=eval)\n", "\n", " for day in range(1, TASK_HORIZON + 1):\n", " active = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy > 0.25]\n", " rest = [i for i in range(n) if not done_mask[i] and obss[i].creator_energy <= 0.25]\n", " if not active and not rest:\n", " break\n", "\n", " actions_by_idx = {i: rest_action for i in rest}\n", " if active:\n", " base_prompts = [format_obs(obss[i]) for i in active]\n", "\n", " disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n", " disc_resps, ptok = _gen(disc_prompts)\n", " if verbose:\n", " print(f\" D{day:2d}A: batch={len(active)} rest={len(rest)} prompt_tok={ptok}\")\n", "\n", " fresh_per_active = []\n", " for j, i in enumerate(active):\n", " tcs = _parse_tool_calls_only(disc_resps[j])\n", " fresh_per_active.append([envs[i]._dispatch_tool(tc) for tc in tcs])\n", " pairs[i].append({\"prompt\": disc_prompts[j], \"response\": disc_resps[j],\n", " \"step\": len(rewards[i]), \"phase\": \"A\"})\n", " if log_tag is not None:\n", " t, s = tasks_seeds[i]\n", " _log_io(f\"{log_tag}/A\", i, day, t, s, disc_prompts[j], disc_resps[j])\n", "\n", " plan_prompts = [base_prompts[j] + \"\\n\" + _format_fresh_results(fresh_per_active[j]) + PLANNING_SUFFIX\n", " for j in range(len(active))]\n", " plan_resps, ptok2 = _gen(plan_prompts)\n", " if verbose:\n", " print(f\" D{day:2d}B: batch={len(active)} prompt_tok={ptok2}\")\n", "\n", " for j, i in enumerate(active):\n", " actions_by_idx[i] = _parse_actions_only(plan_resps[j])\n", " pairs[i].append({\"prompt\": plan_prompts[j], \"response\": plan_resps[j],\n", " \"step\": len(rewards[i]), \"phase\": \"B\"})\n", " if log_tag is not None:\n", " t, s = tasks_seeds[i]\n", " _log_io(f\"{log_tag}/B\", i, day, t, s, plan_prompts[j], plan_resps[j])\n", "\n", " for i in range(n):\n", " if done_mask[i] or i not in actions_by_idx:\n", " continue\n", " obss[i] = envs[i].step(actions_by_idx[i])\n", " r = obss[i].reward or 0.0\n", " rewards[i].append(r)\n", " energies[i].append(obss[i].creator_energy)\n", " if obss[i].done:\n", " done_mask[i] = True\n", "\n", " GAMMA, TERMINAL_W = 0.95, 5.0\n", " results = []\n", " for i, (task, seed) in enumerate(tasks_seeds):\n", " gs = (obss[i].metadata or {}).get(\"grader_score\", 0.0)\n", " rets = [0.0] * len(rewards[i])\n", " G = gs * TERMINAL_W\n", " for t in reversed(range(len(rewards[i]))):\n", " G = rewards[i][t] + GAMMA * G\n", " rets[t] = G\n", " for pr in pairs[i]:\n", " k = pr.get(\"step\", 0)\n", " pr[\"return\"] = rets[k] if 0 <= k < len(rets) else 0.0\n", " results.append({\n", " \"task\": task, \"seed\": seed, \"grader_score\": gs,\n", " \"total_reward\": sum(rewards[i]), \"final_energy\": obss[i].creator_energy,\n", " \"rewards\": rewards[i], \"returns\": rets, \"energies\": energies[i],\n", " \"pairs\": pairs[i], \"follower_delta\": obss[i].follower_count - 10000,\n", " \"burned_out\": obss[i].creator_energy <= 0,\n", " })\n", " return results\n", "\n", "\n", "def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n", " return run_llm_episodes_batched(mdl, tok, [(task, seed)], verbose=verbose)[0]\n", "\n", "\n", "print(\"LLM agent functions defined (batched).\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 3: Untrained LLM Baseline (“Before”)\n", "\n", "Run the base model with NO fine-tuning. This establishes ground truth." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n", "print(\"Running UNTRAINED base model on all tasks (batched)...\")\n", "print(\"=\" * 60)\n", "\n", "t0 = time.time()\n", "results = run_llm_episodes_batched(model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True, log_tag=\"before\")\n", "before_results = {r[\"task\"]: r for r in results}\n", "\n", "print(\"\\n\" + \"=\" * 60)\n", "print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n", "for t in TASKS:\n", " print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 4: LoRA Fine-Tuning (Real Weight Updates)\n", "\n", "This is the core training loop. For each round:\n", "1. Collect episodes with current model\n", "2. Score each (prompt, response) pair by episode reward\n", "3. Keep top 50% highest-reward samples\n", "4. Fine-tune LoRA weights via SFT on those samples\n", "\n", "The model's actual weights change via gradient descent — this is real training." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 10: Attach LoRA adapter\n", "from peft import LoraConfig, get_peft_model, TaskType\n", "\n", "lora_config = LoraConfig(\n", " r=8, lora_alpha=16, lora_dropout=0.05,\n", " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n", " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n", ")\n", "\n", "model.enable_input_require_grads()\n", "peft_model = get_peft_model(model, lora_config)\n", "peft_model.print_trainable_parameters()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 11: Training loop\n", "from trl import SFTTrainer, SFTConfig\n", "from datasets import Dataset\n", "\n", "NUM_ROUNDS = 2\n", "EPISODES_PER_ROUND = 9\n", "QUALITY_FLOOR = 0.40 # skip SFT for the round if no episode beats this grader score\n", "EPISODE_GRADER_FLOOR = 0.25\n", "STEP_RETURN_FLOOR = -0.02\n", "MIN_ROUND_KEEP_RATE = 0.30\n", "MIN_RETAINED_PAIRS = 18\n", "\n", "training_log = {\n", " \"round\": [], \"avg_episode_reward\": [], \"max_episode_reward\": [],\n", " \"min_episode_reward\": [], \"avg_grader\": [], \"max_grader\": [],\n", " \"episode_keep_rate\": [], \"grader_std\": [],\n", " \"p10_episode_reward\": [], \"p90_episode_reward\": [],\n", " \"n_training_samples\": [], \"train_loss\": [],\n", "}\n", "\n", "t_start = time.time()\n", "\n", "for round_idx in range(1, NUM_ROUNDS + 1):\n", " print(f\"\\n{'=' * 60}\")\n", " print(f\"TRAINING ROUND {round_idx}/{NUM_ROUNDS}\")\n", " print(f\"{'=' * 60}\")\n", "\n", " peft_model.eval()\n", " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + (round_idx - 1) * 100 + ep) for ep in range(EPISODES_PER_ROUND)]\n", " t_roll = time.time()\n", " results = run_llm_episodes_batched(peft_model, tokenizer, tasks_seeds, verbose=False,\n", " eval=False, system=SYSTEM_PROMPT_TRAIN,\n", " log_tag=f\"train_round{round_idx}\")\n", " print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n", "\n", " all_pairs, episode_rewards, episode_graders = [], [], []\n", " accepted_episodes = 0\n", " for ep, result in enumerate(results):\n", " ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n", " episode_rewards.append(ep_reward)\n", " episode_graders.append(result[\"grader_score\"])\n", "\n", " well_formed_pairs = [pr for pr in result[\"pairs\"] if is_well_formed_response(pr[\"response\"])]\n", " step_returns = [pr[\"return\"] for pr in well_formed_pairs]\n", " avg_step_return = float(np.mean(step_returns)) if step_returns else -999.0\n", "\n", " accept_episode = (\n", " result[\"grader_score\"] >= EPISODE_GRADER_FLOOR\n", " and avg_step_return >= STEP_RETURN_FLOOR\n", " )\n", "\n", " kept = 0\n", " if accept_episode:\n", " accepted_episodes += 1\n", " for pr in well_formed_pairs:\n", " if pr[\"return\"] < STEP_RETURN_FLOOR:\n", " continue\n", " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT_TRAIN}<|im_end|>\\n\"\n", " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n", " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n", " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"], \"task\": result[\"task\"]})\n", " kept += 1\n", "\n", " print(\n", " f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n", " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} \"\n", " f\"avg_step={avg_step_return:.3f} accepted={accept_episode} kept={kept}/{len(result['pairs'])}\"\n", " )\n", "\n", " avg_r = float(np.mean(episode_rewards))\n", " avg_g = float(np.mean(episode_graders))\n", " max_g = float(max(episode_graders))\n", " grader_std = float(np.std(episode_graders))\n", " keep_rate = accepted_episodes / max(len(results), 1)\n", " p10_reward = float(np.percentile(episode_rewards, 10))\n", " p90_reward = float(np.percentile(episode_rewards, 90))\n", "\n", " print(\" Data quality report:\")\n", " print(f\" avg_reward={avg_r:.3f} p10={p10_reward:.3f} p90={p90_reward:.3f}\")\n", " print(f\" avg_grader={avg_g:.4f} max_grader={max_g:.4f} grader_std={grader_std:.4f}\")\n", " print(f\" episode_keep_rate={keep_rate:.2%} retained_pairs={len(all_pairs)}\")\n", "\n", " if not all_pairs:\n", " print(\" WARNING: 0 retained pairs after quality gates; skipping SFT.\")\n", " continue\n", " if max_g < QUALITY_FLOOR:\n", " print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n", " continue\n", " if keep_rate < MIN_ROUND_KEEP_RATE:\n", " print(f\" SKIP SFT: keep_rate {keep_rate:.2%} below minimum {MIN_ROUND_KEEP_RATE:.2%}\")\n", " continue\n", "\n", " rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n", " adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n", " filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n", " if not filtered:\n", " print(\" SKIP SFT: zero positive-advantage samples\")\n", " continue\n", "\n", " per_task_filtered = {task: [p for p in filtered if p[\"task\"] == task] for task in TASKS}\n", " missing_tasks = [task for task, rows in per_task_filtered.items() if not rows]\n", " if missing_tasks:\n", " print(f\" SKIP SFT: no positive-advantage samples for tasks={missing_tasks}\")\n", " continue\n", "\n", " per_task_cap = min(len(rows) for rows in per_task_filtered.values())\n", " balanced = []\n", " for task in TASKS:\n", " balanced.extend(per_task_filtered[task][:per_task_cap])\n", "\n", " if len(balanced) < MIN_RETAINED_PAIRS:\n", " print(f\" SKIP SFT: balanced sample count {len(balanced)} below minimum {MIN_RETAINED_PAIRS}\")\n", " continue\n", "\n", " print(\n", " f\" Kept {len(filtered)}/{len(all_pairs)} positive-advantage samples; \"\n", " f\"balanced to {len(balanced)} ({per_task_cap}/task)\"\n", " )\n", "\n", " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in balanced])\n", "\n", " # SFT training (real gradient updates)\n", " sft_config = SFTConfig(\n", " output_dir=f\"./checkpoints/round_{round_idx}\",\n", " num_train_epochs=1,\n", " per_device_train_batch_size=2,\n", " gradient_accumulation_steps=4,\n", " learning_rate=5e-6,\n", " warmup_steps=5,\n", " logging_steps=1,\n", " save_strategy=\"no\",\n", " max_length=2048,\n", " bf16=True,\n", " report_to=\"none\",\n", " )\n", "\n", " peft_model.train()\n", " trainer = SFTTrainer(\n", " model=peft_model, processing_class=tokenizer,\n", " train_dataset=dataset, args=sft_config,\n", " )\n", " train_result = trainer.train()\n", " loss = train_result.training_loss\n", " print(f\" Training loss: {loss:.4f}\")\n", "\n", " training_log[\"round\"].append(round_idx)\n", " training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n", " training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n", " training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n", " training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n", " training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n", " training_log[\"episode_keep_rate\"].append(round(float(keep_rate), 4))\n", " training_log[\"grader_std\"].append(round(float(grader_std), 4))\n", " training_log[\"p10_episode_reward\"].append(round(float(p10_reward), 3))\n", " training_log[\"p90_episode_reward\"].append(round(float(p90_reward), 3))\n", " training_log[\"n_training_samples\"].append(len(balanced))\n", " training_log[\"train_loss\"].append(round(loss, 4))\n", "\n", "elapsed = time.time() - t_start\n", "print(f\"\\nTraining complete in {elapsed/60:.1f} min\")\n", "print(pd.DataFrame(training_log).to_string(index=False))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 5: Trained LLM Evaluation (“After”)\n", "\n", "Same model, same seeds, same environment — but now with updated LoRA weights." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 12: Run trained model (batched)\n", "print(\"Running TRAINED model on all tasks (batched)...\")\n", "print(\"=\" * 60)\n", "\n", "peft_model.eval()\n", "t0 = time.time()\n", "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True, log_tag=\"after\")\n", "after_results = {r[\"task\"]: r for r in results}\n", "\n", "print(\"\\n\" + \"=\" * 60)\n", "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n", "for t in TASKS:\n", " print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 6: Result Plots — Real Training Evidence" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 13: Training curves\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "rounds = training_log[\"round\"]\n", "\n", "axes[0].plot(rounds, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n", "axes[0].fill_between(rounds, training_log[\"avg_grader\"],\n", " training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n", "axes[0].set_xlabel('Round'); axes[0].set_ylabel('Grader Score')\n", "axes[0].set_title('Grader Score Over Rounds', fontweight='bold')\n", "axes[0].legend(); axes[0].grid(True, alpha=0.3)\n", "\n", "axes[1].plot(rounds, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n", "axes[1].set_xlabel('Round'); axes[1].set_ylabel('Loss')\n", "axes[1].set_title('Training Loss', fontweight='bold')\n", "axes[1].grid(True, alpha=0.3)\n", "\n", "fig.suptitle('Viraltest v2 — LoRA Training Progress (Qwen 1.5B)', fontsize=14, fontweight='bold')\n", "fig.tight_layout()\n", "fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 14: Before vs After\n", "task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n", "x = np.arange(len(TASKS))\n", "w = 0.25\n", "\n", "fig, ax = plt.subplots(figsize=(10, 6))\n", "b_scores = [before_results[t][\"grader_score\"] for t in TASKS]\n", "a_scores = [after_results[t][\"grader_score\"] for t in TASKS]\n", "s_scores = [baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS]\n", "\n", "ax.bar(x - w, b_scores, w, label='Base Model (Before)', color='#FF9800')\n", "ax.bar(x, a_scores, w, label='LoRA Trained (After)', color='#4CAF50')\n", "ax.bar(x + w, s_scores, w, label='Smart Heuristic', color='#9E9E9E', alpha=0.7)\n", "\n", "ax.set_ylabel('Grader Score'); ax.set_xticks(x); ax.set_xticklabels(task_labels)\n", "ax.set_title('Before vs After LoRA Training — Grader Scores', fontsize=14, fontweight='bold')\n", "ax.legend(); ax.grid(True, alpha=0.3, axis='y')\n", "\n", "for container in ax.containers:\n", " for bar in container:\n", " h = bar.get_height()\n", " if h > 0:\n", " ax.text(bar.get_x() + bar.get_width()/2., h + 0.005,\n", " f'{h:.4f}', ha='center', va='bottom', fontsize=9)\n", "\n", "fig.tight_layout()\n", "fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 15: Trajectory comparison\n", "fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n", "comparisons = [\n", " (\"Base Model\", before_results, '#FF9800', '--'),\n", " (\"LoRA Trained\", after_results, '#4CAF50', '-'),\n", "]\n", "for i, task in enumerate(TASKS):\n", " for label, res, color, ls in comparisons:\n", " lw = 2.5 if 'Trained' in label else 1.5\n", " axes[0, i].plot(res[task][\"rewards\"], label=label, color=color, lw=lw, ls=ls)\n", " axes[1, i].plot(res[task][\"energies\"], label=label, color=color, lw=lw, ls=ls)\n", " sr = baseline_results[\"smart\"][task]\n", " axes[0, i].plot(sr[\"rewards\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n", " axes[1, i].plot(sr[\"energies\"], label=\"Smart\", color='#9E9E9E', lw=1, ls=':')\n", " t_name = task.replace('monthly_', '').title()\n", " axes[0, i].set_title(f\"{t_name} — Rewards\"); axes[0, i].grid(True, alpha=0.3)\n", " axes[1, i].set_title(f\"{t_name} — Energy\"); axes[1, i].grid(True, alpha=0.3)\n", "axes[0, 2].legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", "fig.suptitle('Before vs After — Daily Trajectories', fontsize=14, fontweight='bold', y=1.01)\n", "fig.tight_layout()\n", "fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 7: Summary & Export" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 16: Final summary\n", "print(\"=\" * 67)\n", "print(\"FINAL RESULTS\")\n", "print(\"=\" * 67)\n", "print(f\"\\n{'Task':<25s} {'Before':>10s} {'After':>10s} {'Delta':>10s} {'Smart':>10s}\")\n", "print(\"-\" * 67)\n", "for task in TASKS:\n", " b = before_results[task][\"grader_score\"]\n", " a = after_results[task][\"grader_score\"]\n", " s = baseline_results[\"smart\"][task][\"grader_score\"]\n", " print(f\"{task:<25s} {b:>10.4f} {a:>10.4f} {a-b:>+10.4f} {s:>10.4f}\")\n", "\n", "avg_b = np.mean([before_results[t][\"grader_score\"] for t in TASKS])\n", "avg_a = np.mean([after_results[t][\"grader_score\"] for t in TASKS])\n", "avg_s = np.mean([baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS])\n", "print(\"-\" * 67)\n", "print(f\"{'AVERAGE':<25s} {avg_b:>10.4f} {avg_a:>10.4f} {avg_a-avg_b:>+10.4f} {avg_s:>10.4f}\")\n", "\n", "summary = {\n", " \"model\": MODEL_NAME,\n", " \"training\": \"LoRA SFT (real weight updates)\",\n", " \"rounds\": NUM_ROUNDS, \"episodes_per_round\": EPISODES_PER_ROUND,\n", " \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n", " \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n", " \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n", " \"improvement\": {t: after_results[t][\"grader_score\"] - before_results[t][\"grader_score\"] for t in TASKS},\n", " \"training_log\": training_log,\n", "}\n", "with open(f\"{PLOTS_DIR}/training_summary.json\", \"w\") as f:\n", " json.dump(summary, f, indent=2)\n", "\n", "pd.DataFrame(training_log).to_csv(f\"{PLOTS_DIR}/training_log.csv\", index=False)\n", "\n", "print(f\"\\nSaved to {PLOTS_DIR}/\")\n", "print(\"All results are from real LoRA weight updates on real environment runs.\")" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "# Cell 17: Save adapter\n", "save_path = \"./viraltest_trained_adapter\"\n", "peft_model.save_pretrained(save_path)\n", "tokenizer.save_pretrained(save_path)\n", "print(f\"LoRA adapter saved to {save_path}\")\n", "print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")" ], "execution_count": null, "outputs": [] } ], "metadata": { "accelerator": "GPU", "gpuClass": "standard", "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 4 }