vaibhav12332112312 commited on
Commit
abe4587
·
1 Parent(s): 76b19bd

train: shrink to weekly horizon + bounded steps

Browse files

TASK_HORIZON 30->7 and max_steps=7 to fit L40S 48GB without OOM.
Drop the horizon==30 hard-stop assertion now that 7 is canonical.

Made-with: Cursor

server/viraltest_environment.py CHANGED
@@ -92,7 +92,7 @@ _HEATMAP_GRID: Dict[int, List[float]] = {
92
  # Constants (research-backed, Tier 1-3 sources)
93
  # ---------------------------------------------------------------------------
94
 
95
- TASK_HORIZON = 30 # 30 daily steps (monthly cycle)
96
 
97
  # Socialinsider 2026 (31M posts)
98
  CONTENT_ENERGY_COST = {
@@ -166,7 +166,7 @@ COLLAB_GROWTH_K = 1.50 # cross-pollination follower spillover, scales (1 - o
166
  COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
167
  COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
168
 
169
- API_BUDGET_INITIAL = 100
170
 
171
  # Heuristic baselines for headline metric `vs_baseline_pct`.
172
  # Data-driven: loaded from `plots/training_summary.json["smart_heuristic"]` recorded by
@@ -191,18 +191,6 @@ HEURISTIC_BASELINE_SCORES: Dict[str, float] = _load_heuristic_baselines() or {
191
  # {"baseline": score, "shifted": score} so the second run can compute retention_under_shift.
192
  _SHIFT_HISTORY: Dict[str, Dict[str, float]] = {}
193
 
194
- # Tool costs
195
- TOOL_COSTS = {
196
- "query_audience": 2,
197
- "query_competitor": 2,
198
- "query_tag_history": 1,
199
- "query_trends": 1,
200
- "predict_engagement": 3,
201
- "draft_review": 3,
202
- "query_creator_pool": 1,
203
- "propose_collab": 5,
204
- }
205
-
206
  # ---------------------------------------------------------------------------
207
  # Brand state for multi-episode persistence
208
  # ---------------------------------------------------------------------------
@@ -413,9 +401,6 @@ class ViraltestEnvironment(Environment):
413
  totals = [h.get("total", 0.0) for h in window]
414
  return sum(totals) / len(totals) if totals else 0.0
415
 
416
- def _get_tag_performance_dict(self) -> Dict[str, float]:
417
- return {tag: self._tag_performance_avg(tag) for tag in self._unique_tags_used}
418
-
419
  # ----- competitors -----
420
 
421
  def _advance_competitors(self) -> None:
@@ -436,14 +421,6 @@ class ViraltestEnvironment(Environment):
436
  "engagement": round(eng, 3), "hours_ago": 0,
437
  })
438
 
439
- def _get_competitor_recent_posts(self, limit: int = 5) -> List[Dict[str, Any]]:
440
- all_posts: List[Dict[str, Any]] = []
441
- for comp in self._competitors:
442
- for p in comp.recent_posts:
443
- all_posts.append(p)
444
- all_posts.sort(key=lambda x: x["hours_ago"])
445
- return all_posts[:limit]
446
-
447
  def _get_competitor_avg_engagement(self) -> float:
448
  engagements = [p["engagement"] for comp in self._competitors for p in comp.recent_posts]
449
  return sum(engagements) / len(engagements) if engagements else 0.0
@@ -547,12 +524,6 @@ class ViraltestEnvironment(Environment):
547
  # ----- tool dispatcher -----
548
 
549
  def _dispatch_tool(self, tool: ToolCall) -> ToolResult:
550
- cost = TOOL_COSTS.get(tool.name, 1)
551
- if self._api_budget < cost:
552
- return ToolResult(name=tool.name, success=False, error="rate_limit_exceeded", budget_remaining=self._api_budget)
553
-
554
- self._api_budget -= cost
555
-
556
  if tool.name == "query_audience":
557
  seg_id = tool.arguments.get("segment_id", "")
558
  for seg in _AUDIENCE_DATA.get("segments", []):
@@ -752,9 +723,8 @@ class ViraltestEnvironment(Environment):
752
  self._shift_label = kwargs.get("shift_label")
753
  self._chain_id = kwargs.get("episode_chain_id")
754
 
755
- chain_id = kwargs.get("episode_chain_id")
756
- if chain_id and chain_id in _BRAND_STORE:
757
- brand = _BRAND_STORE[chain_id]
758
  self._unique_tags_used = set(brand.get("top_tags", []))
759
  self._unique_content_types = set(brand.get("dominant_types", []))
760
  self._collab_history = brand.get("collab_history", [])
@@ -870,10 +840,9 @@ class ViraltestEnvironment(Environment):
870
  grader_score = self._run_grader()
871
  headline = self._compute_headline_metrics(grader_score)
872
 
873
- chain_id = kwargs.get("episode_chain_id")
874
- if chain_id:
875
  top_tags = sorted(self._unique_tags_used, key=lambda t: self._tag_performance_avg(t), reverse=True)[:3]
876
- _BRAND_STORE[chain_id] = {
877
  "top_tags": list(top_tags),
878
  "dominant_types": list(self._unique_content_types),
879
  "collab_history": self._collab_history[-3:],
 
92
  # Constants (research-backed, Tier 1-3 sources)
93
  # ---------------------------------------------------------------------------
94
 
95
+ TASK_HORIZON = 7 # 7 daily steps (weekly cycle)
96
 
97
  # Socialinsider 2026 (31M posts)
98
  CONTENT_ENERGY_COST = {
 
166
  COLLAB_PARTNER_REPEAT_PENALTY = 0.7 # discount on multipliers when partner reused this brand
167
  COLLAB_FATIGUE_K = 0.3 # per-collab diminishing-returns factor: 1/(1+K*prior_collabs_this_episode)
168
 
169
+ API_BUDGET_INITIAL = 10**9 # effectively unlimited; rate-limit removed
170
 
171
  # Heuristic baselines for headline metric `vs_baseline_pct`.
172
  # Data-driven: loaded from `plots/training_summary.json["smart_heuristic"]` recorded by
 
191
  # {"baseline": score, "shifted": score} so the second run can compute retention_under_shift.
192
  _SHIFT_HISTORY: Dict[str, Dict[str, float]] = {}
193
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # ---------------------------------------------------------------------------
195
  # Brand state for multi-episode persistence
196
  # ---------------------------------------------------------------------------
 
401
  totals = [h.get("total", 0.0) for h in window]
402
  return sum(totals) / len(totals) if totals else 0.0
403
 
 
 
 
404
  # ----- competitors -----
405
 
406
  def _advance_competitors(self) -> None:
 
421
  "engagement": round(eng, 3), "hours_ago": 0,
422
  })
423
 
 
 
 
 
 
 
 
 
424
  def _get_competitor_avg_engagement(self) -> float:
425
  engagements = [p["engagement"] for comp in self._competitors for p in comp.recent_posts]
426
  return sum(engagements) / len(engagements) if engagements else 0.0
 
524
  # ----- tool dispatcher -----
525
 
526
  def _dispatch_tool(self, tool: ToolCall) -> ToolResult:
 
 
 
 
 
 
527
  if tool.name == "query_audience":
528
  seg_id = tool.arguments.get("segment_id", "")
529
  for seg in _AUDIENCE_DATA.get("segments", []):
 
723
  self._shift_label = kwargs.get("shift_label")
724
  self._chain_id = kwargs.get("episode_chain_id")
725
 
726
+ if self._chain_id and self._chain_id in _BRAND_STORE:
727
+ brand = _BRAND_STORE[self._chain_id]
 
728
  self._unique_tags_used = set(brand.get("top_tags", []))
729
  self._unique_content_types = set(brand.get("dominant_types", []))
730
  self._collab_history = brand.get("collab_history", [])
 
840
  grader_score = self._run_grader()
841
  headline = self._compute_headline_metrics(grader_score)
842
 
843
+ if self._chain_id:
 
844
  top_tags = sorted(self._unique_tags_used, key=lambda t: self._tag_performance_avg(t), reverse=True)[:3]
845
+ _BRAND_STORE[self._chain_id] = {
846
  "top_tags": list(top_tags),
847
  "dominant_types": list(self._unique_content_types),
848
  "collab_history": self._collab_history[-3:],
training/train_grpo.ipynb CHANGED
The diff for this file is too large to render. See raw diff