opensleuth-demo / oracle.py
anugrah55's picture
Initial demo: live agent rollouts against OpenSleuth env
dcc4ca0 verified
"""Per-task reference implementations.
These are the *known-correct* solutions for each of the 15 tasks the OpenSleuth
env exposes. They mirror the rows pushed to ``anugrah55/opensleuth-tasks`` by
``env/opensleuth_env/scripts/bootstrap_tasks_dataset.py`` (which itself mirrors
the in-process oracle in ``env/opensleuth_env/black_box.py``).
The "oracle" demo backend just looks up the task name here and submits the
canonical source. It exists so the viewer can immediately see what a perfect
score looks like end-to-end (signature → probes → submit → +100 reward).
"""
from __future__ import annotations
from typing import Dict
ORACLE_SOLUTIONS: Dict[str, str] = {
# ---- 9 builtins -------------------------------------------------------
"fibonacci": (
"def fibonacci(n):\n"
" if not isinstance(n, int) or isinstance(n, bool) or n <= 0 or n > 90:\n"
" raise ValueError('Input must be a positive integer <= 90.')\n"
" a, b = 0, 1\n"
" for _ in range(n - 1):\n"
" a, b = b, a + b\n"
" return b if n > 0 else a\n"
),
"reverse_string": (
"def reverse_string(s):\n"
" if not isinstance(s, str):\n"
" raise TypeError('Input must be a string.')\n"
" return s[::-1]\n"
),
"is_palindrome": (
"def is_palindrome(s):\n"
" if not isinstance(s, str):\n"
" raise TypeError('Input must be a string.')\n"
" cleaned = ''.join(ch.lower() for ch in s if ch.isalnum())\n"
" return cleaned == cleaned[::-1]\n"
),
"digit_sum": (
"def digit_sum(n):\n"
" if not isinstance(n, int) or isinstance(n, bool):\n"
" raise TypeError('Input must be int.')\n"
" if n < 0:\n"
" raise ValueError('Input must be non-negative.')\n"
" return sum(int(c) for c in str(n))\n"
),
"count_vowels": (
"def count_vowels(s):\n"
" if not isinstance(s, str):\n"
" raise TypeError('Input must be a string.')\n"
" return sum(1 for c in s.lower() if c in 'aeiou')\n"
),
"gcd": (
"def gcd(pair):\n"
" if not isinstance(pair, (list, tuple)) or len(pair) != 2:\n"
" raise TypeError('Input must be a 2-element list or tuple.')\n"
" a, b = pair\n"
" if not all(isinstance(x, int) and not isinstance(x, bool) for x in (a, b)):\n"
" raise TypeError('Both elements must be int.')\n"
" if a < 0 or b < 0:\n"
" raise ValueError('Both elements must be non-negative.')\n"
" while b:\n"
" a, b = b, a % b\n"
" return a\n"
),
"sort_unique": (
"def sort_unique(xs):\n"
" if not isinstance(xs, list):\n"
" raise TypeError('Input must be a list.')\n"
" if not all(isinstance(x, int) and not isinstance(x, bool) for x in xs):\n"
" raise TypeError('All elements must be int.')\n"
" return sorted(set(xs))\n"
),
"caesar_cipher": (
"def caesar_cipher(s):\n"
" if not isinstance(s, str):\n"
" raise TypeError('Input must be a string.')\n"
" out = []\n"
" for ch in s:\n"
" if 'a' <= ch <= 'z':\n"
" out.append(chr((ord(ch) - ord('a') + 3) % 26 + ord('a')))\n"
" else:\n"
" out.append(ch)\n"
" return ''.join(out)\n"
),
"is_prime": (
"def is_prime(n):\n"
" if not isinstance(n, int) or isinstance(n, bool):\n"
" raise TypeError('Input must be int.')\n"
" if n < 2:\n"
" return False\n"
" if n < 4:\n"
" return True\n"
" if n % 2 == 0:\n"
" return False\n"
" i = 3\n"
" while i * i <= n:\n"
" if n % i == 0:\n"
" return False\n"
" i += 2\n"
" return True\n"
),
# ---- 6 hub-pushed tasks -----------------------------------------------
"roman_to_int": (
"def roman_to_int(s):\n"
" if not isinstance(s, str):\n"
" raise TypeError('input must be str')\n"
" table = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000}\n"
" total = 0\n"
" prev = 0\n"
" for ch in reversed(s.upper()):\n"
" if ch not in table:\n"
" raise ValueError(f'invalid roman numeral character: {ch!r}')\n"
" v = table[ch]\n"
" if v < prev:\n"
" total -= v\n"
" else:\n"
" total += v\n"
" prev = v\n"
" return total\n"
),
"levenshtein_distance": (
"def levenshtein_distance(a, b):\n"
" if not isinstance(a, str) or not isinstance(b, str):\n"
" raise TypeError('both arguments must be str')\n"
" if a == b:\n"
" return 0\n"
" if not a:\n"
" return len(b)\n"
" if not b:\n"
" return len(a)\n"
" prev = list(range(len(b) + 1))\n"
" for i, ca in enumerate(a, 1):\n"
" cur = [i] + [0] * len(b)\n"
" for j, cb in enumerate(b, 1):\n"
" ins = cur[j-1] + 1\n"
" dele = prev[j] + 1\n"
" sub = prev[j-1] + (ca != cb)\n"
" cur[j] = min(ins, dele, sub)\n"
" prev = cur\n"
" return prev[-1]\n"
),
"flatten_list": (
"def flatten_list(xs):\n"
" if not isinstance(xs, (list, tuple)):\n"
" raise TypeError('input must be list or tuple')\n"
" out = []\n"
" rev = []\n"
" rev.extend(reversed(list(xs)))\n"
" while rev:\n"
" x = rev.pop()\n"
" if isinstance(x, (list, tuple)):\n"
" for y in reversed(x):\n"
" rev.append(y)\n"
" else:\n"
" out.append(x)\n"
" return out\n"
),
"merge_sorted": (
"def merge_sorted(a, b):\n"
" if not isinstance(a, list) or not isinstance(b, list):\n"
" raise TypeError('both arguments must be list')\n"
" for x in (*a, *b):\n"
" if not isinstance(x, int) or isinstance(x, bool):\n"
" raise TypeError('elements must be int')\n"
" out = []\n"
" i = j = 0\n"
" while i < len(a) and j < len(b):\n"
" if a[i] <= b[j]:\n"
" out.append(a[i]); i += 1\n"
" else:\n"
" out.append(b[j]); j += 1\n"
" out.extend(a[i:])\n"
" out.extend(b[j:])\n"
" return out\n"
),
"run_length_encode": (
"def run_length_encode(s):\n"
" if not isinstance(s, str):\n"
" raise TypeError('input must be str')\n"
" if not s:\n"
" return []\n"
" out = []\n"
" cur = s[0]\n"
" n = 1\n"
" for ch in s[1:]:\n"
" if ch == cur:\n"
" n += 1\n"
" else:\n"
" out.append((cur, n))\n"
" cur = ch\n"
" n = 1\n"
" out.append((cur, n))\n"
" return out\n"
),
"binary_search": (
"def binary_search(arr, target):\n"
" if not isinstance(arr, list):\n"
" raise TypeError('arr must be list')\n"
" if not isinstance(target, int) or isinstance(target, bool):\n"
" raise TypeError('target must be int')\n"
" lo, hi = 0, len(arr) - 1\n"
" while lo <= hi:\n"
" mid = (lo + hi) // 2\n"
" v = arr[mid]\n"
" if v == target:\n"
" return mid\n"
" if v < target:\n"
" lo = mid + 1\n"
" else:\n"
" hi = mid - 1\n"
" return -1\n"
),
}
def get_oracle_code(task_name: str) -> str:
"""Return the canonical source for ``task_name``, or a stub raising
NotImplementedError if the task isn't in the oracle catalog."""
code = ORACLE_SOLUTIONS.get(task_name)
if code is not None:
return code
return (
f"def {task_name}(*args, **kwargs):\n"
f" raise NotImplementedError(\n"
f" 'No oracle reference for {task_name!r}. Try the model backends.'\n"
f" )\n"
)