Spaces:
Paused
train: per-step credit + drop replies + larger batches
Browse files- run_llm_episode now computes Monte-Carlo return-to-go per day
(G_t = r_t + 0.95*G_{t+1}, terminal = grader_score*5); each pair
gets its own return for top-K filtering, removing the
same-reward-per-pair caveat.
- SFTConfig: per_device_train_batch_size 16->32, bf16, fused AdamW,
warmup_ratio=0.1, no grad checkpointing for ~90% VRAM use on 48GB.
- SYSTEM_PROMPT rewritten as full tool/action schema; removed
subjective rules (optimal posts, diversity bonus); reach-bonus
clarified.
- format_obs no longer truncates tool results; history window 4->14
messages (7 days); removed hard-coded auto-rest heuristic.
- Drop replies feature across env, models, client, inference, and
training scripts.
- NUM_ROUNDS/EPISODES_PER_ROUND set to 1 for smoke run.
Made-with: Cursor
- __init__.py +0 -2
- client.py +0 -6
- inference.py +0 -5
- models.py +0 -11
- server/viraltest_environment.py +0 -17
- training/hf_run_space_train_job.sh +3 -4
- training/run_llm_training.py +2 -7
- training/run_training_evidence.py +1 -14
- training/train_grpo.ipynb +70 -36
|
@@ -10,7 +10,6 @@ from .client import ViraltestEnv
|
|
| 10 |
from .models import (
|
| 11 |
CollabProposal,
|
| 12 |
EngagementSignals,
|
| 13 |
-
ReplyAction,
|
| 14 |
ScheduledAction,
|
| 15 |
ToolCall,
|
| 16 |
ToolResult,
|
|
@@ -21,7 +20,6 @@ from .models import (
|
|
| 21 |
__all__ = [
|
| 22 |
"CollabProposal",
|
| 23 |
"EngagementSignals",
|
| 24 |
-
"ReplyAction",
|
| 25 |
"ScheduledAction",
|
| 26 |
"ToolCall",
|
| 27 |
"ToolResult",
|
|
|
|
| 10 |
from .models import (
|
| 11 |
CollabProposal,
|
| 12 |
EngagementSignals,
|
|
|
|
| 13 |
ScheduledAction,
|
| 14 |
ToolCall,
|
| 15 |
ToolResult,
|
|
|
|
| 20 |
__all__ = [
|
| 21 |
"CollabProposal",
|
| 22 |
"EngagementSignals",
|
|
|
|
| 23 |
"ScheduledAction",
|
| 24 |
"ToolCall",
|
| 25 |
"ToolResult",
|
|
@@ -43,12 +43,6 @@ class ViraltestEnv(EnvClient[ViraltestAction, ViraltestObservation, State]):
|
|
| 43 |
actions_list.append(item)
|
| 44 |
payload["scheduled_actions"] = actions_list
|
| 45 |
|
| 46 |
-
if action.replies:
|
| 47 |
-
payload["replies"] = [
|
| 48 |
-
{"post_hour": r.post_hour, "reply_hour": r.reply_hour}
|
| 49 |
-
for r in action.replies
|
| 50 |
-
]
|
| 51 |
-
|
| 52 |
if action.collab:
|
| 53 |
payload["collab"] = {
|
| 54 |
"partner_id": action.collab.partner_id,
|
|
|
|
| 43 |
actions_list.append(item)
|
| 44 |
payload["scheduled_actions"] = actions_list
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
if action.collab:
|
| 47 |
payload["collab"] = {
|
| 48 |
"partner_id": action.collab.partner_id,
|
|
@@ -74,7 +74,6 @@ RESPONSE FORMAT (JSON only, no markdown, no prose):
|
|
| 74 |
{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"},
|
| 75 |
{"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup life", "tags": ["startup", "growth"], "intent": "save_bait"}
|
| 76 |
],
|
| 77 |
-
"replies": [{"post_hour": 12, "reply_hour": 13}],
|
| 78 |
"notes": "Day 3: tech niche trending up. Competitor Alpha posted at 10am. Avoiding overlap."
|
| 79 |
}
|
| 80 |
|
|
@@ -87,7 +86,6 @@ RULES:
|
|
| 87 |
- Use notes to track hypotheses and observations across days
|
| 88 |
- Tool calls cost API budget (starts at 100). Use wisely.
|
| 89 |
- Max 2 collaborations per month
|
| 90 |
-
- Reply within 90 minutes of a post for reach bonus
|
| 91 |
|
| 92 |
Think strategically: use tools to discover what works, then exploit what you learn.""")
|
| 93 |
|
|
@@ -201,13 +199,11 @@ def parse_daily_plan(response_text: str) -> ViraltestAction:
|
|
| 201 |
if isinstance(a, dict):
|
| 202 |
scheduled.append(a)
|
| 203 |
|
| 204 |
-
replies_raw = data.get("replies", [])
|
| 205 |
notes = data.get("notes")
|
| 206 |
|
| 207 |
return ViraltestAction(
|
| 208 |
tool_calls=tool_calls,
|
| 209 |
scheduled_actions=scheduled,
|
| 210 |
-
replies=replies_raw if isinstance(replies_raw, list) else [],
|
| 211 |
notes=notes,
|
| 212 |
)
|
| 213 |
except (json.JSONDecodeError, Exception):
|
|
@@ -236,7 +232,6 @@ def sanitize_predefined_topics(action: ViraltestAction, obs: Any) -> ViraltestAc
|
|
| 236 |
return ViraltestAction(
|
| 237 |
tool_calls=action.tool_calls,
|
| 238 |
scheduled_actions=out,
|
| 239 |
-
replies=action.replies,
|
| 240 |
collab=action.collab,
|
| 241 |
notes=action.notes,
|
| 242 |
)
|
|
|
|
| 74 |
{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"},
|
| 75 |
{"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup life", "tags": ["startup", "growth"], "intent": "save_bait"}
|
| 76 |
],
|
|
|
|
| 77 |
"notes": "Day 3: tech niche trending up. Competitor Alpha posted at 10am. Avoiding overlap."
|
| 78 |
}
|
| 79 |
|
|
|
|
| 86 |
- Use notes to track hypotheses and observations across days
|
| 87 |
- Tool calls cost API budget (starts at 100). Use wisely.
|
| 88 |
- Max 2 collaborations per month
|
|
|
|
| 89 |
|
| 90 |
Think strategically: use tools to discover what works, then exploit what you learn.""")
|
| 91 |
|
|
|
|
| 199 |
if isinstance(a, dict):
|
| 200 |
scheduled.append(a)
|
| 201 |
|
|
|
|
| 202 |
notes = data.get("notes")
|
| 203 |
|
| 204 |
return ViraltestAction(
|
| 205 |
tool_calls=tool_calls,
|
| 206 |
scheduled_actions=scheduled,
|
|
|
|
| 207 |
notes=notes,
|
| 208 |
)
|
| 209 |
except (json.JSONDecodeError, Exception):
|
|
|
|
| 232 |
return ViraltestAction(
|
| 233 |
tool_calls=action.tool_calls,
|
| 234 |
scheduled_actions=out,
|
|
|
|
| 235 |
collab=action.collab,
|
| 236 |
notes=action.notes,
|
| 237 |
)
|
|
@@ -56,13 +56,6 @@ class ScheduledAction(BaseModel):
|
|
| 56 |
return v
|
| 57 |
|
| 58 |
|
| 59 |
-
class ReplyAction(BaseModel):
|
| 60 |
-
"""Reply to comments on a post made earlier today (within reply window)."""
|
| 61 |
-
|
| 62 |
-
post_hour: int = Field(..., ge=0, le=23, description="Hour of the post to reply on")
|
| 63 |
-
reply_hour: int = Field(..., ge=0, le=23, description="Hour to send replies")
|
| 64 |
-
|
| 65 |
-
|
| 66 |
class CollabProposal(BaseModel):
|
| 67 |
"""Propose a collaboration with a competitor archetype."""
|
| 68 |
|
|
@@ -82,10 +75,6 @@ class ViraltestAction(Action):
|
|
| 82 |
default_factory=list,
|
| 83 |
description="Actions scheduled at specific hours; unlisted hours are rest",
|
| 84 |
)
|
| 85 |
-
replies: List[ReplyAction] = Field(
|
| 86 |
-
default_factory=list,
|
| 87 |
-
description="Reply actions on posts made today (within 90-min window for reach bonus)",
|
| 88 |
-
)
|
| 89 |
collab: Optional[CollabProposal] = Field(
|
| 90 |
default=None,
|
| 91 |
description="Optional collaboration proposal (max 2 per month)",
|
|
|
|
| 56 |
return v
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
class CollabProposal(BaseModel):
|
| 60 |
"""Propose a collaboration with a competitor archetype."""
|
| 61 |
|
|
|
|
| 75 |
default_factory=list,
|
| 76 |
description="Actions scheduled at specific hours; unlisted hours are rest",
|
| 77 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
collab: Optional[CollabProposal] = Field(
|
| 79 |
default=None,
|
| 80 |
description="Optional collaboration proposal (max 2 per month)",
|
|
@@ -29,7 +29,6 @@ try:
|
|
| 29 |
EngagementSignals,
|
| 30 |
HeadlineMetrics,
|
| 31 |
JudgeReport,
|
| 32 |
-
ReplyAction,
|
| 33 |
ScheduledAction,
|
| 34 |
ToolCall,
|
| 35 |
ToolResult,
|
|
@@ -42,7 +41,6 @@ except ImportError:
|
|
| 42 |
EngagementSignals,
|
| 43 |
HeadlineMetrics,
|
| 44 |
JudgeReport,
|
| 45 |
-
ReplyAction,
|
| 46 |
ScheduledAction,
|
| 47 |
ToolCall,
|
| 48 |
ToolResult,
|
|
@@ -168,8 +166,6 @@ COLLAB_GROWTH_K = 1.50 # cross-pollination follower spillover, scales (1 - o
|
|
| 168 |
COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
|
| 169 |
COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
|
| 170 |
|
| 171 |
-
REPLY_WINDOW_MINUTES = 90
|
| 172 |
-
REPLY_REACH_BONUS = 1.4
|
| 173 |
API_BUDGET_INITIAL = 100
|
| 174 |
|
| 175 |
# Heuristic baselines for headline metric `vs_baseline_pct`.
|
|
@@ -847,19 +843,6 @@ class ViraltestEnvironment(Environment):
|
|
| 847 |
if self._energy <= 0.0:
|
| 848 |
burned_out = True
|
| 849 |
|
| 850 |
-
# Process replies
|
| 851 |
-
for reply in action.replies:
|
| 852 |
-
if 0 <= reply.reply_hour < 24 and 0 <= reply.post_hour < 24:
|
| 853 |
-
diff_minutes = abs(reply.reply_hour - reply.post_hour) * 60
|
| 854 |
-
if diff_minutes <= REPLY_WINDOW_MINUTES:
|
| 855 |
-
daily_engagement *= REPLY_REACH_BONUS
|
| 856 |
-
daily_signals = EngagementSignals(
|
| 857 |
-
watch_time=daily_signals.watch_time * REPLY_REACH_BONUS,
|
| 858 |
-
sends_per_reach=daily_signals.sends_per_reach * REPLY_REACH_BONUS,
|
| 859 |
-
saves=daily_signals.saves * REPLY_REACH_BONUS,
|
| 860 |
-
likes_per_reach=daily_signals.likes_per_reach * REPLY_REACH_BONUS,
|
| 861 |
-
)
|
| 862 |
-
|
| 863 |
# Weekly tracking
|
| 864 |
self._total_posts_this_week += daily_posts
|
| 865 |
if self._day % 7 == 0 and self._day > 0:
|
|
|
|
| 29 |
EngagementSignals,
|
| 30 |
HeadlineMetrics,
|
| 31 |
JudgeReport,
|
|
|
|
| 32 |
ScheduledAction,
|
| 33 |
ToolCall,
|
| 34 |
ToolResult,
|
|
|
|
| 41 |
EngagementSignals,
|
| 42 |
HeadlineMetrics,
|
| 43 |
JudgeReport,
|
|
|
|
| 44 |
ScheduledAction,
|
| 45 |
ToolCall,
|
| 46 |
ToolResult,
|
|
|
|
| 166 |
COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
|
| 167 |
COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
|
| 168 |
|
|
|
|
|
|
|
| 169 |
API_BUDGET_INITIAL = 100
|
| 170 |
|
| 171 |
# Heuristic baselines for headline metric `vs_baseline_pct`.
|
|
|
|
| 843 |
if self._energy <= 0.0:
|
| 844 |
burned_out = True
|
| 845 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 846 |
# Weekly tracking
|
| 847 |
self._total_posts_this_week += daily_posts
|
| 848 |
if self._day % 7 == 0 and self._day > 0:
|
|
@@ -22,13 +22,12 @@ REMOTE_SCRIPT=$(cat <<'EOS'
|
|
| 22 |
set -euo pipefail
|
| 23 |
export DEBIAN_FRONTEND=noninteractive
|
| 24 |
apt-get update -qq && apt-get install -y --no-install-recommends git curl ca-certificates
|
| 25 |
-
pip install -q --root-user-action=ignore --upgrade "typing_extensions>=4.15.0" jupyter nbconvert nbclient ipykernel huggingface_hub
|
| 26 |
rm -rf /work
|
| 27 |
git clone --depth 1 "https://user:${HF_TOKEN}@huggingface.co/spaces/${SPACE_REPO}" /work
|
| 28 |
cd /work
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
--ExecutePreprocessor.timeout="${NB_EXEC_TIMEOUT}"
|
| 32 |
python -c "import os; from huggingface_hub import HfApi; HfApi().upload_folder(folder_path='.', path_in_repo='run-output', repo_id=os.environ['SPACE_REPO'], repo_type='space', allow_patterns=['training/train_grpo.executed.ipynb','plots/**','**/lora-*/**'])"
|
| 33 |
EOS
|
| 34 |
)
|
|
|
|
| 22 |
set -euo pipefail
|
| 23 |
export DEBIAN_FRONTEND=noninteractive
|
| 24 |
apt-get update -qq && apt-get install -y --no-install-recommends git curl ca-certificates
|
| 25 |
+
pip install -q --root-user-action=ignore --upgrade "typing_extensions>=4.15.0" jupyter nbconvert nbclient ipykernel huggingface_hub papermill
|
| 26 |
rm -rf /work
|
| 27 |
git clone --depth 1 "https://user:${HF_TOKEN}@huggingface.co/spaces/${SPACE_REPO}" /work
|
| 28 |
cd /work
|
| 29 |
+
papermill --log-output --progress-bar --execution-timeout "${NB_EXEC_TIMEOUT}" \
|
| 30 |
+
training/train_grpo.ipynb training/train_grpo.executed.ipynb
|
|
|
|
| 31 |
python -c "import os; from huggingface_hub import HfApi; HfApi().upload_folder(folder_path='.', path_in_repo='run-output', repo_id=os.environ['SPACE_REPO'], repo_type='space', allow_patterns=['training/train_grpo.executed.ipynb','plots/**','**/lora-*/**'])"
|
| 32 |
EOS
|
| 33 |
)
|
|
@@ -106,7 +106,6 @@ def plan_smart(obs_dict, day):
|
|
| 106 |
ScheduledAction(hour=19, action_type="post", content_type=ct2,
|
| 107 |
topic=topic2, tags=tags2, intent=intent2),
|
| 108 |
],
|
| 109 |
-
replies=[{"post_hour": 12, "reply_hour": 13}],
|
| 110 |
)
|
| 111 |
|
| 112 |
BASELINE_AGENTS = {
|
|
@@ -157,17 +156,13 @@ RESPONSE FORMAT — return ONLY valid JSON, no markdown, no explanation:
|
|
| 157 |
"scheduled_actions": [
|
| 158 |
{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"}
|
| 159 |
],
|
| 160 |
-
"replies": [{"post_hour": 12, "reply_hour": 13}],
|
| 161 |
"notes": "strategy notes"
|
| 162 |
}
|
| 163 |
|
| 164 |
RULES:
|
| 165 |
- hour: 0-23. content_type: reel|story|carousel|text_post
|
| 166 |
- intent: send_bait|save_bait|watch_bait|like_bait
|
| 167 |
-
-
|
| 168 |
-
- Empty scheduled_actions = rest (recovers energy).
|
| 169 |
-
- Vary content types and topics across days for diversity bonus.
|
| 170 |
-
- Reply within 90 min of a post for reach bonus.""")
|
| 171 |
|
| 172 |
LEARNED_ADDENDUM = """
|
| 173 |
|
|
@@ -253,7 +248,7 @@ def parse_model_output(text):
|
|
| 253 |
pass
|
| 254 |
return ViraltestAction(
|
| 255 |
tool_calls=tool_calls, scheduled_actions=scheduled,
|
| 256 |
-
|
| 257 |
)
|
| 258 |
except (json.JSONDecodeError, Exception):
|
| 259 |
return ViraltestAction(scheduled_actions=[])
|
|
|
|
| 106 |
ScheduledAction(hour=19, action_type="post", content_type=ct2,
|
| 107 |
topic=topic2, tags=tags2, intent=intent2),
|
| 108 |
],
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
BASELINE_AGENTS = {
|
|
|
|
| 156 |
"scheduled_actions": [
|
| 157 |
{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"], "intent": "watch_bait"}
|
| 158 |
],
|
|
|
|
| 159 |
"notes": "strategy notes"
|
| 160 |
}
|
| 161 |
|
| 162 |
RULES:
|
| 163 |
- hour: 0-23. content_type: reel|story|carousel|text_post
|
| 164 |
- intent: send_bait|save_bait|watch_bait|like_bait
|
| 165 |
+
- Empty scheduled_actions = rest (recovers energy).""")
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
LEARNED_ADDENDUM = """
|
| 168 |
|
|
|
|
| 248 |
pass
|
| 249 |
return ViraltestAction(
|
| 250 |
tool_calls=tool_calls, scheduled_actions=scheduled,
|
| 251 |
+
notes=data.get("notes"),
|
| 252 |
)
|
| 253 |
except (json.JSONDecodeError, Exception):
|
| 254 |
return ViraltestAction(scheduled_actions=[])
|
|
@@ -100,7 +100,6 @@ def plan_smart(obs_dict: dict, day: int) -> ViraltestAction:
|
|
| 100 |
ScheduledAction(hour=19, action_type="post", content_type=ct2,
|
| 101 |
topic=topic2, tags=tags2, intent=intent2),
|
| 102 |
],
|
| 103 |
-
replies=[{"post_hour": 12, "reply_hour": 13}],
|
| 104 |
notes=f"Day {day}: varied content at peak hours.",
|
| 105 |
)
|
| 106 |
|
|
@@ -156,7 +155,6 @@ class PostingPolicy:
|
|
| 156 |
tag_offset: int = 0
|
| 157 |
topic_offset: int = 0
|
| 158 |
create_hour: Optional[int] = None
|
| 159 |
-
use_reply: bool = False
|
| 160 |
use_tools_early: bool = False
|
| 161 |
rest_if_low_energy: float = 0.3
|
| 162 |
|
|
@@ -186,16 +184,9 @@ class PostingPolicy:
|
|
| 186 |
tool_calls.append(ToolCall(name="query_trends",
|
| 187 |
arguments={"niche": NICHES[day % len(NICHES)]}))
|
| 188 |
|
| 189 |
-
replies = []
|
| 190 |
-
if policy.use_reply and policy.post_hours:
|
| 191 |
-
first_post = policy.post_hours[0]
|
| 192 |
-
if first_post < 23:
|
| 193 |
-
replies = [{"post_hour": first_post, "reply_hour": first_post + 1}]
|
| 194 |
-
|
| 195 |
return ViraltestAction(
|
| 196 |
tool_calls=tool_calls,
|
| 197 |
scheduled_actions=actions,
|
| 198 |
-
replies=replies,
|
| 199 |
notes=f"Day {day}: policy-driven plan.",
|
| 200 |
)
|
| 201 |
return plan_fn
|
|
@@ -208,13 +199,12 @@ class PostingPolicy:
|
|
| 208 |
tag_offset=self.tag_offset,
|
| 209 |
topic_offset=self.topic_offset,
|
| 210 |
create_hour=self.create_hour,
|
| 211 |
-
use_reply=self.use_reply,
|
| 212 |
use_tools_early=self.use_tools_early,
|
| 213 |
rest_if_low_energy=self.rest_if_low_energy,
|
| 214 |
)
|
| 215 |
|
| 216 |
mutation = rng.choice(["hours", "types", "intents", "tags", "topics",
|
| 217 |
-
"create", "
|
| 218 |
|
| 219 |
if mutation == "hours":
|
| 220 |
child.post_hours = sorted(rng.sample(range(6, 23), min(rng.randint(1, 3), 3)))
|
|
@@ -230,8 +220,6 @@ class PostingPolicy:
|
|
| 230 |
child.topic_offset = rng.randint(0, len(ALL_TOPICS) - 1)
|
| 231 |
elif mutation == "create":
|
| 232 |
child.create_hour = rng.choice([None, 7, 8, 9, 10])
|
| 233 |
-
elif mutation == "reply":
|
| 234 |
-
child.use_reply = not child.use_reply
|
| 235 |
elif mutation == "tools":
|
| 236 |
child.use_tools_early = not child.use_tools_early
|
| 237 |
elif mutation == "energy":
|
|
@@ -262,7 +250,6 @@ def evolutionary_search(
|
|
| 262 |
tag_offset=rng.randint(0, len(TAG_POOL) - 1),
|
| 263 |
topic_offset=rng.randint(0, len(ALL_TOPICS) - 1),
|
| 264 |
create_hour=rng.choice([None, 7, 8, 9]),
|
| 265 |
-
use_reply=rng.random() > 0.5,
|
| 266 |
use_tools_early=rng.random() > 0.5,
|
| 267 |
rest_if_low_energy=rng.choice([0.2, 0.25, 0.3, 0.35]),
|
| 268 |
) for _ in range(population_size)]
|
|
|
|
| 100 |
ScheduledAction(hour=19, action_type="post", content_type=ct2,
|
| 101 |
topic=topic2, tags=tags2, intent=intent2),
|
| 102 |
],
|
|
|
|
| 103 |
notes=f"Day {day}: varied content at peak hours.",
|
| 104 |
)
|
| 105 |
|
|
|
|
| 155 |
tag_offset: int = 0
|
| 156 |
topic_offset: int = 0
|
| 157 |
create_hour: Optional[int] = None
|
|
|
|
| 158 |
use_tools_early: bool = False
|
| 159 |
rest_if_low_energy: float = 0.3
|
| 160 |
|
|
|
|
| 184 |
tool_calls.append(ToolCall(name="query_trends",
|
| 185 |
arguments={"niche": NICHES[day % len(NICHES)]}))
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
return ViraltestAction(
|
| 188 |
tool_calls=tool_calls,
|
| 189 |
scheduled_actions=actions,
|
|
|
|
| 190 |
notes=f"Day {day}: policy-driven plan.",
|
| 191 |
)
|
| 192 |
return plan_fn
|
|
|
|
| 199 |
tag_offset=self.tag_offset,
|
| 200 |
topic_offset=self.topic_offset,
|
| 201 |
create_hour=self.create_hour,
|
|
|
|
| 202 |
use_tools_early=self.use_tools_early,
|
| 203 |
rest_if_low_energy=self.rest_if_low_energy,
|
| 204 |
)
|
| 205 |
|
| 206 |
mutation = rng.choice(["hours", "types", "intents", "tags", "topics",
|
| 207 |
+
"create", "tools", "energy", "n_posts"])
|
| 208 |
|
| 209 |
if mutation == "hours":
|
| 210 |
child.post_hours = sorted(rng.sample(range(6, 23), min(rng.randint(1, 3), 3)))
|
|
|
|
| 220 |
child.topic_offset = rng.randint(0, len(ALL_TOPICS) - 1)
|
| 221 |
elif mutation == "create":
|
| 222 |
child.create_hour = rng.choice([None, 7, 8, 9, 10])
|
|
|
|
|
|
|
| 223 |
elif mutation == "tools":
|
| 224 |
child.use_tools_early = not child.use_tools_early
|
| 225 |
elif mutation == "energy":
|
|
|
|
| 250 |
tag_offset=rng.randint(0, len(TAG_POOL) - 1),
|
| 251 |
topic_offset=rng.randint(0, len(ALL_TOPICS) - 1),
|
| 252 |
create_hour=rng.choice([None, 7, 8, 9]),
|
|
|
|
| 253 |
use_tools_early=rng.random() > 0.5,
|
| 254 |
rest_if_low_energy=rng.choice([0.2, 0.25, 0.3, 0.35]),
|
| 255 |
) for _ in range(population_size)]
|
|
@@ -301,8 +301,7 @@
|
|
| 301 |
" topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n",
|
| 302 |
" tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n",
|
| 303 |
" intent=INTENTS[(day*2+1)%4]),\n",
|
| 304 |
-
" ]
|
| 305 |
-
" replies=[{\"post_hour\": 12, \"reply_hour\": 13}])\n",
|
| 306 |
"\n",
|
| 307 |
"BASELINE_AGENTS = {\n",
|
| 308 |
" \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n",
|
|
@@ -570,22 +569,38 @@
|
|
| 570 |
"\n",
|
| 571 |
"RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n",
|
| 572 |
"{\n",
|
| 573 |
-
" \"tool_calls\": [{\"name\": \"
|
| 574 |
" \"scheduled_actions\": [\n",
|
| 575 |
-
" {\"hour\":
|
| 576 |
-
" \"
|
|
|
|
|
|
|
| 577 |
" ],\n",
|
| 578 |
-
" \"replies\": [{\"post_hour\": 12, \"reply_hour\": 13}],\n",
|
| 579 |
" \"notes\": \"strategy notes\"\n",
|
| 580 |
"}\n",
|
| 581 |
"\n",
|
| 582 |
-
"
|
| 583 |
-
"-
|
| 584 |
-
"-
|
| 585 |
-
"-
|
| 586 |
-
"-
|
| 587 |
-
"-
|
| 588 |
-
"-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
"\n",
|
| 590 |
"\n",
|
| 591 |
"def format_obs(obs):\n",
|
|
@@ -600,7 +615,7 @@
|
|
| 600 |
" tool_str = \"\"\n",
|
| 601 |
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 602 |
" if tr.success:\n",
|
| 603 |
-
" tool_str += f\" {tr.name}: {json.dumps(tr.data)
|
| 604 |
" if not tool_str:\n",
|
| 605 |
" tool_str = \" (none)\\n\"\n",
|
| 606 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
|
@@ -633,7 +648,6 @@
|
|
| 633 |
" return ViraltestAction(\n",
|
| 634 |
" tool_calls=tool_calls,\n",
|
| 635 |
" scheduled_actions=scheduled,\n",
|
| 636 |
-
" replies=data.get(\"replies\", []),\n",
|
| 637 |
" notes=data.get(\"notes\"),\n",
|
| 638 |
" )\n",
|
| 639 |
" except Exception:\n",
|
|
@@ -652,10 +666,10 @@
|
|
| 652 |
" return torch.device(\"cpu\")\n",
|
| 653 |
"\n",
|
| 654 |
"\n",
|
| 655 |
-
"def generate_action(mdl, tok, obs, history, temperature=0.7):\n",
|
| 656 |
" prompt = format_obs(obs)\n",
|
| 657 |
" messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
|
| 658 |
-
" messages.extend(history[-
|
| 659 |
" messages.append({\"role\": \"user\", \"content\": prompt})\n",
|
| 660 |
" text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 661 |
" inputs = tok(text_input, return_tensors=\"pt\").to(_infer_model_device(mdl))\n",
|
|
@@ -663,21 +677,27 @@
|
|
| 663 |
" out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
|
| 664 |
" do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
|
| 665 |
" resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
" return resp, parse_model_output(resp)\n",
|
| 667 |
"\n",
|
| 668 |
"\n",
|
| 669 |
-
"def run_llm_episode(mdl, tok, task, seed=42, verbose=False):\n",
|
| 670 |
" env = ViraltestEnvironment()\n",
|
| 671 |
" obs = env.reset(task=task, seed=seed)\n",
|
| 672 |
" rewards, energies = [], [obs.creator_energy]\n",
|
| 673 |
" history, pairs = [], []\n",
|
| 674 |
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 675 |
" if obs.done: break\n",
|
| 676 |
-
" if
|
| 677 |
-
"
|
| 678 |
-
"
|
| 679 |
-
" else:\n",
|
| 680 |
-
" resp, action = generate_action(mdl, tok, obs, history)\n",
|
| 681 |
" prompt = format_obs(obs)\n",
|
| 682 |
" pairs.append({\"prompt\": prompt, \"response\": resp})\n",
|
| 683 |
" obs = env.step(action)\n",
|
|
@@ -691,9 +711,17 @@
|
|
| 691 |
" print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
|
| 692 |
" if obs.done: break\n",
|
| 693 |
" gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
" return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
|
| 695 |
" \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
|
| 696 |
-
" \"energies\": energies, \"pairs\": pairs,\n",
|
| 697 |
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 698 |
" \"burned_out\": obs.creator_energy <= 0}\n",
|
| 699 |
"\n",
|
|
@@ -778,8 +806,8 @@
|
|
| 778 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 779 |
"from datasets import Dataset\n",
|
| 780 |
"\n",
|
| 781 |
-
"NUM_ROUNDS =
|
| 782 |
-
"EPISODES_PER_ROUND =
|
| 783 |
"TOP_K_FRACTION = 0.5\n",
|
| 784 |
"\n",
|
| 785 |
"training_log = {\n",
|
|
@@ -811,19 +839,21 @@
|
|
| 811 |
" text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 812 |
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 813 |
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 814 |
-
" all_pairs.append({\"text\": text, \"reward\":
|
| 815 |
"\n",
|
|
|
|
| 816 |
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
|
| 817 |
-
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f}\"
|
|
|
|
| 818 |
"\n",
|
| 819 |
" avg_r = np.mean(episode_rewards)\n",
|
| 820 |
" avg_g = np.mean(episode_graders)\n",
|
| 821 |
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
|
| 822 |
"\n",
|
| 823 |
-
" # Filter to top-K\n",
|
| 824 |
" threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
|
| 825 |
" filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
|
| 826 |
-
" print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples\")\n",
|
| 827 |
"\n",
|
| 828 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 829 |
"\n",
|
|
@@ -831,14 +861,18 @@
|
|
| 831 |
" sft_config = SFTConfig(\n",
|
| 832 |
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 833 |
" num_train_epochs=2,\n",
|
| 834 |
-
" per_device_train_batch_size=
|
| 835 |
-
" gradient_accumulation_steps=
|
| 836 |
" learning_rate=2e-5,\n",
|
| 837 |
-
"
|
| 838 |
-
" logging_steps=
|
| 839 |
" save_strategy=\"no\",\n",
|
| 840 |
" max_length=1024,\n",
|
| 841 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
" report_to=\"none\",\n",
|
| 843 |
" )\n",
|
| 844 |
"\n",
|
|
@@ -1082,7 +1116,7 @@
|
|
| 1082 |
"name": "python",
|
| 1083 |
"nbconvert_exporter": "python",
|
| 1084 |
"pygments_lexer": "ipython3",
|
| 1085 |
-
"version": "3.
|
| 1086 |
}
|
| 1087 |
},
|
| 1088 |
"nbformat": 4,
|
|
|
|
| 301 |
" topic=ALL_TOPICS[(day*2+1)%len(ALL_TOPICS)],\n",
|
| 302 |
" tags=[TAG_POOL[(day*6+3+i)%len(TAG_POOL)] for i in range(3)],\n",
|
| 303 |
" intent=INTENTS[(day*2+1)%4]),\n",
|
| 304 |
+
" ])\n",
|
|
|
|
| 305 |
"\n",
|
| 306 |
"BASELINE_AGENTS = {\n",
|
| 307 |
" \"always_rest\": plan_always_rest, \"spam\": plan_spam,\n",
|
|
|
|
| 569 |
"\n",
|
| 570 |
"RESPONSE FORMAT — return ONLY valid JSON, no markdown:\n",
|
| 571 |
"{\n",
|
| 572 |
+
" \"tool_calls\": [{\"name\": \"<tool>\", \"arguments\": {...}}],\n",
|
| 573 |
" \"scheduled_actions\": [\n",
|
| 574 |
+
" {\"hour\": 0-23, \"action_type\": \"post|create_content\",\n",
|
| 575 |
+
" \"content_type\": \"reel|story|carousel|text_post\",\n",
|
| 576 |
+
" \"topic\": \"<string>\", \"tags\": [\"...\"],\n",
|
| 577 |
+
" \"intent\": \"send_bait|save_bait|watch_bait|like_bait\"}\n",
|
| 578 |
" ],\n",
|
|
|
|
| 579 |
" \"notes\": \"strategy notes\"\n",
|
| 580 |
"}\n",
|
| 581 |
"\n",
|
| 582 |
+
"TOOLS (cost in API budget, total=100):\n",
|
| 583 |
+
"- query_trends(niche) cost=1 trending topics+tags for niche\n",
|
| 584 |
+
"- query_audience(segment_id) cost=2 segment topic affinities + active hours\n",
|
| 585 |
+
"- query_competitor(competitor_id, window_days) cost=2 competitor recent posts\n",
|
| 586 |
+
"- query_tag_history(tag) cost=1 your past signals (watch/sends/saves/likes) for a tag\n",
|
| 587 |
+
"- predict_engagement(scheduled_actions) cost=3 simulate a plan WITHOUT committing\n",
|
| 588 |
+
"- draft_review(scheduled_actions) cost=3 AI review of a draft plan\n",
|
| 589 |
+
"- query_creator_pool() cost=1 list collab partners with audience overlap\n",
|
| 590 |
+
"- propose_collab(partner_id, content_type, hour) cost=5 co-author the post at that hour (max 2/month)\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"ACTION SCHEMA:\n",
|
| 593 |
+
"- hour: 0..23 (unlisted hours = rest)\n",
|
| 594 |
+
"- action_type: post (publish) | create_content (build queue, no publish)\n",
|
| 595 |
+
"- content_type: reel | story | carousel | text_post\n",
|
| 596 |
+
"- intent: which Mosseri signal the post optimises for\n",
|
| 597 |
+
" send_bait -> DM shares (strongest discovery signal)\n",
|
| 598 |
+
" save_bait -> bookmarks (content quality)\n",
|
| 599 |
+
" watch_bait -> reels watch time\n",
|
| 600 |
+
" like_bait -> likes from existing followers\n",
|
| 601 |
+
"- tags: up to 5 hashtags\n",
|
| 602 |
+
"- topic: free-form string\n",
|
| 603 |
+
"- empty scheduled_actions = full day rest\"\"\")\n",
|
| 604 |
"\n",
|
| 605 |
"\n",
|
| 606 |
"def format_obs(obs):\n",
|
|
|
|
| 615 |
" tool_str = \"\"\n",
|
| 616 |
" for tr in getattr(obs, \"tool_results\", []):\n",
|
| 617 |
" if tr.success:\n",
|
| 618 |
+
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 619 |
" if not tool_str:\n",
|
| 620 |
" tool_str = \" (none)\\n\"\n",
|
| 621 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
|
|
|
| 648 |
" return ViraltestAction(\n",
|
| 649 |
" tool_calls=tool_calls,\n",
|
| 650 |
" scheduled_actions=scheduled,\n",
|
|
|
|
| 651 |
" notes=data.get(\"notes\"),\n",
|
| 652 |
" )\n",
|
| 653 |
" except Exception:\n",
|
|
|
|
| 666 |
" return torch.device(\"cpu\")\n",
|
| 667 |
"\n",
|
| 668 |
"\n",
|
| 669 |
+
"def generate_action(mdl, tok, obs, history, temperature=0.7, debug=True):\n",
|
| 670 |
" prompt = format_obs(obs)\n",
|
| 671 |
" messages = [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}]\n",
|
| 672 |
+
" messages.extend(history[-14:])\n",
|
| 673 |
" messages.append({\"role\": \"user\", \"content\": prompt})\n",
|
| 674 |
" text_input = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 675 |
" inputs = tok(text_input, return_tensors=\"pt\").to(_infer_model_device(mdl))\n",
|
|
|
|
| 677 |
" out = mdl.generate(**inputs, max_new_tokens=512, temperature=temperature,\n",
|
| 678 |
" do_sample=True, top_p=0.9, pad_token_id=tok.eos_token_id)\n",
|
| 679 |
" resp = tok.decode(out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True)\n",
|
| 680 |
+
" if debug:\n",
|
| 681 |
+
" print(\"=\" * 60)\n",
|
| 682 |
+
" print(f\"[LLM PROMPT] tokens={inputs['input_ids'].shape[1]}\")\n",
|
| 683 |
+
" print(prompt)\n",
|
| 684 |
+
" print(\"-\" * 60)\n",
|
| 685 |
+
" print(f\"[LLM RESPONSE] tokens={out.shape[1] - inputs['input_ids'].shape[1]}\")\n",
|
| 686 |
+
" print(resp)\n",
|
| 687 |
+
" print(\"=\" * 60)\n",
|
| 688 |
" return resp, parse_model_output(resp)\n",
|
| 689 |
"\n",
|
| 690 |
"\n",
|
| 691 |
+
"def run_llm_episode(mdl, tok, task, seed=42, verbose=False, debug_llm=True):\n",
|
| 692 |
" env = ViraltestEnvironment()\n",
|
| 693 |
" obs = env.reset(task=task, seed=seed)\n",
|
| 694 |
" rewards, energies = [], [obs.creator_energy]\n",
|
| 695 |
" history, pairs = [], []\n",
|
| 696 |
" for day in range(1, TASK_HORIZON + 1):\n",
|
| 697 |
" if obs.done: break\n",
|
| 698 |
+
" if debug_llm:\n",
|
| 699 |
+
" print(f\"\\n>>> Day {day} | task={task} | energy={obs.creator_energy:.2f}\")\n",
|
| 700 |
+
" resp, action = generate_action(mdl, tok, obs, history, debug=debug_llm)\n",
|
|
|
|
|
|
|
| 701 |
" prompt = format_obs(obs)\n",
|
| 702 |
" pairs.append({\"prompt\": prompt, \"response\": resp})\n",
|
| 703 |
" obs = env.step(action)\n",
|
|
|
|
| 711 |
" print(f\" Day {day:2d}: r={r:.4f} e={obs.creator_energy:.2f} posts={n_p} tools={len(action.tool_calls)}\")\n",
|
| 712 |
" if obs.done: break\n",
|
| 713 |
" gs = (obs.metadata or {}).get(\"grader_score\", 0.0)\n",
|
| 714 |
+
" # Per-step credit assignment: G_t = r_t + gamma * G_{t+1}, terminal = grader_score * w\n",
|
| 715 |
+
" GAMMA, TERMINAL_W = 0.95, 5.0\n",
|
| 716 |
+
" G, returns = gs * TERMINAL_W, [0.0] * len(rewards)\n",
|
| 717 |
+
" for t in reversed(range(len(rewards))):\n",
|
| 718 |
+
" G = rewards[t] + GAMMA * G\n",
|
| 719 |
+
" returns[t] = G\n",
|
| 720 |
+
" for i, pr in enumerate(pairs):\n",
|
| 721 |
+
" pr[\"return\"] = returns[i] if i < len(returns) else 0.0\n",
|
| 722 |
" return {\"task\": task, \"grader_score\": gs, \"total_reward\": sum(rewards),\n",
|
| 723 |
" \"final_energy\": obs.creator_energy, \"rewards\": rewards,\n",
|
| 724 |
+
" \"returns\": returns, \"energies\": energies, \"pairs\": pairs,\n",
|
| 725 |
" \"follower_delta\": obs.follower_count - 10000,\n",
|
| 726 |
" \"burned_out\": obs.creator_energy <= 0}\n",
|
| 727 |
"\n",
|
|
|
|
| 806 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 807 |
"from datasets import Dataset\n",
|
| 808 |
"\n",
|
| 809 |
+
"NUM_ROUNDS = 1\n",
|
| 810 |
+
"EPISODES_PER_ROUND = 1\n",
|
| 811 |
"TOP_K_FRACTION = 0.5\n",
|
| 812 |
"\n",
|
| 813 |
"training_log = {\n",
|
|
|
|
| 839 |
" text = (f\"<|im_start|>system\\n{SYSTEM_PROMPT}<|im_end|>\\n\"\n",
|
| 840 |
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 841 |
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 842 |
+
" all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
|
| 843 |
"\n",
|
| 844 |
+
" rets = result[\"returns\"]\n",
|
| 845 |
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {task.split('_')[-1]:>11s} \"\n",
|
| 846 |
+
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} \"\n",
|
| 847 |
+
" f\"return[min={min(rets):.2f} max={max(rets):.2f} mean={np.mean(rets):.2f}]\")\n",
|
| 848 |
"\n",
|
| 849 |
" avg_r = np.mean(episode_rewards)\n",
|
| 850 |
" avg_g = np.mean(episode_graders)\n",
|
| 851 |
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f}\")\n",
|
| 852 |
"\n",
|
| 853 |
+
" # Filter to top-K by per-pair return (per-step credit assignment)\n",
|
| 854 |
" threshold = np.percentile([p[\"reward\"] for p in all_pairs], (1 - TOP_K_FRACTION) * 100)\n",
|
| 855 |
" filtered = [p for p in all_pairs if p[\"reward\"] >= threshold] or all_pairs\n",
|
| 856 |
+
" print(f\" Filtered to {len(filtered)}/{len(all_pairs)} samples (return >= {threshold:.3f})\")\n",
|
| 857 |
"\n",
|
| 858 |
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 859 |
"\n",
|
|
|
|
| 861 |
" sft_config = SFTConfig(\n",
|
| 862 |
" output_dir=f\"./checkpoints/round_{round_idx}\",\n",
|
| 863 |
" num_train_epochs=2,\n",
|
| 864 |
+
" per_device_train_batch_size=32,\n",
|
| 865 |
+
" gradient_accumulation_steps=1,\n",
|
| 866 |
" learning_rate=2e-5,\n",
|
| 867 |
+
" warmup_ratio=0.1,\n",
|
| 868 |
+
" logging_steps=1,\n",
|
| 869 |
" save_strategy=\"no\",\n",
|
| 870 |
" max_length=1024,\n",
|
| 871 |
+
" bf16=True,\n",
|
| 872 |
+
" gradient_checkpointing=False,\n",
|
| 873 |
+
" dataloader_num_workers=4,\n",
|
| 874 |
+
" dataloader_pin_memory=True,\n",
|
| 875 |
+
" optim=\"adamw_torch_fused\",\n",
|
| 876 |
" report_to=\"none\",\n",
|
| 877 |
" )\n",
|
| 878 |
"\n",
|
|
|
|
| 1116 |
"name": "python",
|
| 1117 |
"nbconvert_exporter": "python",
|
| 1118 |
"pygments_lexer": "ipython3",
|
| 1119 |
+
"version": "3.13.1"
|
| 1120 |
}
|
| 1121 |
},
|
| 1122 |
"nbformat": 4,
|