| """ |
| Viraltest — Edge Case & Scenario Tests (Daily Plan Format) |
| Runs scenarios for all 3 tasks using the new daily step format. |
| Each step = one full day. Agent submits a sparse daily plan. |
| """ |
|
|
| import random as stdlib_random |
| from typing import Callable, Dict, List, Tuple |
|
|
| from models import ScheduledAction, ViraltestAction |
| from server.viraltest_environment import ( |
| TAG_POOL, |
| ViraltestEnvironment, |
| ViraltestObservation, |
| ) |
|
|
| TASKS = ["monthly_engage", "monthly_strategic", "monthly_competitive"] |
| SEED = 42 |
|
|
| _CONTENT_TYPES = ["reel", "carousel", "story", "text_post"] |
| _TOPICS = ["AI tools", "fitness routine", "growth hacks", "travel guide", "food recipe", "wellness tips"] |
| _rng = stdlib_random.Random(99) |
|
|
|
|
| def _plan(actions: list) -> ViraltestAction: |
| return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions]) |
|
|
|
|
| def run_episode( |
| task: str, |
| plan_fn: Callable[[Dict, int], ViraltestAction], |
| label: str, |
| ) -> float: |
| env = ViraltestEnvironment() |
| obs = env.reset(task=task, seed=SEED) |
| obs_dict = obs.model_dump() |
| rewards: List[float] = [] |
| min_energy = 1.0 |
| burned_out = False |
|
|
| for day in range(1, 31): |
| action = plan_fn(obs_dict, day) |
| obs = env.step(action) |
| obs_dict = obs.model_dump() |
| r = obs.reward if obs.reward is not None else 0.0 |
| rewards.append(r) |
| min_energy = min(min_energy, obs.creator_energy) |
| if obs.done and obs.creator_energy <= 0: |
| burned_out = True |
| if obs.done: |
| break |
|
|
| score = (obs.metadata or {}).get("grader_score", 0.0) |
| total_steps = len(rewards) |
|
|
| print(f" Task: {task}") |
| print(f" Days: {total_steps} | Done: {obs.done} | Burned out: {burned_out}") |
| print(f" Score: {score:.4f} | Total reward: {sum(rewards):.2f} | Avg reward: {sum(rewards)/len(rewards):.3f}") |
| print(f" Energy: {obs.creator_energy:.2f} | Min energy: {min_energy:.2f}") |
| print(f" Followers: {obs.follower_count} (started 10000, delta {obs.follower_count - 10000:+d})") |
| print(f" Engagement rate: {obs.engagement_rate:.4f}") |
| print(f" Unique tags: {len(obs.tag_performance)}") |
| print(f" Niche saturation: {obs.niche_saturation:.3f}") |
| print() |
| return score |
|
|
|
|
| def plan_always_rest(obs: dict, day: int) -> ViraltestAction: |
| return _plan([]) |
|
|
|
|
| def plan_spam(obs: dict, day: int) -> ViraltestAction: |
| return _plan([{"hour": h, "action_type": "post", "content_type": "reel", |
| "topic": "AI tools", "tags": ["ai"]} for h in range(24)]) |
|
|
|
|
| def plan_smart(obs: dict, day: int) -> ViraltestAction: |
| trending = (obs.get("trending_topics") or ["AI tools"])[0] |
| t_tags = list((obs.get("trending_tags") or [])[:2]) |
| pool_tag = TAG_POOL[(day * 2) % len(TAG_POOL)] |
| pool_tag2 = TAG_POOL[(day * 2 + 1) % len(TAG_POOL)] |
| ct1 = _CONTENT_TYPES[(day * 2) % 4] |
| ct2 = _CONTENT_TYPES[(day * 2 + 1) % 4] |
| return _plan([ |
| {"hour": 8, "action_type": "create_content"}, |
| {"hour": 12, "action_type": "post", "content_type": ct1, "topic": trending, "tags": t_tags + [pool_tag]}, |
| {"hour": 19, "action_type": "post", "content_type": ct2, "topic": trending, "tags": t_tags + [pool_tag2]}, |
| ]) |
|
|
|
|
| def plan_no_rest(obs: dict, day: int) -> ViraltestAction: |
| actions = [] |
| for h in range(24): |
| ct = _CONTENT_TYPES[h % 4] |
| topic = _rng.choice(_TOPICS) |
| tags = _rng.sample(TAG_POOL, 3) |
| actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": topic, "tags": tags}) |
| return _plan(actions) |
|
|
|
|
| def plan_minimal(obs: dict, day: int) -> ViraltestAction: |
| trending = (obs.get("trending_topics") or ["minimalism"])[0] |
| tags = list((obs.get("trending_tags") or [])[:3]) |
| return _plan([ |
| {"hour": 12, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags}, |
| ]) |
|
|
|
|
| def plan_tag_explorer(obs: dict, day: int) -> ViraltestAction: |
| trending = (obs.get("trending_topics") or ["devtools"])[0] |
| start = (day * 6) % len(TAG_POOL) |
| tags1 = [TAG_POOL[(start + i) % len(TAG_POOL)] for i in range(3)] |
| tags2 = [TAG_POOL[(start + 3 + i) % len(TAG_POOL)] for i in range(3)] |
| ct1 = _CONTENT_TYPES[(day * 2) % 4] |
| ct2 = _CONTENT_TYPES[(day * 2 + 1) % 4] |
| return _plan([ |
| {"hour": 10, "action_type": "post", "content_type": ct1, "topic": trending, "tags": tags1}, |
| {"hour": 18, "action_type": "post", "content_type": ct2, "topic": trending, "tags": tags2}, |
| ]) |
|
|
|
|
| def plan_queue_optimizer(obs: dict, day: int) -> ViraltestAction: |
| trending = (obs.get("trending_topics") or ["productivity"])[0] |
| tags = list((obs.get("trending_tags") or [])[:2]) + ["growth"] |
| queue = obs.get("content_queue_size", 0) |
| if day < 3 or queue < 2: |
| return _plan([ |
| {"hour": 8, "action_type": "create_content"}, |
| {"hour": 10, "action_type": "create_content"}, |
| {"hour": 14, "action_type": "create_content"}, |
| ]) |
| ct = _CONTENT_TYPES[day % 4] |
| return _plan([ |
| {"hour": 12, "action_type": "post", "content_type": ct, "topic": trending, "tags": tags}, |
| {"hour": 19, "action_type": "post", "content_type": _CONTENT_TYPES[(day + 1) % 4], "topic": trending, "tags": tags}, |
| ]) |
|
|
|
|
| def plan_double_peak(obs: dict, day: int) -> ViraltestAction: |
| trending = (obs.get("trending_topics") or ["peak time content"])[0] |
| tags = list((obs.get("trending_tags") or [])[:3]) |
| return _plan([ |
| {"hour": 9, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, |
| {"hour": 15, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags}, |
| ]) |
|
|
|
|
| def plan_random(obs: dict, day: int) -> ViraltestAction: |
| actions = [] |
| for h in range(24): |
| r = _rng.random() |
| if r < 0.1: |
| ct = _rng.choice(_CONTENT_TYPES) |
| topic = _rng.choice(["random topic", "AI tools", "fitness", "travel"]) |
| tags = _rng.sample(TAG_POOL, 2) |
| actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": topic, "tags": tags}) |
| elif r < 0.15: |
| actions.append({"hour": h, "action_type": "create_content"}) |
| return _plan(actions) |
|
|
|
|
| SCENARIOS: List[Tuple[str, Callable, str]] = [ |
| ("Always Rest", plan_always_rest, "Zero engagement, no growth, energy stays max"), |
| ("Spam Post", plan_spam, "Post every hour, burns out instantly"), |
| ("Smart Agent", plan_smart, "Peak hours, trending, varied types, energy management"), |
| ("No Rest", plan_no_rest, "Post every hour, never rests, burns out"), |
| ("Minimal Poster", plan_minimal, "1 carousel at noon per day"), |
| ("Tag Explorer", plan_tag_explorer, "Rotates through tag pool for max discovery"), |
| ("Queue Optimizer", plan_queue_optimizer, "Creates content first, posts from queue"), |
| ("Double Peak", plan_double_peak, "Posts at 9am and 3pm"), |
| ("Random Actor", plan_random, "Random sparse actions each day"), |
| ] |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 70) |
| print("VIRALTEST — DAILY PLAN SCENARIO TESTS") |
| print("=" * 70) |
| print() |
|
|
| for scenario_name, plan_fn, description in SCENARIOS: |
| print("=" * 70) |
| print(f"{scenario_name}") |
| print(f" {description}") |
| print("=" * 70) |
| print() |
|
|
| for task in TASKS: |
| _rng = stdlib_random.Random(99) |
| run_episode(task, plan_fn, scenario_name) |
|
|
| print() |
|
|
| print("=" * 70) |
| print("SUMMARY TABLE") |
| print("=" * 70) |
| print() |
| print(f"{'Scenario':<30} {'Engage':>8} {'Strategic':>10} {'Competitive':>12}") |
| print("-" * 62) |
|
|
| for scenario_name, plan_fn, _ in SCENARIOS: |
| scores = [] |
| for task in TASKS: |
| _rng = stdlib_random.Random(99) |
| env = ViraltestEnvironment() |
| obs = env.reset(task=task, seed=SEED) |
| obs_dict = obs.model_dump() |
| for day in range(1, 31): |
| action = plan_fn(obs_dict, day) |
| obs = env.step(action) |
| obs_dict = obs.model_dump() |
| if obs.done: |
| break |
| scores.append((obs.metadata or {}).get("grader_score", 0.0)) |
| print(f"{scenario_name:<30} {scores[0]:>8.4f} {scores[1]:>10.4f} {scores[2]:>12.4f}") |
|
|
| print() |
| print("EXPECTED: Smart/Queue/Tag Explorer should score highest.") |
| print("Burnout agents (spam, no_rest) should score near 0 on strategic/competitive.") |
|
|