{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Minimal Unsloth GRPO on Colab with a remote OpenEnv Space\n", "\n", "This notebook is intentionally similar to the 2048 notebook pattern:\n", "- training runs locally inside Colab\n", "- the environment is accessed remotely through a Hugging Face Space\n", "- the reward function is defined in notebook code by replaying actions against that remote env\n", "- prompt / action / conclusion formatting mirrors the repo logic without importing the repo training script\n", "\n", "Default remote env: `Ev3Dev/hackathon`\n", "\n", "**Runtime**: Enable a GPU in Colab: Runtime -> Change runtime type -> GPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 1. Clone the repo for lightweight client / model definitions only\n", "REPO_URL = \"https://github.com/mhtruong1031/OpenENV-Hackathon.git\" # or your fork\n", "REPO_DIR = \"OpenENV-Hackathon\"\n", "\n", "!git clone --depth 1 {REPO_URL} {REPO_DIR}\n", "%cd {REPO_DIR}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 2. Install only the runtime pieces needed for notebook-side training\n", "!pip install -q unsloth unsloth_zoo --no-deps\n", "!pip install -q \"openenv-core[core]>=0.2.0\" \"pydantic>=2\" \"numpy>=1.24.0\" \"scipy>=1.10.0\" \"datasets>=4.6.1\" \"accelerate>=1.13.0\" \"peft>=0.15.0\" \"bitsandbytes>=0.45.0\" \"matplotlib>=3.8.0\"\n", "!pip install -q \"transformers>=4.57.0\" \"trl>=0.29.0\" \"torchvision>=0.20.0\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 3. Import repo reward helpers, but keep the environment remote\n", "import inspect\n", "import json\n", "import random\n", "import sys\n", "from pathlib import Path\n", "from typing import Any, Dict, List\n", "\n", "# Unsloth must be imported before trl / transformers / peft.\n", "import unsloth # noqa: F401\n", "import torch\n", "from unsloth import FastLanguageModel, PatchFastRL\n", "\n", "sys.path.insert(0, str(Path.cwd()))\n", "\n", "from client import BioExperimentEnv\n", "from models import ActionType, ExperimentAction\n", "from training_script import (\n", " INVALID_ACTION_PENALTY,\n", " ENVIRONMENT_ERROR_PENALTY,\n", " OpenEnvReward,\n", " build_training_prompt,\n", " build_experiment_action,\n", " decode_history_actions,\n", " pick_action,\n", " save_training_plots,\n", ")\n", "\n", "MAX_COMPLETION_TOKENS = 160\n", "LORA_TARGET_MODULES = [\n", " \"q_proj\",\n", " \"k_proj\",\n", " \"v_proj\",\n", " \"o_proj\",\n", " \"gate_proj\",\n", " \"up_proj\",\n", " \"down_proj\",\n", "]\n", "\n", "\n", "def hf_space_repo_to_base_url(repo_id: str) -> str:\n", " owner, space_name = repo_id.split(\"/\", 1)\n", " return f\"https://{owner.lower().replace('_', '-')}-{space_name.lower().replace('_', '-')}.hf.space\"\n", "\n", "\n", "def build_remote_prompt_examples(\n", " base_url: str,\n", " dataset_episodes: int,\n", " rollout_steps: int,\n", " seed: int,\n", ") -> List[Dict[str, str]]:\n", " rng = random.Random(seed)\n", " examples: List[Dict[str, str]] = []\n", "\n", " for _ in range(dataset_episodes):\n", " with BioExperimentEnv(base_url=base_url) as env:\n", " result = env.reset()\n", " obs = result.observation\n", " history_actions: List[ExperimentAction] = []\n", "\n", " for step_idx in range(rollout_steps):\n", " if obs.done:\n", " break\n", "\n", " next_action = build_experiment_action(\n", " action_type=pick_action(\n", " \"heuristic\",\n", " step_idx,\n", " [action.action_type for action in history_actions],\n", " ),\n", " discovered_markers=obs.discovered_markers,\n", " candidate_mechanisms=obs.candidate_mechanisms,\n", " conditions=obs.task.conditions,\n", " )\n", " examples.append(\n", " {\n", " \"prompt\": build_training_prompt(obs),\n", " \"history_actions\": json.dumps(\n", " [action.model_dump() for action in history_actions]\n", " ),\n", " \"reference_action\": json.dumps(next_action.model_dump()),\n", " \"problem_statement\": obs.task.problem_statement,\n", " \"episode_tag\": f\"remote-{rng.randrange(10**9):09d}\",\n", " }\n", " )\n", "\n", " history_actions.append(next_action)\n", " result = env.step(next_action)\n", " obs = result.observation\n", " if result.done:\n", " break\n", "\n", " return examples\n", "\n", "\n", "def build_grpo_config(**overrides: Any):\n", " from trl import GRPOConfig\n", "\n", " supported = set(inspect.signature(GRPOConfig.__init__).parameters)\n", " config_kwargs = {\n", " \"output_dir\": overrides[\"output_dir\"],\n", " \"learning_rate\": overrides[\"learning_rate\"],\n", " \"per_device_train_batch_size\": overrides[\"per_device_train_batch_size\"],\n", " \"gradient_accumulation_steps\": overrides[\"gradient_accumulation_steps\"],\n", " \"num_generations\": overrides[\"num_generations\"],\n", " \"max_completion_length\": overrides[\"max_completion_length\"],\n", " \"num_train_epochs\": overrides[\"num_train_epochs\"],\n", " \"logging_steps\": overrides[\"logging_steps\"],\n", " \"save_steps\": overrides[\"save_steps\"],\n", " \"bf16\": overrides[\"bf16\"],\n", " \"fp16\": overrides[\"fp16\"],\n", " \"report_to\": \"none\",\n", " \"remove_unused_columns\": False,\n", " }\n", " # Keep prompt truncation enabled. Leaving this as None can trigger\n", " # an Unsloth rotary-cache shape mismatch on long GRPO prompts.\n", " if \"max_prompt_length\" in supported:\n", " config_kwargs[\"max_prompt_length\"] = overrides[\"max_prompt_length\"]\n", " if (\n", " \"max_length\" in supported\n", " and \"max_prompt_length\" not in supported\n", " and \"max_completion_length\" not in supported\n", " ):\n", " config_kwargs[\"max_length\"] = (\n", " overrides[\"max_prompt_length\"] + overrides[\"max_completion_length\"]\n", " )\n", " return GRPOConfig(**{k: v for k, v in config_kwargs.items() if k in supported})\n", "\n", "\n", "print(\"CUDA:\", torch.cuda.is_available(), torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"\")\n", "Path(\"artifacts\").mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 4. Config + collect prompt states from the remote Space\n", "SPACE_REPO_ID = \"Ev3Dev/hackathon\"\n", "SPACE_BASE_URL = hf_space_repo_to_base_url(SPACE_REPO_ID)\n", "# If your Space has a custom domain, replace SPACE_BASE_URL manually.\n", "\n", "MODEL_ID = \"unsloth/Llama-3.2-3B-Instruct-bnb-4bit\"\n", "OUTPUT_DIR = \"artifacts/grpo-unsloth-llama32-3b-remote-space\"\n", "\n", "DATASET_EPISODES = 8\n", "ROLLOUT_STEPS = 6\n", "NUM_GENERATIONS = 2\n", "# Keep this modest for Unsloth GRPO stability on Colab.\n", "MAX_PROMPT_LENGTH = 768\n", "MAX_SEQ_LENGTH = 2048\n", "PER_DEVICE_TRAIN_BATCH_SIZE = 1\n", "GRADIENT_ACCUMULATION_STEPS = 4\n", "LEARNING_RATE = 5e-6\n", "NUM_TRAIN_EPOCHS = 1.0\n", "LOGGING_STEPS = 1\n", "SAVE_STEPS = 25\n", "SEED = 42\n", "LORA_R = 16\n", "LORA_ALPHA = 16\n", "LORA_DROPOUT = 0.0\n", "\n", "examples = build_remote_prompt_examples(\n", " base_url=SPACE_BASE_URL,\n", " dataset_episodes=DATASET_EPISODES,\n", " rollout_steps=ROLLOUT_STEPS,\n", " seed=SEED,\n", ")\n", "\n", "reward_fn = OpenEnvReward(\n", " reward_backend=\"remote\",\n", " base_url=SPACE_BASE_URL,\n", " invalid_action_penalty=INVALID_ACTION_PENALTY,\n", " environment_error_penalty=ENVIRONMENT_ERROR_PENALTY,\n", ")\n", "\n", "print(\"Remote env:\", SPACE_BASE_URL)\n", "print(\"Prompt states:\", len(examples))\n", "print(\"Sample prompt preview:\\n\")\n", "print(examples[0][\"prompt\"][:2000])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 5. Local GRPO training in Colab, remote env for rewards\n", "from datasets import Dataset\n", "from trl import GRPOTrainer\n", "\n", "PatchFastRL(\"GRPO\", FastLanguageModel)\n", "train_dataset = Dataset.from_list(examples)\n", "\n", "bf16 = bool(getattr(torch.cuda, \"is_bf16_supported\", lambda: False)()) if torch.cuda.is_available() else False\n", "runtime_dtype = torch.bfloat16 if bf16 else (torch.float16 if torch.cuda.is_available() else torch.float32)\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=MODEL_ID,\n", " max_seq_length=MAX_SEQ_LENGTH,\n", " dtype=runtime_dtype,\n", " load_in_4bit=True,\n", ")\n", "if tokenizer.pad_token is None and tokenizer.eos_token is not None:\n", " tokenizer.pad_token = tokenizer.eos_token\n", "\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r=LORA_R,\n", " target_modules=LORA_TARGET_MODULES,\n", " lora_alpha=LORA_ALPHA,\n", " lora_dropout=LORA_DROPOUT,\n", " bias=\"none\",\n", " use_gradient_checkpointing=True,\n", " random_state=SEED,\n", ")\n", "\n", "training_args = build_grpo_config(\n", " output_dir=OUTPUT_DIR,\n", " learning_rate=LEARNING_RATE,\n", " per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,\n", " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", " num_generations=NUM_GENERATIONS,\n", " max_completion_length=MAX_COMPLETION_TOKENS,\n", " max_prompt_length=MAX_PROMPT_LENGTH,\n", " num_train_epochs=NUM_TRAIN_EPOCHS,\n", " logging_steps=LOGGING_STEPS,\n", " save_steps=SAVE_STEPS,\n", " bf16=bf16,\n", " fp16=torch.cuda.is_available() and not bf16,\n", ")\n", "\n", "trainer = GRPOTrainer(\n", " model=model,\n", " reward_funcs=[reward_fn],\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " processing_class=tokenizer,\n", ")\n", "\n", "for attr in (\"image_token_id\", \"vision_start_token_id\", \"vision_end_token_id\"):\n", " if not hasattr(trainer, attr):\n", " setattr(trainer, attr, None)\n", "\n", "trainer.train()\n", "trainer.save_model(OUTPUT_DIR)\n", "tokenizer.save_pretrained(OUTPUT_DIR)\n", "plot_paths = save_training_plots(trainer.state.log_history, OUTPUT_DIR)\n", "\n", "result = {\n", " \"trainer\": trainer,\n", " \"plot_paths\": plot_paths,\n", " \"output_dir\": OUTPUT_DIR,\n", "}\n", "print(\"Saved to:\", OUTPUT_DIR)\n", "print(\"Plots:\", plot_paths)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 6. (Optional) Show curves and sanity-check the repo reward wrapper\n", "from IPython.display import Image, display\n", "\n", "sample_reward = reward_fn(\n", " completions=[[{\"role\": \"assistant\", \"content\": examples[0][\"reference_action\"]}]],\n", " history_actions=[examples[0][\"history_actions\"]],\n", ")[0]\n", "print(\"Sample reward for reference action:\", sample_reward)\n", "\n", "for name, path in (result.get(\"plot_paths\") or {}).items():\n", " print(name, path)\n", " display(Image(filename=path))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 4 }