openenv / inference.py
jeromerichard's picture
Trust & Safety RL Environment - OpenEnv Hackathon
74e3b5e
"""
inference.py β€” Trust & Safety RL Environment Evaluation
========================================================
MANDATORY env vars:
API_BASE_URL LLM endpoint (e.g. https://router.huggingface.co/v1)
MODEL_NAME Model ID (e.g. meta-llama/Llama-3.1-8B-Instruct)
HF_TOKEN API key
ENV_BASE_URL Environment server URL (default: http://localhost:8000)
"""
import os, json, time, requests
from openai import OpenAI
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000")
if not API_KEY:
raise EnvironmentError("Set HF_TOKEN (your Hugging Face / API key).")
if not MODEL_NAME:
raise EnvironmentError("Set MODEL_NAME.")
SYSTEM_PROMPT = """You are a Trust & Safety analyst. Follow this EXACT 3-stage pipeline.
Skipping any stage results in a score penalty.
════════════════════════════════════════════
STAGE 1 β€” Investigate (optional)
════════════════════════════════════════════
Use tools only when context is needed. For OBVIOUS violations (spam, phishing), skip.
{"action_type": "use_tool", "tool_name": "read_comments"}
{"action_type": "use_tool", "tool_name": "check_user_history"}
{"action_type": "use_tool", "tool_name": "check_entity_status"}
{"action_type": "use_tool", "tool_name": "view_policy"}
════════════════════════════════════════════
STAGE 2 β€” Extract signals (MANDATORY)
════════════════════════════════════════════
You MUST call extract_signals before final_decision. No exceptions.
You are a FEATURE EXTRACTOR β€” output what the content shows, do NOT decide.
{"action_type": "extract_signals", "signals": {
"target": "individual|group|religion|race|gender|political|none",
"is_protected_class": true|false,
"toxicity_level": 0.0-1.0,
"is_direct_attack": true|false,
"context_type": "statement|satire|news|meme|personal_attack|protest|education|spam",
"intent": "harmful|descriptive|ambiguous|satirical|educational",
"confidence": 0.0-1.0,
"abusive_language_present": true|false,
"content_flags": ["spam","phishing","slur","threat","misinformation",
"political_content","public_figure","competitive_language"]
}}
After submitting signals you receive VALIDATION. If needs_escalation_hint is true β†’ prefer ESCALATE.
════════════════════════════════════════════
STAGE 3 β€” Final decision (MANDATORY)
════════════════════════════════════════════
{"action_type": "final_decision", "final_decision": "ALLOW|REMOVE|ALLOW_WITH_WARNING|ESCALATE"}
DECISION RULES (apply in order):
1. Phishing / spam / impersonation β†’ REMOVE
2. Explicit hate speech, direct threats β†’ REMOVE
3. Political satire of a PUBLIC FIGURE β†’ ALLOW_WITH_WARNING
4. News / journalism / education β†’ ALLOW_WITH_WARNING
5. Gaming / sports banter (confirmed) β†’ ALLOW
6. Adjusted confidence after validation < 0.6 β†’ ESCALATE
7. Genuinely ambiguous, no clear rule β†’ ESCALATE
8. Safe, no signals triggered β†’ ALLOW
NEVER jump from Stage 1 directly to final_decision.
ALWAYS: [optional tools] β†’ extract_signals β†’ final_decision
Respond in strict JSON only. No markdown."""
SIGNAL_DEFAULTS = {
"target": "none", "is_protected_class": False,
"toxicity_level": 0.5, "is_direct_attack": False,
"context_type": "statement", "intent": "ambiguous",
"confidence": 0.5, "abusive_language_present": False,
"content_flags": [],
}
VALID_TOOLS = {"read_comments","check_user_history","check_entity_status","view_policy"}
VALID_DECISIONS = {"ALLOW","REMOVE","ALLOW_WITH_WARNING","ESCALATE"}
VALID_TYPES = {"use_tool","extract_signals","final_decision"}
def _clamp_signals(raw: dict) -> dict:
result = SIGNAL_DEFAULTS.copy()
for k in SIGNAL_DEFAULTS:
if k in raw:
result[k] = raw[k]
try:
result["toxicity_level"] = max(0.0, min(1.0, float(result["toxicity_level"])))
result["confidence"] = max(0.0, min(1.0, float(result["confidence"])))
except (TypeError, ValueError):
result["toxicity_level"] = 0.5
result["confidence"] = 0.5
if not isinstance(result["content_flags"], list):
result["content_flags"] = []
return result
def _parse(text: str) -> dict:
text = text.strip()
s, e = text.find("{"), text.rfind("}") + 1
if s == -1 or e == 0:
raise ValueError(f"No JSON in: {text}")
return json.loads(text[s:e])
def _normalize(raw: dict) -> dict:
t = raw.get("action_type", "")
if t not in VALID_TYPES:
return {"action_type": "final_decision", "final_decision": "ESCALATE"}
if t == "use_tool":
tool = raw.get("tool_name", "")
return {"action_type": "use_tool", "tool_name": tool} if tool in VALID_TOOLS \
else {"action_type": "final_decision", "final_decision": "ESCALATE"}
if t == "extract_signals":
sigs = raw.get("signals")
return {"action_type": "extract_signals", "signals": _clamp_signals(sigs)} \
if sigs else {"action_type": "final_decision", "final_decision": "ESCALATE"}
dec = raw.get("final_decision", "ESCALATE")
return {"action_type": "final_decision",
"final_decision": dec if dec in VALID_DECISIONS else "ESCALATE"}
def _obs_to_prompt(obs: dict) -> str:
lines = [
f"=== TICKET {obs.get('ticket_id','')} (Step {obs.get('step_number',0)}) ===",
f"\nPOST TEXT:\n{obs.get('post_text','')}",
f"\nIMAGE:\n{obs.get('image_description','')}",
]
for key, label in [
("comments_found","COMMENTS"),("user_history_found","USER HISTORY"),
("entity_status_found","ENTITY STATUS"),("policy_found","POLICY"),
]:
if obs.get(key):
lines.append(f"\n{label}:\n{obs[key]}")
if obs.get("extracted_signals"):
lines.append(f"\nYOUR EXTRACTED SIGNALS:\n{json.dumps(obs['extracted_signals'],indent=2)}")
if obs.get("validation_result"):
v = obs["validation_result"]
hint = "⚠️ YES β€” prefer ESCALATE" if v.get("needs_escalation_hint") else "No"
lines.append(
f"\nπŸ“‹ VALIDATION:\n"
f" Adj. Confidence : {v.get('adjusted_confidence')}\n"
f" Issues : {v.get('consistency_issues')}\n"
f" Escalation Hint : {hint}"
)
if not obs.get("extracted_signals"):
lines.append("\n⚠️ REMINDER: Call extract_signals before final_decision.")
lines.append("\nYour next action (strict JSON only):")
return "\n".join(lines)
def run_task(client: OpenAI, task_id: str) -> float:
for _ in range(30):
# CORRECT βœ… β€” pass task ID directly
r = requests.post(
f"{ENV_BASE_URL}/reset",
json={"episode_id": task_id}, # ← this is the only change
timeout=10
)
r.raise_for_status()
obs = r.json()
# Handle both flat (TrustObservation) and wrapped response
if isinstance(obs, dict) and "observation" in obs:
obs = obs["observation"]
if obs.get("ticket_id") == task_id:
break
else:
raise RuntimeError(f"Could not get task {task_id} after 30 resets.")
print(f"\n{'='*62}\nTask: {task_id} | Starting...\n{'='*62}")
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
final_reward = 0.0
for step_num in range(14):
messages.append({"role": "user", "content": _obs_to_prompt(obs)})
time.sleep(0.5)
resp = client.chat.completions.create(
model=MODEL_NAME, messages=messages, temperature=0.0,
response_format={"type": "json_object"},
)
llm_text = resp.choices[0].message.content or ""
messages.append({"role": "assistant", "content": llm_text})
try:
action = _normalize(_parse(llm_text))
except Exception as ex:
print(f" [Step {step_num+1}] Parse error: {ex}"); break
atype = action["action_type"]
if atype == "use_tool":
print(f" [Step {step_num+1}] πŸ”§ use_tool β†’ {action.get('tool_name')}")
elif atype == "extract_signals":
s = action.get("signals", {})
print(f" [Step {step_num+1}] πŸ” extract_signals β†’ "
f"intent={s.get('intent')} | ctx={s.get('context_type')} | "
f"tox={s.get('toxicity_level')} | conf={s.get('confidence')}")
else:
print(f" [Step {step_num+1}] βš–οΈ final_decision β†’ {action.get('final_decision')}")
r2 = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30)
r2.raise_for_status()
result = r2.json()
# Handle flat (TrustObservation) and wrapped response
if "observation" in result:
obs = result["observation"]
done = result.get("done", obs.get("done", False))
final_reward = float(result.get("reward") or obs.get("reward") or 0.0)
else:
obs = result
done = result.get("done", False)
final_reward = float(result.get("reward") or 0.0)
if done:
info = obs.get("info") or {}
bd = info.get("reward_breakdown", {})
pol = info.get("policy_recommendation", {})
vr = obs.get("validation_result") or {}
print(f"\n ── EPISODE COMPLETE {'─'*42}")
print(f" Decision: {info.get('final_decision','N/A')}")
print(f" Ground Truth: {info.get('ground_truth','N/A')}")
print(f" Policy Engine: {pol.get('recommended','N/A')} "
f"[{pol.get('rule_strength','?')} rule] ({pol.get('reason','?')})")
print(f" Signals Extracted: {'βœ…' if info.get('signals_extracted') else '❌ SKIPPED'}")
print(f" Tools Used: {info.get('tools_used', [])}")
print(f" Required Tools: {info.get('required_tools', [])}")
print(f" Adj. Confidence: {vr.get('adjusted_confidence','N/A')}")
print(f" Issues: {vr.get('consistency_issues',[])}")
print(f" Ambiguity / Risk: {info.get('ambiguity_level','?')} / {info.get('risk_level','?')}")
if bd:
print(f"\n ── Reward Breakdown {'─'*42}")
print(f" 1. Base Decision Score: {bd.get('base_score',0):+.4f}")
print(f" 2. Policy Alignment: {bd.get('policy_alignment',0):+.4f}")
print(f" 3. Signal Accuracy Bonus: {bd.get('signal_accuracy_bonus',0):+.4f}")
print(f" 4. Escalation Adjustment: {bd.get('escalation_adj',0):+.4f}")
print(f" 5. Signal Process Bonus: {bd.get('signal_bonus',0):+.4f}")
print(f" Tool Cost: -{bd.get('tool_cost',0):.4f}")
print(f" Tool Miss Penalty: -{bd.get('tool_miss_penalty',0):.4f}")
print(f" Validation Penalty: -{bd.get('validation_penalty',0):.4f}")
print(f" Risk Penalty: -{bd.get('risk_penalty',0):.4f}")
print(f" Confidence Discipline: -{bd.get('confidence_penalty',0):.4f}")
print(f" {'─'*60}")
print(f" FINAL REWARD: {bd.get('final_reward',0):.4f}")
print(f"\n SCORE: {final_reward:.4f}")
break
return final_reward
def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
print("=" * 62)
print("Trust & Safety RL Environment β€” Baseline Evaluation")
print("=" * 62)
print(f"Model : {MODEL_NAME}")
print(f"LLM API : {API_BASE_URL}")
print(f"Env Server : {ENV_BASE_URL}")
print(f"Reward : Accuracy Β· Policy Β· Signals Β· Escalation")
print(f" Tools Β· Consistency Β· Risk Β· Confidence")
tasks = [
("T-001", "Easy β€” Phishing Spam", "low"),
("T-002", "Medium β€” Gaming Banter", "low"),
("T-003", "Hard β€” Political Satire", "high"),
]
scores = []
for tid, desc, risk in tasks:
print(f"\n\n>>> {tid} | {desc} | Risk: {risk}")
scores.append((tid, desc, run_task(client, tid)))
print("\n" + "=" * 62)
print("FINAL BASELINE RESULTS")
print("=" * 62)
total = 0.0
for tid, desc, s in scores:
print(f" {tid} | {desc:<32} | {s:.4f} {'βœ… PASS' if s >= 0.6 else '❌ FAIL'}")
total += s
vals = [s for _, _, s in scores]
print(f"\n Average : {total/len(scores):.4f}")
print(f" Min : {min(vals):.4f} | Max : {max(vals):.4f}")
print("=" * 62)
if __name__ == "__main__":
main()