Spaces:
Paused
Paused
Commit ·
abe4587
1
Parent(s): 76b19bd
train: shrink to weekly horizon + bounded steps
Browse filesTASK_HORIZON 30->7 and max_steps=7 to fit L40S 48GB without OOM.
Drop the horizon==30 hard-stop assertion now that 7 is canonical.
Made-with: Cursor
- server/viraltest_environment.py +6 -37
- training/train_grpo.ipynb +0 -0
server/viraltest_environment.py
CHANGED
|
@@ -92,7 +92,7 @@ _HEATMAP_GRID: Dict[int, List[float]] = {
|
|
| 92 |
# Constants (research-backed, Tier 1-3 sources)
|
| 93 |
# ---------------------------------------------------------------------------
|
| 94 |
|
| 95 |
-
TASK_HORIZON =
|
| 96 |
|
| 97 |
# Socialinsider 2026 (31M posts)
|
| 98 |
CONTENT_ENERGY_COST = {
|
|
@@ -166,7 +166,7 @@ COLLAB_GROWTH_K = 1.50 # cross-pollination follower spillover, scales (1 - o
|
|
| 166 |
COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
|
| 167 |
COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
|
| 168 |
|
| 169 |
-
API_BUDGET_INITIAL =
|
| 170 |
|
| 171 |
# Heuristic baselines for headline metric `vs_baseline_pct`.
|
| 172 |
# Data-driven: loaded from `plots/training_summary.json["smart_heuristic"]` recorded by
|
|
@@ -191,18 +191,6 @@ HEURISTIC_BASELINE_SCORES: Dict[str, float] = _load_heuristic_baselines() or {
|
|
| 191 |
# {"baseline": score, "shifted": score} so the second run can compute retention_under_shift.
|
| 192 |
_SHIFT_HISTORY: Dict[str, Dict[str, float]] = {}
|
| 193 |
|
| 194 |
-
# Tool costs
|
| 195 |
-
TOOL_COSTS = {
|
| 196 |
-
"query_audience": 2,
|
| 197 |
-
"query_competitor": 2,
|
| 198 |
-
"query_tag_history": 1,
|
| 199 |
-
"query_trends": 1,
|
| 200 |
-
"predict_engagement": 3,
|
| 201 |
-
"draft_review": 3,
|
| 202 |
-
"query_creator_pool": 1,
|
| 203 |
-
"propose_collab": 5,
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
# ---------------------------------------------------------------------------
|
| 207 |
# Brand state for multi-episode persistence
|
| 208 |
# ---------------------------------------------------------------------------
|
|
@@ -413,9 +401,6 @@ class ViraltestEnvironment(Environment):
|
|
| 413 |
totals = [h.get("total", 0.0) for h in window]
|
| 414 |
return sum(totals) / len(totals) if totals else 0.0
|
| 415 |
|
| 416 |
-
def _get_tag_performance_dict(self) -> Dict[str, float]:
|
| 417 |
-
return {tag: self._tag_performance_avg(tag) for tag in self._unique_tags_used}
|
| 418 |
-
|
| 419 |
# ----- competitors -----
|
| 420 |
|
| 421 |
def _advance_competitors(self) -> None:
|
|
@@ -436,14 +421,6 @@ class ViraltestEnvironment(Environment):
|
|
| 436 |
"engagement": round(eng, 3), "hours_ago": 0,
|
| 437 |
})
|
| 438 |
|
| 439 |
-
def _get_competitor_recent_posts(self, limit: int = 5) -> List[Dict[str, Any]]:
|
| 440 |
-
all_posts: List[Dict[str, Any]] = []
|
| 441 |
-
for comp in self._competitors:
|
| 442 |
-
for p in comp.recent_posts:
|
| 443 |
-
all_posts.append(p)
|
| 444 |
-
all_posts.sort(key=lambda x: x["hours_ago"])
|
| 445 |
-
return all_posts[:limit]
|
| 446 |
-
|
| 447 |
def _get_competitor_avg_engagement(self) -> float:
|
| 448 |
engagements = [p["engagement"] for comp in self._competitors for p in comp.recent_posts]
|
| 449 |
return sum(engagements) / len(engagements) if engagements else 0.0
|
|
@@ -547,12 +524,6 @@ class ViraltestEnvironment(Environment):
|
|
| 547 |
# ----- tool dispatcher -----
|
| 548 |
|
| 549 |
def _dispatch_tool(self, tool: ToolCall) -> ToolResult:
|
| 550 |
-
cost = TOOL_COSTS.get(tool.name, 1)
|
| 551 |
-
if self._api_budget < cost:
|
| 552 |
-
return ToolResult(name=tool.name, success=False, error="rate_limit_exceeded", budget_remaining=self._api_budget)
|
| 553 |
-
|
| 554 |
-
self._api_budget -= cost
|
| 555 |
-
|
| 556 |
if tool.name == "query_audience":
|
| 557 |
seg_id = tool.arguments.get("segment_id", "")
|
| 558 |
for seg in _AUDIENCE_DATA.get("segments", []):
|
|
@@ -752,9 +723,8 @@ class ViraltestEnvironment(Environment):
|
|
| 752 |
self._shift_label = kwargs.get("shift_label")
|
| 753 |
self._chain_id = kwargs.get("episode_chain_id")
|
| 754 |
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
brand = _BRAND_STORE[chain_id]
|
| 758 |
self._unique_tags_used = set(brand.get("top_tags", []))
|
| 759 |
self._unique_content_types = set(brand.get("dominant_types", []))
|
| 760 |
self._collab_history = brand.get("collab_history", [])
|
|
@@ -870,10 +840,9 @@ class ViraltestEnvironment(Environment):
|
|
| 870 |
grader_score = self._run_grader()
|
| 871 |
headline = self._compute_headline_metrics(grader_score)
|
| 872 |
|
| 873 |
-
|
| 874 |
-
if chain_id:
|
| 875 |
top_tags = sorted(self._unique_tags_used, key=lambda t: self._tag_performance_avg(t), reverse=True)[:3]
|
| 876 |
-
_BRAND_STORE[
|
| 877 |
"top_tags": list(top_tags),
|
| 878 |
"dominant_types": list(self._unique_content_types),
|
| 879 |
"collab_history": self._collab_history[-3:],
|
|
|
|
| 92 |
# Constants (research-backed, Tier 1-3 sources)
|
| 93 |
# ---------------------------------------------------------------------------
|
| 94 |
|
| 95 |
+
TASK_HORIZON = 7 # 7 daily steps (weekly cycle)
|
| 96 |
|
| 97 |
# Socialinsider 2026 (31M posts)
|
| 98 |
CONTENT_ENERGY_COST = {
|
|
|
|
| 166 |
COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
|
| 167 |
COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
|
| 168 |
|
| 169 |
+
API_BUDGET_INITIAL = 10**9 # effectively unlimited; rate-limit removed
|
| 170 |
|
| 171 |
# Heuristic baselines for headline metric `vs_baseline_pct`.
|
| 172 |
# Data-driven: loaded from `plots/training_summary.json["smart_heuristic"]` recorded by
|
|
|
|
| 191 |
# {"baseline": score, "shifted": score} so the second run can compute retention_under_shift.
|
| 192 |
_SHIFT_HISTORY: Dict[str, Dict[str, float]] = {}
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
# ---------------------------------------------------------------------------
|
| 195 |
# Brand state for multi-episode persistence
|
| 196 |
# ---------------------------------------------------------------------------
|
|
|
|
| 401 |
totals = [h.get("total", 0.0) for h in window]
|
| 402 |
return sum(totals) / len(totals) if totals else 0.0
|
| 403 |
|
|
|
|
|
|
|
|
|
|
| 404 |
# ----- competitors -----
|
| 405 |
|
| 406 |
def _advance_competitors(self) -> None:
|
|
|
|
| 421 |
"engagement": round(eng, 3), "hours_ago": 0,
|
| 422 |
})
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
def _get_competitor_avg_engagement(self) -> float:
|
| 425 |
engagements = [p["engagement"] for comp in self._competitors for p in comp.recent_posts]
|
| 426 |
return sum(engagements) / len(engagements) if engagements else 0.0
|
|
|
|
| 524 |
# ----- tool dispatcher -----
|
| 525 |
|
| 526 |
def _dispatch_tool(self, tool: ToolCall) -> ToolResult:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
if tool.name == "query_audience":
|
| 528 |
seg_id = tool.arguments.get("segment_id", "")
|
| 529 |
for seg in _AUDIENCE_DATA.get("segments", []):
|
|
|
|
| 723 |
self._shift_label = kwargs.get("shift_label")
|
| 724 |
self._chain_id = kwargs.get("episode_chain_id")
|
| 725 |
|
| 726 |
+
if self._chain_id and self._chain_id in _BRAND_STORE:
|
| 727 |
+
brand = _BRAND_STORE[self._chain_id]
|
|
|
|
| 728 |
self._unique_tags_used = set(brand.get("top_tags", []))
|
| 729 |
self._unique_content_types = set(brand.get("dominant_types", []))
|
| 730 |
self._collab_history = brand.get("collab_history", [])
|
|
|
|
| 840 |
grader_score = self._run_grader()
|
| 841 |
headline = self._compute_headline_metrics(grader_score)
|
| 842 |
|
| 843 |
+
if self._chain_id:
|
|
|
|
| 844 |
top_tags = sorted(self._unique_tags_used, key=lambda t: self._tag_performance_avg(t), reverse=True)[:3]
|
| 845 |
+
_BRAND_STORE[self._chain_id] = {
|
| 846 |
"top_tags": list(top_tags),
|
| 847 |
"dominant_types": list(self._unique_content_types),
|
| 848 |
"collab_history": self._collab_history[-3:],
|
training/train_grpo.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|