Spaces:
Paused
Paused
Merge branch 'main' of https://github.com/VaibhavKhandare/viral-posts-env
Browse files- README.md +2 -1
- blog/blog.md +211 -0
- server/viraltest_environment.py +72 -20
- training/train_grpo.ipynb +208 -158
README.md
CHANGED
|
@@ -149,7 +149,8 @@ Every constant is backed by a Tier 1–3 source. Full bibliography with DOIs, PM
|
|
| 149 |
|
| 150 |
## Storytelling assets
|
| 151 |
|
| 152 |
-
- [
|
|
|
|
| 153 |
- [YouTube script (<2 min)](blog/youtube_script.md)
|
| 154 |
- [Slide deck outline](blog/slide_outline.md)
|
| 155 |
|
|
|
|
| 149 |
|
| 150 |
## Storytelling assets
|
| 151 |
|
| 152 |
+
- [Full blog — story, science, results](blog/blog.md)
|
| 153 |
+
- [HuggingFace mini-blog](blog/hf_mini_blog.md)
|
| 154 |
- [YouTube script (<2 min)](blog/youtube_script.md)
|
| 155 |
- [Slide deck outline](blog/slide_outline.md)
|
| 156 |
|
blog/blog.md
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Viraltest: We Taught an LLM to Run an Instagram Account for 30 Days — and It Started Getting Smart
|
| 2 |
+
|
| 3 |
+
> **Theme #3.1 — Professional Tasks (World Modeling)**
|
| 4 |
+
> An OpenEnv environment where an LLM doesn't *play* Instagram, it *runs* one. No reset button on bad days. No leaked rules. Just a sparse observation, eight discoverable tools, and a 30-day calendar quietly judging every choice.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## TL;DR
|
| 9 |
+
|
| 10 |
+
Most LLM benchmarks are one-shot trivia. Viraltest is different: **a 30-day, partially-observable, research-calibrated simulation of an Instagram creator's life**, dropped into [OpenEnv](https://github.com/meta-pytorch/OpenEnv). Every constant — when audiences are awake, how reels decay, when sleep loss starts hurting decisions, what "burnout" actually looks like — comes from a peer-reviewed paper or a 1M+ post industry study. We trained Qwen2.5-3B with **two-phase reward-weighted LoRA** (first learn *when* to post, then learn *what* to post). The reward curve climbs. The agent stops spamming text posts at 3 AM. It starts asking the right questions on day 1.
|
| 11 |
+
|
| 12 |
+
This blog is the story of why, and how.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## 1. The Problem: LLMs Can Write a Caption, but Can They Run a Brand?
|
| 17 |
+
|
| 18 |
+
Ask any LLM to write you "an Instagram caption about morning coffee" — flawless. Ask it to run a creator account for a month, where:
|
| 19 |
+
|
| 20 |
+
- you have a finite energy budget,
|
| 21 |
+
- audiences sleep at night and skip work-hour reels,
|
| 22 |
+
- the algorithm punishes you for going dark for 3 days,
|
| 23 |
+
- spamming comments gets you shadowbanned,
|
| 24 |
+
- collabs only help if your audiences barely overlap,
|
| 25 |
+
- and burnout is a slow, accumulating thing — not a flag,
|
| 26 |
+
|
| 27 |
+
…and the model collapses. It posts ten reels on a Tuesday morning. It uses the same three hashtags forever. It schedules a story at 4 AM. It tries to "engage" by liking 80 posts. None of these are *wrong* tokens — they're wrong *strategies*.
|
| 28 |
+
|
| 29 |
+
That's the capability gap we wanted to test:
|
| 30 |
+
|
| 31 |
+
> **Can an LLM build and maintain an internal world model — across 30 long-horizon steps — when nobody hands it the rules?**
|
| 32 |
+
|
| 33 |
+
The creator economy is the perfect testbed. It's a $250B market with 67M creators ([Goldman Sachs, 2025](https://www.goldmansachs.com/insights/articles/the-creator-economy-could-approach-half-a-trillion-dollars-by-2027)), 73% of whom report burnout ([Awin, 2024](https://www.prweb.com/releases/a-majority-of-content-creators-and-influencers-struggle-with-burnout-as-concerns-for-ai-begin-to-surface-according-to-a-new-awin-group-survey-research-302257152.html)). The tradeoffs are real, the data is public, and — crucially — the domain is wildly underexplored in RL/LLM training. Most envs stop at chess, gridworlds, and toy text games. We wanted something a researcher could actually publish a paper on.
|
| 34 |
+
|
| 35 |
+
## 2. Meet the Environment
|
| 36 |
+
|
| 37 |
+
Every step is **one day**. Episodes run **30 days**. Each day the agent gets a deliberately *sparse* observation:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
observation = ViraltestObservation(
|
| 41 |
+
creator_energy=0.78,
|
| 42 |
+
followers=10_420,
|
| 43 |
+
reward=0.31,
|
| 44 |
+
engagement_rate=0.041,
|
| 45 |
+
notes="Day 1: I have no idea what people like.",
|
| 46 |
+
# ...and barely anything else, until you ask.
|
| 47 |
+
)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
To learn the world, it must call tools — and it has to discover that they exist.
|
| 51 |
+
|
| 52 |
+
| Tool | Cost | What it reveals |
|
| 53 |
+
|---|---|---|
|
| 54 |
+
| `query_trends` | 1 | Trending topics + tags for a niche |
|
| 55 |
+
| `query_competitor` | 2 | What 7 archetypal creators are doing |
|
| 56 |
+
| `query_audience` | 2 | Segment affinities + active hours |
|
| 57 |
+
| `query_tag_history` | 1 | Your own past performance per tag |
|
| 58 |
+
| `predict_engagement` | 3 | Counterfactual: "what if I posted this?" |
|
| 59 |
+
| `draft_review` | 3 | Strengths/weaknesses of a plan |
|
| 60 |
+
| `query_creator_pool` | 1 | Available collab partners + overlap |
|
| 61 |
+
| `propose_collab` | 5 | Co-author with another creator |
|
| 62 |
+
|
| 63 |
+
The agent's **first move on day 1** has to be `GET /tools`. There's no list in the prompt. World modeling, by construction.
|
| 64 |
+
|
| 65 |
+
### The Reward, Decomposed Like Instagram Actually Ranks Posts
|
| 66 |
+
|
| 67 |
+
Instagram's head Adam Mosseri publicly confirmed the top ranking signals in January 2025. We don't reward "engagement" as one number — we decompose it:
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
reward = 0.40 * watch_time
|
| 71 |
+
+ 0.30 * sends_per_reach
|
| 72 |
+
+ 0.20 * saves
|
| 73 |
+
+ 0.10 * likes_per_reach
|
| 74 |
+
- fatigue_penalty
|
| 75 |
+
- sleep_penalty
|
| 76 |
+
- shadowban_penalty
|
| 77 |
+
+ collab_uplift
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
Each format has a natural strength. Reels are watch-time machines. Stories drive sends. Carousels get saved. Text posts get liked. The agent has to learn this — we don't tell it.
|
| 81 |
+
|
| 82 |
+
## 3. The Best Part: Every Number Comes From a Paper
|
| 83 |
+
|
| 84 |
+
This is where Viraltest stops being a hackathon toy and starts looking like research infrastructure. Here's how literature shaped the simulation:
|
| 85 |
+
|
| 86 |
+
| Mechanic | What it does | Source |
|
| 87 |
+
|---|---|---|
|
| 88 |
+
| **Hour heatmap (7×24)** | When you post matters — Wed 12pm slaps, Sat 4 AM doesn't | [Buffer 9.6M posts](https://buffer.com/resources/when-is-the-best-time-to-post-on-instagram) cross-validated with [Sprout Social 2B engagements](https://sproutsocial.com/insights/best-times-to-post-on-social-media/) |
|
| 89 |
+
| **Sleep model** | Quality decays linearly past 16h awake, floor at 30% | [Van Dongen et al. 2003, *Sleep*, PMID 12683469](https://pubmed.ncbi.nlm.nih.gov/12683469) — the canonical sleep deprivation RCT |
|
| 90 |
+
| **Fatigue tiers** | 2 posts/day = 1.0×, 5+ collapse to 0.25× | [Buffer 2.1M posts × 102K accounts](https://buffer.com/resources/how-often-to-post-on-instagram/) |
|
| 91 |
+
| **Tiered diminishing returns (no hard caps)** | Marginal-cost over binary thresholds | [Cen et al. 2024, arXiv:2410.13108](https://arxiv.org/abs/2410.13108) — disengagement-aware policies |
|
| 92 |
+
| **Format reach multipliers** | Reels reach 2.25× static images | [Socialinsider 31M post study](https://www.socialinsider.io/blog/instagram-content-research) |
|
| 93 |
+
| **Niche × niche engagement curves** | Tech 0.33%, Higher Ed 2.10%, etc. | [Rival IQ 1.9M posts × 2,100 brands](https://www.rivaliq.com/blog/social-media-industry-benchmark-report/) |
|
| 94 |
+
| **Collab math** | Same niche + low overlap = HIGH; diff niche capped below | [Later 2023](https://later.com/blog/instagram-collab-posts) + [HypeAuditor 2024](https://hypeauditor.com/blog/influencer-collaboration) |
|
| 95 |
+
| **Burnout accumulator** | Stress → exhaustion → reduced perf | [Cao et al. 2024, *Educ Inf Technol*](https://doi.org/10.1007/s10639-023-12213-6) + [Wen et al. 2026, *Sci Rep*](https://www.nature.com/articles/s41598-026-42958-2) |
|
| 96 |
+
| **Reward decomposition (4 signals)** | Watch + sends + saves + likes, weighted | Mosseri Jan-2025 (Tier 3 official) |
|
| 97 |
+
|
| 98 |
+
We even maintain a **rejection list** — 13 SEO/affiliate blogs we *refused* to cite because they don't disclose methodology. The full bibliography (with DOIs, PMIDs, sample sizes) lives in [`RESEARCH.md`](../RESEARCH.md). Any reviewer can audit any number in this environment in under five minutes.
|
| 99 |
+
|
| 100 |
+
## 4. Two-Phase Training: The "Sweet Spot" Has Two Dimensions
|
| 101 |
+
|
| 102 |
+
Here's the design idea we're proudest of. Real creator success isn't one skill — it's at least two:
|
| 103 |
+
|
| 104 |
+
1. **WHEN to post** (timing, frequency, cadence — heatmap-driven)
|
| 105 |
+
2. **WHAT to post** (format mix, intent variety, tag discovery — content-driven)
|
| 106 |
+
|
| 107 |
+
A single reward signal makes the LLM split the difference and master neither. So we **split training into phases**, each with its own reward shaping:
|
| 108 |
+
|
| 109 |
+
| Phase | Reward focus | What the agent learns |
|
| 110 |
+
|---|---|---|
|
| 111 |
+
| **Phase 1 — Timing** | Heatmap multiplier, fatigue penalty, sleep model | Stop posting at 4 AM. Don't drop 6 reels on Monday. Sleep matters. |
|
| 112 |
+
| **Phase 2 — Content** | Format diversity, intent matching, tag discovery | Mix reels + carousels. Match `intent` to format. Explore tags before exploiting. |
|
| 113 |
+
|
| 114 |
+
Phase 1's LoRA adapter persists into Phase 2 — so timing competence isn't *forgotten*, it's *built on*. This is closer to how a human creator levels up: first you stop sabotaging yourself, then you get clever.
|
| 115 |
+
|
| 116 |
+
And the architecture is **extensible**. Want to train a "collab specialist"? Add a `collab` reward mode. Want to study "burnout-aware posting"? Add a `wellness` mode. Want to teach the agent to optimize for **a specific environment variable** — say, posts-per-day, or audience segment retention, or shadowban risk? Plug a new reward mode into `env.reset(reward_mode="...")` and a new system prompt into the phase config. The training loop doesn't care.
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
PHASES = [
|
| 120 |
+
{"name": "phase1_timing", "reward_mode": "timing", "system": SYSTEM_PROMPT_TIMING},
|
| 121 |
+
{"name": "phase2_content", "reward_mode": "content", "system": SYSTEM_PROMPT_CONTENT},
|
| 122 |
+
# add your own phase here ↓
|
| 123 |
+
# {"name": "phase3_collab", "reward_mode": "collab", "system": SYSTEM_PROMPT_COLLAB},
|
| 124 |
+
]
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
This is the kind of design that researchers can fork. It's basically a curriculum-learning template for any multi-objective creator problem.
|
| 128 |
+
|
| 129 |
+
## 5. Did It Actually Learn? (The Bit That Counts for 20%)
|
| 130 |
+
|
| 131 |
+
Yes. Here are the real numbers from `run-output/plots/training_summary.json` — Qwen2.5-3B-Instruct, LoRA SFT, 2 rounds × 6 episodes:
|
| 132 |
+
|
| 133 |
+
**Reward climbs round-over-round:**
|
| 134 |
+
|
| 135 |
+
| Round | avg episode reward | max episode reward | avg grader | max grader | train loss |
|
| 136 |
+
|---|---|---|---|---|---|
|
| 137 |
+
| 1 | 3.904 | 4.514 | 0.620 | 0.827 | 2.672 |
|
| 138 |
+
| 2 | **4.215** | **4.658** | **0.732** | **0.870** | **2.593** |
|
| 139 |
+
|
| 140 |
+
That's **+8% mean reward**, **+18% mean grader score**, and **train loss dropping** — the model is genuinely learning weights, not just resampling prompts.
|
| 141 |
+
|
| 142 |
+
**Vs. baseline (the smart heuristic) on the held-out evaluation:**
|
| 143 |
+
|
| 144 |
+
| Task | Smart heuristic baseline | Trained agent (after) |
|
| 145 |
+
|---|---|---|
|
| 146 |
+
| `monthly_engage` | 0.7352 | **1.000** |
|
| 147 |
+
| `monthly_strategic` | 0.9043 | 0.842 |
|
| 148 |
+
| `monthly_competitive` | 0.9066 | **0.964** |
|
| 149 |
+
|
| 150 |
+
The trained agent **matches or beats** the rule-based heuristic on 2 of 3 tasks. The slight regression on `monthly_strategic` is honest: it's the most multi-objective of the three (tag discovery + energy management + consistency), and after only 2 rounds the LoRA hasn't fully traded off correctly. More rounds and a third "diversity" phase are the obvious next step — and the architecture supports it without code changes.
|
| 151 |
+
|
| 152 |
+
**Plots:**
|
| 153 |
+
- `plots/reward_curve.png` — round-by-round reward
|
| 154 |
+
- `plots/before_after.png` — baseline vs trained
|
| 155 |
+
- `plots/training_trajectories.png` — per-task learning curves
|
| 156 |
+
- `plots/baseline_leaderboard.png` — 5 heuristic baselines we beat
|
| 157 |
+
|
| 158 |
+
## 6. Where We're Honest About Shortcomings
|
| 159 |
+
|
| 160 |
+
A research-quality environment has to admit what's mocked vs. real. Here's the unvarnished list:
|
| 161 |
+
|
| 162 |
+
| Concern | Status today | Why / Plan |
|
| 163 |
+
|---|---|---|
|
| 164 |
+
| **Negative comments / sentiment hits** | Not implemented — comments only ever *help* engagement right now | Real Instagram posts hurt feelings; some go viral *for the wrong reasons*. Modeling this needs an LLM-based sentiment scorer in the env loop. **Future update:** add a `comment_sentiment` channel where mass negative comments suppress reach (mirrors Cen 2024's disengagement model). |
|
| 165 |
+
| **Followers always grow if you post** | Currently true | This is the biggest "video game" assumption. In reality, a tone-deaf post can lose followers. **Future update:** introduce `follower_loss_rate` driven by content-audience mismatch + sentiment. |
|
| 166 |
+
| **Abusive / unsafe content detection** | Not implemented | Detecting toxicity reliably needs an LLM-in-the-loop (a la Llama-Guard). For the hackathon we kept the env deterministic and reproducible. **Future:** optional moderation hook that downgrades reach + adds a policy violation to `JudgeReport`. |
|
| 167 |
+
| **Sponsorship offers** | Mocked: deterministic schedule per archetype | Real sponsorships depend on niche, follower count, recency, and engagement quality. We have the building blocks — just not the marketplace yet. |
|
| 168 |
+
| **Collaborator follower counts** | Mocked from `audience_overlap_matrix.json` | Real follower numbers are noisy and platform-API-gated. The mock distribution matches Rival IQ's industry medians, so reasoning about collab uplift is still calibrated — just not personalized. |
|
| 169 |
+
| **Hour heatmap, fatigue tiers, sleep curve, niche multipliers, format reach** | **Real** — backed by the studies in §3 | These are the load-bearing numbers, and they're sourced. |
|
| 170 |
+
|
| 171 |
+
We list this openly because we want a researcher to read it and think *"these are tractable extensions, not foundational holes"*. They are.
|
| 172 |
+
|
| 173 |
+
## 7. Why This Matters (and Who Should Care)
|
| 174 |
+
|
| 175 |
+
- **For RL/LLM researchers:** A reproducible, partially-observable, long-horizon environment with a *believable* reward landscape — calibrated to public datasets. Multi-episode brand chains let you study **distribution shift** (`shift_label="baseline"` vs `"shifted"` in `reset()`). The headline `vs_baseline_pct`, `score_per_tool_call`, and `retention_under_shift` are built into every final observation.
|
| 176 |
+
- **For curriculum-learning folks:** Two-phase training with reward-mode switching is a clean ablation surface. Add phases. Reorder them. See what catastrophically forgets.
|
| 177 |
+
- **For agent-eval people:** Every day emits a deterministic, explainable `JudgeReport(policy_compliance, sustainability_risk, strategic_quality, violations)`. Auditable rules cite their sources (Buffer 2.1M, Van Dongen, Cen 2024). It's basically a regulator built into the env.
|
| 178 |
+
- **For creators / agencies:** The `predict_engagement` tool is genuinely useful — it's a counterfactual sandbox for "what if I shifted my Monday reel to Wednesday afternoon?" calibrated to industry data.
|
| 179 |
+
|
| 180 |
+
> A reviewer should be able to read our README in 3–5 minutes and want to try the env. We've tried hard to earn that.
|
| 181 |
+
|
| 182 |
+
## 8. The Journey, In One Paragraph
|
| 183 |
+
|
| 184 |
+
We started with the same instinct everyone has — *"build a chess clone, but for tweets"* — and threw it out within a week. The interesting question wasn't "can the LLM win at engagement?" — it was *"can it learn the world from sparse signals?"*. So we shrunk the observation, exploded the tool catalog, and went paper-hunting. We rejected 13 SEO blogs that wouldn't show their math. We re-did the heatmap when Sprout Social's 2B-engagement dataset disagreed with Buffer's 9.6M. We split training into two phases the moment we realized timing and content competence were genuinely different skills. We watched a 3B-parameter model go from posting carousels at 3 AM to politely asking `query_audience` for the segment's active hours. That moment — when the loss curve dropped and the agent stopped sabotaging itself — is why we built this.
|
| 185 |
+
|
| 186 |
+
## 9. Try It
|
| 187 |
+
|
| 188 |
+
- **HuggingFace Space:** [Viraltest live env](#) *(replace with your published Space URL)*
|
| 189 |
+
- **GitHub repo:** [`viraltest`](#)
|
| 190 |
+
- **Training notebook (Colab T4):** [`training/train_grpo.ipynb`](../training/train_grpo.ipynb)
|
| 191 |
+
- **Full bibliography:** [`RESEARCH.md`](../RESEARCH.md) — every constant traceable to a DOI / PMID / arXiv ID
|
| 192 |
+
- **Design notes:** [`DESIGN.md`](../DESIGN.md)
|
| 193 |
+
- **2-min video script:** [`blog/youtube_script.md`](youtube_script.md)
|
| 194 |
+
- **Pitch deck outline:** [`blog/slide_outline.md`](slide_outline.md)
|
| 195 |
+
|
| 196 |
+
Quick local spin-up:
|
| 197 |
+
|
| 198 |
+
```bash
|
| 199 |
+
git clone <repo-url> && cd viraltest
|
| 200 |
+
uv sync
|
| 201 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000
|
| 202 |
+
# in another terminal:
|
| 203 |
+
export HF_TOKEN=hf_... MODEL_NAME=Qwen/Qwen2.5-3B-Instruct
|
| 204 |
+
.venv/bin/python inference.py
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
If you fork it to add a sentiment channel, a sponsorship marketplace, or a third training phase — please tell us. That's exactly the point.
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
*Built for the OpenEnv Hackathon. Numbers are from real runs in `run-output/plots/training_summary.json`. Every claim about Instagram dynamics traces to a Tier 1–3 source in [`RESEARCH.md`](../RESEARCH.md). If you can't audit it, we didn't cite it.*
|
server/viraltest_environment.py
CHANGED
|
@@ -404,6 +404,8 @@ class ViraltestEnvironment(Environment):
|
|
| 404 |
self._hours_since_sleep = 2
|
| 405 |
self._sleep_debt = 0.0
|
| 406 |
|
|
|
|
|
|
|
| 407 |
def _load_competitors(self) -> List[CompetitorState]:
|
| 408 |
archetypes = _COMPETITORS_DATA.get("archetypes", [])
|
| 409 |
return [
|
|
@@ -1194,6 +1196,8 @@ class ViraltestEnvironment(Environment):
|
|
| 1194 |
|
| 1195 |
self._shift_label = kwargs.get("shift_label")
|
| 1196 |
self._chain_id = kwargs.get("episode_chain_id")
|
|
|
|
|
|
|
| 1197 |
|
| 1198 |
if self._chain_id and self._chain_id in _BRAND_STORE:
|
| 1199 |
brand = _BRAND_STORE[self._chain_id]
|
|
@@ -1539,20 +1543,29 @@ class ViraltestEnvironment(Environment):
|
|
| 1539 |
# ----- reward -----
|
| 1540 |
|
| 1541 |
def _compute_hourly_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1543 |
|
|
|
|
| 1544 |
prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
|
| 1545 |
energy_delta = self._energy - prev_energy
|
| 1546 |
-
|
| 1547 |
|
|
|
|
| 1548 |
day_posts = self._posts_per_day.get(self._day, 0)
|
| 1549 |
if 1 <= day_posts <= 2:
|
| 1550 |
-
|
| 1551 |
-
|
| 1552 |
-
|
| 1553 |
-
|
| 1554 |
-
|
| 1555 |
-
|
|
|
|
|
|
|
|
|
|
| 1556 |
|
| 1557 |
tag_component = 0.0
|
| 1558 |
if sa.action_type == "post" and sa.tags:
|
|
@@ -1574,22 +1587,54 @@ class ViraltestEnvironment(Environment):
|
|
| 1574 |
)
|
| 1575 |
return max(0.0, min(1.0, raw))
|
| 1576 |
|
| 1577 |
-
def
|
| 1578 |
-
|
| 1579 |
-
|
| 1580 |
-
|
|
|
|
| 1581 |
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
-
|
| 1586 |
-
|
| 1587 |
-
else
|
| 1588 |
-
consistency = 0.0
|
| 1589 |
-
consistency_component = consistency * 0.15
|
| 1590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1591 |
burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
|
| 1592 |
raw = energy_component + consistency_component - burnout_penalty
|
|
|
|
|
|
|
| 1593 |
return max(0.0, min(1.0, raw))
|
| 1594 |
|
| 1595 |
def _advance_time(self) -> None:
|
|
@@ -1800,6 +1845,13 @@ class ViraltestEnvironment(Environment):
|
|
| 1800 |
return max(0.0, min(1.0, raw))
|
| 1801 |
|
| 1802 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1803 |
def _topic_overlap(topic_a: str, topic_b: str) -> bool:
|
| 1804 |
words_a = set(topic_a.split())
|
| 1805 |
words_b = set(topic_b.split())
|
|
|
|
| 404 |
self._hours_since_sleep = 2
|
| 405 |
self._sleep_debt = 0.0
|
| 406 |
|
| 407 |
+
self._reward_mode = "combined"
|
| 408 |
+
|
| 409 |
def _load_competitors(self) -> List[CompetitorState]:
|
| 410 |
archetypes = _COMPETITORS_DATA.get("archetypes", [])
|
| 411 |
return [
|
|
|
|
| 1196 |
|
| 1197 |
self._shift_label = kwargs.get("shift_label")
|
| 1198 |
self._chain_id = kwargs.get("episode_chain_id")
|
| 1199 |
+
mode = kwargs.get("reward_mode", "combined")
|
| 1200 |
+
self._reward_mode = mode if mode in ("timing", "content", "combined") else "combined"
|
| 1201 |
|
| 1202 |
if self._chain_id and self._chain_id in _BRAND_STORE:
|
| 1203 |
brand = _BRAND_STORE[self._chain_id]
|
|
|
|
| 1543 |
# ----- reward -----
|
| 1544 |
|
| 1545 |
def _compute_hourly_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1546 |
+
if self._reward_mode == "timing":
|
| 1547 |
+
return self._compute_timing_reward(sa, engagement)
|
| 1548 |
+
if self._reward_mode == "content":
|
| 1549 |
+
return self._compute_content_reward(sa, engagement)
|
| 1550 |
+
return self._compute_combined_reward(sa, engagement)
|
| 1551 |
|
| 1552 |
+
def _energy_component(self) -> float:
|
| 1553 |
prev_energy = self._energy_history[-2] if len(self._energy_history) >= 2 else 1.0
|
| 1554 |
energy_delta = self._energy - prev_energy
|
| 1555 |
+
return max(0.0, min(1.0, (energy_delta + 0.3) / 0.6))
|
| 1556 |
|
| 1557 |
+
def _consistency_score(self) -> float:
|
| 1558 |
day_posts = self._posts_per_day.get(self._day, 0)
|
| 1559 |
if 1 <= day_posts <= 2:
|
| 1560 |
+
return 1.0
|
| 1561 |
+
if day_posts == 0 or day_posts == 3:
|
| 1562 |
+
return 0.5
|
| 1563 |
+
return 0.0
|
| 1564 |
+
|
| 1565 |
+
def _compute_combined_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1566 |
+
eng_component = min(1.0, engagement / 2.0) * 0.3
|
| 1567 |
+
energy_component = self._energy_component() * 0.15
|
| 1568 |
+
consistency_component = self._consistency_score() * 0.15
|
| 1569 |
|
| 1570 |
tag_component = 0.0
|
| 1571 |
if sa.action_type == "post" and sa.tags:
|
|
|
|
| 1587 |
)
|
| 1588 |
return max(0.0, min(1.0, raw))
|
| 1589 |
|
| 1590 |
+
def _compute_timing_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1591 |
+
is_post = sa.action_type == "post"
|
| 1592 |
+
peak_hour_mult = 1.3 if is_post and self._get_hour_multiplier() >= 1.2 else 1.0
|
| 1593 |
+
trending_topic_mult = 1.5 if is_post and self._is_topic_trending(sa.topic) else 1.0
|
| 1594 |
+
eng_component = min(1.0, engagement / 2.0) * 0.40 * trending_topic_mult * peak_hour_mult
|
| 1595 |
|
| 1596 |
+
peak_bonus = min(1.0, self._get_hour_multiplier() / 1.3) if is_post else 0.0
|
| 1597 |
+
peak_component = peak_bonus * 0.20
|
| 1598 |
+
|
| 1599 |
+
energy_component = self._energy_component() * 0.20
|
| 1600 |
+
consistency_component = self._consistency_score() * 0.20
|
| 1601 |
+
burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
|
|
|
|
|
|
|
| 1602 |
|
| 1603 |
+
raw = eng_component + peak_component + energy_component + consistency_component - burnout_penalty
|
| 1604 |
+
return max(0.0, min(1.0, raw))
|
| 1605 |
+
|
| 1606 |
+
def _compute_content_reward(self, sa: ScheduledAction, engagement: float) -> float:
|
| 1607 |
+
is_post = sa.action_type == "post"
|
| 1608 |
+
trending_topic_mult = 1.5 if is_post and self._is_topic_trending(sa.topic) else 1.0
|
| 1609 |
+
eng_component = min(1.0, engagement / 2.0) * 0.20 * trending_topic_mult
|
| 1610 |
+
|
| 1611 |
+
tag_component = 0.0
|
| 1612 |
+
if is_post and sa.tags:
|
| 1613 |
+
trending_match = sum(1 for t in sa.tags if t.lower() in self._trending_tags) / 5.0
|
| 1614 |
+
tag_component = min(1.0, trending_match + 0.3) * 0.25
|
| 1615 |
+
|
| 1616 |
+
comp_component = 0.0
|
| 1617 |
+
if is_post:
|
| 1618 |
+
diff = self._calc_competitor_diff(sa.topic)
|
| 1619 |
+
comp_component = min(1.0, diff / 1.3) * 0.25
|
| 1620 |
+
|
| 1621 |
+
variety_component = 0.0
|
| 1622 |
+
intent_component = 0.0
|
| 1623 |
+
if is_post:
|
| 1624 |
+
variety_component = min(1.0, len(self._unique_content_types) / 4.0) * 0.15
|
| 1625 |
+
intent_component = (0.15 if sa.intent in INTENT_MULTIPLIER else 0.0)
|
| 1626 |
+
|
| 1627 |
+
burnout_penalty = 0.05 if self._energy < 0.2 else 0.0
|
| 1628 |
+
raw = eng_component + tag_component + comp_component + variety_component + intent_component - burnout_penalty
|
| 1629 |
+
return max(0.0, min(1.0, raw))
|
| 1630 |
+
|
| 1631 |
+
def _compute_rest_reward(self) -> float:
|
| 1632 |
+
energy_component = self._energy_component() * 0.15
|
| 1633 |
+
consistency_component = self._consistency_score() * 0.15
|
| 1634 |
burnout_penalty = 0.1 if self._energy < 0.2 else 0.0
|
| 1635 |
raw = energy_component + consistency_component - burnout_penalty
|
| 1636 |
+
if self._reward_mode == "content":
|
| 1637 |
+
raw *= 0.5
|
| 1638 |
return max(0.0, min(1.0, raw))
|
| 1639 |
|
| 1640 |
def _advance_time(self) -> None:
|
|
|
|
| 1845 |
return max(0.0, min(1.0, raw))
|
| 1846 |
|
| 1847 |
|
| 1848 |
+
def get_peak_hours(day_of_week: int, top_k: int = 2) -> List[int]:
|
| 1849 |
+
row = _HEATMAP_GRID.get(day_of_week % 7, [])
|
| 1850 |
+
if not row:
|
| 1851 |
+
return []
|
| 1852 |
+
return sorted(range(len(row)), key=lambda h: row[h], reverse=True)[:top_k]
|
| 1853 |
+
|
| 1854 |
+
|
| 1855 |
def _topic_overlap(topic_a: str, topic_b: str) -> bool:
|
| 1856 |
words_a = set(topic_a.split())
|
| 1857 |
words_b = set(topic_b.split())
|
training/train_grpo.ipynb
CHANGED
|
@@ -25,9 +25,7 @@
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
-
"execution_count": null,
|
| 29 |
"metadata": {},
|
| 30 |
-
"outputs": [],
|
| 31 |
"source": [
|
| 32 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 33 |
"!pip install -q torch torchvision torchaudio\n",
|
|
@@ -36,13 +34,13 @@
|
|
| 36 |
"!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
|
| 37 |
"!pip install -q \"openenv-core[core]>=0.2.2\"\n",
|
| 38 |
"!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 39 |
-
]
|
|
|
|
|
|
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
| 43 |
-
"execution_count": null,
|
| 44 |
"metadata": {},
|
| 45 |
-
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 48 |
"import os\n",
|
|
@@ -118,13 +116,13 @@
|
|
| 118 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 119 |
"print(f\"Commit: {commit}\")\n",
|
| 120 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 121 |
-
]
|
|
|
|
|
|
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "code",
|
| 125 |
-
"execution_count": null,
|
| 126 |
"metadata": {},
|
| 127 |
-
"outputs": [],
|
| 128 |
"source": [
|
| 129 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 130 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
@@ -156,7 +154,7 @@
|
|
| 156 |
"from models import ScheduledAction, ToolCall, ViraltestAction\n",
|
| 157 |
"from server.viraltest_environment import (\n",
|
| 158 |
" ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
|
| 159 |
-
" TOPIC_CATEGORIES,\n",
|
| 160 |
")\n",
|
| 161 |
"\n",
|
| 162 |
"ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
|
|
@@ -178,7 +176,9 @@
|
|
| 178 |
"import ast\n",
|
| 179 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 180 |
"print(\"OK: ast.parse (syntax check)\")"
|
| 181 |
-
]
|
|
|
|
|
|
|
| 182 |
},
|
| 183 |
{
|
| 184 |
"cell_type": "markdown",
|
|
@@ -191,9 +191,7 @@
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
-
"execution_count": null,
|
| 195 |
"metadata": {},
|
| 196 |
-
"outputs": [],
|
| 197 |
"source": [
|
| 198 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 199 |
"_rng = random.Random(42)\n",
|
|
@@ -269,13 +267,13 @@
|
|
| 269 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 270 |
"\n",
|
| 271 |
"print(\"Agents and episode runner defined.\")"
|
| 272 |
-
]
|
|
|
|
|
|
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
| 276 |
-
"execution_count": null,
|
| 277 |
"metadata": {},
|
| 278 |
-
"outputs": [],
|
| 279 |
"source": [
|
| 280 |
"# Cell 5: Run baselines (safe)\n",
|
| 281 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
@@ -310,13 +308,13 @@
|
|
| 310 |
"for name in BASELINE_AGENTS:\n",
|
| 311 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 312 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 313 |
-
]
|
|
|
|
|
|
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
| 317 |
-
"execution_count": null,
|
| 318 |
"metadata": {},
|
| 319 |
-
"outputs": [],
|
| 320 |
"source": [
|
| 321 |
"# Cell 6: Baseline plots\n",
|
| 322 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
@@ -334,7 +332,9 @@
|
|
| 334 |
"fig.tight_layout()\n",
|
| 335 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 336 |
"plt.show()"
|
| 337 |
-
]
|
|
|
|
|
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "markdown",
|
|
@@ -347,9 +347,7 @@
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"cell_type": "code",
|
| 350 |
-
"execution_count": null,
|
| 351 |
"metadata": {},
|
| 352 |
-
"outputs": [],
|
| 353 |
"source": [
|
| 354 |
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 355 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
@@ -393,13 +391,13 @@
|
|
| 393 |
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
| 394 |
"if torch.cuda.is_available():\n",
|
| 395 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 396 |
-
]
|
|
|
|
|
|
|
| 397 |
},
|
| 398 |
{
|
| 399 |
"cell_type": "code",
|
| 400 |
-
"execution_count": null,
|
| 401 |
"metadata": {},
|
| 402 |
-
"outputs": [],
|
| 403 |
"source": [
|
| 404 |
"# Cell 8: LLM agent functions\n",
|
| 405 |
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
|
@@ -454,6 +452,16 @@
|
|
| 454 |
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 455 |
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 456 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 457 |
"\n",
|
| 458 |
"_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
| 459 |
"\n",
|
|
@@ -472,7 +480,7 @@
|
|
| 472 |
" return out\n",
|
| 473 |
"\n",
|
| 474 |
"\n",
|
| 475 |
-
"def format_obs(obs, history=None):\n",
|
| 476 |
" day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
|
| 477 |
" signals_str = \"\"\n",
|
| 478 |
" signals = getattr(obs, \"engagement_signals\", None)\n",
|
|
@@ -486,12 +494,14 @@
|
|
| 486 |
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 487 |
" if not tool_str:\n",
|
| 488 |
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
|
|
|
| 489 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 490 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 491 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 492 |
" f\"{signals_str}\"\n",
|
| 493 |
" f\"{_format_history(history)}\"\n",
|
| 494 |
" f\"Tool results:\\n{tool_str}\"\n",
|
|
|
|
| 495 |
" f\"Plan today's actions (JSON only):\")\n",
|
| 496 |
"\n",
|
| 497 |
"\n",
|
|
@@ -615,12 +625,13 @@
|
|
| 615 |
" return out\n",
|
| 616 |
"\n",
|
| 617 |
"\n",
|
| 618 |
-
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None,
|
|
|
|
| 619 |
" \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
|
| 620 |
" sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
|
| 621 |
" n = len(tasks_seeds)\n",
|
| 622 |
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 623 |
-
" obss = [envs[i].reset(task=t, seed=s) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 624 |
" rewards = [[] for _ in range(n)]\n",
|
| 625 |
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 626 |
" pairs = [[] for _ in range(n)]\n",
|
|
@@ -641,7 +652,12 @@
|
|
| 641 |
"\n",
|
| 642 |
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 643 |
" if active:\n",
|
| 644 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
"\n",
|
| 646 |
" disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
|
| 647 |
" disc_resps, ptok = _gen(disc_prompts)\n",
|
|
@@ -716,7 +732,9 @@
|
|
| 716 |
"\n",
|
| 717 |
"\n",
|
| 718 |
"print(\"LLM agent functions defined (batched).\")"
|
| 719 |
-
]
|
|
|
|
|
|
|
| 720 |
},
|
| 721 |
{
|
| 722 |
"cell_type": "markdown",
|
|
@@ -729,9 +747,7 @@
|
|
| 729 |
},
|
| 730 |
{
|
| 731 |
"cell_type": "code",
|
| 732 |
-
"execution_count": null,
|
| 733 |
"metadata": {},
|
| 734 |
-
"outputs": [],
|
| 735 |
"source": [
|
| 736 |
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 737 |
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
|
@@ -745,7 +761,9 @@
|
|
| 745 |
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 746 |
"for t in TASKS:\n",
|
| 747 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 748 |
-
]
|
|
|
|
|
|
|
| 749 |
},
|
| 750 |
{
|
| 751 |
"cell_type": "markdown",
|
|
@@ -764,9 +782,7 @@
|
|
| 764 |
},
|
| 765 |
{
|
| 766 |
"cell_type": "code",
|
| 767 |
-
"execution_count": null,
|
| 768 |
"metadata": {},
|
| 769 |
-
"outputs": [],
|
| 770 |
"source": [
|
| 771 |
"# Cell 10: Attach LoRA adapter\n",
|
| 772 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
@@ -780,118 +796,144 @@
|
|
| 780 |
"model.enable_input_require_grads()\n",
|
| 781 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 782 |
"peft_model.print_trainable_parameters()"
|
| 783 |
-
]
|
|
|
|
|
|
|
| 784 |
},
|
| 785 |
{
|
| 786 |
"cell_type": "code",
|
| 787 |
-
"execution_count": null,
|
| 788 |
"metadata": {},
|
| 789 |
-
"outputs": [],
|
| 790 |
"source": [
|
| 791 |
-
"# Cell 11:
|
|
|
|
|
|
|
| 792 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 793 |
"from datasets import Dataset\n",
|
| 794 |
"\n",
|
| 795 |
-
"NUM_ROUNDS = 2\n",
|
| 796 |
"EPISODES_PER_ROUND = 6\n",
|
| 797 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
"\n",
|
| 799 |
"training_log = {\n",
|
| 800 |
-
" \"round\": [], \"
|
| 801 |
-
" \"
|
|
|
|
| 802 |
" \"n_training_samples\": [], \"train_loss\": [],\n",
|
| 803 |
"}\n",
|
| 804 |
"\n",
|
| 805 |
"t_start = time.time()\n",
|
| 806 |
-
"\n",
|
| 807 |
-
"
|
| 808 |
-
"
|
| 809 |
-
"
|
| 810 |
-
"
|
| 811 |
-
"\n",
|
| 812 |
-
"
|
| 813 |
-
"
|
| 814 |
-
"
|
| 815 |
-
"
|
| 816 |
-
"
|
| 817 |
-
"
|
| 818 |
-
"
|
| 819 |
-
"\n",
|
| 820 |
-
"
|
| 821 |
-
"
|
| 822 |
-
"
|
| 823 |
-
"
|
| 824 |
-
"
|
| 825 |
-
"
|
| 826 |
-
"
|
| 827 |
-
"
|
| 828 |
-
"
|
| 829 |
-
"
|
| 830 |
-
"
|
| 831 |
-
"
|
| 832 |
-
"
|
| 833 |
-
" kept
|
| 834 |
-
"
|
| 835 |
-
"
|
| 836 |
-
"\n",
|
| 837 |
-
"
|
| 838 |
-
"
|
| 839 |
-
"
|
| 840 |
-
"
|
| 841 |
-
"
|
| 842 |
-
"
|
| 843 |
-
"
|
| 844 |
-
"
|
| 845 |
-
"
|
| 846 |
-
"
|
| 847 |
-
"\n",
|
| 848 |
-
"
|
| 849 |
-
"
|
| 850 |
-
"
|
| 851 |
-
"
|
| 852 |
-
"
|
| 853 |
-
"
|
| 854 |
-
"
|
| 855 |
-
"\n",
|
| 856 |
-
"
|
| 857 |
-
"\n",
|
| 858 |
-
"
|
| 859 |
-
"
|
| 860 |
-
"
|
| 861 |
-
"
|
| 862 |
-
"
|
| 863 |
-
"
|
| 864 |
-
"
|
| 865 |
-
"
|
| 866 |
-
"
|
| 867 |
-
"
|
| 868 |
-
"
|
| 869 |
-
"
|
| 870 |
-
"
|
| 871 |
-
"
|
| 872 |
-
"\n",
|
| 873 |
-
"
|
| 874 |
-
"
|
| 875 |
-
"
|
| 876 |
-
"
|
| 877 |
-
"
|
| 878 |
-
"
|
| 879 |
-
"
|
| 880 |
-
"
|
| 881 |
-
"\n",
|
| 882 |
-
"
|
| 883 |
-
"
|
| 884 |
-
"
|
| 885 |
-
"
|
| 886 |
-
"
|
| 887 |
-
"
|
| 888 |
-
"
|
| 889 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
"\n",
|
| 891 |
"elapsed = time.time() - t_start\n",
|
| 892 |
-
"print(f\"\\
|
| 893 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 894 |
-
]
|
|
|
|
|
|
|
| 895 |
},
|
| 896 |
{
|
| 897 |
"cell_type": "markdown",
|
|
@@ -904,9 +946,7 @@
|
|
| 904 |
},
|
| 905 |
{
|
| 906 |
"cell_type": "code",
|
| 907 |
-
"execution_count": null,
|
| 908 |
"metadata": {},
|
| 909 |
-
"outputs": [],
|
| 910 |
"source": [
|
| 911 |
"# Cell 12: Run trained model (batched)\n",
|
| 912 |
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
|
@@ -921,7 +961,9 @@
|
|
| 921 |
"print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 922 |
"for t in TASKS:\n",
|
| 923 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 924 |
-
]
|
|
|
|
|
|
|
| 925 |
},
|
| 926 |
{
|
| 927 |
"cell_type": "markdown",
|
|
@@ -932,37 +974,41 @@
|
|
| 932 |
},
|
| 933 |
{
|
| 934 |
"cell_type": "code",
|
| 935 |
-
"execution_count": null,
|
| 936 |
"metadata": {},
|
| 937 |
-
"outputs": [],
|
| 938 |
"source": [
|
| 939 |
-
"# Cell 13: Training curves\n",
|
| 940 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 941 |
-
"
|
|
|
|
|
|
|
| 942 |
"\n",
|
| 943 |
-
"axes[0].plot(
|
| 944 |
-
"axes[0].fill_between(
|
| 945 |
" training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
|
| 946 |
-
"
|
| 947 |
-
"axes[0].
|
|
|
|
|
|
|
| 948 |
"axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
|
| 949 |
"\n",
|
| 950 |
-
"axes[1].plot(
|
| 951 |
-
"
|
|
|
|
|
|
|
| 952 |
"axes[1].set_title('Training Loss', fontweight='bold')\n",
|
| 953 |
"axes[1].grid(True, alpha=0.3)\n",
|
| 954 |
"\n",
|
| 955 |
-
"fig.suptitle('Viraltest v2 — LoRA Training
|
| 956 |
"fig.tight_layout()\n",
|
| 957 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 958 |
"plt.show()"
|
| 959 |
-
]
|
|
|
|
|
|
|
| 960 |
},
|
| 961 |
{
|
| 962 |
"cell_type": "code",
|
| 963 |
-
"execution_count": null,
|
| 964 |
"metadata": {},
|
| 965 |
-
"outputs": [],
|
| 966 |
"source": [
|
| 967 |
"# Cell 14: Before vs After\n",
|
| 968 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
@@ -992,13 +1038,13 @@
|
|
| 992 |
"fig.tight_layout()\n",
|
| 993 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 994 |
"plt.show()"
|
| 995 |
-
]
|
|
|
|
|
|
|
| 996 |
},
|
| 997 |
{
|
| 998 |
"cell_type": "code",
|
| 999 |
-
"execution_count": null,
|
| 1000 |
"metadata": {},
|
| 1001 |
-
"outputs": [],
|
| 1002 |
"source": [
|
| 1003 |
"# Cell 15: Trajectory comparison\n",
|
| 1004 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
@@ -1022,7 +1068,9 @@
|
|
| 1022 |
"fig.tight_layout()\n",
|
| 1023 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 1024 |
"plt.show()"
|
| 1025 |
-
]
|
|
|
|
|
|
|
| 1026 |
},
|
| 1027 |
{
|
| 1028 |
"cell_type": "markdown",
|
|
@@ -1033,9 +1081,7 @@
|
|
| 1033 |
},
|
| 1034 |
{
|
| 1035 |
"cell_type": "code",
|
| 1036 |
-
"execution_count": null,
|
| 1037 |
"metadata": {},
|
| 1038 |
-
"outputs": [],
|
| 1039 |
"source": [
|
| 1040 |
"# Cell 16: Final summary\n",
|
| 1041 |
"print(\"=\" * 67)\n",
|
|
@@ -1057,8 +1103,10 @@
|
|
| 1057 |
"\n",
|
| 1058 |
"summary = {\n",
|
| 1059 |
" \"model\": MODEL_NAME,\n",
|
| 1060 |
-
" \"training\": \"LoRA SFT (
|
| 1061 |
-
" \"
|
|
|
|
|
|
|
| 1062 |
" \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1063 |
" \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1064 |
" \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
|
|
@@ -1072,13 +1120,13 @@
|
|
| 1072 |
"\n",
|
| 1073 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 1074 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 1075 |
-
]
|
|
|
|
|
|
|
| 1076 |
},
|
| 1077 |
{
|
| 1078 |
"cell_type": "code",
|
| 1079 |
-
"execution_count": null,
|
| 1080 |
"metadata": {},
|
| 1081 |
-
"outputs": [],
|
| 1082 |
"source": [
|
| 1083 |
"# Cell 17: Save adapter\n",
|
| 1084 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
@@ -1086,7 +1134,9 @@
|
|
| 1086 |
"tokenizer.save_pretrained(save_path)\n",
|
| 1087 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 1088 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 1089 |
-
]
|
|
|
|
|
|
|
| 1090 |
}
|
| 1091 |
],
|
| 1092 |
"metadata": {
|
|
@@ -1112,4 +1162,4 @@
|
|
| 1112 |
},
|
| 1113 |
"nbformat": 4,
|
| 1114 |
"nbformat_minor": 4
|
| 1115 |
-
}
|
|
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
|
|
|
| 29 |
"source": [
|
| 30 |
"# Cell 1: Install dependencies (quote versions — zsh treats `>` as redirect otherwise)\n",
|
| 31 |
"!pip install -q torch torchvision torchaudio\n",
|
|
|
|
| 34 |
"!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
|
| 35 |
"!pip install -q \"openenv-core[core]>=0.2.2\"\n",
|
| 36 |
"!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
|
| 37 |
+
],
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"outputs": []
|
| 40 |
},
|
| 41 |
{
|
| 42 |
"cell_type": "code",
|
|
|
|
| 43 |
"metadata": {},
|
|
|
|
| 44 |
"source": [
|
| 45 |
"# Cell 2: Resolve repo path (Colab: fresh clone. Local: auto-detect project root)\n",
|
| 46 |
"import os\n",
|
|
|
|
| 116 |
"print(f\"Branch: {REPO_BRANCH}\")\n",
|
| 117 |
"print(f\"Commit: {commit}\")\n",
|
| 118 |
"print(f\"Plots dir: {PLOTS_DIR}\")"
|
| 119 |
+
],
|
| 120 |
+
"execution_count": null,
|
| 121 |
+
"outputs": []
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "code",
|
|
|
|
| 125 |
"metadata": {},
|
|
|
|
| 126 |
"source": [
|
| 127 |
"# Cell 3: Imports (with runtime validation)\n",
|
| 128 |
"import json, random, time, textwrap, copy, os, sys\n",
|
|
|
|
| 154 |
"from models import ScheduledAction, ToolCall, ViraltestAction\n",
|
| 155 |
"from server.viraltest_environment import (\n",
|
| 156 |
" ViraltestEnvironment, TAG_POOL, TASK_HORIZON,\n",
|
| 157 |
+
" TOPIC_CATEGORIES, get_peak_hours,\n",
|
| 158 |
")\n",
|
| 159 |
"\n",
|
| 160 |
"ALL_TOPICS = [t for topics in TOPIC_CATEGORIES.values() for t in topics]\n",
|
|
|
|
| 176 |
"import ast\n",
|
| 177 |
"ast.parse(\"def _t(x: int) -> str: return f'{x}'\")\n",
|
| 178 |
"print(\"OK: ast.parse (syntax check)\")"
|
| 179 |
+
],
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"outputs": []
|
| 182 |
},
|
| 183 |
{
|
| 184 |
"cell_type": "markdown",
|
|
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
|
|
|
| 194 |
"metadata": {},
|
|
|
|
| 195 |
"source": [
|
| 196 |
"# Cell 4: Define heuristic agents + episode runner\n",
|
| 197 |
"_rng = random.Random(42)\n",
|
|
|
|
| 267 |
" \"rewards\": rewards, \"energies\": energies}\n",
|
| 268 |
"\n",
|
| 269 |
"print(\"Agents and episode runner defined.\")"
|
| 270 |
+
],
|
| 271 |
+
"execution_count": null,
|
| 272 |
+
"outputs": []
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "code",
|
|
|
|
| 276 |
"metadata": {},
|
|
|
|
| 277 |
"source": [
|
| 278 |
"# Cell 5: Run baselines (safe)\n",
|
| 279 |
"print(\"Running heuristic baselines (5 agents × 3 tasks)...\")\n",
|
|
|
|
| 308 |
"for name in BASELINE_AGENTS:\n",
|
| 309 |
" scores = [baseline_results[name][t][\"grader_score\"] for t in TASKS]\n",
|
| 310 |
" print(f\"{name:<14s} {scores[0]:>10.4f} {scores[1]:>12.4f} {scores[2]:>14.4f} {sum(scores)/3:>8.4f}\")"
|
| 311 |
+
],
|
| 312 |
+
"execution_count": null,
|
| 313 |
+
"outputs": []
|
| 314 |
},
|
| 315 |
{
|
| 316 |
"cell_type": "code",
|
|
|
|
| 317 |
"metadata": {},
|
|
|
|
| 318 |
"source": [
|
| 319 |
"# Cell 6: Baseline plots\n",
|
| 320 |
"fig, axes = plt.subplots(1, 3, figsize=(16, 5), sharey=True)\n",
|
|
|
|
| 332 |
"fig.tight_layout()\n",
|
| 333 |
"fig.savefig(f\"{PLOTS_DIR}/baseline_leaderboard.png\", dpi=150, bbox_inches='tight')\n",
|
| 334 |
"plt.show()"
|
| 335 |
+
],
|
| 336 |
+
"execution_count": null,
|
| 337 |
+
"outputs": []
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "markdown",
|
|
|
|
| 347 |
},
|
| 348 |
{
|
| 349 |
"cell_type": "code",
|
|
|
|
| 350 |
"metadata": {},
|
|
|
|
| 351 |
"source": [
|
| 352 |
"# Cell 7: Load model (Qwen2.5-3B bf16 on CUDA + flash-attn-2; fp16/fp32 fallback)\n",
|
| 353 |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
|
|
|
| 391 |
"print(f\"Model loaded. dtype={next(model.parameters()).dtype} device={next(model.parameters()).device}\")\n",
|
| 392 |
"if torch.cuda.is_available():\n",
|
| 393 |
" print(f\"CUDA memory: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 394 |
+
],
|
| 395 |
+
"execution_count": null,
|
| 396 |
+
"outputs": []
|
| 397 |
},
|
| 398 |
{
|
| 399 |
"cell_type": "code",
|
|
|
|
| 400 |
"metadata": {},
|
|
|
|
| 401 |
"source": [
|
| 402 |
"# Cell 8: LLM agent functions\n",
|
| 403 |
"_SYSTEM_BASE = textwrap.dedent(\"\"\"\\\n",
|
|
|
|
| 452 |
"SYSTEM_PROMPT_EVAL = SYSTEM_PROMPT\n",
|
| 453 |
"SYSTEM_PROMPT_TRAIN = SYSTEM_PROMPT\n",
|
| 454 |
"\n",
|
| 455 |
+
"SYSTEM_PROMPT_TIMING = SYSTEM_PROMPT + textwrap.dedent(\"\"\"\n",
|
| 456 |
+
"\n",
|
| 457 |
+
"FOCUS: optimise WHEN to post. Identify peak hours for the audience (use query_audience / query_trends).\n",
|
| 458 |
+
"2 posts/day at peak hours beats 4 posts at random hours.\"\"\")\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"SYSTEM_PROMPT_CONTENT = SYSTEM_PROMPT + textwrap.dedent(\"\"\"\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"FOCUS: optimise WHAT to post. Vary content_type and intent across the week,\n",
|
| 463 |
+
"pick differentiated topics, exploit trending tags.\"\"\")\n",
|
| 464 |
+
"\n",
|
| 465 |
"\n",
|
| 466 |
"_DAY_NAMES = [\"Mon\", \"Tue\", \"Wed\", \"Thu\", \"Fri\", \"Sat\", \"Sun\"]\n",
|
| 467 |
"\n",
|
|
|
|
| 480 |
" return out\n",
|
| 481 |
"\n",
|
| 482 |
"\n",
|
| 483 |
+
"def format_obs(obs, history=None, extra_hint=None):\n",
|
| 484 |
" day_name = _DAY_NAMES[obs.day_of_week] if 0 <= obs.day_of_week < 7 else \"?\"\n",
|
| 485 |
" signals_str = \"\"\n",
|
| 486 |
" signals = getattr(obs, \"engagement_signals\", None)\n",
|
|
|
|
| 494 |
" tool_str += f\" {tr.name}: {json.dumps(tr.data)}\\n\"\n",
|
| 495 |
" if not tool_str:\n",
|
| 496 |
" tool_str = \" (none — call query_* tools to discover)\\n\"\n",
|
| 497 |
+
" hint_str = f\"Coach hint: today's peak hours are {extra_hint}.\\n\" if extra_hint else \"\"\n",
|
| 498 |
" return (f\"Day: {day_name} | days_elapsed={obs.days_elapsed}\\n\"\n",
|
| 499 |
" f\"Energy: {obs.creator_energy:.2f} | Followers: {obs.follower_count}\\n\"\n",
|
| 500 |
" f\"Engagement: {obs.engagement_rate:.3f} | Queue: {obs.content_queue_size}\\n\"\n",
|
| 501 |
" f\"{signals_str}\"\n",
|
| 502 |
" f\"{_format_history(history)}\"\n",
|
| 503 |
" f\"Tool results:\\n{tool_str}\"\n",
|
| 504 |
+
" f\"{hint_str}\"\n",
|
| 505 |
" f\"Plan today's actions (JSON only):\")\n",
|
| 506 |
"\n",
|
| 507 |
"\n",
|
|
|
|
| 625 |
" return out\n",
|
| 626 |
"\n",
|
| 627 |
"\n",
|
| 628 |
+
"def run_llm_episodes_batched(mdl, tok, tasks_seeds, verbose=True, eval=False, system=None,\n",
|
| 629 |
+
" log_tag=None, hint_peak_hours=False, reward_mode=\"combined\"):\n",
|
| 630 |
" \"\"\"Run N episodes in parallel. ReAct two-pass: discovery -> dispatch -> planning.\"\"\"\n",
|
| 631 |
" sys_prompt = system or (SYSTEM_PROMPT_EVAL if eval else SYSTEM_PROMPT_TRAIN)\n",
|
| 632 |
" n = len(tasks_seeds)\n",
|
| 633 |
" envs = [ViraltestEnvironment() for _ in range(n)]\n",
|
| 634 |
+
" obss = [envs[i].reset(task=t, seed=s, reward_mode=reward_mode) for i, (t, s) in enumerate(tasks_seeds)]\n",
|
| 635 |
" rewards = [[] for _ in range(n)]\n",
|
| 636 |
" energies = [[obs.creator_energy] for obs in obss]\n",
|
| 637 |
" pairs = [[] for _ in range(n)]\n",
|
|
|
|
| 652 |
"\n",
|
| 653 |
" actions_by_idx = {i: rest_action for i in rest}\n",
|
| 654 |
" if active:\n",
|
| 655 |
+
" def _hint_for(i):\n",
|
| 656 |
+
" if not hint_peak_hours:\n",
|
| 657 |
+
" return None\n",
|
| 658 |
+
" hrs = get_peak_hours(obss[i].day_of_week, top_k=2)\n",
|
| 659 |
+
" return \", \".join(f\"{h:02d}:00\" for h in hrs) if hrs else None\n",
|
| 660 |
+
" base_prompts = [format_obs(obss[i], histories[i], extra_hint=_hint_for(i)) for i in active]\n",
|
| 661 |
"\n",
|
| 662 |
" disc_prompts = [p + DISCOVERY_SUFFIX for p in base_prompts]\n",
|
| 663 |
" disc_resps, ptok = _gen(disc_prompts)\n",
|
|
|
|
| 732 |
"\n",
|
| 733 |
"\n",
|
| 734 |
"print(\"LLM agent functions defined (batched).\")"
|
| 735 |
+
],
|
| 736 |
+
"execution_count": null,
|
| 737 |
+
"outputs": []
|
| 738 |
},
|
| 739 |
{
|
| 740 |
"cell_type": "markdown",
|
|
|
|
| 747 |
},
|
| 748 |
{
|
| 749 |
"cell_type": "code",
|
|
|
|
| 750 |
"metadata": {},
|
|
|
|
| 751 |
"source": [
|
| 752 |
"# Cell 9: Run untrained model (batched: all 3 tasks in parallel envs)\n",
|
| 753 |
"print(\"Running UNTRAINED base model on all tasks (batched)...\")\n",
|
|
|
|
| 761 |
"print(f\"BEFORE TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 762 |
"for t in TASKS:\n",
|
| 763 |
" print(f\" {t}: grader={before_results[t]['grader_score']:.4f}\")"
|
| 764 |
+
],
|
| 765 |
+
"execution_count": null,
|
| 766 |
+
"outputs": []
|
| 767 |
},
|
| 768 |
{
|
| 769 |
"cell_type": "markdown",
|
|
|
|
| 782 |
},
|
| 783 |
{
|
| 784 |
"cell_type": "code",
|
|
|
|
| 785 |
"metadata": {},
|
|
|
|
| 786 |
"source": [
|
| 787 |
"# Cell 10: Attach LoRA adapter\n",
|
| 788 |
"from peft import LoraConfig, get_peft_model, TaskType\n",
|
|
|
|
| 796 |
"model.enable_input_require_grads()\n",
|
| 797 |
"peft_model = get_peft_model(model, lora_config)\n",
|
| 798 |
"peft_model.print_trainable_parameters()"
|
| 799 |
+
],
|
| 800 |
+
"execution_count": null,
|
| 801 |
+
"outputs": []
|
| 802 |
},
|
| 803 |
{
|
| 804 |
"cell_type": "code",
|
|
|
|
| 805 |
"metadata": {},
|
|
|
|
| 806 |
"source": [
|
| 807 |
+
"# Cell 11: Two-phase training loop (timing -> content)\n",
|
| 808 |
+
"# Each phase: 3 rounds (round 0 = hardcoded peak-hours hint, rounds 1-2 = normal prompt).\n",
|
| 809 |
+
"# Adapter persisted to ./checkpoints/phaseN_adapter/ between phases.\n",
|
| 810 |
"from trl import SFTTrainer, SFTConfig\n",
|
| 811 |
"from datasets import Dataset\n",
|
| 812 |
"\n",
|
|
|
|
| 813 |
"EPISODES_PER_ROUND = 6\n",
|
| 814 |
+
"ROUNDS_PER_PHASE = 3\n",
|
| 815 |
+
"QUALITY_FLOOR = 0.0\n",
|
| 816 |
+
"\n",
|
| 817 |
+
"PHASES = [\n",
|
| 818 |
+
" {\"name\": \"phase1_timing\", \"reward_mode\": \"timing\", \"system\": SYSTEM_PROMPT_TIMING},\n",
|
| 819 |
+
" {\"name\": \"phase2_content\", \"reward_mode\": \"content\", \"system\": SYSTEM_PROMPT_CONTENT},\n",
|
| 820 |
+
"]\n",
|
| 821 |
"\n",
|
| 822 |
"training_log = {\n",
|
| 823 |
+
" \"phase\": [], \"round\": [], \"global_step\": [], \"use_hint\": [],\n",
|
| 824 |
+
" \"avg_episode_reward\": [], \"max_episode_reward\": [], \"min_episode_reward\": [],\n",
|
| 825 |
+
" \"avg_grader\": [], \"max_grader\": [],\n",
|
| 826 |
" \"n_training_samples\": [], \"train_loss\": [],\n",
|
| 827 |
"}\n",
|
| 828 |
"\n",
|
| 829 |
"t_start = time.time()\n",
|
| 830 |
+
"global_step = 0\n",
|
| 831 |
+
"\n",
|
| 832 |
+
"for phase in PHASES:\n",
|
| 833 |
+
" phase_name = phase[\"name\"]\n",
|
| 834 |
+
" sys_prompt = phase[\"system\"]\n",
|
| 835 |
+
" reward_mode = phase[\"reward_mode\"]\n",
|
| 836 |
+
" print(f\"\\n{'#' * 60}\\n# PHASE {phase_name} (reward_mode={reward_mode})\\n{'#' * 60}\")\n",
|
| 837 |
+
"\n",
|
| 838 |
+
" for round_idx in range(ROUNDS_PER_PHASE):\n",
|
| 839 |
+
" use_hint = (round_idx == 0)\n",
|
| 840 |
+
" print(f\"\\n{'=' * 60}\\n{phase_name} | ROUND {round_idx+1}/{ROUNDS_PER_PHASE} | hint={use_hint}\\n{'=' * 60}\")\n",
|
| 841 |
+
"\n",
|
| 842 |
+
" peft_model.eval()\n",
|
| 843 |
+
" tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + ep + round_idx * 10) for ep in range(EPISODES_PER_ROUND)]\n",
|
| 844 |
+
" t_roll = time.time()\n",
|
| 845 |
+
" results = run_llm_episodes_batched(\n",
|
| 846 |
+
" peft_model, tokenizer, tasks_seeds, verbose=False, eval=False,\n",
|
| 847 |
+
" system=sys_prompt, hint_peak_hours=use_hint, reward_mode=reward_mode,\n",
|
| 848 |
+
" log_tag=f\"{phase_name}_r{round_idx}\",\n",
|
| 849 |
+
" )\n",
|
| 850 |
+
" print(f\" Rollouts: {len(results)} eps × {TASK_HORIZON} days in {time.time()-t_roll:.1f}s\")\n",
|
| 851 |
+
"\n",
|
| 852 |
+
" all_pairs, episode_rewards, episode_graders = [], [], []\n",
|
| 853 |
+
" for ep, result in enumerate(results):\n",
|
| 854 |
+
" ep_reward = result[\"total_reward\"] + 2.0 * result[\"grader_score\"]\n",
|
| 855 |
+
" episode_rewards.append(ep_reward)\n",
|
| 856 |
+
" episode_graders.append(result[\"grader_score\"])\n",
|
| 857 |
+
" kept = 0\n",
|
| 858 |
+
" for pr in result[\"pairs\"]:\n",
|
| 859 |
+
" if not is_well_formed_response(pr[\"response\"]):\n",
|
| 860 |
+
" continue\n",
|
| 861 |
+
" text = (f\"<|im_start|>system\\n{sys_prompt}<|im_end|>\\n\"\n",
|
| 862 |
+
" f\"<|im_start|>user\\n{pr['prompt']}<|im_end|>\\n\"\n",
|
| 863 |
+
" f\"<|im_start|>assistant\\n{pr['response']}<|im_end|>\")\n",
|
| 864 |
+
" all_pairs.append({\"text\": text, \"reward\": pr[\"return\"]})\n",
|
| 865 |
+
" kept += 1\n",
|
| 866 |
+
" print(f\" ep {ep+1}/{EPISODES_PER_ROUND}: {result['task'].split('_')[-1]:>11s} \"\n",
|
| 867 |
+
" f\"grader={result['grader_score']:.4f} reward={ep_reward:.3f} kept={kept}/{len(result['pairs'])}\")\n",
|
| 868 |
+
"\n",
|
| 869 |
+
" avg_r = float(np.mean(episode_rewards))\n",
|
| 870 |
+
" avg_g = float(np.mean(episode_graders))\n",
|
| 871 |
+
" max_g = float(max(episode_graders))\n",
|
| 872 |
+
" print(f\" Avg reward={avg_r:.3f} Avg grader={avg_g:.4f} max_grader={max_g:.4f} | pairs={len(all_pairs)}\")\n",
|
| 873 |
+
"\n",
|
| 874 |
+
" loss = float(\"nan\")\n",
|
| 875 |
+
" n_filtered = 0\n",
|
| 876 |
+
" if not all_pairs:\n",
|
| 877 |
+
" print(\" WARNING: 0 well-formed pairs collected; skipping SFT.\")\n",
|
| 878 |
+
" elif max_g < QUALITY_FLOOR:\n",
|
| 879 |
+
" print(f\" SKIP SFT: no episode beat quality_floor={QUALITY_FLOOR:.2f}\")\n",
|
| 880 |
+
" else:\n",
|
| 881 |
+
" rets = np.array([p[\"reward\"] for p in all_pairs], dtype=float)\n",
|
| 882 |
+
" adv = (rets - rets.mean()) / (rets.std() + 1e-6)\n",
|
| 883 |
+
" filtered = [p for p, a in zip(all_pairs, adv) if a > 0.0]\n",
|
| 884 |
+
" if not filtered:\n",
|
| 885 |
+
" print(\" SKIP SFT: zero positive-advantage samples\")\n",
|
| 886 |
+
" else:\n",
|
| 887 |
+
" n_filtered = len(filtered)\n",
|
| 888 |
+
" print(f\" Kept {n_filtered}/{len(all_pairs)} positive-advantage samples\")\n",
|
| 889 |
+
" dataset = Dataset.from_list([{\"text\": p[\"text\"]} for p in filtered])\n",
|
| 890 |
+
" sft_config = SFTConfig(\n",
|
| 891 |
+
" output_dir=f\"./checkpoints/{phase_name}_r{round_idx}\",\n",
|
| 892 |
+
" num_train_epochs=1,\n",
|
| 893 |
+
" per_device_train_batch_size=2,\n",
|
| 894 |
+
" gradient_accumulation_steps=4,\n",
|
| 895 |
+
" learning_rate=5e-6,\n",
|
| 896 |
+
" warmup_steps=5,\n",
|
| 897 |
+
" logging_steps=1,\n",
|
| 898 |
+
" save_strategy=\"no\",\n",
|
| 899 |
+
" max_length=2048,\n",
|
| 900 |
+
" bf16=True,\n",
|
| 901 |
+
" report_to=\"none\",\n",
|
| 902 |
+
" )\n",
|
| 903 |
+
" peft_model.train()\n",
|
| 904 |
+
" trainer = SFTTrainer(\n",
|
| 905 |
+
" model=peft_model, processing_class=tokenizer,\n",
|
| 906 |
+
" train_dataset=dataset, args=sft_config,\n",
|
| 907 |
+
" )\n",
|
| 908 |
+
" train_result = trainer.train()\n",
|
| 909 |
+
" loss = float(train_result.training_loss)\n",
|
| 910 |
+
" print(f\" Training loss: {loss:.4f}\")\n",
|
| 911 |
+
"\n",
|
| 912 |
+
" global_step += 1\n",
|
| 913 |
+
" training_log[\"phase\"].append(phase_name)\n",
|
| 914 |
+
" training_log[\"round\"].append(round_idx + 1)\n",
|
| 915 |
+
" training_log[\"global_step\"].append(global_step)\n",
|
| 916 |
+
" training_log[\"use_hint\"].append(use_hint)\n",
|
| 917 |
+
" training_log[\"avg_episode_reward\"].append(round(float(avg_r), 3))\n",
|
| 918 |
+
" training_log[\"max_episode_reward\"].append(round(float(max(episode_rewards)), 3))\n",
|
| 919 |
+
" training_log[\"min_episode_reward\"].append(round(float(min(episode_rewards)), 3))\n",
|
| 920 |
+
" training_log[\"avg_grader\"].append(round(float(avg_g), 4))\n",
|
| 921 |
+
" training_log[\"max_grader\"].append(round(float(max(episode_graders)), 4))\n",
|
| 922 |
+
" training_log[\"n_training_samples\"].append(n_filtered)\n",
|
| 923 |
+
" training_log[\"train_loss\"].append(round(loss, 4) if loss == loss else float(\"nan\"))\n",
|
| 924 |
+
"\n",
|
| 925 |
+
" save_dir = f\"./checkpoints/{phase_name}_adapter\"\n",
|
| 926 |
+
" os.makedirs(save_dir, exist_ok=True)\n",
|
| 927 |
+
" peft_model.save_pretrained(save_dir)\n",
|
| 928 |
+
" tokenizer.save_pretrained(save_dir)\n",
|
| 929 |
+
" print(f\"\\n Saved {phase_name} adapter -> {save_dir}\")\n",
|
| 930 |
"\n",
|
| 931 |
"elapsed = time.time() - t_start\n",
|
| 932 |
+
"print(f\"\\nTwo-phase training complete in {elapsed/60:.1f} min\")\n",
|
| 933 |
"print(pd.DataFrame(training_log).to_string(index=False))"
|
| 934 |
+
],
|
| 935 |
+
"execution_count": null,
|
| 936 |
+
"outputs": []
|
| 937 |
},
|
| 938 |
{
|
| 939 |
"cell_type": "markdown",
|
|
|
|
| 946 |
},
|
| 947 |
{
|
| 948 |
"cell_type": "code",
|
|
|
|
| 949 |
"metadata": {},
|
|
|
|
| 950 |
"source": [
|
| 951 |
"# Cell 12: Run trained model (batched)\n",
|
| 952 |
"print(\"Running TRAINED model on all tasks (batched)...\")\n",
|
|
|
|
| 961 |
"print(f\"AFTER TRAINING (took {time.time()-t0:.1f}s):\")\n",
|
| 962 |
"for t in TASKS:\n",
|
| 963 |
" print(f\" {t}: grader={after_results[t]['grader_score']:.4f}\")"
|
| 964 |
+
],
|
| 965 |
+
"execution_count": null,
|
| 966 |
+
"outputs": []
|
| 967 |
},
|
| 968 |
{
|
| 969 |
"cell_type": "markdown",
|
|
|
|
| 974 |
},
|
| 975 |
{
|
| 976 |
"cell_type": "code",
|
|
|
|
| 977 |
"metadata": {},
|
|
|
|
| 978 |
"source": [
|
| 979 |
+
"# Cell 13: Training curves (two-phase)\n",
|
| 980 |
"fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 981 |
+
"steps = training_log[\"global_step\"]\n",
|
| 982 |
+
"phases = training_log[\"phase\"]\n",
|
| 983 |
+
"phase1_end = max([s for s, p in zip(steps, phases) if p == \"phase1_timing\"], default=0)\n",
|
| 984 |
"\n",
|
| 985 |
+
"axes[0].plot(steps, training_log[\"avg_grader\"], 'o-', color='#2196F3', lw=2, label='Avg grader')\n",
|
| 986 |
+
"axes[0].fill_between(steps, training_log[\"avg_grader\"],\n",
|
| 987 |
" training_log[\"max_grader\"], alpha=0.2, color='#2196F3')\n",
|
| 988 |
+
"if phase1_end > 0:\n",
|
| 989 |
+
" axes[0].axvline(phase1_end + 0.5, color='gray', ls='--', alpha=0.6, label='phase split')\n",
|
| 990 |
+
"axes[0].set_xlabel('Global step'); axes[0].set_ylabel('Grader Score')\n",
|
| 991 |
+
"axes[0].set_title('Grader Score (timing -> content)', fontweight='bold')\n",
|
| 992 |
"axes[0].legend(); axes[0].grid(True, alpha=0.3)\n",
|
| 993 |
"\n",
|
| 994 |
+
"axes[1].plot(steps, training_log[\"train_loss\"], 's-', color='#E53935', lw=2)\n",
|
| 995 |
+
"if phase1_end > 0:\n",
|
| 996 |
+
" axes[1].axvline(phase1_end + 0.5, color='gray', ls='--', alpha=0.6)\n",
|
| 997 |
+
"axes[1].set_xlabel('Global step'); axes[1].set_ylabel('Loss')\n",
|
| 998 |
"axes[1].set_title('Training Loss', fontweight='bold')\n",
|
| 999 |
"axes[1].grid(True, alpha=0.3)\n",
|
| 1000 |
"\n",
|
| 1001 |
+
"fig.suptitle('Viraltest v2 — Two-Phase LoRA Training (timing -> content)', fontsize=14, fontweight='bold')\n",
|
| 1002 |
"fig.tight_layout()\n",
|
| 1003 |
"fig.savefig(f'{PLOTS_DIR}/reward_curve.png', dpi=150, bbox_inches='tight')\n",
|
| 1004 |
"plt.show()"
|
| 1005 |
+
],
|
| 1006 |
+
"execution_count": null,
|
| 1007 |
+
"outputs": []
|
| 1008 |
},
|
| 1009 |
{
|
| 1010 |
"cell_type": "code",
|
|
|
|
| 1011 |
"metadata": {},
|
|
|
|
| 1012 |
"source": [
|
| 1013 |
"# Cell 14: Before vs After\n",
|
| 1014 |
"task_labels = [t.replace('monthly_', '').title() for t in TASKS]\n",
|
|
|
|
| 1038 |
"fig.tight_layout()\n",
|
| 1039 |
"fig.savefig(f'{PLOTS_DIR}/before_after.png', dpi=150, bbox_inches='tight')\n",
|
| 1040 |
"plt.show()"
|
| 1041 |
+
],
|
| 1042 |
+
"execution_count": null,
|
| 1043 |
+
"outputs": []
|
| 1044 |
},
|
| 1045 |
{
|
| 1046 |
"cell_type": "code",
|
|
|
|
| 1047 |
"metadata": {},
|
|
|
|
| 1048 |
"source": [
|
| 1049 |
"# Cell 15: Trajectory comparison\n",
|
| 1050 |
"fig, axes = plt.subplots(2, 3, figsize=(16, 8))\n",
|
|
|
|
| 1068 |
"fig.tight_layout()\n",
|
| 1069 |
"fig.savefig(f'{PLOTS_DIR}/training_trajectories.png', dpi=150, bbox_inches='tight')\n",
|
| 1070 |
"plt.show()"
|
| 1071 |
+
],
|
| 1072 |
+
"execution_count": null,
|
| 1073 |
+
"outputs": []
|
| 1074 |
},
|
| 1075 |
{
|
| 1076 |
"cell_type": "markdown",
|
|
|
|
| 1081 |
},
|
| 1082 |
{
|
| 1083 |
"cell_type": "code",
|
|
|
|
| 1084 |
"metadata": {},
|
|
|
|
| 1085 |
"source": [
|
| 1086 |
"# Cell 16: Final summary\n",
|
| 1087 |
"print(\"=\" * 67)\n",
|
|
|
|
| 1103 |
"\n",
|
| 1104 |
"summary = {\n",
|
| 1105 |
" \"model\": MODEL_NAME,\n",
|
| 1106 |
+
" \"training\": \"Two-phase LoRA SFT (timing -> content) with hardcoded peak-hours hint on round 1 of each phase\",\n",
|
| 1107 |
+
" \"phases\": [p[\"name\"] for p in PHASES],\n",
|
| 1108 |
+
" \"rounds_per_phase\": ROUNDS_PER_PHASE,\n",
|
| 1109 |
+
" \"episodes_per_round\": EPISODES_PER_ROUND,\n",
|
| 1110 |
" \"before\": {t: before_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1111 |
" \"after\": {t: after_results[t][\"grader_score\"] for t in TASKS},\n",
|
| 1112 |
" \"smart_heuristic\": {t: baseline_results[\"smart\"][t][\"grader_score\"] for t in TASKS},\n",
|
|
|
|
| 1120 |
"\n",
|
| 1121 |
"print(f\"\\nSaved to {PLOTS_DIR}/\")\n",
|
| 1122 |
"print(\"All results are from real LoRA weight updates on real environment runs.\")"
|
| 1123 |
+
],
|
| 1124 |
+
"execution_count": null,
|
| 1125 |
+
"outputs": []
|
| 1126 |
},
|
| 1127 |
{
|
| 1128 |
"cell_type": "code",
|
|
|
|
| 1129 |
"metadata": {},
|
|
|
|
| 1130 |
"source": [
|
| 1131 |
"# Cell 17: Save adapter\n",
|
| 1132 |
"save_path = \"./viraltest_trained_adapter\"\n",
|
|
|
|
| 1134 |
"tokenizer.save_pretrained(save_path)\n",
|
| 1135 |
"print(f\"LoRA adapter saved to {save_path}\")\n",
|
| 1136 |
"print(\"Load with: PeftModel.from_pretrained(base_model, save_path)\")"
|
| 1137 |
+
],
|
| 1138 |
+
"execution_count": null,
|
| 1139 |
+
"outputs": []
|
| 1140 |
}
|
| 1141 |
],
|
| 1142 |
"metadata": {
|
|
|
|
| 1162 |
},
|
| 1163 |
"nbformat": 4,
|
| 1164 |
"nbformat_minor": 4
|
| 1165 |
+
}
|