| """OpenSleuth GRPO trainer. |
| |
| Trains a small Qwen2.5 model with TRL's GRPOTrainer to do in-context program |
| synthesis — given the public signature of a hidden function plus a handful of |
| (input, output) probe examples, emit a Python function that reproduces it. |
| |
| Reward comes from the live OpenSleuth env Space: the agent's code is executed |
| against the hidden reference under domain-aware fuzzing, and the verifier |
| returns an `execution_reward - complexity_penalty` score that we hand back to |
| GRPO as the per-completion reward (plus a tiny formatting shaping reward). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| import time |
|
|
| import torch |
| from peft import LoraConfig |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from trl import GRPOConfig, GRPOTrainer |
|
|
| from opensleuth_train import ( |
| EnvClient, |
| SYSTEM_PROMPT, |
| build_synthesis_dataset, |
| discover_functions, |
| ) |
| from opensleuth_train.reward import format_reward, make_env_reward |
|
|
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| stream=sys.stdout, |
| ) |
| log = logging.getLogger("opensleuth.train") |
|
|
|
|
| def _split_csv(s: str) -> list[str]: |
| return [x.strip() for x in s.split(",") if x.strip()] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--env-url", default=os.environ.get("ENV_URL", "https://anugrah55-opensleuth-env-gemini-cli.hf.space")) |
| |
| |
| |
| p.add_argument("--model-name", default=os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-3B-Instruct")) |
| p.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "/data/opensleuth-grpo")) |
| p.add_argument( |
| "--push-to-hub", |
| default=os.environ.get( |
| "PUSH_TO_HUB", "anugrah55/opensleuth-qwen2.5-3b-grpo-v2" |
| ), |
| ) |
| |
| p.add_argument( |
| "--functions", |
| default=os.environ.get("FUNCTIONS_INCLUDE", ""), |
| help="Comma-separated subset of task names to train on. Empty = all " |
| "tasks the env exposes (builtin + Hub).", |
| ) |
| p.add_argument( |
| "--difficulty", |
| default=os.environ.get("DIFFICULTY_FILTER", "all"), |
| choices=["easy", "medium", "hard", "all"], |
| help="Curriculum filter: only sample tasks at this difficulty level.", |
| ) |
| |
| |
| p.add_argument("--n-easy", type=int, default=int(os.environ.get("N_EASY", "8"))) |
| p.add_argument("--n-medium", type=int, default=int(os.environ.get("N_MEDIUM", "16"))) |
| p.add_argument("--n-hard", type=int, default=int(os.environ.get("N_HARD", "24"))) |
| p.add_argument( |
| "--n-per-function", |
| type=int, |
| default=int(os.environ.get("N_PER_FUNCTION", "0")), |
| help="If >0, overrides per-difficulty rollout counts with a uniform N.", |
| ) |
| p.add_argument("--n-probes", type=int, default=int(os.environ.get("N_PROBES", "6"))) |
| |
| p.add_argument("--num-generations", type=int, default=int(os.environ.get("NUM_GENERATIONS", "2"))) |
| p.add_argument("--max-completion-length", type=int, default=int(os.environ.get("MAX_COMPLETION_LENGTH", "384"))) |
| p.add_argument("--max-prompt-length", type=int, default=int(os.environ.get("MAX_PROMPT_LENGTH", "1024"))) |
| p.add_argument("--learning-rate", type=float, default=float(os.environ.get("LEARNING_RATE", "1e-5"))) |
| p.add_argument("--num-train-epochs", type=float, default=float(os.environ.get("NUM_TRAIN_EPOCHS", "1"))) |
| |
| |
| p.add_argument("--per-device-batch-size", type=int, default=int(os.environ.get("PER_DEVICE_BATCH_SIZE", "2"))) |
| p.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("GRAD_ACCUM", "4"))) |
| p.add_argument("--no-4bit", action="store_true", default=os.environ.get("NO_4BIT", "0") == "1") |
| p.add_argument("--seed", type=int, default=int(os.environ.get("SEED", "42"))) |
| return p.parse_args() |
|
|
|
|
| def wait_for_env(client: EnvClient, max_wait_s: float = 300.0) -> None: |
| log.info("waiting for env at %s ...", client.base_url) |
| start = time.time() |
| last_err = "" |
| while time.time() - start < max_wait_s: |
| try: |
| h = client.health() |
| log.info("env healthy: %s", h) |
| return |
| except Exception as e: |
| last_err = str(e) |
| time.sleep(5) |
| raise RuntimeError(f"env never became healthy after {max_wait_s}s. Last error: {last_err}") |
|
|
|
|
| def main() -> int: |
| args = parse_args() |
| log.info("args: %s", vars(args)) |
|
|
| client = EnvClient(base_url=args.env_url, timeout=60.0, retries=4) |
| wait_for_env(client) |
|
|
| include = _split_csv(args.functions) if args.functions else None |
| difficulty = None if args.difficulty == "all" else args.difficulty |
| tasks = discover_functions(client, include=include, difficulty=difficulty) |
| log.info( |
| "env catalog: %d task(s) after filter (include=%s, difficulty=%s):", |
| len(tasks), include, difficulty, |
| ) |
| for t in tasks: |
| log.info( |
| " - %-22s difficulty=%-8s source=%s", |
| t["name"], t.get("difficulty"), t.get("source"), |
| ) |
|
|
| n_per_function_override = args.n_per_function if args.n_per_function > 0 else None |
| log.info( |
| "building synthesis dataset (n_easy=%d n_medium=%d n_hard=%d override=%s n_probes=%d)", |
| args.n_easy, args.n_medium, args.n_hard, n_per_function_override, args.n_probes, |
| ) |
| dataset = build_synthesis_dataset( |
| client, |
| tasks=tasks, |
| n_per_function=n_per_function_override, |
| n_easy=args.n_easy, |
| n_medium=args.n_medium, |
| n_hard=args.n_hard, |
| n_probes=args.n_probes, |
| seed=args.seed, |
| ) |
| log.info("dataset size: %d rows", len(dataset)) |
|
|
| |
| |
| |
| def to_chat(row): |
| return { |
| "prompt": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": row["prompt"]}, |
| ], |
| "target_function_name": row["target_function_name"], |
| "row_seed": row["row_seed"], |
| } |
|
|
| |
| |
| drop_cols = [c for c in ("prompt", "difficulty") if c in dataset.column_names] |
| dataset = dataset.map(to_chat, remove_columns=drop_cols) |
|
|
| |
| log.info("loading model %s (4bit=%s)", args.model_name, not args.no_4bit) |
| bnb_config = None |
| if not args.no_4bit: |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="nf4", |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| peft_config = LoraConfig( |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.05, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], |
| task_type="CAUSAL_LM", |
| bias="none", |
| ) |
|
|
| |
| |
| per_device_bs = args.per_device_batch_size or args.num_generations |
| if per_device_bs % args.num_generations != 0: |
| raise ValueError( |
| f"per_device_batch_size ({per_device_bs}) must be a multiple of " |
| f"num_generations ({args.num_generations})." |
| ) |
| log.info( |
| "GRPO batching: per_device_batch_size=%d (= %d prompt(s) × %d generations), grad_accum=%d", |
| per_device_bs, per_device_bs // args.num_generations, args.num_generations, |
| args.gradient_accumulation_steps, |
| ) |
|
|
| grpo_config = GRPOConfig( |
| output_dir=args.output_dir, |
| per_device_train_batch_size=per_device_bs, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.learning_rate, |
| num_train_epochs=args.num_train_epochs, |
| max_prompt_length=args.max_prompt_length, |
| max_completion_length=args.max_completion_length, |
| num_generations=args.num_generations, |
| beta=0.04, |
| bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False, |
| fp16=False, |
| logging_steps=1, |
| save_steps=50, |
| save_total_limit=2, |
| report_to=[], |
| seed=args.seed, |
| push_to_hub=bool(args.push_to_hub) and bool(os.environ.get("HF_TOKEN")), |
| hub_model_id=args.push_to_hub or None, |
| hub_strategy="end", |
| gradient_checkpointing=True, |
| ) |
|
|
| env_reward_fn = make_env_reward(client) |
| env_reward_fn.__name__ = "env_verifier_reward" |
| format_reward.__name__ = "format_reward" |
|
|
| |
| |
| |
| log.info("loading base model with quantization=%s", bnb_config is not None) |
| model_kwargs = {"trust_remote_code": True, "torch_dtype": torch.bfloat16} |
| if bnb_config is not None: |
| model_kwargs["quantization_config"] = bnb_config |
| model = AutoModelForCausalLM.from_pretrained(args.model_name, **model_kwargs) |
|
|
| log.info("instantiating GRPOTrainer") |
| trainer = GRPOTrainer( |
| model=model, |
| reward_funcs=[env_reward_fn, format_reward], |
| args=grpo_config, |
| train_dataset=dataset, |
| peft_config=peft_config, |
| processing_class=tokenizer, |
| ) |
|
|
| log.info("starting GRPO training") |
| trainer.train() |
| log.info("training complete; saving to %s", args.output_dir) |
| trainer.save_model(args.output_dir) |
| if grpo_config.push_to_hub: |
| log.info("pushing to hub: %s", args.push_to_hub) |
| trainer.push_to_hub() |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|