File size: 3,877 Bytes
4abeb9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Viraltest Environment Client."""

from typing import Any, Dict

from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State

from .models import ViraltestAction, ViraltestObservation


class ViraltestEnv(
    EnvClient[ViraltestAction, ViraltestObservation, State]
):
    """
    Client for the Viraltest Creator Optimization Environment.

    Maintains a persistent WebSocket connection to the environment server.

    Example:
        >>> with ViraltestEnv(base_url="http://localhost:8000") as client:
        ...     result = client.reset(task="weekly_engage")
        ...     result = client.step(ViraltestAction(
        ...         scheduled_actions=[
        ...             {"hour": 12, "action_type": "post", "content_type": "reel",
        ...              "topic": "AI trends", "tags": ["ai", "tech"]},
        ...         ]
        ...     ))
    """

    def _step_payload(self, action: ViraltestAction) -> Dict[str, Any]:
        actions_list = []
        for sa in action.scheduled_actions:
            item: Dict[str, Any] = {
                "hour": sa.hour,
                "action_type": sa.action_type,
            }
            if sa.content_type is not None:
                item["content_type"] = sa.content_type
            if sa.topic is not None:
                item["topic"] = sa.topic
            if sa.tags is not None:
                item["tags"] = sa.tags
            actions_list.append(item)
        return {"scheduled_actions": actions_list}

    def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ViraltestObservation]:
        obs_data = payload.get("observation", {})
        grader_score = obs_data.get("grader_score")
        meta = obs_data.get("metadata", {})
        if grader_score is not None:
            meta["grader_score"] = grader_score
        observation = ViraltestObservation(
            current_hour=obs_data.get("current_hour", 0),
            day_of_week=obs_data.get("day_of_week", 0),
            days_elapsed=obs_data.get("days_elapsed", 0),
            creator_energy=obs_data.get("creator_energy", 1.0),
            follower_count=obs_data.get("follower_count", 0),
            engagement_rate=obs_data.get("engagement_rate", 0.0),
            hours_since_sleep=obs_data.get("hours_since_sleep", 0),
            posts_today=obs_data.get("posts_today", 0),
            sleep_debt=obs_data.get("sleep_debt", 0.0),
            time_since_last_post=obs_data.get("time_since_last_post", 0),
            trending_topics=obs_data.get("trending_topics", []),
            content_queue_size=obs_data.get("content_queue_size", 0),
            last_post_type=obs_data.get("last_post_type", "none"),
            tag_performance=obs_data.get("tag_performance", {}),
            trending_tags=obs_data.get("trending_tags", []),
            competitor_recent_posts=obs_data.get("competitor_recent_posts", []),
            competitor_avg_engagement=obs_data.get("competitor_avg_engagement", 0.0),
            niche_saturation=obs_data.get("niche_saturation", 0.0),
            daily_total_engagement=obs_data.get("daily_total_engagement", 0.0),
            daily_posts_made=obs_data.get("daily_posts_made", 0),
            daily_energy_min=obs_data.get("daily_energy_min", 1.0),
            grader_score=grader_score,
            error=obs_data.get("error"),
            done=payload.get("done", False),
            reward=payload.get("reward"),
            metadata=meta,
        )
        return StepResult(
            observation=observation,
            reward=payload.get("reward"),
            done=payload.get("done", False),
        )

    def _parse_state(self, payload: Dict[str, Any]) -> State:
        return State(
            episode_id=payload.get("episode_id"),
            step_count=payload.get("step_count", 0),
        )