Spaces:
Paused
Paused
Commit ·
571f8a4
1
Parent(s): 28dd5a4
reduced steps to fit out free tier
Browse files- inference.py +47 -7
inference.py
CHANGED
|
@@ -26,8 +26,12 @@ from typing import Any, Dict, List, Optional
|
|
| 26 |
|
| 27 |
from openai import OpenAI
|
| 28 |
|
| 29 |
-
from viraltest import ViraltestAction, ViraltestEnv
|
| 30 |
-
from viraltest.server.viraltest_environment import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
DOCKER_IMAGE = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
|
| 33 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
|
@@ -45,6 +49,16 @@ SUCCESS_SCORE_THRESHOLD = 0.1
|
|
| 45 |
|
| 46 |
VALID_TAGS_TEXT = ", ".join(TAG_POOL)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
SYSTEM_PROMPT = textwrap.dedent(f"""\
|
| 49 |
You are a social media content strategy agent. Each step is one full day (24 hours).
|
| 50 |
You receive the current day's state and must plan your actions for the entire day.
|
|
@@ -56,8 +70,8 @@ FORMAT (JSON only, no markdown, no prose):
|
|
| 56 |
{{
|
| 57 |
"scheduled_actions": [
|
| 58 |
{{"hour": 10, "action_type": "create_content"}},
|
| 59 |
-
{{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI
|
| 60 |
-
{{"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup
|
| 61 |
]
|
| 62 |
}}
|
| 63 |
|
|
@@ -65,6 +79,7 @@ RULES:
|
|
| 65 |
- hour: 0-23 (which hour of the day to perform the action)
|
| 66 |
- action_type: "post" or "create_content" (rest is automatic for unlisted hours)
|
| 67 |
- For posts: content_type (reel|story|carousel|text_post), topic, and tags are required
|
|
|
|
| 68 |
- Tags must be from this pool: {VALID_TAGS_TEXT}
|
| 69 |
- Max 5 tags per post
|
| 70 |
- Empty scheduled_actions means rest all day
|
|
@@ -77,9 +92,9 @@ and use create_content to build a content queue for cheaper posts later.""")
|
|
| 77 |
|
| 78 |
|
| 79 |
def should_force_rest_day(obs: Any) -> bool:
|
| 80 |
-
"""If energy is
|
| 81 |
energy = float(getattr(obs, "creator_energy", 1.0))
|
| 82 |
-
return energy <=
|
| 83 |
|
| 84 |
|
| 85 |
def log_start(task: str, env: str, model: str) -> None:
|
|
@@ -162,6 +177,30 @@ def parse_daily_plan(response_text: str) -> ViraltestAction:
|
|
| 162 |
return ViraltestAction(scheduled_actions=[])
|
| 163 |
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
def format_action_str(action: ViraltestAction) -> str:
|
| 166 |
"""Format daily plan for [STEP] log line."""
|
| 167 |
if not action.scheduled_actions:
|
|
@@ -201,7 +240,8 @@ def get_model_daily_plan(
|
|
| 201 |
stream=False,
|
| 202 |
)
|
| 203 |
text = (completion.choices[0].message.content or "").strip()
|
| 204 |
-
|
|
|
|
| 205 |
except Exception as exc:
|
| 206 |
err_str = str(exc)
|
| 207 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
|
|
|
| 26 |
|
| 27 |
from openai import OpenAI
|
| 28 |
|
| 29 |
+
from viraltest import ScheduledAction, ViraltestAction, ViraltestEnv
|
| 30 |
+
from viraltest.server.viraltest_environment import (
|
| 31 |
+
TAG_POOL,
|
| 32 |
+
TASK_HORIZON,
|
| 33 |
+
TOPIC_CATEGORIES,
|
| 34 |
+
)
|
| 35 |
|
| 36 |
DOCKER_IMAGE = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME")
|
| 37 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
|
|
|
|
| 49 |
|
| 50 |
VALID_TAGS_TEXT = ", ".join(TAG_POOL)
|
| 51 |
|
| 52 |
+
# Flatten env topic categories — posts must use these exact strings (see sanitize_predefined_topics).
|
| 53 |
+
PREDEFINED_TOPICS: tuple[str, ...] = tuple(
|
| 54 |
+
topic for topics in TOPIC_CATEGORIES.values() for topic in topics
|
| 55 |
+
)
|
| 56 |
+
_TOPIC_CANONICAL: dict[str, str] = {t.lower(): t for t in PREDEFINED_TOPICS}
|
| 57 |
+
PREDEFINED_TOPICS_TEXT = ", ".join(PREDEFINED_TOPICS)
|
| 58 |
+
|
| 59 |
+
# When energy is at or below this level, skip the model and rest the full day (avoid burnout).
|
| 60 |
+
NEAR_ZERO_ENERGY_THRESHOLD = 0.25
|
| 61 |
+
|
| 62 |
SYSTEM_PROMPT = textwrap.dedent(f"""\
|
| 63 |
You are a social media content strategy agent. Each step is one full day (24 hours).
|
| 64 |
You receive the current day's state and must plan your actions for the entire day.
|
|
|
|
| 70 |
{{
|
| 71 |
"scheduled_actions": [
|
| 72 |
{{"hour": 10, "action_type": "create_content"}},
|
| 73 |
+
{{"hour": 12, "action_type": "post", "content_type": "reel", "topic": "AI tools", "tags": ["ai", "coding"]}},
|
| 74 |
+
{{"hour": 18, "action_type": "post", "content_type": "carousel", "topic": "startup life", "tags": ["startup", "growth"]}}
|
| 75 |
]
|
| 76 |
}}
|
| 77 |
|
|
|
|
| 79 |
- hour: 0-23 (which hour of the day to perform the action)
|
| 80 |
- action_type: "post" or "create_content" (rest is automatic for unlisted hours)
|
| 81 |
- For posts: content_type (reel|story|carousel|text_post), topic, and tags are required
|
| 82 |
+
- Topic must be exactly one of these strings (no paraphrasing): {PREDEFINED_TOPICS_TEXT}
|
| 83 |
- Tags must be from this pool: {VALID_TAGS_TEXT}
|
| 84 |
- Max 5 tags per post
|
| 85 |
- Empty scheduled_actions means rest all day
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
def should_force_rest_day(obs: Any) -> bool:
|
| 95 |
+
"""If energy is near zero, always submit an empty schedule (all rest)."""
|
| 96 |
energy = float(getattr(obs, "creator_energy", 1.0))
|
| 97 |
+
return energy <= NEAR_ZERO_ENERGY_THRESHOLD
|
| 98 |
|
| 99 |
|
| 100 |
def log_start(task: str, env: str, model: str) -> None:
|
|
|
|
| 177 |
return ViraltestAction(scheduled_actions=[])
|
| 178 |
|
| 179 |
|
| 180 |
+
def _resolve_predefined_topic(raw: Optional[str], obs: Any, hour: int) -> str:
|
| 181 |
+
"""Map a model-provided topic to a canonical string from TOPIC_CATEGORIES."""
|
| 182 |
+
if raw and raw.strip():
|
| 183 |
+
key = raw.strip().lower()
|
| 184 |
+
if key in _TOPIC_CANONICAL:
|
| 185 |
+
return _TOPIC_CANONICAL[key]
|
| 186 |
+
for tt in obs.trending_topics or []:
|
| 187 |
+
tl = (tt or "").strip().lower()
|
| 188 |
+
if tl in _TOPIC_CANONICAL:
|
| 189 |
+
return _TOPIC_CANONICAL[tl]
|
| 190 |
+
return PREDEFINED_TOPICS[hour % len(PREDEFINED_TOPICS)]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def sanitize_predefined_topics(action: ViraltestAction, obs: Any) -> ViraltestAction:
|
| 194 |
+
"""Force every post topic to match the environment's predefined topic set."""
|
| 195 |
+
out: List[ScheduledAction] = []
|
| 196 |
+
for sa in action.scheduled_actions:
|
| 197 |
+
if sa.action_type == "post":
|
| 198 |
+
out.append(sa.model_copy(update={"topic": _resolve_predefined_topic(sa.topic, obs, sa.hour)}))
|
| 199 |
+
else:
|
| 200 |
+
out.append(sa)
|
| 201 |
+
return ViraltestAction(scheduled_actions=out)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
def format_action_str(action: ViraltestAction) -> str:
|
| 205 |
"""Format daily plan for [STEP] log line."""
|
| 206 |
if not action.scheduled_actions:
|
|
|
|
| 240 |
stream=False,
|
| 241 |
)
|
| 242 |
text = (completion.choices[0].message.content or "").strip()
|
| 243 |
+
plan = parse_daily_plan(text) if text else ViraltestAction(scheduled_actions=[])
|
| 244 |
+
return sanitize_predefined_topics(plan, obs)
|
| 245 |
except Exception as exc:
|
| 246 |
err_str = str(exc)
|
| 247 |
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|