vaibhav12332112312 commited on
Commit
9ee7a09
·
1 Parent(s): 56f70b1

train: per-step credit + drop replies + larger batches

Browse files

- run_llm_episode now computes Monte-Carlo return-to-go per day
(G_t = r_t + 0.95*G_{t+1}, terminal = grader_score*5); each pair
gets its own return for top-K filtering, removing the
same-reward-per-pair caveat.
- SFTConfig: per_device_train_batch_size 16->32, bf16, fused AdamW,
warmup_ratio=0.1, no grad checkpointing for ~90% VRAM use on 48GB.
- SYSTEM_PROMPT rewritten as full tool/action schema; removed
subjective rules (optimal posts, diversity bonus); reach-bonus
clarified.
- format_obs no longer truncates tool results; history window 4->14
messages (7 days); removed hard-coded auto-rest heuristic.
- Drop replies feature across env, models, client, inference, and
training scripts.
- NUM_ROUNDS/EPISODES_PER_ROUND set to 1 for smoke run.

Made-with: Cursor

__init__.py CHANGED
@@ -10,7 +10,6 @@ from .client import ViraltestEnv
10
  from .models import (
11
  CollabProposal,
12
  EngagementSignals,
13
- ReplyAction,
14
  ScheduledAction,
15
  ToolCall,
16
  ToolResult,
@@ -21,7 +20,6 @@ from .models import (
21
  __all__ = [
22
  "CollabProposal",
23
  "EngagementSignals",
24
- "ReplyAction",
25
  "ScheduledAction",
26
  "ToolCall",
27
  "ToolResult",
 
10
  from .models import (
11
  CollabProposal,
12
  EngagementSignals,
 
13
  ScheduledAction,
14
  ToolCall,
15
  ToolResult,
 
20
  __all__ = [
21
  "CollabProposal",
22
  "EngagementSignals",
 
23
  "ScheduledAction",
24
  "ToolCall",
25
  "ToolResult",
client.py CHANGED
@@ -43,12 +43,6 @@ class ViraltestEnv(EnvClient[ViraltestAction, ViraltestObservation, State]):
43
  actions_list.append(item)
44
  payload["scheduled_actions"] = actions_list
45
 
46
- if action.replies:
47
- payload["replies"] = [
48
- {"post_hour": r.post_hour, "reply_hour": r.reply_hour}
49
- for r in action.replies
50
- ]
51
-
52
  if action.collab:
53
  payload["collab"] = {
54
  "partner_id": action.collab.partner_id,
 
43
  actions_list.append(item)
44
  payload["scheduled_actions"] = actions_list
45
 
 
 
 
 
 
 
46
  if action.collab:
47
  payload["collab"] = {
48
  "partner_id": action.collab.partner_id,
inference.py CHANGED
@@ -74,7 +74,6 @@ RESPONSE FORMAT (JSON only, no markdown, no prose):
74
  {"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"},
75
  {"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup life", "tags": ["startup", "growth"], "intent": "save_bait"}
76
  ],
77
- "replies": [{"post_hour": 12, "reply_hour": 13}],
78
  "notes": "Day 3: tech niche trending up. Competitor Alpha posted at 10am. Avoiding overlap."
79
  }
80
 
@@ -87,7 +86,6 @@ RULES:
87
  - Use notes to track hypotheses and observations across days
88
  - Tool calls cost API budget (starts at 100). Use wisely.
89
  - Max 2 collaborations per month
90
- - Reply within 90 minutes of a post for reach bonus
91
 
92
  Think strategically: use tools to discover what works, then exploit what you learn.""")
93
 
@@ -201,13 +199,11 @@ def parse_daily_plan(response_text: str) -> ViraltestAction:
201
  if isinstance(a, dict):
202
  scheduled.append(a)
203
 
204
- replies_raw = data.get("replies", [])
205
  notes = data.get("notes")
206
 
207
  return ViraltestAction(
208
  tool_calls=tool_calls,
209
  scheduled_actions=scheduled,
210
- replies=replies_raw if isinstance(replies_raw, list) else [],
211
  notes=notes,
212
  )
213
  except (json.JSONDecodeError, Exception):
@@ -236,7 +232,6 @@ def sanitize_predefined_topics(action: ViraltestAction, obs: Any) -> ViraltestAc
236
  return ViraltestAction(
237
  tool_calls=action.tool_calls,
238
  scheduled_actions=out,
239
- replies=action.replies,
240
  collab=action.collab,
241
  notes=action.notes,
242
  )
 
74
  {"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"},
75
  {"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup life", "tags": ["startup", "growth"], "intent": "save_bait"}
76
  ],
 
77
  "notes": "Day 3: tech niche trending up. Competitor Alpha posted at 10am. Avoiding overlap."
78
  }
79
 
 
86
  - Use notes to track hypotheses and observations across days
87
  - Tool calls cost API budget (starts at 100). Use wisely.
88
  - Max 2 collaborations per month
 
89
 
90
  Think strategically: use tools to discover what works, then exploit what you learn.""")
91
 
 
199
  if isinstance(a, dict):
200
  scheduled.append(a)
201
 
 
202
  notes = data.get("notes")
203
 
204
  return ViraltestAction(
205
  tool_calls=tool_calls,
206
  scheduled_actions=scheduled,
 
207
  notes=notes,
208
  )
209
  except (json.JSONDecodeError, Exception):
 
232
  return ViraltestAction(
233
  tool_calls=action.tool_calls,
234
  scheduled_actions=out,
 
235
  collab=action.collab,
236
  notes=action.notes,
237
  )
models.py CHANGED
@@ -56,13 +56,6 @@ class ScheduledAction(BaseModel):
56
  return v
57
 
58
 
59
- class ReplyAction(BaseModel):
60
- """Reply to comments on a post made earlier today (within reply window)."""
61
-
62
- post_hour: int = Field(..., ge=0, le=23, description="Hour of the post to reply on")
63
- reply_hour: int = Field(..., ge=0, le=23, description="Hour to send replies")
64
-
65
-
66
  class CollabProposal(BaseModel):
67
  """Propose a collaboration with a competitor archetype."""
68
 
@@ -82,10 +75,6 @@ class ViraltestAction(Action):
82
  default_factory=list,
83
  description="Actions scheduled at specific hours; unlisted hours are rest",
84
  )
85
- replies: List[ReplyAction] = Field(
86
- default_factory=list,
87
- description="Reply actions on posts made today (within 90-min window for reach bonus)",
88
- )
89
  collab: Optional[CollabProposal] = Field(
90
  default=None,
91
  description="Optional collaboration proposal (max 2 per month)",
 
56
  return v
57
 
58
 
 
 
 
 
 
 
 
59
  class CollabProposal(BaseModel):
60
  """Propose a collaboration with a competitor archetype."""
61
 
 
75
  default_factory=list,
76
  description="Actions scheduled at specific hours; unlisted hours are rest",
77
  )
 
 
 
 
78
  collab: Optional[CollabProposal] = Field(
79
  default=None,
80
  description="Optional collaboration proposal (max 2 per month)",
server/viraltest_environment.py CHANGED
@@ -29,7 +29,6 @@ try:
29
  EngagementSignals,
30
  HeadlineMetrics,
31
  JudgeReport,
32
- ReplyAction,
33
  ScheduledAction,
34
  ToolCall,
35
  ToolResult,
@@ -42,7 +41,6 @@ except ImportError:
42
  EngagementSignals,
43
  HeadlineMetrics,
44
  JudgeReport,
45
- ReplyAction,
46
  ScheduledAction,
47
  ToolCall,
48
  ToolResult,
@@ -168,8 +166,6 @@ COLLAB_GROWTH_K = 1.50 # cross-pollination follower spillover, scales (1 - o
168
  COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
169
  COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
170
 
171
- REPLY_WINDOW_MINUTES = 90
172
- REPLY_REACH_BONUS = 1.4
173
  API_BUDGET_INITIAL = 100
174
 
175
  # Heuristic baselines for headline metric `vs_baseline_pct`.
@@ -847,19 +843,6 @@ class ViraltestEnvironment(Environment):
847
  if self._energy <= 0.0:
848
  burned_out = True
849
 
850
- # Process replies
851
- for reply in action.replies:
852
- if 0 <= reply.reply_hour < 24 and 0 <= reply.post_hour < 24:
853
- diff_minutes = abs(reply.reply_hour - reply.post_hour) * 60
854
- if diff_minutes <= REPLY_WINDOW_MINUTES:
855
- daily_engagement *= REPLY_REACH_BONUS
856
- daily_signals = EngagementSignals(
857
- watch_time=daily_signals.watch_time * REPLY_REACH_BONUS,
858
- sends_per_reach=daily_signals.sends_per_reach * REPLY_REACH_BONUS,
859
- saves=daily_signals.saves * REPLY_REACH_BONUS,
860
- likes_per_reach=daily_signals.likes_per_reach * REPLY_REACH_BONUS,
861
- )
862
-
863
  # Weekly tracking
864
  self._total_posts_this_week += daily_posts
865
  if self._day % 7 == 0 and self._day > 0:
 
29
  EngagementSignals,
30
  HeadlineMetrics,
31
  JudgeReport,
 
32
  ScheduledAction,
33
  ToolCall,
34
  ToolResult,
 
41
  EngagementSignals,
42
  HeadlineMetrics,
43
  JudgeReport,
 
44
  ScheduledAction,
45
  ToolCall,
46
  ToolResult,
 
166
  COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
167
  COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
168
 
 
 
169
  API_BUDGET_INITIAL = 100
170
 
171
  # Heuristic baselines for headline metric `vs_baseline_pct`.
 
843
  if self._energy <= 0.0:
844
  burned_out = True
845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  # Weekly tracking
847
  self._total_posts_this_week += daily_posts
848
  if self._day % 7 == 0 and self._day > 0:
training/hf_run_space_train_job.sh CHANGED
@@ -22,13 +22,12 @@ REMOTE_SCRIPT=$(cat <<'EOS'
22
  set -euo pipefail
23
  export DEBIAN_FRONTEND=noninteractive
24
  apt-get update -qq && apt-get install -y --no-install-recommends git curl ca-certificates
25
- pip install -q --root-user-action=ignore --upgrade "typing_extensions>=4.15.0" jupyter nbconvert nbclient ipykernel huggingface_hub
26
  rm -rf /work
27
  git clone --depth 1 "https://user:${HF_TOKEN}@huggingface.co/spaces/${SPACE_REPO}" /work
28
  cd /work
29
- jupyter nbconvert --to notebook --execute training/train_grpo.ipynb \
30
- --output train_grpo.executed.ipynb \
31
- --ExecutePreprocessor.timeout="${NB_EXEC_TIMEOUT}"
32
  python -c "import os; from huggingface_hub import HfApi; HfApi().upload_folder(folder_path='.', path_in_repo='run-output', repo_id=os.environ['SPACE_REPO'], repo_type='space', allow_patterns=['training/train_grpo.executed.ipynb','plots/**','**/lora-*/**'])"
33
  EOS
34
  )
 
22
  set -euo pipefail
23
  export DEBIAN_FRONTEND=noninteractive
24
  apt-get update -qq && apt-get install -y --no-install-recommends git curl ca-certificates
25
+ pip install -q --root-user-action=ignore --upgrade "typing_extensions>=4.15.0" jupyter nbconvert nbclient ipykernel huggingface_hub papermill
26
  rm -rf /work
27
  git clone --depth 1 "https://user:${HF_TOKEN}@huggingface.co/spaces/${SPACE_REPO}" /work
28
  cd /work
29
+ papermill --log-output --progress-bar --execution-timeout "${NB_EXEC_TIMEOUT}" \
30
+ training/train_grpo.ipynb training/train_grpo.executed.ipynb
 
31
  python -c "import os; from huggingface_hub import HfApi; HfApi().upload_folder(folder_path='.', path_in_repo='run-output', repo_id=os.environ['SPACE_REPO'], repo_type='space', allow_patterns=['training/train_grpo.executed.ipynb','plots/**','**/lora-*/**'])"
32
  EOS
33
  )
training/run_llm_training.py CHANGED
@@ -106,7 +106,6 @@ def plan_smart(obs_dict, day):
106
  ScheduledAction(hour=19, action_type="post", content_type=ct2,
107
  topic=topic2, tags=tags2, intent=intent2),
108
  ],
109
- replies=[{"post_hour": 12, "reply_hour": 13}],
110
  )
111
 
112
  BASELINE_AGENTS = {
@@ -157,17 +156,13 @@ RESPONSE FORMAT — return ONLY valid JSON, no markdown, no explanation:
157
  "scheduled_actions": [
158
  {"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"}
159
  ],
160
- "replies": [{"post_hour": 12, "reply_hour": 13}],
161
  "notes": "strategy notes"
162
  }
163
 
164
  RULES:
165
  - hour: 0-23. content_type: reel|story|carousel|text_post
166
  - intent: send_bait|save_bait|watch_bait|like_bait
167
- - 1-2 posts per day is optimal. More = audience fatigue + energy drain.
168
- - Empty scheduled_actions = rest (recovers energy).
169
- - Vary content types and topics across days for diversity bonus.
170
- - Reply within 90 min of a post for reach bonus.""")
171
 
172
  LEARNED_ADDENDUM = """
173
 
@@ -253,7 +248,7 @@ def parse_model_output(text):
253
  pass
254
  return ViraltestAction(
255
  tool_calls=tool_calls, scheduled_actions=scheduled,
256
- replies=data.get("replies", []), notes=data.get("notes"),
257
  )
258
  except (json.JSONDecodeError, Exception):
259
  return ViraltestAction(scheduled_actions=[])
 
106
  ScheduledAction(hour=19, action_type="post", content_type=ct2,
107
  topic=topic2, tags=tags2, intent=intent2),
108
  ],
 
109
  )
110
 
111
  BASELINE_AGENTS = {
 
156
  "scheduled_actions": [
157
  {"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"}
158
  ],
 
159
  "notes": "strategy notes"
160
  }
161
 
162
  RULES:
163
  - hour: 0-23. content_type: reel|story|carousel|text_post
164
  - intent: send_bait|save_bait|watch_bait|like_bait
165
+ - Empty scheduled_actions = rest (recovers energy).""")
 
 
 
166
 
167
  LEARNED_ADDENDUM = """
168
 
 
248
  pass
249
  return ViraltestAction(
250
  tool_calls=tool_calls, scheduled_actions=scheduled,
251
+ notes=data.get("notes"),
252
  )
253
  except (json.JSONDecodeError, Exception):
254
  return ViraltestAction(scheduled_actions=[])
training/run_training_evidence.py CHANGED
@@ -100,7 +100,6 @@ def plan_smart(obs_dict: dict, day: int) -> ViraltestAction:
100
  ScheduledAction(hour=19, action_type="post", content_type=ct2,
101
  topic=topic2, tags=tags2, intent=intent2),
102
  ],
103
- replies=[{"post_hour": 12, "reply_hour": 13}],
104
  notes=f"Day {day}: varied content at peak hours.",
105
  )
106
 
@@ -156,7 +155,6 @@ class PostingPolicy:
156
  tag_offset: int = 0
157
  topic_offset: int = 0
158
  create_hour: Optional[int] = None
159
- use_reply: bool = False
160
  use_tools_early: bool = False
161
  rest_if_low_energy: float = 0.3
162
 
@@ -186,16 +184,9 @@ class PostingPolicy:
186
  tool_calls.append(ToolCall(name="query_trends",
187
  arguments={"niche": NICHES[day % len(NICHES)]}))
188
 
189
- replies = []
190
- if policy.use_reply and policy.post_hours:
191
- first_post = policy.post_hours[0]
192
- if first_post < 23:
193
- replies = [{"post_hour": first_post, "reply_hour": first_post + 1}]
194
-
195
  return ViraltestAction(
196
  tool_calls=tool_calls,
197
  scheduled_actions=actions,
198
- replies=replies,
199
  notes=f"Day {day}: policy-driven plan.",
200
  )
201
  return plan_fn
@@ -208,13 +199,12 @@ class PostingPolicy:
208
  tag_offset=self.tag_offset,
209
  topic_offset=self.topic_offset,
210
  create_hour=self.create_hour,
211
- use_reply=self.use_reply,
212
  use_tools_early=self.use_tools_early,
213
  rest_if_low_energy=self.rest_if_low_energy,
214
  )
215
 
216
  mutation = rng.choice(["hours", "types", "intents", "tags", "topics",
217
- "create", "reply", "tools", "energy", "n_posts"])
218
 
219
  if mutation == "hours":
220
  child.post_hours = sorted(rng.sample(range(6, 23), min(rng.randint(1, 3), 3)))
@@ -230,8 +220,6 @@ class PostingPolicy:
230
  child.topic_offset = rng.randint(0, len(ALL_TOPICS) - 1)
231
  elif mutation == "create":
232
  child.create_hour = rng.choice([None, 7, 8, 9, 10])
233
- elif mutation == "reply":
234
- child.use_reply = not child.use_reply
235
  elif mutation == "tools":
236
  child.use_tools_early = not child.use_tools_early
237
  elif mutation == "energy":
@@ -262,7 +250,6 @@ def evolutionary_search(
262
  tag_offset=rng.randint(0, len(TAG_POOL) - 1),
263
  topic_offset=rng.randint(0, len(ALL_TOPICS) - 1),
264
  create_hour=rng.choice([None, 7, 8, 9]),
265
- use_reply=rng.random() > 0.5,
266
  use_tools_early=rng.random() > 0.5,
267
  rest_if_low_energy=rng.choice([0.2, 0.25, 0.3, 0.35]),
268
  ) for _ in range(population_size)]
 
100
  ScheduledAction(hour=19, action_type="post", content_type=ct2,
101
  topic=topic2, tags=tags2, intent=intent2),
102
  ],
 
103
  notes=f"Day {day}: varied content at peak hours.",
104
  )
105
 
 
155
  tag_offset: int = 0
156
  topic_offset: int = 0
157
  create_hour: Optional[int] = None
 
158
  use_tools_early: bool = False
159
  rest_if_low_energy: float = 0.3
160
 
 
184
  tool_calls.append(ToolCall(name="query_trends",
185
  arguments={"niche": NICHES[day % len(NICHES)]}))
186
 
 
 
 
 
 
 
187
  return ViraltestAction(
188
  tool_calls=tool_calls,
189
  scheduled_actions=actions,
 
190
  notes=f"Day {day}: policy-driven plan.",
191
  )
192
  return plan_fn
 
199
  tag_offset=self.tag_offset,
200
  topic_offset=self.topic_offset,
201
  create_hour=self.create_hour,
 
202
  use_tools_early=self.use_tools_early,
203
  rest_if_low_energy=self.rest_if_low_energy,
204
  )
205
 
206
  mutation = rng.choice(["hours", "types", "intents", "tags", "topics",
207
+ "create", "tools", "energy", "n_posts"])
208
 
209
  if mutation == "hours":
210
  child.post_hours = sorted(rng.sample(range(6, 23), min(rng.randint(1, 3), 3)))
 
220
  child.topic_offset = rng.randint(0, len(ALL_TOPICS) - 1)
221
  elif mutation == "create":
222
  child.create_hour = rng.choice([None, 7, 8, 9, 10])
 
 
223
  elif mutation == "tools":
224
  child.use_tools_early = not child.use_tools_early
225
  elif mutation == "energy":
 
250
  tag_offset=rng.randint(0, len(TAG_POOL) - 1),
251
  topic_offset=rng.randint(0, len(ALL_TOPICS) - 1),
252
  create_hour=rng.choice([None, 7, 8, 9]),
 
253
  use_tools_early=rng.random() > 0.5,
254
  rest_if_low_energy=rng.choice([0.2, 0.25, 0.3, 0.35]),
255
  ) for _ in range(population_size)]
training/train_grpo.ipynb CHANGED
@@ -301,8 +301,7 @@
301
  " topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n",
302
  " tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n",
303
  " intent=INTENTS[(day*2+1)%4]),\n",
304
- " ],\n",
305
- " replies=[{\"post_hour\": 12, \"reply_hour\": 13}])\n",
306
  "\n",
307
  "BASELINE_AGENTS = {\n",
308
  " \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n",
@@ -570,22 +569,38 @@
570
  "\n",
571
  "RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n",
572
  "{\n",
573
- " \"tool_calls\": [{\"name\": \"query_trends\", \"arguments\": {\"niche\": \"tech\"}}],\n",
574
  " \"scheduled_actions\": [\n",
575
- " {\"hour\": 12, \"action_type\": \"post\", \"content_type\": \"reel\",\n",
576
- " \"topic\": \"AI tools\", \"tags\": [\"ai\", \"coding\"], \"intent\": \"watch_bait\"}\n",
 
 
577
  " ],\n",
578
- " \"replies\": [{\"post_hour\": 12, \"reply_hour\": 13}],\n",
579
  " \"notes\": \"strategy notes\"\n",
580
  "}\n",
581
  "\n",
582
- "RULES:\n",
583
- "- content_type: reel|story|carousel|text_post\n",
584
- "- intent: send_bait|save_bait|watch_bait|like_bait\n",
585
- "- 1-2 posts/day optimal. More = fatigue.\n",
586
- "- Empty scheduled_actions = rest (recovers energy).\n",
587
- "- Vary content types and topics for diversity bonus.\n",
588
- "- Reply within 90 min of post for reach bonus.\"\"\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  "\n",
590
  "\n",
591
  "def format_obs(obs):\n",
@@ -600,7 +615,7 @@
600
  " tool_str = \"\"\n",
601
  " for tr in getattr(obs, \"tool_results\", []):\n",
602
  " if tr.success:\n",
603
- " tool_str += f\" {tr.name}: {json.dumps(tr.data)[:200]}\\n\"\n",
604
  " if not tool_str:\n",
605
  " tool_str = \" (none)\\n\"\n",
606
  " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
@@ -633,7 +648,6 @@
633
  " return ViraltestAction(\n",
634
  " tool_calls=tool_calls,\n",
635
  " scheduled_actions=scheduled,\n",
636
- " replies=data.get(\"replies\", []),\n",
637
  " notes=data.get(\"notes\"),\n",
638
  " )\n",
639
  " except Exception:\n",
@@ -652,10 +666,10 @@
652
  " return torch.device(\"cpu\")\n",
653
  "\n",
654
  "\n",
655
- "def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
656
  " prompt = format_obs(obs)\n",
657
  " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
658
- " messages.extend(history[-4:])\n",
659
  " messages.append({\"role\": \"user\", \"content\": prompt})\n",
660
  " text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
661
  " inputs = tok(text_input, return_tensors=\"pt\").to(_infer_model_device(mdl))\n",
@@ -663,21 +677,27 @@
663
  " out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
664
  " do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
665
  " resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
 
 
 
 
 
 
 
 
666
  " return resp, parse_model_output(resp)\n",
667
  "\n",
668
  "\n",
669
- "def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
670
  " env = ViraltestEnvironment()\n",
671
  " obs = env.reset(task=task, seed=seed)\n",
672
  " rewards, energies = [], [obs.creator_energy]\n",
673
  " history, pairs = [], []\n",
674
  " for day in range(1, TASK_HORIZON + 1):\n",
675
  " if obs.done: break\n",
676
- " if obs.creator_energy <= 0.25:\n",
677
- " action = ViraltestAction(scheduled_actions=[])\n",
678
- " resp = '{\"scheduled_actions\": []}'\n",
679
- " else:\n",
680
- " resp, action = generate_action(mdl, tok, obs, history)\n",
681
  " prompt = format_obs(obs)\n",
682
  " pairs.append({\"prompt\": prompt, \"response\": resp})\n",
683
  " obs = env.step(action)\n",
@@ -691,9 +711,17 @@
691
  " print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
692
  " if obs.done: break\n",
693
  " gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
 
 
 
 
 
 
 
 
694
  " return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
695
  " \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
696
- " \"energies\": energies, \"pairs\": pairs,\n",
697
  " \"follower_delta\": obs.follower_count - 10000,\n",
698
  " \"burned_out\": obs.creator_energy <= 0}\n",
699
  "\n",
@@ -778,8 +806,8 @@
778
  "from trl import SFTTrainer, SFTConfig\n",
779
  "from datasets import Dataset\n",
780
  "\n",
781
- "NUM_ROUNDS = 4\n",
782
- "EPISODES_PER_ROUND = 6\n",
783
  "TOP_K_FRACTION = 0.5\n",
784
  "\n",
785
  "training_log = {\n",
@@ -811,19 +839,21 @@
811
  " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
812
  " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
813
  " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
814
- " all_pairs.append({\"text\": text, \"reward\": ep_reward})\n",
815
  "\n",
 
816
  " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
817
- " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f}\")\n",
 
818
  "\n",
819
  " avg_r = np.mean(episode_rewards)\n",
820
  " avg_g = np.mean(episode_graders)\n",
821
  " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
822
  "\n",
823
- " # Filter to top-K\n",
824
  " threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
825
  " filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
826
- " print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples\")\n",
827
  "\n",
828
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
829
  "\n",
@@ -831,14 +861,18 @@
831
  " sft_config = SFTConfig(\n",
832
  " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
833
  " num_train_epochs=2,\n",
834
- " per_device_train_batch_size=1,\n",
835
- " gradient_accumulation_steps=4,\n",
836
  " learning_rate=2e-5,\n",
837
- " warmup_steps=5,\n",
838
- " logging_steps=5,\n",
839
  " save_strategy=\"no\",\n",
840
  " max_length=1024,\n",
841
- " fp16=True,\n",
 
 
 
 
842
  " report_to=\"none\",\n",
843
  " )\n",
844
  "\n",
@@ -1082,7 +1116,7 @@
1082
  "name": "python",
1083
  "nbconvert_exporter": "python",
1084
  "pygments_lexer": "ipython3",
1085
- "version": "3.14.2"
1086
  }
1087
  },
1088
  "nbformat": 4,
 
301
  " topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n",
302
  " tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n",
303
  " intent=INTENTS[(day*2+1)%4]),\n",
304
+ " ])\n",
 
305
  "\n",
306
  "BASELINE_AGENTS = {\n",
307
  " \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n",
 
569
  "\n",
570
  "RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n",
571
  "{\n",
572
+ " \"tool_calls\": [{\"name\": \"<tool>\", \"arguments\": {...}}],\n",
573
  " \"scheduled_actions\": [\n",
574
+ " {\"hour\": 0-23, \"action_type\": \"post|create_content\",\n",
575
+ " \"content_type\": \"reel|story|carousel|text_post\",\n",
576
+ " \"topic\": \"<string>\", \"tags\": [\"...\"],\n",
577
+ " \"intent\": \"send_bait|save_bait|watch_bait|like_bait\"}\n",
578
  " ],\n",
 
579
  " \"notes\": \"strategy notes\"\n",
580
  "}\n",
581
  "\n",
582
+ "TOOLS (cost in API budget, total=100):\n",
583
+ "- query_trends(niche) cost=1 trending topics+tags for niche\n",
584
+ "- query_audience(segment_id) cost=2 segment topic affinities + active hours\n",
585
+ "- query_competitor(competitor_id, window_days) cost=2 competitor recent posts\n",
586
+ "- query_tag_history(tag) cost=1 your past signals (watch/sends/saves/likes) for a tag\n",
587
+ "- predict_engagement(scheduled_actions) cost=3 simulate a plan WITHOUT committing\n",
588
+ "- draft_review(scheduled_actions) cost=3 AI review of a draft plan\n",
589
+ "- query_creator_pool() cost=1 list collab partners with audience overlap\n",
590
+ "- propose_collab(partner_id, content_type, hour) cost=5 co-author the post at that hour (max 2/month)\n",
591
+ "\n",
592
+ "ACTION SCHEMA:\n",
593
+ "- hour: 0..23 (unlisted hours = rest)\n",
594
+ "- action_type: post (publish) | create_content (build queue, no publish)\n",
595
+ "- content_type: reel | story | carousel | text_post\n",
596
+ "- intent: which Mosseri signal the post optimises for\n",
597
+ " send_bait -> DM shares (strongest discovery signal)\n",
598
+ " save_bait -> bookmarks (content quality)\n",
599
+ " watch_bait -> reels watch time\n",
600
+ " like_bait -> likes from existing followers\n",
601
+ "- tags: up to 5 hashtags\n",
602
+ "- topic: free-form string\n",
603
+ "- empty scheduled_actions = full day rest\"\"\")\n",
604
  "\n",
605
  "\n",
606
  "def format_obs(obs):\n",
 
615
  " tool_str = \"\"\n",
616
  " for tr in getattr(obs, \"tool_results\", []):\n",
617
  " if tr.success:\n",
618
+ " tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
619
  " if not tool_str:\n",
620
  " tool_str = \" (none)\\n\"\n",
621
  " return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
 
648
  " return ViraltestAction(\n",
649
  " tool_calls=tool_calls,\n",
650
  " scheduled_actions=scheduled,\n",
 
651
  " notes=data.get(\"notes\"),\n",
652
  " )\n",
653
  " except Exception:\n",
 
666
  " return torch.device(\"cpu\")\n",
667
  "\n",
668
  "\n",
669
+ "def generate_action(mdl, tok, obs, history, temperature=0.7, debug=True):\n",
670
  " prompt = format_obs(obs)\n",
671
  " messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
672
+ " messages.extend(history[-14:])\n",
673
  " messages.append({\"role\": \"user\", \"content\": prompt})\n",
674
  " text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
675
  " inputs = tok(text_input, return_tensors=\"pt\").to(_infer_model_device(mdl))\n",
 
677
  " out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
678
  " do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
679
  " resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
680
+ " if debug:\n",
681
+ " print(\"=\" * 60)\n",
682
+ " print(f\"[LLM PROMPT] tokens={inputs['input_ids'].shape[1]}\")\n",
683
+ " print(prompt)\n",
684
+ " print(\"-\" * 60)\n",
685
+ " print(f\"[LLM RESPONSE] tokens={out.shape[1] - inputs['input_ids'].shape[1]}\")\n",
686
+ " print(resp)\n",
687
+ " print(\"=\" * 60)\n",
688
  " return resp, parse_model_output(resp)\n",
689
  "\n",
690
  "\n",
691
+ "def run_llm_episode(mdl, tok, task, seed=42, verbose=False, debug_llm=True):\n",
692
  " env = ViraltestEnvironment()\n",
693
  " obs = env.reset(task=task, seed=seed)\n",
694
  " rewards, energies = [], [obs.creator_energy]\n",
695
  " history, pairs = [], []\n",
696
  " for day in range(1, TASK_HORIZON + 1):\n",
697
  " if obs.done: break\n",
698
+ " if debug_llm:\n",
699
+ " print(f\"\\n>>> Day {day} | task={task} | energy={obs.creator_energy:.2f}\")\n",
700
+ " resp, action = generate_action(mdl, tok, obs, history, debug=debug_llm)\n",
 
 
701
  " prompt = format_obs(obs)\n",
702
  " pairs.append({\"prompt\": prompt, \"response\": resp})\n",
703
  " obs = env.step(action)\n",
 
711
  " print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
712
  " if obs.done: break\n",
713
  " gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
714
+ " # Per-step credit assignment: G_t = r_t + gamma * G_{t+1}, terminal = grader_score * w\n",
715
+ " GAMMA, TERMINAL_W = 0.95, 5.0\n",
716
+ " G, returns = gs * TERMINAL_W, [0.0] * len(rewards)\n",
717
+ " for t in reversed(range(len(rewards))):\n",
718
+ " G = rewards[t] + GAMMA * G\n",
719
+ " returns[t] = G\n",
720
+ " for i, pr in enumerate(pairs):\n",
721
+ " pr[\"return\"] = returns[i] if i < len(returns) else 0.0\n",
722
  " return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
723
  " \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
724
+ " \"returns\": returns, \"energies\": energies, \"pairs\": pairs,\n",
725
  " \"follower_delta\": obs.follower_count - 10000,\n",
726
  " \"burned_out\": obs.creator_energy <= 0}\n",
727
  "\n",
 
806
  "from trl import SFTTrainer, SFTConfig\n",
807
  "from datasets import Dataset\n",
808
  "\n",
809
+ "NUM_ROUNDS = 1\n",
810
+ "EPISODES_PER_ROUND = 1\n",
811
  "TOP_K_FRACTION = 0.5\n",
812
  "\n",
813
  "training_log = {\n",
 
839
  " text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
840
  " f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
841
  " f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
842
+ " all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
843
  "\n",
844
+ " rets = result[\"returns\"]\n",
845
  " print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
846
+ " f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} \"\n",
847
+ " f\"return[min={min(rets):.2f} max={max(rets):.2f} mean={np.mean(rets):.2f}]\")\n",
848
  "\n",
849
  " avg_r = np.mean(episode_rewards)\n",
850
  " avg_g = np.mean(episode_graders)\n",
851
  " print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
852
  "\n",
853
+ " # Filter to top-K by per-pair return (per-step credit assignment)\n",
854
  " threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
855
  " filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
856
+ " print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples (return >= {threshold:.3f})\")\n",
857
  "\n",
858
  " dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
859
  "\n",
 
861
  " sft_config = SFTConfig(\n",
862
  " output_dir=f\"./checkpoints/round_{round_idx}\",\n",
863
  " num_train_epochs=2,\n",
864
+ " per_device_train_batch_size=32,\n",
865
+ " gradient_accumulation_steps=1,\n",
866
  " learning_rate=2e-5,\n",
867
+ " warmup_ratio=0.1,\n",
868
+ " logging_steps=1,\n",
869
  " save_strategy=\"no\",\n",
870
  " max_length=1024,\n",
871
+ " bf16=True,\n",
872
+ " gradient_checkpointing=False,\n",
873
+ " dataloader_num_workers=4,\n",
874
+ " dataloader_pin_memory=True,\n",
875
+ " optim=\"adamw_torch_fused\",\n",
876
  " report_to=\"none\",\n",
877
  " )\n",
878
  "\n",
 
1116
  "name": "python",
1117
  "nbconvert_exporter": "python",
1118
  "pygments_lexer": "ipython3",
1119
+ "version": "3.13.1"
1120
  }
1121
  },
1122
  "nbformat": 4,