openenv / server /your_environment.py
jeromerichard's picture
Fix: add server/app.py, uv.lock, project.scripts entry point
7cf2ffd
from __future__ import annotations
import random
import uuid
from typing import Optional, Dict, Any, Set
try:
from openenv_core.env_server import Environment
print("[env] Inheriting from openenv_core.env_server.Environment ✅")
except ImportError:
try:
from openenv.core.env_server import Environment
print("[env] Inheriting from openenv.core.env_server.Environment ✅")
except ImportError:
Environment = object
print("[env] openenv_core not found — using plain object base ⚠️")
from models import TrustObservation, TrustAction, TrustState, ContentSignals
from tasks import TASKS
TOOL_COSTS: Dict[str, float] = {
"read_comments": 0.05,
"check_user_history": 0.05,
"check_entity_status": 0.10,
"view_policy": 0.10,
}
MAX_STEPS = 7
DECISION_MATRIX: Dict[tuple, float] = {
("REMOVE", "REMOVE"): 1.00,
("ALLOW", "ALLOW"): 1.00,
("ALLOW_WITH_WARNING", "ALLOW_WITH_WARNING"): 1.00,
("ESCALATE", "ESCALATE"): 1.00,
("ALLOW_WITH_WARNING", "ALLOW"): 0.75,
("ALLOW", "ALLOW_WITH_WARNING"): 0.55,
("ESCALATE", "ALLOW_WITH_WARNING"): 0.65,
("ESCALATE", "ALLOW"): 0.45,
("ESCALATE", "REMOVE"): 0.45,
("REMOVE", "ALLOW"): 0.10,
("REMOVE", "ALLOW_WITH_WARNING"): 0.20,
("ALLOW", "REMOVE"): 0.00,
("ALLOW_WITH_WARNING", "REMOVE"): 0.15,
}
class TrustSafetyEnvironment(Environment):
"""
3-Layer Risk-Aware Trust & Safety RL Environment.
Layer 1 — Evidence gathering : agent uses investigation tools (optional)
Layer 2 — Signal extraction : agent outputs ContentSignals as feature extractor
Layer 3 — Policy engine : validates signals, applies rules, computes reward
8-Component Reward: Accuracy · Policy Alignment · Signal Quality · Escalation
Tool Usage · Consistency · Risk Sensitivity · Confidence
"""
def __init__(self, seed: int = 42) -> None:
super().__init__()
self._rng = random.Random(seed)
self._current_task: Optional[Dict[str, Any]] = None
self._tools_used: Set[str] = set()
self._step_count: int = 0
self._extracted_signals: Optional[ContentSignals] = None
self._validation_result: Optional[Dict[str, Any]] = None
self._signals_extracted: bool = False
self._obs: Optional[TrustObservation]= None
self._state = TrustState()
# ✅ FIX 3 — build a dict keyed by task_id for O(1) lookup
self._tasks: Dict[str, Dict[str, Any]] = {
t["task_id"]: t for t in TASKS
}
# -----------------------------------------------------------------------
# OpenEnv interface
# -----------------------------------------------------------------------
def reset(self, seed=None, episode_id=None, **kwargs) -> TrustObservation:
# ✅ FIX 1 — reset() is now correctly INSIDE the class
if seed is not None:
self._rng.seed(seed)
# Pick task by episode_id if provided, else random from all 6
if episode_id and episode_id in self._tasks:
task = self._tasks[episode_id]
else:
task = self._rng.choice(list(self._tasks.values()))
self._current_task = task
self._tools_used = set()
self._step_count = 0
self._extracted_signals = None
self._validation_result = None
self._signals_extracted = False
self._state = TrustState(
episode_id=task["task_id"],
step_count=0,
current_task_id=task["task_id"],
difficulty=task.get("difficulty", "medium"),
risk_level=task.get("risk_level", "medium"),
is_done=False,
tools_used=[],
signals_extracted=False,
)
self._obs = TrustObservation(
ticket_id=task["task_id"],
post_text=task["post_text"],
image_description=task.get("image_description", ""),
step_number=0,
done=False,
)
return self._obs # ✅ FIX 2 — single clean return, stray return removed
def step(self, action: TrustAction, timeouts: Optional[Any] = None,
**kwargs) -> TrustObservation:
if self._current_task is None or self._obs is None:
raise RuntimeError("Call reset() before step().")
if self._step_count >= MAX_STEPS:
self._obs = TrustObservation(
ticket_id=self._current_task["task_id"],
post_text=self._obs.post_text,
image_description=self._obs.image_description,
step_number=self._step_count,
done=True,
reward=0.0,
info={"reason": "timeout", "tools_used": list(self._tools_used)},
)
return self._obs
atype = action.action_type
if atype == "use_tool":
return self._handle_tool(action)
if atype == "extract_signals":
return self._handle_signal_extraction(action)
if atype == "final_decision":
return self._handle_final_decision(action)
raise ValueError(f"Unknown action_type: {atype!r}")
@property
def state(self) -> TrustState:
return self._state
# -----------------------------------------------------------------------
# Layer 1 — Tool handling
# -----------------------------------------------------------------------
def _handle_tool(self, action: TrustAction) -> TrustObservation:
tool = action.tool_name
if tool not in TOOL_COSTS:
raise ValueError(f"Unknown tool: {tool!r}")
self._tools_used.add(tool)
response = self._current_task["tool_responses"].get(tool, "No data found.")
field_map = {
"read_comments": "comments_found",
"check_user_history": "user_history_found",
"check_entity_status": "entity_status_found",
"view_policy": "policy_found",
}
self._step_count += 1
self._state.step_count = self._step_count
self._state.tools_used = list(self._tools_used)
obs_kwargs = {
k: getattr(self._obs, k)
for k in ("ticket_id", "post_text", "image_description",
"comments_found", "user_history_found",
"entity_status_found", "policy_found",
"extracted_signals", "validation_result")
}
obs_kwargs[field_map[tool]] = response
obs_kwargs["step_number"] = self._step_count
obs_kwargs["done"] = False
obs_kwargs["reward"] = None
self._obs = TrustObservation(**obs_kwargs)
return self._obs
# -----------------------------------------------------------------------
# Layer 2 — Signal extraction + validation
# -----------------------------------------------------------------------
def _handle_signal_extraction(self, action: TrustAction) -> TrustObservation:
raw = action.signals
raw.toxicity_level = max(0.0, min(1.0, float(raw.toxicity_level)))
raw.confidence = max(0.0, min(1.0, float(raw.confidence)))
if not isinstance(raw.content_flags, list):
raw.content_flags = []
self._extracted_signals = raw
self._signals_extracted = True
self._validation_result = self._validate_signals(raw)
self._step_count += 1
self._state.step_count = self._step_count
self._state.signals_extracted = True
obs_kwargs = {
k: getattr(self._obs, k)
for k in ("ticket_id", "post_text", "image_description",
"comments_found", "user_history_found",
"entity_status_found", "policy_found")
}
obs_kwargs["extracted_signals"] = {
"target": raw.target,
"is_protected_class": raw.is_protected_class,
"toxicity_level": raw.toxicity_level,
"is_direct_attack": raw.is_direct_attack,
"context_type": raw.context_type,
"intent": raw.intent,
"confidence": raw.confidence,
"abusive_language_present": raw.abusive_language_present,
"content_flags": raw.content_flags,
}
obs_kwargs["validation_result"] = self._validation_result
obs_kwargs["step_number"] = self._step_count
obs_kwargs["done"] = False
obs_kwargs["reward"] = None
self._obs = TrustObservation(**obs_kwargs)
return self._obs
def _validate_signals(self, s: ContentSignals) -> Dict[str, Any]:
issues = []
conf = s.confidence
if not s.abusive_language_present and s.toxicity_level > 0.75:
issues.append("high_toxicity_without_abusive_language"); conf -= 0.15
if s.context_type in ("satire", "education") and s.intent == "harmful":
issues.append("harmful_intent_contradicts_satire_context"); conf -= 0.12
if s.is_protected_class and s.target == "none":
issues.append("protected_class_flagged_without_target"); conf -= 0.10
if s.is_direct_attack and s.toxicity_level < 0.25:
issues.append("direct_attack_with_low_toxicity"); conf -= 0.12
if s.context_type == "spam" and "spam" not in s.content_flags \
and "phishing" not in s.content_flags:
issues.append("spam_context_missing_spam_flag"); conf -= 0.08
adj = round(max(0.0, min(1.0, conf)), 3)
return {
"original_confidence": s.confidence,
"adjusted_confidence": adj,
"consistency_issues": issues,
"issue_count": len(issues),
"is_consistent": len(issues) == 0,
"needs_escalation_hint": adj < 0.55 or len(issues) >= 2,
}
# -----------------------------------------------------------------------
# Layer 3 — Policy engine + final decision
# -----------------------------------------------------------------------
def _policy_engine(self) -> Dict[str, Any]:
if not self._extracted_signals or not self._validation_result:
return {"recommended": "ESCALATE", "reason": "no_signals_extracted",
"rule_strength": "weak"}
s = self._extracted_signals
adj = self._validation_result["adjusted_confidence"]
if adj < 0.50:
return {"recommended": "ESCALATE", "reason": "low_adjusted_confidence",
"rule_strength": "medium"}
if "spam" in s.content_flags or "phishing" in s.content_flags:
return {"recommended": "REMOVE", "reason": "spam_or_phishing_detected",
"rule_strength": "strong"}
if s.is_protected_class and s.is_direct_attack \
and s.toxicity_level > 0.70 and adj > 0.70:
return {"recommended": "REMOVE",
"reason": "protected_class_direct_attack_high_toxicity",
"rule_strength": "strong"}
if s.toxicity_level > 0.85 and s.intent == "harmful":
return {"recommended": "REMOVE", "reason": "high_toxicity_harmful_intent",
"rule_strength": "strong"}
if s.context_type in ("satire", "news", "education", "protest") \
and s.intent in ("satirical", "descriptive", "educational"):
return {"recommended": "ALLOW_WITH_WARNING",
"reason": "protected_expression_context",
"rule_strength": "medium"}
if 0.40 < s.toxicity_level < 0.75 and s.intent == "ambiguous":
return {"recommended": "ESCALATE", "reason": "ambiguous_moderate_toxicity",
"rule_strength": "medium"}
return {"recommended": "ALLOW", "reason": "no_policy_violation_detected",
"rule_strength": "medium"}
def _handle_final_decision(self, action: TrustAction) -> TrustObservation:
decision = action.final_decision
components = self._compute_components(decision)
policy_rec = components.pop("_policy_rec")
reward = self._finalize_reward(components)
self._step_count += 1
self._state.step_count = self._step_count
self._state.is_done = True
components["final_reward"] = reward
obs_kwargs = {
k: getattr(self._obs, k)
for k in ("ticket_id", "post_text", "image_description",
"comments_found", "user_history_found",
"entity_status_found", "policy_found",
"extracted_signals", "validation_result")
}
obs_kwargs["step_number"] = self._step_count
obs_kwargs["done"] = True
obs_kwargs["reward"] = reward
obs_kwargs["info"] = {
"final_decision": decision,
"ground_truth": self._current_task["ground_truth"],
"policy_recommendation": policy_rec,
"signals_extracted": self._signals_extracted,
"tools_used": list(self._tools_used),
"required_tools": self._current_task["required_tools"],
"ambiguity_level": self._current_task["ambiguity_level"],
"risk_level": self._current_task["risk_level"],
"task_id": self._current_task["task_id"],
"reward_breakdown": components,
}
self._obs = TrustObservation(**obs_kwargs)
return self._obs
# -----------------------------------------------------------------------
# 8-Component Reward Engine
# -----------------------------------------------------------------------
def _compute_components(self, final_decision: str) -> Dict[str, Any]:
gt = self._current_task["ground_truth"]
required_tools = self._current_task["required_tools"]
ambiguity = self._current_task["ambiguity_level"]
risk_level = self._current_task["risk_level"]
policy_rec = self._policy_engine()
base_score = DECISION_MATRIX.get((final_decision, gt), 0.20)
if final_decision == "ESCALATE" and ambiguity == "high":
base_score = max(base_score, 0.70)
is_correct = base_score >= 0.90
rule_weight = {"strong": 1.0, "medium": 0.70, "weak": 0.40}.get(
policy_rec.get("rule_strength", "medium"), 0.70)
policy_alignment = round(
(+0.12 if final_decision == policy_rec["recommended"] else -0.18) * rule_weight, 4)
signal_accuracy_bonus = self._compute_signal_accuracy()
adj_conf = (self._validation_result["adjusted_confidence"]
if self._validation_result else 0.50)
should_escalate = adj_conf < 0.50
if should_escalate and final_decision == "ESCALATE":
escalation_adj = +0.15
elif should_escalate and final_decision != "ESCALATE":
escalation_adj = -0.18
elif not should_escalate and final_decision == "ESCALATE" and ambiguity == "low":
escalation_adj = -0.20
elif not should_escalate and final_decision == "ESCALATE":
escalation_adj = -0.10
else:
escalation_adj = 0.0
signal_bonus = +0.05 if self._signals_extracted else -0.10
tool_cost = round(sum(TOOL_COSTS.get(t, 0.0) for t in self._tools_used), 4)
missing_required = set(required_tools) - self._tools_used
tool_miss_penalty = round(len(missing_required) * 0.25, 4)
if self._validation_result:
n = self._validation_result["issue_count"]
validation_penalty = {0: 0.00, 1: 0.05, 2: 0.12}.get(n, 0.20)
else:
validation_penalty = 0.12
risk_penalty = 0.0
if not is_correct:
risk_penalty = {"high": 0.20, "medium": 0.10, "low": 0.0}.get(risk_level, 0.0)
if base_score < 0.50 and adj_conf > 0.80:
confidence_penalty = 0.22
elif base_score < 0.50 and adj_conf > 0.65:
confidence_penalty = 0.12
elif self._signals_extracted and final_decision == "ESCALATE" and adj_conf < 0.55:
confidence_penalty = -0.10
else:
confidence_penalty = 0.0
return {
"base_score": base_score,
"policy_alignment": policy_alignment,
"signal_accuracy_bonus": signal_accuracy_bonus,
"escalation_adj": escalation_adj,
"signal_bonus": signal_bonus,
"tool_cost": tool_cost,
"tool_miss_penalty": tool_miss_penalty,
"validation_penalty": validation_penalty,
"risk_penalty": risk_penalty,
"confidence_penalty": confidence_penalty,
"_policy_rec": policy_rec,
}
def _finalize_reward(self, components: Dict[str, Any]) -> float:
raw = (
components["base_score"]
+ components["policy_alignment"]
+ components["signal_accuracy_bonus"]
+ components["escalation_adj"]
+ components["signal_bonus"]
- components["tool_cost"]
- components["tool_miss_penalty"]
- components["validation_penalty"]
- components["risk_penalty"]
- components["confidence_penalty"]
)
return round(max(0.0, min(1.0, raw)), 4)
def _compute_signal_accuracy(self) -> float:
if not self._extracted_signals:
return 0.0
gt = self._current_task.get("ground_truth_signals", {})
if not gt:
return 0.05
s = self._extracted_signals
score = 0.0
if s.target == gt.get("target"): score += 0.20
if s.intent == gt.get("intent"): score += 0.20
if s.context_type == gt.get("context_type"): score += 0.20
tox_diff = abs(s.toxicity_level - gt.get("toxicity_level", 0.5))
score += 0.20 if tox_diff <= 0.20 else (0.10 if tox_diff <= 0.35 else 0.0)
gt_flags = set(gt.get("content_flags", []))
s_flags = set(s.content_flags)
if gt_flags:
score += 0.20 * min(1.0, len(gt_flags & s_flags) / len(gt_flags))
else:
score += 0.20 if not s_flags else 0.10
return round(score * 0.15, 4)