from __future__ import annotations import random import uuid from typing import Optional, Dict, Any, Set try: from openenv_core.env_server import Environment print("[env] Inheriting from openenv_core.env_server.Environment ✅") except ImportError: try: from openenv.core.env_server import Environment print("[env] Inheriting from openenv.core.env_server.Environment ✅") except ImportError: Environment = object print("[env] openenv_core not found — using plain object base ⚠️") from models import TrustObservation, TrustAction, TrustState, ContentSignals from tasks import TASKS TOOL_COSTS: Dict[str, float] = { "read_comments": 0.05, "check_user_history": 0.05, "check_entity_status": 0.10, "view_policy": 0.10, } MAX_STEPS = 7 DECISION_MATRIX: Dict[tuple, float] = { ("REMOVE", "REMOVE"): 1.00, ("ALLOW", "ALLOW"): 1.00, ("ALLOW_WITH_WARNING", "ALLOW_WITH_WARNING"): 1.00, ("ESCALATE", "ESCALATE"): 1.00, ("ALLOW_WITH_WARNING", "ALLOW"): 0.75, ("ALLOW", "ALLOW_WITH_WARNING"): 0.55, ("ESCALATE", "ALLOW_WITH_WARNING"): 0.65, ("ESCALATE", "ALLOW"): 0.45, ("ESCALATE", "REMOVE"): 0.45, ("REMOVE", "ALLOW"): 0.10, ("REMOVE", "ALLOW_WITH_WARNING"): 0.20, ("ALLOW", "REMOVE"): 0.00, ("ALLOW_WITH_WARNING", "REMOVE"): 0.15, } class TrustSafetyEnvironment(Environment): """ 3-Layer Risk-Aware Trust & Safety RL Environment. Layer 1 — Evidence gathering : agent uses investigation tools (optional) Layer 2 — Signal extraction : agent outputs ContentSignals as feature extractor Layer 3 — Policy engine : validates signals, applies rules, computes reward 8-Component Reward: Accuracy · Policy Alignment · Signal Quality · Escalation Tool Usage · Consistency · Risk Sensitivity · Confidence """ def __init__(self, seed: int = 42) -> None: super().__init__() self._rng = random.Random(seed) self._current_task: Optional[Dict[str, Any]] = None self._tools_used: Set[str] = set() self._step_count: int = 0 self._extracted_signals: Optional[ContentSignals] = None self._validation_result: Optional[Dict[str, Any]] = None self._signals_extracted: bool = False self._obs: Optional[TrustObservation]= None self._state = TrustState() # ✅ FIX 3 — build a dict keyed by task_id for O(1) lookup self._tasks: Dict[str, Dict[str, Any]] = { t["task_id"]: t for t in TASKS } # ----------------------------------------------------------------------- # OpenEnv interface # ----------------------------------------------------------------------- def reset(self, seed=None, episode_id=None, **kwargs) -> TrustObservation: # ✅ FIX 1 — reset() is now correctly INSIDE the class if seed is not None: self._rng.seed(seed) # Pick task by episode_id if provided, else random from all 6 if episode_id and episode_id in self._tasks: task = self._tasks[episode_id] else: task = self._rng.choice(list(self._tasks.values())) self._current_task = task self._tools_used = set() self._step_count = 0 self._extracted_signals = None self._validation_result = None self._signals_extracted = False self._state = TrustState( episode_id=task["task_id"], step_count=0, current_task_id=task["task_id"], difficulty=task.get("difficulty", "medium"), risk_level=task.get("risk_level", "medium"), is_done=False, tools_used=[], signals_extracted=False, ) self._obs = TrustObservation( ticket_id=task["task_id"], post_text=task["post_text"], image_description=task.get("image_description", ""), step_number=0, done=False, ) return self._obs # ✅ FIX 2 — single clean return, stray return removed def step(self, action: TrustAction, timeouts: Optional[Any] = None, **kwargs) -> TrustObservation: if self._current_task is None or self._obs is None: raise RuntimeError("Call reset() before step().") if self._step_count >= MAX_STEPS: self._obs = TrustObservation( ticket_id=self._current_task["task_id"], post_text=self._obs.post_text, image_description=self._obs.image_description, step_number=self._step_count, done=True, reward=0.0, info={"reason": "timeout", "tools_used": list(self._tools_used)}, ) return self._obs atype = action.action_type if atype == "use_tool": return self._handle_tool(action) if atype == "extract_signals": return self._handle_signal_extraction(action) if atype == "final_decision": return self._handle_final_decision(action) raise ValueError(f"Unknown action_type: {atype!r}") @property def state(self) -> TrustState: return self._state # ----------------------------------------------------------------------- # Layer 1 — Tool handling # ----------------------------------------------------------------------- def _handle_tool(self, action: TrustAction) -> TrustObservation: tool = action.tool_name if tool not in TOOL_COSTS: raise ValueError(f"Unknown tool: {tool!r}") self._tools_used.add(tool) response = self._current_task["tool_responses"].get(tool, "No data found.") field_map = { "read_comments": "comments_found", "check_user_history": "user_history_found", "check_entity_status": "entity_status_found", "view_policy": "policy_found", } self._step_count += 1 self._state.step_count = self._step_count self._state.tools_used = list(self._tools_used) obs_kwargs = { k: getattr(self._obs, k) for k in ("ticket_id", "post_text", "image_description", "comments_found", "user_history_found", "entity_status_found", "policy_found", "extracted_signals", "validation_result") } obs_kwargs[field_map[tool]] = response obs_kwargs["step_number"] = self._step_count obs_kwargs["done"] = False obs_kwargs["reward"] = None self._obs = TrustObservation(**obs_kwargs) return self._obs # ----------------------------------------------------------------------- # Layer 2 — Signal extraction + validation # ----------------------------------------------------------------------- def _handle_signal_extraction(self, action: TrustAction) -> TrustObservation: raw = action.signals raw.toxicity_level = max(0.0, min(1.0, float(raw.toxicity_level))) raw.confidence = max(0.0, min(1.0, float(raw.confidence))) if not isinstance(raw.content_flags, list): raw.content_flags = [] self._extracted_signals = raw self._signals_extracted = True self._validation_result = self._validate_signals(raw) self._step_count += 1 self._state.step_count = self._step_count self._state.signals_extracted = True obs_kwargs = { k: getattr(self._obs, k) for k in ("ticket_id", "post_text", "image_description", "comments_found", "user_history_found", "entity_status_found", "policy_found") } obs_kwargs["extracted_signals"] = { "target": raw.target, "is_protected_class": raw.is_protected_class, "toxicity_level": raw.toxicity_level, "is_direct_attack": raw.is_direct_attack, "context_type": raw.context_type, "intent": raw.intent, "confidence": raw.confidence, "abusive_language_present": raw.abusive_language_present, "content_flags": raw.content_flags, } obs_kwargs["validation_result"] = self._validation_result obs_kwargs["step_number"] = self._step_count obs_kwargs["done"] = False obs_kwargs["reward"] = None self._obs = TrustObservation(**obs_kwargs) return self._obs def _validate_signals(self, s: ContentSignals) -> Dict[str, Any]: issues = [] conf = s.confidence if not s.abusive_language_present and s.toxicity_level > 0.75: issues.append("high_toxicity_without_abusive_language"); conf -= 0.15 if s.context_type in ("satire", "education") and s.intent == "harmful": issues.append("harmful_intent_contradicts_satire_context"); conf -= 0.12 if s.is_protected_class and s.target == "none": issues.append("protected_class_flagged_without_target"); conf -= 0.10 if s.is_direct_attack and s.toxicity_level < 0.25: issues.append("direct_attack_with_low_toxicity"); conf -= 0.12 if s.context_type == "spam" and "spam" not in s.content_flags \ and "phishing" not in s.content_flags: issues.append("spam_context_missing_spam_flag"); conf -= 0.08 adj = round(max(0.0, min(1.0, conf)), 3) return { "original_confidence": s.confidence, "adjusted_confidence": adj, "consistency_issues": issues, "issue_count": len(issues), "is_consistent": len(issues) == 0, "needs_escalation_hint": adj < 0.55 or len(issues) >= 2, } # ----------------------------------------------------------------------- # Layer 3 — Policy engine + final decision # ----------------------------------------------------------------------- def _policy_engine(self) -> Dict[str, Any]: if not self._extracted_signals or not self._validation_result: return {"recommended": "ESCALATE", "reason": "no_signals_extracted", "rule_strength": "weak"} s = self._extracted_signals adj = self._validation_result["adjusted_confidence"] if adj < 0.50: return {"recommended": "ESCALATE", "reason": "low_adjusted_confidence", "rule_strength": "medium"} if "spam" in s.content_flags or "phishing" in s.content_flags: return {"recommended": "REMOVE", "reason": "spam_or_phishing_detected", "rule_strength": "strong"} if s.is_protected_class and s.is_direct_attack \ and s.toxicity_level > 0.70 and adj > 0.70: return {"recommended": "REMOVE", "reason": "protected_class_direct_attack_high_toxicity", "rule_strength": "strong"} if s.toxicity_level > 0.85 and s.intent == "harmful": return {"recommended": "REMOVE", "reason": "high_toxicity_harmful_intent", "rule_strength": "strong"} if s.context_type in ("satire", "news", "education", "protest") \ and s.intent in ("satirical", "descriptive", "educational"): return {"recommended": "ALLOW_WITH_WARNING", "reason": "protected_expression_context", "rule_strength": "medium"} if 0.40 < s.toxicity_level < 0.75 and s.intent == "ambiguous": return {"recommended": "ESCALATE", "reason": "ambiguous_moderate_toxicity", "rule_strength": "medium"} return {"recommended": "ALLOW", "reason": "no_policy_violation_detected", "rule_strength": "medium"} def _handle_final_decision(self, action: TrustAction) -> TrustObservation: decision = action.final_decision components = self._compute_components(decision) policy_rec = components.pop("_policy_rec") reward = self._finalize_reward(components) self._step_count += 1 self._state.step_count = self._step_count self._state.is_done = True components["final_reward"] = reward obs_kwargs = { k: getattr(self._obs, k) for k in ("ticket_id", "post_text", "image_description", "comments_found", "user_history_found", "entity_status_found", "policy_found", "extracted_signals", "validation_result") } obs_kwargs["step_number"] = self._step_count obs_kwargs["done"] = True obs_kwargs["reward"] = reward obs_kwargs["info"] = { "final_decision": decision, "ground_truth": self._current_task["ground_truth"], "policy_recommendation": policy_rec, "signals_extracted": self._signals_extracted, "tools_used": list(self._tools_used), "required_tools": self._current_task["required_tools"], "ambiguity_level": self._current_task["ambiguity_level"], "risk_level": self._current_task["risk_level"], "task_id": self._current_task["task_id"], "reward_breakdown": components, } self._obs = TrustObservation(**obs_kwargs) return self._obs # ----------------------------------------------------------------------- # 8-Component Reward Engine # ----------------------------------------------------------------------- def _compute_components(self, final_decision: str) -> Dict[str, Any]: gt = self._current_task["ground_truth"] required_tools = self._current_task["required_tools"] ambiguity = self._current_task["ambiguity_level"] risk_level = self._current_task["risk_level"] policy_rec = self._policy_engine() base_score = DECISION_MATRIX.get((final_decision, gt), 0.20) if final_decision == "ESCALATE" and ambiguity == "high": base_score = max(base_score, 0.70) is_correct = base_score >= 0.90 rule_weight = {"strong": 1.0, "medium": 0.70, "weak": 0.40}.get( policy_rec.get("rule_strength", "medium"), 0.70) policy_alignment = round( (+0.12 if final_decision == policy_rec["recommended"] else -0.18) * rule_weight, 4) signal_accuracy_bonus = self._compute_signal_accuracy() adj_conf = (self._validation_result["adjusted_confidence"] if self._validation_result else 0.50) should_escalate = adj_conf < 0.50 if should_escalate and final_decision == "ESCALATE": escalation_adj = +0.15 elif should_escalate and final_decision != "ESCALATE": escalation_adj = -0.18 elif not should_escalate and final_decision == "ESCALATE" and ambiguity == "low": escalation_adj = -0.20 elif not should_escalate and final_decision == "ESCALATE": escalation_adj = -0.10 else: escalation_adj = 0.0 signal_bonus = +0.05 if self._signals_extracted else -0.10 tool_cost = round(sum(TOOL_COSTS.get(t, 0.0) for t in self._tools_used), 4) missing_required = set(required_tools) - self._tools_used tool_miss_penalty = round(len(missing_required) * 0.25, 4) if self._validation_result: n = self._validation_result["issue_count"] validation_penalty = {0: 0.00, 1: 0.05, 2: 0.12}.get(n, 0.20) else: validation_penalty = 0.12 risk_penalty = 0.0 if not is_correct: risk_penalty = {"high": 0.20, "medium": 0.10, "low": 0.0}.get(risk_level, 0.0) if base_score < 0.50 and adj_conf > 0.80: confidence_penalty = 0.22 elif base_score < 0.50 and adj_conf > 0.65: confidence_penalty = 0.12 elif self._signals_extracted and final_decision == "ESCALATE" and adj_conf < 0.55: confidence_penalty = -0.10 else: confidence_penalty = 0.0 return { "base_score": base_score, "policy_alignment": policy_alignment, "signal_accuracy_bonus": signal_accuracy_bonus, "escalation_adj": escalation_adj, "signal_bonus": signal_bonus, "tool_cost": tool_cost, "tool_miss_penalty": tool_miss_penalty, "validation_penalty": validation_penalty, "risk_penalty": risk_penalty, "confidence_penalty": confidence_penalty, "_policy_rec": policy_rec, } def _finalize_reward(self, components: Dict[str, Any]) -> float: raw = ( components["base_score"] + components["policy_alignment"] + components["signal_accuracy_bonus"] + components["escalation_adj"] + components["signal_bonus"] - components["tool_cost"] - components["tool_miss_penalty"] - components["validation_penalty"] - components["risk_penalty"] - components["confidence_penalty"] ) return round(max(0.0, min(1.0, raw)), 4) def _compute_signal_accuracy(self) -> float: if not self._extracted_signals: return 0.0 gt = self._current_task.get("ground_truth_signals", {}) if not gt: return 0.05 s = self._extracted_signals score = 0.0 if s.target == gt.get("target"): score += 0.20 if s.intent == gt.get("intent"): score += 0.20 if s.context_type == gt.get("context_type"): score += 0.20 tox_diff = abs(s.toxicity_level - gt.get("toxicity_level", 0.5)) score += 0.20 if tox_diff <= 0.20 else (0.10 if tox_diff <= 0.35 else 0.0) gt_flags = set(gt.get("content_flags", [])) s_flags = set(s.content_flags) if gt_flags: score += 0.20 * min(1.0, len(gt_flags & s_flags) / len(gt_flags)) else: score += 0.20 if not s_flags else 0.10 return round(score * 0.15, 4)