final-iteration / test_scenarios.py
anuragredbus's picture
Viraltest OpenEnv: deploy to HF Space
28dd5a4
raw
history blame
8.47 kB
"""
Viraltest — Edge Case & Scenario Tests (Daily Plan Format)
Runs scenarios for all 3 tasks using the new daily step format.
Each step = one full day. Agent submits a sparse daily plan.
"""
import random as stdlib_random
from typing import Callable, Dict, List, Tuple
from models import ScheduledAction, ViraltestAction
from server.viraltest_environment import (
TAG_POOL,
ViraltestEnvironment,
ViraltestObservation,
)
TASKS = ["weekly_engage", "weekly_strategic", "weekly_competitive"]
SEED = 42
_CONTENT_TYPES = ["reel", "carousel", "story", "text_post"]
_TOPICS = ["AI tools", "fitness routine", "growth hacks", "travel guide", "food recipe", "wellness tips"]
_rng = stdlib_random.Random(99)
def _plan(actions: list) -> ViraltestAction:
return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions])
def run_episode(
task: str,
plan_fn: Callable[[Dict, int], ViraltestAction],
label: str,
) -> float:
env = ViraltestEnvironment()
obs = env.reset(task=task, seed=SEED)
obs_dict = obs.model_dump()
rewards: List[float] = []
min_energy = 1.0
burned_out = False
for day in range(1, 8):
action = plan_fn(obs_dict, day)
obs = env.step(action)
obs_dict = obs.model_dump()
r = obs.reward if obs.reward is not None else 0.0
rewards.append(r)
min_energy = min(min_energy, obs.creator_energy)
if obs.done and obs.creator_energy <= 0:
burned_out = True
if obs.done:
break
score = (obs.metadata or {}).get("grader_score", 0.0)
total_steps = len(rewards)
print(f" Task: {task}")
print(f" Days: {total_steps} | Done: {obs.done} | Burned out: {burned_out}")
print(f" Score: {score:.4f} | Total reward: {sum(rewards):.2f} | Avg reward: {sum(rewards)/len(rewards):.3f}")
print(f" Energy: {obs.creator_energy:.2f} | Min energy: {min_energy:.2f}")
print(f" Followers: {obs.follower_count} (started 10000, delta {obs.follower_count - 10000:+d})")
print(f" Engagement rate: {obs.engagement_rate:.4f}")
print(f" Unique tags: {len(obs.tag_performance)}")
print(f" Niche saturation: {obs.niche_saturation:.3f}")
print()
return score
def plan_always_rest(obs: dict, day: int) -> ViraltestAction:
return _plan([])
def plan_spam(obs: dict, day: int) -> ViraltestAction:
return _plan([{"hour": h, "action_type": "post", "content_type": "reel",
"topic": "AI tools", "tags": ["ai"]} for h in range(24)])
def plan_smart(obs: dict, day: int) -> ViraltestAction:
trending = (obs.get("trending_topics") or ["AI tools"])[0]
t_tags = list((obs.get("trending_tags") or [])[:2])
pool_tag = TAG_POOL[(day * 2) % len(TAG_POOL)]
pool_tag2 = TAG_POOL[(day * 2 + 1) % len(TAG_POOL)]
ct1 = _CONTENT_TYPES[(day * 2) % 4]
ct2 = _CONTENT_TYPES[(day * 2 + 1) % 4]
return _plan([
{"hour": 8, "action_type": "create_content"},
{"hour": 12, "action_type": "post", "content_type": ct1, "topic": trending, "tags": t_tags + [pool_tag]},
{"hour": 19, "action_type": "post", "content_type": ct2, "topic": trending, "tags": t_tags + [pool_tag2]},
])
def plan_no_rest(obs: dict, day: int) -> ViraltestAction:
actions = []
for h in range(24):
ct = _CONTENT_TYPES[h % 4]
topic = _rng.choice(_TOPICS)
tags = _rng.sample(TAG_POOL, 3)
actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": topic, "tags": tags})
return _plan(actions)
def plan_minimal(obs: dict, day: int) -> ViraltestAction:
trending = (obs.get("trending_topics") or ["minimalism"])[0]
tags = list((obs.get("trending_tags") or [])[:3])
return _plan([
{"hour": 12, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags},
])
def plan_tag_explorer(obs: dict, day: int) -> ViraltestAction:
trending = (obs.get("trending_topics") or ["devtools"])[0]
start = (day * 6) % len(TAG_POOL)
tags1 = [TAG_POOL[(start + i) % len(TAG_POOL)] for i in range(3)]
tags2 = [TAG_POOL[(start + 3 + i) % len(TAG_POOL)] for i in range(3)]
ct1 = _CONTENT_TYPES[(day * 2) % 4]
ct2 = _CONTENT_TYPES[(day * 2 + 1) % 4]
return _plan([
{"hour": 10, "action_type": "post", "content_type": ct1, "topic": trending, "tags": tags1},
{"hour": 18, "action_type": "post", "content_type": ct2, "topic": trending, "tags": tags2},
])
def plan_queue_optimizer(obs: dict, day: int) -> ViraltestAction:
trending = (obs.get("trending_topics") or ["productivity"])[0]
tags = list((obs.get("trending_tags") or [])[:2]) + ["growth"]
queue = obs.get("content_queue_size", 0)
if day < 3 or queue < 2:
return _plan([
{"hour": 8, "action_type": "create_content"},
{"hour": 10, "action_type": "create_content"},
{"hour": 14, "action_type": "create_content"},
])
ct = _CONTENT_TYPES[day % 4]
return _plan([
{"hour": 12, "action_type": "post", "content_type": ct, "topic": trending, "tags": tags},
{"hour": 19, "action_type": "post", "content_type": _CONTENT_TYPES[(day + 1) % 4], "topic": trending, "tags": tags},
])
def plan_double_peak(obs: dict, day: int) -> ViraltestAction:
trending = (obs.get("trending_topics") or ["peak time content"])[0]
tags = list((obs.get("trending_tags") or [])[:3])
return _plan([
{"hour": 9, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags},
{"hour": 15, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags},
])
def plan_random(obs: dict, day: int) -> ViraltestAction:
actions = []
for h in range(24):
r = _rng.random()
if r < 0.1:
ct = _rng.choice(_CONTENT_TYPES)
topic = _rng.choice(["random topic", "AI tools", "fitness", "travel"])
tags = _rng.sample(TAG_POOL, 2)
actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": topic, "tags": tags})
elif r < 0.15:
actions.append({"hour": h, "action_type": "create_content"})
return _plan(actions)
SCENARIOS: List[Tuple[str, Callable, str]] = [
("Always Rest", plan_always_rest, "Zero engagement, no growth, energy stays max"),
("Spam Post", plan_spam, "Post every hour, burns out instantly"),
("Smart Agent", plan_smart, "Peak hours, trending, varied types, energy management"),
("No Rest", plan_no_rest, "Post every hour, never rests, burns out"),
("Minimal Poster", plan_minimal, "1 carousel at noon per day"),
("Tag Explorer", plan_tag_explorer, "Rotates through tag pool for max discovery"),
("Queue Optimizer", plan_queue_optimizer, "Creates content first, posts from queue"),
("Double Peak", plan_double_peak, "Posts at 9am and 3pm"),
("Random Actor", plan_random, "Random sparse actions each day"),
]
if __name__ == "__main__":
print("=" * 70)
print("VIRALTEST — DAILY PLAN SCENARIO TESTS")
print("=" * 70)
print()
for scenario_name, plan_fn, description in SCENARIOS:
print("=" * 70)
print(f"{scenario_name}")
print(f" {description}")
print("=" * 70)
print()
for task in TASKS:
_rng = stdlib_random.Random(99)
run_episode(task, plan_fn, scenario_name)
print()
print("=" * 70)
print("SUMMARY TABLE")
print("=" * 70)
print()
print(f"{'Scenario':<30} {'Engage':>8} {'Strategic':>10} {'Competitive':>12}")
print("-" * 62)
for scenario_name, plan_fn, _ in SCENARIOS:
scores = []
for task in TASKS:
_rng = stdlib_random.Random(99)
env = ViraltestEnvironment()
obs = env.reset(task=task, seed=SEED)
obs_dict = obs.model_dump()
for day in range(1, 8):
action = plan_fn(obs_dict, day)
obs = env.step(action)
obs_dict = obs.model_dump()
if obs.done:
break
scores.append((obs.metadata or {}).get("grader_score", 0.0))
print(f"{scenario_name:<30} {scores[0]:>8.4f} {scores[1]:>10.4f} {scores[2]:>12.4f}")
print()
print("EXPECTED: Smart/Queue/Tag Explorer should score highest.")
print("Burnout agents (spam, no_rest) should score near 0 on strategic/competitive.")