File size: 7,528 Bytes
b0b140b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0556ae
 
 
 
 
 
 
 
 
b0b140b
d0556ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b140b
 
 
 
d0556ae
 
 
 
 
b0b140b
 
 
 
 
 
d0556ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b140b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Reference optimizers run by `run_baseline` action.

These are invoked by the env — not by OptCoder's submitted code. They
produce diagnostic trajectories (x_t, f_t, |g_t|) that the agent sees.
The source code is NEVER exposed to the agent.
"""

from typing import Callable

import numpy as np


def _step_sgd(x, g, state, lr=0.01):
    return x - lr * g, state


def _step_momentum(x, g, state, lr=0.01, beta=0.9):
    v = state.get("v", np.zeros_like(x))
    v = beta * v - lr * g
    state["v"] = v
    return x + v, state


def _step_adam(x, g, state, lr=0.001, b1=0.9, b2=0.999, eps=1e-8):
    m = state.get("m", np.zeros_like(x))
    v = state.get("v", np.zeros_like(x))
    t = state.get("t", 0) + 1
    m = b1 * m + (1 - b1) * g
    v = b2 * v + (1 - b2) * g**2
    m_hat = m / (1 - b1**t)
    v_hat = v / (1 - b2**t)
    state["m"], state["v"], state["t"] = m, v, t
    return x - lr * m_hat / (np.sqrt(v_hat) + eps), state


def _run_adam_with_lr(f, grad, x0: np.ndarray, lr: float, steps: int) -> tuple[np.ndarray, float]:
    """Run Adam for `steps` steps from x0 with the given lr. Returns (x_final, f_final).

    Used by the LR-tuning sweep for the Adam baseline. Returns (x0, inf) on divergence.
    """
    x = x0.copy().astype(float)
    state: dict = {}
    for _ in range(steps):
        g = np.asarray(grad(x), dtype=float)
        x, state = _step_adam(x, g, state, lr=lr)
        if not np.all(np.isfinite(x)):
            return x0, float("inf")
    return x, float(f(x))


def tune_adam_lr(f, grad, x0: np.ndarray,
                 lrs: tuple[float, ...] = (1e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1),
                 sweep_steps: int = 30) -> float:
    """Grid-search Adam's LR on a short run from x0. Returns the best LR."""
    best_lr = lrs[0]
    best_f = float("inf")
    for lr in lrs:
        _, f_final = _run_adam_with_lr(f, grad, x0, lr=lr, steps=sweep_steps)
        if f_final < best_f:
            best_f = f_final
            best_lr = lr
    return best_lr


def _run_baseline_with_lr(name: str, f, grad, x0: np.ndarray,
                           lr: float, steps: int) -> tuple[np.ndarray, float]:
    """Run any reference baseline with an overridden LR. Returns (x_final, f_final)."""
    if name not in BASELINES:
        raise ValueError(f"Unknown baseline {name!r}")
    step_fn = BASELINES[name]
    x = x0.copy().astype(float)
    state: dict = {}
    for _ in range(steps):
        g = np.asarray(grad(x), dtype=float)
        x, state = step_fn(x, g, state, lr=lr)
        if not np.all(np.isfinite(x)):
            return x0, float("inf")
    return x, float(f(x))


def tune_baseline_lr(name: str, f, grad, x0: np.ndarray,
                     lrs: tuple[float, ...] = (1e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1, 3e-1),
                     sweep_steps: int = 30) -> float:
    """Grid-search the LR for any named baseline (sgd / momentum / adam / lbfgs).

    Each baseline's `_step_*` fn accepts `lr` as a kwarg, so the sweep uses the
    same harness as Adam tuning but is parameterised by the step function.
    """
    best_lr = lrs[0]
    best_f = float("inf")
    for lr in lrs:
        try:
            _, f_final = _run_baseline_with_lr(name, f, grad, x0,
                                                 lr=lr, steps=sweep_steps)
        except Exception:
            f_final = float("inf")
        if f_final < best_f:
            best_f = f_final
            best_lr = lr
    return best_lr


def run_baseline_tuned(name: str, f, grad, x0: np.ndarray, steps: int = 30,
                        tune_x0: np.ndarray | None = None,
                        sweep_steps: int = 30) -> dict:
    """Run a baseline with its LR tuned to the landscape first.

    Returns the same dict shape as `run_baseline`, plus a `lr` field.
    """
    tune_start = tune_x0 if tune_x0 is not None else x0
    best_lr = tune_baseline_lr(name, f, grad, tune_start, sweep_steps=sweep_steps)

    step_fn = BASELINES[name]
    x = x0.copy().astype(float)
    state: dict = {}
    traj: list[dict] = []
    for t in range(steps):
        fv = float(f(x))
        g = np.asarray(grad(x), dtype=float)
        gn = float(np.linalg.norm(g))
        traj.append({"t": t, "x": x.tolist(), "f": fv, "grad_norm": gn})
        x, state = step_fn(x, g, state, lr=best_lr)
        if not np.all(np.isfinite(x)):
            traj.append({"t": t + 1, "x": None, "f": None, "grad_norm": None,
                         "diverged": True})
            break
    if np.all(np.isfinite(x)):
        traj.append({"t": len(traj), "x": x.tolist(), "f": float(f(x)),
                     "grad_norm": float(np.linalg.norm(np.asarray(grad(x))))})
    return {"name": name, "trajectory": traj, "final_x": x.tolist(),
            "lr": best_lr}


def _step_lbfgs(x, g, state, lr=0.01, m_size=5):
    """Crude L-BFGS with finite-step history. Good enough as a reference."""
    xs = state.setdefault("xs", [])     # positions
    gs = state.setdefault("gs", [])     # gradients

    if len(xs) < 2:
        # First step: plain gradient descent to seed history
        x_new = x - lr * g
    else:
        # Two-loop recursion over last m_size pairs
        s_list, y_list, rho_list = [], [], []
        for i in range(1, min(m_size, len(xs)) + 1):
            s = xs[-i] - xs[-i - 1] if len(xs) > i else None
            if s is None:
                continue
            y = gs[-i] - gs[-i - 1]
            denom = float(y @ s)
            if abs(denom) < 1e-12:
                continue
            s_list.append(s); y_list.append(y); rho_list.append(1.0 / denom)

        q = g.copy()
        alpha = []
        for s, y, rho in zip(s_list, y_list, rho_list):
            a = rho * float(s @ q)
            alpha.append(a)
            q = q - a * y

        # H0 scaling
        if y_list:
            y0 = y_list[0]; s0 = s_list[0]
            gamma = float(s0 @ y0) / (float(y0 @ y0) + 1e-12)
        else:
            gamma = 1.0
        r = gamma * q

        for (s, y, rho), a in zip(reversed(list(zip(s_list, y_list, rho_list))), reversed(alpha)):
            b = rho * float(y @ r)
            r = r + (a - b) * s

        x_new = x - lr * r

    xs.append(x.copy())
    gs.append(g.copy())
    return x_new, state


BASELINES: dict[str, Callable] = {
    "sgd": _step_sgd,
    "momentum": _step_momentum,
    "adam": _step_adam,
    "lbfgs": _step_lbfgs,
}


def run_baseline(name: str, f, grad, x0: np.ndarray, steps: int = 30) -> dict:
    """Run a reference optimizer from x0 for `steps` steps.

    Returns a trajectory dict with per-step (x, f, |g|).
    """
    if name not in BASELINES:
        raise ValueError(f"Unknown baseline {name!r}")
    step_fn = BASELINES[name]
    x = x0.copy().astype(float)
    state: dict = {}
    traj = []
    for t in range(steps):
        fv = float(f(x))
        g = np.asarray(grad(x), dtype=float)
        gn = float(np.linalg.norm(g))
        traj.append({"t": t, "x": x.tolist(), "f": fv, "grad_norm": gn})
        x, state = step_fn(x, g, state)
        if not np.all(np.isfinite(x)):
            # Pad with the last finite state; record divergence
            traj.append({"t": t + 1, "x": None, "f": None, "grad_norm": None,
                         "diverged": True})
            break
    # Final state
    if np.all(np.isfinite(x)):
        traj.append({"t": len(traj), "x": x.tolist(), "f": float(f(x)),
                     "grad_norm": float(np.linalg.norm(np.asarray(grad(x))))})
    return {"name": name, "trajectory": traj, "final_x": x.tolist()}