""" Baseline inference script for the Content Moderation Queue OpenEnv environment. Runs a language model against all 3 tasks and reports reproducible scores. Required environment variables: API_BASE_URL - LLM API endpoint (OpenAI-compatible) MODEL_NAME - Model identifier HF_TOKEN - API key / Hugging Face token Stdout format strictly follows the START / STEP / END spec for automated scoring. Usage: export API_BASE_URL="https://api-inference.huggingface.co/v1" export MODEL_NAME="meta-llama/Meta-Llama-3-8B-Instruct" export HF_TOKEN="hf_..." python inference.py """ import os import sys import json import time import ssl import certifi import httpx import requests from openai import OpenAI # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api-inference.huggingface.co/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct") HF_TOKEN = os.environ.get("HF_TOKEN", "") ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860") TASKS = ["task_easy", "task_medium", "task_hard"] MAX_RETRIES = 3 # Use certifi CA bundle to fix LibreSSL / macOS SSL certificate errors _ssl_context = ssl.create_default_context(cafile=certifi.where()) _http_client = httpx.Client(verify=certifi.where()) client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL, http_client=_http_client) # --------------------------------------------------------------------------- # Environment API helpers # --------------------------------------------------------------------------- _session_id: str = "" def env_reset(task_id: str, seed: int = 42) -> dict: global _session_id # seed=42 ensures same post order every run → reproducible baseline scores resp = requests.post(f"{ENV_BASE_URL}/reset", params={"task_id": task_id, "seed": seed}, timeout=30, verify=certifi.where()) resp.raise_for_status() data = resp.json() _session_id = data.get("session_id", "") return data def env_step(action_type: str, reasoning: str = "", violation_type: str = "none") -> dict: payload = { "action_type": action_type, "reasoning": reasoning, "violation_type": violation_type, } resp = requests.post(f"{ENV_BASE_URL}/step", params={"session_id": _session_id}, json=payload, timeout=30, verify=certifi.where()) resp.raise_for_status() return resp.json() def env_state() -> dict: resp = requests.get(f"{ENV_BASE_URL}/state", params={"session_id": _session_id}, timeout=30, verify=certifi.where()) resp.raise_for_status() return resp.json() # --------------------------------------------------------------------------- # LLM agent # --------------------------------------------------------------------------- SYSTEM_PROMPT = """You are an expert content moderator. You will be shown social media posts and must decide the correct moderation action according to platform policy. POLICY SUMMARY: - approve : Content is safe and complies with all policies - warn : Minor first-time violation; send user a warning, keep content - remove : Clear policy violation; remove content, no account action - ban_temporary : Repeat offender OR serious single violation - ban_permanent : Zero-tolerance violation (CSAM, doxxing, 5+ prior violations, selling PII) - escalate : Ambiguous case requiring human review OR crisis situation (self-harm signals) IMPORTANT RULES: 1. Always read author history carefully — repeat violations escalate consequences 2. Context matters — gaming slang, creative writing, and recovery stories are NOT violations 3. For appeals with plausible evidence → escalate (do not auto-deny) 4. Self-harm crisis posts → escalate (do NOT remove — user may need support) 5. CSAM and PII selling → ban_permanent regardless of history You must respond with valid JSON only: { "action_type": "", "violation_type": "", "reasoning": "" }""" def build_user_prompt(obs: dict) -> str: lines = [] lines.append(f"POST ID: {obs['post_id']}") lines.append(f"CONTENT: {obs['content']}") lines.append(f"AUTHOR: {obs['author_id']}") history = obs.get("author_history", {}) violations = history.get("total_violations", 0) prior = history.get("prior_actions", []) lines.append(f"AUTHOR HISTORY: {violations} total violations") if prior: for p in prior[-3:]: # show last 3 lines.append(f" - {p.get('date','?')}: {p.get('action','?')} for {p.get('violation_type','?')}") if obs.get("is_appeal"): lines.append(f"THIS IS AN APPEAL. Reason: {obs.get('appeal_reason', 'No reason given')}") thread = obs.get("context_thread", []) if thread: lines.append("THREAD CONTEXT:") for t in thread: lines.append(f" [{t.get('action_taken','?')}] {t.get('content','')[:120]}") lines.append(f"\nAVAILABLE ACTIONS: {obs.get('available_actions', [])}") lines.append(f"TASK: {obs.get('task_description', '')}") lines.append("\nRespond with JSON only.") return "\n".join(lines) def call_llm(obs: dict) -> dict: user_prompt = build_user_prompt(obs) for attempt in range(MAX_RETRIES): try: response = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], max_tokens=256, temperature=0.0, ) raw = response.choices[0].message.content.strip() # Extract JSON if wrapped in markdown code fences if "```" in raw: raw = raw.split("```")[1] if raw.startswith("json"): raw = raw[4:] return json.loads(raw) except json.JSONDecodeError: if attempt < MAX_RETRIES - 1: time.sleep(1) continue # Fallback: return safe default return {"action_type": "approve", "violation_type": "none", "reasoning": "JSON parse failed"} except Exception as e: if attempt < MAX_RETRIES - 1: time.sleep(2) continue print(f"[ERROR] LLM call failed after {MAX_RETRIES} attempts: {e}", file=sys.stderr) return {"action_type": "approve", "violation_type": "none", "reasoning": f"LLM error: {e}"} # --------------------------------------------------------------------------- # Episode runner # --------------------------------------------------------------------------- def run_episode(task_id: str) -> float: print(f"[START] task_id={task_id} model={MODEL_NAME}") sys.stdout.flush() obs = env_reset(task_id) step_num = 0 episode_reward = 0.0 done = False while not done: step_num += 1 post_id = obs.get("post_id", "unknown") # Agent decision llm_output = call_llm(obs) action_type = llm_output.get("action_type", "approve") violation_type = llm_output.get("violation_type", "none") reasoning = llm_output.get("reasoning", "") # Validate action against allowed actions for this task allowed = obs.get("available_actions", ["approve", "remove"]) if action_type not in allowed: action_type = "approve" # safe fallback # Submit to environment result = env_step(action_type, reasoning, violation_type) reward = result.get("reward", 0.0) done = result.get("done", True) info = result.get("info", {}) episode_reward += reward # Strict [STEP] log format print( f"[STEP] step={step_num} " f"post_id={post_id} " f"action={action_type} " f"violation={violation_type} " f"reward={reward:.4f} " f"cumulative={episode_reward:.4f} " f"done={str(done).lower()}" ) sys.stdout.flush() if not done: obs = result.get("observation", obs) # Final state final_state = env_state() num_decisions = len(final_state.get("decisions", [])) final_score = max(0.01, min(0.99, round(episode_reward / max(num_decisions, 1), 4))) print( f"[END] task_id={task_id} " f"total_steps={step_num} " f"total_reward={episode_reward:.4f} " f"final_score={final_score:.4f} " f"model={MODEL_NAME}" ) sys.stdout.flush() return final_score # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): if not HF_TOKEN: print("[ERROR] HF_TOKEN environment variable is not set.", file=sys.stderr) sys.exit(1) # Wait for environment to be ready for attempt in range(10): try: resp = requests.get(f"{ENV_BASE_URL}/health", timeout=10, verify=certifi.where()) if resp.status_code == 200: break except Exception: pass print(f"[INFO] Waiting for environment at {ENV_BASE_URL} ... ({attempt+1}/10)", file=sys.stderr) time.sleep(3) else: print(f"[ERROR] Environment not reachable at {ENV_BASE_URL}", file=sys.stderr) sys.exit(1) print(f"[INFO] Environment ready. Running baseline on {len(TASKS)} tasks.", file=sys.stderr) print(f"[INFO] Model: {MODEL_NAME}", file=sys.stderr) print(f"[INFO] API: {API_BASE_URL}", file=sys.stderr) print("-" * 60, file=sys.stderr) all_scores = {} for task_id in TASKS: score = run_episode(task_id) all_scores[task_id] = score time.sleep(1) # brief pause between tasks # Summary print("\n" + "=" * 60) print("BASELINE RESULTS SUMMARY") print("=" * 60) for task_id, score in all_scores.items(): bar = "█" * int(score * 20) + "░" * (20 - int(score * 20)) print(f" {task_id:<15} {bar} {score:.4f}") avg = sum(all_scores.values()) / len(all_scores) print(f"\n {'OVERALL':<15} {avg:.4f}") print("=" * 60) if __name__ == "__main__": main()