Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import random | |
| import uuid | |
| from typing import Optional, Dict, Any, Set | |
| try: | |
| from openenv_core.env_server import Environment | |
| print("[env] Inheriting from openenv_core.env_server.Environment ✅") | |
| except ImportError: | |
| try: | |
| from openenv.core.env_server import Environment | |
| print("[env] Inheriting from openenv.core.env_server.Environment ✅") | |
| except ImportError: | |
| Environment = object | |
| print("[env] openenv_core not found — using plain object base ⚠️") | |
| from models import TrustObservation, TrustAction, TrustState, ContentSignals | |
| from tasks import TASKS | |
| TOOL_COSTS: Dict[str, float] = { | |
| "read_comments": 0.05, | |
| "check_user_history": 0.05, | |
| "check_entity_status": 0.10, | |
| "view_policy": 0.10, | |
| } | |
| MAX_STEPS = 7 | |
| DECISION_MATRIX: Dict[tuple, float] = { | |
| ("REMOVE", "REMOVE"): 1.00, | |
| ("ALLOW", "ALLOW"): 1.00, | |
| ("ALLOW_WITH_WARNING", "ALLOW_WITH_WARNING"): 1.00, | |
| ("ESCALATE", "ESCALATE"): 1.00, | |
| ("ALLOW_WITH_WARNING", "ALLOW"): 0.75, | |
| ("ALLOW", "ALLOW_WITH_WARNING"): 0.55, | |
| ("ESCALATE", "ALLOW_WITH_WARNING"): 0.65, | |
| ("ESCALATE", "ALLOW"): 0.45, | |
| ("ESCALATE", "REMOVE"): 0.45, | |
| ("REMOVE", "ALLOW"): 0.10, | |
| ("REMOVE", "ALLOW_WITH_WARNING"): 0.20, | |
| ("ALLOW", "REMOVE"): 0.00, | |
| ("ALLOW_WITH_WARNING", "REMOVE"): 0.15, | |
| } | |
| class TrustSafetyEnvironment(Environment): | |
| """ | |
| 3-Layer Risk-Aware Trust & Safety RL Environment. | |
| Layer 1 — Evidence gathering : agent uses investigation tools (optional) | |
| Layer 2 — Signal extraction : agent outputs ContentSignals as feature extractor | |
| Layer 3 — Policy engine : validates signals, applies rules, computes reward | |
| 8-Component Reward: Accuracy · Policy Alignment · Signal Quality · Escalation | |
| Tool Usage · Consistency · Risk Sensitivity · Confidence | |
| """ | |
| def __init__(self, seed: int = 42) -> None: | |
| super().__init__() | |
| self._rng = random.Random(seed) | |
| self._current_task: Optional[Dict[str, Any]] = None | |
| self._tools_used: Set[str] = set() | |
| self._step_count: int = 0 | |
| self._extracted_signals: Optional[ContentSignals] = None | |
| self._validation_result: Optional[Dict[str, Any]] = None | |
| self._signals_extracted: bool = False | |
| self._obs: Optional[TrustObservation]= None | |
| self._state = TrustState() | |
| # ✅ FIX 3 — build a dict keyed by task_id for O(1) lookup | |
| self._tasks: Dict[str, Dict[str, Any]] = { | |
| t["task_id"]: t for t in TASKS | |
| } | |
| # ----------------------------------------------------------------------- | |
| # OpenEnv interface | |
| # ----------------------------------------------------------------------- | |
| def reset(self, seed=None, episode_id=None, **kwargs) -> TrustObservation: | |
| # ✅ FIX 1 — reset() is now correctly INSIDE the class | |
| if seed is not None: | |
| self._rng.seed(seed) | |
| # Pick task by episode_id if provided, else random from all 6 | |
| if episode_id and episode_id in self._tasks: | |
| task = self._tasks[episode_id] | |
| else: | |
| task = self._rng.choice(list(self._tasks.values())) | |
| self._current_task = task | |
| self._tools_used = set() | |
| self._step_count = 0 | |
| self._extracted_signals = None | |
| self._validation_result = None | |
| self._signals_extracted = False | |
| self._state = TrustState( | |
| episode_id=task["task_id"], | |
| step_count=0, | |
| current_task_id=task["task_id"], | |
| difficulty=task.get("difficulty", "medium"), | |
| risk_level=task.get("risk_level", "medium"), | |
| is_done=False, | |
| tools_used=[], | |
| signals_extracted=False, | |
| ) | |
| self._obs = TrustObservation( | |
| ticket_id=task["task_id"], | |
| post_text=task["post_text"], | |
| image_description=task.get("image_description", ""), | |
| step_number=0, | |
| done=False, | |
| ) | |
| return self._obs # ✅ FIX 2 — single clean return, stray return removed | |
| def step(self, action: TrustAction, timeouts: Optional[Any] = None, | |
| **kwargs) -> TrustObservation: | |
| if self._current_task is None or self._obs is None: | |
| raise RuntimeError("Call reset() before step().") | |
| if self._step_count >= MAX_STEPS: | |
| self._obs = TrustObservation( | |
| ticket_id=self._current_task["task_id"], | |
| post_text=self._obs.post_text, | |
| image_description=self._obs.image_description, | |
| step_number=self._step_count, | |
| done=True, | |
| reward=0.0, | |
| info={"reason": "timeout", "tools_used": list(self._tools_used)}, | |
| ) | |
| return self._obs | |
| atype = action.action_type | |
| if atype == "use_tool": | |
| return self._handle_tool(action) | |
| if atype == "extract_signals": | |
| return self._handle_signal_extraction(action) | |
| if atype == "final_decision": | |
| return self._handle_final_decision(action) | |
| raise ValueError(f"Unknown action_type: {atype!r}") | |
| def state(self) -> TrustState: | |
| return self._state | |
| # ----------------------------------------------------------------------- | |
| # Layer 1 — Tool handling | |
| # ----------------------------------------------------------------------- | |
| def _handle_tool(self, action: TrustAction) -> TrustObservation: | |
| tool = action.tool_name | |
| if tool not in TOOL_COSTS: | |
| raise ValueError(f"Unknown tool: {tool!r}") | |
| self._tools_used.add(tool) | |
| response = self._current_task["tool_responses"].get(tool, "No data found.") | |
| field_map = { | |
| "read_comments": "comments_found", | |
| "check_user_history": "user_history_found", | |
| "check_entity_status": "entity_status_found", | |
| "view_policy": "policy_found", | |
| } | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| self._state.tools_used = list(self._tools_used) | |
| obs_kwargs = { | |
| k: getattr(self._obs, k) | |
| for k in ("ticket_id", "post_text", "image_description", | |
| "comments_found", "user_history_found", | |
| "entity_status_found", "policy_found", | |
| "extracted_signals", "validation_result") | |
| } | |
| obs_kwargs[field_map[tool]] = response | |
| obs_kwargs["step_number"] = self._step_count | |
| obs_kwargs["done"] = False | |
| obs_kwargs["reward"] = None | |
| self._obs = TrustObservation(**obs_kwargs) | |
| return self._obs | |
| # ----------------------------------------------------------------------- | |
| # Layer 2 — Signal extraction + validation | |
| # ----------------------------------------------------------------------- | |
| def _handle_signal_extraction(self, action: TrustAction) -> TrustObservation: | |
| raw = action.signals | |
| raw.toxicity_level = max(0.0, min(1.0, float(raw.toxicity_level))) | |
| raw.confidence = max(0.0, min(1.0, float(raw.confidence))) | |
| if not isinstance(raw.content_flags, list): | |
| raw.content_flags = [] | |
| self._extracted_signals = raw | |
| self._signals_extracted = True | |
| self._validation_result = self._validate_signals(raw) | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| self._state.signals_extracted = True | |
| obs_kwargs = { | |
| k: getattr(self._obs, k) | |
| for k in ("ticket_id", "post_text", "image_description", | |
| "comments_found", "user_history_found", | |
| "entity_status_found", "policy_found") | |
| } | |
| obs_kwargs["extracted_signals"] = { | |
| "target": raw.target, | |
| "is_protected_class": raw.is_protected_class, | |
| "toxicity_level": raw.toxicity_level, | |
| "is_direct_attack": raw.is_direct_attack, | |
| "context_type": raw.context_type, | |
| "intent": raw.intent, | |
| "confidence": raw.confidence, | |
| "abusive_language_present": raw.abusive_language_present, | |
| "content_flags": raw.content_flags, | |
| } | |
| obs_kwargs["validation_result"] = self._validation_result | |
| obs_kwargs["step_number"] = self._step_count | |
| obs_kwargs["done"] = False | |
| obs_kwargs["reward"] = None | |
| self._obs = TrustObservation(**obs_kwargs) | |
| return self._obs | |
| def _validate_signals(self, s: ContentSignals) -> Dict[str, Any]: | |
| issues = [] | |
| conf = s.confidence | |
| if not s.abusive_language_present and s.toxicity_level > 0.75: | |
| issues.append("high_toxicity_without_abusive_language"); conf -= 0.15 | |
| if s.context_type in ("satire", "education") and s.intent == "harmful": | |
| issues.append("harmful_intent_contradicts_satire_context"); conf -= 0.12 | |
| if s.is_protected_class and s.target == "none": | |
| issues.append("protected_class_flagged_without_target"); conf -= 0.10 | |
| if s.is_direct_attack and s.toxicity_level < 0.25: | |
| issues.append("direct_attack_with_low_toxicity"); conf -= 0.12 | |
| if s.context_type == "spam" and "spam" not in s.content_flags \ | |
| and "phishing" not in s.content_flags: | |
| issues.append("spam_context_missing_spam_flag"); conf -= 0.08 | |
| adj = round(max(0.0, min(1.0, conf)), 3) | |
| return { | |
| "original_confidence": s.confidence, | |
| "adjusted_confidence": adj, | |
| "consistency_issues": issues, | |
| "issue_count": len(issues), | |
| "is_consistent": len(issues) == 0, | |
| "needs_escalation_hint": adj < 0.55 or len(issues) >= 2, | |
| } | |
| # ----------------------------------------------------------------------- | |
| # Layer 3 — Policy engine + final decision | |
| # ----------------------------------------------------------------------- | |
| def _policy_engine(self) -> Dict[str, Any]: | |
| if not self._extracted_signals or not self._validation_result: | |
| return {"recommended": "ESCALATE", "reason": "no_signals_extracted", | |
| "rule_strength": "weak"} | |
| s = self._extracted_signals | |
| adj = self._validation_result["adjusted_confidence"] | |
| if adj < 0.50: | |
| return {"recommended": "ESCALATE", "reason": "low_adjusted_confidence", | |
| "rule_strength": "medium"} | |
| if "spam" in s.content_flags or "phishing" in s.content_flags: | |
| return {"recommended": "REMOVE", "reason": "spam_or_phishing_detected", | |
| "rule_strength": "strong"} | |
| if s.is_protected_class and s.is_direct_attack \ | |
| and s.toxicity_level > 0.70 and adj > 0.70: | |
| return {"recommended": "REMOVE", | |
| "reason": "protected_class_direct_attack_high_toxicity", | |
| "rule_strength": "strong"} | |
| if s.toxicity_level > 0.85 and s.intent == "harmful": | |
| return {"recommended": "REMOVE", "reason": "high_toxicity_harmful_intent", | |
| "rule_strength": "strong"} | |
| if s.context_type in ("satire", "news", "education", "protest") \ | |
| and s.intent in ("satirical", "descriptive", "educational"): | |
| return {"recommended": "ALLOW_WITH_WARNING", | |
| "reason": "protected_expression_context", | |
| "rule_strength": "medium"} | |
| if 0.40 < s.toxicity_level < 0.75 and s.intent == "ambiguous": | |
| return {"recommended": "ESCALATE", "reason": "ambiguous_moderate_toxicity", | |
| "rule_strength": "medium"} | |
| return {"recommended": "ALLOW", "reason": "no_policy_violation_detected", | |
| "rule_strength": "medium"} | |
| def _handle_final_decision(self, action: TrustAction) -> TrustObservation: | |
| decision = action.final_decision | |
| components = self._compute_components(decision) | |
| policy_rec = components.pop("_policy_rec") | |
| reward = self._finalize_reward(components) | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| self._state.is_done = True | |
| components["final_reward"] = reward | |
| obs_kwargs = { | |
| k: getattr(self._obs, k) | |
| for k in ("ticket_id", "post_text", "image_description", | |
| "comments_found", "user_history_found", | |
| "entity_status_found", "policy_found", | |
| "extracted_signals", "validation_result") | |
| } | |
| obs_kwargs["step_number"] = self._step_count | |
| obs_kwargs["done"] = True | |
| obs_kwargs["reward"] = reward | |
| obs_kwargs["info"] = { | |
| "final_decision": decision, | |
| "ground_truth": self._current_task["ground_truth"], | |
| "policy_recommendation": policy_rec, | |
| "signals_extracted": self._signals_extracted, | |
| "tools_used": list(self._tools_used), | |
| "required_tools": self._current_task["required_tools"], | |
| "ambiguity_level": self._current_task["ambiguity_level"], | |
| "risk_level": self._current_task["risk_level"], | |
| "task_id": self._current_task["task_id"], | |
| "reward_breakdown": components, | |
| } | |
| self._obs = TrustObservation(**obs_kwargs) | |
| return self._obs | |
| # ----------------------------------------------------------------------- | |
| # 8-Component Reward Engine | |
| # ----------------------------------------------------------------------- | |
| def _compute_components(self, final_decision: str) -> Dict[str, Any]: | |
| gt = self._current_task["ground_truth"] | |
| required_tools = self._current_task["required_tools"] | |
| ambiguity = self._current_task["ambiguity_level"] | |
| risk_level = self._current_task["risk_level"] | |
| policy_rec = self._policy_engine() | |
| base_score = DECISION_MATRIX.get((final_decision, gt), 0.20) | |
| if final_decision == "ESCALATE" and ambiguity == "high": | |
| base_score = max(base_score, 0.70) | |
| is_correct = base_score >= 0.90 | |
| rule_weight = {"strong": 1.0, "medium": 0.70, "weak": 0.40}.get( | |
| policy_rec.get("rule_strength", "medium"), 0.70) | |
| policy_alignment = round( | |
| (+0.12 if final_decision == policy_rec["recommended"] else -0.18) * rule_weight, 4) | |
| signal_accuracy_bonus = self._compute_signal_accuracy() | |
| adj_conf = (self._validation_result["adjusted_confidence"] | |
| if self._validation_result else 0.50) | |
| should_escalate = adj_conf < 0.50 | |
| if should_escalate and final_decision == "ESCALATE": | |
| escalation_adj = +0.15 | |
| elif should_escalate and final_decision != "ESCALATE": | |
| escalation_adj = -0.18 | |
| elif not should_escalate and final_decision == "ESCALATE" and ambiguity == "low": | |
| escalation_adj = -0.20 | |
| elif not should_escalate and final_decision == "ESCALATE": | |
| escalation_adj = -0.10 | |
| else: | |
| escalation_adj = 0.0 | |
| signal_bonus = +0.05 if self._signals_extracted else -0.10 | |
| tool_cost = round(sum(TOOL_COSTS.get(t, 0.0) for t in self._tools_used), 4) | |
| missing_required = set(required_tools) - self._tools_used | |
| tool_miss_penalty = round(len(missing_required) * 0.25, 4) | |
| if self._validation_result: | |
| n = self._validation_result["issue_count"] | |
| validation_penalty = {0: 0.00, 1: 0.05, 2: 0.12}.get(n, 0.20) | |
| else: | |
| validation_penalty = 0.12 | |
| risk_penalty = 0.0 | |
| if not is_correct: | |
| risk_penalty = {"high": 0.20, "medium": 0.10, "low": 0.0}.get(risk_level, 0.0) | |
| if base_score < 0.50 and adj_conf > 0.80: | |
| confidence_penalty = 0.22 | |
| elif base_score < 0.50 and adj_conf > 0.65: | |
| confidence_penalty = 0.12 | |
| elif self._signals_extracted and final_decision == "ESCALATE" and adj_conf < 0.55: | |
| confidence_penalty = -0.10 | |
| else: | |
| confidence_penalty = 0.0 | |
| return { | |
| "base_score": base_score, | |
| "policy_alignment": policy_alignment, | |
| "signal_accuracy_bonus": signal_accuracy_bonus, | |
| "escalation_adj": escalation_adj, | |
| "signal_bonus": signal_bonus, | |
| "tool_cost": tool_cost, | |
| "tool_miss_penalty": tool_miss_penalty, | |
| "validation_penalty": validation_penalty, | |
| "risk_penalty": risk_penalty, | |
| "confidence_penalty": confidence_penalty, | |
| "_policy_rec": policy_rec, | |
| } | |
| def _finalize_reward(self, components: Dict[str, Any]) -> float: | |
| raw = ( | |
| components["base_score"] | |
| + components["policy_alignment"] | |
| + components["signal_accuracy_bonus"] | |
| + components["escalation_adj"] | |
| + components["signal_bonus"] | |
| - components["tool_cost"] | |
| - components["tool_miss_penalty"] | |
| - components["validation_penalty"] | |
| - components["risk_penalty"] | |
| - components["confidence_penalty"] | |
| ) | |
| return round(max(0.0, min(1.0, raw)), 4) | |
| def _compute_signal_accuracy(self) -> float: | |
| if not self._extracted_signals: | |
| return 0.0 | |
| gt = self._current_task.get("ground_truth_signals", {}) | |
| if not gt: | |
| return 0.05 | |
| s = self._extracted_signals | |
| score = 0.0 | |
| if s.target == gt.get("target"): score += 0.20 | |
| if s.intent == gt.get("intent"): score += 0.20 | |
| if s.context_type == gt.get("context_type"): score += 0.20 | |
| tox_diff = abs(s.toxicity_level - gt.get("toxicity_level", 0.5)) | |
| score += 0.20 if tox_diff <= 0.20 else (0.10 if tox_diff <= 0.35 else 0.0) | |
| gt_flags = set(gt.get("content_flags", [])) | |
| s_flags = set(s.content_flags) | |
| if gt_flags: | |
| score += 0.20 * min(1.0, len(gt_flags & s_flags) / len(gt_flags)) | |
| else: | |
| score += 0.20 if not s_flags else 0.10 | |
| return round(score * 0.15, 4) |