Simo76 commited on
Commit
1a5825f
Β·
1 Parent(s): d72fbc5

Refactor stress test to use Orbital Controller

Browse files

Updated stress test script to use Orbital Controller instead of Nested LoRA. Improved code readability and structure.

Files changed (1) hide show
  1. experiments/stress_test_task_switch.py +144 -153
experiments/stress_test_task_switch.py CHANGED
@@ -1,15 +1,8 @@
1
- """
2
- Unified-LoRA β€” Stress Test: Task Switch
3
- =========================================
4
 
5
  MRPC (60 steps) β†’ SST-2 (60 steps)
6
- Baseline (r=16 fixed) vs Nested Orbital Controller
7
-
8
- Self-contained, reproducible on Google Colab with T4 GPU.
9
-
10
- Usage:
11
- pip install transformers datasets evaluate
12
- python stress_test_task_switch.py
13
  """
14
 
15
  import time, random, math, numpy as np, torch, torch.nn as nn
@@ -18,12 +11,15 @@ from datasets import load_dataset
18
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
19
  from torch.utils.data import DataLoader
20
 
21
- # Import from controller.py (same repo)
22
  import sys, os
23
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
- from controller import NestedLoRALinear, OrbitalController, inject_nested_lora, set_rank
 
 
 
 
 
25
 
26
- # ── CONFIG ──────────────────────────────────────────
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
  MODEL = "distilbert-base-uncased"
29
  BATCH = 8
@@ -34,18 +30,19 @@ MAX_RANK = 16
34
  WARMUP = 10
35
  STABLE_WINDOW = 6
36
 
37
- STEPS_TASK1 = 60 # MRPC
38
- STEPS_TASK2 = 60 # SST-2
39
  TOTAL_STEPS = STEPS_TASK1 + STEPS_TASK2
40
 
41
- # ── DATA ────────────────────────────────────────────
 
42
  print("Loading data...")
43
  tok = AutoTokenizer.from_pretrained(MODEL)
44
 
45
  ds_mrpc = load_dataset("glue", "mrpc")
46
  def tok_mrpc(x):
47
- return tok(x["sentence1"], x["sentence2"],
48
- truncation=True, padding="max_length", max_length=128)
49
  ds_mrpc = ds_mrpc.map(tok_mrpc, batched=True)
50
  ds_mrpc.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
51
  train_mrpc = DataLoader(ds_mrpc["train"], batch_size=BATCH, shuffle=True)
@@ -53,7 +50,7 @@ val_mrpc = DataLoader(ds_mrpc["validation"], batch_size=BATCH)
53
 
54
  ds_sst2 = load_dataset("glue", "sst2")
55
  def tok_sst2(x):
56
- return tok(x["sentence"], truncation=True, padding="max_length", max_length=128)
57
  ds_sst2 = ds_sst2.map(tok_sst2, batched=True)
58
  ds_sst2.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
59
  train_sst2 = DataLoader(ds_sst2["train"], batch_size=BATCH, shuffle=True)
@@ -62,100 +59,97 @@ val_sst2 = DataLoader(ds_sst2["validation"], batch_size=BATCH)
62
  metric_mrpc = evaluate.load("glue", "mrpc")
63
  metric_sst2 = evaluate.load("glue", "sst2")
64
 
65
- # ── HELPERS ─────────────────────────────────────────
 
66
  def make_iter(loader):
67
- while True:
68
- for batch in loader:
69
- yield batch
70
 
71
- def get_batch(it, device):
72
- batch = next(it)
73
- return (batch["input_ids"].to(device),
74
- batch["attention_mask"].to(device),
75
- batch["label"].to(device))
76
 
77
  def build_model():
78
- base = AutoModelForSequenceClassification.from_pretrained(
79
- MODEL, num_labels=2, ignore_mismatched_sizes=True
80
- )
81
- return inject_nested_lora(base, MAX_RANK).to(DEVICE)
82
 
83
  def eval_f1(model, loader, metric_fn):
84
- model.eval()
85
- preds, labels = [], []
86
- with torch.no_grad():
87
- for batch in loader:
88
- x = batch["input_ids"].to(DEVICE)
89
- m = batch["attention_mask"].to(DEVICE)
90
- y = batch["label"].to(DEVICE)
91
- logits = model(input_ids=x, attention_mask=m).logits
92
- preds.extend(logits.argmax(dim=-1).cpu().numpy())
93
- labels.extend(y.cpu().numpy())
94
- model.train()
95
- result = metric_fn.compute(predictions=preds, references=labels)
96
- return result.get("f1", result.get("accuracy", 0.0))
97
 
98
  def eff_rank(usage):
99
- tot = sum(usage.values())
100
- return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0
 
 
101
 
102
- # ── TRAIN BASELINE ──────────────────────────────────
103
  def train_baseline(model):
104
- opt = torch.optim.AdamW(model.parameters(), lr=LR)
105
- set_rank(model, 16)
106
- it_mrpc = make_iter(train_mrpc)
107
- it_sst2 = make_iter(train_sst2)
108
- loss_trace = []
109
-
110
- for step in range(TOTAL_STEPS):
111
- if step < STEPS_TASK1:
112
- x, m, y = get_batch(it_mrpc, DEVICE)
113
- else:
114
- x, m, y = get_batch(it_sst2, DEVICE)
115
-
116
- loss = model(input_ids=x, attention_mask=m, labels=y).loss
117
- loss_trace.append(loss.item())
118
- loss.backward()
119
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
120
- opt.step()
121
- opt.zero_grad()
122
-
123
- return model, loss_trace
124
-
125
- # ── TRAIN UNIFIED ───────────────────────────────────
126
- def train_unified(model):
127
- ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW)
128
- ctrl.rank = 4
129
- set_rank(model, 4)
130
-
131
- opt = torch.optim.AdamW(model.parameters(), lr=LR)
132
- usage = {4: 0, 8: 0, 16: 0}
133
- rank_trace, loss_trace = [], []
134
- it_mrpc = make_iter(train_mrpc)
135
- it_sst2 = make_iter(train_sst2)
136
-
137
- for step in range(TOTAL_STEPS):
138
- if step < STEPS_TASK1:
139
- x, m, y = get_batch(it_mrpc, DEVICE)
140
- else:
141
- x, m, y = get_batch(it_sst2, DEVICE)
142
-
143
- loss = model(input_ids=x, attention_mask=m, labels=y).loss
144
- new_rank = ctrl.step(loss.item())
145
- set_rank(model, new_rank)
146
-
147
- usage[new_rank] += 1
148
- rank_trace.append(new_rank)
149
- loss_trace.append(loss.item())
150
-
151
- loss.backward()
152
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
153
- opt.step()
154
- opt.zero_grad()
155
-
156
- return model, usage, rank_trace, loss_trace, ctrl
157
-
158
- # ── RUN ─────────────────────────────────────────────
159
  print(f"\nDevice: {DEVICE}")
160
  print(f"Plan: MRPC Γ— {STEPS_TASK1} β†’ SST-2 Γ— {STEPS_TASK2}")
161
  print(f"Shock at step {STEPS_TASK1}")
@@ -164,49 +158,48 @@ print("=" * 55)
164
  results = []
165
 
166
  for seed in SEEDS:
167
- print(f"\n{'─' * 55}\n SEED {seed}\n{'─' * 55}")
168
-
169
- torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
170
- base_model = build_model()
171
- base_model, base_loss = train_baseline(base_model)
172
- f1_mrpc_base = eval_f1(base_model, val_mrpc, metric_mrpc)
173
- f1_sst2_base = eval_f1(base_model, val_sst2, metric_sst2)
174
- del base_model; torch.cuda.empty_cache()
175
-
176
- torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
177
- uni_model = build_model()
178
- uni_model, usage, rank_trace, uni_loss, ctrl = train_unified(uni_model)
179
- f1_mrpc_uni = eval_f1(uni_model, val_mrpc, metric_mrpc)
180
- f1_sst2_uni = eval_f1(uni_model, val_sst2, metric_sst2)
181
-
182
- er = eff_rank(usage)
183
- saving = 1 - er / 16
184
- transitions = sum(1 for i in range(1, len(rank_trace)) if rank_trace[i] != rank_trace[i-1])
185
-
186
- print(f"\n {'':30s} {'BASELINE':>10s} {'UNIFIED':>10s}")
187
- print(f" {'─' * 55}")
188
- print(f" {'MRPC F1 (retention)':30s} {f1_mrpc_base:10.3f} {f1_mrpc_uni:10.3f}")
189
- print(f" {'SST-2 Acc (new task)':30s} {f1_sst2_base:10.3f} {f1_sst2_uni:10.3f}")
190
- print(f"\n Unified: eff_rank={er:.1f} saving={saving*100:.0f}% transitions={transitions}")
191
- print(f" Usage: r4={usage[4]} r8={usage[8]} r16={usage[16]}")
192
-
193
- # Rank trace
194
- trace_str = ""
195
- for i, r in enumerate(rank_trace):
196
- if i % 10 == 0:
197
- marker = " <<<SHOCK" if i == STEPS_TASK1 else ""
198
- trace_str += f"\n [{i:3d}]{marker} "
199
- trace_str += f"r{r:<3d}"
200
- print(f" Rank trace:{trace_str}")
201
-
202
- results.append({
203
- 'seed': seed, 'f1_mrpc_base': f1_mrpc_base, 'f1_sst2_base': f1_sst2_base,
204
- 'f1_mrpc_uni': f1_mrpc_uni, 'f1_sst2_uni': f1_sst2_uni,
205
- 'eff_rank': er, 'saving': saving, 'transitions': transitions,
206
- })
207
- del uni_model; torch.cuda.empty_cache()
208
-
209
- # ── SUMMARY ─────────────────────────────────────────
210
  print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}")
211
  mrpc_b = np.mean([r['f1_mrpc_base'] for r in results])
212
  mrpc_u = np.mean([r['f1_mrpc_uni'] for r in results])
@@ -215,9 +208,7 @@ sst2_u = np.mean([r['f1_sst2_uni'] for r in results])
215
  er_avg = np.mean([r['eff_rank'] for r in results])
216
  sv_avg = np.mean([r['saving'] for r in results])
217
 
218
- print(f"\n {'':30s} {'BASELINE':>10s} {'UNIFIED':>10s} {'DELTA':>8s}")
219
- print(f" {'─' * 60}")
220
- print(f" {'MRPC F1 (retention)':30s} {mrpc_b:10.3f} {mrpc_u:10.3f} {mrpc_u-mrpc_b:+8.3f}")
221
- print(f" {'SST-2 Acc (new task)':30s} {sst2_b:10.3f} {sst2_u:10.3f} {sst2_u-sst2_b:+8.3f}")
222
- print(f" {'Eff rank':30s} {'16.0':>10s} {er_avg:10.1f}")
223
- print(f" {'Saving':30s} {'0%':>10s} {sv_avg*100:9.0f}%")
 
1
+ """
2
+ Orbital LoRA β€” Stress Test: Task Switch
 
3
 
4
  MRPC (60 steps) β†’ SST-2 (60 steps)
5
+ Baseline (r=16 fixed) vs Orbital Controller
 
 
 
 
 
 
6
  """
7
 
8
  import time, random, math, numpy as np, torch, torch.nn as nn
 
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
  from torch.utils.data import DataLoader
13
 
 
14
  import sys, os
15
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(file))))
16
+
17
+ from nested_lora import NestedLoRALinear, inject_nested_lora
18
+ from orbital_controller import OrbitalController
19
+ from controller import set_rank
20
+
21
+ ── CONFIG ──────────────────────────────────────────
22
 
 
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
  MODEL = "distilbert-base-uncased"
25
  BATCH = 8
 
30
  WARMUP = 10
31
  STABLE_WINDOW = 6
32
 
33
+ STEPS_TASK1 = 60
34
+ STEPS_TASK2 = 60
35
  TOTAL_STEPS = STEPS_TASK1 + STEPS_TASK2
36
 
37
+ ── DATA ────────────────────────────────────────────
38
+
39
  print("Loading data...")
40
  tok = AutoTokenizer.from_pretrained(MODEL)
41
 
42
  ds_mrpc = load_dataset("glue", "mrpc")
43
  def tok_mrpc(x):
44
+ return tok(x["sentence1"], x["sentence2"],
45
+ truncation=True, padding="max_length", max_length=128)
46
  ds_mrpc = ds_mrpc.map(tok_mrpc, batched=True)
47
  ds_mrpc.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
48
  train_mrpc = DataLoader(ds_mrpc["train"], batch_size=BATCH, shuffle=True)
 
50
 
51
  ds_sst2 = load_dataset("glue", "sst2")
52
  def tok_sst2(x):
53
+ return tok(x["sentence"], truncation=True, padding="max_length", max_length=128)
54
  ds_sst2 = ds_sst2.map(tok_sst2, batched=True)
55
  ds_sst2.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
56
  train_sst2 = DataLoader(ds_sst2["train"], batch_size=BATCH, shuffle=True)
 
59
  metric_mrpc = evaluate.load("glue", "mrpc")
60
  metric_sst2 = evaluate.load("glue", "sst2")
61
 
62
+ ── HELPERS ─────────────────────────────────────────
63
+
64
  def make_iter(loader):
65
+ while True:
66
+ for batch in loader:
67
+ yield batch
68
 
69
+ def get_batch(it):
70
+ batch = next(it)
71
+ return (batch["input_ids"].to(DEVICE),
72
+ batch["attention_mask"].to(DEVICE),
73
+ batch["label"].to(DEVICE))
74
 
75
  def build_model():
76
+ base = AutoModelForSequenceClassification.from_pretrained(
77
+ MODEL, num_labels=2, ignore_mismatched_sizes=True
78
+ )
79
+ return inject_nested_lora(base, MAX_RANK).to(DEVICE)
80
 
81
  def eval_f1(model, loader, metric_fn):
82
+ model.eval()
83
+ preds, labels = [], []
84
+ with torch.no_grad():
85
+ for batch in loader:
86
+ x = batch["input_ids"].to(DEVICE)
87
+ m = batch["attention_mask"].to(DEVICE)
88
+ y = batch["label"].to(DEVICE)
89
+ logits = model(input_ids=x, attention_mask=m).logits
90
+ preds.extend(logits.argmax(dim=-1).cpu().numpy())
91
+ labels.extend(y.cpu().numpy())
92
+ model.train()
93
+ result = metric_fn.compute(predictions=preds, references=labels)
94
+ return result.get("f1", result.get("accuracy", 0.0))
95
 
96
  def eff_rank(usage):
97
+ tot = sum(usage.values())
98
+ return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0
99
+
100
+ ── TRAIN BASELINE ──────────────────────────────────
101
 
 
102
  def train_baseline(model):
103
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
104
+ set_rank(model, 16)
105
+ it_mrpc = make_iter(train_mrpc)
106
+ it_sst2 = make_iter(train_sst2)
107
+
108
+ for step in range(TOTAL_STEPS):
109
+ x, m, y = get_batch(it_mrpc if step < STEPS_TASK1 else it_sst2)
110
+
111
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
112
+ loss.backward()
113
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
114
+ opt.step()
115
+ opt.zero_grad()
116
+
117
+ return model
118
+
119
+ ── TRAIN ORBITAL ───────────────────────────────────
120
+
121
+ def train_orbital(model):
122
+ ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW)
123
+ ctrl.rank = 4
124
+ set_rank(model, 4)
125
+
126
+ opt = torch.optim.AdamW(model.parameters(), lr=LR)
127
+ usage = {4: 0, 8: 0, 16: 0}
128
+ rank_trace = []
129
+ it_mrpc = make_iter(train_mrpc)
130
+ it_sst2 = make_iter(train_sst2)
131
+
132
+ for step in range(TOTAL_STEPS):
133
+ x, m, y = get_batch(it_mrpc if step < STEPS_TASK1 else it_sst2)
134
+
135
+ loss = model(input_ids=x, attention_mask=m, labels=y).loss
136
+ loss.backward()
137
+
138
+ new_rank = ctrl.step(loss.item())
139
+ new_rank = max(4, min(16, new_rank))
140
+ set_rank(model, new_rank)
141
+
142
+ usage[new_rank] += 1
143
+ rank_trace.append(new_rank)
144
+
145
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
146
+ opt.step()
147
+ opt.zero_grad()
148
+
149
+ return model, usage, rank_trace
150
+
151
+ ── RUN ─────────────────────────────────────────────
152
+
 
 
 
 
 
153
  print(f"\nDevice: {DEVICE}")
154
  print(f"Plan: MRPC Γ— {STEPS_TASK1} β†’ SST-2 Γ— {STEPS_TASK2}")
155
  print(f"Shock at step {STEPS_TASK1}")
 
158
  results = []
159
 
160
  for seed in SEEDS:
161
+ print(f"\n{'─' * 55}\n SEED {seed}\n{'─' * 55}")
162
+
163
+ torch.manual_seed(seed)
164
+ torch.cuda.manual_seed_all(seed)
165
+ np.random.seed(seed)
166
+ random.seed(seed)
167
+
168
+ base_model = build_model()
169
+ base_model = train_baseline(base_model)
170
+ f1_mrpc_base = eval_f1(base_model, val_mrpc, metric_mrpc)
171
+ f1_sst2_base = eval_f1(base_model, val_sst2, metric_sst2)
172
+ del base_model; torch.cuda.empty_cache()
173
+
174
+ torch.manual_seed(seed)
175
+ torch.cuda.manual_seed_all(seed)
176
+ np.random.seed(seed)
177
+ random.seed(seed)
178
+
179
+ uni_model = build_model()
180
+ uni_model, usage, rank_trace = train_orbital(uni_model)
181
+ f1_mrpc_uni = eval_f1(uni_model, val_mrpc, metric_mrpc)
182
+ f1_sst2_uni = eval_f1(uni_model, val_sst2, metric_sst2)
183
+
184
+ er = eff_rank(usage)
185
+ saving = 1 - er / 16
186
+ transitions = sum(1 for i in range(1, len(rank_trace)) if rank_trace[i] != rank_trace[i-1])
187
+
188
+ print(f"\n {'':30s} {'BASELINE':>10s} {'ORBITAL':>10s}")
189
+ print(f" {'─' * 55}")
190
+ print(f" {'MRPC F1 (retention)':30s} {f1_mrpc_base:10.3f} {f1_mrpc_uni:10.3f}")
191
+ print(f" {'SST-2 Acc (new task)':30s} {f1_sst2_base:10.3f} {f1_sst2_uni:10.3f}")
192
+ print(f"\n Orbital: eff_rank={er:.1f} saving={saving*100:.0f}% transitions={transitions}")
193
+
194
+ results.append({
195
+ 'f1_mrpc_base': f1_mrpc_base, 'f1_sst2_base': f1_sst2_base,
196
+ 'f1_mrpc_uni': f1_mrpc_uni, 'f1_sst2_uni': f1_sst2_uni,
197
+ 'eff_rank': er, 'saving': saving
198
+ })
199
+ del uni_model; torch.cuda.empty_cache()
200
+
201
+ ── SUMMARY ─────────────────────────────────────────
202
+
 
203
  print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}")
204
  mrpc_b = np.mean([r['f1_mrpc_base'] for r in results])
205
  mrpc_u = np.mean([r['f1_mrpc_uni'] for r in results])
 
208
  er_avg = np.mean([r['eff_rank'] for r in results])
209
  sv_avg = np.mean([r['saving'] for r in results])
210
 
211
+ print(f"\n {'MRPC F1':20s} {mrpc_b:.3f} β†’ {mrpc_u:.3f}")
212
+ print(f" {'SST-2 Acc':20s} {sst2_b:.3f} β†’ {sst2_u:.3f}")
213
+ print(f" {'Eff rank':20s} 16.0 β†’ {er_avg:.1f}")
214
+ print(f" {'Saving':20s} 0% β†’ {sv_avg*100:.0f}%")