--- title: GRPO Training Collapse Analysis description: Root-cause analysis of GRPO training collapse on Qwen3-1.7B caused by extra kwargs in tool calls and advantage collapse doc_type: exploration --- # GRPO Training Collapse Analysis ## What happened After SFT warmup, GRPO training on Qwen3-1.7B collapsed within the first 30 steps. The model degenerated into passing extra `null` arguments to every tool call (`"sql": null, "table_name": "...", "value": null`), triggering `unexpected keyword argument` errors on every rollout. It never recovered across 351 steps (~8 hours on L4). ## Timeline | Step | Reward | What the model does | |------|--------|-------------------| | 10 | -1.25 | First call has extra args, gets error, loops with `Episode is over` | | 20 | 0.01 | Occasionally correct describe, but passes wrong args to answer | | 30 | 0.00 | Stuck: `describe(sql=null, table_name="concert")` infinite loop | | 40-351 | 0.00 | Complete collapse: every rollout is identical error loops | ## Why it collapsed ### 1. SFT taught wrong argument patterns The SFT examples show `describe(table_name=...)` correctly, but the base Qwen3-1.7B model has a strong prior from pretraining to include all available parameter names in every call. The 353-turn SFT warmup (2 epochs, batch=2) wasn't enough to override this for all 4 tools. ### 2. Extra kwargs cause hard failures, not soft degradation When the model passes `describe(sql=null, table_name="flights")`, TRL dispatches `SQLEnvTRL.describe(sql=None, table_name="flights")` which raises `TypeError: unexpected keyword argument 'sql'`. This is a **hard wall** — the model gets zero useful information back, just an error string it can't learn from. ### 3. GRPO advantage collapse With 6 generations per question: - All 6 rollouts pass the same extra args → all get reward 0.0 - Advantage = 0.0 for every sample → zero gradient signal - The model has no way to discover that dropping the extra args would work - Loss oscillates near 0 throughout training ### 4. No recovery mechanism Once the model enters the error loop: - Error messages say "unexpected keyword argument 'sql'" but don't say "try calling with only table_name" - The model retries the same call pattern endlessly - Post-episode penalty accumulates negative reward (-1.25 at step 10) but doesn't help because ALL rollouts are equally bad - No positive examples exist in any rollout group to provide advantage signal ## The core problem: kwargs rejection vs. kwargs tolerance The TRL adapter methods have strict signatures: ```python def describe(self, table_name: str) -> str: def query(self, sql: str) -> str: def answer(self, value: str) -> str: ``` When the model generates `{"table_name": "flights", "sql": null}`, Python raises TypeError before the method body executes. The model never gets a schema response, so it has no path to success. ## Fix: Accept and ignore extra kwargs The simplest fix is to make the tool methods tolerant of extra arguments: ```python def describe(self, table_name: str, **kwargs) -> str: def query(self, sql: str, **kwargs) -> str: def answer(self, value: str, **kwargs) -> str: def sample(self, table_name: str, **kwargs) -> str: ``` This means `describe(sql=null, table_name="flights")` would work — it would ignore `sql` and return the schema. The model gets useful feedback, can write SQL, and has a path to positive reward. GRPO then has signal to learn that the extra args are unnecessary. **Why this is the right approach:** - Small models (1.7B) lack the capacity to perfectly learn function signatures from tool definitions alone - The tool definitions in `` XML clearly state which params are required — the model will converge toward correct signatures over time via reward signal - Strict rejection creates an unrecoverable dead end; tolerance creates a learning gradient - This matches how real APIs work — most accept and ignore unexpected fields ## Other contributing factors ### SFT quality issues - SFT was only 100 questions x ~3.5 turns = 347 examples - Only 2 epochs at batch=2 (total 347 steps) - The model learned tool-call format but not strict argument isolation - Need: more SFT data or more epochs on existing data ### Missing KL penalty - No KL divergence penalty against the SFT reference model - GRPO updated the policy freely, drifting away from the SFT distribution - A KL penalty (beta=0.01-0.05) would have anchored the model near the working SFT baseline ### Learning rate may be too high - Default TRL learning rate (5e-7 or 1e-6) may be too aggressive for 1.7B - Lower LR (1e-7) would make smaller updates, reducing drift risk ## Recommended fixes (priority order) ### 1. Add `**kwargs` to all tool methods (critical) Prevents the hard wall. Model can still learn correct signatures from reward signal. ### 2. Increase SFT warmup - 4 epochs instead of 2 - Or increase SFT data from 100 to 200 questions - Verify post-SFT that the model generates correct single-arg calls ### 3. Add KL penalty ```python GRPOConfig( ..., beta=0.04, # KL penalty against SFT reference ) ``` Prevents policy from drifting too far from the working SFT baseline. ### 4. Lower GRPO learning rate From default to 1e-7 or 5e-8. ## Verification checklist Before running GRPO again: - [ ] Post-SFT format check shows `describe(table_name="X")` with NO extra args - [ ] Tool methods accept `**kwargs` so extra args don't crash - [ ] First 10 GRPO steps show at least some reward > 0 - [ ] Reward doesn't flatline at 0.0 by step 30