prompt_golf_env / server /judge.py
Don Rishabh
Pre-launch fixes: disable Qwen3 thinking, strip think blocks, degenerate-short guard
5abc867
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
LLM Judge backend for Prompt Golf.
Loads a larger frozen model (default: Qwen/Qwen3-8B) once at server startup
and uses it to score generated outputs on criteria that can't be captured
by structural/regex scorers — persona fidelity, reasoning quality, tone,
etc.
Quantized to 8-bit via bitsandbytes so a ~8B model occupies ~8 GB VRAM
alongside the target (2B bf16 ~4 GB) and the agent stack (~10 GB) on a
single L40S. CPU fallback also provided for smoke tests.
Select via env var PROMPT_GOLF_JUDGE_BACKEND=mock|hf (default "hf").
Select the model via PROMPT_GOLF_JUDGE_MODEL=Qwen/Qwen3-8B (default).
Disable quantization via PROMPT_GOLF_JUDGE_NO_QUANT=1.
"""
from __future__ import annotations
import os
import re
from typing import Optional, Protocol
DEFAULT_JUDGE_MODEL = "Qwen/Qwen3-8B"
# Judge prompt template. The judge must output a single float in [0, 1]
# on the first line; any further explanation is ignored.
JUDGE_PROMPT_TEMPLATE = """\
You are a strict grader. Given a task and a candidate output, return a single \
floating-point score in the range [0.0, 1.0] on the FIRST line of your \
response with NO other characters on that line.
Scoring convention:
1.0 = fully correct / fully satisfies the criterion
0.5 = partially correct
0.0 = incorrect or off-task
TASK DESCRIPTION:
{task_description}
EVALUATION CRITERION:
{criterion}
CANDIDATE OUTPUT:
{output}
{expected_block}First line: only the score (a number between 0 and 1).\
"""
class JudgeBackend(Protocol):
model_id: str
def score(self, task_description: str, output: str, criterion: str,
expected: Optional[str] = None) -> float:
...
# ---------------------------------------------------------------------------
# Mock backend (deterministic pattern-based, CPU-only)
# ---------------------------------------------------------------------------
class MockJudgeBackend:
"""Rule-based fake judge for local development / CI."""
model_id = "mock-judge"
_KEYWORD_MAP = {
"shakespearean": ("thou", "thy", "hath", "art", "doth", "ere"),
"pirate": ("arr", "matey", "ye", "ahoy", "booty", "plunder"),
"refusal": ("cannot", "won't", "unable", "decline", "against"),
"question": ("?",),
"terminal": ("$ ", ">>> ", "/bin/", "ls ", "cat "),
"stepwise": ("step 1", "step 2", "first", "therefore", "so"),
}
def score(self, task_description: str, output: str, criterion: str,
expected: Optional[str] = None) -> float:
crit = criterion.lower()
out_lc = output.lower()
# Pick the keyword group that matches the criterion
for key, keywords in self._KEYWORD_MAP.items():
if key in crit:
hits = sum(1 for kw in keywords if kw in out_lc)
return min(1.0, hits / max(1, len(keywords) // 2))
# Generic similarity to expected, if provided
if expected:
exp_lc = expected.lower()
tokens_exp = set(re.findall(r"\w+", exp_lc))
tokens_out = set(re.findall(r"\w+", out_lc))
if not tokens_exp:
return 0.0
hits = len(tokens_exp & tokens_out)
return min(1.0, hits / len(tokens_exp))
return 0.5 # neutral fallback
# ---------------------------------------------------------------------------
# HF backend (Qwen3-8B, 8-bit via bitsandbytes)
# ---------------------------------------------------------------------------
class HFJudgeBackend:
"""Transformers-based judge, 8-bit quantized by default.
Loaded lazily on first use. Uses greedy decoding for determinism.
"""
def __init__(self, model_id: str = DEFAULT_JUDGE_MODEL, load_in_8bit: bool = True):
self.model_id = model_id
self.load_in_8bit = load_in_8bit
self._model = None
self._tok = None
self._device = None
def _ensure_loaded(self) -> None:
if self._model is not None:
return
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
self._tok = AutoTokenizer.from_pretrained(self.model_id)
if self._tok.pad_token is None:
self._tok.pad_token = self._tok.eos_token
self._tok.padding_side = "left"
load_kwargs = {}
if self.load_in_8bit and torch.cuda.is_available():
# bitsandbytes 8-bit: ~half the bf16 footprint, negligible
# quality loss at ~8B scale.
try:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
except ImportError:
# bnb missing; fall back to bf16
load_kwargs["torch_dtype"] = torch.bfloat16
else:
load_kwargs["torch_dtype"] = (
torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
if torch.cuda.is_available():
load_kwargs["device_map"] = "auto"
self._model = AutoModelForCausalLM.from_pretrained(self.model_id, **load_kwargs)
self._model.eval()
self._device = next(self._model.parameters()).device
def score(self, task_description: str, output: str, criterion: str,
expected: Optional[str] = None) -> float:
self._ensure_loaded()
import torch
expected_block = ""
if expected:
expected_block = f"EXPECTED OUTPUT (reference):\n{expected}\n\n"
user_msg = JUDGE_PROMPT_TEMPLATE.format(
task_description=task_description,
criterion=criterion,
output=output,
expected_block=expected_block,
)
# Prefer chat template if present (Qwen3 has one). Disable thinking
# so the judge emits the score directly on line 1 instead of a
# <think>...</think> block before the score.
if getattr(self._tok, "chat_template", None):
messages = [
{"role": "system", "content": "You grade outputs numerically. Output only the score."},
{"role": "user", "content": user_msg},
]
try:
prompt_text = self._tok.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
prompt_text = self._tok.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
)
else:
prompt_text = user_msg
enc = self._tok(prompt_text, return_tensors="pt", truncation=True, max_length=3072).to(self._device)
with torch.inference_mode():
gen = self._model.generate(
**enc,
max_new_tokens=16, # only need the number
do_sample=False,
temperature=1.0,
top_p=1.0,
pad_token_id=self._tok.pad_token_id,
)
new_tokens = gen[0][enc["input_ids"].shape[1]:]
raw = self._tok.decode(new_tokens, skip_special_tokens=True).strip()
return self._parse_score(raw)
@staticmethod
def _parse_score(text: str) -> float:
first_line = text.strip().split("\n", 1)[0].strip()
# Try to parse a float from the first line
m = re.search(r"-?\d+\.?\d*", first_line)
if not m:
return 0.0
try:
v = float(m.group(0))
except ValueError:
return 0.0
return float(max(0.0, min(1.0, v)))
# ---------------------------------------------------------------------------
# Singleton factory
# ---------------------------------------------------------------------------
_SINGLETON: Optional[JudgeBackend] = None
def get_judge_backend() -> JudgeBackend:
"""Process-global judge, loaded on first call."""
global _SINGLETON
if _SINGLETON is not None:
return _SINGLETON
backend_kind = os.environ.get("PROMPT_GOLF_JUDGE_BACKEND", "hf").lower().strip()
if backend_kind == "mock":
_SINGLETON = MockJudgeBackend()
else:
model_id = os.environ.get("PROMPT_GOLF_JUDGE_MODEL", DEFAULT_JUDGE_MODEL)
load_in_8bit = os.environ.get("PROMPT_GOLF_JUDGE_NO_QUANT", "") == ""
_SINGLETON = HFJudgeBackend(model_id=model_id, load_in_8bit=load_in_8bit)
return _SINGLETON