""" inference.py — Trust & Safety RL Environment Evaluation ======================================================== MANDATORY env vars: API_BASE_URL LLM endpoint (e.g. https://router.huggingface.co/v1) MODEL_NAME Model ID (e.g. meta-llama/Llama-3.1-8B-Instruct) HF_TOKEN API key ENV_BASE_URL Environment server URL (default: http://localhost:8000) """ import os, json, time, requests from openai import OpenAI API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "") MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000") if not API_KEY: raise EnvironmentError("Set HF_TOKEN (your Hugging Face / API key).") if not MODEL_NAME: raise EnvironmentError("Set MODEL_NAME.") SYSTEM_PROMPT = """You are a Trust & Safety analyst. Follow this EXACT 3-stage pipeline. Skipping any stage results in a score penalty. ════════════════════════════════════════════ STAGE 1 — Investigate (optional) ════════════════════════════════════════════ Use tools only when context is needed. For OBVIOUS violations (spam, phishing), skip. {"action_type": "use_tool", "tool_name": "read_comments"} {"action_type": "use_tool", "tool_name": "check_user_history"} {"action_type": "use_tool", "tool_name": "check_entity_status"} {"action_type": "use_tool", "tool_name": "view_policy"} ════════════════════════════════════════════ STAGE 2 — Extract signals (MANDATORY) ════════════════════════════════════════════ You MUST call extract_signals before final_decision. No exceptions. You are a FEATURE EXTRACTOR — output what the content shows, do NOT decide. {"action_type": "extract_signals", "signals": { "target": "individual|group|religion|race|gender|political|none", "is_protected_class": true|false, "toxicity_level": 0.0-1.0, "is_direct_attack": true|false, "context_type": "statement|satire|news|meme|personal_attack|protest|education|spam", "intent": "harmful|descriptive|ambiguous|satirical|educational", "confidence": 0.0-1.0, "abusive_language_present": true|false, "content_flags": ["spam","phishing","slur","threat","misinformation", "political_content","public_figure","competitive_language"] }} After submitting signals you receive VALIDATION. If needs_escalation_hint is true → prefer ESCALATE. ════════════════════════════════════════════ STAGE 3 — Final decision (MANDATORY) ════════════════════════════════════════════ {"action_type": "final_decision", "final_decision": "ALLOW|REMOVE|ALLOW_WITH_WARNING|ESCALATE"} DECISION RULES (apply in order): 1. Phishing / spam / impersonation → REMOVE 2. Explicit hate speech, direct threats → REMOVE 3. Political satire of a PUBLIC FIGURE → ALLOW_WITH_WARNING 4. News / journalism / education → ALLOW_WITH_WARNING 5. Gaming / sports banter (confirmed) → ALLOW 6. Adjusted confidence after validation < 0.6 → ESCALATE 7. Genuinely ambiguous, no clear rule → ESCALATE 8. Safe, no signals triggered → ALLOW NEVER jump from Stage 1 directly to final_decision. ALWAYS: [optional tools] → extract_signals → final_decision Respond in strict JSON only. No markdown.""" SIGNAL_DEFAULTS = { "target": "none", "is_protected_class": False, "toxicity_level": 0.5, "is_direct_attack": False, "context_type": "statement", "intent": "ambiguous", "confidence": 0.5, "abusive_language_present": False, "content_flags": [], } VALID_TOOLS = {"read_comments","check_user_history","check_entity_status","view_policy"} VALID_DECISIONS = {"ALLOW","REMOVE","ALLOW_WITH_WARNING","ESCALATE"} VALID_TYPES = {"use_tool","extract_signals","final_decision"} def _clamp_signals(raw: dict) -> dict: result = SIGNAL_DEFAULTS.copy() for k in SIGNAL_DEFAULTS: if k in raw: result[k] = raw[k] try: result["toxicity_level"] = max(0.0, min(1.0, float(result["toxicity_level"]))) result["confidence"] = max(0.0, min(1.0, float(result["confidence"]))) except (TypeError, ValueError): result["toxicity_level"] = 0.5 result["confidence"] = 0.5 if not isinstance(result["content_flags"], list): result["content_flags"] = [] return result def _parse(text: str) -> dict: text = text.strip() s, e = text.find("{"), text.rfind("}") + 1 if s == -1 or e == 0: raise ValueError(f"No JSON in: {text}") return json.loads(text[s:e]) def _normalize(raw: dict) -> dict: t = raw.get("action_type", "") if t not in VALID_TYPES: return {"action_type": "final_decision", "final_decision": "ESCALATE"} if t == "use_tool": tool = raw.get("tool_name", "") return {"action_type": "use_tool", "tool_name": tool} if tool in VALID_TOOLS \ else {"action_type": "final_decision", "final_decision": "ESCALATE"} if t == "extract_signals": sigs = raw.get("signals") return {"action_type": "extract_signals", "signals": _clamp_signals(sigs)} \ if sigs else {"action_type": "final_decision", "final_decision": "ESCALATE"} dec = raw.get("final_decision", "ESCALATE") return {"action_type": "final_decision", "final_decision": dec if dec in VALID_DECISIONS else "ESCALATE"} def _obs_to_prompt(obs: dict) -> str: lines = [ f"=== TICKET {obs.get('ticket_id','')} (Step {obs.get('step_number',0)}) ===", f"\nPOST TEXT:\n{obs.get('post_text','')}", f"\nIMAGE:\n{obs.get('image_description','')}", ] for key, label in [ ("comments_found","COMMENTS"),("user_history_found","USER HISTORY"), ("entity_status_found","ENTITY STATUS"),("policy_found","POLICY"), ]: if obs.get(key): lines.append(f"\n{label}:\n{obs[key]}") if obs.get("extracted_signals"): lines.append(f"\nYOUR EXTRACTED SIGNALS:\n{json.dumps(obs['extracted_signals'],indent=2)}") if obs.get("validation_result"): v = obs["validation_result"] hint = "⚠️ YES — prefer ESCALATE" if v.get("needs_escalation_hint") else "No" lines.append( f"\n📋 VALIDATION:\n" f" Adj. Confidence : {v.get('adjusted_confidence')}\n" f" Issues : {v.get('consistency_issues')}\n" f" Escalation Hint : {hint}" ) if not obs.get("extracted_signals"): lines.append("\n⚠️ REMINDER: Call extract_signals before final_decision.") lines.append("\nYour next action (strict JSON only):") return "\n".join(lines) def run_task(client: OpenAI, task_id: str) -> float: for _ in range(30): # CORRECT ✅ — pass task ID directly r = requests.post( f"{ENV_BASE_URL}/reset", json={"episode_id": task_id}, # ← this is the only change timeout=10 ) r.raise_for_status() obs = r.json() # Handle both flat (TrustObservation) and wrapped response if isinstance(obs, dict) and "observation" in obs: obs = obs["observation"] if obs.get("ticket_id") == task_id: break else: raise RuntimeError(f"Could not get task {task_id} after 30 resets.") print(f"\n{'='*62}\nTask: {task_id} | Starting...\n{'='*62}") messages = [{"role": "system", "content": SYSTEM_PROMPT}] final_reward = 0.0 for step_num in range(14): messages.append({"role": "user", "content": _obs_to_prompt(obs)}) time.sleep(0.5) resp = client.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=0.0, response_format={"type": "json_object"}, ) llm_text = resp.choices[0].message.content or "" messages.append({"role": "assistant", "content": llm_text}) try: action = _normalize(_parse(llm_text)) except Exception as ex: print(f" [Step {step_num+1}] Parse error: {ex}"); break atype = action["action_type"] if atype == "use_tool": print(f" [Step {step_num+1}] 🔧 use_tool → {action.get('tool_name')}") elif atype == "extract_signals": s = action.get("signals", {}) print(f" [Step {step_num+1}] 🔍 extract_signals → " f"intent={s.get('intent')} | ctx={s.get('context_type')} | " f"tox={s.get('toxicity_level')} | conf={s.get('confidence')}") else: print(f" [Step {step_num+1}] ⚖️ final_decision → {action.get('final_decision')}") r2 = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30) r2.raise_for_status() result = r2.json() # Handle flat (TrustObservation) and wrapped response if "observation" in result: obs = result["observation"] done = result.get("done", obs.get("done", False)) final_reward = float(result.get("reward") or obs.get("reward") or 0.0) else: obs = result done = result.get("done", False) final_reward = float(result.get("reward") or 0.0) if done: info = obs.get("info") or {} bd = info.get("reward_breakdown", {}) pol = info.get("policy_recommendation", {}) vr = obs.get("validation_result") or {} print(f"\n ── EPISODE COMPLETE {'─'*42}") print(f" Decision: {info.get('final_decision','N/A')}") print(f" Ground Truth: {info.get('ground_truth','N/A')}") print(f" Policy Engine: {pol.get('recommended','N/A')} " f"[{pol.get('rule_strength','?')} rule] ({pol.get('reason','?')})") print(f" Signals Extracted: {'✅' if info.get('signals_extracted') else '❌ SKIPPED'}") print(f" Tools Used: {info.get('tools_used', [])}") print(f" Required Tools: {info.get('required_tools', [])}") print(f" Adj. Confidence: {vr.get('adjusted_confidence','N/A')}") print(f" Issues: {vr.get('consistency_issues',[])}") print(f" Ambiguity / Risk: {info.get('ambiguity_level','?')} / {info.get('risk_level','?')}") if bd: print(f"\n ── Reward Breakdown {'─'*42}") print(f" 1. Base Decision Score: {bd.get('base_score',0):+.4f}") print(f" 2. Policy Alignment: {bd.get('policy_alignment',0):+.4f}") print(f" 3. Signal Accuracy Bonus: {bd.get('signal_accuracy_bonus',0):+.4f}") print(f" 4. Escalation Adjustment: {bd.get('escalation_adj',0):+.4f}") print(f" 5. Signal Process Bonus: {bd.get('signal_bonus',0):+.4f}") print(f" Tool Cost: -{bd.get('tool_cost',0):.4f}") print(f" Tool Miss Penalty: -{bd.get('tool_miss_penalty',0):.4f}") print(f" Validation Penalty: -{bd.get('validation_penalty',0):.4f}") print(f" Risk Penalty: -{bd.get('risk_penalty',0):.4f}") print(f" Confidence Discipline: -{bd.get('confidence_penalty',0):.4f}") print(f" {'─'*60}") print(f" FINAL REWARD: {bd.get('final_reward',0):.4f}") print(f"\n SCORE: {final_reward:.4f}") break return final_reward def main() -> None: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) print("=" * 62) print("Trust & Safety RL Environment — Baseline Evaluation") print("=" * 62) print(f"Model : {MODEL_NAME}") print(f"LLM API : {API_BASE_URL}") print(f"Env Server : {ENV_BASE_URL}") print(f"Reward : Accuracy · Policy · Signals · Escalation") print(f" Tools · Consistency · Risk · Confidence") tasks = [ ("T-001", "Easy — Phishing Spam", "low"), ("T-002", "Medium — Gaming Banter", "low"), ("T-003", "Hard — Political Satire", "high"), ] scores = [] for tid, desc, risk in tasks: print(f"\n\n>>> {tid} | {desc} | Risk: {risk}") scores.append((tid, desc, run_task(client, tid))) print("\n" + "=" * 62) print("FINAL BASELINE RESULTS") print("=" * 62) total = 0.0 for tid, desc, s in scores: print(f" {tid} | {desc:<32} | {s:.4f} {'✅ PASS' if s >= 0.6 else '❌ FAIL'}") total += s vals = [s for _, _, s in scores] print(f"\n Average : {total/len(scores):.4f}") print(f" Min : {min(vals):.4f} | Max : {max(vals):.4f}") print("=" * 62) if __name__ == "__main__": main()