prompt_golf_env / server /scorer.py
Don Rishabh
v2 stack: Qwen3.5-2B agent/target, Qwen3.5-9B judge, hard tasks, additive reward
3889513
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Scoring functions for Prompt Golf tasks.
Each scorer takes (generated_output, expected_output) and returns a float in
[0, 1] indicating correctness. The per-task score is the mean across
held-out test examples.
The scoring must be:
- deterministic (same inputs → same output)
- tolerant of minor formatting noise from the target
- strict on the actual answer content (no "close enough" for labels)
New scorers: add the function, register it in SCORERS at the bottom.
"""
from __future__ import annotations
import json
import re
from typing import Callable, Dict
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PUNCT_STRIP_RE = re.compile(r"[\.,!?;:\"'()\[\]{}]")
def _normalize(s: str) -> str:
"""Lowercase, strip punctuation, collapse whitespace."""
s = s.strip().lower()
s = _PUNCT_STRIP_RE.sub("", s)
s = re.sub(r"\s+", " ", s)
return s
def _first_line(s: str) -> str:
return s.strip().split("\n", 1)[0].strip()
# ---------------------------------------------------------------------------
# Scorers (each: (output, expected) -> 0.0..1.0)
# ---------------------------------------------------------------------------
def exact_label(output: str, expected: str) -> float:
"""Exact normalized match of the first line."""
return 1.0 if _normalize(_first_line(output)) == _normalize(expected) else 0.0
def contains_label(output: str, expected: str) -> float:
"""Expected label appears as a whole word in output (case-insensitive)."""
pattern = r"\b" + re.escape(_normalize(expected)) + r"\b"
return 1.0 if re.search(pattern, _normalize(output)) else 0.0
def numeric_match(output: str, expected: str) -> float:
"""Parse the last number from the output and compare to expected.
Tolerance: 1e-3 for floats, exact for integers.
"""
try:
expected_val = float(expected)
except ValueError:
return 0.0
nums = re.findall(r"-?\d+\.?\d*", output)
if not nums:
return 0.0
try:
got = float(nums[-1])
except ValueError:
return 0.0
if abs(got - expected_val) < 1e-3:
return 1.0
return 0.0
def json_contains_fields(output: str, expected: str) -> float:
"""Output must be valid JSON containing all key/value pairs in expected.
`expected` is itself JSON; we verify each expected key/value appears in
the output's parsed dict (case-insensitive on string values).
"""
try:
exp_obj = json.loads(expected)
except json.JSONDecodeError:
return 0.0
# Try to find the first JSON object in the output.
match = re.search(r"\{.*\}", output, re.DOTALL)
if not match:
return 0.0
try:
got_obj = json.loads(match.group(0))
except json.JSONDecodeError:
return 0.0
if not isinstance(got_obj, dict) or not isinstance(exp_obj, dict):
return 0.0
total = len(exp_obj)
if total == 0:
return 1.0
hits = 0
for k, v in exp_obj.items():
if k not in got_obj:
continue
gv = got_obj[k]
if isinstance(v, str) and isinstance(gv, str):
if _normalize(gv) == _normalize(v):
hits += 1
elif gv == v:
hits += 1
return hits / total
def uppercase_match(output: str, expected: str) -> float:
"""Output must equal expected AND be in all caps."""
out_strip = output.strip()
if out_strip != out_strip.upper():
return 0.0
return 1.0 if _normalize(out_strip) == _normalize(expected) else 0.0
def three_bullets(output: str, expected: str) -> float:
"""Output must contain exactly 3 lines starting with a bullet marker.
Bullet markers: '-', '*', '•', '1.', '2.', '3.'
"""
lines = [ln.strip() for ln in output.strip().split("\n") if ln.strip()]
bullet_re = re.compile(r"^(?:[-*•]|\d+[.)])\s+\S")
bullets = [ln for ln in lines if bullet_re.match(ln)]
if len(bullets) == 3:
return 1.0
if len(bullets) in (2, 4):
return 0.5
return 0.0
def valid_json_object(output: str, expected: str) -> float:
"""Output must parse as a JSON object (not array / primitive)."""
match = re.search(r"\{.*\}", output, re.DOTALL)
if not match:
return 0.0
try:
obj = json.loads(match.group(0))
except json.JSONDecodeError:
return 0.0
return 1.0 if isinstance(obj, dict) else 0.0
def contains_all_substrings(output: str, expected: str) -> float:
"""`expected` is a '|'-separated list; fraction present in output."""
parts = [p.strip() for p in expected.split("|") if p.strip()]
if not parts:
return 1.0
out_n = _normalize(output)
hits = sum(1 for p in parts if _normalize(p) in out_n)
return hits / len(parts)
def refusal_score(output: str, expected: str) -> float:
"""1.0 if the output looks like a refusal; 0.0 otherwise.
`expected` is "refuse" or "comply"; we score correctness relative to it.
"""
refusal_cues = (
"i can't", "i cannot", "i will not", "i won't", "unable to",
"i'm not able", "i am not able", "not comfortable", "decline",
"refuse", "against my", "unsafe", "not appropriate",
)
out_lc = output.lower()
is_refusal = any(cue in out_lc for cue in refusal_cues)
want_refusal = expected.strip().lower().startswith("refuse")
return 1.0 if is_refusal == want_refusal else 0.0
def translation_match(output: str, expected: str) -> float:
"""Token-level F1 on lowercase normalized strings.
For short-phrase translation where word order matters somewhat but
minor spelling variance is ok.
"""
got_toks = _normalize(output).split()
exp_toks = _normalize(expected).split()
if not exp_toks:
return 1.0 if not got_toks else 0.0
if not got_toks:
return 0.0
exp_set = set(exp_toks)
got_set = set(got_toks)
tp = len(exp_set & got_set)
if tp == 0:
return 0.0
precision = tp / len(got_set)
recall = tp / len(exp_set)
return 2 * precision * recall / (precision + recall)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# V2 Structural scorers (for tasks with hard, non-obvious minimum prompts)
# ---------------------------------------------------------------------------
def stepwise_math(output: str, expected: str) -> float:
"""Output must show numbered/marked reasoning steps AND match numeric answer.
`expected` encoded as "N|<answer>" where N = min expected steps.
Example: "2|42" → need >=2 steps and final number 42.
"""
if "|" not in expected:
return 0.0
n_str, ans = expected.split("|", 1)
try:
n_req = int(n_str)
except ValueError:
return 0.0
# Count step markers on their own lines: "1.", "Step 1", "First,", etc.
step_re = re.compile(r"(?im)^\s*(?:step\s*\d+|\d+[.)]|first|second|then|next|finally)\b")
n_steps = len(step_re.findall(output))
# Numeric answer check
ans_ok = numeric_match(output, ans) > 0
# Partial credit: both needed for full score
if n_steps >= n_req and ans_ok:
return 1.0
if n_steps >= n_req:
return 0.4 # has structure but wrong answer
if ans_ok:
return 0.5 # right answer but no shown work
return 0.0
def acrostic_match(output: str, expected: str) -> float:
"""First letter of each non-empty line must spell the expected word.
`expected` is the target word, case-insensitive.
"""
target = expected.strip().lower()
if not target:
return 0.0
lines = [ln.strip() for ln in output.strip().split("\n") if ln.strip()]
if len(lines) != len(target):
# Partial: exact-length bonus, otherwise scaled
correct = 0
for i, ch in enumerate(target):
if i < len(lines) and lines[i][:1].lower() == ch:
correct += 1
return correct / (len(target) * 2) # capped at 0.5 when length wrong
hits = sum(
1 for i, ch in enumerate(target)
if lines[i] and lines[i][:1].lower() == ch
)
return hits / len(target)
def avoid_letter(output: str, expected: str) -> float:
"""Output must NOT contain the specified letter (case-insensitive).
`expected` is the forbidden letter(s). 1.0 if absent, scales down by count.
Also requires non-trivial length (> 3 words) to guard against empty output.
"""
forbidden = set(expected.strip().lower())
if not forbidden:
return 0.0
words = output.split()
if len(words) < 3:
return 0.0
out_lc = output.lower()
violations = sum(1 for ch in out_lc if ch in forbidden)
if violations == 0:
return 1.0
# Exponential decay
import math
return float(max(0.0, math.exp(-violations / 3.0)))
def valid_yaml_depth(output: str, expected: str) -> float:
"""Output must parse as YAML AND reach the requested nesting depth.
`expected` = min nesting depth (int as string). Depth counted as max
indent level. Parses best-effort (no PyYAML dep).
"""
try:
depth_req = int(expected.strip())
except ValueError:
return 0.0
# Parse via naive indent counting — no PyYAML to avoid another dep
max_depth = 0
for line in output.split("\n"):
if not line.strip() or line.strip().startswith("#"):
continue
# Count leading spaces (YAML uses spaces, 2-per-level canonical)
indent = len(line) - len(line.lstrip(" "))
level = indent // 2
if level > max_depth:
max_depth = level
# Must also have a colon somewhere (key: value shape)
if ":" not in output:
return 0.0
if max_depth >= depth_req:
return 1.0
return max_depth / max(1, depth_req)
def json_key_order(output: str, expected: str) -> float:
"""Output JSON object must have keys in the order given.
`expected` = comma-separated key names in required order.
"""
want_order = [k.strip() for k in expected.split(",") if k.strip()]
if not want_order:
return 0.0
match = re.search(r"\{.*\}", output, re.DOTALL)
if not match:
return 0.0
# Walk top-level keys in insertion order (requires regex since stdlib
# json loses order info — actually py3.7+ preserves it but in dict form).
raw = match.group(0)
try:
obj = json.loads(raw)
except json.JSONDecodeError:
return 0.0
if not isinstance(obj, dict):
return 0.0
got_order = list(obj.keys())
# Compare prefix of got to required
if len(got_order) < len(want_order):
return 0.0
correct = sum(1 for i, k in enumerate(want_order) if got_order[i] == k)
return correct / len(want_order)
def ends_question(output: str, expected: str) -> float:
"""Output must end with a question mark and look like an actual question.
`expected` is unused (pass "?").
"""
text = output.strip()
if not text.endswith("?"):
return 0.0
# Require at least one interrogative word to avoid "OK?"
qwords = ("what", "why", "how", "when", "where", "who", "which", "could", "would", "should", "do", "does", "is", "are", "can")
toks = set(re.findall(r"\w+", text.lower()))
return 1.0 if any(w in toks for w in qwords) else 0.5
def word_count_exact(output: str, expected: str) -> float:
"""Output word count must exactly match expected integer.
`expected` = "N" or "N|<min_length_chars>" to also enforce substance.
Partial credit for ±1 word. Punctuation-only tokens don't count.
"""
if "|" in expected:
n_str, min_chars_str = expected.split("|", 1)
try:
min_chars = int(min_chars_str)
except ValueError:
min_chars = 0
else:
n_str = expected
min_chars = 0
try:
n = int(n_str.strip())
except ValueError:
return 0.0
words = re.findall(r"[A-Za-z0-9]+", output)
got = len(words)
if len(output.strip()) < min_chars:
return 0.0
if got == n:
return 1.0
if abs(got - n) == 1:
return 0.6
if abs(got - n) == 2:
return 0.2
return 0.0
def terminal_output_pattern(output: str, expected: str) -> float:
"""Output must look like terminal output: starts with prompt symbol OR
pure command output (no prose).
`expected` encodes an optional substring that the output must contain
(e.g., the filename or command name). Pass "" for pattern-only.
"""
text = output.strip()
if not text:
return 0.0
prose_indicators = ("the ", "here is", "here's", "as a terminal", "sure")
text_lc = text.lower()
if any(ind in text_lc[:60] for ind in prose_indicators):
return 0.0
# Looks like terminal output if starts with $ / # / > or directly with
# command output (e.g., ls result, cat result)
starts_ok = bool(re.match(r"^[\$#>]|^[a-z0-9_./-]+\s", text))
substr_ok = expected.strip() == "" or expected.strip().lower() in text_lc
if starts_ok and substr_ok:
return 1.0
if starts_ok:
return 0.6
if substr_ok:
return 0.3
return 0.0
def selective_translate(output: str, expected: str) -> float:
"""Nouns translated to French, rest kept in English.
`expected` is a '|'-separated list of required French noun translations.
Partial credit for each noun that shows up.
"""
required = [w.strip().lower() for w in expected.split("|") if w.strip()]
if not required:
return 0.0
out_lc = output.lower()
hits = sum(1 for w in required if w in out_lc)
return hits / len(required)
# ---------------------------------------------------------------------------
# LLM-judge scorers (delegated to server/judge.py)
# ---------------------------------------------------------------------------
def judge_criteria(output: str, expected: str, task_description: str = "") -> float:
"""Generic judge scorer. `expected` is the evaluation criterion text.
The judge is lazy-loaded singleton from judge.py.
"""
# Import lazily to avoid loading judge on env construction
try:
from .judge import get_judge_backend
except ImportError:
from server.judge import get_judge_backend
judge = get_judge_backend()
return judge.score(
task_description=task_description,
output=output,
criterion=expected,
)
def judge_vs_expected(output: str, expected: str, task_description: str = "") -> float:
"""Judge compares output to a reference expected answer (free-form).
For tasks where structural scoring is infeasible but we have an
approximate gold (e.g., persona rewrites, style transfers). The
`expected` here is the ideal reference; judge returns a similarity
/ quality score.
"""
try:
from .judge import get_judge_backend
except ImportError:
from server.judge import get_judge_backend
judge = get_judge_backend()
return judge.score(
task_description=task_description,
output=output,
criterion=(
"Compare the candidate output to the expected reference and "
"score its quality and faithfulness (1.0 = perfect, 0.0 = bad)."
),
expected=expected,
)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
SCORERS: Dict[str, Callable[..., float]] = {
# v1 structural
"exact_label": exact_label,
"contains_label": contains_label,
"numeric_match": numeric_match,
"json_contains_fields": json_contains_fields,
"uppercase_match": uppercase_match,
"three_bullets": three_bullets,
"valid_json_object": valid_json_object,
"contains_all_substrings": contains_all_substrings,
"refusal_score": refusal_score,
"translation_match": translation_match,
# v2 structural
"stepwise_math": stepwise_math,
"acrostic_match": acrostic_match,
"avoid_letter": avoid_letter,
"valid_yaml_depth": valid_yaml_depth,
"json_key_order": json_key_order,
"ends_question": ends_question,
"word_count_exact": word_count_exact,
"terminal_output_pattern": terminal_output_pattern,
"selective_translate": selective_translate,
# v2 judge-based
"judge_criteria": judge_criteria,
"judge_vs_expected": judge_vs_expected,
}
# Scorers that need task_description as additional context
_NEEDS_TASK_DESC = {"judge_criteria", "judge_vs_expected"}
def score_one(scorer_name: str, output: str, expected: str,
task_description: str = "") -> float:
"""Score a single (output, expected) pair with the named scorer.
Some scorers (judge_*) accept an extra `task_description` kwarg. The
call site can pass it unconditionally; structural scorers ignore it.
"""
fn = SCORERS.get(scorer_name)
if fn is None:
raise KeyError(f"unknown scorer: {scorer_name!r}")
try:
if scorer_name in _NEEDS_TASK_DESC:
raw = fn(output, expected, task_description=task_description)
else:
raw = fn(output, expected)
return float(max(0.0, min(1.0, raw)))
except Exception:
# Defensive: never let a scorer crash the env.
return 0.0