import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" """ GRPO Training on API Debug Environment ======================================= Trains a small LLM (Qwen 0.5B) to debug malformed API requests using reward signals from the live HuggingFace Space environment. Supports curriculum learning: starts on easy task, promotes to classify and medium as the agent improves. Run on Colab (free T4 GPU): pip install -r training/requirements.txt python training/train.py """ import json import re import sys import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from datasets import Dataset from transformers import AutoTokenizer, AutoModelForCausalLM from trl import GRPOConfig, GRPOTrainer from trl.experimental.openenv import generate_rollout_completions from client import APIDebugEnv from models import APIDebugAction # -- GPU check ---------------------------------------------------------------- print(f"GPU available : {torch.cuda.is_available()}") print(f"GPU name : {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None (CPU)'}") has_gpu = torch.cuda.is_available() supports_bf16 = has_gpu and torch.cuda.is_bf16_supported() # -- Config ------------------------------------------------------------------- MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" ENV_URL = "https://avichauhan-api-debug-env.hf.space" MAX_TURNS = 3 # easy task: 3 steps max NUM_SAMPLES = 64 # -- Curriculum state --------------------------------------------------------- # Tracks which task the agent is currently training on. # Promotes when rolling average reward exceeds threshold. CURRICULUM = { "easy": {"next": "classify", "threshold": 0.7, "max_turns": 3}, "classify": {"next": "medium", "threshold": 0.6, "max_turns": 4}, "medium": {"next": "headers", "threshold": 0.6, "max_turns": 5}, "headers": {"next": "response", "threshold": 0.5, "max_turns": 4}, "response": {"next": "hard", "threshold": 0.5, "max_turns": 4}, "hard": {"next": None, "threshold": None, "max_turns": 7}, } current_task = "easy" recent_rewards: list[float] = [] WINDOW_SIZE = 10 SYSTEM_PROMPT = """You are an API debugging expert. You receive a broken API request and its specification. Your job: identify the error type and the affected fields. Respond with ONLY a JSON object in this format: {"error_type": "", "affected_fields": ["field1", "field2"]} Valid error types: missing_required_field, wrong_field_type, invalid_email_format, missing_auth_header, extra_unknown_field, null_value_in_required, wrong_http_method, malformed_json_value, invalid_enum_value, datetime_format_error, wrong_content_type, expired_auth_token""" CLASSIFY_PROMPT = """You are an API debugging expert. This request has MULTIPLE errors. Identify ALL error types and ALL affected fields. Respond with ONLY a JSON object: {"error_types": ["type1", "type2"], "affected_fields": ["field1", "field2"]}""" MEDIUM_PROMPT = """You are an API debugging expert. Fix the broken request to match the API spec. Respond with ONLY a JSON object: {"fixed_request": {"field": "value"}, "fixed_headers": {"Header": "value"}}""" HEADERS_PROMPT = """You are an API debugging expert. This request has ONLY header-level errors. Identify the error type and fix the headers to match the API spec. Respond with ONLY a JSON object: {"error_type": "", "fixed_headers": {"Header-Name": "correct_value"}} Common header error types: wrong_content_type, expired_auth_token, missing_auth_header""" RESPONSE_PROMPT = """You are an API response validation expert. You receive an API request, its spec, and the server response. Identify issues in the response: wrong status codes, missing fields, wrong types, extra fields, inconsistent error format. Respond with ONLY a JSON object: {"response_issues": ["issue_type1"], "affected_fields": ["field1"], "expected_status_code": 200} Valid issue types: wrong_status_code, missing_response_field, wrong_response_type, extra_response_field, inconsistent_error_format""" HARD_PROMPT = """You are an API debugging expert. This request has MULTIPLE errors across headers and body. Some errors are chained -- fixing one may reveal others. Fix everything and explain your reasoning. Respond with ONLY a JSON object: {"fixed_request": {"field": "value"}, "fixed_headers": {"Header": "value"}, "explanation": "why each fix was needed"}""" TASK_PROMPTS = { "easy": SYSTEM_PROMPT, "classify": CLASSIFY_PROMPT, "medium": MEDIUM_PROMPT, "headers": HEADERS_PROMPT, "response": RESPONSE_PROMPT, "hard": HARD_PROMPT, } # -- Environment client ------------------------------------------------------- print(f"Connecting to environment: {ENV_URL}") env_client = APIDebugEnv(base_url=ENV_URL) # -- JSON parser (reused from inference.py) ----------------------------------- def parse_llm_response(text: str) -> dict: if not text: return {} try: return json.loads(text) except json.JSONDecodeError: pass code_block = re.search(r"```(?:json)?\s*\n?(.*?)\n?\s*```", text, re.DOTALL) if code_block: try: return json.loads(code_block.group(1)) except json.JSONDecodeError: pass brace_match = re.search(r"\{[^{}]*\}", text, re.DOTALL) if brace_match: try: return json.loads(brace_match.group(0)) except json.JSONDecodeError: pass return {} def build_action(data) -> APIDebugAction: if not isinstance(data, dict): return APIDebugAction() fixed_req = data.get("fixed_request") if isinstance(fixed_req, dict): fixed_req = json.dumps(fixed_req) return APIDebugAction( error_type=data.get("error_type"), error_types=data.get("error_types"), affected_fields=data.get("affected_fields"), fixed_request=fixed_req, fixed_headers=data.get("fixed_headers"), response_issues=data.get("response_issues"), expected_status_code=data.get("expected_status_code"), ) # -- Curriculum learning ------------------------------------------------------ def maybe_promote(): """Check if agent should be promoted to next difficulty.""" global current_task config = CURRICULUM[current_task] if config["next"] is None or config["threshold"] is None: return if len(recent_rewards) < WINDOW_SIZE: return avg = sum(recent_rewards[-WINDOW_SIZE:]) / WINDOW_SIZE if avg >= config["threshold"]: old_task = current_task current_task = config["next"] recent_rewards.clear() print(f"[CURRICULUM] Promoted: {old_task} -> {current_task} (avg_reward={avg:.3f})") # -- Rollout function --------------------------------------------------------- def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict: tokenizer = trainer.processing_class all_prompt_ids = [] all_completion_ids = [] all_logprobs = [] all_rewards = [] task = current_task max_turns = CURRICULUM[task]["max_turns"] system_prompt = TASK_PROMPTS[task] for base_prompt in prompts: with env_client.sync() as env: obs = env.reset(task=task) episode_reward = 0.0 episode_prompt_ids = [] episode_comp_ids = [] episode_logprobs = [] for turn in range(max_turns): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": ( f"{base_prompt}\n\n" f"API: {obs.observation.http_method} {obs.observation.endpoint} " f"({obs.observation.api_name})\n" f"Error count: {obs.observation.error_count}\n" f"Step {turn + 1}/{max_turns}\n\n" f"Broken request:\n{obs.observation.broken_request}\n\n" f"Headers: {json.dumps(obs.observation.broken_headers)}\n\n" f"API Spec:\n{obs.observation.api_spec}\n" + (f"\nFeedback:\n{obs.observation.feedback}" if obs.observation.feedback else "") )}, ] prompt_text = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, enable_thinking=False, ) outputs = generate_rollout_completions(trainer, [prompt_text])[0] completion_text = tokenizer.decode( outputs["completion_ids"], skip_special_tokens=True ).strip() episode_prompt_ids.extend(outputs["prompt_ids"]) episode_comp_ids.extend(outputs["completion_ids"]) episode_logprobs.extend(outputs["logprobs"]) # Parse LLM output into action parsed = parse_llm_response(completion_text) action = build_action(parsed) obs = env.step(action) episode_reward = float(obs.reward or 0.0) if obs.done: break all_prompt_ids.append(episode_prompt_ids) all_completion_ids.append(episode_comp_ids) all_logprobs.append(episode_logprobs) all_rewards.append(episode_reward) # Track for curriculum recent_rewards.append(episode_reward) # Check if agent should be promoted maybe_promote() return { "prompt_ids": all_prompt_ids, "completion_ids": all_completion_ids, "logprobs": all_logprobs, "env_reward": all_rewards, } # -- Reward function ---------------------------------------------------------- def reward_from_env(completions, **kwargs): env_rewards = kwargs.get("env_reward", []) return [float(r) for r in env_rewards] if env_rewards else [0.0] * len(completions) # -- Dataset ------------------------------------------------------------------ dataset = Dataset.from_dict({ "prompt": ["Debug this broken API request."] * NUM_SAMPLES }) # -- Trainer ------------------------------------------------------------------ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, attn_implementation="eager") grpo_args = GRPOConfig( use_vllm=True, vllm_mode="colocate", num_train_epochs=1, num_generations=2, max_completion_length=128, per_device_train_batch_size=1, gradient_accumulation_steps=16, learning_rate=5e-6, output_dir="./outputs/api-debug-grpo", logging_steps=1, report_to="none", bf16=supports_bf16, fp16=has_gpu and not supports_bf16, gradient_checkpointing=True, vllm_gpu_memory_utilization=0.3, dataloader_pin_memory=False, ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=reward_from_env, train_dataset=dataset, rollout_func=rollout_func, args=grpo_args, ) if __name__ == "__main__": print("Starting GRPO training on API Debug Environment...") print(f"Model : {MODEL_ID}") print(f"Environment: {ENV_URL}") print(f"Episodes : {NUM_SAMPLES}") print(f"Task : {current_task} (with curriculum learning)") print(f"bf16 : {supports_bf16}") print(f"fp16 : {has_gpu and not supports_bf16}") trainer.train() print(f"Training complete! Final task: {current_task}") print("Model saved to ./outputs/api-debug-grpo")