LH-Tech-AI commited on
Commit
1606f40
Β·
verified Β·
1 Parent(s): 571dc68

Create sft.py

Browse files
Files changed (1) hide show
  1. sft.py +249 -0
sft.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Β© SupraLabs 2026 - SFT script for Supra-50M on alpaca-cleaned
3
+ No TRL. Uses HuggingFace Trainer with prompt-masked cross-entropy loss.
4
+ """
5
+
6
+ import os
7
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
8
+
9
+ print("[*] Loading libraries...")
10
+ import torch
11
+ import numpy as np
12
+ from dataclasses import dataclass
13
+ from typing import Optional
14
+ from datasets import load_dataset
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ Trainer,
18
+ TrainingArguments,
19
+ PreTrainedTokenizerBase,
20
+ PreTrainedTokenizerFast
21
+ )
22
+ from torch.utils.data import Dataset
23
+
24
+ # ── Config ────────────────────────────────────────────────────────────────────
25
+
26
+ MODEL_ID = "./Chimera-FINAL"
27
+ OUTPUT_DIR = "./Supra-50M-SFT"
28
+ MAX_LENGTH = 512 # alpaca samples are short, 512 is plenty
29
+ IGNORE_INDEX = -100 # standard label mask value for cross-entropy
30
+
31
+ # Conservative hyperparameters β€” small model, don't nuke the pretraining
32
+ LEARNING_RATE = 3e-4
33
+ EPOCHS = 4
34
+ BATCH_SIZE = 8
35
+ GRAD_ACCUM = 2 # effective batch size = 16
36
+ WARMUP_RATIO = 0.1
37
+ WEIGHT_DECAY = 0.0
38
+ MAX_GRAD_NORM = 1.0
39
+
40
+ # ── Alpaca prompt template ────────────────────────────────────────────────────
41
+
42
+ PROMPT_WITH_INPUT = (
43
+ "Below is an instruction that describes a task, paired with an input "
44
+ "that provides further context. Write a response that appropriately "
45
+ "completes the request.\n\n"
46
+ "### Instruction:\n{instruction}\n\n"
47
+ "### Input:\n{input}\n\n"
48
+ "### Response:\n"
49
+ )
50
+
51
+ PROMPT_WITHOUT_INPUT = (
52
+ "Below is an instruction that describes a task. Write a response that "
53
+ "appropriately completes the request.\n\n"
54
+ "### Instruction:\n{instruction}\n\n"
55
+ "### Response:\n"
56
+ )
57
+
58
+ def build_prompt(sample: dict) -> tuple[str, str]:
59
+ """Returns (prompt, response) β€” kept separate so we can mask the prompt."""
60
+ instruction = sample["instruction"].strip()
61
+ inp = sample.get("input", "").strip()
62
+ output = sample["output"].strip()
63
+
64
+ if inp:
65
+ prompt = PROMPT_WITH_INPUT.format(instruction=instruction, input=inp)
66
+ else:
67
+ prompt = PROMPT_WITHOUT_INPUT.format(instruction=instruction)
68
+
69
+ return prompt, output
70
+
71
+
72
+ # ── Dataset ───────────────────────────────────────────────────────────────────
73
+
74
+ class AlpacaDataset(Dataset):
75
+ """
76
+ Tokenizes each sample and masks the prompt portion of the labels so the
77
+ model only computes loss on the response tokens β€” not on the instruction.
78
+ """
79
+
80
+ def __init__(self, hf_dataset, tokenizer: PreTrainedTokenizerBase, max_length: int):
81
+ self.tokenizer = tokenizer
82
+ self.max_length = max_length
83
+ self.samples = hf_dataset
84
+
85
+ def __len__(self):
86
+ return len(self.samples)
87
+
88
+ def __getitem__(self, idx):
89
+ prompt, response = build_prompt(self.samples[idx])
90
+
91
+ # Tokenize prompt and response separately so we know the prompt length
92
+ prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
93
+ prompt_ids = [self.tokenizer.bos_token_id] + prompt_ids # explizit
94
+ response_ids = self.tokenizer.encode(response, add_special_tokens=False) + [self.tokenizer.eos_token_id]
95
+
96
+ input_ids = prompt_ids + response_ids
97
+
98
+ # Truncate to max_length
99
+ input_ids = input_ids[:self.max_length]
100
+
101
+ # Labels: mask prompt tokens with IGNORE_INDEX
102
+ prompt_len = min(len(prompt_ids), len(input_ids))
103
+ labels = [IGNORE_INDEX] * prompt_len + input_ids[prompt_len:]
104
+
105
+ # Sanity: both must be the same length after truncation
106
+ assert len(input_ids) == len(labels)
107
+
108
+ return {
109
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
110
+ "labels": torch.tensor(labels, dtype=torch.long),
111
+ }
112
+
113
+
114
+ # ── Collator ──────────────────────────────────────────────────────────────────
115
+
116
+ @dataclass
117
+ class PaddingCollator:
118
+ """
119
+ Right-pads input_ids and labels to the longest sequence in the batch.
120
+ Labels are padded with IGNORE_INDEX so padding never contributes to loss.
121
+ """
122
+ tokenizer: PreTrainedTokenizerBase
123
+ max_length: int
124
+
125
+ def __call__(self, batch):
126
+ max_len = max(len(x["input_ids"]) for x in batch)
127
+ max_len = min(max_len, self.max_length)
128
+
129
+ input_ids_padded = []
130
+ labels_padded = []
131
+ attention_masks = []
132
+
133
+ for item in batch:
134
+ ids = item["input_ids"][:max_len]
135
+ lbls = item["labels"][:max_len]
136
+ pad_n = max_len - len(ids)
137
+
138
+ input_ids_padded.append(
139
+ torch.cat([ids, torch.full((pad_n,), self.tokenizer.pad_token_id, dtype=torch.long)])
140
+ )
141
+ labels_padded.append(
142
+ torch.cat([lbls, torch.full((pad_n,), IGNORE_INDEX, dtype=torch.long)])
143
+ )
144
+ attention_masks.append(
145
+ torch.cat([torch.ones(len(ids), dtype=torch.long),
146
+ torch.zeros(pad_n, dtype=torch.long)])
147
+ )
148
+
149
+ return {
150
+ "input_ids": torch.stack(input_ids_padded),
151
+ "labels": torch.stack(labels_padded),
152
+ "attention_mask": torch.stack(attention_masks),
153
+ }
154
+
155
+
156
+ # ── Main ──────────────────────────────────────────────────────────────────────
157
+
158
+ def main():
159
+ # Load tokenizer + model from Hub
160
+ print(f"[*] Loading tokenizer from {MODEL_ID}...")
161
+ from tokenizers import ByteLevelBPETokenizer
162
+
163
+ fast_tokenizer = ByteLevelBPETokenizer(
164
+ "custom_llama_tokenizer-vocab.json",
165
+ "custom_llama_tokenizer-merges.txt"
166
+ )
167
+ tokenizer = PreTrainedTokenizerFast(
168
+ tokenizer_object=fast_tokenizer,
169
+ bos_token="<s>",
170
+ eos_token="</s>",
171
+ unk_token="<unk>",
172
+ pad_token="<pad>",
173
+ )
174
+
175
+ print(f"[*] Loading model from {MODEL_ID}...")
176
+ model = AutoModelForCausalLM.from_pretrained(
177
+ MODEL_ID,
178
+ dtype=torch.bfloat16,
179
+ device_map="auto",
180
+ )
181
+
182
+ print(f"[+] Model loaded β€” {model.num_parameters():,} parameters")
183
+
184
+ # Load alpaca-cleaned (β‰ˆ52k instruction-tuning pairs)
185
+ print("[*] Loading alpaca-cleaned dataset...")
186
+ raw = load_dataset("yahma/alpaca-cleaned", split="train")
187
+ print(f"[+] Dataset: {len(raw):,} samples")
188
+
189
+ # Optional: quick sanity-check split (comment out for full training)
190
+ # raw = raw.select(range(1000))
191
+
192
+ split = raw.train_test_split(test_size=0.01, seed=42)
193
+ train_dataset = AlpacaDataset(split["train"], tokenizer, MAX_LENGTH)
194
+ eval_dataset = AlpacaDataset(split["test"], tokenizer, MAX_LENGTH)
195
+ collator = PaddingCollator(tokenizer=tokenizer, max_length=MAX_LENGTH)
196
+
197
+ print(f"[+] Dataset ready: {len(train_dataset):,} samples")
198
+ print(f"[+] Example prompt preview:\n{build_prompt(raw[0])[0][:800]}...")
199
+
200
+ # Training arguments
201
+ training_args = TrainingArguments(
202
+ output_dir=OUTPUT_DIR,
203
+ num_train_epochs=EPOCHS,
204
+ per_device_train_batch_size=BATCH_SIZE,
205
+ gradient_accumulation_steps=GRAD_ACCUM,
206
+ learning_rate=LEARNING_RATE,
207
+ lr_scheduler_type="cosine",
208
+ warmup_ratio=WARMUP_RATIO,
209
+ weight_decay=WEIGHT_DECAY,
210
+ max_grad_norm=MAX_GRAD_NORM,
211
+ bf16=True,
212
+ fp16=False,
213
+ logging_steps=50,
214
+ save_total_limit=2,
215
+ report_to="none",
216
+ dataloader_num_workers=8,
217
+ dataloader_pin_memory=True,
218
+ optim="adamw_torch_fused",
219
+ adam_beta1=0.9,
220
+ adam_beta2=0.999,
221
+ push_to_hub=False,
222
+ seed=42,
223
+ data_seed=42,
224
+ eval_strategy="epoch",
225
+ save_strategy="epoch",
226
+ load_best_model_at_end=True,
227
+ metric_for_best_model="eval_loss",
228
+ greater_is_better=False,
229
+ )
230
+
231
+ trainer = Trainer(
232
+ model=model,
233
+ args=training_args,
234
+ train_dataset=train_dataset,
235
+ eval_dataset=eval_dataset,
236
+ data_collator=collator,
237
+ )
238
+
239
+ print("[*] Starting SFT...")
240
+ trainer.train()
241
+
242
+ print(f"[*] Saving final model to {OUTPUT_DIR}-FINAL ...")
243
+ trainer.save_model(f"{OUTPUT_DIR}-FINAL")
244
+ tokenizer.save_pretrained(f"{OUTPUT_DIR}-FINAL")
245
+ print("[+] Done.")
246
+
247
+
248
+ if __name__ == "__main__":
249
+ main()