ASTERIZER commited on
Commit
cd7ee10
·
verified ·
1 Parent(s): 1dec56b

Upload lora_sft_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. lora_sft_train.py +455 -0
lora_sft_train.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import math
4
+ import os
5
+ import time
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import yaml
11
+ from huggingface_hub import hf_hub_download
12
+ from torch.amp import GradScaler, autocast
13
+
14
+ from sft_train import LUNAModel, SFTDataset, cosine_lr, probe_hardware, run_eval_prompts
15
+
16
+
17
+ SEP = "=" * 72
18
+
19
+
20
+ class LoRALinear(nn.Module):
21
+ def __init__(self, base_layer, rank=16, alpha=32, dropout=0.05):
22
+ super().__init__()
23
+ if not isinstance(base_layer, nn.Linear):
24
+ raise TypeError("LoRALinear expects a torch.nn.Linear base layer")
25
+ self.base = base_layer
26
+ self.rank = rank
27
+ self.alpha = alpha
28
+ self.scale = alpha / max(rank, 1)
29
+ self.dropout = nn.Dropout(dropout)
30
+ self.lora_a = nn.Linear(base_layer.in_features, rank, bias=False)
31
+ self.lora_b = nn.Linear(rank, base_layer.out_features, bias=False)
32
+ nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
33
+ nn.init.zeros_(self.lora_b.weight)
34
+
35
+ for parameter in self.base.parameters():
36
+ parameter.requires_grad = False
37
+
38
+ def forward(self, x):
39
+ base_out = self.base(x)
40
+ lora_out = self.lora_b(self.lora_a(self.dropout(x))) * self.scale
41
+ return base_out + lora_out
42
+
43
+
44
+ def load_config(config_path):
45
+ with open(config_path, encoding="utf-8") as handle:
46
+ raw = yaml.safe_load(handle)
47
+
48
+ cfg = {
49
+ "auto_config": raw.get("auto_config", True),
50
+ "hf_model_repo": raw.get("hf_model_repo", "ASTERIZER/LUNA-100M"),
51
+ "hf_model_file": raw.get("hf_model_file", "sft_v1/final/model.pth"),
52
+ "pretrained_ckpt": raw.get("pretrained_ckpt", "Base/out/input_models/luna_sft_v1/model.pth"),
53
+ "train_json": raw.get("train_json", "Base/Datasets/rag_mcp_sft/train.json"),
54
+ "val_json": raw.get("val_json", "Base/Datasets/rag_mcp_sft/val.json"),
55
+ "out_dir": raw.get("out_dir", "Base/out/sft/rag_mcp_lora"),
56
+ "tokenizer_dir": raw.get("tokenizer_dir", "Base/checkpoints/EleutherAI/pythia-160m"),
57
+ "vocab_size": raw["model"]["vocab_size"],
58
+ "seq_len": raw["model"]["seq_len"],
59
+ "n_layer": raw["model"]["n_layer"],
60
+ "n_embd": raw["model"]["n_embd"],
61
+ "n_head": raw["model"]["n_head"],
62
+ "epochs": raw["train"]["epochs"],
63
+ "lr_warmup_steps": raw["train"]["lr_warmup_steps"],
64
+ "save_interval": raw["train"]["save_interval"],
65
+ "log_interval": raw["train"]["log_interval"],
66
+ "eval_interval": raw["train"]["eval_interval"],
67
+ "max_norm": raw["train"]["max_norm"],
68
+ "lr": raw["optimizer"]["lr"],
69
+ "min_lr": raw["optimizer"]["min_lr"],
70
+ "weight_decay": raw["optimizer"]["weight_decay"],
71
+ "betas": tuple(raw["optimizer"]["betas"]),
72
+ "eps": raw["optimizer"]["eps"],
73
+ "global_batch": raw["batch"]["global_batch"],
74
+ "micro_batch": raw["batch"]["micro_batch"],
75
+ "grad_accum": raw["batch"]["grad_accum"],
76
+ "auto_probe_batch": raw["batch"].get("auto_probe_batch", True),
77
+ "probe_safety": raw["batch"].get("probe_safety", 0.94),
78
+ "num_workers": raw["dataloader"]["num_workers"],
79
+ "pin_memory": raw["dataloader"]["pin_memory"],
80
+ "precision": raw["hardware"]["precision"],
81
+ "eval_prompts": raw.get("eval_prompts", []),
82
+ "lora_rank": raw["lora"]["rank"],
83
+ "lora_alpha": raw["lora"]["alpha"],
84
+ "lora_dropout": raw["lora"]["dropout"],
85
+ "target_modules": list(raw["lora"]["target_modules"]),
86
+ }
87
+ return cfg
88
+
89
+
90
+ def resolve_checkpoint(cfg):
91
+ ckpt_path = Path(cfg["pretrained_ckpt"])
92
+ if ckpt_path.exists():
93
+ return ckpt_path
94
+
95
+ ckpt_path.parent.mkdir(parents=True, exist_ok=True)
96
+ hf_hub_download(
97
+ repo_id=cfg["hf_model_repo"],
98
+ filename=cfg["hf_model_file"],
99
+ local_dir=str(ckpt_path.parent),
100
+ token=os.environ.get("HF_TOKEN"),
101
+ )
102
+ downloaded = ckpt_path.parent / cfg["hf_model_file"]
103
+ if not downloaded.exists():
104
+ raise FileNotFoundError(f"Expected downloaded checkpoint at {downloaded}")
105
+ return downloaded
106
+
107
+
108
+ def inject_lora(model, target_modules, rank, alpha, dropout):
109
+ replaced = []
110
+ for module_name, module in list(model.named_modules()):
111
+ if not isinstance(module, nn.Linear):
112
+ continue
113
+ if not any(module_name.endswith(target) for target in target_modules):
114
+ continue
115
+ parent_name, _, child_name = module_name.rpartition(".")
116
+ parent_module = model.get_submodule(parent_name) if parent_name else model
117
+ wrapped = LoRALinear(module, rank=rank, alpha=alpha, dropout=dropout)
118
+ wrapped = wrapped.to(device=module.weight.device, dtype=module.weight.dtype)
119
+ setattr(parent_module, child_name, wrapped)
120
+ replaced.append(module_name)
121
+ if not replaced:
122
+ raise RuntimeError("No target modules matched for LoRA injection")
123
+ return replaced
124
+
125
+
126
+ def get_lora_state_dict(model):
127
+ state_dict = model.state_dict()
128
+ return {
129
+ name: tensor.cpu()
130
+ for name, tensor in state_dict.items()
131
+ if "lora_a.weight" in name or "lora_b.weight" in name
132
+ }
133
+
134
+
135
+ def count_trainable_parameters(model):
136
+ return sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
137
+
138
+
139
+ def probe_max_micro_batch_lora(model, trainable_parameters, device, dtype, seq_len, vocab_size, safety=0.94, grad_accum_sim=2):
140
+ if device.type != "cuda":
141
+ return 1
142
+
143
+ optimizer = torch.optim.AdamW(trainable_parameters, lr=1e-4)
144
+ lo, hi, best = 1, 512, 1
145
+
146
+ while lo <= hi:
147
+ mid = (lo + hi) // 2
148
+ try:
149
+ torch.cuda.empty_cache()
150
+ gc.collect()
151
+ optimizer.zero_grad(set_to_none=True)
152
+
153
+ for _ in range(grad_accum_sim):
154
+ input_ids = torch.randint(0, vocab_size, (mid, seq_len), device=device)
155
+ loss_mask = torch.ones_like(input_ids)
156
+ with autocast(device_type="cuda", dtype=dtype):
157
+ _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
158
+ loss = loss / grad_accum_sim
159
+ loss.backward()
160
+ del input_ids, loss_mask, loss
161
+
162
+ optimizer.step()
163
+ optimizer.zero_grad(set_to_none=True)
164
+ best = mid
165
+ lo = mid + 1
166
+ except (torch.cuda.OutOfMemoryError, RuntimeError) as error:
167
+ if "out of memory" not in str(error).lower() and not isinstance(error, torch.cuda.OutOfMemoryError):
168
+ raise
169
+ optimizer.zero_grad(set_to_none=True)
170
+ torch.cuda.empty_cache()
171
+ gc.collect()
172
+ hi = mid - 1
173
+
174
+ del optimizer
175
+ torch.cuda.empty_cache()
176
+ gc.collect()
177
+
178
+ safe = max(1, int(best * safety))
179
+ print(f" LoRA batch probe: max_micro_batch={best}, using {safe} ({int(safety * 100)}% safety)")
180
+ return safe
181
+
182
+
183
+ def load_base_weights(model, checkpoint_path, device):
184
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
185
+ state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
186
+ model.load_state_dict(state_dict, strict=True)
187
+
188
+
189
+ def train(cfg):
190
+ hw = probe_hardware()
191
+ device = torch.device(hw["device"])
192
+ dtype = hw.get("dtype", torch.float32) if cfg["auto_config"] else {
193
+ "bf16": torch.bfloat16,
194
+ "fp16": torch.float16,
195
+ "fp32": torch.float32,
196
+ }.get(cfg["precision"], torch.float32)
197
+
198
+ from transformers import AutoTokenizer
199
+
200
+ tokenizer = AutoTokenizer.from_pretrained(cfg["tokenizer_dir"])
201
+ ckpt_path = resolve_checkpoint(cfg)
202
+
203
+ model = LUNAModel(
204
+ vocab_size=cfg["vocab_size"],
205
+ block_size=cfg["seq_len"],
206
+ n_layer=cfg["n_layer"],
207
+ n_embd=cfg["n_embd"],
208
+ n_head=cfg["n_head"],
209
+ ).to(device)
210
+ load_base_weights(model, ckpt_path, device)
211
+
212
+ for parameter in model.parameters():
213
+ parameter.requires_grad = False
214
+
215
+ replaced = inject_lora(
216
+ model,
217
+ target_modules=cfg["target_modules"],
218
+ rank=cfg["lora_rank"],
219
+ alpha=cfg["lora_alpha"],
220
+ dropout=cfg["lora_dropout"],
221
+ )
222
+ trainable_params = count_trainable_parameters(model)
223
+ total_params = sum(parameter.numel() for parameter in model.parameters())
224
+ trainable_parameters = [parameter for parameter in model.parameters() if parameter.requires_grad]
225
+
226
+ if cfg["auto_config"] and device.type == "cuda" and cfg["auto_probe_batch"]:
227
+ print(" Probing LoRA micro_batch against available VRAM...")
228
+ cfg["micro_batch"] = probe_max_micro_batch_lora(
229
+ model,
230
+ trainable_parameters=trainable_parameters,
231
+ device=device,
232
+ dtype=dtype,
233
+ seq_len=cfg["seq_len"],
234
+ vocab_size=cfg["vocab_size"],
235
+ safety=cfg["probe_safety"],
236
+ )
237
+ cfg["grad_accum"] = max(1, math.ceil(cfg["global_batch"] / cfg["micro_batch"]))
238
+ torch.cuda.reset_peak_memory_stats(device)
239
+
240
+ effective_batch = cfg["micro_batch"] * cfg["grad_accum"]
241
+
242
+ train_dataset = SFTDataset(cfg["train_json"], tokenizer, max_len=cfg["seq_len"])
243
+ val_dataset = SFTDataset(cfg["val_json"], tokenizer, max_len=cfg["seq_len"]) if Path(cfg["val_json"]).exists() else None
244
+
245
+ train_loader = torch.utils.data.DataLoader(
246
+ train_dataset,
247
+ batch_size=cfg["micro_batch"],
248
+ shuffle=True,
249
+ num_workers=cfg["num_workers"],
250
+ pin_memory=cfg["pin_memory"],
251
+ drop_last=True,
252
+ prefetch_factor=4 if cfg["num_workers"] > 0 else None,
253
+ persistent_workers=cfg["num_workers"] > 0,
254
+ )
255
+ val_loader = None
256
+ if val_dataset is not None:
257
+ val_loader = torch.utils.data.DataLoader(
258
+ val_dataset,
259
+ batch_size=cfg["micro_batch"],
260
+ shuffle=False,
261
+ num_workers=min(2, cfg["num_workers"]),
262
+ pin_memory=cfg["pin_memory"],
263
+ drop_last=False,
264
+ )
265
+
266
+ optimizer = torch.optim.AdamW(
267
+ trainable_parameters,
268
+ lr=cfg["lr"],
269
+ weight_decay=cfg["weight_decay"],
270
+ betas=cfg["betas"],
271
+ eps=cfg["eps"],
272
+ )
273
+ scaler = GradScaler(enabled=(device.type == "cuda" and dtype == torch.float16))
274
+
275
+ steps_per_epoch = max(1, len(train_loader) // cfg["grad_accum"])
276
+ total_steps = steps_per_epoch * cfg["epochs"]
277
+ warmup_steps = min(cfg["lr_warmup_steps"], max(1, total_steps // 5))
278
+ out_dir = Path(cfg["out_dir"])
279
+ out_dir.mkdir(parents=True, exist_ok=True)
280
+ best_val_loss = float("inf")
281
+ step = 0
282
+
283
+ latest_path = out_dir / "latest.pt"
284
+ if latest_path.exists():
285
+ checkpoint = torch.load(latest_path, map_location=device, weights_only=True)
286
+ model.load_state_dict(checkpoint["adapter"], strict=False)
287
+ optimizer.load_state_dict(checkpoint["optimizer"])
288
+ step = checkpoint["step"]
289
+
290
+ print(SEP)
291
+ print(" LUNA 100M - LoRA SFT")
292
+ print(SEP)
293
+ print(f" Base checkpoint : {ckpt_path}")
294
+ print(f" Train dataset : {cfg['train_json']}")
295
+ print(f" Val dataset : {cfg['val_json']}")
296
+ print(f" Output dir : {out_dir}")
297
+ print(f" Device : {hw['gpu_name']} ({hw['vram_gb']:.1f} GB)")
298
+ print(f" Precision : {cfg['precision']} dtype={dtype}")
299
+ print(f" LoRA modules : {', '.join(replaced)}")
300
+ print(f" Trainable params: {trainable_params:,} / {total_params:,}")
301
+ print(f" micro_batch : {cfg['micro_batch']}")
302
+ print(f" grad_accum : {cfg['grad_accum']}")
303
+ print(f" effective_batch : {effective_batch}")
304
+ print(f" Train samples : {len(train_dataset):,}")
305
+ print(f" Val samples : {len(val_dataset):,}" if val_dataset is not None else " Val samples : 0")
306
+ print(SEP)
307
+
308
+ if cfg["eval_prompts"] and step == 0:
309
+ run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, 0, out_dir)
310
+
311
+ model.train()
312
+ run_t0 = time.perf_counter()
313
+
314
+ for epoch in range(cfg["epochs"]):
315
+ micro_step = 0
316
+ for input_ids, loss_mask in train_loader:
317
+ current_global_step = epoch * steps_per_epoch + (micro_step // cfg["grad_accum"])
318
+ if current_global_step < step and (micro_step % cfg["grad_accum"] == cfg["grad_accum"] - 1):
319
+ micro_step += 1
320
+ continue
321
+ if current_global_step >= total_steps:
322
+ break
323
+
324
+ input_ids = input_ids.to(device, non_blocking=True)
325
+ loss_mask = loss_mask.to(device, non_blocking=True)
326
+ step_start = time.perf_counter()
327
+
328
+ with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
329
+ _, loss = model(input_ids, targets=input_ids, loss_mask=loss_mask, return_logits=False)
330
+ loss = loss / cfg["grad_accum"]
331
+
332
+ scaler.scale(loss).backward()
333
+ micro_step += 1
334
+
335
+ if micro_step % cfg["grad_accum"] != 0:
336
+ continue
337
+
338
+ scaler.unscale_(optimizer)
339
+ torch.nn.utils.clip_grad_norm_(trainable_parameters, cfg["max_norm"])
340
+ lr_now = cosine_lr(step, warmup_steps, total_steps, cfg["lr"], cfg["min_lr"])
341
+ for param_group in optimizer.param_groups:
342
+ param_group["lr"] = lr_now
343
+
344
+ scaler.step(optimizer)
345
+ scaler.update()
346
+ optimizer.zero_grad(set_to_none=True)
347
+
348
+ if device.type == "cuda":
349
+ torch.cuda.synchronize()
350
+
351
+ dt = time.perf_counter() - step_start
352
+ step += 1
353
+
354
+ if step % cfg["log_interval"] == 0 or step <= 3:
355
+ tokens_step = effective_batch * cfg["seq_len"]
356
+ tps = tokens_step / max(dt, 1e-6)
357
+ vram = torch.cuda.max_memory_allocated() / 1024**3 if device.type == "cuda" else 0
358
+ eta_h = (total_steps - step) * dt / 3600
359
+ print(
360
+ f" step {step:6d}/{total_steps} | epoch {epoch + 1}/{cfg['epochs']} | "
361
+ f"loss {loss.item() * cfg['grad_accum']:.4f} | lr {lr_now:.2e} | "
362
+ f"{tps:,.0f} tok/s | VRAM {vram:.1f}GB | ETA {eta_h:.1f}h"
363
+ )
364
+
365
+ if step % cfg["save_interval"] == 0 or step == total_steps:
366
+ step_dir = out_dir / f"step-{step:06d}"
367
+ step_dir.mkdir(parents=True, exist_ok=True)
368
+ adapter_state = get_lora_state_dict(model)
369
+ torch.save(adapter_state, step_dir / "adapter_model.pt")
370
+ torch.save(
371
+ {
372
+ "step": step,
373
+ "adapter": adapter_state,
374
+ "optimizer": optimizer.state_dict(),
375
+ "epoch": epoch,
376
+ "loss": loss.item() * cfg["grad_accum"],
377
+ },
378
+ latest_path,
379
+ )
380
+ print(f" Saved -> {step_dir}")
381
+
382
+ if step % cfg["eval_interval"] == 0 or step == total_steps:
383
+ if val_loader is not None:
384
+ model.eval()
385
+ val_loss_sum = 0.0
386
+ val_count = 0
387
+ with torch.no_grad():
388
+ for val_ids, val_mask in val_loader:
389
+ val_ids = val_ids.to(device, non_blocking=True)
390
+ val_mask = val_mask.to(device, non_blocking=True)
391
+ with autocast(device_type=device.type, dtype=dtype, enabled=(device.type == "cuda")):
392
+ _, val_loss = model(val_ids, targets=val_ids, loss_mask=val_mask, return_logits=False)
393
+ val_loss_sum += val_loss.item()
394
+ val_count += 1
395
+ if val_count >= 50:
396
+ break
397
+ avg_val = val_loss_sum / max(val_count, 1)
398
+ print(f" Val loss: {avg_val:.4f}")
399
+ if avg_val < best_val_loss:
400
+ best_val_loss = avg_val
401
+ torch.save(get_lora_state_dict(model), out_dir / "best_adapter_model.pt")
402
+ print(" New best! Saved best_adapter_model.pt")
403
+ model.train()
404
+
405
+ if cfg["eval_prompts"]:
406
+ run_eval_prompts(model, tokenizer, cfg["eval_prompts"], device, step, out_dir)
407
+
408
+ final_dir = out_dir / "final"
409
+ final_dir.mkdir(parents=True, exist_ok=True)
410
+ torch.save(get_lora_state_dict(model), final_dir / "adapter_model.pt")
411
+ torch.save(
412
+ {
413
+ "step": step,
414
+ "adapter": get_lora_state_dict(model),
415
+ "lora_rank": cfg["lora_rank"],
416
+ "lora_alpha": cfg["lora_alpha"],
417
+ "lora_dropout": cfg["lora_dropout"],
418
+ "target_modules": cfg["target_modules"],
419
+ "base_checkpoint": str(ckpt_path),
420
+ },
421
+ final_dir / "adapter_bundle.pt",
422
+ )
423
+
424
+ total_h = (time.perf_counter() - run_t0) / 3600
425
+ print(SEP)
426
+ print(f" LoRA SFT complete in {total_h:.2f}h -> {final_dir}")
427
+ print(f" Best val loss: {best_val_loss:.4f}")
428
+ print(SEP)
429
+
430
+
431
+ def parse_args():
432
+ parser = argparse.ArgumentParser(description="LUNA 100M - LoRA SFT")
433
+ parser.add_argument("--config", default="rag_mcp_lora_config.yaml")
434
+ parser.add_argument("--pretrained_ckpt", default=None)
435
+ parser.add_argument("--train_json", default=None)
436
+ parser.add_argument("--val_json", default=None)
437
+ parser.add_argument("--out_dir", default=None)
438
+ parser.add_argument("--epochs", type=int, default=None)
439
+ return parser.parse_args()
440
+
441
+
442
+ def main():
443
+ args = parse_args()
444
+ cfg = load_config(args.config)
445
+ for key in ("pretrained_ckpt", "train_json", "val_json", "out_dir"):
446
+ value = getattr(args, key)
447
+ if value:
448
+ cfg[key] = value
449
+ if args.epochs is not None:
450
+ cfg["epochs"] = args.epochs
451
+ train(cfg)
452
+
453
+
454
+ if __name__ == "__main__":
455
+ main()