Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| FastAPI application for the Viraltest Environment. | |
| This module creates an HTTP server that exposes the ViraltestEnvironment | |
| over HTTP and WebSocket endpoints, compatible with EnvClient. | |
| Endpoints: | |
| - POST /reset: Reset the environment | |
| - POST /step: Execute an action | |
| - GET /state: Get current environment state | |
| - GET /schema: Get action/observation schemas | |
| - WS /ws: WebSocket endpoint for persistent sessions | |
| Usage: | |
| # Development (with auto-reload): | |
| uvicorn server.app:app --reload --host 0.0.0.0 --port 8000 | |
| # Production: | |
| uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4 | |
| # Or run directly: | |
| python -m server.app | |
| """ | |
| import json | |
| import os | |
| import random as stdlib_random | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import Body | |
| from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse | |
| try: | |
| from openenv.core.env_server.http_server import create_app | |
| except Exception as e: # pragma: no cover | |
| raise ImportError( | |
| "openenv is required for the web interface. Install dependencies with '\n uv sync\n'" | |
| ) from e | |
| # OpenEnv Gradio UI lives at /web; Dockerfile sets this — default on for local parity with HF Spaces. | |
| if "ENABLE_WEB_INTERFACE" not in os.environ: | |
| os.environ["ENABLE_WEB_INTERFACE"] = "true" | |
| try: | |
| from ..models import ScheduledAction, ViraltestAction, ViraltestObservation | |
| from .viraltest_environment import ViraltestEnvironment | |
| except ImportError: | |
| from models import ScheduledAction, ViraltestAction, ViraltestObservation | |
| from server.viraltest_environment import ViraltestEnvironment | |
| _DASHBOARD_HTML = (Path(__file__).parent / "dashboard.html").read_text() | |
| app = create_app( | |
| ViraltestEnvironment, | |
| ViraltestAction, | |
| ViraltestObservation, | |
| env_name="viraltest", | |
| max_concurrent_envs=1, | |
| ) | |
| _gradio_web = os.getenv("ENABLE_WEB_INTERFACE", "false").lower() in ("true", "1", "yes") | |
| if not _gradio_web: | |
| async def _root_redirect(): | |
| return RedirectResponse("/dashboard", status_code=302) | |
| async def _web_disabled_redirect(): | |
| return RedirectResponse("/dashboard", status_code=302) | |
| _dash_env: Optional[ViraltestEnvironment] = None | |
| _HISTORY_FILE = Path(__file__).parent / "simulation_history.json" | |
| def _obs_to_dict(obs: ViraltestObservation) -> Dict[str, Any]: | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": obs.reward, | |
| "done": obs.done, | |
| } | |
| def _load_history() -> List[Dict[str, Any]]: | |
| if _HISTORY_FILE.exists(): | |
| try: | |
| return json.loads(_HISTORY_FILE.read_text()) | |
| except (json.JSONDecodeError, OSError): | |
| return [] | |
| return [] | |
| def _save_history_entry(entry: Dict[str, Any]) -> None: | |
| history = _load_history() | |
| history.append(entry) | |
| if len(history) > 100: | |
| history = history[-100:] | |
| _HISTORY_FILE.write_text(json.dumps(history, indent=2)) | |
| async def dashboard(): | |
| return _DASHBOARD_HTML | |
| async def dashboard_history(): | |
| history = _load_history() | |
| out: List[Dict[str, Any]] = [] | |
| for row in history: | |
| entry = dict(row) | |
| if not entry.get("description"): | |
| sid = entry.get("scenario_id") | |
| if sid and sid in SCENARIOS: | |
| entry["description"] = SCENARIOS[sid][1] | |
| out.append(entry) | |
| return out | |
| async def dashboard_history_clear(): | |
| if _HISTORY_FILE.exists(): | |
| _HISTORY_FILE.unlink() | |
| return {"status": "cleared"} | |
| async def dashboard_reset(body: Dict[str, Any] = Body(default={})): | |
| global _dash_env | |
| _dash_env = ViraltestEnvironment() | |
| task = body.get("task", "weekly_engage") | |
| obs = _dash_env.reset(task=task) | |
| return _obs_to_dict(obs) | |
| async def dashboard_step(body: Dict[str, Any] = Body(...)): | |
| global _dash_env | |
| if _dash_env is None: | |
| _dash_env = ViraltestEnvironment() | |
| _dash_env.reset() | |
| action_data = body.get("action", body) | |
| action = ViraltestAction(**action_data) | |
| obs = _dash_env.step(action) | |
| return _obs_to_dict(obs) | |
| try: | |
| from .viraltest_environment import TAG_POOL | |
| except ImportError: | |
| from server.viraltest_environment import TAG_POOL | |
| _SIM_RNG = stdlib_random.Random(99) | |
| _CONTENT_TYPES = ["reel", "carousel", "story", "text_post"] | |
| _TOPICS = ["AI tools", "fitness routine", "growth hacks", "travel guide", "food recipe", "wellness tips"] | |
| def _make_daily_plan(actions: list) -> ViraltestAction: | |
| """Helper: build a ViraltestAction from a list of ScheduledAction-like dicts.""" | |
| return ViraltestAction(scheduled_actions=[ScheduledAction(**a) for a in actions]) | |
| def _plan_always_rest(obs: dict, day: int) -> ViraltestAction: | |
| return _make_daily_plan([]) | |
| def _plan_spam(obs: dict, day: int) -> ViraltestAction: | |
| actions = [{"hour": h, "action_type": "post", "content_type": "reel", | |
| "topic": "AI tools", "tags": ["ai"]} for h in range(24)] | |
| return _make_daily_plan(actions) | |
| def _plan_smart(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["AI tools"])[0] | |
| t_tags = list((obs.get("trending_tags") or [])[:2]) | |
| pool_tag = TAG_POOL[(day * 2) % len(TAG_POOL)] | |
| pool_tag2 = TAG_POOL[(day * 2 + 1) % len(TAG_POOL)] | |
| ct1 = _CONTENT_TYPES[(day * 2) % 4] | |
| ct2 = _CONTENT_TYPES[(day * 2 + 1) % 4] | |
| actions = [ | |
| {"hour": 8, "action_type": "create_content"}, | |
| {"hour": 12, "action_type": "post", "content_type": ct1, "topic": trending, "tags": t_tags + [pool_tag]}, | |
| {"hour": 19, "action_type": "post", "content_type": ct2, "topic": trending, "tags": t_tags + [pool_tag2]}, | |
| ] | |
| return _make_daily_plan(actions) | |
| def _plan_no_rest(obs: dict, day: int) -> ViraltestAction: | |
| actions = [] | |
| for h in range(24): | |
| ct = _CONTENT_TYPES[h % 4] | |
| topic = _SIM_RNG.choice(_TOPICS) | |
| tags = _SIM_RNG.sample(TAG_POOL, 3) | |
| actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": topic, "tags": tags}) | |
| return _make_daily_plan(actions) | |
| def _plan_minimal(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["minimalism"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:3]) | |
| return _make_daily_plan([ | |
| {"hour": 12, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_reel_max(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["viral content"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:3]) | |
| return _make_daily_plan([ | |
| {"hour": 12, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, | |
| {"hour": 14, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_split_schedule(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["daily content"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:2]) + ["tips"] | |
| return _make_daily_plan([ | |
| {"hour": 9, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags}, | |
| {"hour": 19, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_double_peak(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["peak time content"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:3]) | |
| return _make_daily_plan([ | |
| {"hour": 9, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, | |
| {"hour": 15, "action_type": "post", "content_type": "carousel", "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_tag_explorer(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["devtools"])[0] | |
| start = (day * 6) % len(TAG_POOL) | |
| tags1 = [TAG_POOL[(start + i) % len(TAG_POOL)] for i in range(3)] | |
| tags2 = [TAG_POOL[(start + 3 + i) % len(TAG_POOL)] for i in range(3)] | |
| ct1 = _CONTENT_TYPES[(day * 2) % 4] | |
| ct2 = _CONTENT_TYPES[(day * 2 + 1) % 4] | |
| return _make_daily_plan([ | |
| {"hour": 10, "action_type": "post", "content_type": ct1, "topic": trending, "tags": tags1}, | |
| {"hour": 18, "action_type": "post", "content_type": ct2, "topic": trending, "tags": tags2}, | |
| ]) | |
| def _plan_queue_optimizer(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["productivity"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:2]) + ["growth"] | |
| queue = obs.get("content_queue_size", 0) | |
| if day < 2 or queue < 2: | |
| return _make_daily_plan([ | |
| {"hour": 8, "action_type": "create_content"}, | |
| {"hour": 10, "action_type": "create_content"}, | |
| {"hour": 14, "action_type": "create_content"}, | |
| ]) | |
| ct = _CONTENT_TYPES[day % 4] | |
| return _make_daily_plan([ | |
| {"hour": 12, "action_type": "post", "content_type": ct, "topic": trending, "tags": tags}, | |
| {"hour": 19, "action_type": "post", "content_type": _CONTENT_TYPES[(day + 1) % 4], "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_weekend(obs: dict, day: int) -> ViraltestAction: | |
| dow = obs.get("day_of_week", 0) | |
| if dow not in (5, 6): | |
| return _make_daily_plan([]) | |
| trending = (obs.get("trending_topics") or ["travel"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:3]) | |
| return _make_daily_plan([ | |
| {"hour": 11, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, | |
| {"hour": 17, "action_type": "post", "content_type": "reel", "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_weekday_only(obs: dict, day: int) -> ViraltestAction: | |
| dow = obs.get("day_of_week", 0) | |
| if dow >= 5: | |
| return _make_daily_plan([]) | |
| trending = (obs.get("trending_topics") or ["weekday content"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:2]) + ["productivity"] | |
| ct = _CONTENT_TYPES[day % 4] | |
| return _make_daily_plan([ | |
| {"hour": 12, "action_type": "post", "content_type": ct, "topic": trending, "tags": tags}, | |
| ]) | |
| def _plan_random(obs: dict, day: int) -> ViraltestAction: | |
| actions = [] | |
| for h in range(24): | |
| r = _SIM_RNG.random() | |
| if r < 0.1: | |
| ct = _SIM_RNG.choice(_CONTENT_TYPES) | |
| topic = _SIM_RNG.choice(["random topic", "AI tools", "fitness", "travel"]) | |
| tags = _SIM_RNG.sample(TAG_POOL, 2) | |
| actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": topic, "tags": tags}) | |
| elif r < 0.15: | |
| actions.append({"hour": h, "action_type": "create_content"}) | |
| return _make_daily_plan(actions) | |
| def _plan_sleep_conscious(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["wellness"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:2]) + ["productivity"] | |
| ct = _CONTENT_TYPES[day % 4] | |
| return _make_daily_plan([ | |
| {"hour": 10, "action_type": "post", "content_type": ct, "topic": trending, "tags": tags}, | |
| {"hour": 16, "action_type": "create_content"}, | |
| ]) | |
| def _plan_sleep_deprived(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["coding"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:2]) | |
| actions = [] | |
| for h in range(24): | |
| if 9 <= h <= 20 and len([a for a in actions if a["action_type"] == "post"]) < 2: | |
| ct = _CONTENT_TYPES[h % 4] | |
| actions.append({"hour": h, "action_type": "post", "content_type": ct, "topic": trending, "tags": tags}) | |
| else: | |
| actions.append({"hour": h, "action_type": "create_content"}) | |
| return _make_daily_plan(actions) | |
| def _plan_growth_focus(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["growth hacks"])[0] | |
| return _make_daily_plan([ | |
| {"hour": 13, "action_type": "post", "content_type": "reel", "topic": trending, "tags": ["viral", "growth", "trending"]}, | |
| ]) | |
| def _plan_tech_niche(obs: dict, day: int) -> ViraltestAction: | |
| ct = _CONTENT_TYPES[day % 4] | |
| return _make_daily_plan([ | |
| {"hour": 12, "action_type": "post", "content_type": ct, "topic": "AI tools and coding tips", "tags": ["ai", "coding", "devtools"]}, | |
| {"hour": 18, "action_type": "post", "content_type": _CONTENT_TYPES[(day + 1) % 4], "topic": "AI tools and coding tips", "tags": ["ai", "ml", "startup"]}, | |
| ]) | |
| def _plan_conservative(obs: dict, day: int) -> ViraltestAction: | |
| trending = (obs.get("trending_topics") or ["quick tip"])[0] | |
| tags = list((obs.get("trending_tags") or [])[:2]) | |
| return _make_daily_plan([ | |
| {"hour": 13, "action_type": "post", "content_type": "text_post", "topic": trending, "tags": tags}, | |
| ]) | |
| SCENARIOS = { | |
| "always_rest": ("Always Rest", "Never posts. Tests follower decay + zero engagement.", _plan_always_rest), | |
| "spam": ("Spam Post", "Same reel every hour. Burns out fast.", _plan_spam), | |
| "no_rest": ("No Rest", "Posts every hour, never rests. Burns out fast.", _plan_no_rest), | |
| "smart": ("Smart Agent", "Optimal: peak hours, trending, varied types, rests.", _plan_smart), | |
| "queue_optimizer": ("Queue Optimizer", "Creates content first, posts from queue.", _plan_queue_optimizer), | |
| "weekend": ("Weekend Warrior", "Only posts on Sat/Sun.", _plan_weekend), | |
| "tag_explorer": ("Tag Explorer", "New tag combo every post. Max discovery.", _plan_tag_explorer), | |
| "sleep_deprived": ("Sleep Deprived", "Never rests. Tests sleep deprivation.", _plan_sleep_deprived), | |
| "sleep_conscious": ("Sleep Conscious", "Proper sleep schedule.", _plan_sleep_conscious), | |
| "minimal": ("Minimal Poster", "1 post per day at noon.", _plan_minimal), | |
| "reel_max": ("Reel Maximizer", "Reels at peak hours for max reach.", _plan_reel_max), | |
| "split_schedule": ("Split Schedule", "Morning and evening posts.", _plan_split_schedule), | |
| "double_peak": ("Double Peak", "Posts at 9am and 3pm.", _plan_double_peak), | |
| "growth_focus": ("Growth Focus", "Maximizes follower growth.", _plan_growth_focus), | |
| "weekday_only": ("Weekday Only", "No weekend posting.", _plan_weekday_only), | |
| "tech_niche": ("Tech Niche", "AI/coding content focus.", _plan_tech_niche), | |
| "conservative": ("Conservative", "One text post at 1pm.", _plan_conservative), | |
| "random": ("Random Actor", "Random actions. Baseline test.", _plan_random), | |
| } | |
| async def dashboard_scenarios(): | |
| """List all simulation strategies for the dashboard UI.""" | |
| items = [{"id": k, "label": v[0], "description": v[1]} for k, v in SCENARIOS.items()] | |
| items.sort(key=lambda x: (x["label"].lower())) | |
| return JSONResponse( | |
| content={"count": len(items), "scenarios": items}, | |
| headers={"Cache-Control": "no-store, max-age=0, must-revalidate"}, | |
| ) | |
| async def dashboard_simulate(body: Dict[str, Any] = Body(...)): | |
| global _SIM_RNG | |
| _SIM_RNG = stdlib_random.Random(99) | |
| scenario_id = body.get("scenario", "smart") | |
| task = body.get("task", "weekly_competitive") | |
| if scenario_id not in SCENARIOS: | |
| return {"error": f"Unknown scenario: {scenario_id}"} | |
| label, desc, plan_fn = SCENARIOS[scenario_id] | |
| env = ViraltestEnvironment() | |
| obs = env.reset(task=task, seed=42) | |
| obs_dict = obs.model_dump() | |
| steps: List[Dict[str, Any]] = [] | |
| for day in range(1, 8): | |
| action = plan_fn(obs_dict, day) | |
| obs = env.step(action) | |
| obs_dict = obs.model_dump() | |
| r = obs.reward if obs.reward is not None else 0.0 | |
| n_posts = len([sa for sa in action.scheduled_actions if sa.action_type == "post"]) | |
| n_create = len([sa for sa in action.scheduled_actions if sa.action_type == "create_content"]) | |
| action_str = f"day{day}(posts={n_posts},creates={n_create})" | |
| steps.append({ | |
| "step": day, | |
| "action": action_str, | |
| "reward": round(r, 4), | |
| "done": obs.done, | |
| "error": obs.error, | |
| "energy": round(obs.creator_energy, 3), | |
| "hours_since_sleep": obs.hours_since_sleep, | |
| "sleep_debt": round(obs.sleep_debt, 3), | |
| "followers": obs.follower_count, | |
| "engagement_rate": round(obs.engagement_rate, 4), | |
| "niche_saturation": round(obs.niche_saturation, 3), | |
| "posts_today": obs.posts_today, | |
| "hour": obs.current_hour, | |
| "day": obs.day_of_week, | |
| "days_elapsed": obs.days_elapsed, | |
| "queue": obs.content_queue_size, | |
| "tag_performance": obs.tag_performance, | |
| "trending_topics": obs.trending_topics, | |
| "trending_tags": obs.trending_tags, | |
| "competitor_avg_engagement": round(obs.competitor_avg_engagement, 4), | |
| "daily_total_engagement": round(obs.daily_total_engagement, 4), | |
| "daily_posts_made": obs.daily_posts_made, | |
| "daily_energy_min": round(obs.daily_energy_min, 3), | |
| }) | |
| if obs.done: | |
| break | |
| score = (obs.metadata or {}).get("grader_score", 0.0) | |
| result = { | |
| "scenario": label, | |
| "description": desc, | |
| "task": task, | |
| "steps": steps, | |
| "total_steps": len(steps), | |
| "score": round(score, 4), | |
| "final": { | |
| "energy": round(obs.creator_energy, 3), | |
| "hours_since_sleep": obs.hours_since_sleep, | |
| "sleep_debt": round(obs.sleep_debt, 3), | |
| "followers": obs.follower_count, | |
| "engagement_rate": round(obs.engagement_rate, 4), | |
| "burned_out": obs.creator_energy <= 0, | |
| }, | |
| } | |
| rewards = [s["reward"] for s in steps] | |
| total_posts = sum(s.get("daily_posts_made", 0) for s in steps) | |
| _save_history_entry({ | |
| "id": datetime.now(timezone.utc).isoformat(), | |
| "scenario": label, | |
| "scenario_id": scenario_id, | |
| "description": desc, | |
| "task": task, | |
| "score": round(score, 4), | |
| "total_steps": len(steps), | |
| "total_posts": total_posts, | |
| "avg_reward": round(sum(rewards) / len(rewards), 4) if rewards else 0, | |
| "final": result["final"], | |
| }) | |
| return result | |
| def main(host: str = "0.0.0.0", port: int = 8000): | |
| """ | |
| Entry point for direct execution via uv run or python -m. | |
| This function enables running the server without Docker: | |
| uv run --project . server | |
| uv run --project . server --port 8001 | |
| python -m viraltest.server.app | |
| Args: | |
| host: Host address to bind to (default: "0.0.0.0") | |
| port: Port number to listen on (default: 8000) | |
| For production deployments, consider using uvicorn directly with | |
| multiple workers: | |
| uvicorn viraltest.server.app:app --workers 4 | |
| """ | |
| import uvicorn | |
| uvicorn.run(app, host=host, port=port) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=None) | |
| args = parser.parse_args() | |
| if args.port is not None: | |
| main(port=args.port) | |
| else: | |
| main() | |