| |
| from __future__ import annotations |
| import uuid |
| from typing import Any |
|
|
| from openenv.core.env_server.interfaces import Environment |
| from openenv.core.env_server.types import State |
|
|
| try: |
| from ..models import ( |
| ModerationDecision, ModerationObservation, ModerationReward, ModerationState, ContentItem |
| ) |
| except ImportError: |
| from models import ( |
| ModerationDecision, ModerationObservation, ModerationReward, ModerationState, ContentItem |
| ) |
|
|
| from server.dataset import ( |
| get_posts, get_image_descriptions, get_ad_copies, get_whatsapp_threads, |
| get_community_standards, get_ad_policies, |
| ) |
| from server.graders import ( |
| grade_single_label, grade_multi_label, grade_ad_policy, grade_thread_hard, get_ground_truth, |
| ) |
| from server.tasks.task_single_label import build_episode as build_single_label_episode, build_observation as build_single_label_obs, MAX_STEPS as SINGLE_MAX, TASK_NAME as SINGLE_TASK |
| from server.tasks.task_multi_label import build_episode as build_multi_label_episode, build_observation as build_multi_label_obs, MAX_STEPS as MULTI_MAX, TASK_NAME as MULTI_TASK |
| from server.tasks.task_ad_policy import build_episode as build_ad_episode, build_observation as build_ad_obs, MAX_STEPS as AD_MAX, TASK_NAME as AD_TASK |
| from server.tasks.task_thread_hard import build_episode as build_thread_episode, build_observation as build_thread_obs, MAX_STEPS as THREAD_MAX, TASK_NAME as THREAD_TASK |
|
|
| VALID_TASKS = {SINGLE_TASK, MULTI_TASK, AD_TASK, THREAD_TASK} |
|
|
| class MetaContentModerationEnv(Environment[ModerationDecision, ModerationObservation, ModerationState]): |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True |
| _instance = None |
|
|
| def __new__(cls, *args, **kwargs): |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| cls._instance._initialized = False |
| return cls._instance |
|
|
| def __init__(self, task: str = "single-label-classify", seed: int = 42) -> None: |
| if getattr(self, "_initialized", False): |
| return |
| self._initialized = True |
| if task not in VALID_TASKS: |
| raise ValueError(f"Unknown task '{task}'. Valid: {VALID_TASKS}") |
| self.task = task |
| self.seed = seed |
|
|
| self._episode_id: str = "" |
| self._step: int = 0 |
| self._max_steps: int = 0 |
| self._done: bool = False |
| self._cumulative_reward: float = 0.0 |
| self._decisions_log: list[dict[str, Any]] = [] |
|
|
| self._items: list[ContentItem] = [] |
| self._ground_truth_all: list[dict] = [] |
| self._thread_steps: list[Any] = [] |
|
|
| def reset(self, task: str = None, seed: int = None) -> ModerationObservation: |
| if task is not None: |
| if task not in VALID_TASKS: |
| raise ValueError(f"Unknown task '{task}'. Valid: {VALID_TASKS}") |
| self.task = task |
| if seed is not None: |
| self.seed = seed |
| |
| self._episode_id = str(uuid.uuid4()) |
| self._step = 0 |
| self._done = False |
| self._cumulative_reward = 0.0 |
| self._decisions_log = [] |
|
|
| self._load_episode_data() |
| |
| obs = self._make_observation() |
| obs.reward = 0.0 |
| obs.done = False |
| obs.metadata = { |
| "episode_id": self._episode_id, |
| "step": self._step, |
| "cumulative_reward": 0.0 |
| } |
| return obs |
|
|
| def step(self, action: ModerationDecision) -> ModerationObservation: |
| if not self._episode_id: |
| raise RuntimeError("Call reset() before step()") |
| if self._done: |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") |
|
|
| reward_obj = self._grade(action) |
| reward = reward_obj.total |
|
|
| self._cumulative_reward += reward |
| self._decisions_log.append({ |
| "step": self._step, |
| "content_id": action.content_id, |
| "labels": [l.value for l in action.labels], |
| "action": action.action.value, |
| "reward": reward, |
| "breakdown": reward_obj.breakdown, |
| }) |
|
|
| self._step += 1 |
| self._done = self._step >= self._max_steps |
|
|
| if self._done: |
| next_obs = self._make_terminal_observation() |
| else: |
| next_obs = self._make_observation() |
|
|
| next_obs.reward = reward |
| next_obs.done = self._done |
| next_obs.metadata = { |
| "episode_id": self._episode_id, |
| "cumulative_reward": self._cumulative_reward, |
| "step": self._step, |
| "reward_breakdown": reward_obj.model_dump(), |
| } |
| return next_obs |
|
|
| @property |
| def state(self) -> ModerationState: |
| score = self._compute_score() |
| return ModerationState( |
| task_name=self.task, |
| episode_id=self._episode_id, |
| current_step=self._step, |
| max_steps=self._max_steps, |
| done=self._done, |
| cumulative_reward=self._cumulative_reward, |
| items_seen=self._step, |
| items_remaining=max(0, self._max_steps - self._step), |
| decisions_log=self._decisions_log, |
| score=score, |
| ground_truth_data=self._ground_truth_all, |
| has_policy_conflict=bool(self._thread_steps[self._step][2]) if self.task == THREAD_TASK and self._step < len(self._thread_steps) else False, |
| is_final_message=(self._step == self._max_steps - 1) if self.task == THREAD_TASK else False, |
| ) |
|
|
| |
|
|
| def _load_episode_data(self) -> None: |
| if self.task == SINGLE_TASK: |
| self._items = build_single_label_episode(self.seed) |
| self._max_steps = min(SINGLE_MAX, len(self._items)) |
| raw_all = get_posts(self.seed) + get_image_descriptions(self.seed) |
| self._ground_truth_all = raw_all |
|
|
| elif self.task == MULTI_TASK: |
| self._items = build_multi_label_episode(self.seed) |
| self._max_steps = min(MULTI_MAX, len(self._items)) |
| self._ground_truth_all = get_posts(self.seed) + get_ad_copies(self.seed) |
|
|
| elif self.task == AD_TASK: |
| self._items = build_ad_episode(self.seed) |
| self._max_steps = min(AD_MAX, len(self._items)) |
| self._ground_truth_all = get_ad_copies(self.seed) |
|
|
| elif self.task == THREAD_TASK: |
| self._thread_steps = build_thread_episode(self.seed) |
| self._max_steps = min(THREAD_MAX, len(self._thread_steps)) |
| threads = get_whatsapp_threads(self.seed) |
| self._ground_truth_all = [ |
| msg for t in threads for msg in t["messages"] |
| ] |
| self._items = [step[0] for step in self._thread_steps] |
|
|
| def _make_observation(self) -> ModerationObservation: |
| if self.task == THREAD_TASK: |
| item, history, conflicts = self._thread_steps[self._step] |
| return build_thread_obs(self._step, item, history, conflicts) |
|
|
| item = self._items[self._step] |
|
|
| if self.task == SINGLE_TASK: |
| return build_single_label_obs(self._step, item) |
| elif self.task == MULTI_TASK: |
| return build_multi_label_obs(self._step, item) |
| elif self.task == AD_TASK: |
| return build_ad_obs(self._step, item) |
|
|
| raise ValueError(f"Unknown task: {self.task}") |
|
|
| def _make_terminal_observation(self) -> ModerationObservation: |
| try: |
| from ..models import ContentItem, ContentType |
| except ImportError: |
| from models import ContentItem, ContentType |
| |
| dummy = ContentItem( |
| content_id="__terminal__", |
| content_type=ContentType.TEXT_POST, |
| text="Episode complete.", |
| ) |
| return ModerationObservation( |
| step=self._step, |
| content_item=dummy, |
| task_name=self.task, |
| instructions="Episode complete. No more items.", |
| ) |
|
|
| def _grade(self, action: ModerationDecision) -> ModerationReward: |
| gt = get_ground_truth(action.content_id, self._ground_truth_all) |
|
|
| if self.task == SINGLE_TASK: |
| return grade_single_label(action, gt["labels"], gt["action"]) |
|
|
| elif self.task == MULTI_TASK: |
| return grade_multi_label(action, gt["labels"], gt["action"]) |
|
|
| elif self.task == AD_TASK: |
| return grade_ad_policy(action, gt["labels"], gt["action"], gt["policy_ids"]) |
|
|
| elif self.task == THREAD_TASK: |
| _, _, conflicts = self._thread_steps[self._step] |
| is_final = (self._step == self._max_steps - 1) |
| return grade_thread_hard( |
| action, gt["labels"], gt["action"], |
| has_policy_conflict=bool(conflicts), |
| is_final_message=is_final, |
| ) |
|
|
| raise ValueError(f"Unknown task: {self.task}") |
|
|
| def _compute_score(self) -> float: |
| if not self._decisions_log: |
| return 0.01 |
| max_possible = self._max_steps * 1.0 |
| if max_possible <= 0: |
| return 0.01 |
| |
| avg_reward = self._cumulative_reward / max_possible |
| |
| normalized = (avg_reward + 1.0) / 2.0 |
| |
| |
| score = min(max(normalized, 0.01), 0.99) |
| return round(score, 4) |
|
|
|
|