File size: 5,650 Bytes
b0b140b
 
 
51a403e
b0b140b
51a403e
 
b0b140b
 
 
 
 
51a403e
b0b140b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a403e
 
 
 
 
 
b0b140b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a403e
 
 
 
 
 
 
 
 
 
 
 
b0b140b
 
 
 
 
 
 
 
 
 
 
51a403e
 
 
 
 
 
 
b0b140b
 
 
 
 
51a403e
b0b140b
 
 
 
 
 
 
 
 
 
51a403e
b0b140b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Sandbox for executing OptCoder-submitted optimizer code.

In-process exec with:
- AST strip of module-level demo code (keeps only `class Optimizer`)
- Restricted globals (only `np` and `math` exposed)
- Thread-based timeout on each `step()` call (works on any thread — Gradio
  handlers, uvicorn workers, etc. — unlike SIGALRM which requires main thread)
"""

from __future__ import annotations

import ast
import concurrent.futures
import math
from dataclasses import dataclass
from typing import Any

import numpy as np


class SandboxError(Exception):
    """Raised for any sandbox-level failure (syntax, timeout, security)."""


class StepTimeout(SandboxError):
    pass


# Shared single-worker pool for step() invocations. Using a reusable executor
# avoids the ~50ms-per-call cost of spinning up a fresh thread pool for each
# arena step. Not shared across processes, but that's fine in-process.
_STEP_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
    max_workers=1, thread_name_prefix="lf-sandbox-step",
)


def strip_module_code(source: str) -> str:
    """Keep only the `class Optimizer` node.

    Drops imports (the sandbox pre-injects np/numpy/math into globals),
    hallucinated demo functions, `if __name__ == '__main__'` blocks, and
    trailing execution code that frequently appears in LLM output.
    """
    try:
        tree = ast.parse(source)
    except SyntaxError as e:
        raise SandboxError(f"SyntaxError: {e}") from e

    kept: list[ast.stmt] = []
    found_class = False
    for node in tree.body:
        if isinstance(node, ast.ClassDef) and node.name == "Optimizer":
            kept.append(node)
            found_class = True
        # Imports are dropped — env provides np/numpy/math via globals.

    if not found_class:
        raise SandboxError("No `class Optimizer` found in submission")

    new_tree = ast.Module(body=kept, type_ignores=[])
    ast.fix_missing_locations(new_tree)
    return ast.unparse(new_tree)


def _safe_globals() -> dict:
    """Globals exposed to submitted code. Minimal builtins + np/numpy/math."""
    import builtins as _bi

    safe_names = [
        # numeric / iteration
        "abs", "min", "max", "sum", "len", "range", "zip", "enumerate",
        "list", "tuple", "dict", "set", "float", "int", "bool", "str",
        "round", "divmod", "pow", "reversed", "sorted", "any", "all", "map", "filter",
        # introspection (safe subset)
        "isinstance", "issubclass", "hasattr", "getattr", "setattr",
        "True", "False", "None",
        # class definition machinery (required to define `class Optimizer`)
        "__build_class__", "__name__", "object", "super",
        "type", "property", "staticmethod", "classmethod",
        # errors (so submitted code can raise/catch sanely)
        "Exception", "ValueError", "TypeError", "IndexError", "KeyError",
        "ZeroDivisionError", "RuntimeError", "ArithmeticError", "OverflowError",
    ]
    safe_bi = {n: getattr(_bi, n) for n in safe_names if hasattr(_bi, n)}

    return {
        "__builtins__": safe_bi,
        "__name__": "__submission__",
        "np": np,
        "numpy": np,
        "math": math,
    }


@dataclass
class CompiledOptimizer:
    """Wraps an instantiated Optimizer with bounded `step` execution."""
    instance: Any
    step_timeout: float = 0.5

    def step(self, x: np.ndarray, f_val: float, grad: np.ndarray) -> np.ndarray:
        # Run step() on a worker thread with a hard deadline. This is
        # thread-safe (unlike SIGALRM) so it works from Gradio handlers and
        # uvicorn workers.
        future = _STEP_EXECUTOR.submit(self.instance.step, x, f_val, grad)
        try:
            out = future.result(timeout=self.step_timeout)
        except concurrent.futures.TimeoutError:
            future.cancel()
            raise StepTimeout(f"step() exceeded {self.step_timeout}s")
        except Exception as e:
            raise SandboxError(f"step() raised {type(e).__name__}: {e}") from e

        try:
            out = np.asarray(out, dtype=float)
        except Exception as e:
            raise SandboxError(f"step() returned non-array value ({type(e).__name__}: {e})") from e
        if out.shape != x.shape:
            raise SandboxError(f"step() returned shape {out.shape}, expected {x.shape}")
        if not np.all(np.isfinite(out)):
            raise SandboxError("step() returned non-finite values")
        return out


def compile_optimizer(source: str, dim: int, step_timeout: float = 0.5) -> CompiledOptimizer:
    """Strip, exec, and instantiate Optimizer(dim=dim). Returns a wrapper.

    exec() and __init__() are NOT timeout-guarded — they should be fast
    (microseconds) and any pathological module-level code would be caught
    by the AST strip. Timeout protection is applied to `step()` calls.
    """
    stripped = strip_module_code(source)
    globs = _safe_globals()
    locs: dict = {}

    try:
        exec(compile(stripped, "<submission>", "exec"), globs, locs)
    except SandboxError:
        raise
    except Exception as e:
        raise SandboxError(f"exec failed: {type(e).__name__}: {e}") from e

    OptimizerCls = locs.get("Optimizer") or globs.get("Optimizer")
    if OptimizerCls is None:
        raise SandboxError("Optimizer class not defined after exec")

    try:
        instance = OptimizerCls(dim=dim)
    except Exception as e:
        raise SandboxError(f"__init__ failed: {type(e).__name__}: {e}") from e

    if not hasattr(instance, "step"):
        raise SandboxError("Optimizer instance missing `step` method")

    return CompiledOptimizer(instance=instance, step_timeout=step_timeout)