File size: 5,857 Bytes
0ee3210 1b50e57 0ee3210 1b50e57 0ee3210 1b50e57 0ee3210 1b50e57 0ee3210 1b50e57 0ee3210 1b50e57 0ee3210 7264c89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from __future__ import annotations
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, model_validator
from openenv.core.env_server.types import Action, Observation
class ViolationCategory(str, Enum):
"""All possible content violation categories."""
HATE_SPEECH = "hate_speech"
VIOLENCE = "violence"
NUDITY = "nudity"
SPAM = "spam"
MISINFORMATION = "misinformation"
HARASSMENT = "harassment"
SELF_HARM = "self_harm"
TERRORISM = "terrorism"
SCAM = "scam"
CLEAN = "clean" # No violation
class ContentType(str, Enum):
"""Type of content being moderated."""
TEXT_POST = "text_post"
IMAGE_DESCRIPTION = "image_description"
AD_COPY = "ad_copy"
WHATSAPP_MESSAGE = "whatsapp_message"
class ModerationAction(str, Enum):
"""Actions the agent can take on content."""
APPROVE = "approve"
REMOVE = "remove"
ESCALATE = "escalate"
RESTRICT = "restrict" # Limit reach, don't remove
REQUEST_REVIEW = "request_review"
class ContentItem(BaseModel):
"""A single piece of content to be moderated."""
content_id: str = Field(..., description="Unique content identifier")
content_type: ContentType
text: str = Field(..., description="Raw content text or image description")
author_region: str = Field(default="US", description="ISO 3166-1 alpha-2 country code")
author_history: list[str] = Field(
default_factory=list,
description="Prior violation categories for this author"
)
context: dict[str, Any] = Field(
default_factory=dict,
description="Additional context: thread_id, reply_to, ad_targeting etc."
)
language: str = Field(default="en", description="ISO 639-1 language code")
media_urls: list[str] = Field(
default_factory=list,
description="List of direct URLs to associated media (e.g. JPG, PNG)"
)
media_types: list[str] = Field(
default_factory=list,
description="Corresponding media types (e.g. 'image', 'video')"
)
class ModerationObservation(Observation):
"""
What the agent sees at each step.
Inherits from Observation (which provides `done`, `reward`, `metadata`).
"""
step: int = Field(default=0, ge=0, description="Current step index")
content_item: ContentItem
policy_excerpt: str = Field(
default="",
description="Relevant policy text snippet for this content type"
)
thread_history: list[ContentItem] = Field(
default_factory=list,
description="Prior messages in thread (for thread tasks only)"
)
conflicting_policies: list[str] = Field(
default_factory=list,
description="Names of policies that may conflict for this item (hard task)"
)
task_name: str = Field(default="", description="Active task identifier")
instructions: str = Field(default="", description="Natural language task instructions for the agent")
class ModerationDecision(Action):
"""
The agent's moderation decision for a single content item.
This is what the agent submits via step().
"""
content_id: str = Field(..., description="Must match the observed content_item.content_id")
labels: list[ViolationCategory] = Field(
...,
min_length=1,
description="One or more violation categories. Use [CLEAN] if no violation."
)
action: ModerationAction = Field(..., description="Enforcement action to take")
confidence: float = Field(
..., ge=0.0, le=1.0,
description="Agent confidence in this decision (0.0–1.0)"
)
reasoning: str = Field(
default="",
description="Optional free-text reasoning (used in hard task grader)"
)
policy_citations: list[str] = Field(
default_factory=list,
description="Policy rule IDs cited for this decision"
)
@model_validator(mode="after")
def clean_label_exclusivity(self) -> ModerationDecision:
"""CLEAN label must be the only label if present."""
if ViolationCategory.CLEAN in self.labels and len(self.labels) > 1:
raise ValueError("CLEAN label cannot be combined with violation labels")
return self
class ModerationReward(BaseModel):
"""
Structured reward breakdown for one step.
Provides partial credit signals at every step.
"""
total: float = Field(..., ge=-1.0, le=1.0, description="Composite reward for this step")
# Partial credit components
label_precision: float = Field(default=0.0, ge=0.0, le=1.0)
label_recall: float = Field(default=0.0, ge=0.0, le=1.0)
action_correct: float = Field(default=0.0, ge=0.0, le=1.0)
policy_citation_score: float = Field(default=0.0, ge=0.0, le=1.0)
false_positive_penalty: float = Field(default=0.0, le=0.0, description="Negative if clean content flagged")
reasoning_quality: float = Field(default=0.0, ge=0.0, le=1.0, description="Hard task only")
breakdown: dict[str, float] = Field(
default_factory=dict,
description="Raw component scores for debugging"
)
class ModerationState(BaseModel):
"""
Full internal state of the environment.
Returned by state() endpoint.
"""
task_name: str
episode_id: str
current_step: int
max_steps: int
done: bool
cumulative_reward: float
items_seen: int
items_remaining: int
decisions_log: list[dict[str, Any]] = Field(default_factory=list)
score: float = Field(default=0.0, ge=0.0, le=1.0)
ground_truth_data: list[dict] = Field(default_factory=list, description="All ground truth data for the current episode (used by grader)")
has_policy_conflict: bool = Field(default=False, description="Flag for thread task grader")
is_final_message: bool = Field(default=False, description="Flag for thread task grader")
|