Simo76 commited on
Commit
d6b96b5
Β·
1 Parent(s): 1fda0d1

Add OrbitalController for adaptive trajectory control

Browse files

Implement a closed-loop trajectory controller for dynamic capacity adaptation in machine learning models. This module adapts model capacity based on observed training stress and includes features for stability and memory.

Files changed (1) hide show
  1. orbital_controller.py +291 -0
orbital_controller.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Orbital Controller β€” Trajectory Control with Memory
3
+ =====================================================
4
+
5
+ Closed-loop rank controller that adapts model capacity based on
6
+ observed training stress. Works with any rank-adjustable system
7
+ (NestedLoRA, adaptive LR, or API-based training).
8
+
9
+ This module is the "intelligence" β€” pure control logic, no model code.
10
+ Pair with NestedLoRA for the complete Unified-LoRA system.
11
+
12
+ Author: Simona Vargiu
13
+ License: Apache 2.0
14
+ """
15
+
16
+ import numpy as np
17
+ from typing import Dict, List, Optional
18
+
19
+
20
+ class OrbitalController:
21
+ """
22
+ Closed-loop trajectory controller for dynamic capacity adaptation.
23
+
24
+ Unlike threshold-based controllers that map stress to rank statically,
25
+ this implements orbital dynamics with memory:
26
+
27
+ Ascend: stress detected β†’ jump to higher orbital, push delta
28
+ Hold: oscillating β†’ stay, don't move
29
+ Descend: confirmed stable β†’ pop delta, symmetric return
30
+
31
+ Each capacity increase is tracked on a stack and reversed only under
32
+ confirmed stability. This prevents premature compression (returning
33
+ too early) and oscillatory collapse (bouncing between ranks).
34
+
35
+ The stress signal and thresholds are adaptive β€” they auto-calibrate
36
+ to any model/task/loss scale without manual tuning.
37
+
38
+ Args:
39
+ ranks: Available capacity levels (default: [4, 8, 16])
40
+ warmup: Steps at max capacity to build EMA baseline
41
+ stable_window: Consecutive stable steps required for descent
42
+
43
+ Example:
44
+ >>> from nested_lora import inject_nested_lora, set_rank
45
+ >>> from orbital_controller import OrbitalController
46
+ >>>
47
+ >>> model = inject_nested_lora(model, max_rank=16)
48
+ >>> ctrl = OrbitalController()
49
+ >>>
50
+ >>> for step, batch in enumerate(loader):
51
+ ... loss = model(**batch).loss
52
+ ... new_rank = ctrl.step(loss.item())
53
+ ... set_rank(model, new_rank)
54
+ ... loss.backward()
55
+ ... optimizer.step()
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ ranks: Optional[List[int]] = None,
61
+ warmup: int = 10,
62
+ stable_window: int = 6,
63
+ ):
64
+ self.RANKS = ranks or [4, 8, 16]
65
+ self.warmup = warmup
66
+ self.stable_window = stable_window
67
+ self.reset()
68
+
69
+ def reset(self):
70
+ """Reset controller to initial state."""
71
+ self.rank = self.RANKS[-1]
72
+ self.orbit_stack = []
73
+ self.loss_ema = 0.0
74
+ self.prev_loss = None
75
+ self.phi_hist = []
76
+ self.stable_count = 0
77
+ self.step_count = 0
78
+ self.post_warmup = False
79
+
80
+ self.history = {
81
+ "rank": [],
82
+ "phi": [],
83
+ "stable_count": [],
84
+ }
85
+
86
+ # ── Stress signal ───────────────────────────────
87
+
88
+ def _compute_phi(self, loss: float) -> float:
89
+ """
90
+ Stress signal from loss trajectory.
91
+
92
+ Ο† = |loss - EMA| + 2.0 Γ— max(0, loss - prev_loss)
93
+
94
+ Combines deviation from trend (general instability)
95
+ with spike detection (sudden deterioration).
96
+ """
97
+ self.loss_ema = 0.9 * self.loss_ema + 0.1 * loss
98
+ delta = abs(loss - self.loss_ema)
99
+ spike = max(0.0, loss - self.prev_loss) if self.prev_loss is not None else 0.0
100
+ self.prev_loss = loss
101
+ return delta + 2.0 * spike
102
+
103
+ def _thresholds(self):
104
+ """
105
+ Adaptive thresholds from running statistics.
106
+
107
+ t_stress = ΞΌ + 0.7Οƒ (above this β†’ ascend)
108
+ t_stable = ΞΌ - 0.3Οƒ (below this β†’ stability confirmed)
109
+
110
+ Auto-calibrates to loss scale. No manual tuning.
111
+ """
112
+ if len(self.phi_hist) < 10:
113
+ return 0.15, 0.04
114
+ recent = self.phi_hist[-40:]
115
+ mu = np.mean(recent)
116
+ sigma = np.std(recent) + 1e-8
117
+ t_stress = mu + 0.7 * sigma
118
+ t_stable = max(mu - 0.3 * sigma, 0.0)
119
+ return t_stress, t_stable
120
+
121
+ # ── Core logic ──────────────────────────────────
122
+
123
+ def _rank_index(self) -> int:
124
+ return self.RANKS.index(self.rank)
125
+
126
+ def step(self, loss: float) -> int:
127
+ """
128
+ Called once per training step. Returns the capacity level to use.
129
+
130
+ Args:
131
+ loss: Current step loss value
132
+
133
+ Returns:
134
+ int: Active rank (or capacity level) for next step
135
+ """
136
+ self.step_count += 1
137
+
138
+ # First step: initialize EMA
139
+ if self.prev_loss is None:
140
+ self.loss_ema = loss
141
+ self.prev_loss = loss
142
+ self._log(0.0)
143
+ return self.rank
144
+
145
+ phi = self._compute_phi(loss)
146
+ self.phi_hist.append(phi)
147
+
148
+ # Warmup: build baseline at max capacity
149
+ if self.step_count <= self.warmup:
150
+ self._log(phi)
151
+ return self.rank
152
+
153
+ # Transition: warmup β†’ ground state
154
+ if not self.post_warmup:
155
+ self.post_warmup = True
156
+ self.rank = self.RANKS[0]
157
+ self.orbit_stack = []
158
+ self.stable_count = 0
159
+ self._log(phi)
160
+ return self.rank
161
+
162
+ t_stress, t_stable = self._thresholds()
163
+
164
+ # Stability counter
165
+ if phi <= t_stable:
166
+ self.stable_count += 1
167
+ elif phi > t_stress:
168
+ self.stable_count = 0
169
+ else:
170
+ self.stable_count = max(0, self.stable_count - 1)
171
+
172
+ # ASCEND: stress β†’ jump to higher orbital
173
+ if phi > t_stress and self.rank < self.RANKS[-1]:
174
+ idx = self._rank_index()
175
+ new_idx = min(idx + 1, len(self.RANKS) - 1)
176
+ new_rank = self.RANKS[new_idx]
177
+ if new_rank != self.rank:
178
+ self.orbit_stack.append(new_rank - self.rank)
179
+ self.rank = new_rank
180
+ self.stable_count = 0
181
+ self._log(phi)
182
+ return self.rank
183
+
184
+ # DESCEND: confirmed stability β†’ symmetric return
185
+ if self.stable_count >= self.stable_window and self.orbit_stack:
186
+ delta = self.orbit_stack.pop()
187
+ target = self.rank - delta
188
+ self.rank = min(self.RANKS, key=lambda r: abs(r - target))
189
+ self.rank = max(self.rank, self.RANKS[0])
190
+ self.stable_count = 0
191
+ self._log(phi)
192
+ return self.rank
193
+
194
+ # HOLD: neutral β†’ don't move
195
+ self._log(phi)
196
+ return self.rank
197
+
198
+ # ── Introspection ───────────────────────────────
199
+
200
+ def _log(self, phi: float):
201
+ self.history["rank"].append(self.rank)
202
+ self.history["phi"].append(phi)
203
+ self.history["stable_count"].append(self.stable_count)
204
+
205
+ def get_state(self) -> Dict:
206
+ """Current controller state."""
207
+ return {
208
+ "rank": self.rank,
209
+ "step": self.step_count,
210
+ "orbit_stack": list(self.orbit_stack),
211
+ "stable_count": self.stable_count,
212
+ "phi": self.phi_hist[-1] if self.phi_hist else 0.0,
213
+ }
214
+
215
+ def get_history(self) -> Dict[str, list]:
216
+ """Complete training history."""
217
+ return self.history
218
+
219
+ def __repr__(self) -> str:
220
+ return (
221
+ f"OrbitalController(step={self.step_count}, rank={self.rank}, "
222
+ f"stack={self.orbit_stack}, stable={self.stable_count})"
223
+ )
224
+
225
+
226
+ # ============================================================
227
+ # CONVENIENCE: setup helper
228
+ # ============================================================
229
+
230
+ def setup_unified_lora(model, max_rank=16, ranks=None, warmup=10, stable_window=6):
231
+ """
232
+ One-call setup: inject NestedLoRA + create OrbitalController.
233
+
234
+ Args:
235
+ model: PyTorch model
236
+ max_rank: Maximum LoRA rank
237
+ ranks: Available rank levels
238
+ warmup: Controller warmup steps
239
+ stable_window: Steps of stability before descent
240
+
241
+ Returns:
242
+ (model, controller) tuple
243
+
244
+ Example:
245
+ >>> from orbital_controller import setup_unified_lora
246
+ >>> from nested_lora import set_rank
247
+ >>>
248
+ >>> model, ctrl = setup_unified_lora(model)
249
+ >>> for step, batch in enumerate(loader):
250
+ ... loss = model(**batch).loss
251
+ ... set_rank(model, ctrl.step(loss.item()))
252
+ ... loss.backward(); optimizer.step(); optimizer.zero_grad()
253
+ """
254
+ from nested_lora import inject_nested_lora
255
+
256
+ model = inject_nested_lora(model, max_rank)
257
+ controller = OrbitalController(
258
+ ranks=ranks or [4, 8, 16],
259
+ warmup=warmup,
260
+ stable_window=stable_window,
261
+ )
262
+ return model, controller
263
+
264
+
265
+ # ============================================================
266
+ # DEMO
267
+ # ============================================================
268
+
269
+ if __name__ == "__main__":
270
+ print("Orbital Controller β€” Demo")
271
+ print("=" * 50)
272
+ print("Simulating: 30 stable β†’ 10 shock β†’ 30 recovery\n")
273
+
274
+ ctrl = OrbitalController(warmup=8, stable_window=5)
275
+
276
+ for step in range(70):
277
+ if step < 30:
278
+ loss = np.random.uniform(0.4, 0.6)
279
+ elif step < 40:
280
+ loss = np.random.uniform(1.5, 3.0)
281
+ else:
282
+ loss = np.random.uniform(0.3, 0.5)
283
+
284
+ rank = ctrl.step(loss)
285
+
286
+ if step % 5 == 0 or step == 30:
287
+ s = ctrl.get_state()
288
+ tag = " <<<SHOCK" if step == 30 else ""
289
+ print(f" [{step:3d}] rank={rank:2d} phi={s['phi']:.3f} stack={s['orbit_stack']}{tag}")
290
+
291
+ print(f"\nFinal: {ctrl}")