| """ |
| Orbital LoRA β Stress Test: Task Switch |
| |
| MRPC (60 steps) β SST-2 (60 steps) |
| Baseline (r=16 fixed) vs Orbital Controller |
| """ |
|
|
| import time, random, math, numpy as np, torch, torch.nn as nn |
| import torch.nn.functional as F, evaluate |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from torch.utils.data import DataLoader |
|
|
| import sys, os |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(file)))) |
|
|
| from nested_lora import NestedLoRALinear, inject_nested_lora |
| from orbital_controller import OrbitalController |
| from controller import set_rank |
|
|
| ββ CONFIG ββββββββββββββββββββββββββββββββββββββββββ |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| MODEL = "distilbert-base-uncased" |
| BATCH = 8 |
| LR = 5e-5 |
| SEEDS = [0, 1, 2] |
|
|
| MAX_RANK = 16 |
| WARMUP = 10 |
| STABLE_WINDOW = 6 |
|
|
| STEPS_TASK1 = 60 |
| STEPS_TASK2 = 60 |
| TOTAL_STEPS = STEPS_TASK1 + STEPS_TASK2 |
|
|
| ββ DATA ββββββββββββββββββββββββββββββββββββββββββββ |
|
|
| print("Loading data...") |
| tok = AutoTokenizer.from_pretrained(MODEL) |
|
|
| ds_mrpc = load_dataset("glue", "mrpc") |
| def tok_mrpc(x): |
| return tok(x["sentence1"], x["sentence2"], |
| truncation=True, padding="max_length", max_length=128) |
| ds_mrpc = ds_mrpc.map(tok_mrpc, batched=True) |
| ds_mrpc.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) |
| train_mrpc = DataLoader(ds_mrpc["train"], batch_size=BATCH, shuffle=True) |
| val_mrpc = DataLoader(ds_mrpc["validation"], batch_size=BATCH) |
|
|
| ds_sst2 = load_dataset("glue", "sst2") |
| def tok_sst2(x): |
| return tok(x["sentence"], truncation=True, padding="max_length", max_length=128) |
| ds_sst2 = ds_sst2.map(tok_sst2, batched=True) |
| ds_sst2.set_format(type="torch", columns=["input_ids", "attention_mask", "label"]) |
| train_sst2 = DataLoader(ds_sst2["train"], batch_size=BATCH, shuffle=True) |
| val_sst2 = DataLoader(ds_sst2["validation"], batch_size=BATCH) |
|
|
| metric_mrpc = evaluate.load("glue", "mrpc") |
| metric_sst2 = evaluate.load("glue", "sst2") |
|
|
| ββ HELPERS βββββββββββββββββββββββββββββββββββββββββ |
|
|
| def make_iter(loader): |
| while True: |
| for batch in loader: |
| yield batch |
|
|
| def get_batch(it): |
| batch = next(it) |
| return (batch["input_ids"].to(DEVICE), |
| batch["attention_mask"].to(DEVICE), |
| batch["label"].to(DEVICE)) |
|
|
| def build_model(): |
| base = AutoModelForSequenceClassification.from_pretrained( |
| MODEL, num_labels=2, ignore_mismatched_sizes=True |
| ) |
| return inject_nested_lora(base, MAX_RANK).to(DEVICE) |
|
|
| def eval_f1(model, loader, metric_fn): |
| model.eval() |
| preds, labels = [], [] |
| with torch.no_grad(): |
| for batch in loader: |
| x = batch["input_ids"].to(DEVICE) |
| m = batch["attention_mask"].to(DEVICE) |
| y = batch["label"].to(DEVICE) |
| logits = model(input_ids=x, attention_mask=m).logits |
| preds.extend(logits.argmax(dim=-1).cpu().numpy()) |
| labels.extend(y.cpu().numpy()) |
| model.train() |
| result = metric_fn.compute(predictions=preds, references=labels) |
| return result.get("f1", result.get("accuracy", 0.0)) |
|
|
| def eff_rank(usage): |
| tot = sum(usage.values()) |
| return sum(k * v for k, v in usage.items()) / tot if tot > 0 else 0 |
|
|
| ββ TRAIN BASELINE ββββββββββββββββββββββββββββββββββ |
|
|
| def train_baseline(model): |
| opt = torch.optim.AdamW(model.parameters(), lr=LR) |
| set_rank(model, 16) |
| it_mrpc = make_iter(train_mrpc) |
| it_sst2 = make_iter(train_sst2) |
|
|
| for step in range(TOTAL_STEPS): |
| x, m, y = get_batch(it_mrpc if step < STEPS_TASK1 else it_sst2) |
|
|
| loss = model(input_ids=x, attention_mask=m, labels=y).loss |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| opt.zero_grad() |
|
|
| return model |
|
|
| ββ TRAIN ORBITAL βββββββββββββββββββββββββββββββββββ |
|
|
| def train_orbital(model): |
| ctrl = OrbitalController(warmup=WARMUP, stable_window=STABLE_WINDOW) |
| ctrl.rank = 4 |
| set_rank(model, 4) |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=LR) |
| usage = {4: 0, 8: 0, 16: 0} |
| rank_trace = [] |
| it_mrpc = make_iter(train_mrpc) |
| it_sst2 = make_iter(train_sst2) |
|
|
| for step in range(TOTAL_STEPS): |
| x, m, y = get_batch(it_mrpc if step < STEPS_TASK1 else it_sst2) |
|
|
| loss = model(input_ids=x, attention_mask=m, labels=y).loss |
| loss.backward() |
|
|
| new_rank = ctrl.step(loss.item()) |
| new_rank = max(4, min(16, new_rank)) |
| set_rank(model, new_rank) |
|
|
| usage[new_rank] += 1 |
| rank_trace.append(new_rank) |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| opt.zero_grad() |
|
|
| return model, usage, rank_trace |
|
|
| ββ RUN βββββββββββββββββββββββββββββββββββββββββββββ |
|
|
| print(f"\nDevice: {DEVICE}") |
| print(f"Plan: MRPC Γ {STEPS_TASK1} β SST-2 Γ {STEPS_TASK2}") |
| print(f"Shock at step {STEPS_TASK1}") |
| print("=" * 55) |
|
|
| results = [] |
|
|
| for seed in SEEDS: |
| print(f"\n{'β' * 55}\n SEED {seed}\n{'β' * 55}") |
|
|
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| base_model = build_model() |
| base_model = train_baseline(base_model) |
| f1_mrpc_base = eval_f1(base_model, val_mrpc, metric_mrpc) |
| f1_sst2_base = eval_f1(base_model, val_sst2, metric_sst2) |
| del base_model; torch.cuda.empty_cache() |
|
|
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| uni_model = build_model() |
| uni_model, usage, rank_trace = train_orbital(uni_model) |
| f1_mrpc_uni = eval_f1(uni_model, val_mrpc, metric_mrpc) |
| f1_sst2_uni = eval_f1(uni_model, val_sst2, metric_sst2) |
|
|
| er = eff_rank(usage) |
| saving = 1 - er / 16 |
| transitions = sum(1 for i in range(1, len(rank_trace)) if rank_trace[i] != rank_trace[i-1]) |
|
|
| print(f"\n {'':30s} {'BASELINE':>10s} {'ORBITAL':>10s}") |
| print(f" {'β' * 55}") |
| print(f" {'MRPC F1 (retention)':30s} {f1_mrpc_base:10.3f} {f1_mrpc_uni:10.3f}") |
| print(f" {'SST-2 Acc (new task)':30s} {f1_sst2_base:10.3f} {f1_sst2_uni:10.3f}") |
| print(f"\n Orbital: eff_rank={er:.1f} saving={saving*100:.0f}% transitions={transitions}") |
|
|
| results.append({ |
| 'f1_mrpc_base': f1_mrpc_base, 'f1_sst2_base': f1_sst2_base, |
| 'f1_mrpc_uni': f1_mrpc_uni, 'f1_sst2_uni': f1_sst2_uni, |
| 'eff_rank': er, 'saving': saving |
| }) |
| del uni_model; torch.cuda.empty_cache() |
|
|
| ββ SUMMARY βββββββββββββββββββββββββββββββββββββββββ |
|
|
| print(f"\n{'=' * 55}\n SUMMARY\n{'=' * 55}") |
| mrpc_b = np.mean([r['f1_mrpc_base'] for r in results]) |
| mrpc_u = np.mean([r['f1_mrpc_uni'] for r in results]) |
| sst2_b = np.mean([r['f1_sst2_base'] for r in results]) |
| sst2_u = np.mean([r['f1_sst2_uni'] for r in results]) |
| er_avg = np.mean([r['eff_rank'] for r in results]) |
| sv_avg = np.mean([r['saving'] for r in results]) |
|
|
| print(f"\n {'MRPC F1':20s} {mrpc_b:.3f} β {mrpc_u:.3f}") |
| print(f" {'SST-2 Acc':20s} {sst2_b:.3f} β {sst2_u:.3f}") |
| print(f" {'Eff rank':20s} 16.0 β {er_avg:.1f}") |
| print(f" {'Saving':20s} 0% β {sv_avg*100:.0f}%") |
|
|