from __future__ import annotations from enum import Enum from typing import Any from pydantic import BaseModel, Field, model_validator from openenv.core.env_server.types import Action, Observation class ViolationCategory(str, Enum): """All possible content violation categories.""" HATE_SPEECH = "hate_speech" VIOLENCE = "violence" NUDITY = "nudity" SPAM = "spam" MISINFORMATION = "misinformation" HARASSMENT = "harassment" SELF_HARM = "self_harm" TERRORISM = "terrorism" SCAM = "scam" CLEAN = "clean" # No violation class ContentType(str, Enum): """Type of content being moderated.""" TEXT_POST = "text_post" IMAGE_DESCRIPTION = "image_description" AD_COPY = "ad_copy" WHATSAPP_MESSAGE = "whatsapp_message" class ModerationAction(str, Enum): """Actions the agent can take on content.""" APPROVE = "approve" REMOVE = "remove" ESCALATE = "escalate" RESTRICT = "restrict" # Limit reach, don't remove REQUEST_REVIEW = "request_review" class ContentItem(BaseModel): """A single piece of content to be moderated.""" content_id: str = Field(..., description="Unique content identifier") content_type: ContentType text: str = Field(..., description="Raw content text or image description") author_region: str = Field(default="US", description="ISO 3166-1 alpha-2 country code") author_history: list[str] = Field( default_factory=list, description="Prior violation categories for this author" ) context: dict[str, Any] = Field( default_factory=dict, description="Additional context: thread_id, reply_to, ad_targeting etc." ) language: str = Field(default="en", description="ISO 639-1 language code") media_urls: list[str] = Field( default_factory=list, description="List of direct URLs to associated media (e.g. JPG, PNG)" ) media_types: list[str] = Field( default_factory=list, description="Corresponding media types (e.g. 'image', 'video')" ) class ModerationObservation(Observation): """ What the agent sees at each step. Inherits from Observation (which provides `done`, `reward`, `metadata`). """ step: int = Field(default=0, ge=0, description="Current step index") content_item: ContentItem policy_excerpt: str = Field( default="", description="Relevant policy text snippet for this content type" ) thread_history: list[ContentItem] = Field( default_factory=list, description="Prior messages in thread (for thread tasks only)" ) conflicting_policies: list[str] = Field( default_factory=list, description="Names of policies that may conflict for this item (hard task)" ) task_name: str = Field(default="", description="Active task identifier") instructions: str = Field(default="", description="Natural language task instructions for the agent") class ModerationDecision(Action): """ The agent's moderation decision for a single content item. This is what the agent submits via step(). """ content_id: str = Field(..., description="Must match the observed content_item.content_id") labels: list[ViolationCategory] = Field( ..., min_length=1, description="One or more violation categories. Use [CLEAN] if no violation." ) action: ModerationAction = Field(..., description="Enforcement action to take") confidence: float = Field( ..., ge=0.0, le=1.0, description="Agent confidence in this decision (0.0–1.0)" ) reasoning: str = Field( default="", description="Optional free-text reasoning (used in hard task grader)" ) policy_citations: list[str] = Field( default_factory=list, description="Policy rule IDs cited for this decision" ) @model_validator(mode="after") def clean_label_exclusivity(self) -> ModerationDecision: """CLEAN label must be the only label if present.""" if ViolationCategory.CLEAN in self.labels and len(self.labels) > 1: raise ValueError("CLEAN label cannot be combined with violation labels") return self class ModerationReward(BaseModel): """ Structured reward breakdown for one step. Provides partial credit signals at every step. """ total: float = Field(..., ge=-1.0, le=1.0, description="Composite reward for this step") # Partial credit components label_precision: float = Field(default=0.0, ge=0.0, le=1.0) label_recall: float = Field(default=0.0, ge=0.0, le=1.0) action_correct: float = Field(default=0.0, ge=0.0, le=1.0) policy_citation_score: float = Field(default=0.0, ge=0.0, le=1.0) false_positive_penalty: float = Field(default=0.0, le=0.0, description="Negative if clean content flagged") reasoning_quality: float = Field(default=0.0, ge=0.0, le=1.0, description="Hard task only") breakdown: dict[str, float] = Field( default_factory=dict, description="Raw component scores for debugging" ) class ModerationState(BaseModel): """ Full internal state of the environment. Returned by state() endpoint. """ task_name: str episode_id: str current_step: int max_steps: int done: bool cumulative_reward: float items_seen: int items_remaining: int decisions_log: list[dict[str, Any]] = Field(default_factory=list) score: float = Field(default=0.0, ge=0.0, le=1.0) ground_truth_data: list[dict] = Field(default_factory=list, description="All ground truth data for the current episode (used by grader)") has_policy_conflict: bool = Field(default=False, description="Flag for thread task grader") is_final_message: bool = Field(default=False, description="Flag for thread task grader")