Spaces:
Configuration error
Configuration error
| """ | |
| inference.py β Trust & Safety RL Environment Evaluation | |
| ======================================================== | |
| MANDATORY env vars: | |
| API_BASE_URL LLM endpoint (e.g. https://router.huggingface.co/v1) | |
| MODEL_NAME Model ID (e.g. meta-llama/Llama-3.1-8B-Instruct) | |
| HF_TOKEN API key | |
| ENV_BASE_URL Environment server URL (default: http://localhost:8000) | |
| """ | |
| import os, json, time, requests | |
| from openai import OpenAI | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") | |
| API_KEY = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") | |
| ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:8000") | |
| if not API_KEY: | |
| raise EnvironmentError("Set HF_TOKEN (your Hugging Face / API key).") | |
| if not MODEL_NAME: | |
| raise EnvironmentError("Set MODEL_NAME.") | |
| SYSTEM_PROMPT = """You are a Trust & Safety analyst. Follow this EXACT 3-stage pipeline. | |
| Skipping any stage results in a score penalty. | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| STAGE 1 β Investigate (optional) | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| Use tools only when context is needed. For OBVIOUS violations (spam, phishing), skip. | |
| {"action_type": "use_tool", "tool_name": "read_comments"} | |
| {"action_type": "use_tool", "tool_name": "check_user_history"} | |
| {"action_type": "use_tool", "tool_name": "check_entity_status"} | |
| {"action_type": "use_tool", "tool_name": "view_policy"} | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| STAGE 2 β Extract signals (MANDATORY) | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| You MUST call extract_signals before final_decision. No exceptions. | |
| You are a FEATURE EXTRACTOR β output what the content shows, do NOT decide. | |
| {"action_type": "extract_signals", "signals": { | |
| "target": "individual|group|religion|race|gender|political|none", | |
| "is_protected_class": true|false, | |
| "toxicity_level": 0.0-1.0, | |
| "is_direct_attack": true|false, | |
| "context_type": "statement|satire|news|meme|personal_attack|protest|education|spam", | |
| "intent": "harmful|descriptive|ambiguous|satirical|educational", | |
| "confidence": 0.0-1.0, | |
| "abusive_language_present": true|false, | |
| "content_flags": ["spam","phishing","slur","threat","misinformation", | |
| "political_content","public_figure","competitive_language"] | |
| }} | |
| After submitting signals you receive VALIDATION. If needs_escalation_hint is true β prefer ESCALATE. | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| STAGE 3 β Final decision (MANDATORY) | |
| ββββββββββββββββββββββββββββββββββββββββββββ | |
| {"action_type": "final_decision", "final_decision": "ALLOW|REMOVE|ALLOW_WITH_WARNING|ESCALATE"} | |
| DECISION RULES (apply in order): | |
| 1. Phishing / spam / impersonation β REMOVE | |
| 2. Explicit hate speech, direct threats β REMOVE | |
| 3. Political satire of a PUBLIC FIGURE β ALLOW_WITH_WARNING | |
| 4. News / journalism / education β ALLOW_WITH_WARNING | |
| 5. Gaming / sports banter (confirmed) β ALLOW | |
| 6. Adjusted confidence after validation < 0.6 β ESCALATE | |
| 7. Genuinely ambiguous, no clear rule β ESCALATE | |
| 8. Safe, no signals triggered β ALLOW | |
| NEVER jump from Stage 1 directly to final_decision. | |
| ALWAYS: [optional tools] β extract_signals β final_decision | |
| Respond in strict JSON only. No markdown.""" | |
| SIGNAL_DEFAULTS = { | |
| "target": "none", "is_protected_class": False, | |
| "toxicity_level": 0.5, "is_direct_attack": False, | |
| "context_type": "statement", "intent": "ambiguous", | |
| "confidence": 0.5, "abusive_language_present": False, | |
| "content_flags": [], | |
| } | |
| VALID_TOOLS = {"read_comments","check_user_history","check_entity_status","view_policy"} | |
| VALID_DECISIONS = {"ALLOW","REMOVE","ALLOW_WITH_WARNING","ESCALATE"} | |
| VALID_TYPES = {"use_tool","extract_signals","final_decision"} | |
| def _clamp_signals(raw: dict) -> dict: | |
| result = SIGNAL_DEFAULTS.copy() | |
| for k in SIGNAL_DEFAULTS: | |
| if k in raw: | |
| result[k] = raw[k] | |
| try: | |
| result["toxicity_level"] = max(0.0, min(1.0, float(result["toxicity_level"]))) | |
| result["confidence"] = max(0.0, min(1.0, float(result["confidence"]))) | |
| except (TypeError, ValueError): | |
| result["toxicity_level"] = 0.5 | |
| result["confidence"] = 0.5 | |
| if not isinstance(result["content_flags"], list): | |
| result["content_flags"] = [] | |
| return result | |
| def _parse(text: str) -> dict: | |
| text = text.strip() | |
| s, e = text.find("{"), text.rfind("}") + 1 | |
| if s == -1 or e == 0: | |
| raise ValueError(f"No JSON in: {text}") | |
| return json.loads(text[s:e]) | |
| def _normalize(raw: dict) -> dict: | |
| t = raw.get("action_type", "") | |
| if t not in VALID_TYPES: | |
| return {"action_type": "final_decision", "final_decision": "ESCALATE"} | |
| if t == "use_tool": | |
| tool = raw.get("tool_name", "") | |
| return {"action_type": "use_tool", "tool_name": tool} if tool in VALID_TOOLS \ | |
| else {"action_type": "final_decision", "final_decision": "ESCALATE"} | |
| if t == "extract_signals": | |
| sigs = raw.get("signals") | |
| return {"action_type": "extract_signals", "signals": _clamp_signals(sigs)} \ | |
| if sigs else {"action_type": "final_decision", "final_decision": "ESCALATE"} | |
| dec = raw.get("final_decision", "ESCALATE") | |
| return {"action_type": "final_decision", | |
| "final_decision": dec if dec in VALID_DECISIONS else "ESCALATE"} | |
| def _obs_to_prompt(obs: dict) -> str: | |
| lines = [ | |
| f"=== TICKET {obs.get('ticket_id','')} (Step {obs.get('step_number',0)}) ===", | |
| f"\nPOST TEXT:\n{obs.get('post_text','')}", | |
| f"\nIMAGE:\n{obs.get('image_description','')}", | |
| ] | |
| for key, label in [ | |
| ("comments_found","COMMENTS"),("user_history_found","USER HISTORY"), | |
| ("entity_status_found","ENTITY STATUS"),("policy_found","POLICY"), | |
| ]: | |
| if obs.get(key): | |
| lines.append(f"\n{label}:\n{obs[key]}") | |
| if obs.get("extracted_signals"): | |
| lines.append(f"\nYOUR EXTRACTED SIGNALS:\n{json.dumps(obs['extracted_signals'],indent=2)}") | |
| if obs.get("validation_result"): | |
| v = obs["validation_result"] | |
| hint = "β οΈ YES β prefer ESCALATE" if v.get("needs_escalation_hint") else "No" | |
| lines.append( | |
| f"\nπ VALIDATION:\n" | |
| f" Adj. Confidence : {v.get('adjusted_confidence')}\n" | |
| f" Issues : {v.get('consistency_issues')}\n" | |
| f" Escalation Hint : {hint}" | |
| ) | |
| if not obs.get("extracted_signals"): | |
| lines.append("\nβ οΈ REMINDER: Call extract_signals before final_decision.") | |
| lines.append("\nYour next action (strict JSON only):") | |
| return "\n".join(lines) | |
| def run_task(client: OpenAI, task_id: str) -> float: | |
| for _ in range(30): | |
| # CORRECT β β pass task ID directly | |
| r = requests.post( | |
| f"{ENV_BASE_URL}/reset", | |
| json={"episode_id": task_id}, # β this is the only change | |
| timeout=10 | |
| ) | |
| r.raise_for_status() | |
| obs = r.json() | |
| # Handle both flat (TrustObservation) and wrapped response | |
| if isinstance(obs, dict) and "observation" in obs: | |
| obs = obs["observation"] | |
| if obs.get("ticket_id") == task_id: | |
| break | |
| else: | |
| raise RuntimeError(f"Could not get task {task_id} after 30 resets.") | |
| print(f"\n{'='*62}\nTask: {task_id} | Starting...\n{'='*62}") | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| final_reward = 0.0 | |
| for step_num in range(14): | |
| messages.append({"role": "user", "content": _obs_to_prompt(obs)}) | |
| time.sleep(0.5) | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, messages=messages, temperature=0.0, | |
| response_format={"type": "json_object"}, | |
| ) | |
| llm_text = resp.choices[0].message.content or "" | |
| messages.append({"role": "assistant", "content": llm_text}) | |
| try: | |
| action = _normalize(_parse(llm_text)) | |
| except Exception as ex: | |
| print(f" [Step {step_num+1}] Parse error: {ex}"); break | |
| atype = action["action_type"] | |
| if atype == "use_tool": | |
| print(f" [Step {step_num+1}] π§ use_tool β {action.get('tool_name')}") | |
| elif atype == "extract_signals": | |
| s = action.get("signals", {}) | |
| print(f" [Step {step_num+1}] π extract_signals β " | |
| f"intent={s.get('intent')} | ctx={s.get('context_type')} | " | |
| f"tox={s.get('toxicity_level')} | conf={s.get('confidence')}") | |
| else: | |
| print(f" [Step {step_num+1}] βοΈ final_decision β {action.get('final_decision')}") | |
| r2 = requests.post(f"{ENV_BASE_URL}/step", json=action, timeout=30) | |
| r2.raise_for_status() | |
| result = r2.json() | |
| # Handle flat (TrustObservation) and wrapped response | |
| if "observation" in result: | |
| obs = result["observation"] | |
| done = result.get("done", obs.get("done", False)) | |
| final_reward = float(result.get("reward") or obs.get("reward") or 0.0) | |
| else: | |
| obs = result | |
| done = result.get("done", False) | |
| final_reward = float(result.get("reward") or 0.0) | |
| if done: | |
| info = obs.get("info") or {} | |
| bd = info.get("reward_breakdown", {}) | |
| pol = info.get("policy_recommendation", {}) | |
| vr = obs.get("validation_result") or {} | |
| print(f"\n ββ EPISODE COMPLETE {'β'*42}") | |
| print(f" Decision: {info.get('final_decision','N/A')}") | |
| print(f" Ground Truth: {info.get('ground_truth','N/A')}") | |
| print(f" Policy Engine: {pol.get('recommended','N/A')} " | |
| f"[{pol.get('rule_strength','?')} rule] ({pol.get('reason','?')})") | |
| print(f" Signals Extracted: {'β ' if info.get('signals_extracted') else 'β SKIPPED'}") | |
| print(f" Tools Used: {info.get('tools_used', [])}") | |
| print(f" Required Tools: {info.get('required_tools', [])}") | |
| print(f" Adj. Confidence: {vr.get('adjusted_confidence','N/A')}") | |
| print(f" Issues: {vr.get('consistency_issues',[])}") | |
| print(f" Ambiguity / Risk: {info.get('ambiguity_level','?')} / {info.get('risk_level','?')}") | |
| if bd: | |
| print(f"\n ββ Reward Breakdown {'β'*42}") | |
| print(f" 1. Base Decision Score: {bd.get('base_score',0):+.4f}") | |
| print(f" 2. Policy Alignment: {bd.get('policy_alignment',0):+.4f}") | |
| print(f" 3. Signal Accuracy Bonus: {bd.get('signal_accuracy_bonus',0):+.4f}") | |
| print(f" 4. Escalation Adjustment: {bd.get('escalation_adj',0):+.4f}") | |
| print(f" 5. Signal Process Bonus: {bd.get('signal_bonus',0):+.4f}") | |
| print(f" Tool Cost: -{bd.get('tool_cost',0):.4f}") | |
| print(f" Tool Miss Penalty: -{bd.get('tool_miss_penalty',0):.4f}") | |
| print(f" Validation Penalty: -{bd.get('validation_penalty',0):.4f}") | |
| print(f" Risk Penalty: -{bd.get('risk_penalty',0):.4f}") | |
| print(f" Confidence Discipline: -{bd.get('confidence_penalty',0):.4f}") | |
| print(f" {'β'*60}") | |
| print(f" FINAL REWARD: {bd.get('final_reward',0):.4f}") | |
| print(f"\n SCORE: {final_reward:.4f}") | |
| break | |
| return final_reward | |
| def main() -> None: | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| print("=" * 62) | |
| print("Trust & Safety RL Environment β Baseline Evaluation") | |
| print("=" * 62) | |
| print(f"Model : {MODEL_NAME}") | |
| print(f"LLM API : {API_BASE_URL}") | |
| print(f"Env Server : {ENV_BASE_URL}") | |
| print(f"Reward : Accuracy Β· Policy Β· Signals Β· Escalation") | |
| print(f" Tools Β· Consistency Β· Risk Β· Confidence") | |
| tasks = [ | |
| ("T-001", "Easy β Phishing Spam", "low"), | |
| ("T-002", "Medium β Gaming Banter", "low"), | |
| ("T-003", "Hard β Political Satire", "high"), | |
| ] | |
| scores = [] | |
| for tid, desc, risk in tasks: | |
| print(f"\n\n>>> {tid} | {desc} | Risk: {risk}") | |
| scores.append((tid, desc, run_task(client, tid))) | |
| print("\n" + "=" * 62) | |
| print("FINAL BASELINE RESULTS") | |
| print("=" * 62) | |
| total = 0.0 | |
| for tid, desc, s in scores: | |
| print(f" {tid} | {desc:<32} | {s:.4f} {'β PASS' if s >= 0.6 else 'β FAIL'}") | |
| total += s | |
| vals = [s for _, _, s in scores] | |
| print(f"\n Average : {total/len(scores):.4f}") | |
| print(f" Min : {min(vals):.4f} | Max : {max(vals):.4f}") | |
| print("=" * 62) | |
| if __name__ == "__main__": | |
| main() |