anuragredbus commited on
Commit
2ae4336
·
1 Parent(s): 4abeb9a

reduced steps to fit out free tier

Browse files
Files changed (1) hide show
  1. inference.py +47 -7
inference.py CHANGED
@@ -26,8 +26,12 @@ from typing import Any, Dict, List, Optional
26
 
27
  from openai import OpenAI
28
 
29
- from viraltest import ViraltestAction, ViraltestEnv
30
- from viraltest.server.viraltest_environment import TAG_POOL, TASK_HORIZON
 
 
 
 
31
 
32
  DOCKER_IMAGE = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
33
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
@@ -45,6 +49,16 @@ SUCCESS_SCORE_THRESHOLD = 0.1
45
 
46
  VALID_TAGS_TEXT = ", ".join(TAG_POOL)
47
 
 
 
 
 
 
 
 
 
 
 
48
  SYSTEM_PROMPT = textwrap.dedent(f"""\
49
  You are a social media content strategy agent. Each step is one full day (24 hours).
50
  You receive the current day's state and must plan your actions for the entire day.
@@ -56,8 +70,8 @@ FORMAT (JSON only, no markdown, no prose):
56
  {{
57
  "scheduled_actions": [
58
  {{"hour": 10, "action_type": "create_content"}},
59
- {{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI trends", "tags": ["ai", "coding"]}},
60
- {{"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup tips", "tags": ["startup", "growth"]}}
61
  ]
62
  }}
63
 
@@ -65,6 +79,7 @@ RULES:
65
  - hour: 0-23 (which hour of the day to perform the action)
66
  - action_type: "post" or "create_content" (rest is automatic for unlisted hours)
67
  - For posts: content_type (reel|story|carousel|text_post), topic, and tags are required
 
68
  - Tags must be from this pool: {VALID_TAGS_TEXT}
69
  - Max 5 tags per post
70
  - Empty scheduled_actions means rest all day
@@ -77,9 +92,9 @@ and use create_content to build a content queue for cheaper posts later.""")
77
 
78
 
79
  def should_force_rest_day(obs: Any) -> bool:
80
- """If energy is critically low, submit an empty schedule (all rest)."""
81
  energy = float(getattr(obs, "creator_energy", 1.0))
82
- return energy <= 0.15
83
 
84
 
85
  def log_start(task: str, env: str, model: str) -> None:
@@ -162,6 +177,30 @@ def parse_daily_plan(response_text: str) -> ViraltestAction:
162
  return ViraltestAction(scheduled_actions=[])
163
 
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def format_action_str(action: ViraltestAction) -> str:
166
  """Format daily plan for [STEP] log line."""
167
  if not action.scheduled_actions:
@@ -201,7 +240,8 @@ def get_model_daily_plan(
201
  stream=False,
202
  )
203
  text = (completion.choices[0].message.content or "").strip()
204
- return parse_daily_plan(text) if text else ViraltestAction(scheduled_actions=[])
 
205
  except Exception as exc:
206
  err_str = str(exc)
207
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
 
26
 
27
  from openai import OpenAI
28
 
29
+ from viraltest import ScheduledAction, ViraltestAction, ViraltestEnv
30
+ from viraltest.server.viraltest_environment import (
31
+ TAG_POOL,
32
+ TASK_HORIZON,
33
+ TOPIC_CATEGORIES,
34
+ )
35
 
36
  DOCKER_IMAGE = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
37
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
 
49
 
50
  VALID_TAGS_TEXT = ", ".join(TAG_POOL)
51
 
52
+ # Flatten env topic categories — posts must use these exact strings (see sanitize_predefined_topics).
53
+ PREDEFINED_TOPICS: tuple[str, ...] = tuple(
54
+ topic for topics in TOPIC_CATEGORIES.values() for topic in topics
55
+ )
56
+ _TOPIC_CANONICAL: dict[str, str] = {t.lower(): t for t in PREDEFINED_TOPICS}
57
+ PREDEFINED_TOPICS_TEXT = ", ".join(PREDEFINED_TOPICS)
58
+
59
+ # When energy is at or below this level, skip the model and rest the full day (avoid burnout).
60
+ NEAR_ZERO_ENERGY_THRESHOLD = 0.25
61
+
62
  SYSTEM_PROMPT = textwrap.dedent(f"""\
63
  You are a social media content strategy agent. Each step is one full day (24 hours).
64
  You receive the current day's state and must plan your actions for the entire day.
 
70
  {{
71
  "scheduled_actions": [
72
  {{"hour": 10, "action_type": "create_content"}},
73
+ {{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"]}},
74
+ {{"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup life", "tags": ["startup", "growth"]}}
75
  ]
76
  }}
77
 
 
79
  - hour: 0-23 (which hour of the day to perform the action)
80
  - action_type: "post" or "create_content" (rest is automatic for unlisted hours)
81
  - For posts: content_type (reel|story|carousel|text_post), topic, and tags are required
82
+ - Topic must be exactly one of these strings (no paraphrasing): {PREDEFINED_TOPICS_TEXT}
83
  - Tags must be from this pool: {VALID_TAGS_TEXT}
84
  - Max 5 tags per post
85
  - Empty scheduled_actions means rest all day
 
92
 
93
 
94
  def should_force_rest_day(obs: Any) -> bool:
95
+ """If energy is near zero, always submit an empty schedule (all rest)."""
96
  energy = float(getattr(obs, "creator_energy", 1.0))
97
+ return energy <= NEAR_ZERO_ENERGY_THRESHOLD
98
 
99
 
100
  def log_start(task: str, env: str, model: str) -> None:
 
177
  return ViraltestAction(scheduled_actions=[])
178
 
179
 
180
+ def _resolve_predefined_topic(raw: Optional[str], obs: Any, hour: int) -> str:
181
+ """Map a model-provided topic to a canonical string from TOPIC_CATEGORIES."""
182
+ if raw and raw.strip():
183
+ key = raw.strip().lower()
184
+ if key in _TOPIC_CANONICAL:
185
+ return _TOPIC_CANONICAL[key]
186
+ for tt in obs.trending_topics or []:
187
+ tl = (tt or "").strip().lower()
188
+ if tl in _TOPIC_CANONICAL:
189
+ return _TOPIC_CANONICAL[tl]
190
+ return PREDEFINED_TOPICS[hour % len(PREDEFINED_TOPICS)]
191
+
192
+
193
+ def sanitize_predefined_topics(action: ViraltestAction, obs: Any) -> ViraltestAction:
194
+ """Force every post topic to match the environment's predefined topic set."""
195
+ out: List[ScheduledAction] = []
196
+ for sa in action.scheduled_actions:
197
+ if sa.action_type == "post":
198
+ out.append(sa.model_copy(update={"topic": _resolve_predefined_topic(sa.topic, obs, sa.hour)}))
199
+ else:
200
+ out.append(sa)
201
+ return ViraltestAction(scheduled_actions=out)
202
+
203
+
204
  def format_action_str(action: ViraltestAction) -> str:
205
  """Format daily plan for [STEP] log line."""
206
  if not action.scheduled_actions:
 
240
  stream=False,
241
  )
242
  text = (completion.choices[0].message.content or "").strip()
243
+ plan = parse_daily_plan(text) if text else ViraltestAction(scheduled_actions=[])
244
+ return sanitize_predefined_topics(plan, obs)
245
  except Exception as exc:
246
  err_str = str(exc)
247
  print(f"[DEBUG] Model request failed: {exc}", flush=True)