openenv / client.py
jeromerichard's picture
Trust & Safety RL Environment - OpenEnv Hackathon
74e3b5e
from __future__ import annotations
from typing import Any
from openenv.core.env_client import EnvClient # βœ… correct import
from openenv.core.client_types import StepResult # βœ… correct import
from models import TrustAction, TrustObservation, TrustState, ContentSignals
class TrustSafetyEnv(EnvClient[TrustAction, TrustObservation, TrustState]): # βœ… EnvClient, 3 generics
"""
Typed WebSocket/HTTP client for the Trust & Safety RL Environment.
Usage (sync β€” for scripts, GRPOTrainer):
env = TrustSafetyEnv(base_url="http://localhost:8000").sync()
result = env.reset()
result = env.reset(episode_id="T-001")
result = env.step(TrustAction(action_type="use_tool", tool_name="view_policy"))
result = env.step(TrustAction(action_type="final_decision", final_decision="REMOVE"))
state = env.state()
env.close()
Usage (async):
async with TrustSafetyEnv(base_url="http://localhost:8000") as env:
result = await env.reset()
"""
def step_payload(self, action: TrustAction) -> dict: # βœ… NO underscore
payload: dict[str, Any] = {"action_type": action.action_type}
if action.tool_name is not None:
payload["tool_name"] = action.tool_name
if action.signals is not None:
s = action.signals
payload["signals"] = {
"target": s.target,
"is_protected_class": s.is_protected_class,
"toxicity_level": float(s.toxicity_level),
"is_direct_attack": s.is_direct_attack,
"context_type": s.context_type,
"intent": s.intent,
"confidence": float(s.confidence),
"abusive_language_present": s.abusive_language_present,
"content_flags": list(s.content_flags),
}
if action.final_decision is not None:
payload["final_decision"] = action.final_decision
return payload
def parse_result(self, payload: dict) -> StepResult[TrustObservation]: # βœ… NO underscore
obs_data = payload.get("observation", payload)
obs = TrustObservation(
ticket_id = obs_data.get("ticket_id", ""),
post_text = obs_data.get("post_text", ""),
image_description = obs_data.get("image_description", ""),
comments_found = obs_data.get("comments_found"),
user_history_found = obs_data.get("user_history_found"),
entity_status_found = obs_data.get("entity_status_found"),
policy_found = obs_data.get("policy_found"),
extracted_signals = obs_data.get("extracted_signals"),
validation_result = obs_data.get("validation_result"),
step_number = obs_data.get("step_number", 0),
info = obs_data.get("info"),
done = payload.get("done", obs_data.get("done", False)),
reward = payload.get("reward", obs_data.get("reward")),
)
return StepResult(
observation = obs,
reward = payload.get("reward", obs_data.get("reward")),
done = payload.get("done", obs_data.get("done", False)),
)
def parse_state(self, payload: dict) -> TrustState: # βœ… NO underscore
return TrustState(
episode_id = payload.get("episode_id"),
step_count = payload.get("step_count", 0),
current_task_id = payload.get("current_task_id"),
difficulty = payload.get("difficulty"),
ambiguity_level = payload.get("ambiguity_level"),
risk_level = payload.get("risk_level"),
tools_used = payload.get("tools_used", []),
signals_extracted = payload.get("signals_extracted", False),
is_done = payload.get("is_done", False),
)