from __future__ import annotations from typing import Any from openenv.core.env_client import EnvClient # ✅ correct import from openenv.core.client_types import StepResult # ✅ correct import from models import TrustAction, TrustObservation, TrustState, ContentSignals class TrustSafetyEnv(EnvClient[TrustAction, TrustObservation, TrustState]): # ✅ EnvClient, 3 generics """ Typed WebSocket/HTTP client for the Trust & Safety RL Environment. Usage (sync — for scripts, GRPOTrainer): env = TrustSafetyEnv(base_url="http://localhost:8000").sync() result = env.reset() result = env.reset(episode_id="T-001") result = env.step(TrustAction(action_type="use_tool", tool_name="view_policy")) result = env.step(TrustAction(action_type="final_decision", final_decision="REMOVE")) state = env.state() env.close() Usage (async): async with TrustSafetyEnv(base_url="http://localhost:8000") as env: result = await env.reset() """ def step_payload(self, action: TrustAction) -> dict: # ✅ NO underscore payload: dict[str, Any] = {"action_type": action.action_type} if action.tool_name is not None: payload["tool_name"] = action.tool_name if action.signals is not None: s = action.signals payload["signals"] = { "target": s.target, "is_protected_class": s.is_protected_class, "toxicity_level": float(s.toxicity_level), "is_direct_attack": s.is_direct_attack, "context_type": s.context_type, "intent": s.intent, "confidence": float(s.confidence), "abusive_language_present": s.abusive_language_present, "content_flags": list(s.content_flags), } if action.final_decision is not None: payload["final_decision"] = action.final_decision return payload def parse_result(self, payload: dict) -> StepResult[TrustObservation]: # ✅ NO underscore obs_data = payload.get("observation", payload) obs = TrustObservation( ticket_id = obs_data.get("ticket_id", ""), post_text = obs_data.get("post_text", ""), image_description = obs_data.get("image_description", ""), comments_found = obs_data.get("comments_found"), user_history_found = obs_data.get("user_history_found"), entity_status_found = obs_data.get("entity_status_found"), policy_found = obs_data.get("policy_found"), extracted_signals = obs_data.get("extracted_signals"), validation_result = obs_data.get("validation_result"), step_number = obs_data.get("step_number", 0), info = obs_data.get("info"), done = payload.get("done", obs_data.get("done", False)), reward = payload.get("reward", obs_data.get("reward")), ) return StepResult( observation = obs, reward = payload.get("reward", obs_data.get("reward")), done = payload.get("done", obs_data.get("done", False)), ) def parse_state(self, payload: dict) -> TrustState: # ✅ NO underscore return TrustState( episode_id = payload.get("episode_id"), step_count = payload.get("step_count", 0), current_task_id = payload.get("current_task_id"), difficulty = payload.get("difficulty"), ambiguity_level = payload.get("ambiguity_level"), risk_level = payload.get("risk_level"), tools_used = payload.get("tools_used", []), signals_extracted = payload.get("signals_extracted", False), is_done = payload.get("is_done", False), )