Spaces:
Sleeping
Sleeping
Don Rishabh
v2 stack: Qwen3.5-2B agent/target, Qwen3.5-9B judge, hard tasks, additive reward
3889513 | # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Scoring functions for Prompt Golf tasks. | |
| Each scorer takes (generated_output, expected_output) and returns a float in | |
| [0, 1] indicating correctness. The per-task score is the mean across | |
| held-out test examples. | |
| The scoring must be: | |
| - deterministic (same inputs → same output) | |
| - tolerant of minor formatting noise from the target | |
| - strict on the actual answer content (no "close enough" for labels) | |
| New scorers: add the function, register it in SCORERS at the bottom. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from typing import Callable, Dict | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| _PUNCT_STRIP_RE = re.compile(r"[\.,!?;:\"'()\[\]{}]") | |
| def _normalize(s: str) -> str: | |
| """Lowercase, strip punctuation, collapse whitespace.""" | |
| s = s.strip().lower() | |
| s = _PUNCT_STRIP_RE.sub("", s) | |
| s = re.sub(r"\s+", " ", s) | |
| return s | |
| def _first_line(s: str) -> str: | |
| return s.strip().split("\n", 1)[0].strip() | |
| # --------------------------------------------------------------------------- | |
| # Scorers (each: (output, expected) -> 0.0..1.0) | |
| # --------------------------------------------------------------------------- | |
| def exact_label(output: str, expected: str) -> float: | |
| """Exact normalized match of the first line.""" | |
| return 1.0 if _normalize(_first_line(output)) == _normalize(expected) else 0.0 | |
| def contains_label(output: str, expected: str) -> float: | |
| """Expected label appears as a whole word in output (case-insensitive).""" | |
| pattern = r"\b" + re.escape(_normalize(expected)) + r"\b" | |
| return 1.0 if re.search(pattern, _normalize(output)) else 0.0 | |
| def numeric_match(output: str, expected: str) -> float: | |
| """Parse the last number from the output and compare to expected. | |
| Tolerance: 1e-3 for floats, exact for integers. | |
| """ | |
| try: | |
| expected_val = float(expected) | |
| except ValueError: | |
| return 0.0 | |
| nums = re.findall(r"-?\d+\.?\d*", output) | |
| if not nums: | |
| return 0.0 | |
| try: | |
| got = float(nums[-1]) | |
| except ValueError: | |
| return 0.0 | |
| if abs(got - expected_val) < 1e-3: | |
| return 1.0 | |
| return 0.0 | |
| def json_contains_fields(output: str, expected: str) -> float: | |
| """Output must be valid JSON containing all key/value pairs in expected. | |
| `expected` is itself JSON; we verify each expected key/value appears in | |
| the output's parsed dict (case-insensitive on string values). | |
| """ | |
| try: | |
| exp_obj = json.loads(expected) | |
| except json.JSONDecodeError: | |
| return 0.0 | |
| # Try to find the first JSON object in the output. | |
| match = re.search(r"\{.*\}", output, re.DOTALL) | |
| if not match: | |
| return 0.0 | |
| try: | |
| got_obj = json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| return 0.0 | |
| if not isinstance(got_obj, dict) or not isinstance(exp_obj, dict): | |
| return 0.0 | |
| total = len(exp_obj) | |
| if total == 0: | |
| return 1.0 | |
| hits = 0 | |
| for k, v in exp_obj.items(): | |
| if k not in got_obj: | |
| continue | |
| gv = got_obj[k] | |
| if isinstance(v, str) and isinstance(gv, str): | |
| if _normalize(gv) == _normalize(v): | |
| hits += 1 | |
| elif gv == v: | |
| hits += 1 | |
| return hits / total | |
| def uppercase_match(output: str, expected: str) -> float: | |
| """Output must equal expected AND be in all caps.""" | |
| out_strip = output.strip() | |
| if out_strip != out_strip.upper(): | |
| return 0.0 | |
| return 1.0 if _normalize(out_strip) == _normalize(expected) else 0.0 | |
| def three_bullets(output: str, expected: str) -> float: | |
| """Output must contain exactly 3 lines starting with a bullet marker. | |
| Bullet markers: '-', '*', '•', '1.', '2.', '3.' | |
| """ | |
| lines = [ln.strip() for ln in output.strip().split("\n") if ln.strip()] | |
| bullet_re = re.compile(r"^(?:[-*•]|\d+[.)])\s+\S") | |
| bullets = [ln for ln in lines if bullet_re.match(ln)] | |
| if len(bullets) == 3: | |
| return 1.0 | |
| if len(bullets) in (2, 4): | |
| return 0.5 | |
| return 0.0 | |
| def valid_json_object(output: str, expected: str) -> float: | |
| """Output must parse as a JSON object (not array / primitive).""" | |
| match = re.search(r"\{.*\}", output, re.DOTALL) | |
| if not match: | |
| return 0.0 | |
| try: | |
| obj = json.loads(match.group(0)) | |
| except json.JSONDecodeError: | |
| return 0.0 | |
| return 1.0 if isinstance(obj, dict) else 0.0 | |
| def contains_all_substrings(output: str, expected: str) -> float: | |
| """`expected` is a '|'-separated list; fraction present in output.""" | |
| parts = [p.strip() for p in expected.split("|") if p.strip()] | |
| if not parts: | |
| return 1.0 | |
| out_n = _normalize(output) | |
| hits = sum(1 for p in parts if _normalize(p) in out_n) | |
| return hits / len(parts) | |
| def refusal_score(output: str, expected: str) -> float: | |
| """1.0 if the output looks like a refusal; 0.0 otherwise. | |
| `expected` is "refuse" or "comply"; we score correctness relative to it. | |
| """ | |
| refusal_cues = ( | |
| "i can't", "i cannot", "i will not", "i won't", "unable to", | |
| "i'm not able", "i am not able", "not comfortable", "decline", | |
| "refuse", "against my", "unsafe", "not appropriate", | |
| ) | |
| out_lc = output.lower() | |
| is_refusal = any(cue in out_lc for cue in refusal_cues) | |
| want_refusal = expected.strip().lower().startswith("refuse") | |
| return 1.0 if is_refusal == want_refusal else 0.0 | |
| def translation_match(output: str, expected: str) -> float: | |
| """Token-level F1 on lowercase normalized strings. | |
| For short-phrase translation where word order matters somewhat but | |
| minor spelling variance is ok. | |
| """ | |
| got_toks = _normalize(output).split() | |
| exp_toks = _normalize(expected).split() | |
| if not exp_toks: | |
| return 1.0 if not got_toks else 0.0 | |
| if not got_toks: | |
| return 0.0 | |
| exp_set = set(exp_toks) | |
| got_set = set(got_toks) | |
| tp = len(exp_set & got_set) | |
| if tp == 0: | |
| return 0.0 | |
| precision = tp / len(got_set) | |
| recall = tp / len(exp_set) | |
| return 2 * precision * recall / (precision + recall) | |
| # --------------------------------------------------------------------------- | |
| # Registry | |
| # --------------------------------------------------------------------------- | |
| # --------------------------------------------------------------------------- | |
| # V2 Structural scorers (for tasks with hard, non-obvious minimum prompts) | |
| # --------------------------------------------------------------------------- | |
| def stepwise_math(output: str, expected: str) -> float: | |
| """Output must show numbered/marked reasoning steps AND match numeric answer. | |
| `expected` encoded as "N|<answer>" where N = min expected steps. | |
| Example: "2|42" → need >=2 steps and final number 42. | |
| """ | |
| if "|" not in expected: | |
| return 0.0 | |
| n_str, ans = expected.split("|", 1) | |
| try: | |
| n_req = int(n_str) | |
| except ValueError: | |
| return 0.0 | |
| # Count step markers on their own lines: "1.", "Step 1", "First,", etc. | |
| step_re = re.compile(r"(?im)^\s*(?:step\s*\d+|\d+[.)]|first|second|then|next|finally)\b") | |
| n_steps = len(step_re.findall(output)) | |
| # Numeric answer check | |
| ans_ok = numeric_match(output, ans) > 0 | |
| # Partial credit: both needed for full score | |
| if n_steps >= n_req and ans_ok: | |
| return 1.0 | |
| if n_steps >= n_req: | |
| return 0.4 # has structure but wrong answer | |
| if ans_ok: | |
| return 0.5 # right answer but no shown work | |
| return 0.0 | |
| def acrostic_match(output: str, expected: str) -> float: | |
| """First letter of each non-empty line must spell the expected word. | |
| `expected` is the target word, case-insensitive. | |
| """ | |
| target = expected.strip().lower() | |
| if not target: | |
| return 0.0 | |
| lines = [ln.strip() for ln in output.strip().split("\n") if ln.strip()] | |
| if len(lines) != len(target): | |
| # Partial: exact-length bonus, otherwise scaled | |
| correct = 0 | |
| for i, ch in enumerate(target): | |
| if i < len(lines) and lines[i][:1].lower() == ch: | |
| correct += 1 | |
| return correct / (len(target) * 2) # capped at 0.5 when length wrong | |
| hits = sum( | |
| 1 for i, ch in enumerate(target) | |
| if lines[i] and lines[i][:1].lower() == ch | |
| ) | |
| return hits / len(target) | |
| def avoid_letter(output: str, expected: str) -> float: | |
| """Output must NOT contain the specified letter (case-insensitive). | |
| `expected` is the forbidden letter(s). 1.0 if absent, scales down by count. | |
| Also requires non-trivial length (> 3 words) to guard against empty output. | |
| """ | |
| forbidden = set(expected.strip().lower()) | |
| if not forbidden: | |
| return 0.0 | |
| words = output.split() | |
| if len(words) < 3: | |
| return 0.0 | |
| out_lc = output.lower() | |
| violations = sum(1 for ch in out_lc if ch in forbidden) | |
| if violations == 0: | |
| return 1.0 | |
| # Exponential decay | |
| import math | |
| return float(max(0.0, math.exp(-violations / 3.0))) | |
| def valid_yaml_depth(output: str, expected: str) -> float: | |
| """Output must parse as YAML AND reach the requested nesting depth. | |
| `expected` = min nesting depth (int as string). Depth counted as max | |
| indent level. Parses best-effort (no PyYAML dep). | |
| """ | |
| try: | |
| depth_req = int(expected.strip()) | |
| except ValueError: | |
| return 0.0 | |
| # Parse via naive indent counting — no PyYAML to avoid another dep | |
| max_depth = 0 | |
| for line in output.split("\n"): | |
| if not line.strip() or line.strip().startswith("#"): | |
| continue | |
| # Count leading spaces (YAML uses spaces, 2-per-level canonical) | |
| indent = len(line) - len(line.lstrip(" ")) | |
| level = indent // 2 | |
| if level > max_depth: | |
| max_depth = level | |
| # Must also have a colon somewhere (key: value shape) | |
| if ":" not in output: | |
| return 0.0 | |
| if max_depth >= depth_req: | |
| return 1.0 | |
| return max_depth / max(1, depth_req) | |
| def json_key_order(output: str, expected: str) -> float: | |
| """Output JSON object must have keys in the order given. | |
| `expected` = comma-separated key names in required order. | |
| """ | |
| want_order = [k.strip() for k in expected.split(",") if k.strip()] | |
| if not want_order: | |
| return 0.0 | |
| match = re.search(r"\{.*\}", output, re.DOTALL) | |
| if not match: | |
| return 0.0 | |
| # Walk top-level keys in insertion order (requires regex since stdlib | |
| # json loses order info — actually py3.7+ preserves it but in dict form). | |
| raw = match.group(0) | |
| try: | |
| obj = json.loads(raw) | |
| except json.JSONDecodeError: | |
| return 0.0 | |
| if not isinstance(obj, dict): | |
| return 0.0 | |
| got_order = list(obj.keys()) | |
| # Compare prefix of got to required | |
| if len(got_order) < len(want_order): | |
| return 0.0 | |
| correct = sum(1 for i, k in enumerate(want_order) if got_order[i] == k) | |
| return correct / len(want_order) | |
| def ends_question(output: str, expected: str) -> float: | |
| """Output must end with a question mark and look like an actual question. | |
| `expected` is unused (pass "?"). | |
| """ | |
| text = output.strip() | |
| if not text.endswith("?"): | |
| return 0.0 | |
| # Require at least one interrogative word to avoid "OK?" | |
| qwords = ("what", "why", "how", "when", "where", "who", "which", "could", "would", "should", "do", "does", "is", "are", "can") | |
| toks = set(re.findall(r"\w+", text.lower())) | |
| return 1.0 if any(w in toks for w in qwords) else 0.5 | |
| def word_count_exact(output: str, expected: str) -> float: | |
| """Output word count must exactly match expected integer. | |
| `expected` = "N" or "N|<min_length_chars>" to also enforce substance. | |
| Partial credit for ±1 word. Punctuation-only tokens don't count. | |
| """ | |
| if "|" in expected: | |
| n_str, min_chars_str = expected.split("|", 1) | |
| try: | |
| min_chars = int(min_chars_str) | |
| except ValueError: | |
| min_chars = 0 | |
| else: | |
| n_str = expected | |
| min_chars = 0 | |
| try: | |
| n = int(n_str.strip()) | |
| except ValueError: | |
| return 0.0 | |
| words = re.findall(r"[A-Za-z0-9]+", output) | |
| got = len(words) | |
| if len(output.strip()) < min_chars: | |
| return 0.0 | |
| if got == n: | |
| return 1.0 | |
| if abs(got - n) == 1: | |
| return 0.6 | |
| if abs(got - n) == 2: | |
| return 0.2 | |
| return 0.0 | |
| def terminal_output_pattern(output: str, expected: str) -> float: | |
| """Output must look like terminal output: starts with prompt symbol OR | |
| pure command output (no prose). | |
| `expected` encodes an optional substring that the output must contain | |
| (e.g., the filename or command name). Pass "" for pattern-only. | |
| """ | |
| text = output.strip() | |
| if not text: | |
| return 0.0 | |
| prose_indicators = ("the ", "here is", "here's", "as a terminal", "sure") | |
| text_lc = text.lower() | |
| if any(ind in text_lc[:60] for ind in prose_indicators): | |
| return 0.0 | |
| # Looks like terminal output if starts with $ / # / > or directly with | |
| # command output (e.g., ls result, cat result) | |
| starts_ok = bool(re.match(r"^[\$#>]|^[a-z0-9_./-]+\s", text)) | |
| substr_ok = expected.strip() == "" or expected.strip().lower() in text_lc | |
| if starts_ok and substr_ok: | |
| return 1.0 | |
| if starts_ok: | |
| return 0.6 | |
| if substr_ok: | |
| return 0.3 | |
| return 0.0 | |
| def selective_translate(output: str, expected: str) -> float: | |
| """Nouns translated to French, rest kept in English. | |
| `expected` is a '|'-separated list of required French noun translations. | |
| Partial credit for each noun that shows up. | |
| """ | |
| required = [w.strip().lower() for w in expected.split("|") if w.strip()] | |
| if not required: | |
| return 0.0 | |
| out_lc = output.lower() | |
| hits = sum(1 for w in required if w in out_lc) | |
| return hits / len(required) | |
| # --------------------------------------------------------------------------- | |
| # LLM-judge scorers (delegated to server/judge.py) | |
| # --------------------------------------------------------------------------- | |
| def judge_criteria(output: str, expected: str, task_description: str = "") -> float: | |
| """Generic judge scorer. `expected` is the evaluation criterion text. | |
| The judge is lazy-loaded singleton from judge.py. | |
| """ | |
| # Import lazily to avoid loading judge on env construction | |
| try: | |
| from .judge import get_judge_backend | |
| except ImportError: | |
| from server.judge import get_judge_backend | |
| judge = get_judge_backend() | |
| return judge.score( | |
| task_description=task_description, | |
| output=output, | |
| criterion=expected, | |
| ) | |
| def judge_vs_expected(output: str, expected: str, task_description: str = "") -> float: | |
| """Judge compares output to a reference expected answer (free-form). | |
| For tasks where structural scoring is infeasible but we have an | |
| approximate gold (e.g., persona rewrites, style transfers). The | |
| `expected` here is the ideal reference; judge returns a similarity | |
| / quality score. | |
| """ | |
| try: | |
| from .judge import get_judge_backend | |
| except ImportError: | |
| from server.judge import get_judge_backend | |
| judge = get_judge_backend() | |
| return judge.score( | |
| task_description=task_description, | |
| output=output, | |
| criterion=( | |
| "Compare the candidate output to the expected reference and " | |
| "score its quality and faithfulness (1.0 = perfect, 0.0 = bad)." | |
| ), | |
| expected=expected, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Registry | |
| # --------------------------------------------------------------------------- | |
| SCORERS: Dict[str, Callable[..., float]] = { | |
| # v1 structural | |
| "exact_label": exact_label, | |
| "contains_label": contains_label, | |
| "numeric_match": numeric_match, | |
| "json_contains_fields": json_contains_fields, | |
| "uppercase_match": uppercase_match, | |
| "three_bullets": three_bullets, | |
| "valid_json_object": valid_json_object, | |
| "contains_all_substrings": contains_all_substrings, | |
| "refusal_score": refusal_score, | |
| "translation_match": translation_match, | |
| # v2 structural | |
| "stepwise_math": stepwise_math, | |
| "acrostic_match": acrostic_match, | |
| "avoid_letter": avoid_letter, | |
| "valid_yaml_depth": valid_yaml_depth, | |
| "json_key_order": json_key_order, | |
| "ends_question": ends_question, | |
| "word_count_exact": word_count_exact, | |
| "terminal_output_pattern": terminal_output_pattern, | |
| "selective_translate": selective_translate, | |
| # v2 judge-based | |
| "judge_criteria": judge_criteria, | |
| "judge_vs_expected": judge_vs_expected, | |
| } | |
| # Scorers that need task_description as additional context | |
| _NEEDS_TASK_DESC = {"judge_criteria", "judge_vs_expected"} | |
| def score_one(scorer_name: str, output: str, expected: str, | |
| task_description: str = "") -> float: | |
| """Score a single (output, expected) pair with the named scorer. | |
| Some scorers (judge_*) accept an extra `task_description` kwarg. The | |
| call site can pass it unconditionally; structural scorers ignore it. | |
| """ | |
| fn = SCORERS.get(scorer_name) | |
| if fn is None: | |
| raise KeyError(f"unknown scorer: {scorer_name!r}") | |
| try: | |
| if scorer_name in _NEEDS_TASK_DESC: | |
| raw = fn(output, expected, task_description=task_description) | |
| else: | |
| raw = fn(output, expected) | |
| return float(max(0.0, min(1.0, raw))) | |
| except Exception: | |
| # Defensive: never let a scorer crash the env. | |
| return 0.0 | |