File size: 5,857 Bytes
0ee3210
 
 
 
 
1b50e57
0ee3210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b50e57
0ee3210
 
1b50e57
0ee3210
1b50e57
0ee3210
 
 
 
 
 
 
 
 
 
 
 
 
1b50e57
 
0ee3210
1b50e57
0ee3210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7264c89
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field, model_validator

from openenv.core.env_server.types import Action, Observation

class ViolationCategory(str, Enum):
    """All possible content violation categories."""
    HATE_SPEECH = "hate_speech"
    VIOLENCE = "violence"
    NUDITY = "nudity"
    SPAM = "spam"
    MISINFORMATION = "misinformation"
    HARASSMENT = "harassment"
    SELF_HARM = "self_harm"
    TERRORISM = "terrorism"
    SCAM = "scam"
    CLEAN = "clean"  # No violation

class ContentType(str, Enum):
    """Type of content being moderated."""
    TEXT_POST = "text_post"
    IMAGE_DESCRIPTION = "image_description"
    AD_COPY = "ad_copy"
    WHATSAPP_MESSAGE = "whatsapp_message"

class ModerationAction(str, Enum):
    """Actions the agent can take on content."""
    APPROVE = "approve"
    REMOVE = "remove"
    ESCALATE = "escalate"
    RESTRICT = "restrict"         # Limit reach, don't remove
    REQUEST_REVIEW = "request_review"

class ContentItem(BaseModel):
    """A single piece of content to be moderated."""
    content_id: str = Field(..., description="Unique content identifier")
    content_type: ContentType
    text: str = Field(..., description="Raw content text or image description")
    author_region: str = Field(default="US", description="ISO 3166-1 alpha-2 country code")
    author_history: list[str] = Field(
        default_factory=list,
        description="Prior violation categories for this author"
    )
    context: dict[str, Any] = Field(
        default_factory=dict,
        description="Additional context: thread_id, reply_to, ad_targeting etc."
    )
    language: str = Field(default="en", description="ISO 639-1 language code")
    media_urls: list[str] = Field(
        default_factory=list, 
        description="List of direct URLs to associated media (e.g. JPG, PNG)"
    )
    media_types: list[str] = Field(
        default_factory=list, 
        description="Corresponding media types (e.g. 'image', 'video')"
    )

class ModerationObservation(Observation):
    """
    What the agent sees at each step.
    Inherits from Observation (which provides `done`, `reward`, `metadata`).
    """
    step: int = Field(default=0, ge=0, description="Current step index")
    content_item: ContentItem
    policy_excerpt: str = Field(
        default="",
        description="Relevant policy text snippet for this content type"
    )
    thread_history: list[ContentItem] = Field(
        default_factory=list,
        description="Prior messages in thread (for thread tasks only)"
    )
    conflicting_policies: list[str] = Field(
        default_factory=list,
        description="Names of policies that may conflict for this item (hard task)"
    )
    task_name: str = Field(default="", description="Active task identifier")
    instructions: str = Field(default="", description="Natural language task instructions for the agent")

class ModerationDecision(Action):
    """
    The agent's moderation decision for a single content item.
    This is what the agent submits via step().
    """
    content_id: str = Field(..., description="Must match the observed content_item.content_id")
    labels: list[ViolationCategory] = Field(
        ...,
        min_length=1,
        description="One or more violation categories. Use [CLEAN] if no violation."
    )
    action: ModerationAction = Field(..., description="Enforcement action to take")
    confidence: float = Field(
        ..., ge=0.0, le=1.0,
        description="Agent confidence in this decision (0.0–1.0)"
    )
    reasoning: str = Field(
        default="",
        description="Optional free-text reasoning (used in hard task grader)"
    )
    policy_citations: list[str] = Field(
        default_factory=list,
        description="Policy rule IDs cited for this decision"
    )

    @model_validator(mode="after")
    def clean_label_exclusivity(self) -> ModerationDecision:
        """CLEAN label must be the only label if present."""
        if ViolationCategory.CLEAN in self.labels and len(self.labels) > 1:
            raise ValueError("CLEAN label cannot be combined with violation labels")
        return self

class ModerationReward(BaseModel):
    """
    Structured reward breakdown for one step.
    Provides partial credit signals at every step.
    """
    total: float = Field(..., ge=-1.0, le=1.0, description="Composite reward for this step")

    # Partial credit components
    label_precision: float = Field(default=0.0, ge=0.0, le=1.0)
    label_recall: float = Field(default=0.0, ge=0.0, le=1.0)
    action_correct: float = Field(default=0.0, ge=0.0, le=1.0)
    policy_citation_score: float = Field(default=0.0, ge=0.0, le=1.0)
    false_positive_penalty: float = Field(default=0.0, le=0.0, description="Negative if clean content flagged")
    reasoning_quality: float = Field(default=0.0, ge=0.0, le=1.0, description="Hard task only")

    breakdown: dict[str, float] = Field(
        default_factory=dict,
        description="Raw component scores for debugging"
    )

class ModerationState(BaseModel):
    """
    Full internal state of the environment.
    Returned by state() endpoint.
    """
    task_name: str
    episode_id: str
    current_step: int
    max_steps: int
    done: bool
    cumulative_reward: float
    items_seen: int
    items_remaining: int
    decisions_log: list[dict[str, Any]] = Field(default_factory=list)
    score: float = Field(default=0.0, ge=0.0, le=1.0)
    ground_truth_data: list[dict] = Field(default_factory=list, description="All ground truth data for the current episode (used by grader)")
    has_policy_conflict: bool = Field(default=False, description="Flag for thread task grader")
    is_final_message: bool = Field(default=False, description="Flag for thread task grader")