vaibhav12332112312 commited on
Commit
4299c91
·
1 Parent(s): 9fac734

Mandate tool calls in system prompt to debug zero-tool collapse

Browse files

All 180 logged responses (before / round1 / after) had tool_calls = [].
The previous prompt only listed the catalog and hinted at posting hours,
giving the model no reason to call any query_*. Add an explicit tool
policy: required query_* on day 1 and predict_engagement + a fresh
query_* on later days. Concrete arg examples (segment_id, competitor_id)
keep validation from rejecting them.

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +14 -5
training/train_grpo.ipynb CHANGED
@@ -441,9 +441,18 @@
441
  "- topic: free-form string\n",
442
  "- empty scheduled_actions = full day rest\"\"\")\n",
443
  "\n",
444
- "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
445
- "\n",
446
- "HINT: schedule posts during/just before the audience_active_hours window that is when your target users are online.\"\"\")\n",
 
 
 
 
 
 
 
 
 
447
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
448
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
449
  "\n",
@@ -840,8 +849,8 @@
840
  "\n",
841
  "peft_model.eval()\n",
842
  "t0 = time.time()\n",
843
- "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True, log_tag=\"after\")\n",
844
- "after_results = {r[\"task\"]: r for r in results}\n",
845
  "\n",
846
  "print(\"\\n\" + \"=\" * 60)\n",
847
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
 
441
  "- topic: free-form string\n",
442
  "- empty scheduled_actions = full day rest\"\"\")\n",
443
  "\n",
444
+ "SYSTEM_PROMPT = _SYSTEM_BASE + textwrap.dedent(\"\"\"\n",
445
+ "\n",
446
+ "TOOL POLICY (MANDATORYempty tool_calls on day 1 = wasted day):\n",
447
+ "- days_elapsed == 0 -> call AT LEAST these in tool_calls:\n",
448
+ " {\"name\": \"query_trends\", \"arguments\": {\"niche\": \"<one of TOPIC_CATEGORIES keys>\"}}\n",
449
+ " {\"name\": \"query_audience\", \"arguments\": {\"segment_id\": \"young_professionals\"}}\n",
450
+ " {\"name\": \"query_creator_pool\", \"arguments\": {}}\n",
451
+ " {\"name\": \"query_competitor\", \"arguments\": {\"competitor_id\": \"niche_expert\", \"window_days\": 7}}\n",
452
+ "- days_elapsed >= 1 -> before scheduling, call:\n",
453
+ " {\"name\": \"predict_engagement\", \"arguments\": {\"scheduled_actions\": [...]}}\n",
454
+ " and at least one query_* tool whose result you don't already have in Tool results.\n",
455
+ "- audience_active_hours in the observation is a coarse hint; query_audience returns ranked topic affinities you cannot get otherwise.\"\"\")\n",
456
  "SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
457
  "SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
458
  "\n",
 
849
  "\n",
850
  "peft_model.eval()\n",
851
  "t0 = time.time()\n",
852
+ "results = run_llm_episodes_batched(peft_model, tokenizer, [(t, 42) for t in TASKS], verbose=True, eval=True, log_tag=\"after\")\n",
853
+ "after_results = {r[\"task\"]: r for r in results}\n",
854
  "\n",
855
  "print(\"\\n\" + \"=\" * 60)\n",
856
  "print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",