vaibhav12332112312 commited on
Commit
1f72457
·
1 Parent(s): e52d302

training: smoke-mode + hardcoded peak hint + valid tool IDs

Browse files

- SMOKE_MODE=1 default: 1 phase x 1 round x 4 eps, lr 2e-4, r=16, 3 epochs (visible delta)
- always-on coach hint with day-aware top-3 peak hours
- system prompt lists valid niche/segment/competitor IDs (kills tool-arg errors)
- LoRA targets full MLP in smoke (gate/up/down + qkvo)
- new debug cell: io_log diff/error/hint stats

Made-with: Cursor

training/hf_run_space_train_job.sh CHANGED
@@ -8,7 +8,7 @@
8
  set -euo pipefail
9
 
10
  IMAGE="${HF_JOB_IMAGE:-pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime}"
11
- FLAVOR="${HF_JOB_FLAVOR:-l40sx1}"
12
  TIMEOUT="${HF_JOB_TIMEOUT:-8h}"
13
  SPACE_REPO="${HF_SPACE_REPO_ID:-vaibhavkhandare/train-bhai-train}"
14
  NB_EXEC_TIMEOUT="${NB_EXEC_TIMEOUT:-3600}"
 
8
  set -euo pipefail
9
 
10
  IMAGE="${HF_JOB_IMAGE:-pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime}"
11
+ FLAVOR="${HF_JOB_FLAVOR:-a100-large}"
12
  TIMEOUT="${HF_JOB_TIMEOUT:-8h}"
13
  SPACE_REPO="${HF_SPACE_REPO_ID:-vaibhavkhandare/train-bhai-train}"
14
  NB_EXEC_TIMEOUT="${NB_EXEC_TIMEOUT:-3600}"
training/train_grpo.ipynb CHANGED
@@ -175,7 +175,11 @@
175
  "# Same sanity as syntax_only.ipynb (kernel parses modern Python)\n",
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
- "print(\"OK: ast.parse (syntax check)\")"
 
 
 
 
179
  ],
180
  "execution_count": null,
181
  "outputs": []
@@ -439,6 +443,11 @@
439
  "- topic: free-form string\n",
440
  "- empty scheduled_actions = full day rest\n",
441
  "\n",
 
 
 
 
 
442
  "POSTING RULES:\n",
443
  "- Each active day: 2-3 `post` actions at the audience's peak hours.\n",
444
  "- `create_content` alone earns 0 reward.\n",
@@ -494,7 +503,10 @@
494
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
495
  " if not tool_str:\n",
496
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
497
- " hint_str = f\"Coach hint: today's peak hours are {extra_hint}.\\n\" if extra_hint else \"\"\n",
 
 
 
498
  " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
499
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
500
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
@@ -653,9 +665,9 @@
653
  " actions_by_idx = {i: rest_action for i in rest}\n",
654
  " if active:\n",
655
  " def _hint_for(i):\n",
656
- " if not hint_peak_hours:\n",
657
  " return None\n",
658
- " hrs = get_peak_hours(obss[i].day_of_week, top_k=2)\n",
659
  " return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
660
  " base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
661
  "\n",
@@ -787,11 +799,19 @@
787
  "# Cell 10: Attach LoRA adapter\n",
788
  "from peft import LoraConfig, get_peft_model, TaskType\n",
789
  "\n",
790
- "lora_config = LoraConfig(\n",
791
- " r=8, lora_alpha=16, lora_dropout=0.05,\n",
792
- " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
793
- " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
794
- ")\n",
 
 
 
 
 
 
 
 
795
  "\n",
796
  "model.enable_input_require_grads()\n",
797
  "peft_model = get_peft_model(model, lora_config)\n",
@@ -810,14 +830,25 @@
810
  "from trl import SFTTrainer, SFTConfig\n",
811
  "from datasets import Dataset\n",
812
  "\n",
813
- "EPISODES_PER_ROUND = 6\n",
814
- "ROUNDS_PER_PHASE = 3\n",
815
- "QUALITY_FLOOR = 0.0\n",
816
- "\n",
817
- "PHASES = [\n",
818
- " {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
819
- " {\"name\": \"phase2_content\", \"reward_mode\": \"content\", \"system\": SYSTEM_PROMPT_CONTENT},\n",
820
- "]\n",
 
 
 
 
 
 
 
 
 
 
 
821
  "\n",
822
  "training_log = {\n",
823
  " \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
@@ -889,10 +920,10 @@
889
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
890
  " sft_config = SFTConfig(\n",
891
  " output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
892
- " num_train_epochs=1,\n",
893
  " per_device_train_batch_size=2,\n",
894
  " gradient_accumulation_steps=4,\n",
895
- " learning_rate=5e-6,\n",
896
  " warmup_steps=5,\n",
897
  " logging_steps=1,\n",
898
  " save_strategy=\"no\",\n",
@@ -965,6 +996,73 @@
965
  "execution_count": null,
966
  "outputs": []
967
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  {
969
  "cell_type": "markdown",
970
  "metadata": {},
 
175
  "# Same sanity as syntax_only.ipynb (kernel parses modern Python)\n",
176
  "import ast\n",
177
  "ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
178
+ "print(\"OK: ast.parse (syntax check)\")\n",
179
+ "\n",
180
+ "SMOKE_MODE = bool(int(os.environ.get(\"SMOKE_MODE\", \"1\")))\n",
181
+ "HINT_ALWAYS = True\n",
182
+ "print(f\"SMOKE_MODE={SMOKE_MODE} | HINT_ALWAYS={HINT_ALWAYS}\")"
183
  ],
184
  "execution_count": null,
185
  "outputs": []
 
443
  "- topic: free-form string\n",
444
  "- empty scheduled_actions = full day rest\n",
445
  "\n",
446
+ "VALID TOOL ARGS (use ONLY these IDs — invented IDs return ERROR):\n",
447
+ "- niche: tech | lifestyle | fitness | business | food | travel | fashion | beauty | photography | education\n",
448
+ "- segment_id: young_professionals | students | parents | global_night_owls | passive_scrollers\n",
449
+ "- competitor_id: niche_expert | viral_chaser | lifestyle_blogger | b2b_thought_leader | food_creator | fitness_coach | travel_creator\n",
450
+ "\n",
451
  "POSTING RULES:\n",
452
  "- Each active day: 2-3 `post` actions at the audience's peak hours.\n",
453
  "- `create_content` alone earns 0 reward.\n",
 
503
  " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
504
  " if not tool_str:\n",
505
  " tool_str = \" (none — call query_* tools to discover)\\n\"\n",
506
+ " hint_str = (\n",
507
+ " f\"COACH HINT (USE THESE EXACT HOURS): post 2-3 times today at hours {extra_hint}. \"\n",
508
+ " f\"Set scheduled_actions[i].hour to one of these values.\\n\"\n",
509
+ " ) if extra_hint else \"\"\n",
510
  " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
511
  " f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
512
  " f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
 
665
  " actions_by_idx = {i: rest_action for i in rest}\n",
666
  " if active:\n",
667
  " def _hint_for(i):\n",
668
+ " if not (hint_peak_hours or HINT_ALWAYS):\n",
669
  " return None\n",
670
+ " hrs = get_peak_hours(obss[i].day_of_week, top_k=3)\n",
671
  " return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
672
  " base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
673
  "\n",
 
799
  "# Cell 10: Attach LoRA adapter\n",
800
  "from peft import LoraConfig, get_peft_model, TaskType\n",
801
  "\n",
802
+ "if SMOKE_MODE:\n",
803
+ " lora_config = LoraConfig(\n",
804
+ " r=16, lora_alpha=32, lora_dropout=0.05,\n",
805
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
806
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
807
+ " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
808
+ " )\n",
809
+ "else:\n",
810
+ " lora_config = LoraConfig(\n",
811
+ " r=8, lora_alpha=16, lora_dropout=0.05,\n",
812
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
813
+ " task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
814
+ " )\n",
815
  "\n",
816
  "model.enable_input_require_grads()\n",
817
  "peft_model = get_peft_model(model, lora_config)\n",
 
830
  "from trl import SFTTrainer, SFTConfig\n",
831
  "from datasets import Dataset\n",
832
  "\n",
833
+ "if SMOKE_MODE:\n",
834
+ " EPISODES_PER_ROUND = 4\n",
835
+ " ROUNDS_PER_PHASE = 1\n",
836
+ " QUALITY_FLOOR = 0.0\n",
837
+ " NUM_TRAIN_EPOCHS = 3\n",
838
+ " LEARNING_RATE = 2e-4\n",
839
+ " PHASES = [\n",
840
+ " {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
841
+ " ]\n",
842
+ "else:\n",
843
+ " EPISODES_PER_ROUND = 6\n",
844
+ " ROUNDS_PER_PHASE = 3\n",
845
+ " QUALITY_FLOOR = 0.0\n",
846
+ " NUM_TRAIN_EPOCHS = 1\n",
847
+ " LEARNING_RATE = 5e-6\n",
848
+ " PHASES = [\n",
849
+ " {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
850
+ " {\"name\": \"phase2_content\", \"reward_mode\": \"content\", \"system\": SYSTEM_PROMPT_CONTENT},\n",
851
+ " ]\n",
852
  "\n",
853
  "training_log = {\n",
854
  " \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
 
920
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
921
  " sft_config = SFTConfig(\n",
922
  " output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
923
+ " num_train_epochs=NUM_TRAIN_EPOCHS,\n",
924
  " per_device_train_batch_size=2,\n",
925
  " gradient_accumulation_steps=4,\n",
926
+ " learning_rate=LEARNING_RATE,\n",
927
  " warmup_steps=5,\n",
928
  " logging_steps=1,\n",
929
  " save_strategy=\"no\",\n",
 
996
  "execution_count": null,
997
  "outputs": []
998
  },
999
+ {
1000
+ "cell_type": "code",
1001
+ "metadata": {},
1002
+ "source": [
1003
+ "# Cell 12.5: Debug — analyse io_log.jsonl (before vs after, tool error rate, hint usage)\n",
1004
+ "import re\n",
1005
+ "from collections import Counter\n",
1006
+ "\n",
1007
+ "def _safe_json_loads(s):\n",
1008
+ " try:\n",
1009
+ " s = s.strip()\n",
1010
+ " if \"```\" in s:\n",
1011
+ " s = \"\\n\".join(l for l in s.split(\"\\n\") if not l.strip().startswith(\"```\")).strip()\n",
1012
+ " a, b = s.find(\"{\"), s.rfind(\"}\") + 1\n",
1013
+ " return json.loads(s[a:b]) if a >= 0 and b > a else None\n",
1014
+ " except Exception:\n",
1015
+ " return None\n",
1016
+ "\n",
1017
+ "records = []\n",
1018
+ "with open(IO_LOG_PATH) as f:\n",
1019
+ " for line in f:\n",
1020
+ " if line.strip():\n",
1021
+ " records.append(json.loads(line))\n",
1022
+ "\n",
1023
+ "by_tag = Counter(r[\"tag\"] for r in records)\n",
1024
+ "print(\"io_log records by tag:\", dict(by_tag))\n",
1025
+ "\n",
1026
+ "before = {(r[\"ep\"], r[\"day\"], r[\"tag\"].split(\"/\")[1]): r for r in records if r[\"tag\"].startswith(\"before\")}\n",
1027
+ "after = {(r[\"ep\"], r[\"day\"], r[\"tag\"].split(\"/\")[1]): r for r in records if r[\"tag\"].startswith(\"after\")}\n",
1028
+ "common = set(before) & set(after)\n",
1029
+ "identical = sum(1 for k in common if before[k][\"response\"] == after[k][\"response\"])\n",
1030
+ "print(f\"\\nbefore/after: {len(common)} common keys, identical={identical}, diff={len(common)-identical}\")\n",
1031
+ "\n",
1032
+ "tool_errs = sum(1 for r in records if r[\"tag\"].endswith(\"/A\") and \"ERROR\" in r[\"response\"])\n",
1033
+ "print(f\"PHASE A responses containing 'ERROR' string: {tool_errs}\")\n",
1034
+ "\n",
1035
+ "niche_used, seg_used, comp_used = Counter(), Counter(), Counter()\n",
1036
+ "for r in records:\n",
1037
+ " if not r[\"tag\"].endswith(\"/A\"):\n",
1038
+ " continue\n",
1039
+ " j = _safe_json_loads(r[\"response\"])\n",
1040
+ " if not j:\n",
1041
+ " continue\n",
1042
+ " for tc in j.get(\"tool_calls\", []):\n",
1043
+ " a = tc.get(\"arguments\", {}) or {}\n",
1044
+ " if tc.get(\"name\") == \"query_trends\" and \"niche\" in a: niche_used[a[\"niche\"]] += 1\n",
1045
+ " if tc.get(\"name\") == \"query_audience\" and \"segment_id\" in a: seg_used[a[\"segment_id\"]] += 1\n",
1046
+ " if tc.get(\"name\") == \"query_competitor\" and \"competitor_id\" in a: comp_used[a[\"competitor_id\"]] += 1\n",
1047
+ "print(\"\\nTop niches used:\", niche_used.most_common(8))\n",
1048
+ "print(\"Top segments used:\", seg_used.most_common(8))\n",
1049
+ "print(\"Top competitors used:\", comp_used.most_common(8))\n",
1050
+ "\n",
1051
+ "hint_seen = sum(1 for r in records if \"COACH HINT\" in r[\"prompt\"])\n",
1052
+ "print(f\"\\nPrompts containing COACH HINT: {hint_seen}/{len(records)}\")\n",
1053
+ "\n",
1054
+ "if common:\n",
1055
+ " k = next(iter(sorted(common)))\n",
1056
+ " print(f\"\\n--- diff sample @ {k} (B-phase only if available) ---\")\n",
1057
+ " bk = before.get((k[0], k[1], \"B\"))\n",
1058
+ " ak = after.get((k[0], k[1], \"B\"))\n",
1059
+ " if bk and ak:\n",
1060
+ " print(\"BEFORE response head:\", bk[\"response\"][:300].replace(\"\\n\", \" \"))\n",
1061
+ " print(\"AFTER response head:\", ak[\"response\"][:300].replace(\"\\n\", \" \"))"
1062
+ ],
1063
+ "execution_count": null,
1064
+ "outputs": []
1065
+ },
1066
  {
1067
  "cell_type": "markdown",
1068
  "metadata": {},