final-iteration / client.py
anuragredbus's picture
Viraltest OpenEnv: deploy to HF Space
28dd5a4
raw
history blame
3.88 kB
"""Viraltest Environment Client."""
from typing import Any, Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import ViraltestAction, ViraltestObservation
class ViraltestEnv(
EnvClient[ViraltestAction, ViraltestObservation, State]
):
"""
Client for the Viraltest Creator Optimization Environment.
Maintains a persistent WebSocket connection to the environment server.
Example:
>>> with ViraltestEnv(base_url="http://localhost:8000") as client:
... result = client.reset(task="weekly_engage")
... result = client.step(ViraltestAction(
... scheduled_actions=[
... {"hour": 12, "action_type": "post", "content_type": "reel",
... "topic": "AI trends", "tags": ["ai", "tech"]},
... ]
... ))
"""
def _step_payload(self, action: ViraltestAction) -> Dict[str, Any]:
actions_list = []
for sa in action.scheduled_actions:
item: Dict[str, Any] = {
"hour": sa.hour,
"action_type": sa.action_type,
}
if sa.content_type is not None:
item["content_type"] = sa.content_type
if sa.topic is not None:
item["topic"] = sa.topic
if sa.tags is not None:
item["tags"] = sa.tags
actions_list.append(item)
return {"scheduled_actions": actions_list}
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ViraltestObservation]:
obs_data = payload.get("observation", {})
grader_score = obs_data.get("grader_score")
meta = obs_data.get("metadata", {})
if grader_score is not None:
meta["grader_score"] = grader_score
observation = ViraltestObservation(
current_hour=obs_data.get("current_hour", 0),
day_of_week=obs_data.get("day_of_week", 0),
days_elapsed=obs_data.get("days_elapsed", 0),
creator_energy=obs_data.get("creator_energy", 1.0),
follower_count=obs_data.get("follower_count", 0),
engagement_rate=obs_data.get("engagement_rate", 0.0),
hours_since_sleep=obs_data.get("hours_since_sleep", 0),
posts_today=obs_data.get("posts_today", 0),
sleep_debt=obs_data.get("sleep_debt", 0.0),
time_since_last_post=obs_data.get("time_since_last_post", 0),
trending_topics=obs_data.get("trending_topics", []),
content_queue_size=obs_data.get("content_queue_size", 0),
last_post_type=obs_data.get("last_post_type", "none"),
tag_performance=obs_data.get("tag_performance", {}),
trending_tags=obs_data.get("trending_tags", []),
competitor_recent_posts=obs_data.get("competitor_recent_posts", []),
competitor_avg_engagement=obs_data.get("competitor_avg_engagement", 0.0),
niche_saturation=obs_data.get("niche_saturation", 0.0),
daily_total_engagement=obs_data.get("daily_total_engagement", 0.0),
daily_posts_made=obs_data.get("daily_posts_made", 0),
daily_energy_min=obs_data.get("daily_energy_min", 1.0),
grader_score=grader_score,
error=obs_data.get("error"),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=meta,
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict[str, Any]) -> State:
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)