content-moderation-queue / inference.py
sniki28's picture
Upload inference.py with huggingface_hub
47e3e7a verified
"""
Baseline inference script for the Content Moderation Queue OpenEnv environment.
Runs a language model against all 3 tasks and reports reproducible scores.
Required environment variables:
API_BASE_URL - LLM API endpoint (OpenAI-compatible)
MODEL_NAME - Model identifier
HF_TOKEN - API key / Hugging Face token
Stdout format strictly follows the START / STEP / END spec for automated scoring.
Usage:
export API_BASE_URL="https://api-inference.huggingface.co/v1"
export MODEL_NAME="meta-llama/Meta-Llama-3-8B-Instruct"
export HF_TOKEN="hf_..."
python inference.py
"""
import os
import sys
import json
import time
import ssl
import certifi
import httpx
import requests
from openai import OpenAI
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api-inference.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860")
TASKS = ["task_easy", "task_medium", "task_hard"]
MAX_RETRIES = 3
# Use certifi CA bundle to fix LibreSSL / macOS SSL certificate errors
_ssl_context = ssl.create_default_context(cafile=certifi.where())
_http_client = httpx.Client(verify=certifi.where())
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL, http_client=_http_client)
# ---------------------------------------------------------------------------
# Environment API helpers
# ---------------------------------------------------------------------------
_session_id: str = ""
def env_reset(task_id: str, seed: int = 42) -> dict:
global _session_id
# seed=42 ensures same post order every run → reproducible baseline scores
resp = requests.post(f"{ENV_BASE_URL}/reset", params={"task_id": task_id, "seed": seed}, timeout=30, verify=certifi.where())
resp.raise_for_status()
data = resp.json()
_session_id = data.get("session_id", "")
return data
def env_step(action_type: str, reasoning: str = "", violation_type: str = "none") -> dict:
payload = {
"action_type": action_type,
"reasoning": reasoning,
"violation_type": violation_type,
}
resp = requests.post(f"{ENV_BASE_URL}/step", params={"session_id": _session_id}, json=payload, timeout=30, verify=certifi.where())
resp.raise_for_status()
return resp.json()
def env_state() -> dict:
resp = requests.get(f"{ENV_BASE_URL}/state", params={"session_id": _session_id}, timeout=30, verify=certifi.where())
resp.raise_for_status()
return resp.json()
# ---------------------------------------------------------------------------
# LLM agent
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """You are an expert content moderator. You will be shown social media posts
and must decide the correct moderation action according to platform policy.
POLICY SUMMARY:
- approve : Content is safe and complies with all policies
- warn : Minor first-time violation; send user a warning, keep content
- remove : Clear policy violation; remove content, no account action
- ban_temporary : Repeat offender OR serious single violation
- ban_permanent : Zero-tolerance violation (CSAM, doxxing, 5+ prior violations, selling PII)
- escalate : Ambiguous case requiring human review OR crisis situation (self-harm signals)
IMPORTANT RULES:
1. Always read author history carefully — repeat violations escalate consequences
2. Context matters — gaming slang, creative writing, and recovery stories are NOT violations
3. For appeals with plausible evidence → escalate (do not auto-deny)
4. Self-harm crisis posts → escalate (do NOT remove — user may need support)
5. CSAM and PII selling → ban_permanent regardless of history
You must respond with valid JSON only:
{
"action_type": "<one of the 6 actions>",
"violation_type": "<spam|hate_speech|harassment|misinformation|csam|illegal_services|doxxing|self_harm_risk|none>",
"reasoning": "<one sentence explanation>"
}"""
def build_user_prompt(obs: dict) -> str:
lines = []
lines.append(f"POST ID: {obs['post_id']}")
lines.append(f"CONTENT: {obs['content']}")
lines.append(f"AUTHOR: {obs['author_id']}")
history = obs.get("author_history", {})
violations = history.get("total_violations", 0)
prior = history.get("prior_actions", [])
lines.append(f"AUTHOR HISTORY: {violations} total violations")
if prior:
for p in prior[-3:]: # show last 3
lines.append(f" - {p.get('date','?')}: {p.get('action','?')} for {p.get('violation_type','?')}")
if obs.get("is_appeal"):
lines.append(f"THIS IS AN APPEAL. Reason: {obs.get('appeal_reason', 'No reason given')}")
thread = obs.get("context_thread", [])
if thread:
lines.append("THREAD CONTEXT:")
for t in thread:
lines.append(f" [{t.get('action_taken','?')}] {t.get('content','')[:120]}")
lines.append(f"\nAVAILABLE ACTIONS: {obs.get('available_actions', [])}")
lines.append(f"TASK: {obs.get('task_description', '')}")
lines.append("\nRespond with JSON only.")
return "\n".join(lines)
def call_llm(obs: dict) -> dict:
user_prompt = build_user_prompt(obs)
for attempt in range(MAX_RETRIES):
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=256,
temperature=0.0,
)
raw = response.choices[0].message.content.strip()
# Extract JSON if wrapped in markdown code fences
if "```" in raw:
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
return json.loads(raw)
except json.JSONDecodeError:
if attempt < MAX_RETRIES - 1:
time.sleep(1)
continue
# Fallback: return safe default
return {"action_type": "approve", "violation_type": "none", "reasoning": "JSON parse failed"}
except Exception as e:
if attempt < MAX_RETRIES - 1:
time.sleep(2)
continue
print(f"[ERROR] LLM call failed after {MAX_RETRIES} attempts: {e}", file=sys.stderr)
return {"action_type": "approve", "violation_type": "none", "reasoning": f"LLM error: {e}"}
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
def run_episode(task_id: str) -> float:
print(f"[START] task_id={task_id} model={MODEL_NAME}")
sys.stdout.flush()
obs = env_reset(task_id)
step_num = 0
episode_reward = 0.0
done = False
while not done:
step_num += 1
post_id = obs.get("post_id", "unknown")
# Agent decision
llm_output = call_llm(obs)
action_type = llm_output.get("action_type", "approve")
violation_type = llm_output.get("violation_type", "none")
reasoning = llm_output.get("reasoning", "")
# Validate action against allowed actions for this task
allowed = obs.get("available_actions", ["approve", "remove"])
if action_type not in allowed:
action_type = "approve" # safe fallback
# Submit to environment
result = env_step(action_type, reasoning, violation_type)
reward = result.get("reward", 0.0)
done = result.get("done", True)
info = result.get("info", {})
episode_reward += reward
# Strict [STEP] log format
print(
f"[STEP] step={step_num} "
f"post_id={post_id} "
f"action={action_type} "
f"violation={violation_type} "
f"reward={reward:.4f} "
f"cumulative={episode_reward:.4f} "
f"done={str(done).lower()}"
)
sys.stdout.flush()
if not done:
obs = result.get("observation", obs)
# Final state
final_state = env_state()
num_decisions = len(final_state.get("decisions", []))
final_score = max(0.01, min(0.99, round(episode_reward / max(num_decisions, 1), 4)))
print(
f"[END] task_id={task_id} "
f"total_steps={step_num} "
f"total_reward={episode_reward:.4f} "
f"final_score={final_score:.4f} "
f"model={MODEL_NAME}"
)
sys.stdout.flush()
return final_score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
if not HF_TOKEN:
print("[ERROR] HF_TOKEN environment variable is not set.", file=sys.stderr)
sys.exit(1)
# Wait for environment to be ready
for attempt in range(10):
try:
resp = requests.get(f"{ENV_BASE_URL}/health", timeout=10, verify=certifi.where())
if resp.status_code == 200:
break
except Exception:
pass
print(f"[INFO] Waiting for environment at {ENV_BASE_URL} ... ({attempt+1}/10)", file=sys.stderr)
time.sleep(3)
else:
print(f"[ERROR] Environment not reachable at {ENV_BASE_URL}", file=sys.stderr)
sys.exit(1)
print(f"[INFO] Environment ready. Running baseline on {len(TASKS)} tasks.", file=sys.stderr)
print(f"[INFO] Model: {MODEL_NAME}", file=sys.stderr)
print(f"[INFO] API: {API_BASE_URL}", file=sys.stderr)
print("-" * 60, file=sys.stderr)
all_scores = {}
for task_id in TASKS:
score = run_episode(task_id)
all_scores[task_id] = score
time.sleep(1) # brief pause between tasks
# Summary
print("\n" + "=" * 60)
print("BASELINE RESULTS SUMMARY")
print("=" * 60)
for task_id, score in all_scores.items():
bar = "█" * int(score * 20) + "░" * (20 - int(score * 20))
print(f" {task_id:<15} {bar} {score:.4f}")
avg = sum(all_scores.values()) / len(all_scores)
print(f"\n {'OVERALL':<15} {avg:.4f}")
print("=" * 60)
if __name__ == "__main__":
main()