"""Viraltest Environment Client.""" from typing import Any, Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from .models import ViraltestAction, ViraltestObservation class ViraltestEnv( EnvClient[ViraltestAction, ViraltestObservation, State] ): """ Client for the Viraltest Creator Optimization Environment. Maintains a persistent WebSocket connection to the environment server. Example: >>> with ViraltestEnv(base_url="http://localhost:8000") as client: ... result = client.reset(task="weekly_engage") ... result = client.step(ViraltestAction( ... scheduled_actions=[ ... {"hour": 12, "action_type": "post", "content_type": "reel", ... "topic": "AI trends", "tags": ["ai", "tech"]}, ... ] ... )) """ def _step_payload(self, action: ViraltestAction) -> Dict[str, Any]: actions_list = [] for sa in action.scheduled_actions: item: Dict[str, Any] = { "hour": sa.hour, "action_type": sa.action_type, } if sa.content_type is not None: item["content_type"] = sa.content_type if sa.topic is not None: item["topic"] = sa.topic if sa.tags is not None: item["tags"] = sa.tags actions_list.append(item) return {"scheduled_actions": actions_list} def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ViraltestObservation]: obs_data = payload.get("observation", {}) grader_score = obs_data.get("grader_score") meta = obs_data.get("metadata", {}) if grader_score is not None: meta["grader_score"] = grader_score observation = ViraltestObservation( current_hour=obs_data.get("current_hour", 0), day_of_week=obs_data.get("day_of_week", 0), days_elapsed=obs_data.get("days_elapsed", 0), creator_energy=obs_data.get("creator_energy", 1.0), follower_count=obs_data.get("follower_count", 0), engagement_rate=obs_data.get("engagement_rate", 0.0), hours_since_sleep=obs_data.get("hours_since_sleep", 0), posts_today=obs_data.get("posts_today", 0), sleep_debt=obs_data.get("sleep_debt", 0.0), time_since_last_post=obs_data.get("time_since_last_post", 0), trending_topics=obs_data.get("trending_topics", []), content_queue_size=obs_data.get("content_queue_size", 0), last_post_type=obs_data.get("last_post_type", "none"), tag_performance=obs_data.get("tag_performance", {}), trending_tags=obs_data.get("trending_tags", []), competitor_recent_posts=obs_data.get("competitor_recent_posts", []), competitor_avg_engagement=obs_data.get("competitor_avg_engagement", 0.0), niche_saturation=obs_data.get("niche_saturation", 0.0), daily_total_engagement=obs_data.get("daily_total_engagement", 0.0), daily_posts_made=obs_data.get("daily_posts_made", 0), daily_energy_min=obs_data.get("daily_energy_min", 1.0), grader_score=grader_score, error=obs_data.get("error"), done=payload.get("done", False), reward=payload.get("reward"), metadata=meta, ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict[str, Any]) -> State: return State( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), )