jeromerichard commited on
Commit
74e3b5e
·
0 Parent(s):

Trust & Safety RL Environment - OpenEnv Hackathon

Browse files
Files changed (15) hide show
  1. .gitignore +20 -0
  2. Dockerfile +15 -0
  3. abouthack +838 -0
  4. app.py +157 -0
  5. client.py +89 -0
  6. inference.py +295 -0
  7. models.py +63 -0
  8. openenv.yaml +64 -0
  9. pyproject.toml +33 -0
  10. readme.md +24 -0
  11. requirements.txt +6 -0
  12. simpley +1716 -0
  13. tasks.py +296 -0
  14. train.py +222 -0
  15. your_environment.py +440 -0
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment variables (contains API keys)
2
+ .env
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+ *.so
9
+ .Python
10
+ *.egg-info/
11
+ dist/
12
+ build/
13
+ .venv/
14
+ venv/
15
+
16
+ # IDE
17
+ .vscode/
18
+ .idea/
19
+ *.swp
20
+ *.swo
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ ENV PORT=7860
13
+ ENV HOST=0.0.0.0
14
+
15
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
abouthack ADDED
@@ -0,0 +1,838 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://www.scaler.com/school-of-technology/meta-pytorch-hackathon
2
+
3
+ Code Review by Meta Engineers
4
+
5
+ Get your work reviewed by engineers shaping agentic AI at Meta.
6
+
7
+ Real Open Source Contribution
8
+
9
+ Your code ships to a Meta-backed project, visible on your GitHub profile.
10
+ ROUND-1
11
+
12
+ Round 1 - Build Your Mini RL Environment:
13
+
14
+ Wednesday, 25th March - Wednesday, 8th April
15
+
16
+ Build a Mini-RL environment with defined tasks, graders, and reward logic. Evaluation includes programmatic checks & LLM scoring.
17
+
18
+ OpenEnv is an open-source framework by Meta & Hugging Face for creating standardized, isolated, and reusable environments for training and deploying AI agents.
19
+
20
+ Think of it as the universal language for AI training environments. It uses a Gymnasium-style API, containerized execution via Docker, and a central hub on Hugging Face for sharing environments.
21
+
22
+ Teams at Meta use OpenEnv to define environments once and run them consistently across training, post-training, and evaluation. Now, you get to build on the same infrastructure.
23
+ Read this Repository for proper codes to be used in this Hackathon:
24
+ https://github.com/huggingface/openenv-course
25
+
26
+ Round 1 — Problem Statement:
27
+ Build a complete, real-world OpenEnv environment that an AI agent can learn from through the standard step() / reset() / state() API.
28
+
29
+ Key Requirements at a Glance
30
+
31
+ Must simulate a real-world task (not games or toys)
32
+
33
+ Implement full OpenEnv spec: typed models, step()/reset()/state(), openenv.yaml
34
+
35
+ Minimum 3 tasks with agent graders (easy → medium → hard, scores 0.0–1.0)
36
+
37
+ Meaningful reward function with partial progress signals
38
+
39
+ Baseline inference script with reproducible scores
40
+
41
+ Deploy to Hugging Face Spaces + working Dockerfile
42
+
43
+ README with environment description, action/observation spaces, setup instructions
44
+
45
+ Functional Requirements:
46
+
47
+ Real-world task simulation
48
+
49
+ The environment must simulate a task humans actually do. Not games, not toys. Examples: email triage, code review, data cleaning, scheduling, customer support, content moderation.
50
+
51
+ OpenEnv spec compliance
52
+
53
+ Implement the full OpenEnv interface: typed Observation, Action, and Reward Pydantic models. step(action) → returns observation, reward, done, info. reset() → returns initial observation. state() → returns current state. openenv.yaml with metadata. Tested via openenv validate.
54
+
55
+ Minimum 3 tasks with agent graders
56
+
57
+ Each task defines a concrete objective an agent must accomplish, with a programmatic grader that scores performance (0.0–1.0). Tasks should range: easy → medium → hard. Graders must have clear, deterministic success/failure criteria.
58
+
59
+ Meaningful reward function
60
+
61
+ Provides signal over the full trajectory (not just binary end-of-episode). Rewards partial progress toward task completion. Penalizes clearly undesirable behavior (e.g. infinite loops, destructive actions).
62
+
63
+ Baseline inference script
64
+
65
+ Uses the OpenAI API client to run a model against the environment. Reads API credentials from environment variables (OPENAI_API_KEY). Produces a reproducible baseline score on all 3 tasks.
66
+
67
+
68
+ Non-Functional Requirements:
69
+
70
+ Deploys to a Hugging Face Space
71
+
72
+ Environment must run as a containerized HF Space tagged with openenv.
73
+
74
+ Containerized execution
75
+
76
+ Must include a working Dockerfile. The environment should start cleanly with docker build + docker run.
77
+
78
+ Documentation
79
+
80
+ README must include: environment description and motivation, action and observation space definitions, task descriptions with expected difficulty, setup and usage instructions, baseline scores.
81
+
82
+ Evaluation Criteria:
83
+ Parameter
84
+
85
+ Weight
86
+
87
+ Description
88
+
89
+ Real-world utility
90
+
91
+ 30%
92
+
93
+ Does the environment model a genuine task? Would someone actually use this to train or evaluate agents?
94
+
95
+ Task & grader quality
96
+
97
+ 25%
98
+
99
+ Are tasks well-defined with clear objectives? Do graders accurately and fairly measure success? Meaningful difficulty progression?
100
+
101
+ Environment design
102
+
103
+ 20%
104
+
105
+ Clean state management, sensible action/observation spaces, good reward shaping, proper episode boundaries.
106
+
107
+ Code quality & spec compliance
108
+
109
+ 15%
110
+
111
+ Follows OpenEnv spec, clean project structure, typed models, documented, tested, Dockerfile works.
112
+
113
+ Creativity & novelty
114
+
115
+ 10%
116
+
117
+ Novel problem domain, interesting mechanics, clever reward design, original approach.
118
+
119
+ Scoring Breakdown
120
+
121
+ Real-world utility (30%)
122
+
123
+ • 0–5: Toy/artificial problem with no practical application
124
+
125
+ • 6–15: Valid domain but shallow modeling of the real task
126
+
127
+ • 16–25: Good domain modeling, would be useful for agent evaluation
128
+
129
+ • 26–30: Excellent — fills a real gap, immediate value for the RL/agent community
130
+
131
+ Task & grader quality (25%)
132
+
133
+ • 3+ tasks with difficulty range?
134
+
135
+ • Graders produce scores between 0.0–1.0?
136
+
137
+ • Graders deterministic and reproducible?
138
+
139
+ • Hard task genuinely challenges frontier models?
140
+
141
+ Environment design (20%)
142
+
143
+ • reset() produces clean state?
144
+
145
+ • Action/observation types well-designed and documented?
146
+
147
+ • Reward function provides useful varying signal (not just sparse)?
148
+
149
+ • Episode boundaries sensible?
150
+
151
+ Code quality & spec compliance (15%)
152
+
153
+ • openenv validate passes?
154
+
155
+ • docker build && docker run works?
156
+
157
+ • HF Space deploys and responds?
158
+
159
+ • Baseline script runs and reproduces scores?
160
+
161
+ Creativity & novelty (10%)
162
+
163
+ • Domain we haven’t seen in OpenEnv before?
164
+
165
+ • Reward design has interesting properties?
166
+
167
+ • Clever mechanics that make the environment engaging?
168
+
169
+
170
+
171
+
172
+
173
+
174
+ How Judging works :
175
+ Phase 1: Automated Validation
176
+
177
+ Pass/fail gate — HF Space deploys, OpenEnv spec compliance, Dockerfile builds, baseline reproduces, 3+ tasks with graders.
178
+
179
+ Phase 2: Agentic Evaluation
180
+
181
+ Scored — baseline agent re-run, standard Open LLM agent (e.g. Nemotron 3 Super) run against all environments, score variance check.
182
+
183
+ Phase 3: Human Review
184
+
185
+ Top submissions reviewed by Meta and Hugging Face engineers for real-world utility, creativity, and exploit checks.
186
+
187
+ Disqualification Criteria
188
+
189
+ Environment does not deploy or respond
190
+
191
+ Plagiarized or trivially modified existing environments
192
+
193
+ Graders that always return the same score
194
+
195
+ No baseline inference script
196
+
197
+
198
+
199
+ Pre-Submission Checklist — all must pass or you're disqualified:
200
+
201
+ HF Space deploys
202
+
203
+ Automated ping to the Space URL — must return 200 and respond to reset()
204
+
205
+ OpenEnv spec compliance
206
+
207
+ Validate openenv.yaml, typed models, step()/reset()/state() endpoints
208
+
209
+ Dockerfile builds
210
+
211
+ Automated docker build on the submitted repo
212
+
213
+ Baseline reproduces
214
+
215
+ Run the submitted inference script — must complete without error and produce scores
216
+
217
+ 3+ tasks with graders
218
+
219
+ Enumerate tasks, run each grader, verify scores in 0.0–1.0 range
220
+
221
+ Additional Instructions
222
+
223
+ Before submitting, ensure the following variables are defined in your environment configuration:
224
+
225
+ API_BASE_URL The API endpoint for the LLM.
226
+
227
+ MODEL_NAME The model identifier to use for inference.
228
+
229
+ HF_TOKEN Your Hugging Face / API key.
230
+
231
+ The inference script must be named `inference.py` and placed in the root directory of the project
232
+
233
+ Participants must use OpenAI Client for all LLM calls using above variables
234
+
235
+ Infra Restrictions
236
+
237
+ Runtime of inference script should be less than 20min
238
+
239
+ Make sure your env and inference can run on a machine with vcpu=2, memory=8gb
240
+
241
+ Validator
242
+
243
+ Run the pre-submission validation script before submitting
244
+
245
+
246
+
247
+ Sample Inference Script:
248
+
249
+ """
250
+ Inference Script Example
251
+ ===================================
252
+ MANDATORY
253
+ - Before submitting, ensure the following variables are defined in your environment configuration:
254
+ API_BASE_URL The API endpoint for the LLM.
255
+ MODEL_NAME The model identifier to use for inference.
256
+ HF_TOKEN Your Hugging Face / API key.
257
+
258
+ - The inference script must be named `inference.py` and placed in the root directory of the project
259
+ - Participants must use OpenAI Client for all LLM calls using above variables
260
+ """
261
+
262
+ import os
263
+ import re
264
+ import base64
265
+ import textwrap
266
+ from io import BytesIO
267
+ from typing import List, Optional, Dict
268
+
269
+ from openai import OpenAI
270
+ import numpy as np
271
+ from PIL import Image
272
+
273
+ from browsergym_env import BrowserGymAction, BrowserGymEnv
274
+
275
+ API_BASE_URL = os.getenv("API_BASE_URL") // "https://router.huggingface.co/v1"
276
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
277
+ MODEL_NAME = os.getenv("MODEL_NAME")
278
+ MAX_STEPS = 8
279
+ MAX_DOM_CHARS = 3500
280
+ TEMPERATURE = 0.2
281
+ MAX_TOKENS = 200
282
+ FALLBACK_ACTION = "noop()"
283
+
284
+ DEBUG = True
285
+ ACTION_PREFIX_RE = re.compile(
286
+ r"^(action|next action)\s*[:\-]\s*",
287
+ re.IGNORECASE,
288
+ )
289
+ ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
290
+
291
+
292
+ SYSTEM_PROMPT = textwrap.dedent(
293
+ """
294
+ You control a web browser through BrowserGym.
295
+ Reply with exactly one action string.
296
+ The action must be a valid BrowserGym command such as:
297
+ - noop()
298
+ - click('<BID>')
299
+ - type('selector', 'text to enter')
300
+ - fill('selector', 'text to enter')
301
+ - send_keys('Enter')
302
+ - scroll('down')
303
+ Use single quotes around string arguments.
304
+ When clicking, use the BrowserGym element IDs (BIDs) listed in the user message.
305
+ If you are unsure, respond with noop().
306
+ Do not include explanations or additional text.
307
+ """
308
+ ).strip()
309
+
310
+
311
+ def build_history_lines(history: List[str]) -> str:
312
+ if not history:
313
+ return "None"
314
+ return "\n".join(history[-4:])
315
+
316
+
317
+ def extract_screenshot_uri(observation) -> Optional[str]:
318
+ if observation.screenshot is None:
319
+ return None
320
+ screen_array = np.array(observation.screenshot, dtype=np.uint8)
321
+ image = Image.fromarray(screen_array)
322
+ buffer = BytesIO()
323
+ image.save(buffer, format="PNG")
324
+ buffer.seek(0)
325
+ data_uri = base64.b64encode(buffer.read()).decode("utf-8")
326
+ return f"data:image/png;base64,{data_uri}"
327
+
328
+
329
+ def extract_clickable_elements(observation) -> List[Dict[str, str]]:
330
+ """Collect BrowserGym element IDs that can be clicked."""
331
+
332
+ metadata = getattr(observation, "metadata", {}) or {}
333
+ obs_dict = metadata.get("browsergym_obs", {}) or {}
334
+ extra_props = obs_dict.get("extra_element_properties", {}) or {}
335
+
336
+ clickables: List[Dict[str, str]] = []
337
+ for bid, props in extra_props.items():
338
+ if not props.get("clickable"):
339
+ continue
340
+
341
+ bbox = props.get("bbox") or []
342
+ bbox_str = ", ".join(bbox) if bbox else "?"
343
+ clickables.append(
344
+ {
345
+ "bid": str(bid),
346
+ "bbox": bbox_str,
347
+ }
348
+ )
349
+
350
+ # Keep a stable ordering for readability
351
+ clickables.sort(key=lambda item: item["bid"])
352
+ return clickables
353
+
354
+
355
+ def build_user_prompt(step: int, observation, history: List[str]) -> str:
356
+ goal = observation.goal or "(not provided)"
357
+ url = observation.url or "(unknown)"
358
+ error_note = "Yes" if observation.last_action_error else "No"
359
+
360
+ clickables = extract_clickable_elements(observation)
361
+ if clickables:
362
+ actions_hint = "\n".join(
363
+ f" - {item['bid']} (bbox: {item['bbox']})" for item in clickables
364
+ )
365
+ else:
366
+ actions_hint = " (none detected)"
367
+
368
+ prompt = textwrap.dedent(
369
+ f"""
370
+ Step: {step}
371
+ Goal: {goal}
372
+ Current URL: {url}
373
+ Previous steps:
374
+ {build_history_lines(history)}
375
+ Last action error: {error_note}
376
+ Available clickable element IDs: {actions_hint}
377
+ Reply with exactly one BrowserGym action string.
378
+ """
379
+ ).strip()
380
+ return prompt
381
+
382
+
383
+ def parse_model_action(response_text: str) -> str:
384
+ if not response_text:
385
+ return FALLBACK_ACTION
386
+
387
+ # Prefer the first line that looks like an action string
388
+ lines = response_text.splitlines()
389
+ for raw_line in lines:
390
+ line = raw_line.strip()
391
+ if not line:
392
+ continue
393
+ line = ACTION_PREFIX_RE.sub("", line)
394
+ match = ACTION_PATTERN.search(line)
395
+ if match:
396
+ action = match.group(0).strip()
397
+ # Collapse internal whitespace
398
+ action = re.sub(r"\s+", " ", action)
399
+ # If the model tried to click by natural-language description while we
400
+ # only exposed numeric BrowserGym IDs, fallback to the single detected ID.
401
+ return action
402
+
403
+ # Fall back to searching the whole response
404
+ match = ACTION_PATTERN.search(response_text)
405
+ if match:
406
+ action = match.group(0).strip()
407
+ action = re.sub(r"\s+", " ", action)
408
+ return action
409
+
410
+ return FALLBACK_ACTION
411
+
412
+
413
+ def main() -> None:
414
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
415
+
416
+ env = BrowserGymEnv.from_docker_image(
417
+ image="browsergym-env:latest",
418
+ env_vars={
419
+ "BROWSERGYM_BENCHMARK": "miniwob",
420
+ "BROWSERGYM_TASK_NAME": "click-test",
421
+ },
422
+ )
423
+
424
+ history: List[str] = []
425
+
426
+ try:
427
+ result = env.reset()
428
+ observation = result.observation
429
+ print(f"Episode goal: {observation.goal}")
430
+
431
+ for step in range(1, MAX_STEPS + 1):
432
+ if result.done:
433
+ print("Environment signalled done. Stopping early.")
434
+ break
435
+
436
+ user_prompt = build_user_prompt(step, observation, history)
437
+ user_content = [{"type": "text", "text": user_prompt}]
438
+ screenshot_uri = extract_screenshot_uri(observation)
439
+ if screenshot_uri:
440
+ user_content.append(
441
+ {
442
+ "type": "image_url",
443
+ "image_url": {"url": screenshot_uri},
444
+ }
445
+ )
446
+
447
+ messages = [
448
+ {
449
+ "role": "system",
450
+ "content": [{"type": "text", "text": SYSTEM_PROMPT}],
451
+ },
452
+ {
453
+ "role": "user",
454
+ "content": user_content,
455
+ },
456
+ ]
457
+
458
+ try:
459
+ completion = client.chat.completions.create(
460
+ model=MODEL_NAME,
461
+ messages=messages,
462
+ temperature=TEMPERATURE,
463
+ max_tokens=MAX_TOKENS,
464
+ stream=False,
465
+ )
466
+ response_text = completion.choices[0].message.content or ""
467
+ # pylint: disable=broad-except
468
+ except Exception as exc: # noqa: BLE001
469
+ failure_msg = f"Model request failed ({exc}). Using fallback action."
470
+ print(failure_msg)
471
+ response_text = FALLBACK_ACTION
472
+
473
+ action_str = parse_model_action(response_text)
474
+ print(f"Step {step}: model suggested -> {action_str}")
475
+
476
+ result = env.step(BrowserGymAction(action_str=action_str))
477
+ observation = result.observation
478
+
479
+ reward = result.reward or 0.0
480
+ error_flag = " ERROR" if observation.last_action_error else ""
481
+ history_line = (
482
+ f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
483
+ )
484
+ history.append(history_line)
485
+ print(
486
+ " Reward: "
487
+ f"{reward:+.2f} | Done: {result.done} | Last action error: "
488
+ f"{observation.last_action_error}"
489
+ )
490
+
491
+ if result.done:
492
+ print("Episode complete.")
493
+ break
494
+
495
+ else:
496
+ print(f"Reached max steps ({MAX_STEPS}).")
497
+
498
+ finally:
499
+ env.close()
500
+
501
+
502
+ if __name__ == "__main__":
503
+ main()
504
+
505
+
506
+
507
+ Pre Validation Script:
508
+ #!/usr/bin/env bash
509
+ #
510
+ # validate-submission.sh — OpenEnv Submission Validator
511
+ #
512
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
513
+ #
514
+ # Prerequisites:
515
+ # - Docker: https://docs.docker.com/get-docker/
516
+ # - openenv-core: pip install openenv-core
517
+ # - curl (usually pre-installed)
518
+ #
519
+ # Run:
520
+ # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
521
+ #
522
+ # Or download and run locally:
523
+ # chmod +x validate-submission.sh
524
+ # ./validate-submission.sh <ping_url> [repo_dir]
525
+ #
526
+ # Arguments:
527
+ # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
528
+ # repo_dir Path to your repo (default: current directory)
529
+ #
530
+ # Examples:
531
+ # ./validate-submission.sh https://my-team.hf.space
532
+ # ./validate-submission.sh https://my-team.hf.space ./my-repo
533
+ #
534
+
535
+ set -uo pipefail
536
+
537
+ DOCKER_BUILD_TIMEOUT=600
538
+ if [ -t 1 ]; then
539
+ RED='\033[0;31m'
540
+ GREEN='\033[0;32m'
541
+ YELLOW='\033[1;33m'
542
+ BOLD='\033[1m'
543
+ NC='\033[0m'
544
+ else
545
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
546
+ fi
547
+
548
+ run_with_timeout() {
549
+ local secs="$1"; shift
550
+ if command -v timeout &>/dev/null; then
551
+ timeout "$secs" "$@"
552
+ elif command -v gtimeout &>/dev/null; then
553
+ gtimeout "$secs" "$@"
554
+ else
555
+ "$@" &
556
+ local pid=$!
557
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
558
+ local watcher=$!
559
+ wait "$pid" 2>/dev/null
560
+ local rc=$?
561
+ kill "$watcher" 2>/dev/null
562
+ wait "$watcher" 2>/dev/null
563
+ return $rc
564
+ fi
565
+ }
566
+
567
+ portable_mktemp() {
568
+ local prefix="${1:-validate}"
569
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
570
+ }
571
+
572
+ CLEANUP_FILES=()
573
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
574
+ trap cleanup EXIT
575
+
576
+ PING_URL="${1:-}"
577
+ REPO_DIR="${2:-.}"
578
+
579
+ if [ -z "$PING_URL" ]; then
580
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
581
+ printf "\n"
582
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
583
+ printf " repo_dir Path to your repo (default: current directory)\n"
584
+ exit 1
585
+ fi
586
+
587
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
588
+ printf "Error: directory '%s' not found\n" "${2:-.}"
589
+ exit 1
590
+ fi
591
+ PING_URL="${PING_URL%/}"
592
+ export PING_URL
593
+ PASS=0
594
+
595
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
596
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
597
+ fail() { log "${RED}FAILED${NC} -- $1"; }
598
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
599
+ stop_at() {
600
+ printf "\n"
601
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
602
+ exit 1
603
+ }
604
+
605
+ printf "\n"
606
+ printf "${BOLD}========================================${NC}\n"
607
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
608
+ printf "${BOLD}========================================${NC}\n"
609
+ log "Repo: $REPO_DIR"
610
+ log "Ping URL: $PING_URL"
611
+ printf "\n"
612
+
613
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
614
+
615
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
616
+ CLEANUP_FILES+=("$CURL_OUTPUT")
617
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
618
+ -H "Content-Type: application/json" -d '{}' \
619
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
620
+
621
+ if [ "$HTTP_CODE" = "200" ]; then
622
+ pass "HF Space is live and responds to /reset"
623
+ elif [ "$HTTP_CODE" = "000" ]; then
624
+ fail "HF Space not reachable (connection failed or timed out)"
625
+ hint "Check your network connection and that the Space is running."
626
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
627
+ stop_at "Step 1"
628
+ else
629
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
630
+ hint "Make sure your Space is running and the URL is correct."
631
+ hint "Try opening $PING_URL in your browser first."
632
+ stop_at "Step 1"
633
+ fi
634
+
635
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
636
+
637
+ if ! command -v docker &>/dev/null; then
638
+ fail "docker command not found"
639
+ hint "Install Docker: https://docs.docker.com/get-docker/"
640
+ stop_at "Step 2"
641
+ fi
642
+
643
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
644
+ DOCKER_CONTEXT="$REPO_DIR"
645
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
646
+ DOCKER_CONTEXT="$REPO_DIR/server"
647
+ else
648
+ fail "No Dockerfile found in repo root or server/ directory"
649
+ stop_at "Step 2"
650
+ fi
651
+
652
+ log " Found Dockerfile in $DOCKER_CONTEXT"
653
+
654
+ BUILD_OK=false
655
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
656
+
657
+ if [ "$BUILD_OK" = true ]; then
658
+ pass "Docker build succeeded"
659
+ else
660
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
661
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
662
+ stop_at "Step 2"
663
+ fi
664
+
665
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
666
+
667
+ if ! command -v openenv &>/dev/null; then
668
+ fail "openenv command not found"
669
+ hint "Install it: pip install openenv-core"
670
+ stop_at "Step 3"
671
+ fi
672
+
673
+ VALIDATE_OK=false
674
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
675
+
676
+ if [ "$VALIDATE_OK" = true ]; then
677
+ pass "openenv validate passed"
678
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
679
+ else
680
+ fail "openenv validate failed"
681
+ printf "%s\n" "$VALIDATE_OUTPUT"
682
+ stop_at "Step 3"
683
+ fi
684
+
685
+ printf "\n"
686
+ printf "${BOLD}========================================${NC}\n"
687
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
688
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
689
+ printf "${BOLD}========================================${NC}\n"
690
+ printf "\n"
691
+
692
+ exit 0
693
+
694
+
695
+
696
+
697
+
698
+
699
+ When Round 1 opens, you'll choose 1 of 4–5 problem statements and build an OpenEnv environment around it.
700
+
701
+ Example of what a problem statement looks like
702
+
703
+ "Build a mini-game RL environment with clearly defined tasks, automated graders, and reward logic using the OpenEnv framework."
704
+
705
+ → Create a mini-game an AI agent can play
706
+
707
+ → Define tasks with increasing difficulty
708
+
709
+ → Write graders that verify task completion
710
+
711
+ → Define reward logic for scoring
712
+
713
+ → Package using OpenEnv for automated evaluation
714
+
715
+ Evaluation Criteria
716
+
717
+ Runtime correctness
718
+
719
+ Runs without errors
720
+
721
+ Interface compliance
722
+
723
+ Follows OpenEnv standard
724
+
725
+ Task design
726
+
727
+ Clear, realistic, testable
728
+
729
+ Grading logic
730
+
731
+ Reward system makes sense
732
+
733
+
734
+
735
+
736
+
737
+
738
+
739
+
740
+
741
+ Install before April 1st.
742
+
743
+ Python 3.10+
744
+
745
+ Install 3.10, 3.11, or 3.12.
746
+
747
+ $
748
+ python --version
749
+ Copy
750
+ Git + GitHub account
751
+
752
+ Push your submission to GitHub or HF.
753
+
754
+ $
755
+ git --version
756
+ Copy
757
+ Hugging Face CLI
758
+
759
+ Deploy to HF Spaces.
760
+
761
+ $
762
+ pip install huggingface_hub --version
763
+ Copy
764
+ $
765
+ huggingface-cli login
766
+ Copy
767
+ OpenEnv
768
+
769
+ The framework.
770
+
771
+ $
772
+ pip install openenv-core
773
+ Copy
774
+ Google Colab
775
+
776
+ Prep course runs in Colab. Free tier works.
777
+
778
+ $
779
+ pip install openenv-core
780
+ Copy
781
+ OpenEnv
782
+
783
+ The framework.
784
+
785
+ → colab.research.google.com
786
+ Copy
787
+ Docker
788
+
789
+ Isolated container testing.
790
+
791
+ docker --version
792
+ Copy
793
+ Recommended
794
+
795
+ VS Code
796
+
797
+ Best Python + Docker support
798
+
799
+ When Round 1 starts on 1 April:
800
+
801
+ Step 1
802
+
803
+ Application Form
804
+ Choose 1 of the 4–5 problem statements revealed on the platform.
805
+
806
+ Step 2
807
+
808
+ Scaffold
809
+ $
810
+ openenv init my_env
811
+ Copy
812
+ Generate project structure.
813
+
814
+ Step 3
815
+
816
+ Build
817
+ Define your environment in the generated files.
818
+
819
+ Step 4
820
+
821
+ Test locally
822
+ $
823
+ uv run server
824
+ Copy
825
+ Step 5
826
+
827
+ Deploy
828
+ $
829
+ openenv push --repo-id your-username/my-env
830
+ Copy
831
+ Step 6
832
+
833
+ Submit
834
+ Paste your HF Spaces URL here before the deadline.
835
+
836
+
837
+
838
+ Deadline: 8 April 2026, 11:59 PM IST
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any, Dict, Optional
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import JSONResponse
9
+ from pydantic import BaseModel
10
+
11
+ from models import TrustAction, TrustObservation, TrustState, ContentSignals
12
+ from your_environment import TrustSafetyEnvironment
13
+
14
+ # ── Force manual FastAPI (openenv_core create_app causes 422 on /step) ────────
15
+ print("[app] Using manual FastAPI ✅")
16
+
17
+ _env = TrustSafetyEnvironment(seed=42)
18
+
19
+ app = FastAPI(
20
+ title="Trust & Safety RL Environment",
21
+ description="Risk-aware content moderation environment for agent training.",
22
+ version="1.0.0",
23
+ )
24
+
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+
33
+ # ── Serializers ───────────────────────────────────────────────────────────────
34
+
35
+ def _obs_to_dict(obs: TrustObservation) -> Dict[str, Any]:
36
+ return {
37
+ "ticket_id": obs.ticket_id,
38
+ "post_text": obs.post_text,
39
+ "image_description": obs.image_description,
40
+ "comments_found": obs.comments_found,
41
+ "user_history_found": obs.user_history_found,
42
+ "entity_status_found": obs.entity_status_found,
43
+ "policy_found": obs.policy_found,
44
+ "extracted_signals": obs.extracted_signals,
45
+ "validation_result": obs.validation_result,
46
+ "step_number": obs.step_number,
47
+ "info": obs.info,
48
+ "done": obs.done,
49
+ "reward": obs.reward,
50
+ }
51
+
52
+
53
+ def _state_to_dict(s: TrustState) -> Dict[str, Any]:
54
+ return {
55
+ "episode_id": s.episode_id,
56
+ "step_count": s.step_count,
57
+ "current_task_id": s.current_task_id,
58
+ "difficulty": s.difficulty,
59
+ "ambiguity_level": s.ambiguity_level,
60
+ "risk_level": s.risk_level,
61
+ "tools_used": s.tools_used,
62
+ "signals_extracted": s.signals_extracted,
63
+ "is_done": s.is_done,
64
+ }
65
+
66
+
67
+ # ── Request bodies ─────────────────────────────────────────────────────────────
68
+
69
+ class ResetRequest(BaseModel):
70
+ seed: Any = None
71
+ episode_id: Any = None
72
+
73
+ model_config = {"extra": "ignore"}
74
+
75
+
76
+ class ActionRequest(BaseModel):
77
+ action_type: str = ""
78
+ tool_name: Optional[str] = None
79
+ signals: Optional[Dict[str, Any]] = None # raw dict — validated below
80
+ final_decision: Optional[str] = None
81
+
82
+ model_config = {"extra": "ignore"} # ← ignore unknown keys from LLM
83
+
84
+
85
+ # ── Helpers ────────────────────────────────────────────────────────────────────
86
+
87
+ def _parse_signals(raw: Dict[str, Any]) -> ContentSignals:
88
+ """Defensively normalise LLM signal output before Pydantic validation."""
89
+ # Clamp floats
90
+ raw["toxicity_level"] = float(raw.get("toxicity_level", 0.5))
91
+ raw["confidence"] = float(raw.get("confidence", 0.5))
92
+
93
+ # content_flags must be a list of strings
94
+ flags = raw.get("content_flags", [])
95
+ if not isinstance(flags, list):
96
+ flags = [flags] if isinstance(flags, str) else []
97
+ raw["content_flags"] = [str(f) for f in flags]
98
+
99
+ # boolean coercion
100
+ raw["is_protected_class"] = bool(raw.get("is_protected_class", False))
101
+ raw["is_direct_attack"] = bool(raw.get("is_direct_attack", False))
102
+ raw["abusive_language_present"] = bool(raw.get("abusive_language_present", False))
103
+
104
+ # string fields — fallback to sensible defaults
105
+ raw.setdefault("target", "none")
106
+ raw.setdefault("intent", "ambiguous")
107
+ raw.setdefault("context_type", "statement")
108
+
109
+ return ContentSignals(**raw)
110
+
111
+
112
+ # ── Routes ─────────────────────────────────────────────────────────────────────
113
+
114
+ @app.get("/health")
115
+ async def health():
116
+ return {"status": "ok", "environment": "trust-safety-env", "version": "1.0.0"}
117
+
118
+
119
+ @app.get("/")
120
+ async def root():
121
+ return {"status": "ok", "docs": "/docs"}
122
+
123
+
124
+ @app.post("/reset")
125
+ async def reset(body: ResetRequest = ResetRequest()):
126
+ obs = _env.reset(seed=body.seed, episode_id=body.episode_id)
127
+ return JSONResponse(_obs_to_dict(obs))
128
+
129
+
130
+ @app.post("/step")
131
+ async def step(body: ActionRequest):
132
+ # Parse + validate signals defensively
133
+ signals: Optional[ContentSignals] = None
134
+ if body.signals:
135
+ try:
136
+ signals = _parse_signals(dict(body.signals)) # copy so we don't mutate
137
+ except Exception as e:
138
+ raise HTTPException(status_code=400, detail=f"Invalid signals payload: {e}")
139
+
140
+ action = TrustAction(
141
+ action_type = body.action_type,
142
+ tool_name = body.tool_name,
143
+ signals = signals,
144
+ final_decision = body.final_decision,
145
+ )
146
+
147
+ try:
148
+ obs = _env.step(action)
149
+ except (RuntimeError, ValueError) as e:
150
+ raise HTTPException(status_code=400, detail=str(e))
151
+
152
+ return JSONResponse(_obs_to_dict(obs))
153
+
154
+
155
+ @app.get("/state")
156
+ async def state():
157
+ return JSONResponse(_state_to_dict(_env.state))
client.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any
3
+
4
+ from openenv.core.env_client import EnvClient # ✅ correct import
5
+ from openenv.core.client_types import StepResult # ✅ correct import
6
+
7
+ from models import TrustAction, TrustObservation, TrustState, ContentSignals
8
+
9
+
10
+ class TrustSafetyEnv(EnvClient[TrustAction, TrustObservation, TrustState]): # ✅ EnvClient, 3 generics
11
+ """
12
+ Typed WebSocket/HTTP client for the Trust & Safety RL Environment.
13
+
14
+ Usage (sync — for scripts, GRPOTrainer):
15
+ env = TrustSafetyEnv(base_url="http://localhost:8000").sync()
16
+ result = env.reset()
17
+ result = env.reset(episode_id="T-001")
18
+ result = env.step(TrustAction(action_type="use_tool", tool_name="view_policy"))
19
+ result = env.step(TrustAction(action_type="final_decision", final_decision="REMOVE"))
20
+ state = env.state()
21
+ env.close()
22
+
23
+ Usage (async):
24
+ async with TrustSafetyEnv(base_url="http://localhost:8000") as env:
25
+ result = await env.reset()
26
+ """
27
+
28
+ def step_payload(self, action: TrustAction) -> dict: # ✅ NO underscore
29
+ payload: dict[str, Any] = {"action_type": action.action_type}
30
+
31
+ if action.tool_name is not None:
32
+ payload["tool_name"] = action.tool_name
33
+
34
+ if action.signals is not None:
35
+ s = action.signals
36
+ payload["signals"] = {
37
+ "target": s.target,
38
+ "is_protected_class": s.is_protected_class,
39
+ "toxicity_level": float(s.toxicity_level),
40
+ "is_direct_attack": s.is_direct_attack,
41
+ "context_type": s.context_type,
42
+ "intent": s.intent,
43
+ "confidence": float(s.confidence),
44
+ "abusive_language_present": s.abusive_language_present,
45
+ "content_flags": list(s.content_flags),
46
+ }
47
+
48
+ if action.final_decision is not None:
49
+ payload["final_decision"] = action.final_decision
50
+
51
+ return payload
52
+
53
+ def parse_result(self, payload: dict) -> StepResult[TrustObservation]: # ✅ NO underscore
54
+ obs_data = payload.get("observation", payload)
55
+
56
+ obs = TrustObservation(
57
+ ticket_id = obs_data.get("ticket_id", ""),
58
+ post_text = obs_data.get("post_text", ""),
59
+ image_description = obs_data.get("image_description", ""),
60
+ comments_found = obs_data.get("comments_found"),
61
+ user_history_found = obs_data.get("user_history_found"),
62
+ entity_status_found = obs_data.get("entity_status_found"),
63
+ policy_found = obs_data.get("policy_found"),
64
+ extracted_signals = obs_data.get("extracted_signals"),
65
+ validation_result = obs_data.get("validation_result"),
66
+ step_number = obs_data.get("step_number", 0),
67
+ info = obs_data.get("info"),
68
+ done = payload.get("done", obs_data.get("done", False)),
69
+ reward = payload.get("reward", obs_data.get("reward")),
70
+ )
71
+
72
+ return StepResult(
73
+ observation = obs,
74
+ reward = payload.get("reward", obs_data.get("reward")),
75
+ done = payload.get("done", obs_data.get("done", False)),
76
+ )
77
+
78
+ def parse_state(self, payload: dict) -> TrustState: # ✅ NO underscore
79
+ return TrustState(
80
+ episode_id = payload.get("episode_id"),
81
+ step_count = payload.get("step_count", 0),
82
+ current_task_id = payload.get("current_task_id"),
83
+ difficulty = payload.get("difficulty"),
84
+ ambiguity_level = payload.get("ambiguity_level"),
85
+ risk_level = payload.get("risk_level"),
86
+ tools_used = payload.get("tools_used", []),
87
+ signals_extracted = payload.get("signals_extracted", False),
88
+ is_done = payload.get("is_done", False),
89
+ )
inference.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — Trust & Safety RL Environment Evaluation
3
+ ========================================================
4
+ MANDATORY env vars:
5
+ API_BASE_URL LLM endpoint (e.g. https://router.huggingface.co/v1)
6
+ MODEL_NAME Model ID (e.g. meta-llama/Llama-3.1-8B-Instruct)
7
+ HF_TOKEN API key
8
+ ENV_BASE_URL Environment server URL (default: http://localhost:8000)
9
+ """
10
+
11
+ import os, json, time, requests
12
+ from openai import OpenAI
13
+
14
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
15
+ API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
16
+ MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
17
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000")
18
+
19
+ if not API_KEY:
20
+ raise EnvironmentError("Set HF_TOKEN (your Hugging Face / API key).")
21
+ if not MODEL_NAME:
22
+ raise EnvironmentError("Set MODEL_NAME.")
23
+
24
+ SYSTEM_PROMPT = """You are a Trust & Safety analyst. Follow this EXACT 3-stage pipeline.
25
+ Skipping any stage results in a score penalty.
26
+
27
+ ════════════════════════════════════════════
28
+ STAGE 1 — Investigate (optional)
29
+ ════════════════════════════════════════════
30
+ Use tools only when context is needed. For OBVIOUS violations (spam, phishing), skip.
31
+
32
+ {"action_type": "use_tool", "tool_name": "read_comments"}
33
+ {"action_type": "use_tool", "tool_name": "check_user_history"}
34
+ {"action_type": "use_tool", "tool_name": "check_entity_status"}
35
+ {"action_type": "use_tool", "tool_name": "view_policy"}
36
+
37
+ ════════════════════════════════════════════
38
+ STAGE 2 — Extract signals (MANDATORY)
39
+ ════════════════════════════════════════════
40
+ You MUST call extract_signals before final_decision. No exceptions.
41
+ You are a FEATURE EXTRACTOR — output what the content shows, do NOT decide.
42
+
43
+ {"action_type": "extract_signals", "signals": {
44
+ "target": "individual|group|religion|race|gender|political|none",
45
+ "is_protected_class": true|false,
46
+ "toxicity_level": 0.0-1.0,
47
+ "is_direct_attack": true|false,
48
+ "context_type": "statement|satire|news|meme|personal_attack|protest|education|spam",
49
+ "intent": "harmful|descriptive|ambiguous|satirical|educational",
50
+ "confidence": 0.0-1.0,
51
+ "abusive_language_present": true|false,
52
+ "content_flags": ["spam","phishing","slur","threat","misinformation",
53
+ "political_content","public_figure","competitive_language"]
54
+ }}
55
+
56
+ After submitting signals you receive VALIDATION. If needs_escalation_hint is true → prefer ESCALATE.
57
+
58
+ ════════════════════════════════════════════
59
+ STAGE 3 — Final decision (MANDATORY)
60
+ ════════════════════════════════════════════
61
+ {"action_type": "final_decision", "final_decision": "ALLOW|REMOVE|ALLOW_WITH_WARNING|ESCALATE"}
62
+
63
+ DECISION RULES (apply in order):
64
+ 1. Phishing / spam / impersonation → REMOVE
65
+ 2. Explicit hate speech, direct threats → REMOVE
66
+ 3. Political satire of a PUBLIC FIGURE → ALLOW_WITH_WARNING
67
+ 4. News / journalism / education → ALLOW_WITH_WARNING
68
+ 5. Gaming / sports banter (confirmed) → ALLOW
69
+ 6. Adjusted confidence after validation < 0.6 → ESCALATE
70
+ 7. Genuinely ambiguous, no clear rule → ESCALATE
71
+ 8. Safe, no signals triggered → ALLOW
72
+
73
+ NEVER jump from Stage 1 directly to final_decision.
74
+ ALWAYS: [optional tools] → extract_signals → final_decision
75
+ Respond in strict JSON only. No markdown."""
76
+
77
+ SIGNAL_DEFAULTS = {
78
+ "target": "none", "is_protected_class": False,
79
+ "toxicity_level": 0.5, "is_direct_attack": False,
80
+ "context_type": "statement", "intent": "ambiguous",
81
+ "confidence": 0.5, "abusive_language_present": False,
82
+ "content_flags": [],
83
+ }
84
+ VALID_TOOLS = {"read_comments","check_user_history","check_entity_status","view_policy"}
85
+ VALID_DECISIONS = {"ALLOW","REMOVE","ALLOW_WITH_WARNING","ESCALATE"}
86
+ VALID_TYPES = {"use_tool","extract_signals","final_decision"}
87
+
88
+
89
+ def _clamp_signals(raw: dict) -> dict:
90
+ result = SIGNAL_DEFAULTS.copy()
91
+ for k in SIGNAL_DEFAULTS:
92
+ if k in raw:
93
+ result[k] = raw[k]
94
+ try:
95
+ result["toxicity_level"] = max(0.0, min(1.0, float(result["toxicity_level"])))
96
+ result["confidence"] = max(0.0, min(1.0, float(result["confidence"])))
97
+ except (TypeError, ValueError):
98
+ result["toxicity_level"] = 0.5
99
+ result["confidence"] = 0.5
100
+ if not isinstance(result["content_flags"], list):
101
+ result["content_flags"] = []
102
+ return result
103
+
104
+
105
+ def _parse(text: str) -> dict:
106
+ text = text.strip()
107
+ s, e = text.find("{"), text.rfind("}") + 1
108
+ if s == -1 or e == 0:
109
+ raise ValueError(f"No JSON in: {text}")
110
+ return json.loads(text[s:e])
111
+
112
+
113
+ def _normalize(raw: dict) -> dict:
114
+ t = raw.get("action_type", "")
115
+ if t not in VALID_TYPES:
116
+ return {"action_type": "final_decision", "final_decision": "ESCALATE"}
117
+ if t == "use_tool":
118
+ tool = raw.get("tool_name", "")
119
+ return {"action_type": "use_tool", "tool_name": tool} if tool in VALID_TOOLS \
120
+ else {"action_type": "final_decision", "final_decision": "ESCALATE"}
121
+ if t == "extract_signals":
122
+ sigs = raw.get("signals")
123
+ return {"action_type": "extract_signals", "signals": _clamp_signals(sigs)} \
124
+ if sigs else {"action_type": "final_decision", "final_decision": "ESCALATE"}
125
+ dec = raw.get("final_decision", "ESCALATE")
126
+ return {"action_type": "final_decision",
127
+ "final_decision": dec if dec in VALID_DECISIONS else "ESCALATE"}
128
+
129
+
130
+ def _obs_to_prompt(obs: dict) -> str:
131
+ lines = [
132
+ f"=== TICKET {obs.get('ticket_id','')} (Step {obs.get('step_number',0)}) ===",
133
+ f"\nPOST TEXT:\n{obs.get('post_text','')}",
134
+ f"\nIMAGE:\n{obs.get('image_description','')}",
135
+ ]
136
+ for key, label in [
137
+ ("comments_found","COMMENTS"),("user_history_found","USER HISTORY"),
138
+ ("entity_status_found","ENTITY STATUS"),("policy_found","POLICY"),
139
+ ]:
140
+ if obs.get(key):
141
+ lines.append(f"\n{label}:\n{obs[key]}")
142
+ if obs.get("extracted_signals"):
143
+ lines.append(f"\nYOUR EXTRACTED SIGNALS:\n{json.dumps(obs['extracted_signals'],indent=2)}")
144
+ if obs.get("validation_result"):
145
+ v = obs["validation_result"]
146
+ hint = "⚠️ YES — prefer ESCALATE" if v.get("needs_escalation_hint") else "No"
147
+ lines.append(
148
+ f"\n📋 VALIDATION:\n"
149
+ f" Adj. Confidence : {v.get('adjusted_confidence')}\n"
150
+ f" Issues : {v.get('consistency_issues')}\n"
151
+ f" Escalation Hint : {hint}"
152
+ )
153
+ if not obs.get("extracted_signals"):
154
+ lines.append("\n⚠️ REMINDER: Call extract_signals before final_decision.")
155
+ lines.append("\nYour next action (strict JSON only):")
156
+ return "\n".join(lines)
157
+
158
+
159
+ def run_task(client: OpenAI, task_id: str) -> float:
160
+ for _ in range(30):
161
+ # CORRECT ✅ — pass task ID directly
162
+ r = requests.post(
163
+ f"{ENV_BASE_URL}/reset",
164
+ json={"episode_id": task_id}, # ← this is the only change
165
+ timeout=10
166
+ )
167
+ r.raise_for_status()
168
+ obs = r.json()
169
+ # Handle both flat (TrustObservation) and wrapped response
170
+ if isinstance(obs, dict) and "observation" in obs:
171
+ obs = obs["observation"]
172
+ if obs.get("ticket_id") == task_id:
173
+ break
174
+ else:
175
+ raise RuntimeError(f"Could not get task {task_id} after 30 resets.")
176
+
177
+ print(f"\n{'='*62}\nTask: {task_id} | Starting...\n{'='*62}")
178
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
179
+ final_reward = 0.0
180
+
181
+ for step_num in range(14):
182
+ messages.append({"role": "user", "content": _obs_to_prompt(obs)})
183
+ time.sleep(0.5)
184
+
185
+ resp = client.chat.completions.create(
186
+ model=MODEL_NAME, messages=messages, temperature=0.0,
187
+ response_format={"type": "json_object"},
188
+ )
189
+ llm_text = resp.choices[0].message.content or ""
190
+ messages.append({"role": "assistant", "content": llm_text})
191
+
192
+ try:
193
+ action = _normalize(_parse(llm_text))
194
+ except Exception as ex:
195
+ print(f" [Step {step_num+1}] Parse error: {ex}"); break
196
+
197
+ atype = action["action_type"]
198
+ if atype == "use_tool":
199
+ print(f" [Step {step_num+1}] 🔧 use_tool → {action.get('tool_name')}")
200
+ elif atype == "extract_signals":
201
+ s = action.get("signals", {})
202
+ print(f" [Step {step_num+1}] 🔍 extract_signals → "
203
+ f"intent={s.get('intent')} | ctx={s.get('context_type')} | "
204
+ f"tox={s.get('toxicity_level')} | conf={s.get('confidence')}")
205
+ else:
206
+ print(f" [Step {step_num+1}] ⚖️ final_decision → {action.get('final_decision')}")
207
+
208
+ r2 = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
209
+ r2.raise_for_status()
210
+ result = r2.json()
211
+
212
+ # Handle flat (TrustObservation) and wrapped response
213
+ if "observation" in result:
214
+ obs = result["observation"]
215
+ done = result.get("done", obs.get("done", False))
216
+ final_reward = float(result.get("reward") or obs.get("reward") or 0.0)
217
+ else:
218
+ obs = result
219
+ done = result.get("done", False)
220
+ final_reward = float(result.get("reward") or 0.0)
221
+
222
+ if done:
223
+ info = obs.get("info") or {}
224
+ bd = info.get("reward_breakdown", {})
225
+ pol = info.get("policy_recommendation", {})
226
+ vr = obs.get("validation_result") or {}
227
+
228
+ print(f"\n ── EPISODE COMPLETE {'─'*42}")
229
+ print(f" Decision: {info.get('final_decision','N/A')}")
230
+ print(f" Ground Truth: {info.get('ground_truth','N/A')}")
231
+ print(f" Policy Engine: {pol.get('recommended','N/A')} "
232
+ f"[{pol.get('rule_strength','?')} rule] ({pol.get('reason','?')})")
233
+ print(f" Signals Extracted: {'✅' if info.get('signals_extracted') else '❌ SKIPPED'}")
234
+ print(f" Tools Used: {info.get('tools_used', [])}")
235
+ print(f" Required Tools: {info.get('required_tools', [])}")
236
+ print(f" Adj. Confidence: {vr.get('adjusted_confidence','N/A')}")
237
+ print(f" Issues: {vr.get('consistency_issues',[])}")
238
+ print(f" Ambiguity / Risk: {info.get('ambiguity_level','?')} / {info.get('risk_level','?')}")
239
+ if bd:
240
+ print(f"\n ── Reward Breakdown {'─'*42}")
241
+ print(f" 1. Base Decision Score: {bd.get('base_score',0):+.4f}")
242
+ print(f" 2. Policy Alignment: {bd.get('policy_alignment',0):+.4f}")
243
+ print(f" 3. Signal Accuracy Bonus: {bd.get('signal_accuracy_bonus',0):+.4f}")
244
+ print(f" 4. Escalation Adjustment: {bd.get('escalation_adj',0):+.4f}")
245
+ print(f" 5. Signal Process Bonus: {bd.get('signal_bonus',0):+.4f}")
246
+ print(f" Tool Cost: -{bd.get('tool_cost',0):.4f}")
247
+ print(f" Tool Miss Penalty: -{bd.get('tool_miss_penalty',0):.4f}")
248
+ print(f" Validation Penalty: -{bd.get('validation_penalty',0):.4f}")
249
+ print(f" Risk Penalty: -{bd.get('risk_penalty',0):.4f}")
250
+ print(f" Confidence Discipline: -{bd.get('confidence_penalty',0):.4f}")
251
+ print(f" {'─'*60}")
252
+ print(f" FINAL REWARD: {bd.get('final_reward',0):.4f}")
253
+ print(f"\n SCORE: {final_reward:.4f}")
254
+ break
255
+
256
+ return final_reward
257
+
258
+
259
+ def main() -> None:
260
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
261
+
262
+ print("=" * 62)
263
+ print("Trust & Safety RL Environment — Baseline Evaluation")
264
+ print("=" * 62)
265
+ print(f"Model : {MODEL_NAME}")
266
+ print(f"LLM API : {API_BASE_URL}")
267
+ print(f"Env Server : {ENV_BASE_URL}")
268
+ print(f"Reward : Accuracy · Policy · Signals · Escalation")
269
+ print(f" Tools · Consistency · Risk · Confidence")
270
+
271
+ tasks = [
272
+ ("T-001", "Easy — Phishing Spam", "low"),
273
+ ("T-002", "Medium — Gaming Banter", "low"),
274
+ ("T-003", "Hard — Political Satire", "high"),
275
+ ]
276
+ scores = []
277
+ for tid, desc, risk in tasks:
278
+ print(f"\n\n>>> {tid} | {desc} | Risk: {risk}")
279
+ scores.append((tid, desc, run_task(client, tid)))
280
+
281
+ print("\n" + "=" * 62)
282
+ print("FINAL BASELINE RESULTS")
283
+ print("=" * 62)
284
+ total = 0.0
285
+ for tid, desc, s in scores:
286
+ print(f" {tid} | {desc:<32} | {s:.4f} {'✅ PASS' if s >= 0.6 else '❌ FAIL'}")
287
+ total += s
288
+ vals = [s for _, _, s in scores]
289
+ print(f"\n Average : {total/len(scores):.4f}")
290
+ print(f" Min : {min(vals):.4f} | Max : {max(vals):.4f}")
291
+ print("=" * 62)
292
+
293
+
294
+ if __name__ == "__main__":
295
+ main()
models.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Optional, List, Dict, Any
3
+ from pydantic import BaseModel, Field, field_validator
4
+
5
+
6
+ class ContentSignals(BaseModel):
7
+ target: str = "none"
8
+ is_protected_class: bool = False
9
+ toxicity_level: float = 0.5
10
+ is_direct_attack: bool = False
11
+ context_type: str = "statement"
12
+ intent: str = "ambiguous"
13
+ confidence: float = 0.5
14
+ abusive_language_present: bool = False
15
+ content_flags: List[str] = Field(default_factory=list)
16
+
17
+ @field_validator("toxicity_level", "confidence")
18
+ @classmethod
19
+ def clamp_0_1(cls, v: float) -> float:
20
+ return max(0.0, min(1.0, float(v)))
21
+
22
+ model_config = {"extra": "ignore"}
23
+
24
+
25
+ class TrustAction(BaseModel):
26
+ action_type: str = ""
27
+ tool_name: Optional[str] = None
28
+ signals: Optional[ContentSignals] = None
29
+ final_decision: Optional[str] = None
30
+
31
+ model_config = {"extra": "ignore"}
32
+
33
+
34
+ class TrustObservation(BaseModel):
35
+ ticket_id: str = ""
36
+ post_text: str = ""
37
+ image_description: str = ""
38
+ comments_found: Optional[str] = None
39
+ user_history_found: Optional[str] = None
40
+ entity_status_found: Optional[str] = None
41
+ policy_found: Optional[str] = None
42
+ extracted_signals: Optional[Dict[str, Any]] = None
43
+ validation_result: Optional[Dict[str, Any]] = None
44
+ step_number: int = 0
45
+ info: Optional[Dict[str, Any]] = None
46
+ done: bool = False
47
+ reward: Optional[float] = None
48
+
49
+ model_config = {"extra": "ignore"}
50
+
51
+
52
+ class TrustState(BaseModel):
53
+ episode_id: Optional[str] = None
54
+ step_count: int = 0
55
+ current_task_id: Optional[str] = None
56
+ difficulty: Optional[str] = None
57
+ ambiguity_level: Optional[str] = None
58
+ risk_level: Optional[str] = None
59
+ tools_used: List[str] = Field(default_factory=list)
60
+ signals_extracted: bool = False
61
+ is_done: bool = False
62
+
63
+ model_config = {"extra": "ignore"}
openenv.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: trust-safety-env
3
+ type: environment
4
+ runtime: python
5
+ app: app:app
6
+ port: 8000
7
+
8
+ description: >
9
+ Risk-aware content moderation RL environment for Trust & Safety decision-making.
10
+ Agents investigate content, extract structured signals, and make policy-aligned
11
+ decisions under uncertainty across hate speech, political sensitivity, and
12
+ cultural nuance. Models real-world moderation at scale (Meta-style).
13
+
14
+ author: Jerome Richard D
15
+ version: "1.0.0"
16
+ license: MIT
17
+
18
+ action_space:
19
+ type: TrustAction
20
+ description: "use_tool | extract_signals | final_decision"
21
+
22
+ observation_space:
23
+ type: TrustObservation
24
+ description: "Content ticket with progressive context revelation"
25
+
26
+ tasks:
27
+ - id: T-001
28
+ name: Phishing Spam Detection
29
+ difficulty: easy
30
+ description: Identify and remove clear phishing / impersonation content
31
+
32
+ - id: T-002
33
+ name: Gaming Banter Classification
34
+ difficulty: medium
35
+ description: Distinguish competitive gaming banter from genuine harassment
36
+
37
+ - id: T-003
38
+ name: Political Satire Review
39
+ difficulty: hard
40
+ description: Handle editorial satire of public figures with high-risk sensitivity
41
+
42
+ - id: T-004
43
+ name: Hate Speech Disguised as Education
44
+ difficulty: medium
45
+ description: Detect hate speech hidden behind pseudoscientific or educational framing
46
+
47
+ - id: T-005
48
+ name: Political News with Protest Violence
49
+ difficulty: hard
50
+ description: Protect legitimate journalism on sensitive political events without over-censorship
51
+
52
+ - id: T-006
53
+ name: Religious Expression False Flag
54
+ difficulty: hard
55
+ description: Distinguish protected religious expression from automated false-positive flag
56
+
57
+ tags:
58
+ - content-moderation
59
+ - trust-safety
60
+ - hate-speech
61
+ - political-sensitivity
62
+ - cultural-nuance
63
+ - real-world
64
+ - openenv
pyproject.toml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.backends.legacy:build"
4
+
5
+ [project]
6
+ name = "trust-safety-env"
7
+ version = "1.0.0"
8
+ description = "Risk-aware Trust & Safety content moderation RL environment — OpenEnv compatible"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ dependencies = [
12
+ "openenv-core>=0.2.0",
13
+ "fastapi>=0.110.0",
14
+ "uvicorn[standard]>=0.29.0",
15
+ "pydantic>=2.6.0",
16
+ "openai>=1.30.0",
17
+ "requests>=2.31.0",
18
+ "python-dotenv>=1.0.0",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ dev = ["pytest>=8.0"]
23
+
24
+ [tool.setuptools.packages.find]
25
+ where = ["."]
26
+ include = ["*"]
27
+
28
+ [tool.openenv]
29
+ name = "trust-safety-env"
30
+ environment_class = "your_environment.TrustSafetyEnvironment"
31
+ action_model = "models.TrustAction"
32
+ observation_model = "models.TrustObservation"
33
+ state_model = "models.TrustState"
readme.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Trust Safety RL Environment
3
+ emoji: 🛡️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ tags:
9
+ - openenv
10
+ - reinforcement-learning
11
+ - content-moderation
12
+ - trust-safety
13
+ ---
14
+
15
+ # Trust & Safety RL Environment
16
+
17
+ 3-layer risk-aware content moderation RL environment built on OpenEnv.
18
+
19
+ ## Endpoints
20
+ - `POST /reset` — start a new episode
21
+ - `POST /step` — take an action
22
+ - `GET /state` — current episode state
23
+ - `GET /health` — health check
24
+ - `GET /docs` — interactive API docs
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.115.0
2
+ uvicorn[standard]>=0.30.0
3
+ pydantic>=2.0.0
4
+ requests>=2.31.0
5
+ openenv-core>=0.2.2
6
+ python-dotenv>=1.0.0
simpley ADDED
@@ -0,0 +1,1716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app.py:
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from typing import Any, Dict, Optional
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import JSONResponse
10
+ from pydantic import BaseModel
11
+
12
+ from models import TrustAction, TrustObservation, TrustState, ContentSignals
13
+ from your_environment import TrustSafetyEnvironment
14
+
15
+ # ── Force manual FastAPI (openenv_core create_app causes 422 on /step) ────────
16
+ print("[app] Using manual FastAPI ✅")
17
+
18
+ _env = TrustSafetyEnvironment(seed=42)
19
+
20
+ app = FastAPI(
21
+ title="Trust & Safety RL Environment",
22
+ description="Risk-aware content moderation environment for agent training.",
23
+ version="1.0.0",
24
+ )
25
+
26
+ app.add_middleware(
27
+ CORSMiddleware,
28
+ allow_origins=["*"],
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+
34
+ # ── Serializers ───────────────────────────────────────────────────────────────
35
+
36
+ def _obs_to_dict(obs: TrustObservation) -> Dict[str, Any]:
37
+ return {
38
+ "ticket_id": obs.ticket_id,
39
+ "post_text": obs.post_text,
40
+ "image_description": obs.image_description,
41
+ "comments_found": obs.comments_found,
42
+ "user_history_found": obs.user_history_found,
43
+ "entity_status_found": obs.entity_status_found,
44
+ "policy_found": obs.policy_found,
45
+ "extracted_signals": obs.extracted_signals,
46
+ "validation_result": obs.validation_result,
47
+ "step_number": obs.step_number,
48
+ "info": obs.info,
49
+ "done": obs.done,
50
+ "reward": obs.reward,
51
+ }
52
+
53
+
54
+ def _state_to_dict(s: TrustState) -> Dict[str, Any]:
55
+ return {
56
+ "episode_id": s.episode_id,
57
+ "step_count": s.step_count,
58
+ "current_task_id": s.current_task_id,
59
+ "difficulty": s.difficulty,
60
+ "ambiguity_level": s.ambiguity_level,
61
+ "risk_level": s.risk_level,
62
+ "tools_used": s.tools_used,
63
+ "signals_extracted": s.signals_extracted,
64
+ "is_done": s.is_done,
65
+ }
66
+
67
+
68
+ # ── Request bodies ─────────────────────────────────────────────────────────────
69
+
70
+ class ResetRequest(BaseModel):
71
+ seed: Any = None
72
+ episode_id: Any = None
73
+
74
+ model_config = {"extra": "ignore"}
75
+
76
+
77
+ class ActionRequest(BaseModel):
78
+ action_type: str = ""
79
+ tool_name: Optional[str] = None
80
+ signals: Optional[Dict[str, Any]] = None # raw dict — validated below
81
+ final_decision: Optional[str] = None
82
+
83
+ model_config = {"extra": "ignore"} # ← ignore unknown keys from LLM
84
+
85
+
86
+ # ── Helpers ────────────────────────────────────────────────────────────────────
87
+
88
+ def _parse_signals(raw: Dict[str, Any]) -> ContentSignals:
89
+ """Defensively normalise LLM signal output before Pydantic validation."""
90
+ # Clamp floats
91
+ raw["toxicity_level"] = float(raw.get("toxicity_level", 0.5))
92
+ raw["confidence"] = float(raw.get("confidence", 0.5))
93
+
94
+ # content_flags must be a list of strings
95
+ flags = raw.get("content_flags", [])
96
+ if not isinstance(flags, list):
97
+ flags = [flags] if isinstance(flags, str) else []
98
+ raw["content_flags"] = [str(f) for f in flags]
99
+
100
+ # boolean coercion
101
+ raw["is_protected_class"] = bool(raw.get("is_protected_class", False))
102
+ raw["is_direct_attack"] = bool(raw.get("is_direct_attack", False))
103
+ raw["abusive_language_present"] = bool(raw.get("abusive_language_present", False))
104
+
105
+ # string fields — fallback to sensible defaults
106
+ raw.setdefault("target", "none")
107
+ raw.setdefault("intent", "ambiguous")
108
+ raw.setdefault("context_type", "statement")
109
+
110
+ return ContentSignals(**raw)
111
+
112
+
113
+ # ── Routes ─────────────────────────────────────────────────────────────────────
114
+
115
+ @app.get("/health")
116
+ async def health():
117
+ return {"status": "ok", "environment": "trust-safety-env", "version": "1.0.0"}
118
+
119
+
120
+ @app.get("/")
121
+ async def root():
122
+ return {"status": "ok", "docs": "/docs"}
123
+
124
+
125
+ @app.post("/reset")
126
+ async def reset(body: ResetRequest = ResetRequest()):
127
+ obs = _env.reset(seed=body.seed, episode_id=body.episode_id)
128
+ return JSONResponse(_obs_to_dict(obs))
129
+
130
+
131
+ @app.post("/step")
132
+ async def step(body: ActionRequest):
133
+ # Parse + validate signals defensively
134
+ signals: Optional[ContentSignals] = None
135
+ if body.signals:
136
+ try:
137
+ signals = _parse_signals(dict(body.signals)) # copy so we don't mutate
138
+ except Exception as e:
139
+ raise HTTPException(status_code=400, detail=f"Invalid signals payload: {e}")
140
+
141
+ action = TrustAction(
142
+ action_type = body.action_type,
143
+ tool_name = body.tool_name,
144
+ signals = signals,
145
+ final_decision = body.final_decision,
146
+ )
147
+
148
+ try:
149
+ obs = _env.step(action)
150
+ except (RuntimeError, ValueError) as e:
151
+ raise HTTPException(status_code=400, detail=str(e))
152
+
153
+ return JSONResponse(_obs_to_dict(obs))
154
+
155
+
156
+ @app.get("/state")
157
+ async def state():
158
+ return JSONResponse(_state_to_dict(_env.state))
159
+
160
+
161
+
162
+ client.py:
163
+ from __future__ import annotations
164
+ from typing import Any
165
+ from openenv.core.http_env_client import HTTPEnvClient
166
+ from openenv.core.types import StepResult
167
+ from models import TrustAction, TrustObservation, TrustState, ContentSignals
168
+
169
+
170
+ class TrustSafetyEnv(HTTPEnvClient[TrustAction, TrustObservation]):
171
+ """
172
+ Typed HTTP client for the Trust & Safety RL Environment.
173
+
174
+ Usage:
175
+ client = TrustSafetyEnv(base_url="http://localhost:8000")
176
+ result = client.reset()
177
+ result = client.step(TrustAction(action_type="final_decision",
178
+ final_decision="ALLOW"))
179
+ state = client.state()
180
+ client.close()
181
+ """
182
+
183
+ def _step_payload(self, action: TrustAction) -> dict:
184
+ payload: dict = {"action_type": action.action_type}
185
+ if action.tool_name is not None:
186
+ payload["tool_name"] = action.tool_name
187
+ if action.signals is not None:
188
+ s = action.signals
189
+ payload["signals"] = {
190
+ "target": s.target,
191
+ "is_protected_class": s.is_protected_class,
192
+ "toxicity_level": s.toxicity_level,
193
+ "is_direct_attack": s.is_direct_attack,
194
+ "context_type": s.context_type,
195
+ "intent": s.intent,
196
+ "confidence": s.confidence,
197
+ "abusive_language_present": s.abusive_language_present,
198
+ "content_flags": s.content_flags,
199
+ }
200
+ if action.final_decision is not None:
201
+ payload["final_decision"] = action.final_decision
202
+ return payload
203
+
204
+ def _parse_result(self, payload: dict) -> StepResult[TrustObservation]:
205
+ obs_data = payload.get("observation", payload) # handle flat or nested
206
+ signals_raw = obs_data.get("extracted_signals")
207
+ signals = None
208
+ if isinstance(signals_raw, dict):
209
+ try:
210
+ signals = ContentSignals(**signals_raw)
211
+ except Exception:
212
+ signals = None
213
+
214
+ obs = TrustObservation(
215
+ ticket_id=obs_data.get("ticket_id", ""),
216
+ post_text=obs_data.get("post_text", ""),
217
+ image_description=obs_data.get("image_description", ""),
218
+ comments_found=obs_data.get("comments_found"),
219
+ user_history_found=obs_data.get("user_history_found"),
220
+ entity_status_found=obs_data.get("entity_status_found"),
221
+ policy_found=obs_data.get("policy_found"),
222
+ extracted_signals=obs_data.get("extracted_signals"),
223
+ validation_result=obs_data.get("validation_result"),
224
+ step_number=obs_data.get("step_number", 0),
225
+ info=obs_data.get("info"),
226
+ done=payload.get("done", obs_data.get("done", False)),
227
+ reward=payload.get("reward", obs_data.get("reward")),
228
+ )
229
+ return StepResult(
230
+ observation=obs,
231
+ reward=payload.get("reward", obs_data.get("reward")),
232
+ done=payload.get("done", obs_data.get("done", False)),
233
+ )
234
+
235
+ def _parse_state(self, payload: dict) -> TrustState:
236
+ return TrustState(
237
+ episode_id=payload.get("episode_id"),
238
+ step_count=payload.get("step_count", 0),
239
+ current_task_id=payload.get("current_task_id"),
240
+ difficulty=payload.get("difficulty"),
241
+ ambiguity_level=payload.get("ambiguity_level"),
242
+ risk_level=payload.get("risk_level"),
243
+ tools_used=payload.get("tools_used", []),
244
+ signals_extracted=payload.get("signals_extracted", False),
245
+ is_done=payload.get("is_done", False),
246
+ )
247
+
248
+ DockerFile:
249
+ FROM python:3.11-slim
250
+
251
+ WORKDIR /app
252
+
253
+ RUN apt-get update \
254
+ && apt-get install -y --no-install-recommends curl \
255
+ && rm -rf /var/lib/apt/lists/*
256
+
257
+ COPY pyproject.toml .
258
+ RUN pip install --no-cache-dir -e "."
259
+
260
+ COPY . .
261
+
262
+ ENV PYTHONPATH="/app:$PYTHONPATH"
263
+
264
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
265
+ CMD curl -f http://localhost:8000/health || exit 1
266
+
267
+ EXPOSE 8000
268
+
269
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
270
+
271
+
272
+ inference.py:
273
+ """
274
+ inference.py — Trust & Safety RL Environment Evaluation
275
+ ========================================================
276
+ MANDATORY env vars:
277
+ API_BASE_URL LLM endpoint (e.g. https://router.huggingface.co/v1)
278
+ MODEL_NAME Model ID (e.g. meta-llama/Llama-3.1-8B-Instruct)
279
+ HF_TOKEN API key
280
+ ENV_BASE_URL Environment server URL (default: http://localhost:8000)
281
+ """
282
+
283
+ import os, json, time, requests
284
+ from openai import OpenAI
285
+
286
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
287
+ API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
288
+ MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
289
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000")
290
+
291
+ if not API_KEY:
292
+ raise EnvironmentError("Set HF_TOKEN (your Hugging Face / API key).")
293
+ if not MODEL_NAME:
294
+ raise EnvironmentError("Set MODEL_NAME.")
295
+
296
+ SYSTEM_PROMPT = """You are a Trust & Safety analyst. Follow this EXACT 3-stage pipeline.
297
+ Skipping any stage results in a score penalty.
298
+
299
+ ════════════════════════════════════════════
300
+ STAGE 1 — Investigate (optional)
301
+ ════════════════════════════════════════════
302
+ Use tools only when context is needed. For OBVIOUS violations (spam, phishing), skip.
303
+
304
+ {"action_type": "use_tool", "tool_name": "read_comments"}
305
+ {"action_type": "use_tool", "tool_name": "check_user_history"}
306
+ {"action_type": "use_tool", "tool_name": "check_entity_status"}
307
+ {"action_type": "use_tool", "tool_name": "view_policy"}
308
+
309
+ ════════════════════════════════════════════
310
+ STAGE 2 — Extract signals (MANDATORY)
311
+ ════════════════════════════════════════════
312
+ You MUST call extract_signals before final_decision. No exceptions.
313
+ You are a FEATURE EXTRACTOR — output what the content shows, do NOT decide.
314
+
315
+ {"action_type": "extract_signals", "signals": {
316
+ "target": "individual|group|religion|race|gender|political|none",
317
+ "is_protected_class": true|false,
318
+ "toxicity_level": 0.0-1.0,
319
+ "is_direct_attack": true|false,
320
+ "context_type": "statement|satire|news|meme|personal_attack|protest|education|spam",
321
+ "intent": "harmful|descriptive|ambiguous|satirical|educational",
322
+ "confidence": 0.0-1.0,
323
+ "abusive_language_present": true|false,
324
+ "content_flags": ["spam","phishing","slur","threat","misinformation",
325
+ "political_content","public_figure","competitive_language"]
326
+ }}
327
+
328
+ After submitting signals you receive VALIDATION. If needs_escalation_hint is true → prefer ESCALATE.
329
+
330
+ ════════════════════════════════════════════
331
+ STAGE 3 — Final decision (MANDATORY)
332
+ ════════════════════════════════════════════
333
+ {"action_type": "final_decision", "final_decision": "ALLOW|REMOVE|ALLOW_WITH_WARNING|ESCALATE"}
334
+
335
+ DECISION RULES (apply in order):
336
+ 1. Phishing / spam / impersonation → REMOVE
337
+ 2. Explicit hate speech, direct threats → REMOVE
338
+ 3. Political satire of a PUBLIC FIGURE → ALLOW_WITH_WARNING
339
+ 4. News / journalism / education → ALLOW_WITH_WARNING
340
+ 5. Gaming / sports banter (confirmed) → ALLOW
341
+ 6. Adjusted confidence after validation < 0.6 → ESCALATE
342
+ 7. Genuinely ambiguous, no clear rule → ESCALATE
343
+ 8. Safe, no signals triggered → ALLOW
344
+
345
+ NEVER jump from Stage 1 directly to final_decision.
346
+ ALWAYS: [optional tools] → extract_signals → final_decision
347
+ Respond in strict JSON only. No markdown."""
348
+
349
+ SIGNAL_DEFAULTS = {
350
+ "target": "none", "is_protected_class": False,
351
+ "toxicity_level": 0.5, "is_direct_attack": False,
352
+ "context_type": "statement", "intent": "ambiguous",
353
+ "confidence": 0.5, "abusive_language_present": False,
354
+ "content_flags": [],
355
+ }
356
+ VALID_TOOLS = {"read_comments","check_user_history","check_entity_status","view_policy"}
357
+ VALID_DECISIONS = {"ALLOW","REMOVE","ALLOW_WITH_WARNING","ESCALATE"}
358
+ VALID_TYPES = {"use_tool","extract_signals","final_decision"}
359
+
360
+
361
+ def _clamp_signals(raw: dict) -> dict:
362
+ result = SIGNAL_DEFAULTS.copy()
363
+ for k in SIGNAL_DEFAULTS:
364
+ if k in raw:
365
+ result[k] = raw[k]
366
+ try:
367
+ result["toxicity_level"] = max(0.0, min(1.0, float(result["toxicity_level"])))
368
+ result["confidence"] = max(0.0, min(1.0, float(result["confidence"])))
369
+ except (TypeError, ValueError):
370
+ result["toxicity_level"] = 0.5
371
+ result["confidence"] = 0.5
372
+ if not isinstance(result["content_flags"], list):
373
+ result["content_flags"] = []
374
+ return result
375
+
376
+
377
+ def _parse(text: str) -> dict:
378
+ text = text.strip()
379
+ s, e = text.find("{"), text.rfind("}") + 1
380
+ if s == -1 or e == 0:
381
+ raise ValueError(f"No JSON in: {text}")
382
+ return json.loads(text[s:e])
383
+
384
+
385
+ def _normalize(raw: dict) -> dict:
386
+ t = raw.get("action_type", "")
387
+ if t not in VALID_TYPES:
388
+ return {"action_type": "final_decision", "final_decision": "ESCALATE"}
389
+ if t == "use_tool":
390
+ tool = raw.get("tool_name", "")
391
+ return {"action_type": "use_tool", "tool_name": tool} if tool in VALID_TOOLS \
392
+ else {"action_type": "final_decision", "final_decision": "ESCALATE"}
393
+ if t == "extract_signals":
394
+ sigs = raw.get("signals")
395
+ return {"action_type": "extract_signals", "signals": _clamp_signals(sigs)} \
396
+ if sigs else {"action_type": "final_decision", "final_decision": "ESCALATE"}
397
+ dec = raw.get("final_decision", "ESCALATE")
398
+ return {"action_type": "final_decision",
399
+ "final_decision": dec if dec in VALID_DECISIONS else "ESCALATE"}
400
+
401
+
402
+ def _obs_to_prompt(obs: dict) -> str:
403
+ lines = [
404
+ f"=== TICKET {obs.get('ticket_id','')} (Step {obs.get('step_number',0)}) ===",
405
+ f"\nPOST TEXT:\n{obs.get('post_text','')}",
406
+ f"\nIMAGE:\n{obs.get('image_description','')}",
407
+ ]
408
+ for key, label in [
409
+ ("comments_found","COMMENTS"),("user_history_found","USER HISTORY"),
410
+ ("entity_status_found","ENTITY STATUS"),("policy_found","POLICY"),
411
+ ]:
412
+ if obs.get(key):
413
+ lines.append(f"\n{label}:\n{obs[key]}")
414
+ if obs.get("extracted_signals"):
415
+ lines.append(f"\nYOUR EXTRACTED SIGNALS:\n{json.dumps(obs['extracted_signals'],indent=2)}")
416
+ if obs.get("validation_result"):
417
+ v = obs["validation_result"]
418
+ hint = "⚠️ YES — prefer ESCALATE" if v.get("needs_escalation_hint") else "No"
419
+ lines.append(
420
+ f"\n📋 VALIDATION:\n"
421
+ f" Adj. Confidence : {v.get('adjusted_confidence')}\n"
422
+ f" Issues : {v.get('consistency_issues')}\n"
423
+ f" Escalation Hint : {hint}"
424
+ )
425
+ if not obs.get("extracted_signals"):
426
+ lines.append("\n⚠️ REMINDER: Call extract_signals before final_decision.")
427
+ lines.append("\nYour next action (strict JSON only):")
428
+ return "\n".join(lines)
429
+
430
+
431
+ def run_task(client: OpenAI, task_id: str) -> float:
432
+ for _ in range(30):
433
+ # CORRECT ✅ — pass task ID directly
434
+ r = requests.post(
435
+ f"{ENV_BASE_URL}/reset",
436
+ json={"episode_id": task_id}, # ← this is the only change
437
+ timeout=10
438
+ )
439
+ r.raise_for_status()
440
+ obs = r.json()
441
+ # Handle both flat (TrustObservation) and wrapped response
442
+ if isinstance(obs, dict) and "observation" in obs:
443
+ obs = obs["observation"]
444
+ if obs.get("ticket_id") == task_id:
445
+ break
446
+ else:
447
+ raise RuntimeError(f"Could not get task {task_id} after 30 resets.")
448
+
449
+ print(f"\n{'='*62}\nTask: {task_id} | Starting...\n{'='*62}")
450
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
451
+ final_reward = 0.0
452
+
453
+ for step_num in range(14):
454
+ messages.append({"role": "user", "content": _obs_to_prompt(obs)})
455
+ time.sleep(0.5)
456
+
457
+ resp = client.chat.completions.create(
458
+ model=MODEL_NAME, messages=messages, temperature=0.0,
459
+ response_format={"type": "json_object"},
460
+ )
461
+ llm_text = resp.choices[0].message.content or ""
462
+ messages.append({"role": "assistant", "content": llm_text})
463
+
464
+ try:
465
+ action = _normalize(_parse(llm_text))
466
+ except Exception as ex:
467
+ print(f" [Step {step_num+1}] Parse error: {ex}"); break
468
+
469
+ atype = action["action_type"]
470
+ if atype == "use_tool":
471
+ print(f" [Step {step_num+1}] 🔧 use_tool → {action.get('tool_name')}")
472
+ elif atype == "extract_signals":
473
+ s = action.get("signals", {})
474
+ print(f" [Step {step_num+1}] 🔍 extract_signals → "
475
+ f"intent={s.get('intent')} | ctx={s.get('context_type')} | "
476
+ f"tox={s.get('toxicity_level')} | conf={s.get('confidence')}")
477
+ else:
478
+ print(f" [Step {step_num+1}] ⚖️ final_decision → {action.get('final_decision')}")
479
+
480
+ r2 = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
481
+ r2.raise_for_status()
482
+ result = r2.json()
483
+
484
+ # Handle flat (TrustObservation) and wrapped response
485
+ if "observation" in result:
486
+ obs = result["observation"]
487
+ done = result.get("done", obs.get("done", False))
488
+ final_reward = float(result.get("reward") or obs.get("reward") or 0.0)
489
+ else:
490
+ obs = result
491
+ done = result.get("done", False)
492
+ final_reward = float(result.get("reward") or 0.0)
493
+
494
+ if done:
495
+ info = obs.get("info") or {}
496
+ bd = info.get("reward_breakdown", {})
497
+ pol = info.get("policy_recommendation", {})
498
+ vr = obs.get("validation_result") or {}
499
+
500
+ print(f"\n ── EPISODE COMPLETE {'─'*42}")
501
+ print(f" Decision: {info.get('final_decision','N/A')}")
502
+ print(f" Ground Truth: {info.get('ground_truth','N/A')}")
503
+ print(f" Policy Engine: {pol.get('recommended','N/A')} "
504
+ f"[{pol.get('rule_strength','?')} rule] ({pol.get('reason','?')})")
505
+ print(f" Signals Extracted: {'✅' if info.get('signals_extracted') else '❌ SKIPPED'}")
506
+ print(f" Tools Used: {info.get('tools_used', [])}")
507
+ print(f" Required Tools: {info.get('required_tools', [])}")
508
+ print(f" Adj. Confidence: {vr.get('adjusted_confidence','N/A')}")
509
+ print(f" Issues: {vr.get('consistency_issues',[])}")
510
+ print(f" Ambiguity / Risk: {info.get('ambiguity_level','?')} / {info.get('risk_level','?')}")
511
+ if bd:
512
+ print(f"\n ── Reward Breakdown {'─'*42}")
513
+ print(f" 1. Base Decision Score: {bd.get('base_score',0):+.4f}")
514
+ print(f" 2. Policy Alignment: {bd.get('policy_alignment',0):+.4f}")
515
+ print(f" 3. Signal Accuracy Bonus: {bd.get('signal_accuracy_bonus',0):+.4f}")
516
+ print(f" 4. Escalation Adjustment: {bd.get('escalation_adj',0):+.4f}")
517
+ print(f" 5. Signal Process Bonus: {bd.get('signal_bonus',0):+.4f}")
518
+ print(f" Tool Cost: -{bd.get('tool_cost',0):.4f}")
519
+ print(f" Tool Miss Penalty: -{bd.get('tool_miss_penalty',0):.4f}")
520
+ print(f" Validation Penalty: -{bd.get('validation_penalty',0):.4f}")
521
+ print(f" Risk Penalty: -{bd.get('risk_penalty',0):.4f}")
522
+ print(f" Confidence Discipline: -{bd.get('confidence_penalty',0):.4f}")
523
+ print(f" {'─'*60}")
524
+ print(f" FINAL REWARD: {bd.get('final_reward',0):.4f}")
525
+ print(f"\n SCORE: {final_reward:.4f}")
526
+ break
527
+
528
+ return final_reward
529
+
530
+
531
+ def main() -> None:
532
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
533
+
534
+ print("=" * 62)
535
+ print("Trust & Safety RL Environment — Baseline Evaluation")
536
+ print("=" * 62)
537
+ print(f"Model : {MODEL_NAME}")
538
+ print(f"LLM API : {API_BASE_URL}")
539
+ print(f"Env Server : {ENV_BASE_URL}")
540
+ print(f"Reward : Accuracy · Policy · Signals · Escalation")
541
+ print(f" Tools · Consistency · Risk · Confidence")
542
+
543
+ tasks = [
544
+ ("T-001", "Easy — Phishing Spam", "low"),
545
+ ("T-002", "Medium — Gaming Banter", "low"),
546
+ ("T-003", "Hard — Political Satire", "high"),
547
+ ]
548
+ scores = []
549
+ for tid, desc, risk in tasks:
550
+ print(f"\n\n>>> {tid} | {desc} | Risk: {risk}")
551
+ scores.append((tid, desc, run_task(client, tid)))
552
+
553
+ print("\n" + "=" * 62)
554
+ print("FINAL BASELINE RESULTS")
555
+ print("=" * 62)
556
+ total = 0.0
557
+ for tid, desc, s in scores:
558
+ print(f" {tid} | {desc:<32} | {s:.4f} {'✅ PASS' if s >= 0.6 else '❌ FAIL'}")
559
+ total += s
560
+ vals = [s for _, _, s in scores]
561
+ print(f"\n Average : {total/len(scores):.4f}")
562
+ print(f" Min : {min(vals):.4f} | Max : {max(vals):.4f}")
563
+ print("=" * 62)
564
+
565
+
566
+ if __name__ == "__main__":
567
+ main()
568
+
569
+ models.py:
570
+ """
571
+ inference.py — Trust & Safety RL Environment Evaluation
572
+ ========================================================
573
+ MANDATORY env vars:
574
+ API_BASE_URL LLM endpoint (e.g. https://router.huggingface.co/v1)
575
+ MODEL_NAME Model ID (e.g. meta-llama/Llama-3.1-8B-Instruct)
576
+ HF_TOKEN API key
577
+ ENV_BASE_URL Environment server URL (default: http://localhost:8000)
578
+ """
579
+
580
+ import os, json, time, requests
581
+ from openai import OpenAI
582
+
583
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
584
+ API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
585
+ MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
586
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000")
587
+
588
+ if not API_KEY:
589
+ raise EnvironmentError("Set HF_TOKEN (your Hugging Face / API key).")
590
+ if not MODEL_NAME:
591
+ raise EnvironmentError("Set MODEL_NAME.")
592
+
593
+ SYSTEM_PROMPT = """You are a Trust & Safety analyst. Follow this EXACT 3-stage pipeline.
594
+ Skipping any stage results in a score penalty.
595
+
596
+ ════════════════════════════════════════════
597
+ STAGE 1 — Investigate (optional)
598
+ ════════════════════════════════════════════
599
+ Use tools only when context is needed. For OBVIOUS violations (spam, phishing), skip.
600
+
601
+ {"action_type": "use_tool", "tool_name": "read_comments"}
602
+ {"action_type": "use_tool", "tool_name": "check_user_history"}
603
+ {"action_type": "use_tool", "tool_name": "check_entity_status"}
604
+ {"action_type": "use_tool", "tool_name": "view_policy"}
605
+
606
+ ════════════════════════════════════════════
607
+ STAGE 2 — Extract signals (MANDATORY)
608
+ ════════════════════════════════════════════
609
+ You MUST call extract_signals before final_decision. No exceptions.
610
+ You are a FEATURE EXTRACTOR — output what the content shows, do NOT decide.
611
+
612
+ {"action_type": "extract_signals", "signals": {
613
+ "target": "individual|group|religion|race|gender|political|none",
614
+ "is_protected_class": true|false,
615
+ "toxicity_level": 0.0-1.0,
616
+ "is_direct_attack": true|false,
617
+ "context_type": "statement|satire|news|meme|personal_attack|protest|education|spam",
618
+ "intent": "harmful|descriptive|ambiguous|satirical|educational",
619
+ "confidence": 0.0-1.0,
620
+ "abusive_language_present": true|false,
621
+ "content_flags": ["spam","phishing","slur","threat","misinformation",
622
+ "political_content","public_figure","competitive_language"]
623
+ }}
624
+
625
+ After submitting signals you receive VALIDATION. If needs_escalation_hint is true → prefer ESCALATE.
626
+
627
+ ════════════════════════════════════════════
628
+ STAGE 3 — Final decision (MANDATORY)
629
+ ════════════════════════════════════════════
630
+ {"action_type": "final_decision", "final_decision": "ALLOW|REMOVE|ALLOW_WITH_WARNING|ESCALATE"}
631
+
632
+ DECISION RULES (apply in order):
633
+ 1. Phishing / spam / impersonation → REMOVE
634
+ 2. Explicit hate speech, direct threats → REMOVE
635
+ 3. Political satire of a PUBLIC FIGURE → ALLOW_WITH_WARNING
636
+ 4. News / journalism / education → ALLOW_WITH_WARNING
637
+ 5. Gaming / sports banter (confirmed) → ALLOW
638
+ 6. Adjusted confidence after validation < 0.6 → ESCALATE
639
+ 7. Genuinely ambiguous, no clear rule → ESCALATE
640
+ 8. Safe, no signals triggered → ALLOW
641
+
642
+ NEVER jump from Stage 1 directly to final_decision.
643
+ ALWAYS: [optional tools] → extract_signals → final_decision
644
+ Respond in strict JSON only. No markdown."""
645
+
646
+ SIGNAL_DEFAULTS = {
647
+ "target": "none", "is_protected_class": False,
648
+ "toxicity_level": 0.5, "is_direct_attack": False,
649
+ "context_type": "statement", "intent": "ambiguous",
650
+ "confidence": 0.5, "abusive_language_present": False,
651
+ "content_flags": [],
652
+ }
653
+ VALID_TOOLS = {"read_comments","check_user_history","check_entity_status","view_policy"}
654
+ VALID_DECISIONS = {"ALLOW","REMOVE","ALLOW_WITH_WARNING","ESCALATE"}
655
+ VALID_TYPES = {"use_tool","extract_signals","final_decision"}
656
+
657
+
658
+ def _clamp_signals(raw: dict) -> dict:
659
+ result = SIGNAL_DEFAULTS.copy()
660
+ for k in SIGNAL_DEFAULTS:
661
+ if k in raw:
662
+ result[k] = raw[k]
663
+ try:
664
+ result["toxicity_level"] = max(0.0, min(1.0, float(result["toxicity_level"])))
665
+ result["confidence"] = max(0.0, min(1.0, float(result["confidence"])))
666
+ except (TypeError, ValueError):
667
+ result["toxicity_level"] = 0.5
668
+ result["confidence"] = 0.5
669
+ if not isinstance(result["content_flags"], list):
670
+ result["content_flags"] = []
671
+ return result
672
+
673
+
674
+ def _parse(text: str) -> dict:
675
+ text = text.strip()
676
+ s, e = text.find("{"), text.rfind("}") + 1
677
+ if s == -1 or e == 0:
678
+ raise ValueError(f"No JSON in: {text}")
679
+ return json.loads(text[s:e])
680
+
681
+
682
+ def _normalize(raw: dict) -> dict:
683
+ t = raw.get("action_type", "")
684
+ if t not in VALID_TYPES:
685
+ return {"action_type": "final_decision", "final_decision": "ESCALATE"}
686
+ if t == "use_tool":
687
+ tool = raw.get("tool_name", "")
688
+ return {"action_type": "use_tool", "tool_name": tool} if tool in VALID_TOOLS \
689
+ else {"action_type": "final_decision", "final_decision": "ESCALATE"}
690
+ if t == "extract_signals":
691
+ sigs = raw.get("signals")
692
+ return {"action_type": "extract_signals", "signals": _clamp_signals(sigs)} \
693
+ if sigs else {"action_type": "final_decision", "final_decision": "ESCALATE"}
694
+ dec = raw.get("final_decision", "ESCALATE")
695
+ return {"action_type": "final_decision",
696
+ "final_decision": dec if dec in VALID_DECISIONS else "ESCALATE"}
697
+
698
+
699
+ def _obs_to_prompt(obs: dict) -> str:
700
+ lines = [
701
+ f"=== TICKET {obs.get('ticket_id','')} (Step {obs.get('step_number',0)}) ===",
702
+ f"\nPOST TEXT:\n{obs.get('post_text','')}",
703
+ f"\nIMAGE:\n{obs.get('image_description','')}",
704
+ ]
705
+ for key, label in [
706
+ ("comments_found","COMMENTS"),("user_history_found","USER HISTORY"),
707
+ ("entity_status_found","ENTITY STATUS"),("policy_found","POLICY"),
708
+ ]:
709
+ if obs.get(key):
710
+ lines.append(f"\n{label}:\n{obs[key]}")
711
+ if obs.get("extracted_signals"):
712
+ lines.append(f"\nYOUR EXTRACTED SIGNALS:\n{json.dumps(obs['extracted_signals'],indent=2)}")
713
+ if obs.get("validation_result"):
714
+ v = obs["validation_result"]
715
+ hint = "⚠️ YES — prefer ESCALATE" if v.get("needs_escalation_hint") else "No"
716
+ lines.append(
717
+ f"\n📋 VALIDATION:\n"
718
+ f" Adj. Confidence : {v.get('adjusted_confidence')}\n"
719
+ f" Issues : {v.get('consistency_issues')}\n"
720
+ f" Escalation Hint : {hint}"
721
+ )
722
+ if not obs.get("extracted_signals"):
723
+ lines.append("\n⚠️ REMINDER: Call extract_signals before final_decision.")
724
+ lines.append("\nYour next action (strict JSON only):")
725
+ return "\n".join(lines)
726
+
727
+
728
+ def run_task(client: OpenAI, task_id: str) -> float:
729
+ for _ in range(30):
730
+ # CORRECT ✅ — pass task ID directly
731
+ r = requests.post(
732
+ f"{ENV_BASE_URL}/reset",
733
+ json={"episode_id": task_id}, # ← this is the only change
734
+ timeout=10
735
+ )
736
+ r.raise_for_status()
737
+ obs = r.json()
738
+ # Handle both flat (TrustObservation) and wrapped response
739
+ if isinstance(obs, dict) and "observation" in obs:
740
+ obs = obs["observation"]
741
+ if obs.get("ticket_id") == task_id:
742
+ break
743
+ else:
744
+ raise RuntimeError(f"Could not get task {task_id} after 30 resets.")
745
+
746
+ print(f"\n{'='*62}\nTask: {task_id} | Starting...\n{'='*62}")
747
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
748
+ final_reward = 0.0
749
+
750
+ for step_num in range(14):
751
+ messages.append({"role": "user", "content": _obs_to_prompt(obs)})
752
+ time.sleep(0.5)
753
+
754
+ resp = client.chat.completions.create(
755
+ model=MODEL_NAME, messages=messages, temperature=0.0,
756
+ response_format={"type": "json_object"},
757
+ )
758
+ llm_text = resp.choices[0].message.content or ""
759
+ messages.append({"role": "assistant", "content": llm_text})
760
+
761
+ try:
762
+ action = _normalize(_parse(llm_text))
763
+ except Exception as ex:
764
+ print(f" [Step {step_num+1}] Parse error: {ex}"); break
765
+
766
+ atype = action["action_type"]
767
+ if atype == "use_tool":
768
+ print(f" [Step {step_num+1}] 🔧 use_tool → {action.get('tool_name')}")
769
+ elif atype == "extract_signals":
770
+ s = action.get("signals", {})
771
+ print(f" [Step {step_num+1}] 🔍 extract_signals → "
772
+ f"intent={s.get('intent')} | ctx={s.get('context_type')} | "
773
+ f"tox={s.get('toxicity_level')} | conf={s.get('confidence')}")
774
+ else:
775
+ print(f" [Step {step_num+1}] ⚖️ final_decision → {action.get('final_decision')}")
776
+
777
+ r2 = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
778
+ r2.raise_for_status()
779
+ result = r2.json()
780
+
781
+ # Handle flat (TrustObservation) and wrapped response
782
+ if "observation" in result:
783
+ obs = result["observation"]
784
+ done = result.get("done", obs.get("done", False))
785
+ final_reward = float(result.get("reward") or obs.get("reward") or 0.0)
786
+ else:
787
+ obs = result
788
+ done = result.get("done", False)
789
+ final_reward = float(result.get("reward") or 0.0)
790
+
791
+ if done:
792
+ info = obs.get("info") or {}
793
+ bd = info.get("reward_breakdown", {})
794
+ pol = info.get("policy_recommendation", {})
795
+ vr = obs.get("validation_result") or {}
796
+
797
+ print(f"\n ── EPISODE COMPLETE {'─'*42}")
798
+ print(f" Decision: {info.get('final_decision','N/A')}")
799
+ print(f" Ground Truth: {info.get('ground_truth','N/A')}")
800
+ print(f" Policy Engine: {pol.get('recommended','N/A')} "
801
+ f"[{pol.get('rule_strength','?')} rule] ({pol.get('reason','?')})")
802
+ print(f" Signals Extracted: {'✅' if info.get('signals_extracted') else '❌ SKIPPED'}")
803
+ print(f" Tools Used: {info.get('tools_used', [])}")
804
+ print(f" Required Tools: {info.get('required_tools', [])}")
805
+ print(f" Adj. Confidence: {vr.get('adjusted_confidence','N/A')}")
806
+ print(f" Issues: {vr.get('consistency_issues',[])}")
807
+ print(f" Ambiguity / Risk: {info.get('ambiguity_level','?')} / {info.get('risk_level','?')}")
808
+ if bd:
809
+ print(f"\n ── Reward Breakdown {'─'*42}")
810
+ print(f" 1. Base Decision Score: {bd.get('base_score',0):+.4f}")
811
+ print(f" 2. Policy Alignment: {bd.get('policy_alignment',0):+.4f}")
812
+ print(f" 3. Signal Accuracy Bonus: {bd.get('signal_accuracy_bonus',0):+.4f}")
813
+ print(f" 4. Escalation Adjustment: {bd.get('escalation_adj',0):+.4f}")
814
+ print(f" 5. Signal Process Bonus: {bd.get('signal_bonus',0):+.4f}")
815
+ print(f" Tool Cost: -{bd.get('tool_cost',0):.4f}")
816
+ print(f" Tool Miss Penalty: -{bd.get('tool_miss_penalty',0):.4f}")
817
+ print(f" Validation Penalty: -{bd.get('validation_penalty',0):.4f}")
818
+ print(f" Risk Penalty: -{bd.get('risk_penalty',0):.4f}")
819
+ print(f" Confidence Discipline: -{bd.get('confidence_penalty',0):.4f}")
820
+ print(f" {'─'*60}")
821
+ print(f" FINAL REWARD: {bd.get('final_reward',0):.4f}")
822
+ print(f"\n SCORE: {final_reward:.4f}")
823
+ break
824
+
825
+ return final_reward
826
+
827
+
828
+ def main() -> None:
829
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
830
+
831
+ print("=" * 62)
832
+ print("Trust & Safety RL Environment — Baseline Evaluation")
833
+ print("=" * 62)
834
+ print(f"Model : {MODEL_NAME}")
835
+ print(f"LLM API : {API_BASE_URL}")
836
+ print(f"Env Server : {ENV_BASE_URL}")
837
+ print(f"Reward : Accuracy · Policy · Signals · Escalation")
838
+ print(f" Tools · Consistency · Risk · Confidence")
839
+
840
+ tasks = [
841
+ ("T-001", "Easy — Phishing Spam", "low"),
842
+ ("T-002", "Medium — Gaming Banter", "low"),
843
+ ("T-003", "Hard — Political Satire", "high"),
844
+ ]
845
+ scores = []
846
+ for tid, desc, risk in tasks:
847
+ print(f"\n\n>>> {tid} | {desc} | Risk: {risk}")
848
+ scores.append((tid, desc, run_task(client, tid)))
849
+
850
+ print("\n" + "=" * 62)
851
+ print("FINAL BASELINE RESULTS")
852
+ print("=" * 62)
853
+ total = 0.0
854
+ for tid, desc, s in scores:
855
+ print(f" {tid} | {desc:<32} | {s:.4f} {'✅ PASS' if s >= 0.6 else '❌ FAIL'}")
856
+ total += s
857
+ vals = [s for _, _, s in scores]
858
+ print(f"\n Average : {total/len(scores):.4f}")
859
+ print(f" Min : {min(vals):.4f} | Max : {max(vals):.4f}")
860
+ print("=" * 62)
861
+
862
+
863
+ if __name__ == "__main__":
864
+ main()
865
+
866
+
867
+ openenv.yaml:
868
+ spec_version: 1
869
+ name: trust-safety-env
870
+ type: environment
871
+ runtime: python
872
+ app: app:app
873
+ port: 8000
874
+
875
+ description: >
876
+ Risk-aware content moderation RL environment for Trust & Safety decision-making.
877
+ Agents investigate content, extract structured signals, and make policy-aligned
878
+ decisions under uncertainty across hate speech, political sensitivity, and
879
+ cultural nuance. Models real-world moderation at scale (Meta-style).
880
+
881
+ author: Jerome Richard D
882
+ version: "1.0.0"
883
+ license: MIT
884
+
885
+ action_space:
886
+ type: TrustAction
887
+ description: "use_tool | extract_signals | final_decision"
888
+
889
+ observation_space:
890
+ type: TrustObservation
891
+ description: "Content ticket with progressive context revelation"
892
+
893
+ tasks:
894
+ - id: T-001
895
+ name: Phishing Spam Detection
896
+ difficulty: easy
897
+ description: Identify and remove clear phishing / impersonation content
898
+
899
+ - id: T-002
900
+ name: Gaming Banter Classification
901
+ difficulty: medium
902
+ description: Distinguish competitive gaming banter from genuine harassment
903
+
904
+ - id: T-003
905
+ name: Political Satire Review
906
+ difficulty: hard
907
+ description: Handle editorial satire of public figures with high-risk sensitivity
908
+
909
+ - id: T-004
910
+ name: Hate Speech Disguised as Education
911
+ difficulty: medium
912
+ description: Detect hate speech hidden behind pseudoscientific or educational framing
913
+
914
+ - id: T-005
915
+ name: Political News with Protest Violence
916
+ difficulty: hard
917
+ description: Protect legitimate journalism on sensitive political events without over-censorship
918
+
919
+ - id: T-006
920
+ name: Religious Expression False Flag
921
+ difficulty: hard
922
+ description: Distinguish protected religious expression from automated false-positive flag
923
+
924
+ tags:
925
+ - content-moderation
926
+ - trust-safety
927
+ - hate-speech
928
+ - political-sensitivity
929
+ - cultural-nuance
930
+ - real-world
931
+ - openenv
932
+
933
+ pyproject.toml:
934
+ [build-system]
935
+ requires = ["setuptools>=68.0", "wheel"]
936
+ build-backend = "setuptools.backends.legacy:build"
937
+
938
+ [project]
939
+ name = "trust-safety-env"
940
+ version = "1.0.0"
941
+ description = "Risk-aware Trust & Safety content moderation RL environment — OpenEnv compatible"
942
+ readme = "README.md"
943
+ requires-python = ">=3.11"
944
+ dependencies = [
945
+ "openenv-core>=0.2.0",
946
+ "fastapi>=0.110.0",
947
+ "uvicorn[standard]>=0.29.0",
948
+ "pydantic>=2.6.0",
949
+ "openai>=1.30.0",
950
+ "requests>=2.31.0",
951
+ "python-dotenv>=1.0.0",
952
+ ]
953
+
954
+ [project.optional-dependencies]
955
+ dev = ["pytest>=8.0"]
956
+
957
+ [tool.setuptools.packages.find]
958
+ where = ["."]
959
+ include = ["*"]
960
+
961
+ [tool.openenv]
962
+ name = "trust-safety-env"
963
+ environment_class = "your_environment.TrustSafetyEnvironment"
964
+ action_model = "models.TrustAction"
965
+ observation_model = "models.TrustObservation"
966
+ state_model = "models.TrustState"
967
+
968
+ requirements.txt:
969
+
970
+ fastapi>=0.110.0
971
+ uvicorn[standard]>=0.29.0
972
+ pydantic>=2.6.0
973
+ openai>=1.30.0
974
+ requests>=2.31.0
975
+ python-dotenv>=1.0.0
976
+
977
+
978
+ tasks.py :
979
+ TASKS = [
980
+ {
981
+ "task_id": "T-001",
982
+ "difficulty": "easy",
983
+ "ambiguity_level": "low",
984
+ "risk_level": "low",
985
+ "post_text": (
986
+ "🚨 URGENT: Your account has been COMPROMISED! "
987
+ "Click here immediately: http://meta-security-verify-now.xyz/login?token=abc123 "
988
+ "Act within 24 hours or your account will be PERMANENTLY BANNED. — Meta Security Team"
989
+ ),
990
+ "image_description": (
991
+ "Poorly formatted alert page with mismatched Meta logos, "
992
+ "red urgent warning banners, and a suspicious external URL."
993
+ ),
994
+ "tool_responses": {
995
+ "read_comments": "Comments disabled. Zero engagement.",
996
+ "check_user_history": "Account created 2 days ago. Zero posts. Flagged by spam filter 3×.",
997
+ "check_entity_status": "Domain 'meta-security-verify-now.xyz' registered 1 day ago — confirmed phishing domain.",
998
+ "view_policy": "Policy 3.2 — Phishing: Impersonating Meta to harvest credentials = immediate REMOVE.",
999
+ },
1000
+ "required_tools": [],
1001
+ "ground_truth": "REMOVE",
1002
+ "ground_truth_signals": {
1003
+ "target": "individual",
1004
+ "is_protected_class": False,
1005
+ "toxicity_level": 0.90,
1006
+ "is_direct_attack": True,
1007
+ "context_type": "spam",
1008
+ "intent": "harmful",
1009
+ "abusive_language_present": False,
1010
+ "content_flags": ["phishing", "spam", "impersonation"],
1011
+ },
1012
+ },
1013
+ {
1014
+ "task_id": "T-002",
1015
+ "difficulty": "medium",
1016
+ "ambiguity_level": "medium",
1017
+ "risk_level": "low",
1018
+ "post_text": (
1019
+ "LMAO I will literally SNIPE you the second the game starts, "
1020
+ "no mercy, you're getting absolutely destroyed 💀🎮 "
1021
+ "1v1 me and I'll end your whole career bro. Your team is DONE."
1022
+ ),
1023
+ "image_description": (
1024
+ "Call of Duty: Warzone — squad wipe in progress. "
1025
+ "Kill notifications, sniper rifle, enemy team eliminated. Classic gaming screenshot."
1026
+ ),
1027
+ "tool_responses": {
1028
+ "read_comments": (
1029
+ "47 comments: 'GG bro!', 'You always trash talk before CoD 😂', "
1030
+ "'rematch tonight??', 'your aim is mid bro'. "
1031
+ "All mutual gaming friends. Clear competitive banter context."
1032
+ ),
1033
+ "check_user_history": (
1034
+ "Account 3 years active. 500+ posts, 90% gaming. "
1035
+ "Zero harassment reports. Regular CoD tournament player."
1036
+ ),
1037
+ "check_entity_status": "Private individual. Standard gaming community member.",
1038
+ "view_policy": (
1039
+ "Policy 7.1 — Harassment: Hyperbolic language in established gaming communities "
1040
+ "is NOT harassment. 'Snipe', 'destroy', 'end your career' are non-literal expressions."
1041
+ ),
1042
+ },
1043
+ "required_tools": ["read_comments"],
1044
+ "ground_truth": "ALLOW",
1045
+ "ground_truth_signals": {
1046
+ "target": "individual",
1047
+ "is_protected_class": False,
1048
+ "toxicity_level": 0.30,
1049
+ "is_direct_attack": False,
1050
+ "context_type": "meme",
1051
+ "intent": "ambiguous",
1052
+ "abusive_language_present": False,
1053
+ "content_flags": ["competitive_language"],
1054
+ },
1055
+ },
1056
+ {
1057
+ "task_id": "T-003",
1058
+ "difficulty": "hard",
1059
+ "ambiguity_level": "high",
1060
+ "risk_level": "high",
1061
+ "post_text": (
1062
+ "Sharing this political cartoon I drew. "
1063
+ "Mayor Davidson cutting the homeless shelter budget while sitting on a throne of gold coins. "
1064
+ "Caption: 'Priorities.' #LocalPolitics #HoldThemAccountable #SatiricalArt #MayorDavidson"
1065
+ ),
1066
+ "image_description": (
1067
+ "Hand-drawn political cartoon. Exaggerated caricature of a public official on a coin throne, "
1068
+ "homeless families pushed away. Deliberately satirical editorial style. Artist signature visible."
1069
+ ),
1070
+ "tool_responses": {
1071
+ "read_comments": (
1072
+ "23 comments: 'Powerful commentary', 'Free speech!', 'Great art', 'Defamation!'. "
1073
+ "No threats or coordinated harassment detected."
1074
+ ),
1075
+ "check_user_history": (
1076
+ "Account 5 years. Verified local journalist and artist. "
1077
+ "Regular political commentary. No violations. Posts cited in local news."
1078
+ ),
1079
+ "check_entity_status": (
1080
+ "Mayor Davidson: CONFIRMED PUBLIC FIGURE. "
1081
+ "Elected Mayor of Springfield. Active public interest in recent social services budget cuts. "
1082
+ "Multiple credible news articles confirm the cuts are real and newsworthy."
1083
+ ),
1084
+ "view_policy": (
1085
+ "Policy 12.4 — Political Satire: Editorial cartoons targeting PUBLIC OFFICIALS "
1086
+ "on PUBLIC INTEREST matters are PROTECTED. Caricature is a recognized technique. "
1087
+ "Not realistic impersonation. Correct action: ALLOW_WITH_WARNING."
1088
+ ),
1089
+ },
1090
+ "required_tools": ["check_entity_status", "view_policy"],
1091
+ "ground_truth": "ALLOW_WITH_WARNING",
1092
+ "ground_truth_signals": {
1093
+ "target": "political",
1094
+ "is_protected_class": False,
1095
+ "toxicity_level": 0.20,
1096
+ "is_direct_attack": False,
1097
+ "context_type": "satire",
1098
+ "intent": "satirical",
1099
+ "abusive_language_present": False,
1100
+ "content_flags": ["political_content", "public_figure"],
1101
+ },
1102
+ },
1103
+ # ── Append these to the TASKS list ──────────────────────────────────────────
1104
+
1105
+ {
1106
+ "task_id": "T-004",
1107
+ "difficulty": "medium",
1108
+ "ambiguity_level": "medium",
1109
+ "risk_level": "high",
1110
+ "post_text": (
1111
+ "Sharing this 'educational' video about why [ethnic group] people "
1112
+ "are genetically predisposed to criminal behaviour. "
1113
+ "The science doesn't lie. Share before it gets deleted. 👀 "
1114
+ "#HiddenTruths #CensoredFacts #JustAsking"
1115
+ ),
1116
+ "image_description": (
1117
+ "Screenshot of a YouTube-style video thumbnail with a white lab coat figure "
1118
+ "and a graph. Title: 'The Uncomfortable Biology of Crime Rates'. "
1119
+ "Uploader has 0 verified credentials. Comments disabled."
1120
+ ),
1121
+ "tool_responses": {
1122
+ "read_comments": (
1123
+ "Comments disabled by uploader. "
1124
+ "Zero engagement from credible academic or scientific accounts."
1125
+ ),
1126
+ "check_user_history": (
1127
+ "Account 6 months old. 48 posts — all promoting pseudoscientific race content. "
1128
+ "Previously received 2 policy warnings for hate content. "
1129
+ "No verified credentials. Frequently uses 'just asking questions' framing."
1130
+ ),
1131
+ "check_entity_status": (
1132
+ "Ethnic group mentioned: confirmed PROTECTED CLASS under platform policy. "
1133
+ "Claim about genetic criminality: DEBUNKED by mainstream biology and criminology. "
1134
+ "This framing is a known pseudoscience hate vector (scientific racism)."
1135
+ ),
1136
+ "view_policy": (
1137
+ "Policy 4.3 — Hate Speech: Content attributing criminal or negative traits "
1138
+ "to a protected class based on race or ethnicity = REMOVE. "
1139
+ "Disguising hate speech as 'educational content' or 'science' does not "
1140
+ "change the violation. Action: REMOVE."
1141
+ ),
1142
+ },
1143
+ "required_tools": ["check_entity_status", "view_policy"],
1144
+ "ground_truth": "REMOVE",
1145
+ "ground_truth_signals": {
1146
+ "target": "group",
1147
+ "is_protected_class": True,
1148
+ "toxicity_level": 0.85,
1149
+ "is_direct_attack": True,
1150
+ "context_type": "statement",
1151
+ "intent": "harmful",
1152
+ "abusive_language_present": False,
1153
+ "content_flags": ["slur", "misinformation", "hate_speech_disguised_as_education"],
1154
+ },
1155
+ },
1156
+
1157
+ {
1158
+ "task_id": "T-005",
1159
+ "difficulty": "hard",
1160
+ "ambiguity_level": "high",
1161
+ "risk_level": "high",
1162
+ "post_text": (
1163
+ "The Prime Minister announced new economic reforms today. "
1164
+ "Opposition leaders called it 'economic warfare on the poor.' "
1165
+ "Protests erupted in 3 cities. Security forces used tear gas. "
1166
+ "Government says the crackdown was 'necessary to restore order.' "
1167
+ "At least 12 protesters arrested. #Politics #Protest #Reform"
1168
+ ),
1169
+ "image_description": (
1170
+ "News-style post with a photo of a large protest crowd holding signs. "
1171
+ "Tear gas clouds visible in background. "
1172
+ "Caption: 'Thousands take to streets against new austerity measures.' "
1173
+ "Shared by a verified news outlet handle. No graphic violence."
1174
+ ),
1175
+ "tool_responses": {
1176
+ "read_comments": (
1177
+ "340 comments split sharply: "
1178
+ "'State violence against citizens!', 'Finally someone fixing the economy', "
1179
+ "'This is propaganda!', 'Biased reporting!', 'Necessary action'. "
1180
+ "High engagement, deep political polarization. No incitement to violence detected."
1181
+ ),
1182
+ "check_user_history": (
1183
+ "Posted by @NationalNewsNetwork — verified news organization. "
1184
+ "12 years active. Consistently neutral coverage. "
1185
+ "No policy violations. Registered press entity."
1186
+ ),
1187
+ "check_entity_status": (
1188
+ "Prime Minister: CONFIRMED PUBLIC FIGURE, head of government. "
1189
+ "Economic reforms: CONFIRMED real policy, actively debated in parliament. "
1190
+ "Protest: CONFIRMED real event, covered by 6+ independent news agencies. "
1191
+ "Political sensitivity: VERY HIGH — active governance dispute in a democracy."
1192
+ ),
1193
+ "view_policy": (
1194
+ "Policy 12.1 — News & Political Content: Factual reporting on real government "
1195
+ "actions, protests, and police responses by VERIFIED NEWS entities is PROTECTED. "
1196
+ "Even politically contentious factual reporting is protected expression. "
1197
+ "Suppressing legitimate news = censorship risk. "
1198
+ "Action: ALLOW_WITH_WARNING (politically sensitive content label)."
1199
+ ),
1200
+ },
1201
+ "required_tools": ["check_user_history", "check_entity_status", "view_policy"],
1202
+ "ground_truth": "ALLOW_WITH_WARNING",
1203
+ "ground_truth_signals": {
1204
+ "target": "political",
1205
+ "is_protected_class": False,
1206
+ "toxicity_level": 0.20,
1207
+ "is_direct_attack": False,
1208
+ "context_type": "news",
1209
+ "intent": "descriptive",
1210
+ "abusive_language_present": False,
1211
+ "content_flags": ["political_content", "public_figure", "protest"],
1212
+ },
1213
+ },
1214
+
1215
+ {
1216
+ "task_id": "T-006",
1217
+ "difficulty": "hard",
1218
+ "ambiguity_level": "high",
1219
+ "risk_level": "high",
1220
+ "post_text": (
1221
+ "Sending Eid Mubarak wishes to my community 🌙✨ "
1222
+ "During this holy month let us remember those suffering in conflict zones. "
1223
+ "Our brothers and sisters deserve peace and dignity. "
1224
+ "May Allah guide us all. #EidMubarak #Ramadan #Peace #Solidarity"
1225
+ ),
1226
+ "image_description": (
1227
+ "Warm celebratory post with a crescent moon and stars graphic. "
1228
+ "Soft gold and green colors. "
1229
+ "Comment section shows replies in Arabic, Urdu, and English. "
1230
+ "Emojis: praying hands, crescent moon, dove."
1231
+ ),
1232
+ "tool_responses": {
1233
+ "read_comments": (
1234
+ "89 comments: 'Eid Mubarak!', 'JazakAllah Khair', "
1235
+ "'prayers for peace', 'beautiful message', "
1236
+ "1 comment flagged by automated system: 'why is this allowed, "
1237
+ "these people support terrorism' — that comment is from a third party, "
1238
+ "not the original poster. "
1239
+ "No incitement, no threats, no extremist content in original post."
1240
+ ),
1241
+ "check_user_history": (
1242
+ "Account 4 years active. 200+ posts — community announcements, "
1243
+ "religious holiday greetings, charity drives. "
1244
+ "Zero prior violations. Community leader in local Muslim cultural centre."
1245
+ ),
1246
+ "check_entity_status": (
1247
+ "Islam / Muslim community: CONFIRMED PROTECTED CLASS (religion). "
1248
+ "Eid al-Fitr: globally recognized Islamic religious holiday. "
1249
+ "Content: religious greeting + humanitarian solidarity message. "
1250
+ "No extremist symbols, no prohibited organisations mentioned."
1251
+ ),
1252
+ "view_policy": (
1253
+ "Policy 8.1 — Religious Expression: Greetings and solidarity messages "
1254
+ "for recognized religious holidays are PROTECTED. "
1255
+ "Vague 'conflict zone' solidarity without naming a prohibited group "
1256
+ "is NOT a violation. "
1257
+ "Automated flag was triggered by a COMMENT, not the original post. "
1258
+ "Action: ALLOW — original content is clearly safe religious expression."
1259
+ ),
1260
+ },
1261
+ "required_tools": ["read_comments", "check_entity_status", "view_policy"],
1262
+ "ground_truth": "ALLOW",
1263
+ "ground_truth_signals": {
1264
+ "target": "religion",
1265
+ "is_protected_class": True,
1266
+ "toxicity_level": 0.05,
1267
+ "is_direct_attack": False,
1268
+ "context_type": "statement",
1269
+ "intent": "educational",
1270
+ "abusive_language_present": False,
1271
+ "content_flags": ["political_content"],
1272
+ },
1273
+ },
1274
+ ]
1275
+
1276
+ your_environment.py:
1277
+ from __future__ import annotations
1278
+
1279
+ import random
1280
+ import uuid
1281
+ from typing import Optional, Dict, Any, Set
1282
+
1283
+ try:
1284
+ from openenv_core.env_server import Environment
1285
+ print("[env] Inheriting from openenv_core.env_server.Environment ✅")
1286
+ except ImportError:
1287
+ try:
1288
+ from openenv.core.env_server import Environment
1289
+ print("[env] Inheriting from openenv.core.env_server.Environment ✅")
1290
+ except ImportError:
1291
+ Environment = object
1292
+ print("[env] openenv_core not found — using plain object base ⚠️")
1293
+
1294
+ from models import TrustObservation, TrustAction, TrustState, ContentSignals
1295
+ from tasks import TASKS
1296
+
1297
+
1298
+ TOOL_COSTS: Dict[str, float] = {
1299
+ "read_comments": 0.05,
1300
+ "check_user_history": 0.05,
1301
+ "check_entity_status": 0.10,
1302
+ "view_policy": 0.10,
1303
+ }
1304
+
1305
+ MAX_STEPS = 7
1306
+
1307
+ DECISION_MATRIX: Dict[tuple, float] = {
1308
+ ("REMOVE", "REMOVE"): 1.00,
1309
+ ("ALLOW", "ALLOW"): 1.00,
1310
+ ("ALLOW_WITH_WARNING", "ALLOW_WITH_WARNING"): 1.00,
1311
+ ("ESCALATE", "ESCALATE"): 1.00,
1312
+ ("ALLOW_WITH_WARNING", "ALLOW"): 0.75,
1313
+ ("ALLOW", "ALLOW_WITH_WARNING"): 0.55,
1314
+ ("ESCALATE", "ALLOW_WITH_WARNING"): 0.65,
1315
+ ("ESCALATE", "ALLOW"): 0.45,
1316
+ ("ESCALATE", "REMOVE"): 0.45,
1317
+ ("REMOVE", "ALLOW"): 0.10,
1318
+ ("REMOVE", "ALLOW_WITH_WARNING"): 0.20,
1319
+ ("ALLOW", "REMOVE"): 0.00,
1320
+ ("ALLOW_WITH_WARNING", "REMOVE"): 0.15,
1321
+ }
1322
+
1323
+
1324
+ class TrustSafetyEnvironment(Environment):
1325
+ """
1326
+ 3-Layer Risk-Aware Trust & Safety RL Environment.
1327
+
1328
+ Layer 1 — Evidence gathering : agent uses investigation tools (optional)
1329
+ Layer 2 — Signal extraction : agent outputs ContentSignals as feature extractor
1330
+ Layer 3 — Policy engine : validates signals, applies rules, computes reward
1331
+
1332
+ 8-Component Reward: Accuracy · Policy Alignment · Signal Quality · Escalation
1333
+ Tool Usage · Consistency · Risk Sensitivity · Confidence
1334
+ """
1335
+
1336
+ def __init__(self, seed: int = 42) -> None:
1337
+ super().__init__()
1338
+ self._rng = random.Random(seed)
1339
+ self._current_task: Optional[Dict[str, Any]] = None
1340
+ self._tools_used: Set[str] = set()
1341
+ self._step_count: int = 0
1342
+ self._extracted_signals: Optional[ContentSignals] = None
1343
+ self._validation_result: Optional[Dict[str, Any]] = None
1344
+ self._signals_extracted: bool = False
1345
+ self._obs: Optional[TrustObservation]= None
1346
+ self._state = TrustState()
1347
+
1348
+ # ✅ FIX 3 — build a dict keyed by task_id for O(1) lookup
1349
+ self._tasks: Dict[str, Dict[str, Any]] = {
1350
+ t["task_id"]: t for t in TASKS
1351
+ }
1352
+
1353
+ # -----------------------------------------------------------------------
1354
+ # OpenEnv interface
1355
+ # -----------------------------------------------------------------------
1356
+
1357
+ def reset(self, seed=None, episode_id=None, **kwargs) -> TrustObservation:
1358
+ # ✅ FIX 1 — reset() is now correctly INSIDE the class
1359
+ if seed is not None:
1360
+ self._rng.seed(seed)
1361
+
1362
+ # Pick task by episode_id if provided, else random from all 6
1363
+ if episode_id and episode_id in self._tasks:
1364
+ task = self._tasks[episode_id]
1365
+ else:
1366
+ task = self._rng.choice(list(self._tasks.values()))
1367
+
1368
+ self._current_task = task
1369
+ self._tools_used = set()
1370
+ self._step_count = 0
1371
+ self._extracted_signals = None
1372
+ self._validation_result = None
1373
+ self._signals_extracted = False
1374
+
1375
+ self._state = TrustState(
1376
+ episode_id=task["task_id"],
1377
+ step_count=0,
1378
+ current_task_id=task["task_id"],
1379
+ difficulty=task.get("difficulty", "medium"),
1380
+ risk_level=task.get("risk_level", "medium"),
1381
+ is_done=False,
1382
+ tools_used=[],
1383
+ signals_extracted=False,
1384
+ )
1385
+
1386
+ self._obs = TrustObservation(
1387
+ ticket_id=task["task_id"],
1388
+ post_text=task["post_text"],
1389
+ image_description=task.get("image_description", ""),
1390
+ step_number=0,
1391
+ done=False,
1392
+ )
1393
+ return self._obs # ✅ FIX 2 — single clean return, stray return removed
1394
+
1395
+ def step(self, action: TrustAction, timeouts: Optional[Any] = None,
1396
+ **kwargs) -> TrustObservation:
1397
+ if self._current_task is None or self._obs is None:
1398
+ raise RuntimeError("Call reset() before step().")
1399
+
1400
+ if self._step_count >= MAX_STEPS:
1401
+ self._obs = TrustObservation(
1402
+ ticket_id=self._current_task["task_id"],
1403
+ post_text=self._obs.post_text,
1404
+ image_description=self._obs.image_description,
1405
+ step_number=self._step_count,
1406
+ done=True,
1407
+ reward=0.0,
1408
+ info={"reason": "timeout", "tools_used": list(self._tools_used)},
1409
+ )
1410
+ return self._obs
1411
+
1412
+ atype = action.action_type
1413
+ if atype == "use_tool":
1414
+ return self._handle_tool(action)
1415
+ if atype == "extract_signals":
1416
+ return self._handle_signal_extraction(action)
1417
+ if atype == "final_decision":
1418
+ return self._handle_final_decision(action)
1419
+ raise ValueError(f"Unknown action_type: {atype!r}")
1420
+
1421
+ @property
1422
+ def state(self) -> TrustState:
1423
+ return self._state
1424
+
1425
+ # -----------------------------------------------------------------------
1426
+ # Layer 1 — Tool handling
1427
+ # -----------------------------------------------------------------------
1428
+
1429
+ def _handle_tool(self, action: TrustAction) -> TrustObservation:
1430
+ tool = action.tool_name
1431
+ if tool not in TOOL_COSTS:
1432
+ raise ValueError(f"Unknown tool: {tool!r}")
1433
+ self._tools_used.add(tool)
1434
+ response = self._current_task["tool_responses"].get(tool, "No data found.")
1435
+ field_map = {
1436
+ "read_comments": "comments_found",
1437
+ "check_user_history": "user_history_found",
1438
+ "check_entity_status": "entity_status_found",
1439
+ "view_policy": "policy_found",
1440
+ }
1441
+ self._step_count += 1
1442
+ self._state.step_count = self._step_count
1443
+ self._state.tools_used = list(self._tools_used)
1444
+
1445
+ obs_kwargs = {
1446
+ k: getattr(self._obs, k)
1447
+ for k in ("ticket_id", "post_text", "image_description",
1448
+ "comments_found", "user_history_found",
1449
+ "entity_status_found", "policy_found",
1450
+ "extracted_signals", "validation_result")
1451
+ }
1452
+ obs_kwargs[field_map[tool]] = response
1453
+ obs_kwargs["step_number"] = self._step_count
1454
+ obs_kwargs["done"] = False
1455
+ obs_kwargs["reward"] = None
1456
+
1457
+ self._obs = TrustObservation(**obs_kwargs)
1458
+ return self._obs
1459
+
1460
+ # -----------------------------------------------------------------------
1461
+ # Layer 2 — Signal extraction + validation
1462
+ # -----------------------------------------------------------------------
1463
+
1464
+ def _handle_signal_extraction(self, action: TrustAction) -> TrustObservation:
1465
+ raw = action.signals
1466
+ raw.toxicity_level = max(0.0, min(1.0, float(raw.toxicity_level)))
1467
+ raw.confidence = max(0.0, min(1.0, float(raw.confidence)))
1468
+ if not isinstance(raw.content_flags, list):
1469
+ raw.content_flags = []
1470
+
1471
+ self._extracted_signals = raw
1472
+ self._signals_extracted = True
1473
+ self._validation_result = self._validate_signals(raw)
1474
+ self._step_count += 1
1475
+ self._state.step_count = self._step_count
1476
+ self._state.signals_extracted = True
1477
+
1478
+ obs_kwargs = {
1479
+ k: getattr(self._obs, k)
1480
+ for k in ("ticket_id", "post_text", "image_description",
1481
+ "comments_found", "user_history_found",
1482
+ "entity_status_found", "policy_found")
1483
+ }
1484
+ obs_kwargs["extracted_signals"] = {
1485
+ "target": raw.target,
1486
+ "is_protected_class": raw.is_protected_class,
1487
+ "toxicity_level": raw.toxicity_level,
1488
+ "is_direct_attack": raw.is_direct_attack,
1489
+ "context_type": raw.context_type,
1490
+ "intent": raw.intent,
1491
+ "confidence": raw.confidence,
1492
+ "abusive_language_present": raw.abusive_language_present,
1493
+ "content_flags": raw.content_flags,
1494
+ }
1495
+ obs_kwargs["validation_result"] = self._validation_result
1496
+ obs_kwargs["step_number"] = self._step_count
1497
+ obs_kwargs["done"] = False
1498
+ obs_kwargs["reward"] = None
1499
+
1500
+ self._obs = TrustObservation(**obs_kwargs)
1501
+ return self._obs
1502
+
1503
+ def _validate_signals(self, s: ContentSignals) -> Dict[str, Any]:
1504
+ issues = []
1505
+ conf = s.confidence
1506
+
1507
+ if not s.abusive_language_present and s.toxicity_level > 0.75:
1508
+ issues.append("high_toxicity_without_abusive_language"); conf -= 0.15
1509
+ if s.context_type in ("satire", "education") and s.intent == "harmful":
1510
+ issues.append("harmful_intent_contradicts_satire_context"); conf -= 0.12
1511
+ if s.is_protected_class and s.target == "none":
1512
+ issues.append("protected_class_flagged_without_target"); conf -= 0.10
1513
+ if s.is_direct_attack and s.toxicity_level < 0.25:
1514
+ issues.append("direct_attack_with_low_toxicity"); conf -= 0.12
1515
+ if s.context_type == "spam" and "spam" not in s.content_flags \
1516
+ and "phishing" not in s.content_flags:
1517
+ issues.append("spam_context_missing_spam_flag"); conf -= 0.08
1518
+
1519
+ adj = round(max(0.0, min(1.0, conf)), 3)
1520
+ return {
1521
+ "original_confidence": s.confidence,
1522
+ "adjusted_confidence": adj,
1523
+ "consistency_issues": issues,
1524
+ "issue_count": len(issues),
1525
+ "is_consistent": len(issues) == 0,
1526
+ "needs_escalation_hint": adj < 0.55 or len(issues) >= 2,
1527
+ }
1528
+
1529
+ # -----------------------------------------------------------------------
1530
+ # Layer 3 — Policy engine + final decision
1531
+ # -----------------------------------------------------------------------
1532
+
1533
+ def _policy_engine(self) -> Dict[str, Any]:
1534
+ if not self._extracted_signals or not self._validation_result:
1535
+ return {"recommended": "ESCALATE", "reason": "no_signals_extracted",
1536
+ "rule_strength": "weak"}
1537
+
1538
+ s = self._extracted_signals
1539
+ adj = self._validation_result["adjusted_confidence"]
1540
+
1541
+ if adj < 0.50:
1542
+ return {"recommended": "ESCALATE", "reason": "low_adjusted_confidence",
1543
+ "rule_strength": "medium"}
1544
+ if "spam" in s.content_flags or "phishing" in s.content_flags:
1545
+ return {"recommended": "REMOVE", "reason": "spam_or_phishing_detected",
1546
+ "rule_strength": "strong"}
1547
+ if s.is_protected_class and s.is_direct_attack \
1548
+ and s.toxicity_level > 0.70 and adj > 0.70:
1549
+ return {"recommended": "REMOVE",
1550
+ "reason": "protected_class_direct_attack_high_toxicity",
1551
+ "rule_strength": "strong"}
1552
+ if s.toxicity_level > 0.85 and s.intent == "harmful":
1553
+ return {"recommended": "REMOVE", "reason": "high_toxicity_harmful_intent",
1554
+ "rule_strength": "strong"}
1555
+ if s.context_type in ("satire", "news", "education", "protest") \
1556
+ and s.intent in ("satirical", "descriptive", "educational"):
1557
+ return {"recommended": "ALLOW_WITH_WARNING",
1558
+ "reason": "protected_expression_context",
1559
+ "rule_strength": "medium"}
1560
+ if 0.40 < s.toxicity_level < 0.75 and s.intent == "ambiguous":
1561
+ return {"recommended": "ESCALATE", "reason": "ambiguous_moderate_toxicity",
1562
+ "rule_strength": "medium"}
1563
+ return {"recommended": "ALLOW", "reason": "no_policy_violation_detected",
1564
+ "rule_strength": "medium"}
1565
+
1566
+ def _handle_final_decision(self, action: TrustAction) -> TrustObservation:
1567
+ decision = action.final_decision
1568
+ components = self._compute_components(decision)
1569
+ policy_rec = components.pop("_policy_rec")
1570
+ reward = self._finalize_reward(components)
1571
+
1572
+ self._step_count += 1
1573
+ self._state.step_count = self._step_count
1574
+ self._state.is_done = True
1575
+ components["final_reward"] = reward
1576
+
1577
+ obs_kwargs = {
1578
+ k: getattr(self._obs, k)
1579
+ for k in ("ticket_id", "post_text", "image_description",
1580
+ "comments_found", "user_history_found",
1581
+ "entity_status_found", "policy_found",
1582
+ "extracted_signals", "validation_result")
1583
+ }
1584
+ obs_kwargs["step_number"] = self._step_count
1585
+ obs_kwargs["done"] = True
1586
+ obs_kwargs["reward"] = reward
1587
+ obs_kwargs["info"] = {
1588
+ "final_decision": decision,
1589
+ "ground_truth": self._current_task["ground_truth"],
1590
+ "policy_recommendation": policy_rec,
1591
+ "signals_extracted": self._signals_extracted,
1592
+ "tools_used": list(self._tools_used),
1593
+ "required_tools": self._current_task["required_tools"],
1594
+ "ambiguity_level": self._current_task["ambiguity_level"],
1595
+ "risk_level": self._current_task["risk_level"],
1596
+ "task_id": self._current_task["task_id"],
1597
+ "reward_breakdown": components,
1598
+ }
1599
+
1600
+ self._obs = TrustObservation(**obs_kwargs)
1601
+ return self._obs
1602
+
1603
+ # -----------------------------------------------------------------------
1604
+ # 8-Component Reward Engine
1605
+ # -----------------------------------------------------------------------
1606
+
1607
+ def _compute_components(self, final_decision: str) -> Dict[str, Any]:
1608
+ gt = self._current_task["ground_truth"]
1609
+ required_tools = self._current_task["required_tools"]
1610
+ ambiguity = self._current_task["ambiguity_level"]
1611
+ risk_level = self._current_task["risk_level"]
1612
+ policy_rec = self._policy_engine()
1613
+
1614
+ base_score = DECISION_MATRIX.get((final_decision, gt), 0.20)
1615
+ if final_decision == "ESCALATE" and ambiguity == "high":
1616
+ base_score = max(base_score, 0.70)
1617
+ is_correct = base_score >= 0.90
1618
+
1619
+ rule_weight = {"strong": 1.0, "medium": 0.70, "weak": 0.40}.get(
1620
+ policy_rec.get("rule_strength", "medium"), 0.70)
1621
+ policy_alignment = round(
1622
+ (+0.12 if final_decision == policy_rec["recommended"] else -0.18) * rule_weight, 4)
1623
+
1624
+ signal_accuracy_bonus = self._compute_signal_accuracy()
1625
+
1626
+ adj_conf = (self._validation_result["adjusted_confidence"]
1627
+ if self._validation_result else 0.50)
1628
+ should_escalate = adj_conf < 0.50
1629
+ if should_escalate and final_decision == "ESCALATE":
1630
+ escalation_adj = +0.15
1631
+ elif should_escalate and final_decision != "ESCALATE":
1632
+ escalation_adj = -0.18
1633
+ elif not should_escalate and final_decision == "ESCALATE" and ambiguity == "low":
1634
+ escalation_adj = -0.20
1635
+ elif not should_escalate and final_decision == "ESCALATE":
1636
+ escalation_adj = -0.10
1637
+ else:
1638
+ escalation_adj = 0.0
1639
+
1640
+ signal_bonus = +0.05 if self._signals_extracted else -0.10
1641
+ tool_cost = round(sum(TOOL_COSTS.get(t, 0.0) for t in self._tools_used), 4)
1642
+ missing_required = set(required_tools) - self._tools_used
1643
+ tool_miss_penalty = round(len(missing_required) * 0.25, 4)
1644
+
1645
+ if self._validation_result:
1646
+ n = self._validation_result["issue_count"]
1647
+ validation_penalty = {0: 0.00, 1: 0.05, 2: 0.12}.get(n, 0.20)
1648
+ else:
1649
+ validation_penalty = 0.12
1650
+
1651
+ risk_penalty = 0.0
1652
+ if not is_correct:
1653
+ risk_penalty = {"high": 0.20, "medium": 0.10, "low": 0.0}.get(risk_level, 0.0)
1654
+
1655
+ if base_score < 0.50 and adj_conf > 0.80:
1656
+ confidence_penalty = 0.22
1657
+ elif base_score < 0.50 and adj_conf > 0.65:
1658
+ confidence_penalty = 0.12
1659
+ elif self._signals_extracted and final_decision == "ESCALATE" and adj_conf < 0.55:
1660
+ confidence_penalty = -0.10
1661
+ else:
1662
+ confidence_penalty = 0.0
1663
+
1664
+ return {
1665
+ "base_score": base_score,
1666
+ "policy_alignment": policy_alignment,
1667
+ "signal_accuracy_bonus": signal_accuracy_bonus,
1668
+ "escalation_adj": escalation_adj,
1669
+ "signal_bonus": signal_bonus,
1670
+ "tool_cost": tool_cost,
1671
+ "tool_miss_penalty": tool_miss_penalty,
1672
+ "validation_penalty": validation_penalty,
1673
+ "risk_penalty": risk_penalty,
1674
+ "confidence_penalty": confidence_penalty,
1675
+ "_policy_rec": policy_rec,
1676
+ }
1677
+
1678
+ def _finalize_reward(self, components: Dict[str, Any]) -> float:
1679
+ raw = (
1680
+ components["base_score"]
1681
+ + components["policy_alignment"]
1682
+ + components["signal_accuracy_bonus"]
1683
+ + components["escalation_adj"]
1684
+ + components["signal_bonus"]
1685
+ - components["tool_cost"]
1686
+ - components["tool_miss_penalty"]
1687
+ - components["validation_penalty"]
1688
+ - components["risk_penalty"]
1689
+ - components["confidence_penalty"]
1690
+ )
1691
+ return round(max(0.0, min(1.0, raw)), 4)
1692
+
1693
+ def _compute_signal_accuracy(self) -> float:
1694
+ if not self._extracted_signals:
1695
+ return 0.0
1696
+ gt = self._current_task.get("ground_truth_signals", {})
1697
+ if not gt:
1698
+ return 0.05
1699
+
1700
+ s = self._extracted_signals
1701
+ score = 0.0
1702
+ if s.target == gt.get("target"): score += 0.20
1703
+ if s.intent == gt.get("intent"): score += 0.20
1704
+ if s.context_type == gt.get("context_type"): score += 0.20
1705
+
1706
+ tox_diff = abs(s.toxicity_level - gt.get("toxicity_level", 0.5))
1707
+ score += 0.20 if tox_diff <= 0.20 else (0.10 if tox_diff <= 0.35 else 0.0)
1708
+
1709
+ gt_flags = set(gt.get("content_flags", []))
1710
+ s_flags = set(s.content_flags)
1711
+ if gt_flags:
1712
+ score += 0.20 * min(1.0, len(gt_flags & s_flags) / len(gt_flags))
1713
+ else:
1714
+ score += 0.20 if not s_flags else 0.10
1715
+
1716
+ return round(score * 0.15, 4)
tasks.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASKS = [
2
+ {
3
+ "task_id": "T-001",
4
+ "difficulty": "easy",
5
+ "ambiguity_level": "low",
6
+ "risk_level": "low",
7
+ "post_text": (
8
+ "🚨 URGENT: Your account has been COMPROMISED! "
9
+ "Click here immediately: http://meta-security-verify-now.xyz/login?token=abc123 "
10
+ "Act within 24 hours or your account will be PERMANENTLY BANNED. — Meta Security Team"
11
+ ),
12
+ "image_description": (
13
+ "Poorly formatted alert page with mismatched Meta logos, "
14
+ "red urgent warning banners, and a suspicious external URL."
15
+ ),
16
+ "tool_responses": {
17
+ "read_comments": "Comments disabled. Zero engagement.",
18
+ "check_user_history": "Account created 2 days ago. Zero posts. Flagged by spam filter 3×.",
19
+ "check_entity_status": "Domain 'meta-security-verify-now.xyz' registered 1 day ago — confirmed phishing domain.",
20
+ "view_policy": "Policy 3.2 — Phishing: Impersonating Meta to harvest credentials = immediate REMOVE.",
21
+ },
22
+ "required_tools": [],
23
+ "ground_truth": "REMOVE",
24
+ "ground_truth_signals": {
25
+ "target": "individual",
26
+ "is_protected_class": False,
27
+ "toxicity_level": 0.90,
28
+ "is_direct_attack": True,
29
+ "context_type": "spam",
30
+ "intent": "harmful",
31
+ "abusive_language_present": False,
32
+ "content_flags": ["phishing", "spam", "impersonation"],
33
+ },
34
+ },
35
+ {
36
+ "task_id": "T-002",
37
+ "difficulty": "medium",
38
+ "ambiguity_level": "medium",
39
+ "risk_level": "low",
40
+ "post_text": (
41
+ "LMAO I will literally SNIPE you the second the game starts, "
42
+ "no mercy, you're getting absolutely destroyed 💀🎮 "
43
+ "1v1 me and I'll end your whole career bro. Your team is DONE."
44
+ ),
45
+ "image_description": (
46
+ "Call of Duty: Warzone — squad wipe in progress. "
47
+ "Kill notifications, sniper rifle, enemy team eliminated. Classic gaming screenshot."
48
+ ),
49
+ "tool_responses": {
50
+ "read_comments": (
51
+ "47 comments: 'GG bro!', 'You always trash talk before CoD 😂', "
52
+ "'rematch tonight??', 'your aim is mid bro'. "
53
+ "All mutual gaming friends. Clear competitive banter context."
54
+ ),
55
+ "check_user_history": (
56
+ "Account 3 years active. 500+ posts, 90% gaming. "
57
+ "Zero harassment reports. Regular CoD tournament player."
58
+ ),
59
+ "check_entity_status": "Private individual. Standard gaming community member.",
60
+ "view_policy": (
61
+ "Policy 7.1 — Harassment: Hyperbolic language in established gaming communities "
62
+ "is NOT harassment. 'Snipe', 'destroy', 'end your career' are non-literal expressions."
63
+ ),
64
+ },
65
+ "required_tools": ["read_comments"],
66
+ "ground_truth": "ALLOW",
67
+ "ground_truth_signals": {
68
+ "target": "individual",
69
+ "is_protected_class": False,
70
+ "toxicity_level": 0.30,
71
+ "is_direct_attack": False,
72
+ "context_type": "meme",
73
+ "intent": "ambiguous",
74
+ "abusive_language_present": False,
75
+ "content_flags": ["competitive_language"],
76
+ },
77
+ },
78
+ {
79
+ "task_id": "T-003",
80
+ "difficulty": "hard",
81
+ "ambiguity_level": "high",
82
+ "risk_level": "high",
83
+ "post_text": (
84
+ "Sharing this political cartoon I drew. "
85
+ "Mayor Davidson cutting the homeless shelter budget while sitting on a throne of gold coins. "
86
+ "Caption: 'Priorities.' #LocalPolitics #HoldThemAccountable #SatiricalArt #MayorDavidson"
87
+ ),
88
+ "image_description": (
89
+ "Hand-drawn political cartoon. Exaggerated caricature of a public official on a coin throne, "
90
+ "homeless families pushed away. Deliberately satirical editorial style. Artist signature visible."
91
+ ),
92
+ "tool_responses": {
93
+ "read_comments": (
94
+ "23 comments: 'Powerful commentary', 'Free speech!', 'Great art', 'Defamation!'. "
95
+ "No threats or coordinated harassment detected."
96
+ ),
97
+ "check_user_history": (
98
+ "Account 5 years. Verified local journalist and artist. "
99
+ "Regular political commentary. No violations. Posts cited in local news."
100
+ ),
101
+ "check_entity_status": (
102
+ "Mayor Davidson: CONFIRMED PUBLIC FIGURE. "
103
+ "Elected Mayor of Springfield. Active public interest in recent social services budget cuts. "
104
+ "Multiple credible news articles confirm the cuts are real and newsworthy."
105
+ ),
106
+ "view_policy": (
107
+ "Policy 12.4 — Political Satire: Editorial cartoons targeting PUBLIC OFFICIALS "
108
+ "on PUBLIC INTEREST matters are PROTECTED. Caricature is a recognized technique. "
109
+ "Not realistic impersonation. Correct action: ALLOW_WITH_WARNING."
110
+ ),
111
+ },
112
+ "required_tools": ["check_entity_status", "view_policy"],
113
+ "ground_truth": "ALLOW_WITH_WARNING",
114
+ "ground_truth_signals": {
115
+ "target": "political",
116
+ "is_protected_class": False,
117
+ "toxicity_level": 0.20,
118
+ "is_direct_attack": False,
119
+ "context_type": "satire",
120
+ "intent": "satirical",
121
+ "abusive_language_present": False,
122
+ "content_flags": ["political_content", "public_figure"],
123
+ },
124
+ },
125
+ # ── Append these to the TASKS list ──────────────────────────────────────────
126
+
127
+ {
128
+ "task_id": "T-004",
129
+ "difficulty": "medium",
130
+ "ambiguity_level": "medium",
131
+ "risk_level": "high",
132
+ "post_text": (
133
+ "Sharing this 'educational' video about why [ethnic group] people "
134
+ "are genetically predisposed to criminal behaviour. "
135
+ "The science doesn't lie. Share before it gets deleted. 👀 "
136
+ "#HiddenTruths #CensoredFacts #JustAsking"
137
+ ),
138
+ "image_description": (
139
+ "Screenshot of a YouTube-style video thumbnail with a white lab coat figure "
140
+ "and a graph. Title: 'The Uncomfortable Biology of Crime Rates'. "
141
+ "Uploader has 0 verified credentials. Comments disabled."
142
+ ),
143
+ "tool_responses": {
144
+ "read_comments": (
145
+ "Comments disabled by uploader. "
146
+ "Zero engagement from credible academic or scientific accounts."
147
+ ),
148
+ "check_user_history": (
149
+ "Account 6 months old. 48 posts — all promoting pseudoscientific race content. "
150
+ "Previously received 2 policy warnings for hate content. "
151
+ "No verified credentials. Frequently uses 'just asking questions' framing."
152
+ ),
153
+ "check_entity_status": (
154
+ "Ethnic group mentioned: confirmed PROTECTED CLASS under platform policy. "
155
+ "Claim about genetic criminality: DEBUNKED by mainstream biology and criminology. "
156
+ "This framing is a known pseudoscience hate vector (scientific racism)."
157
+ ),
158
+ "view_policy": (
159
+ "Policy 4.3 — Hate Speech: Content attributing criminal or negative traits "
160
+ "to a protected class based on race or ethnicity = REMOVE. "
161
+ "Disguising hate speech as 'educational content' or 'science' does not "
162
+ "change the violation. Action: REMOVE."
163
+ ),
164
+ },
165
+ "required_tools": ["check_entity_status", "view_policy"],
166
+ "ground_truth": "REMOVE",
167
+ "ground_truth_signals": {
168
+ "target": "group",
169
+ "is_protected_class": True,
170
+ "toxicity_level": 0.85,
171
+ "is_direct_attack": True,
172
+ "context_type": "statement",
173
+ "intent": "harmful",
174
+ "abusive_language_present": False,
175
+ "content_flags": ["slur", "misinformation", "hate_speech_disguised_as_education"],
176
+ },
177
+ },
178
+
179
+ {
180
+ "task_id": "T-005",
181
+ "difficulty": "hard",
182
+ "ambiguity_level": "high",
183
+ "risk_level": "high",
184
+ "post_text": (
185
+ "The Prime Minister announced new economic reforms today. "
186
+ "Opposition leaders called it 'economic warfare on the poor.' "
187
+ "Protests erupted in 3 cities. Security forces used tear gas. "
188
+ "Government says the crackdown was 'necessary to restore order.' "
189
+ "At least 12 protesters arrested. #Politics #Protest #Reform"
190
+ ),
191
+ "image_description": (
192
+ "News-style post with a photo of a large protest crowd holding signs. "
193
+ "Tear gas clouds visible in background. "
194
+ "Caption: 'Thousands take to streets against new austerity measures.' "
195
+ "Shared by a verified news outlet handle. No graphic violence."
196
+ ),
197
+ "tool_responses": {
198
+ "read_comments": (
199
+ "340 comments split sharply: "
200
+ "'State violence against citizens!', 'Finally someone fixing the economy', "
201
+ "'This is propaganda!', 'Biased reporting!', 'Necessary action'. "
202
+ "High engagement, deep political polarization. No incitement to violence detected."
203
+ ),
204
+ "check_user_history": (
205
+ "Posted by @NationalNewsNetwork — verified news organization. "
206
+ "12 years active. Consistently neutral coverage. "
207
+ "No policy violations. Registered press entity."
208
+ ),
209
+ "check_entity_status": (
210
+ "Prime Minister: CONFIRMED PUBLIC FIGURE, head of government. "
211
+ "Economic reforms: CONFIRMED real policy, actively debated in parliament. "
212
+ "Protest: CONFIRMED real event, covered by 6+ independent news agencies. "
213
+ "Political sensitivity: VERY HIGH — active governance dispute in a democracy."
214
+ ),
215
+ "view_policy": (
216
+ "Policy 12.1 — News & Political Content: Factual reporting on real government "
217
+ "actions, protests, and police responses by VERIFIED NEWS entities is PROTECTED. "
218
+ "Even politically contentious factual reporting is protected expression. "
219
+ "Suppressing legitimate news = censorship risk. "
220
+ "Action: ALLOW_WITH_WARNING (politically sensitive content label)."
221
+ ),
222
+ },
223
+ "required_tools": ["check_user_history", "check_entity_status", "view_policy"],
224
+ "ground_truth": "ALLOW_WITH_WARNING",
225
+ "ground_truth_signals": {
226
+ "target": "political",
227
+ "is_protected_class": False,
228
+ "toxicity_level": 0.20,
229
+ "is_direct_attack": False,
230
+ "context_type": "news",
231
+ "intent": "descriptive",
232
+ "abusive_language_present": False,
233
+ "content_flags": ["political_content", "public_figure", "protest"],
234
+ },
235
+ },
236
+
237
+ {
238
+ "task_id": "T-006",
239
+ "difficulty": "hard",
240
+ "ambiguity_level": "high",
241
+ "risk_level": "high",
242
+ "post_text": (
243
+ "Sending Eid Mubarak wishes to my community 🌙✨ "
244
+ "During this holy month let us remember those suffering in conflict zones. "
245
+ "Our brothers and sisters deserve peace and dignity. "
246
+ "May Allah guide us all. #EidMubarak #Ramadan #Peace #Solidarity"
247
+ ),
248
+ "image_description": (
249
+ "Warm celebratory post with a crescent moon and stars graphic. "
250
+ "Soft gold and green colors. "
251
+ "Comment section shows replies in Arabic, Urdu, and English. "
252
+ "Emojis: praying hands, crescent moon, dove."
253
+ ),
254
+ "tool_responses": {
255
+ "read_comments": (
256
+ "89 comments: 'Eid Mubarak!', 'JazakAllah Khair', "
257
+ "'prayers for peace', 'beautiful message', "
258
+ "1 comment flagged by automated system: 'why is this allowed, "
259
+ "these people support terrorism' — that comment is from a third party, "
260
+ "not the original poster. "
261
+ "No incitement, no threats, no extremist content in original post."
262
+ ),
263
+ "check_user_history": (
264
+ "Account 4 years active. 200+ posts — community announcements, "
265
+ "religious holiday greetings, charity drives. "
266
+ "Zero prior violations. Community leader in local Muslim cultural centre."
267
+ ),
268
+ "check_entity_status": (
269
+ "Islam / Muslim community: CONFIRMED PROTECTED CLASS (religion). "
270
+ "Eid al-Fitr: globally recognized Islamic religious holiday. "
271
+ "Content: religious greeting + humanitarian solidarity message. "
272
+ "No extremist symbols, no prohibited organisations mentioned."
273
+ ),
274
+ "view_policy": (
275
+ "Policy 8.1 — Religious Expression: Greetings and solidarity messages "
276
+ "for recognized religious holidays are PROTECTED. "
277
+ "Vague 'conflict zone' solidarity without naming a prohibited group "
278
+ "is NOT a violation. "
279
+ "Automated flag was triggered by a COMMENT, not the original post. "
280
+ "Action: ALLOW — original content is clearly safe religious expression."
281
+ ),
282
+ },
283
+ "required_tools": ["read_comments", "check_entity_status", "view_policy"],
284
+ "ground_truth": "ALLOW",
285
+ "ground_truth_signals": {
286
+ "target": "religion",
287
+ "is_protected_class": True,
288
+ "toxicity_level": 0.05,
289
+ "is_direct_attack": False,
290
+ "context_type": "statement",
291
+ "intent": "educational",
292
+ "abusive_language_present": False,
293
+ "content_flags": ["political_content"],
294
+ },
295
+ },
296
+ ]
train.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import re, random
3
+ from typing import Any
4
+ import torch
5
+ from datasets import Dataset
6
+ from transformers import AutoTokenizer
7
+ from trl import GRPOConfig, GRPOTrainer
8
+ from trl.trainer.grpo_trainer import generate_rollout_completions
9
+
10
+ from client import TrustSafetyEnv
11
+ from models import TrustAction, ContentSignals
12
+
13
+ # ── Config ────────────────────────────────────────────────────────────────────
14
+ MODEL_NAME = "Qwen/Qwen3-1.7B"
15
+ ENV_BASE_URL = "https://YOUR-USERNAME-trust-safety-env.hf.space" # ← your HF Space URL
16
+ MAX_STEPS = 7
17
+ TASKS = ["T-001","T-002","T-003","T-004","T-005","T-006"]
18
+
19
+ SYSTEM_PROMPT = """You are a Trust & Safety content moderation agent.
20
+
21
+ MANDATORY 3-STEP PROCESS — follow this EXACT order every time:
22
+
23
+ STEP 1 — USE TOOLS (always before anything else)
24
+ Call investigation tools to gather evidence before deciding:
25
+ - read_comments → check comment context
26
+ - check_user_history → check prior violations
27
+ - check_entity_status → verify public figure / org status
28
+ - view_policy → retrieve relevant platform policy
29
+ Respond with JSON: {"action_type": "use_tool", "tool_name": "<tool>"}
30
+
31
+ STEP 2 — EXTRACT SIGNALS
32
+ After gathering evidence, respond with JSON:
33
+ {"action_type": "extract_signals", "signals": {
34
+ "target": "<none|person|group>",
35
+ "is_protected_class": <true|false>,
36
+ "toxicity_level": <0.0-1.0>,
37
+ "is_direct_attack": <true|false>,
38
+ "context_type": "<spam|gaming|satire|news|education|statement|protest>",
39
+ "intent": "<harmful|satirical|descriptive|educational|ambiguous>",
40
+ "confidence": <0.0-1.0>,
41
+ "abusive_language_present": <true|false>,
42
+ "content_flags": ["<spam|phishing|hate_speech|harassment|...>"]
43
+ }}
44
+
45
+ STEP 3 — FINAL DECISION
46
+ Respond with JSON: {"action_type": "final_decision", "final_decision": "<REMOVE|ALLOW|ALLOW_WITH_WARNING|ESCALATE>"}
47
+
48
+ ⚠️ Skipping tools costs -0.25 reward per missed required tool.
49
+ ⚠️ Always output valid JSON. Nothing else."""
50
+
51
+
52
+ # ── Dataset — one row per task, shuffled every epoch ─────────────────────────
53
+ def make_dataset() -> Dataset:
54
+ rows = [{"task_id": tid, "prompt": f"Process moderation ticket {tid}"} for tid in TASKS]
55
+ random.shuffle(rows)
56
+ return Dataset.from_list(rows * 50) # 300 rows = ~50 passes over all 6 tasks
57
+
58
+
59
+ # ── Action parser — extract JSON action from model output ────────────────────
60
+ def parse_action(text: str) -> TrustAction | None:
61
+ try:
62
+ match = re.search(r'\{.*\}', text, re.DOTALL)
63
+ if not match:
64
+ return None
65
+ import json
66
+ d = json.loads(match.group())
67
+ atype = d.get("action_type", "")
68
+ if atype == "use_tool":
69
+ return TrustAction(action_type="use_tool", tool_name=d.get("tool_name"))
70
+ if atype == "extract_signals":
71
+ raw = d.get("signals", {})
72
+ flags = raw.get("content_flags", [])
73
+ if not isinstance(flags, list):
74
+ flags = []
75
+ signals = ContentSignals(
76
+ target = raw.get("target", "none"),
77
+ is_protected_class = bool(raw.get("is_protected_class", False)),
78
+ toxicity_level = float(raw.get("toxicity_level", 0.5)),
79
+ is_direct_attack = bool(raw.get("is_direct_attack", False)),
80
+ context_type = raw.get("context_type", "statement"),
81
+ intent = raw.get("intent", "ambiguous"),
82
+ confidence = float(raw.get("confidence", 0.5)),
83
+ abusive_language_present= bool(raw.get("abusive_language_present", False)),
84
+ content_flags = flags,
85
+ )
86
+ return TrustAction(action_type="extract_signals", signals=signals)
87
+ if atype == "final_decision":
88
+ return TrustAction(action_type="final_decision",
89
+ final_decision=d.get("final_decision", "ESCALATE"))
90
+ except Exception:
91
+ pass
92
+ return None
93
+
94
+
95
+ # ── Rollout — one full episode per prompt ────────────────────────────────────
96
+ def rollout_once(trainer, env, tokenizer, task_id: str):
97
+ result = env.reset(episode_id=task_id)
98
+ obs = result.observation
99
+ ep_reward = 0.0
100
+
101
+ all_prompt_ids = []
102
+ all_completion_ids = []
103
+ all_logprobs = []
104
+ reward_components = {"accuracy": 0.0, "tools": 0.0, "signals": 0.0}
105
+
106
+ for step in range(MAX_STEPS):
107
+ if result.done:
108
+ break
109
+
110
+ # Build context from current observation
111
+ context_parts = [f"TICKET: {obs.ticket_id}\n{obs.post_text}"]
112
+ if obs.comments_found: context_parts.append(f"Comments: {obs.comments_found}")
113
+ if obs.user_history_found: context_parts.append(f"User History: {obs.user_history_found}")
114
+ if obs.entity_status_found: context_parts.append(f"Entity: {obs.entity_status_found}")
115
+ if obs.policy_found: context_parts.append(f"Policy: {obs.policy_found}")
116
+ if obs.extracted_signals: context_parts.append(f"Signals: {obs.extracted_signals}")
117
+
118
+ messages = [
119
+ {"role": "system", "content": SYSTEM_PROMPT},
120
+ {"role": "user", "content": "\n\n".join(context_parts)},
121
+ ]
122
+
123
+ rollout = generate_rollout_completions(trainer, [messages])
124
+ text = rollout["text"][0] if rollout["text"] else ""
125
+
126
+ action = parse_action(text)
127
+ if action is None:
128
+ action = TrustAction(action_type="final_decision", final_decision="ESCALATE")
129
+
130
+ result = env.step(action)
131
+ obs = result.observation
132
+
133
+ if result.reward is not None:
134
+ ep_reward = result.reward
135
+
136
+ all_prompt_ids.extend(rollout.get("prompt_ids", []))
137
+ all_completion_ids.extend(rollout.get("completion_ids", []))
138
+ all_logprobs.extend(rollout.get("logprobs", []))
139
+
140
+ # Reward components for GRPO shaping signal
141
+ info = obs.info or {}
142
+ breakdown = info.get("reward_breakdown", {})
143
+ reward_components["accuracy"] = breakdown.get("base_score", ep_reward)
144
+ reward_components["tools"] = -breakdown.get("tool_miss_penalty", 0.0)
145
+ reward_components["signals"] = breakdown.get("signal_accuracy_bonus", 0.0)
146
+
147
+ return {
148
+ "prompt_ids": all_prompt_ids,
149
+ "completion_ids": all_completion_ids,
150
+ "logprobs": all_logprobs,
151
+ "accuracy_reward": reward_components["accuracy"],
152
+ "tool_reward": reward_components["tools"],
153
+ "signal_reward": reward_components["signals"],
154
+ }
155
+
156
+
157
+ # ── Rollout wrapper — called by GRPOTrainer ───────────────────────────────────
158
+ def rollout_func(trainer, prompts, **kwargs):
159
+ env = TrustSafetyEnv(base_url=ENV_BASE_URL).sync()
160
+ results = []
161
+ for p in prompts:
162
+ task_id = p.get("task_id", random.choice(TASKS))
163
+ results.append(rollout_once(trainer, env, None, task_id))
164
+ env.close()
165
+ return results
166
+
167
+
168
+ # ── Reward functions (called by GRPOTrainer separately) ──────────────────────
169
+ def reward_accuracy(completions, **kwargs):
170
+ """Base decision correctness from environment."""
171
+ return [r.get("accuracy_reward", 0.0) for r in completions]
172
+
173
+ def reward_tools(completions, **kwargs):
174
+ """Penalise skipping required tools."""
175
+ return [r.get("tool_reward", 0.0) for r in completions]
176
+
177
+ def reward_signals(completions, **kwargs):
178
+ """Bonus for accurate signal extraction."""
179
+ return [r.get("signal_reward", 0.0) for r in completions]
180
+
181
+
182
+ # ── GRPOConfig ────────────────────────────────────────────────────────────────
183
+ grpo_config = GRPOConfig(
184
+ output_dir = "./trust-safety-grpo",
185
+ num_train_epochs = 1,
186
+ learning_rate = 5e-6,
187
+ gradient_accumulation_steps = 64,
188
+ per_device_train_batch_size = 1,
189
+ num_generations = 2,
190
+ max_completion_length = 256, # JSON actions are longer than Wordle guesses
191
+ max_prompt_length = 1024,
192
+ use_vllm = True,
193
+ vllm_mode = "colocate",
194
+ vllm_gpu_memory_utilization = 0.15,
195
+ gradient_checkpointing = True,
196
+ report_to = "trackio",
197
+ logging_steps = 5,
198
+ save_steps = 50,
199
+ )
200
+
201
+
202
+ # ── Main ──────────────────────────────────────────────────────────────────────
203
+ if __name__ == "__main__":
204
+ dataset = make_dataset()
205
+
206
+ trainer = GRPOTrainer(
207
+ model = MODEL_NAME,
208
+ reward_funcs = [reward_accuracy, reward_tools, reward_signals],
209
+ rollout_func = rollout_func,
210
+ train_dataset = dataset,
211
+ args = grpo_config,
212
+ )
213
+
214
+ print("=" * 60)
215
+ print(f"Training {MODEL_NAME} on Trust & Safety environment")
216
+ print(f"Tasks : {TASKS}")
217
+ print(f"Epochs: {grpo_config.num_train_epochs}")
218
+ print("=" * 60)
219
+
220
+ trainer.train()
221
+ trainer.save_model("./trust-safety-grpo-final")
222
+ print("✅ Model saved to ./trust-safety-grpo-final")
your_environment.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ import uuid
5
+ from typing import Optional, Dict, Any, Set
6
+
7
+ try:
8
+ from openenv_core.env_server import Environment
9
+ print("[env] Inheriting from openenv_core.env_server.Environment ✅")
10
+ except ImportError:
11
+ try:
12
+ from openenv.core.env_server import Environment
13
+ print("[env] Inheriting from openenv.core.env_server.Environment ✅")
14
+ except ImportError:
15
+ Environment = object
16
+ print("[env] openenv_core not found — using plain object base ⚠️")
17
+
18
+ from models import TrustObservation, TrustAction, TrustState, ContentSignals
19
+ from tasks import TASKS
20
+
21
+
22
+ TOOL_COSTS: Dict[str, float] = {
23
+ "read_comments": 0.05,
24
+ "check_user_history": 0.05,
25
+ "check_entity_status": 0.10,
26
+ "view_policy": 0.10,
27
+ }
28
+
29
+ MAX_STEPS = 7
30
+
31
+ DECISION_MATRIX: Dict[tuple, float] = {
32
+ ("REMOVE", "REMOVE"): 1.00,
33
+ ("ALLOW", "ALLOW"): 1.00,
34
+ ("ALLOW_WITH_WARNING", "ALLOW_WITH_WARNING"): 1.00,
35
+ ("ESCALATE", "ESCALATE"): 1.00,
36
+ ("ALLOW_WITH_WARNING", "ALLOW"): 0.75,
37
+ ("ALLOW", "ALLOW_WITH_WARNING"): 0.55,
38
+ ("ESCALATE", "ALLOW_WITH_WARNING"): 0.65,
39
+ ("ESCALATE", "ALLOW"): 0.45,
40
+ ("ESCALATE", "REMOVE"): 0.45,
41
+ ("REMOVE", "ALLOW"): 0.10,
42
+ ("REMOVE", "ALLOW_WITH_WARNING"): 0.20,
43
+ ("ALLOW", "REMOVE"): 0.00,
44
+ ("ALLOW_WITH_WARNING", "REMOVE"): 0.15,
45
+ }
46
+
47
+
48
+ class TrustSafetyEnvironment(Environment):
49
+ """
50
+ 3-Layer Risk-Aware Trust & Safety RL Environment.
51
+
52
+ Layer 1 — Evidence gathering : agent uses investigation tools (optional)
53
+ Layer 2 — Signal extraction : agent outputs ContentSignals as feature extractor
54
+ Layer 3 — Policy engine : validates signals, applies rules, computes reward
55
+
56
+ 8-Component Reward: Accuracy · Policy Alignment · Signal Quality · Escalation
57
+ Tool Usage · Consistency · Risk Sensitivity · Confidence
58
+ """
59
+
60
+ def __init__(self, seed: int = 42) -> None:
61
+ super().__init__()
62
+ self._rng = random.Random(seed)
63
+ self._current_task: Optional[Dict[str, Any]] = None
64
+ self._tools_used: Set[str] = set()
65
+ self._step_count: int = 0
66
+ self._extracted_signals: Optional[ContentSignals] = None
67
+ self._validation_result: Optional[Dict[str, Any]] = None
68
+ self._signals_extracted: bool = False
69
+ self._obs: Optional[TrustObservation]= None
70
+ self._state = TrustState()
71
+
72
+ # ✅ FIX 3 — build a dict keyed by task_id for O(1) lookup
73
+ self._tasks: Dict[str, Dict[str, Any]] = {
74
+ t["task_id"]: t for t in TASKS
75
+ }
76
+
77
+ # -----------------------------------------------------------------------
78
+ # OpenEnv interface
79
+ # -----------------------------------------------------------------------
80
+
81
+ def reset(self, seed=None, episode_id=None, **kwargs) -> TrustObservation:
82
+ # ✅ FIX 1 — reset() is now correctly INSIDE the class
83
+ if seed is not None:
84
+ self._rng.seed(seed)
85
+
86
+ # Pick task by episode_id if provided, else random from all 6
87
+ if episode_id and episode_id in self._tasks:
88
+ task = self._tasks[episode_id]
89
+ else:
90
+ task = self._rng.choice(list(self._tasks.values()))
91
+
92
+ self._current_task = task
93
+ self._tools_used = set()
94
+ self._step_count = 0
95
+ self._extracted_signals = None
96
+ self._validation_result = None
97
+ self._signals_extracted = False
98
+
99
+ self._state = TrustState(
100
+ episode_id=task["task_id"],
101
+ step_count=0,
102
+ current_task_id=task["task_id"],
103
+ difficulty=task.get("difficulty", "medium"),
104
+ risk_level=task.get("risk_level", "medium"),
105
+ is_done=False,
106
+ tools_used=[],
107
+ signals_extracted=False,
108
+ )
109
+
110
+ self._obs = TrustObservation(
111
+ ticket_id=task["task_id"],
112
+ post_text=task["post_text"],
113
+ image_description=task.get("image_description", ""),
114
+ step_number=0,
115
+ done=False,
116
+ )
117
+ return self._obs # ✅ FIX 2 — single clean return, stray return removed
118
+
119
+ def step(self, action: TrustAction, timeouts: Optional[Any] = None,
120
+ **kwargs) -> TrustObservation:
121
+ if self._current_task is None or self._obs is None:
122
+ raise RuntimeError("Call reset() before step().")
123
+
124
+ if self._step_count >= MAX_STEPS:
125
+ self._obs = TrustObservation(
126
+ ticket_id=self._current_task["task_id"],
127
+ post_text=self._obs.post_text,
128
+ image_description=self._obs.image_description,
129
+ step_number=self._step_count,
130
+ done=True,
131
+ reward=0.0,
132
+ info={"reason": "timeout", "tools_used": list(self._tools_used)},
133
+ )
134
+ return self._obs
135
+
136
+ atype = action.action_type
137
+ if atype == "use_tool":
138
+ return self._handle_tool(action)
139
+ if atype == "extract_signals":
140
+ return self._handle_signal_extraction(action)
141
+ if atype == "final_decision":
142
+ return self._handle_final_decision(action)
143
+ raise ValueError(f"Unknown action_type: {atype!r}")
144
+
145
+ @property
146
+ def state(self) -> TrustState:
147
+ return self._state
148
+
149
+ # -----------------------------------------------------------------------
150
+ # Layer 1 — Tool handling
151
+ # -----------------------------------------------------------------------
152
+
153
+ def _handle_tool(self, action: TrustAction) -> TrustObservation:
154
+ tool = action.tool_name
155
+ if tool not in TOOL_COSTS:
156
+ raise ValueError(f"Unknown tool: {tool!r}")
157
+ self._tools_used.add(tool)
158
+ response = self._current_task["tool_responses"].get(tool, "No data found.")
159
+ field_map = {
160
+ "read_comments": "comments_found",
161
+ "check_user_history": "user_history_found",
162
+ "check_entity_status": "entity_status_found",
163
+ "view_policy": "policy_found",
164
+ }
165
+ self._step_count += 1
166
+ self._state.step_count = self._step_count
167
+ self._state.tools_used = list(self._tools_used)
168
+
169
+ obs_kwargs = {
170
+ k: getattr(self._obs, k)
171
+ for k in ("ticket_id", "post_text", "image_description",
172
+ "comments_found", "user_history_found",
173
+ "entity_status_found", "policy_found",
174
+ "extracted_signals", "validation_result")
175
+ }
176
+ obs_kwargs[field_map[tool]] = response
177
+ obs_kwargs["step_number"] = self._step_count
178
+ obs_kwargs["done"] = False
179
+ obs_kwargs["reward"] = None
180
+
181
+ self._obs = TrustObservation(**obs_kwargs)
182
+ return self._obs
183
+
184
+ # -----------------------------------------------------------------------
185
+ # Layer 2 — Signal extraction + validation
186
+ # -----------------------------------------------------------------------
187
+
188
+ def _handle_signal_extraction(self, action: TrustAction) -> TrustObservation:
189
+ raw = action.signals
190
+ raw.toxicity_level = max(0.0, min(1.0, float(raw.toxicity_level)))
191
+ raw.confidence = max(0.0, min(1.0, float(raw.confidence)))
192
+ if not isinstance(raw.content_flags, list):
193
+ raw.content_flags = []
194
+
195
+ self._extracted_signals = raw
196
+ self._signals_extracted = True
197
+ self._validation_result = self._validate_signals(raw)
198
+ self._step_count += 1
199
+ self._state.step_count = self._step_count
200
+ self._state.signals_extracted = True
201
+
202
+ obs_kwargs = {
203
+ k: getattr(self._obs, k)
204
+ for k in ("ticket_id", "post_text", "image_description",
205
+ "comments_found", "user_history_found",
206
+ "entity_status_found", "policy_found")
207
+ }
208
+ obs_kwargs["extracted_signals"] = {
209
+ "target": raw.target,
210
+ "is_protected_class": raw.is_protected_class,
211
+ "toxicity_level": raw.toxicity_level,
212
+ "is_direct_attack": raw.is_direct_attack,
213
+ "context_type": raw.context_type,
214
+ "intent": raw.intent,
215
+ "confidence": raw.confidence,
216
+ "abusive_language_present": raw.abusive_language_present,
217
+ "content_flags": raw.content_flags,
218
+ }
219
+ obs_kwargs["validation_result"] = self._validation_result
220
+ obs_kwargs["step_number"] = self._step_count
221
+ obs_kwargs["done"] = False
222
+ obs_kwargs["reward"] = None
223
+
224
+ self._obs = TrustObservation(**obs_kwargs)
225
+ return self._obs
226
+
227
+ def _validate_signals(self, s: ContentSignals) -> Dict[str, Any]:
228
+ issues = []
229
+ conf = s.confidence
230
+
231
+ if not s.abusive_language_present and s.toxicity_level > 0.75:
232
+ issues.append("high_toxicity_without_abusive_language"); conf -= 0.15
233
+ if s.context_type in ("satire", "education") and s.intent == "harmful":
234
+ issues.append("harmful_intent_contradicts_satire_context"); conf -= 0.12
235
+ if s.is_protected_class and s.target == "none":
236
+ issues.append("protected_class_flagged_without_target"); conf -= 0.10
237
+ if s.is_direct_attack and s.toxicity_level < 0.25:
238
+ issues.append("direct_attack_with_low_toxicity"); conf -= 0.12
239
+ if s.context_type == "spam" and "spam" not in s.content_flags \
240
+ and "phishing" not in s.content_flags:
241
+ issues.append("spam_context_missing_spam_flag"); conf -= 0.08
242
+
243
+ adj = round(max(0.0, min(1.0, conf)), 3)
244
+ return {
245
+ "original_confidence": s.confidence,
246
+ "adjusted_confidence": adj,
247
+ "consistency_issues": issues,
248
+ "issue_count": len(issues),
249
+ "is_consistent": len(issues) == 0,
250
+ "needs_escalation_hint": adj < 0.55 or len(issues) >= 2,
251
+ }
252
+
253
+ # -----------------------------------------------------------------------
254
+ # Layer 3 — Policy engine + final decision
255
+ # -----------------------------------------------------------------------
256
+
257
+ def _policy_engine(self) -> Dict[str, Any]:
258
+ if not self._extracted_signals or not self._validation_result:
259
+ return {"recommended": "ESCALATE", "reason": "no_signals_extracted",
260
+ "rule_strength": "weak"}
261
+
262
+ s = self._extracted_signals
263
+ adj = self._validation_result["adjusted_confidence"]
264
+
265
+ if adj < 0.50:
266
+ return {"recommended": "ESCALATE", "reason": "low_adjusted_confidence",
267
+ "rule_strength": "medium"}
268
+ if "spam" in s.content_flags or "phishing" in s.content_flags:
269
+ return {"recommended": "REMOVE", "reason": "spam_or_phishing_detected",
270
+ "rule_strength": "strong"}
271
+ if s.is_protected_class and s.is_direct_attack \
272
+ and s.toxicity_level > 0.70 and adj > 0.70:
273
+ return {"recommended": "REMOVE",
274
+ "reason": "protected_class_direct_attack_high_toxicity",
275
+ "rule_strength": "strong"}
276
+ if s.toxicity_level > 0.85 and s.intent == "harmful":
277
+ return {"recommended": "REMOVE", "reason": "high_toxicity_harmful_intent",
278
+ "rule_strength": "strong"}
279
+ if s.context_type in ("satire", "news", "education", "protest") \
280
+ and s.intent in ("satirical", "descriptive", "educational"):
281
+ return {"recommended": "ALLOW_WITH_WARNING",
282
+ "reason": "protected_expression_context",
283
+ "rule_strength": "medium"}
284
+ if 0.40 < s.toxicity_level < 0.75 and s.intent == "ambiguous":
285
+ return {"recommended": "ESCALATE", "reason": "ambiguous_moderate_toxicity",
286
+ "rule_strength": "medium"}
287
+ return {"recommended": "ALLOW", "reason": "no_policy_violation_detected",
288
+ "rule_strength": "medium"}
289
+
290
+ def _handle_final_decision(self, action: TrustAction) -> TrustObservation:
291
+ decision = action.final_decision
292
+ components = self._compute_components(decision)
293
+ policy_rec = components.pop("_policy_rec")
294
+ reward = self._finalize_reward(components)
295
+
296
+ self._step_count += 1
297
+ self._state.step_count = self._step_count
298
+ self._state.is_done = True
299
+ components["final_reward"] = reward
300
+
301
+ obs_kwargs = {
302
+ k: getattr(self._obs, k)
303
+ for k in ("ticket_id", "post_text", "image_description",
304
+ "comments_found", "user_history_found",
305
+ "entity_status_found", "policy_found",
306
+ "extracted_signals", "validation_result")
307
+ }
308
+ obs_kwargs["step_number"] = self._step_count
309
+ obs_kwargs["done"] = True
310
+ obs_kwargs["reward"] = reward
311
+ obs_kwargs["info"] = {
312
+ "final_decision": decision,
313
+ "ground_truth": self._current_task["ground_truth"],
314
+ "policy_recommendation": policy_rec,
315
+ "signals_extracted": self._signals_extracted,
316
+ "tools_used": list(self._tools_used),
317
+ "required_tools": self._current_task["required_tools"],
318
+ "ambiguity_level": self._current_task["ambiguity_level"],
319
+ "risk_level": self._current_task["risk_level"],
320
+ "task_id": self._current_task["task_id"],
321
+ "reward_breakdown": components,
322
+ }
323
+
324
+ self._obs = TrustObservation(**obs_kwargs)
325
+ return self._obs
326
+
327
+ # -----------------------------------------------------------------------
328
+ # 8-Component Reward Engine
329
+ # -----------------------------------------------------------------------
330
+
331
+ def _compute_components(self, final_decision: str) -> Dict[str, Any]:
332
+ gt = self._current_task["ground_truth"]
333
+ required_tools = self._current_task["required_tools"]
334
+ ambiguity = self._current_task["ambiguity_level"]
335
+ risk_level = self._current_task["risk_level"]
336
+ policy_rec = self._policy_engine()
337
+
338
+ base_score = DECISION_MATRIX.get((final_decision, gt), 0.20)
339
+ if final_decision == "ESCALATE" and ambiguity == "high":
340
+ base_score = max(base_score, 0.70)
341
+ is_correct = base_score >= 0.90
342
+
343
+ rule_weight = {"strong": 1.0, "medium": 0.70, "weak": 0.40}.get(
344
+ policy_rec.get("rule_strength", "medium"), 0.70)
345
+ policy_alignment = round(
346
+ (+0.12 if final_decision == policy_rec["recommended"] else -0.18) * rule_weight, 4)
347
+
348
+ signal_accuracy_bonus = self._compute_signal_accuracy()
349
+
350
+ adj_conf = (self._validation_result["adjusted_confidence"]
351
+ if self._validation_result else 0.50)
352
+ should_escalate = adj_conf < 0.50
353
+ if should_escalate and final_decision == "ESCALATE":
354
+ escalation_adj = +0.15
355
+ elif should_escalate and final_decision != "ESCALATE":
356
+ escalation_adj = -0.18
357
+ elif not should_escalate and final_decision == "ESCALATE" and ambiguity == "low":
358
+ escalation_adj = -0.20
359
+ elif not should_escalate and final_decision == "ESCALATE":
360
+ escalation_adj = -0.10
361
+ else:
362
+ escalation_adj = 0.0
363
+
364
+ signal_bonus = +0.05 if self._signals_extracted else -0.10
365
+ tool_cost = round(sum(TOOL_COSTS.get(t, 0.0) for t in self._tools_used), 4)
366
+ missing_required = set(required_tools) - self._tools_used
367
+ tool_miss_penalty = round(len(missing_required) * 0.25, 4)
368
+
369
+ if self._validation_result:
370
+ n = self._validation_result["issue_count"]
371
+ validation_penalty = {0: 0.00, 1: 0.05, 2: 0.12}.get(n, 0.20)
372
+ else:
373
+ validation_penalty = 0.12
374
+
375
+ risk_penalty = 0.0
376
+ if not is_correct:
377
+ risk_penalty = {"high": 0.20, "medium": 0.10, "low": 0.0}.get(risk_level, 0.0)
378
+
379
+ if base_score < 0.50 and adj_conf > 0.80:
380
+ confidence_penalty = 0.22
381
+ elif base_score < 0.50 and adj_conf > 0.65:
382
+ confidence_penalty = 0.12
383
+ elif self._signals_extracted and final_decision == "ESCALATE" and adj_conf < 0.55:
384
+ confidence_penalty = -0.10
385
+ else:
386
+ confidence_penalty = 0.0
387
+
388
+ return {
389
+ "base_score": base_score,
390
+ "policy_alignment": policy_alignment,
391
+ "signal_accuracy_bonus": signal_accuracy_bonus,
392
+ "escalation_adj": escalation_adj,
393
+ "signal_bonus": signal_bonus,
394
+ "tool_cost": tool_cost,
395
+ "tool_miss_penalty": tool_miss_penalty,
396
+ "validation_penalty": validation_penalty,
397
+ "risk_penalty": risk_penalty,
398
+ "confidence_penalty": confidence_penalty,
399
+ "_policy_rec": policy_rec,
400
+ }
401
+
402
+ def _finalize_reward(self, components: Dict[str, Any]) -> float:
403
+ raw = (
404
+ components["base_score"]
405
+ + components["policy_alignment"]
406
+ + components["signal_accuracy_bonus"]
407
+ + components["escalation_adj"]
408
+ + components["signal_bonus"]
409
+ - components["tool_cost"]
410
+ - components["tool_miss_penalty"]
411
+ - components["validation_penalty"]
412
+ - components["risk_penalty"]
413
+ - components["confidence_penalty"]
414
+ )
415
+ return round(max(0.0, min(1.0, raw)), 4)
416
+
417
+ def _compute_signal_accuracy(self) -> float:
418
+ if not self._extracted_signals:
419
+ return 0.0
420
+ gt = self._current_task.get("ground_truth_signals", {})
421
+ if not gt:
422
+ return 0.05
423
+
424
+ s = self._extracted_signals
425
+ score = 0.0
426
+ if s.target == gt.get("target"): score += 0.20
427
+ if s.intent == gt.get("intent"): score += 0.20
428
+ if s.context_type == gt.get("context_type"): score += 0.20
429
+
430
+ tox_diff = abs(s.toxicity_level - gt.get("toxicity_level", 0.5))
431
+ score += 0.20 if tox_diff <= 0.20 else (0.10 if tox_diff <= 0.35 else 0.0)
432
+
433
+ gt_flags = set(gt.get("content_flags", []))
434
+ s_flags = set(s.content_flags)
435
+ if gt_flags:
436
+ score += 0.20 * min(1.0, len(gt_flags & s_flags) / len(gt_flags))
437
+ else:
438
+ score += 0.20 if not s_flags else 0.10
439
+
440
+ return round(score * 0.15, 4)