Spaces:
Paused
Paused
Commit ·
1f72457
1
Parent(s): e52d302
training: smoke-mode + hardcoded peak hint + valid tool IDs
Browse files- SMOKE_MODE=1 default: 1 phase x 1 round x 4 eps, lr 2e-4, r=16, 3 epochs (visible delta)
- always-on coach hint with day-aware top-3 peak hours
- system prompt lists valid niche/segment/competitor IDs (kills tool-arg errors)
- LoRA targets full MLP in smoke (gate/up/down + qkvo)
- new debug cell: io_log diff/error/hint stats
Made-with: Cursor
- training/hf_run_space_train_job.sh +1 -1
- training/train_grpo.ipynb +117 -19
training/hf_run_space_train_job.sh
CHANGED
|
@@ -8,7 +8,7 @@
|
|
| 8 |
set -euo pipefail
|
| 9 |
|
| 10 |
IMAGE="${HF_JOB_IMAGE:-pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime}"
|
| 11 |
-
FLAVOR="${HF_JOB_FLAVOR:-
|
| 12 |
TIMEOUT="${HF_JOB_TIMEOUT:-8h}"
|
| 13 |
SPACE_REPO="${HF_SPACE_REPO_ID:-vaibhavkhandare/train-bhai-train}"
|
| 14 |
NB_EXEC_TIMEOUT="${NB_EXEC_TIMEOUT:-3600}"
|
|
|
|
| 8 |
set -euo pipefail
|
| 9 |
|
| 10 |
IMAGE="${HF_JOB_IMAGE:-pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime}"
|
| 11 |
+
FLAVOR="${HF_JOB_FLAVOR:-a100-large}"
|
| 12 |
TIMEOUT="${HF_JOB_TIMEOUT:-8h}"
|
| 13 |
SPACE_REPO="${HF_SPACE_REPO_ID:-vaibhavkhandare/train-bhai-train}"
|
| 14 |
NB_EXEC_TIMEOUT="${NB_EXEC_TIMEOUT:-3600}"
|
training/train_grpo.ipynb
CHANGED
|
@@ -175,7 +175,11 @@
|
|
| 175 |
"# Same sanity as syntax_only.ipynb (kernel parses modern Python)\n",
|
| 176 |
"import ast\n",
|
| 177 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 178 |
-
"print(\"OK: ast.parse (syntax check)\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
],
|
| 180 |
"execution_count": null,
|
| 181 |
"outputs": []
|
|
@@ -439,6 +443,11 @@
|
|
| 439 |
"- topic: free-form string\n",
|
| 440 |
"- empty scheduled_actions = full day rest\n",
|
| 441 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
"POSTING RULES:\n",
|
| 443 |
"- Each active day: 2-3 `post` actions at the audience's peak hours.\n",
|
| 444 |
"- `create_content` alone earns 0 reward.\n",
|
|
@@ -494,7 +503,10 @@
|
|
| 494 |
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 495 |
" if not tool_str:\n",
|
| 496 |
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
| 497 |
-
" hint_str =
|
|
|
|
|
|
|
|
|
|
| 498 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 499 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 500 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
|
@@ -653,9 +665,9 @@
|
|
| 653 |
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 654 |
" if active:\n",
|
| 655 |
" def _hint_for(i):\n",
|
| 656 |
-
" if not hint_peak_hours:\n",
|
| 657 |
" return None\n",
|
| 658 |
-
" hrs = get_peak_hours(obss[i].day_of_week, top_k=
|
| 659 |
" return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
|
| 660 |
" base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
|
| 661 |
"\n",
|
|
@@ -787,11 +799,19 @@
|
|
| 787 |
"# Cell 10: Attach LoRA adapter\n",
|
| 788 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 789 |
"\n",
|
| 790 |
-
"
|
| 791 |
-
"
|
| 792 |
-
"
|
| 793 |
-
"
|
| 794 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 795 |
"\n",
|
| 796 |
"model.enable_input_require_grads()\n",
|
| 797 |
"peft_model = get_peft_model(model, lora_config)\n",
|
|
@@ -810,14 +830,25 @@
|
|
| 810 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 811 |
"from datasets import Dataset\n",
|
| 812 |
"\n",
|
| 813 |
-
"
|
| 814 |
-
"
|
| 815 |
-
"
|
| 816 |
-
"\n",
|
| 817 |
-
"
|
| 818 |
-
"
|
| 819 |
-
"
|
| 820 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
"\n",
|
| 822 |
"training_log = {\n",
|
| 823 |
" \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
|
|
@@ -889,10 +920,10 @@
|
|
| 889 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 890 |
" sft_config = SFTConfig(\n",
|
| 891 |
" output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
|
| 892 |
-
" num_train_epochs=
|
| 893 |
" per_device_train_batch_size=2,\n",
|
| 894 |
" gradient_accumulation_steps=4,\n",
|
| 895 |
-
" learning_rate=
|
| 896 |
" warmup_steps=5,\n",
|
| 897 |
" logging_steps=1,\n",
|
| 898 |
" save_strategy=\"no\",\n",
|
|
@@ -965,6 +996,73 @@
|
|
| 965 |
"execution_count": null,
|
| 966 |
"outputs": []
|
| 967 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
{
|
| 969 |
"cell_type": "markdown",
|
| 970 |
"metadata": {},
|
|
|
|
| 175 |
"# Same sanity as syntax_only.ipynb (kernel parses modern Python)\n",
|
| 176 |
"import ast\n",
|
| 177 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 178 |
+
"print(\"OK: ast.parse (syntax check)\")\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"SMOKE_MODE = bool(int(os.environ.get(\"SMOKE_MODE\", \"1\")))\n",
|
| 181 |
+
"HINT_ALWAYS = True\n",
|
| 182 |
+
"print(f\"SMOKE_MODE={SMOKE_MODE} | HINT_ALWAYS={HINT_ALWAYS}\")"
|
| 183 |
],
|
| 184 |
"execution_count": null,
|
| 185 |
"outputs": []
|
|
|
|
| 443 |
"- topic: free-form string\n",
|
| 444 |
"- empty scheduled_actions = full day rest\n",
|
| 445 |
"\n",
|
| 446 |
+
"VALID TOOL ARGS (use ONLY these IDs — invented IDs return ERROR):\n",
|
| 447 |
+
"- niche: tech | lifestyle | fitness | business | food | travel | fashion | beauty | photography | education\n",
|
| 448 |
+
"- segment_id: young_professionals | students | parents | global_night_owls | passive_scrollers\n",
|
| 449 |
+
"- competitor_id: niche_expert | viral_chaser | lifestyle_blogger | b2b_thought_leader | food_creator | fitness_coach | travel_creator\n",
|
| 450 |
+
"\n",
|
| 451 |
"POSTING RULES:\n",
|
| 452 |
"- Each active day: 2-3 `post` actions at the audience's peak hours.\n",
|
| 453 |
"- `create_content` alone earns 0 reward.\n",
|
|
|
|
| 503 |
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 504 |
" if not tool_str:\n",
|
| 505 |
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
| 506 |
+
" hint_str = (\n",
|
| 507 |
+
" f\"COACH HINT (USE THESE EXACT HOURS): post 2-3 times today at hours {extra_hint}. \"\n",
|
| 508 |
+
" f\"Set scheduled_actions[i].hour to one of these values.\\n\"\n",
|
| 509 |
+
" ) if extra_hint else \"\"\n",
|
| 510 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 511 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 512 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
|
|
|
| 665 |
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 666 |
" if active:\n",
|
| 667 |
" def _hint_for(i):\n",
|
| 668 |
+
" if not (hint_peak_hours or HINT_ALWAYS):\n",
|
| 669 |
" return None\n",
|
| 670 |
+
" hrs = get_peak_hours(obss[i].day_of_week, top_k=3)\n",
|
| 671 |
" return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
|
| 672 |
" base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
|
| 673 |
"\n",
|
|
|
|
| 799 |
"# Cell 10: Attach LoRA adapter\n",
|
| 800 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
| 801 |
"\n",
|
| 802 |
+
"if SMOKE_MODE:\n",
|
| 803 |
+
" lora_config = LoraConfig(\n",
|
| 804 |
+
" r=16, lora_alpha=32, lora_dropout=0.05,\n",
|
| 805 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 806 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 807 |
+
" task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
|
| 808 |
+
" )\n",
|
| 809 |
+
"else:\n",
|
| 810 |
+
" lora_config = LoraConfig(\n",
|
| 811 |
+
" r=8, lora_alpha=16, lora_dropout=0.05,\n",
|
| 812 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
|
| 813 |
+
" task_type=TaskType.CAUSAL_LM, bias=\"none\",\n",
|
| 814 |
+
" )\n",
|
| 815 |
"\n",
|
| 816 |
"model.enable_input_require_grads()\n",
|
| 817 |
"peft_model = get_peft_model(model, lora_config)\n",
|
|
|
|
| 830 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 831 |
"from datasets import Dataset\n",
|
| 832 |
"\n",
|
| 833 |
+
"if SMOKE_MODE:\n",
|
| 834 |
+
" EPISODES_PER_ROUND = 4\n",
|
| 835 |
+
" ROUNDS_PER_PHASE = 1\n",
|
| 836 |
+
" QUALITY_FLOOR = 0.0\n",
|
| 837 |
+
" NUM_TRAIN_EPOCHS = 3\n",
|
| 838 |
+
" LEARNING_RATE = 2e-4\n",
|
| 839 |
+
" PHASES = [\n",
|
| 840 |
+
" {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
|
| 841 |
+
" ]\n",
|
| 842 |
+
"else:\n",
|
| 843 |
+
" EPISODES_PER_ROUND = 6\n",
|
| 844 |
+
" ROUNDS_PER_PHASE = 3\n",
|
| 845 |
+
" QUALITY_FLOOR = 0.0\n",
|
| 846 |
+
" NUM_TRAIN_EPOCHS = 1\n",
|
| 847 |
+
" LEARNING_RATE = 5e-6\n",
|
| 848 |
+
" PHASES = [\n",
|
| 849 |
+
" {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
|
| 850 |
+
" {\"name\": \"phase2_content\", \"reward_mode\": \"content\", \"system\": SYSTEM_PROMPT_CONTENT},\n",
|
| 851 |
+
" ]\n",
|
| 852 |
"\n",
|
| 853 |
"training_log = {\n",
|
| 854 |
" \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
|
|
|
|
| 920 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 921 |
" sft_config = SFTConfig(\n",
|
| 922 |
" output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
|
| 923 |
+
" num_train_epochs=NUM_TRAIN_EPOCHS,\n",
|
| 924 |
" per_device_train_batch_size=2,\n",
|
| 925 |
" gradient_accumulation_steps=4,\n",
|
| 926 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 927 |
" warmup_steps=5,\n",
|
| 928 |
" logging_steps=1,\n",
|
| 929 |
" save_strategy=\"no\",\n",
|
|
|
|
| 996 |
"execution_count": null,
|
| 997 |
"outputs": []
|
| 998 |
},
|
| 999 |
+
{
|
| 1000 |
+
"cell_type": "code",
|
| 1001 |
+
"metadata": {},
|
| 1002 |
+
"source": [
|
| 1003 |
+
"# Cell 12.5: Debug — analyse io_log.jsonl (before vs after, tool error rate, hint usage)\n",
|
| 1004 |
+
"import re\n",
|
| 1005 |
+
"from collections import Counter\n",
|
| 1006 |
+
"\n",
|
| 1007 |
+
"def _safe_json_loads(s):\n",
|
| 1008 |
+
" try:\n",
|
| 1009 |
+
" s = s.strip()\n",
|
| 1010 |
+
" if \"```\" in s:\n",
|
| 1011 |
+
" s = \"\\n\".join(l for l in s.split(\"\\n\") if not l.strip().startswith(\"```\")).strip()\n",
|
| 1012 |
+
" a, b = s.find(\"{\"), s.rfind(\"}\") + 1\n",
|
| 1013 |
+
" return json.loads(s[a:b]) if a >= 0 and b > a else None\n",
|
| 1014 |
+
" except Exception:\n",
|
| 1015 |
+
" return None\n",
|
| 1016 |
+
"\n",
|
| 1017 |
+
"records = []\n",
|
| 1018 |
+
"with open(IO_LOG_PATH) as f:\n",
|
| 1019 |
+
" for line in f:\n",
|
| 1020 |
+
" if line.strip():\n",
|
| 1021 |
+
" records.append(json.loads(line))\n",
|
| 1022 |
+
"\n",
|
| 1023 |
+
"by_tag = Counter(r[\"tag\"] for r in records)\n",
|
| 1024 |
+
"print(\"io_log records by tag:\", dict(by_tag))\n",
|
| 1025 |
+
"\n",
|
| 1026 |
+
"before = {(r[\"ep\"], r[\"day\"], r[\"tag\"].split(\"/\")[1]): r for r in records if r[\"tag\"].startswith(\"before\")}\n",
|
| 1027 |
+
"after = {(r[\"ep\"], r[\"day\"], r[\"tag\"].split(\"/\")[1]): r for r in records if r[\"tag\"].startswith(\"after\")}\n",
|
| 1028 |
+
"common = set(before) & set(after)\n",
|
| 1029 |
+
"identical = sum(1 for k in common if before[k][\"response\"] == after[k][\"response\"])\n",
|
| 1030 |
+
"print(f\"\\nbefore/after: {len(common)} common keys, identical={identical}, diff={len(common)-identical}\")\n",
|
| 1031 |
+
"\n",
|
| 1032 |
+
"tool_errs = sum(1 for r in records if r[\"tag\"].endswith(\"/A\") and \"ERROR\" in r[\"response\"])\n",
|
| 1033 |
+
"print(f\"PHASE A responses containing 'ERROR' string: {tool_errs}\")\n",
|
| 1034 |
+
"\n",
|
| 1035 |
+
"niche_used, seg_used, comp_used = Counter(), Counter(), Counter()\n",
|
| 1036 |
+
"for r in records:\n",
|
| 1037 |
+
" if not r[\"tag\"].endswith(\"/A\"):\n",
|
| 1038 |
+
" continue\n",
|
| 1039 |
+
" j = _safe_json_loads(r[\"response\"])\n",
|
| 1040 |
+
" if not j:\n",
|
| 1041 |
+
" continue\n",
|
| 1042 |
+
" for tc in j.get(\"tool_calls\", []):\n",
|
| 1043 |
+
" a = tc.get(\"arguments\", {}) or {}\n",
|
| 1044 |
+
" if tc.get(\"name\") == \"query_trends\" and \"niche\" in a: niche_used[a[\"niche\"]] += 1\n",
|
| 1045 |
+
" if tc.get(\"name\") == \"query_audience\" and \"segment_id\" in a: seg_used[a[\"segment_id\"]] += 1\n",
|
| 1046 |
+
" if tc.get(\"name\") == \"query_competitor\" and \"competitor_id\" in a: comp_used[a[\"competitor_id\"]] += 1\n",
|
| 1047 |
+
"print(\"\\nTop niches used:\", niche_used.most_common(8))\n",
|
| 1048 |
+
"print(\"Top segments used:\", seg_used.most_common(8))\n",
|
| 1049 |
+
"print(\"Top competitors used:\", comp_used.most_common(8))\n",
|
| 1050 |
+
"\n",
|
| 1051 |
+
"hint_seen = sum(1 for r in records if \"COACH HINT\" in r[\"prompt\"])\n",
|
| 1052 |
+
"print(f\"\\nPrompts containing COACH HINT: {hint_seen}/{len(records)}\")\n",
|
| 1053 |
+
"\n",
|
| 1054 |
+
"if common:\n",
|
| 1055 |
+
" k = next(iter(sorted(common)))\n",
|
| 1056 |
+
" print(f\"\\n--- diff sample @ {k} (B-phase only if available) ---\")\n",
|
| 1057 |
+
" bk = before.get((k[0], k[1], \"B\"))\n",
|
| 1058 |
+
" ak = after.get((k[0], k[1], \"B\"))\n",
|
| 1059 |
+
" if bk and ak:\n",
|
| 1060 |
+
" print(\"BEFORE response head:\", bk[\"response\"][:300].replace(\"\\n\", \" \"))\n",
|
| 1061 |
+
" print(\"AFTER response head:\", ak[\"response\"][:300].replace(\"\\n\", \" \"))"
|
| 1062 |
+
],
|
| 1063 |
+
"execution_count": null,
|
| 1064 |
+
"outputs": []
|
| 1065 |
+
},
|
| 1066 |
{
|
| 1067 |
"cell_type": "markdown",
|
| 1068 |
"metadata": {},
|