| import json, yaml, time |
| import torch |
| from pathlib import Path |
| from tokenizers import Tokenizer |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from model.tiny_gpt2 import TinyGPT2, GPTConfig |
|
|
| class SFTDataset(Dataset): |
| def __init__(self, jsonl_path, tokenizer, block_size): |
| self.block = block_size |
| self.tok = tokenizer |
| self.samples = [json.loads(l) for l in open(jsonl_path, 'r', encoding='utf-8')] |
| self.ids = [] |
| for s in self.samples: |
| text = f"Instruction:\n{s['instruction'].strip()}\nAnswer:\n{s['output'].strip()}\n" |
| self.ids.append(self.tok.encode(text).ids) |
| def __len__(self): return len(self.ids) |
| def __getitem__(self, i): |
| ids = self.ids[i][:self.block] |
| x = ids[:-1]; y = ids[1:] |
| return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long) |
|
|
| if __name__ == "__main__": |
| cfg = yaml.safe_load(open("train/config.yaml")) |
| Path("out/sft").mkdir(parents=True, exist_ok=True) |
| tok = Tokenizer.from_file(cfg["tokenizer_path"]) |
|
|
| gcfg = GPTConfig(**json.load(open(Path(cfg["save_dir"]) / "gpt_config.json"))) |
| model = TinyGPT2(gcfg) |
| model.load_state_dict(torch.load(Path(cfg["save_dir"])/"model.pt", map_location="cpu")) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = model.to(device) |
|
|
| ds = SFTDataset(cfg["sft_jsonl"], tok, gcfg.block_size) |
| dl = DataLoader(ds, batch_size=8, shuffle=True, drop_last=True) |
| opt = AdamW(model.parameters(), lr=1e-4) |
|
|
| model.train() |
| t0 = time.time() |
| for step, (x,y) in enumerate(dl, start=1): |
| x,y = x.to(device), y.to(device) |
| logits = model(x) |
| loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1)) |
| loss.backward(); opt.step(); opt.zero_grad(set_to_none=True) |
| if step % 50 == 0: |
| dt = time.time()-t0; t0=time.time() |
| print(f"sft step {step:5d} | loss {loss.item():.4f} | {dt:.2f}s") |
| if step >= 800: break |
|
|
| torch.save(model.state_dict(), "out/sft/model_sft.pt") |
| print("SFT saved.") |
|
|