jeromerichard commited on
Commit
7cf2ffd
·
1 Parent(s): f45aa51

Fix: add server/app.py, uv.lock, project.scripts entry point

Browse files
Files changed (7) hide show
  1. pyproject.toml +11 -25
  2. server/__init__.py +0 -0
  3. server/app.py +157 -0
  4. server/models.py +63 -0
  5. server/tasks.py +296 -0
  6. server/your_environment.py +440 -0
  7. uv.lock +0 -0
pyproject.toml CHANGED
@@ -1,33 +1,19 @@
1
- [build-system]
2
- requires = ["setuptools>=68.0", "wheel"]
3
- build-backend = "setuptools.backends.legacy:build"
4
-
5
- [project]
6
  name = "trust-safety-env"
7
  version = "1.0.0"
8
- description = "Risk-aware Trust & Safety content moderation RL environment OpenEnv compatible"
9
- readme = "README.md"
10
  requires-python = ">=3.11"
11
  dependencies = [
12
- "openenv-core>=0.2.0",
13
- "fastapi>=0.110.0",
14
- "uvicorn[standard]>=0.29.0",
15
- "pydantic>=2.6.0",
16
- "openai>=1.30.0",
17
  "requests>=2.31.0",
18
- "python-dotenv>=1.0.0",
19
  ]
20
 
21
- [project.optional-dependencies]
22
- dev = ["pytest>=8.0"]
23
 
24
- [tool.setuptools.packages.find]
25
- where = ["."]
26
- include = ["*"]
27
-
28
- [tool.openenv]
29
- name = "trust-safety-env"
30
- environment_class = "your_environment.TrustSafetyEnvironment"
31
- action_model = "models.TrustAction"
32
- observation_model = "models.TrustObservation"
33
- state_model = "models.TrustState"
 
1
+ [project]
 
 
 
 
2
  name = "trust-safety-env"
3
  version = "1.0.0"
4
+ description = "Trust & Safety RL Environment built on OpenEnv"
 
5
  requires-python = ">=3.11"
6
  dependencies = [
7
+ "fastapi>=0.115.0",
8
+ "uvicorn[standard]>=0.30.0",
9
+ "pydantic>=2.0.0",
 
 
10
  "requests>=2.31.0",
11
+ "openenv-core>=0.2.2",
12
  ]
13
 
14
+ [project.scripts]
15
+ server = "server.app:app"
16
 
17
+ [build-system]
18
+ requires = ["hatchling"]
19
+ build-backend = "hatchling.build"
 
 
 
 
 
 
 
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any, Dict, Optional
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse
9
+ from pydantic import BaseModel
10
+
11
+ from models import TrustAction, TrustObservation, TrustState, ContentSignals
12
+ from your_environment import TrustSafetyEnvironment
13
+
14
+ # ── Force manual FastAPI (openenv_core create_app causes 422 on /step) ────────
15
+ print("[app] Using manual FastAPI ✅")
16
+
17
+ _env = TrustSafetyEnvironment(seed=42)
18
+
19
+ app = FastAPI(
20
+ title="Trust & Safety RL Environment",
21
+ description="Risk-aware content moderation environment for agent training.",
22
+ version="1.0.0",
23
+ )
24
+
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+
33
+ # ── Serializers ───────────────────────────────────────────────────────────────
34
+
35
+ def _obs_to_dict(obs: TrustObservation) -> Dict[str, Any]:
36
+ return {
37
+ "ticket_id": obs.ticket_id,
38
+ "post_text": obs.post_text,
39
+ "image_description": obs.image_description,
40
+ "comments_found": obs.comments_found,
41
+ "user_history_found": obs.user_history_found,
42
+ "entity_status_found": obs.entity_status_found,
43
+ "policy_found": obs.policy_found,
44
+ "extracted_signals": obs.extracted_signals,
45
+ "validation_result": obs.validation_result,
46
+ "step_number": obs.step_number,
47
+ "info": obs.info,
48
+ "done": obs.done,
49
+ "reward": obs.reward,
50
+ }
51
+
52
+
53
+ def _state_to_dict(s: TrustState) -> Dict[str, Any]:
54
+ return {
55
+ "episode_id": s.episode_id,
56
+ "step_count": s.step_count,
57
+ "current_task_id": s.current_task_id,
58
+ "difficulty": s.difficulty,
59
+ "ambiguity_level": s.ambiguity_level,
60
+ "risk_level": s.risk_level,
61
+ "tools_used": s.tools_used,
62
+ "signals_extracted": s.signals_extracted,
63
+ "is_done": s.is_done,
64
+ }
65
+
66
+
67
+ # ── Request bodies ─────────────────────────────────────────────────────────────
68
+
69
+ class ResetRequest(BaseModel):
70
+ seed: Any = None
71
+ episode_id: Any = None
72
+
73
+ model_config = {"extra": "ignore"}
74
+
75
+
76
+ class ActionRequest(BaseModel):
77
+ action_type: str = ""
78
+ tool_name: Optional[str] = None
79
+ signals: Optional[Dict[str, Any]] = None # raw dict — validated below
80
+ final_decision: Optional[str] = None
81
+
82
+ model_config = {"extra": "ignore"} # ← ignore unknown keys from LLM
83
+
84
+
85
+ # ── Helpers ────────────────────────────────────────────────────────────────────
86
+
87
+ def _parse_signals(raw: Dict[str, Any]) -> ContentSignals:
88
+ """Defensively normalise LLM signal output before Pydantic validation."""
89
+ # Clamp floats
90
+ raw["toxicity_level"] = float(raw.get("toxicity_level", 0.5))
91
+ raw["confidence"] = float(raw.get("confidence", 0.5))
92
+
93
+ # content_flags must be a list of strings
94
+ flags = raw.get("content_flags", [])
95
+ if not isinstance(flags, list):
96
+ flags = [flags] if isinstance(flags, str) else []
97
+ raw["content_flags"] = [str(f) for f in flags]
98
+
99
+ # boolean coercion
100
+ raw["is_protected_class"] = bool(raw.get("is_protected_class", False))
101
+ raw["is_direct_attack"] = bool(raw.get("is_direct_attack", False))
102
+ raw["abusive_language_present"] = bool(raw.get("abusive_language_present", False))
103
+
104
+ # string fields — fallback to sensible defaults
105
+ raw.setdefault("target", "none")
106
+ raw.setdefault("intent", "ambiguous")
107
+ raw.setdefault("context_type", "statement")
108
+
109
+ return ContentSignals(**raw)
110
+
111
+
112
+ # ── Routes ─────────────────────────────────────────────────────────────────────
113
+
114
+ @app.get("/health")
115
+ async def health():
116
+ return {"status": "ok", "environment": "trust-safety-env", "version": "1.0.0"}
117
+
118
+
119
+ @app.get("/")
120
+ async def root():
121
+ return {"status": "ok", "docs": "/docs"}
122
+
123
+
124
+ @app.post("/reset")
125
+ async def reset(body: ResetRequest = ResetRequest()):
126
+ obs = _env.reset(seed=body.seed, episode_id=body.episode_id)
127
+ return JSONResponse(_obs_to_dict(obs))
128
+
129
+
130
+ @app.post("/step")
131
+ async def step(body: ActionRequest):
132
+ # Parse + validate signals defensively
133
+ signals: Optional[ContentSignals] = None
134
+ if body.signals:
135
+ try:
136
+ signals = _parse_signals(dict(body.signals)) # copy so we don't mutate
137
+ except Exception as e:
138
+ raise HTTPException(status_code=400, detail=f"Invalid signals payload: {e}")
139
+
140
+ action = TrustAction(
141
+ action_type = body.action_type,
142
+ tool_name = body.tool_name,
143
+ signals = signals,
144
+ final_decision = body.final_decision,
145
+ )
146
+
147
+ try:
148
+ obs = _env.step(action)
149
+ except (RuntimeError, ValueError) as e:
150
+ raise HTTPException(status_code=400, detail=str(e))
151
+
152
+ return JSONResponse(_obs_to_dict(obs))
153
+
154
+
155
+ @app.get("/state")
156
+ async def state():
157
+ return JSONResponse(_state_to_dict(_env.state))
server/models.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Optional, List, Dict, Any
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class ContentSignals(BaseModel):
7
+ target: str = "none"
8
+ is_protected_class: bool = False
9
+ toxicity_level: float = 0.5
10
+ is_direct_attack: bool = False
11
+ context_type: str = "statement"
12
+ intent: str = "ambiguous"
13
+ confidence: float = 0.5
14
+ abusive_language_present: bool = False
15
+ content_flags: List[str] = Field(default_factory=list)
16
+
17
+ @field_validator("toxicity_level", "confidence")
18
+ @classmethod
19
+ def clamp_0_1(cls, v: float) -> float:
20
+ return max(0.0, min(1.0, float(v)))
21
+
22
+ model_config = {"extra": "ignore"}
23
+
24
+
25
+ class TrustAction(BaseModel):
26
+ action_type: str = ""
27
+ tool_name: Optional[str] = None
28
+ signals: Optional[ContentSignals] = None
29
+ final_decision: Optional[str] = None
30
+
31
+ model_config = {"extra": "ignore"}
32
+
33
+
34
+ class TrustObservation(BaseModel):
35
+ ticket_id: str = ""
36
+ post_text: str = ""
37
+ image_description: str = ""
38
+ comments_found: Optional[str] = None
39
+ user_history_found: Optional[str] = None
40
+ entity_status_found: Optional[str] = None
41
+ policy_found: Optional[str] = None
42
+ extracted_signals: Optional[Dict[str, Any]] = None
43
+ validation_result: Optional[Dict[str, Any]] = None
44
+ step_number: int = 0
45
+ info: Optional[Dict[str, Any]] = None
46
+ done: bool = False
47
+ reward: Optional[float] = None
48
+
49
+ model_config = {"extra": "ignore"}
50
+
51
+
52
+ class TrustState(BaseModel):
53
+ episode_id: Optional[str] = None
54
+ step_count: int = 0
55
+ current_task_id: Optional[str] = None
56
+ difficulty: Optional[str] = None
57
+ ambiguity_level: Optional[str] = None
58
+ risk_level: Optional[str] = None
59
+ tools_used: List[str] = Field(default_factory=list)
60
+ signals_extracted: bool = False
61
+ is_done: bool = False
62
+
63
+ model_config = {"extra": "ignore"}
server/tasks.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASKS = [
2
+ {
3
+ "task_id": "T-001",
4
+ "difficulty": "easy",
5
+ "ambiguity_level": "low",
6
+ "risk_level": "low",
7
+ "post_text": (
8
+ "🚨 URGENT: Your account has been COMPROMISED! "
9
+ "Click here immediately: http://meta-security-verify-now.xyz/login?token=abc123 "
10
+ "Act within 24 hours or your account will be PERMANENTLY BANNED. — Meta Security Team"
11
+ ),
12
+ "image_description": (
13
+ "Poorly formatted alert page with mismatched Meta logos, "
14
+ "red urgent warning banners, and a suspicious external URL."
15
+ ),
16
+ "tool_responses": {
17
+ "read_comments": "Comments disabled. Zero engagement.",
18
+ "check_user_history": "Account created 2 days ago. Zero posts. Flagged by spam filter 3×.",
19
+ "check_entity_status": "Domain 'meta-security-verify-now.xyz' registered 1 day ago — confirmed phishing domain.",
20
+ "view_policy": "Policy 3.2 — Phishing: Impersonating Meta to harvest credentials = immediate REMOVE.",
21
+ },
22
+ "required_tools": [],
23
+ "ground_truth": "REMOVE",
24
+ "ground_truth_signals": {
25
+ "target": "individual",
26
+ "is_protected_class": False,
27
+ "toxicity_level": 0.90,
28
+ "is_direct_attack": True,
29
+ "context_type": "spam",
30
+ "intent": "harmful",
31
+ "abusive_language_present": False,
32
+ "content_flags": ["phishing", "spam", "impersonation"],
33
+ },
34
+ },
35
+ {
36
+ "task_id": "T-002",
37
+ "difficulty": "medium",
38
+ "ambiguity_level": "medium",
39
+ "risk_level": "low",
40
+ "post_text": (
41
+ "LMAO I will literally SNIPE you the second the game starts, "
42
+ "no mercy, you're getting absolutely destroyed 💀🎮 "
43
+ "1v1 me and I'll end your whole career bro. Your team is DONE."
44
+ ),
45
+ "image_description": (
46
+ "Call of Duty: Warzone — squad wipe in progress. "
47
+ "Kill notifications, sniper rifle, enemy team eliminated. Classic gaming screenshot."
48
+ ),
49
+ "tool_responses": {
50
+ "read_comments": (
51
+ "47 comments: 'GG bro!', 'You always trash talk before CoD 😂', "
52
+ "'rematch tonight??', 'your aim is mid bro'. "
53
+ "All mutual gaming friends. Clear competitive banter context."
54
+ ),
55
+ "check_user_history": (
56
+ "Account 3 years active. 500+ posts, 90% gaming. "
57
+ "Zero harassment reports. Regular CoD tournament player."
58
+ ),
59
+ "check_entity_status": "Private individual. Standard gaming community member.",
60
+ "view_policy": (
61
+ "Policy 7.1 — Harassment: Hyperbolic language in established gaming communities "
62
+ "is NOT harassment. 'Snipe', 'destroy', 'end your career' are non-literal expressions."
63
+ ),
64
+ },
65
+ "required_tools": ["read_comments"],
66
+ "ground_truth": "ALLOW",
67
+ "ground_truth_signals": {
68
+ "target": "individual",
69
+ "is_protected_class": False,
70
+ "toxicity_level": 0.30,
71
+ "is_direct_attack": False,
72
+ "context_type": "meme",
73
+ "intent": "ambiguous",
74
+ "abusive_language_present": False,
75
+ "content_flags": ["competitive_language"],
76
+ },
77
+ },
78
+ {
79
+ "task_id": "T-003",
80
+ "difficulty": "hard",
81
+ "ambiguity_level": "high",
82
+ "risk_level": "high",
83
+ "post_text": (
84
+ "Sharing this political cartoon I drew. "
85
+ "Mayor Davidson cutting the homeless shelter budget while sitting on a throne of gold coins. "
86
+ "Caption: 'Priorities.' #LocalPolitics #HoldThemAccountable #SatiricalArt #MayorDavidson"
87
+ ),
88
+ "image_description": (
89
+ "Hand-drawn political cartoon. Exaggerated caricature of a public official on a coin throne, "
90
+ "homeless families pushed away. Deliberately satirical editorial style. Artist signature visible."
91
+ ),
92
+ "tool_responses": {
93
+ "read_comments": (
94
+ "23 comments: 'Powerful commentary', 'Free speech!', 'Great art', 'Defamation!'. "
95
+ "No threats or coordinated harassment detected."
96
+ ),
97
+ "check_user_history": (
98
+ "Account 5 years. Verified local journalist and artist. "
99
+ "Regular political commentary. No violations. Posts cited in local news."
100
+ ),
101
+ "check_entity_status": (
102
+ "Mayor Davidson: CONFIRMED PUBLIC FIGURE. "
103
+ "Elected Mayor of Springfield. Active public interest in recent social services budget cuts. "
104
+ "Multiple credible news articles confirm the cuts are real and newsworthy."
105
+ ),
106
+ "view_policy": (
107
+ "Policy 12.4 — Political Satire: Editorial cartoons targeting PUBLIC OFFICIALS "
108
+ "on PUBLIC INTEREST matters are PROTECTED. Caricature is a recognized technique. "
109
+ "Not realistic impersonation. Correct action: ALLOW_WITH_WARNING."
110
+ ),
111
+ },
112
+ "required_tools": ["check_entity_status", "view_policy"],
113
+ "ground_truth": "ALLOW_WITH_WARNING",
114
+ "ground_truth_signals": {
115
+ "target": "political",
116
+ "is_protected_class": False,
117
+ "toxicity_level": 0.20,
118
+ "is_direct_attack": False,
119
+ "context_type": "satire",
120
+ "intent": "satirical",
121
+ "abusive_language_present": False,
122
+ "content_flags": ["political_content", "public_figure"],
123
+ },
124
+ },
125
+ # ── Append these to the TASKS list ──────────────────────────────────────────
126
+
127
+ {
128
+ "task_id": "T-004",
129
+ "difficulty": "medium",
130
+ "ambiguity_level": "medium",
131
+ "risk_level": "high",
132
+ "post_text": (
133
+ "Sharing this 'educational' video about why [ethnic group] people "
134
+ "are genetically predisposed to criminal behaviour. "
135
+ "The science doesn't lie. Share before it gets deleted. 👀 "
136
+ "#HiddenTruths #CensoredFacts #JustAsking"
137
+ ),
138
+ "image_description": (
139
+ "Screenshot of a YouTube-style video thumbnail with a white lab coat figure "
140
+ "and a graph. Title: 'The Uncomfortable Biology of Crime Rates'. "
141
+ "Uploader has 0 verified credentials. Comments disabled."
142
+ ),
143
+ "tool_responses": {
144
+ "read_comments": (
145
+ "Comments disabled by uploader. "
146
+ "Zero engagement from credible academic or scientific accounts."
147
+ ),
148
+ "check_user_history": (
149
+ "Account 6 months old. 48 posts — all promoting pseudoscientific race content. "
150
+ "Previously received 2 policy warnings for hate content. "
151
+ "No verified credentials. Frequently uses 'just asking questions' framing."
152
+ ),
153
+ "check_entity_status": (
154
+ "Ethnic group mentioned: confirmed PROTECTED CLASS under platform policy. "
155
+ "Claim about genetic criminality: DEBUNKED by mainstream biology and criminology. "
156
+ "This framing is a known pseudoscience hate vector (scientific racism)."
157
+ ),
158
+ "view_policy": (
159
+ "Policy 4.3 — Hate Speech: Content attributing criminal or negative traits "
160
+ "to a protected class based on race or ethnicity = REMOVE. "
161
+ "Disguising hate speech as 'educational content' or 'science' does not "
162
+ "change the violation. Action: REMOVE."
163
+ ),
164
+ },
165
+ "required_tools": ["check_entity_status", "view_policy"],
166
+ "ground_truth": "REMOVE",
167
+ "ground_truth_signals": {
168
+ "target": "group",
169
+ "is_protected_class": True,
170
+ "toxicity_level": 0.85,
171
+ "is_direct_attack": True,
172
+ "context_type": "statement",
173
+ "intent": "harmful",
174
+ "abusive_language_present": False,
175
+ "content_flags": ["slur", "misinformation", "hate_speech_disguised_as_education"],
176
+ },
177
+ },
178
+
179
+ {
180
+ "task_id": "T-005",
181
+ "difficulty": "hard",
182
+ "ambiguity_level": "high",
183
+ "risk_level": "high",
184
+ "post_text": (
185
+ "The Prime Minister announced new economic reforms today. "
186
+ "Opposition leaders called it 'economic warfare on the poor.' "
187
+ "Protests erupted in 3 cities. Security forces used tear gas. "
188
+ "Government says the crackdown was 'necessary to restore order.' "
189
+ "At least 12 protesters arrested. #Politics #Protest #Reform"
190
+ ),
191
+ "image_description": (
192
+ "News-style post with a photo of a large protest crowd holding signs. "
193
+ "Tear gas clouds visible in background. "
194
+ "Caption: 'Thousands take to streets against new austerity measures.' "
195
+ "Shared by a verified news outlet handle. No graphic violence."
196
+ ),
197
+ "tool_responses": {
198
+ "read_comments": (
199
+ "340 comments split sharply: "
200
+ "'State violence against citizens!', 'Finally someone fixing the economy', "
201
+ "'This is propaganda!', 'Biased reporting!', 'Necessary action'. "
202
+ "High engagement, deep political polarization. No incitement to violence detected."
203
+ ),
204
+ "check_user_history": (
205
+ "Posted by @NationalNewsNetwork — verified news organization. "
206
+ "12 years active. Consistently neutral coverage. "
207
+ "No policy violations. Registered press entity."
208
+ ),
209
+ "check_entity_status": (
210
+ "Prime Minister: CONFIRMED PUBLIC FIGURE, head of government. "
211
+ "Economic reforms: CONFIRMED real policy, actively debated in parliament. "
212
+ "Protest: CONFIRMED real event, covered by 6+ independent news agencies. "
213
+ "Political sensitivity: VERY HIGH — active governance dispute in a democracy."
214
+ ),
215
+ "view_policy": (
216
+ "Policy 12.1 — News & Political Content: Factual reporting on real government "
217
+ "actions, protests, and police responses by VERIFIED NEWS entities is PROTECTED. "
218
+ "Even politically contentious factual reporting is protected expression. "
219
+ "Suppressing legitimate news = censorship risk. "
220
+ "Action: ALLOW_WITH_WARNING (politically sensitive content label)."
221
+ ),
222
+ },
223
+ "required_tools": ["check_user_history", "check_entity_status", "view_policy"],
224
+ "ground_truth": "ALLOW_WITH_WARNING",
225
+ "ground_truth_signals": {
226
+ "target": "political",
227
+ "is_protected_class": False,
228
+ "toxicity_level": 0.20,
229
+ "is_direct_attack": False,
230
+ "context_type": "news",
231
+ "intent": "descriptive",
232
+ "abusive_language_present": False,
233
+ "content_flags": ["political_content", "public_figure", "protest"],
234
+ },
235
+ },
236
+
237
+ {
238
+ "task_id": "T-006",
239
+ "difficulty": "hard",
240
+ "ambiguity_level": "high",
241
+ "risk_level": "high",
242
+ "post_text": (
243
+ "Sending Eid Mubarak wishes to my community 🌙✨ "
244
+ "During this holy month let us remember those suffering in conflict zones. "
245
+ "Our brothers and sisters deserve peace and dignity. "
246
+ "May Allah guide us all. #EidMubarak #Ramadan #Peace #Solidarity"
247
+ ),
248
+ "image_description": (
249
+ "Warm celebratory post with a crescent moon and stars graphic. "
250
+ "Soft gold and green colors. "
251
+ "Comment section shows replies in Arabic, Urdu, and English. "
252
+ "Emojis: praying hands, crescent moon, dove."
253
+ ),
254
+ "tool_responses": {
255
+ "read_comments": (
256
+ "89 comments: 'Eid Mubarak!', 'JazakAllah Khair', "
257
+ "'prayers for peace', 'beautiful message', "
258
+ "1 comment flagged by automated system: 'why is this allowed, "
259
+ "these people support terrorism' — that comment is from a third party, "
260
+ "not the original poster. "
261
+ "No incitement, no threats, no extremist content in original post."
262
+ ),
263
+ "check_user_history": (
264
+ "Account 4 years active. 200+ posts — community announcements, "
265
+ "religious holiday greetings, charity drives. "
266
+ "Zero prior violations. Community leader in local Muslim cultural centre."
267
+ ),
268
+ "check_entity_status": (
269
+ "Islam / Muslim community: CONFIRMED PROTECTED CLASS (religion). "
270
+ "Eid al-Fitr: globally recognized Islamic religious holiday. "
271
+ "Content: religious greeting + humanitarian solidarity message. "
272
+ "No extremist symbols, no prohibited organisations mentioned."
273
+ ),
274
+ "view_policy": (
275
+ "Policy 8.1 — Religious Expression: Greetings and solidarity messages "
276
+ "for recognized religious holidays are PROTECTED. "
277
+ "Vague 'conflict zone' solidarity without naming a prohibited group "
278
+ "is NOT a violation. "
279
+ "Automated flag was triggered by a COMMENT, not the original post. "
280
+ "Action: ALLOW — original content is clearly safe religious expression."
281
+ ),
282
+ },
283
+ "required_tools": ["read_comments", "check_entity_status", "view_policy"],
284
+ "ground_truth": "ALLOW",
285
+ "ground_truth_signals": {
286
+ "target": "religion",
287
+ "is_protected_class": True,
288
+ "toxicity_level": 0.05,
289
+ "is_direct_attack": False,
290
+ "context_type": "statement",
291
+ "intent": "educational",
292
+ "abusive_language_present": False,
293
+ "content_flags": ["political_content"],
294
+ },
295
+ },
296
+ ]
server/your_environment.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ import uuid
5
+ from typing import Optional, Dict, Any, Set
6
+
7
+ try:
8
+ from openenv_core.env_server import Environment
9
+ print("[env] Inheriting from openenv_core.env_server.Environment ✅")
10
+ except ImportError:
11
+ try:
12
+ from openenv.core.env_server import Environment
13
+ print("[env] Inheriting from openenv.core.env_server.Environment ✅")
14
+ except ImportError:
15
+ Environment = object
16
+ print("[env] openenv_core not found — using plain object base ⚠️")
17
+
18
+ from models import TrustObservation, TrustAction, TrustState, ContentSignals
19
+ from tasks import TASKS
20
+
21
+
22
+ TOOL_COSTS: Dict[str, float] = {
23
+ "read_comments": 0.05,
24
+ "check_user_history": 0.05,
25
+ "check_entity_status": 0.10,
26
+ "view_policy": 0.10,
27
+ }
28
+
29
+ MAX_STEPS = 7
30
+
31
+ DECISION_MATRIX: Dict[tuple, float] = {
32
+ ("REMOVE", "REMOVE"): 1.00,
33
+ ("ALLOW", "ALLOW"): 1.00,
34
+ ("ALLOW_WITH_WARNING", "ALLOW_WITH_WARNING"): 1.00,
35
+ ("ESCALATE", "ESCALATE"): 1.00,
36
+ ("ALLOW_WITH_WARNING", "ALLOW"): 0.75,
37
+ ("ALLOW", "ALLOW_WITH_WARNING"): 0.55,
38
+ ("ESCALATE", "ALLOW_WITH_WARNING"): 0.65,
39
+ ("ESCALATE", "ALLOW"): 0.45,
40
+ ("ESCALATE", "REMOVE"): 0.45,
41
+ ("REMOVE", "ALLOW"): 0.10,
42
+ ("REMOVE", "ALLOW_WITH_WARNING"): 0.20,
43
+ ("ALLOW", "REMOVE"): 0.00,
44
+ ("ALLOW_WITH_WARNING", "REMOVE"): 0.15,
45
+ }
46
+
47
+
48
+ class TrustSafetyEnvironment(Environment):
49
+ """
50
+ 3-Layer Risk-Aware Trust & Safety RL Environment.
51
+
52
+ Layer 1 — Evidence gathering : agent uses investigation tools (optional)
53
+ Layer 2 — Signal extraction : agent outputs ContentSignals as feature extractor
54
+ Layer 3 — Policy engine : validates signals, applies rules, computes reward
55
+
56
+ 8-Component Reward: Accuracy · Policy Alignment · Signal Quality · Escalation
57
+ Tool Usage · Consistency · Risk Sensitivity · Confidence
58
+ """
59
+
60
+ def __init__(self, seed: int = 42) -> None:
61
+ super().__init__()
62
+ self._rng = random.Random(seed)
63
+ self._current_task: Optional[Dict[str, Any]] = None
64
+ self._tools_used: Set[str] = set()
65
+ self._step_count: int = 0
66
+ self._extracted_signals: Optional[ContentSignals] = None
67
+ self._validation_result: Optional[Dict[str, Any]] = None
68
+ self._signals_extracted: bool = False
69
+ self._obs: Optional[TrustObservation]= None
70
+ self._state = TrustState()
71
+
72
+ # ✅ FIX 3 — build a dict keyed by task_id for O(1) lookup
73
+ self._tasks: Dict[str, Dict[str, Any]] = {
74
+ t["task_id"]: t for t in TASKS
75
+ }
76
+
77
+ # -----------------------------------------------------------------------
78
+ # OpenEnv interface
79
+ # -----------------------------------------------------------------------
80
+
81
+ def reset(self, seed=None, episode_id=None, **kwargs) -> TrustObservation:
82
+ # ✅ FIX 1 — reset() is now correctly INSIDE the class
83
+ if seed is not None:
84
+ self._rng.seed(seed)
85
+
86
+ # Pick task by episode_id if provided, else random from all 6
87
+ if episode_id and episode_id in self._tasks:
88
+ task = self._tasks[episode_id]
89
+ else:
90
+ task = self._rng.choice(list(self._tasks.values()))
91
+
92
+ self._current_task = task
93
+ self._tools_used = set()
94
+ self._step_count = 0
95
+ self._extracted_signals = None
96
+ self._validation_result = None
97
+ self._signals_extracted = False
98
+
99
+ self._state = TrustState(
100
+ episode_id=task["task_id"],
101
+ step_count=0,
102
+ current_task_id=task["task_id"],
103
+ difficulty=task.get("difficulty", "medium"),
104
+ risk_level=task.get("risk_level", "medium"),
105
+ is_done=False,
106
+ tools_used=[],
107
+ signals_extracted=False,
108
+ )
109
+
110
+ self._obs = TrustObservation(
111
+ ticket_id=task["task_id"],
112
+ post_text=task["post_text"],
113
+ image_description=task.get("image_description", ""),
114
+ step_number=0,
115
+ done=False,
116
+ )
117
+ return self._obs # ✅ FIX 2 — single clean return, stray return removed
118
+
119
+ def step(self, action: TrustAction, timeouts: Optional[Any] = None,
120
+ **kwargs) -> TrustObservation:
121
+ if self._current_task is None or self._obs is None:
122
+ raise RuntimeError("Call reset() before step().")
123
+
124
+ if self._step_count >= MAX_STEPS:
125
+ self._obs = TrustObservation(
126
+ ticket_id=self._current_task["task_id"],
127
+ post_text=self._obs.post_text,
128
+ image_description=self._obs.image_description,
129
+ step_number=self._step_count,
130
+ done=True,
131
+ reward=0.0,
132
+ info={"reason": "timeout", "tools_used": list(self._tools_used)},
133
+ )
134
+ return self._obs
135
+
136
+ atype = action.action_type
137
+ if atype == "use_tool":
138
+ return self._handle_tool(action)
139
+ if atype == "extract_signals":
140
+ return self._handle_signal_extraction(action)
141
+ if atype == "final_decision":
142
+ return self._handle_final_decision(action)
143
+ raise ValueError(f"Unknown action_type: {atype!r}")
144
+
145
+ @property
146
+ def state(self) -> TrustState:
147
+ return self._state
148
+
149
+ # -----------------------------------------------------------------------
150
+ # Layer 1 — Tool handling
151
+ # -----------------------------------------------------------------------
152
+
153
+ def _handle_tool(self, action: TrustAction) -> TrustObservation:
154
+ tool = action.tool_name
155
+ if tool not in TOOL_COSTS:
156
+ raise ValueError(f"Unknown tool: {tool!r}")
157
+ self._tools_used.add(tool)
158
+ response = self._current_task["tool_responses"].get(tool, "No data found.")
159
+ field_map = {
160
+ "read_comments": "comments_found",
161
+ "check_user_history": "user_history_found",
162
+ "check_entity_status": "entity_status_found",
163
+ "view_policy": "policy_found",
164
+ }
165
+ self._step_count += 1
166
+ self._state.step_count = self._step_count
167
+ self._state.tools_used = list(self._tools_used)
168
+
169
+ obs_kwargs = {
170
+ k: getattr(self._obs, k)
171
+ for k in ("ticket_id", "post_text", "image_description",
172
+ "comments_found", "user_history_found",
173
+ "entity_status_found", "policy_found",
174
+ "extracted_signals", "validation_result")
175
+ }
176
+ obs_kwargs[field_map[tool]] = response
177
+ obs_kwargs["step_number"] = self._step_count
178
+ obs_kwargs["done"] = False
179
+ obs_kwargs["reward"] = None
180
+
181
+ self._obs = TrustObservation(**obs_kwargs)
182
+ return self._obs
183
+
184
+ # -----------------------------------------------------------------------
185
+ # Layer 2 — Signal extraction + validation
186
+ # -----------------------------------------------------------------------
187
+
188
+ def _handle_signal_extraction(self, action: TrustAction) -> TrustObservation:
189
+ raw = action.signals
190
+ raw.toxicity_level = max(0.0, min(1.0, float(raw.toxicity_level)))
191
+ raw.confidence = max(0.0, min(1.0, float(raw.confidence)))
192
+ if not isinstance(raw.content_flags, list):
193
+ raw.content_flags = []
194
+
195
+ self._extracted_signals = raw
196
+ self._signals_extracted = True
197
+ self._validation_result = self._validate_signals(raw)
198
+ self._step_count += 1
199
+ self._state.step_count = self._step_count
200
+ self._state.signals_extracted = True
201
+
202
+ obs_kwargs = {
203
+ k: getattr(self._obs, k)
204
+ for k in ("ticket_id", "post_text", "image_description",
205
+ "comments_found", "user_history_found",
206
+ "entity_status_found", "policy_found")
207
+ }
208
+ obs_kwargs["extracted_signals"] = {
209
+ "target": raw.target,
210
+ "is_protected_class": raw.is_protected_class,
211
+ "toxicity_level": raw.toxicity_level,
212
+ "is_direct_attack": raw.is_direct_attack,
213
+ "context_type": raw.context_type,
214
+ "intent": raw.intent,
215
+ "confidence": raw.confidence,
216
+ "abusive_language_present": raw.abusive_language_present,
217
+ "content_flags": raw.content_flags,
218
+ }
219
+ obs_kwargs["validation_result"] = self._validation_result
220
+ obs_kwargs["step_number"] = self._step_count
221
+ obs_kwargs["done"] = False
222
+ obs_kwargs["reward"] = None
223
+
224
+ self._obs = TrustObservation(**obs_kwargs)
225
+ return self._obs
226
+
227
+ def _validate_signals(self, s: ContentSignals) -> Dict[str, Any]:
228
+ issues = []
229
+ conf = s.confidence
230
+
231
+ if not s.abusive_language_present and s.toxicity_level > 0.75:
232
+ issues.append("high_toxicity_without_abusive_language"); conf -= 0.15
233
+ if s.context_type in ("satire", "education") and s.intent == "harmful":
234
+ issues.append("harmful_intent_contradicts_satire_context"); conf -= 0.12
235
+ if s.is_protected_class and s.target == "none":
236
+ issues.append("protected_class_flagged_without_target"); conf -= 0.10
237
+ if s.is_direct_attack and s.toxicity_level < 0.25:
238
+ issues.append("direct_attack_with_low_toxicity"); conf -= 0.12
239
+ if s.context_type == "spam" and "spam" not in s.content_flags \
240
+ and "phishing" not in s.content_flags:
241
+ issues.append("spam_context_missing_spam_flag"); conf -= 0.08
242
+
243
+ adj = round(max(0.0, min(1.0, conf)), 3)
244
+ return {
245
+ "original_confidence": s.confidence,
246
+ "adjusted_confidence": adj,
247
+ "consistency_issues": issues,
248
+ "issue_count": len(issues),
249
+ "is_consistent": len(issues) == 0,
250
+ "needs_escalation_hint": adj < 0.55 or len(issues) >= 2,
251
+ }
252
+
253
+ # -----------------------------------------------------------------------
254
+ # Layer 3 — Policy engine + final decision
255
+ # -----------------------------------------------------------------------
256
+
257
+ def _policy_engine(self) -> Dict[str, Any]:
258
+ if not self._extracted_signals or not self._validation_result:
259
+ return {"recommended": "ESCALATE", "reason": "no_signals_extracted",
260
+ "rule_strength": "weak"}
261
+
262
+ s = self._extracted_signals
263
+ adj = self._validation_result["adjusted_confidence"]
264
+
265
+ if adj < 0.50:
266
+ return {"recommended": "ESCALATE", "reason": "low_adjusted_confidence",
267
+ "rule_strength": "medium"}
268
+ if "spam" in s.content_flags or "phishing" in s.content_flags:
269
+ return {"recommended": "REMOVE", "reason": "spam_or_phishing_detected",
270
+ "rule_strength": "strong"}
271
+ if s.is_protected_class and s.is_direct_attack \
272
+ and s.toxicity_level > 0.70 and adj > 0.70:
273
+ return {"recommended": "REMOVE",
274
+ "reason": "protected_class_direct_attack_high_toxicity",
275
+ "rule_strength": "strong"}
276
+ if s.toxicity_level > 0.85 and s.intent == "harmful":
277
+ return {"recommended": "REMOVE", "reason": "high_toxicity_harmful_intent",
278
+ "rule_strength": "strong"}
279
+ if s.context_type in ("satire", "news", "education", "protest") \
280
+ and s.intent in ("satirical", "descriptive", "educational"):
281
+ return {"recommended": "ALLOW_WITH_WARNING",
282
+ "reason": "protected_expression_context",
283
+ "rule_strength": "medium"}
284
+ if 0.40 < s.toxicity_level < 0.75 and s.intent == "ambiguous":
285
+ return {"recommended": "ESCALATE", "reason": "ambiguous_moderate_toxicity",
286
+ "rule_strength": "medium"}
287
+ return {"recommended": "ALLOW", "reason": "no_policy_violation_detected",
288
+ "rule_strength": "medium"}
289
+
290
+ def _handle_final_decision(self, action: TrustAction) -> TrustObservation:
291
+ decision = action.final_decision
292
+ components = self._compute_components(decision)
293
+ policy_rec = components.pop("_policy_rec")
294
+ reward = self._finalize_reward(components)
295
+
296
+ self._step_count += 1
297
+ self._state.step_count = self._step_count
298
+ self._state.is_done = True
299
+ components["final_reward"] = reward
300
+
301
+ obs_kwargs = {
302
+ k: getattr(self._obs, k)
303
+ for k in ("ticket_id", "post_text", "image_description",
304
+ "comments_found", "user_history_found",
305
+ "entity_status_found", "policy_found",
306
+ "extracted_signals", "validation_result")
307
+ }
308
+ obs_kwargs["step_number"] = self._step_count
309
+ obs_kwargs["done"] = True
310
+ obs_kwargs["reward"] = reward
311
+ obs_kwargs["info"] = {
312
+ "final_decision": decision,
313
+ "ground_truth": self._current_task["ground_truth"],
314
+ "policy_recommendation": policy_rec,
315
+ "signals_extracted": self._signals_extracted,
316
+ "tools_used": list(self._tools_used),
317
+ "required_tools": self._current_task["required_tools"],
318
+ "ambiguity_level": self._current_task["ambiguity_level"],
319
+ "risk_level": self._current_task["risk_level"],
320
+ "task_id": self._current_task["task_id"],
321
+ "reward_breakdown": components,
322
+ }
323
+
324
+ self._obs = TrustObservation(**obs_kwargs)
325
+ return self._obs
326
+
327
+ # -----------------------------------------------------------------------
328
+ # 8-Component Reward Engine
329
+ # -----------------------------------------------------------------------
330
+
331
+ def _compute_components(self, final_decision: str) -> Dict[str, Any]:
332
+ gt = self._current_task["ground_truth"]
333
+ required_tools = self._current_task["required_tools"]
334
+ ambiguity = self._current_task["ambiguity_level"]
335
+ risk_level = self._current_task["risk_level"]
336
+ policy_rec = self._policy_engine()
337
+
338
+ base_score = DECISION_MATRIX.get((final_decision, gt), 0.20)
339
+ if final_decision == "ESCALATE" and ambiguity == "high":
340
+ base_score = max(base_score, 0.70)
341
+ is_correct = base_score >= 0.90
342
+
343
+ rule_weight = {"strong": 1.0, "medium": 0.70, "weak": 0.40}.get(
344
+ policy_rec.get("rule_strength", "medium"), 0.70)
345
+ policy_alignment = round(
346
+ (+0.12 if final_decision == policy_rec["recommended"] else -0.18) * rule_weight, 4)
347
+
348
+ signal_accuracy_bonus = self._compute_signal_accuracy()
349
+
350
+ adj_conf = (self._validation_result["adjusted_confidence"]
351
+ if self._validation_result else 0.50)
352
+ should_escalate = adj_conf < 0.50
353
+ if should_escalate and final_decision == "ESCALATE":
354
+ escalation_adj = +0.15
355
+ elif should_escalate and final_decision != "ESCALATE":
356
+ escalation_adj = -0.18
357
+ elif not should_escalate and final_decision == "ESCALATE" and ambiguity == "low":
358
+ escalation_adj = -0.20
359
+ elif not should_escalate and final_decision == "ESCALATE":
360
+ escalation_adj = -0.10
361
+ else:
362
+ escalation_adj = 0.0
363
+
364
+ signal_bonus = +0.05 if self._signals_extracted else -0.10
365
+ tool_cost = round(sum(TOOL_COSTS.get(t, 0.0) for t in self._tools_used), 4)
366
+ missing_required = set(required_tools) - self._tools_used
367
+ tool_miss_penalty = round(len(missing_required) * 0.25, 4)
368
+
369
+ if self._validation_result:
370
+ n = self._validation_result["issue_count"]
371
+ validation_penalty = {0: 0.00, 1: 0.05, 2: 0.12}.get(n, 0.20)
372
+ else:
373
+ validation_penalty = 0.12
374
+
375
+ risk_penalty = 0.0
376
+ if not is_correct:
377
+ risk_penalty = {"high": 0.20, "medium": 0.10, "low": 0.0}.get(risk_level, 0.0)
378
+
379
+ if base_score < 0.50 and adj_conf > 0.80:
380
+ confidence_penalty = 0.22
381
+ elif base_score < 0.50 and adj_conf > 0.65:
382
+ confidence_penalty = 0.12
383
+ elif self._signals_extracted and final_decision == "ESCALATE" and adj_conf < 0.55:
384
+ confidence_penalty = -0.10
385
+ else:
386
+ confidence_penalty = 0.0
387
+
388
+ return {
389
+ "base_score": base_score,
390
+ "policy_alignment": policy_alignment,
391
+ "signal_accuracy_bonus": signal_accuracy_bonus,
392
+ "escalation_adj": escalation_adj,
393
+ "signal_bonus": signal_bonus,
394
+ "tool_cost": tool_cost,
395
+ "tool_miss_penalty": tool_miss_penalty,
396
+ "validation_penalty": validation_penalty,
397
+ "risk_penalty": risk_penalty,
398
+ "confidence_penalty": confidence_penalty,
399
+ "_policy_rec": policy_rec,
400
+ }
401
+
402
+ def _finalize_reward(self, components: Dict[str, Any]) -> float:
403
+ raw = (
404
+ components["base_score"]
405
+ + components["policy_alignment"]
406
+ + components["signal_accuracy_bonus"]
407
+ + components["escalation_adj"]
408
+ + components["signal_bonus"]
409
+ - components["tool_cost"]
410
+ - components["tool_miss_penalty"]
411
+ - components["validation_penalty"]
412
+ - components["risk_penalty"]
413
+ - components["confidence_penalty"]
414
+ )
415
+ return round(max(0.0, min(1.0, raw)), 4)
416
+
417
+ def _compute_signal_accuracy(self) -> float:
418
+ if not self._extracted_signals:
419
+ return 0.0
420
+ gt = self._current_task.get("ground_truth_signals", {})
421
+ if not gt:
422
+ return 0.05
423
+
424
+ s = self._extracted_signals
425
+ score = 0.0
426
+ if s.target == gt.get("target"): score += 0.20
427
+ if s.intent == gt.get("intent"): score += 0.20
428
+ if s.context_type == gt.get("context_type"): score += 0.20
429
+
430
+ tox_diff = abs(s.toxicity_level - gt.get("toxicity_level", 0.5))
431
+ score += 0.20 if tox_diff <= 0.20 else (0.10 if tox_diff <= 0.35 else 0.0)
432
+
433
+ gt_flags = set(gt.get("content_flags", []))
434
+ s_flags = set(s.content_flags)
435
+ if gt_flags:
436
+ score += 0.20 * min(1.0, len(gt_flags & s_flags) / len(gt_flags))
437
+ else:
438
+ score += 0.20 if not s_flags else 0.10
439
+
440
+ return round(score * 0.15, 4)
uv.lock ADDED
The diff for this file is too large to render. See raw diff