Spaces:
Sleeping
Sleeping
| """ | |
| Baseline inference script for the Content Moderation Queue OpenEnv environment. | |
| Runs a language model against all 3 tasks and reports reproducible scores. | |
| Required environment variables: | |
| API_BASE_URL - LLM API endpoint (OpenAI-compatible) | |
| MODEL_NAME - Model identifier | |
| HF_TOKEN - API key / Hugging Face token | |
| Stdout format strictly follows the START / STEP / END spec for automated scoring. | |
| Usage: | |
| export API_BASE_URL="https://api-inference.huggingface.co/v1" | |
| export MODEL_NAME="meta-llama/Meta-Llama-3-8B-Instruct" | |
| export HF_TOKEN="hf_..." | |
| python inference.py | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import time | |
| import ssl | |
| import certifi | |
| import httpx | |
| import requests | |
| from openai import OpenAI | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| API_BASE_URL = os.environ.get("API_BASE_URL", "https://api-inference.huggingface.co/v1") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860") | |
| TASKS = ["task_easy", "task_medium", "task_hard"] | |
| MAX_RETRIES = 3 | |
| # Use certifi CA bundle to fix LibreSSL / macOS SSL certificate errors | |
| _ssl_context = ssl.create_default_context(cafile=certifi.where()) | |
| _http_client = httpx.Client(verify=certifi.where()) | |
| client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL, http_client=_http_client) | |
| # --------------------------------------------------------------------------- | |
| # Environment API helpers | |
| # --------------------------------------------------------------------------- | |
| _session_id: str = "" | |
| def env_reset(task_id: str, seed: int = 42) -> dict: | |
| global _session_id | |
| # seed=42 ensures same post order every run → reproducible baseline scores | |
| resp = requests.post(f"{ENV_BASE_URL}/reset", params={"task_id": task_id, "seed": seed}, timeout=30, verify=certifi.where()) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| _session_id = data.get("session_id", "") | |
| return data | |
| def env_step(action_type: str, reasoning: str = "", violation_type: str = "none") -> dict: | |
| payload = { | |
| "action_type": action_type, | |
| "reasoning": reasoning, | |
| "violation_type": violation_type, | |
| } | |
| resp = requests.post(f"{ENV_BASE_URL}/step", params={"session_id": _session_id}, json=payload, timeout=30, verify=certifi.where()) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def env_state() -> dict: | |
| resp = requests.get(f"{ENV_BASE_URL}/state", params={"session_id": _session_id}, timeout=30, verify=certifi.where()) | |
| resp.raise_for_status() | |
| return resp.json() | |
| # --------------------------------------------------------------------------- | |
| # LLM agent | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are an expert content moderator. You will be shown social media posts | |
| and must decide the correct moderation action according to platform policy. | |
| POLICY SUMMARY: | |
| - approve : Content is safe and complies with all policies | |
| - warn : Minor first-time violation; send user a warning, keep content | |
| - remove : Clear policy violation; remove content, no account action | |
| - ban_temporary : Repeat offender OR serious single violation | |
| - ban_permanent : Zero-tolerance violation (CSAM, doxxing, 5+ prior violations, selling PII) | |
| - escalate : Ambiguous case requiring human review OR crisis situation (self-harm signals) | |
| IMPORTANT RULES: | |
| 1. Always read author history carefully — repeat violations escalate consequences | |
| 2. Context matters — gaming slang, creative writing, and recovery stories are NOT violations | |
| 3. For appeals with plausible evidence → escalate (do not auto-deny) | |
| 4. Self-harm crisis posts → escalate (do NOT remove — user may need support) | |
| 5. CSAM and PII selling → ban_permanent regardless of history | |
| You must respond with valid JSON only: | |
| { | |
| "action_type": "<one of the 6 actions>", | |
| "violation_type": "<spam|hate_speech|harassment|misinformation|csam|illegal_services|doxxing|self_harm_risk|none>", | |
| "reasoning": "<one sentence explanation>" | |
| }""" | |
| def build_user_prompt(obs: dict) -> str: | |
| lines = [] | |
| lines.append(f"POST ID: {obs['post_id']}") | |
| lines.append(f"CONTENT: {obs['content']}") | |
| lines.append(f"AUTHOR: {obs['author_id']}") | |
| history = obs.get("author_history", {}) | |
| violations = history.get("total_violations", 0) | |
| prior = history.get("prior_actions", []) | |
| lines.append(f"AUTHOR HISTORY: {violations} total violations") | |
| if prior: | |
| for p in prior[-3:]: # show last 3 | |
| lines.append(f" - {p.get('date','?')}: {p.get('action','?')} for {p.get('violation_type','?')}") | |
| if obs.get("is_appeal"): | |
| lines.append(f"THIS IS AN APPEAL. Reason: {obs.get('appeal_reason', 'No reason given')}") | |
| thread = obs.get("context_thread", []) | |
| if thread: | |
| lines.append("THREAD CONTEXT:") | |
| for t in thread: | |
| lines.append(f" [{t.get('action_taken','?')}] {t.get('content','')[:120]}") | |
| lines.append(f"\nAVAILABLE ACTIONS: {obs.get('available_actions', [])}") | |
| lines.append(f"TASK: {obs.get('task_description', '')}") | |
| lines.append("\nRespond with JSON only.") | |
| return "\n".join(lines) | |
| def call_llm(obs: dict) -> dict: | |
| user_prompt = build_user_prompt(obs) | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| max_tokens=256, | |
| temperature=0.0, | |
| ) | |
| raw = response.choices[0].message.content.strip() | |
| # Extract JSON if wrapped in markdown code fences | |
| if "```" in raw: | |
| raw = raw.split("```")[1] | |
| if raw.startswith("json"): | |
| raw = raw[4:] | |
| return json.loads(raw) | |
| except json.JSONDecodeError: | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(1) | |
| continue | |
| # Fallback: return safe default | |
| return {"action_type": "approve", "violation_type": "none", "reasoning": "JSON parse failed"} | |
| except Exception as e: | |
| if attempt < MAX_RETRIES - 1: | |
| time.sleep(2) | |
| continue | |
| print(f"[ERROR] LLM call failed after {MAX_RETRIES} attempts: {e}", file=sys.stderr) | |
| return {"action_type": "approve", "violation_type": "none", "reasoning": f"LLM error: {e}"} | |
| # --------------------------------------------------------------------------- | |
| # Episode runner | |
| # --------------------------------------------------------------------------- | |
| def run_episode(task_id: str) -> float: | |
| print(f"[START] task_id={task_id} model={MODEL_NAME}") | |
| sys.stdout.flush() | |
| obs = env_reset(task_id) | |
| step_num = 0 | |
| episode_reward = 0.0 | |
| done = False | |
| while not done: | |
| step_num += 1 | |
| post_id = obs.get("post_id", "unknown") | |
| # Agent decision | |
| llm_output = call_llm(obs) | |
| action_type = llm_output.get("action_type", "approve") | |
| violation_type = llm_output.get("violation_type", "none") | |
| reasoning = llm_output.get("reasoning", "") | |
| # Validate action against allowed actions for this task | |
| allowed = obs.get("available_actions", ["approve", "remove"]) | |
| if action_type not in allowed: | |
| action_type = "approve" # safe fallback | |
| # Submit to environment | |
| result = env_step(action_type, reasoning, violation_type) | |
| reward = result.get("reward", 0.0) | |
| done = result.get("done", True) | |
| info = result.get("info", {}) | |
| episode_reward += reward | |
| # Strict [STEP] log format | |
| print( | |
| f"[STEP] step={step_num} " | |
| f"post_id={post_id} " | |
| f"action={action_type} " | |
| f"violation={violation_type} " | |
| f"reward={reward:.4f} " | |
| f"cumulative={episode_reward:.4f} " | |
| f"done={str(done).lower()}" | |
| ) | |
| sys.stdout.flush() | |
| if not done: | |
| obs = result.get("observation", obs) | |
| # Final state | |
| final_state = env_state() | |
| num_decisions = len(final_state.get("decisions", [])) | |
| final_score = max(0.01, min(0.99, round(episode_reward / max(num_decisions, 1), 4))) | |
| print( | |
| f"[END] task_id={task_id} " | |
| f"total_steps={step_num} " | |
| f"total_reward={episode_reward:.4f} " | |
| f"final_score={final_score:.4f} " | |
| f"model={MODEL_NAME}" | |
| ) | |
| sys.stdout.flush() | |
| return final_score | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| if not HF_TOKEN: | |
| print("[ERROR] HF_TOKEN environment variable is not set.", file=sys.stderr) | |
| sys.exit(1) | |
| # Wait for environment to be ready | |
| for attempt in range(10): | |
| try: | |
| resp = requests.get(f"{ENV_BASE_URL}/health", timeout=10, verify=certifi.where()) | |
| if resp.status_code == 200: | |
| break | |
| except Exception: | |
| pass | |
| print(f"[INFO] Waiting for environment at {ENV_BASE_URL} ... ({attempt+1}/10)", file=sys.stderr) | |
| time.sleep(3) | |
| else: | |
| print(f"[ERROR] Environment not reachable at {ENV_BASE_URL}", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"[INFO] Environment ready. Running baseline on {len(TASKS)} tasks.", file=sys.stderr) | |
| print(f"[INFO] Model: {MODEL_NAME}", file=sys.stderr) | |
| print(f"[INFO] API: {API_BASE_URL}", file=sys.stderr) | |
| print("-" * 60, file=sys.stderr) | |
| all_scores = {} | |
| for task_id in TASKS: | |
| score = run_episode(task_id) | |
| all_scores[task_id] = score | |
| time.sleep(1) # brief pause between tasks | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("BASELINE RESULTS SUMMARY") | |
| print("=" * 60) | |
| for task_id, score in all_scores.items(): | |
| bar = "█" * int(score * 20) + "░" * (20 - int(score * 20)) | |
| print(f" {task_id:<15} {bar} {score:.4f}") | |
| avg = sum(all_scores.values()) / len(all_scores) | |
| print(f"\n {'OVERALL':<15} {avg:.4f}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |