LH-Tech-AI commited on
Commit
0e9ee61
·
verified ·
1 Parent(s): 6e3719b

Create train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +178 -0
train_model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4
+
5
+ print("[*] Loading libraries...")
6
+ import torch
7
+ import math
8
+ import numpy as np
9
+ from datasets import load_dataset
10
+ from tokenizers import ByteLevelBPETokenizer
11
+ from transformers import (
12
+ LlamaConfig,
13
+ LlamaForCausalLM,
14
+ PreTrainedTokenizerFast,
15
+ Trainer,
16
+ TrainingArguments,
17
+ )
18
+ from torch.utils.data import Dataset
19
+ from tqdm import tqdm
20
+
21
+ print("[*] Loading tokenizer...")
22
+ fast_tokenizer = ByteLevelBPETokenizer(
23
+ "./custom_llama_tokenizer-vocab.json",
24
+ "./custom_llama_tokenizer-merges.txt"
25
+ )
26
+ tokenizer = PreTrainedTokenizerFast(
27
+ tokenizer_object=fast_tokenizer,
28
+ bos_token="<s>",
29
+ eos_token="</s>",
30
+ unk_token="<unk>",
31
+ pad_token="<pad>",
32
+ )
33
+
34
+ TOKEN_BIN = "/kaggle/working/tokens.bin"
35
+ TARGET_TOKENS = 500_000_000
36
+ SEQ_LEN = 256
37
+ BATCH_TEXTS = 1000
38
+ FLUSH_EVERY = 1_000_000
39
+
40
+
41
+ def build_token_bin(fast_tokenizer, path=TOKEN_BIN, target_tokens=TARGET_TOKENS):
42
+ if os.path.exists(path) and os.path.getsize(path) >= target_tokens * 2:
43
+ print(f"[=] Reusing existing token file: {path}")
44
+ return
45
+
46
+ print(f"[*] Streaming + tokenizing {target_tokens:,} tokens → {path}")
47
+ mm = np.memmap(path, dtype=np.uint16, mode="w+", shape=(target_tokens,))
48
+
49
+ dataset = load_dataset(
50
+ "HuggingFaceFW/fineweb-edu", "sample-10BT",
51
+ split="train", streaming=True
52
+ )
53
+
54
+ written = 0
55
+ buf = []
56
+ texts = []
57
+ pbar = tqdm(total=target_tokens, desc="[*] Gathering tokens", unit="tok")
58
+
59
+ def flush_buf():
60
+ nonlocal written, buf
61
+ if not buf:
62
+ return False
63
+ n = min(len(buf), target_tokens - written)
64
+ mm[written:written + n] = np.asarray(buf[:n], dtype=np.uint16)
65
+ written += n
66
+ pbar.update(n)
67
+ del buf[:n]
68
+ return written >= target_tokens
69
+
70
+ for example in dataset:
71
+ texts.append(example["text"])
72
+ if len(texts) >= BATCH_TEXTS:
73
+ encs = fast_tokenizer.encode_batch(texts)
74
+ texts.clear()
75
+ for e in encs:
76
+ buf.extend(e.ids)
77
+ if len(buf) >= FLUSH_EVERY:
78
+ if flush_buf():
79
+ break
80
+
81
+ if written < target_tokens and texts:
82
+ encs = fast_tokenizer.encode_batch(texts)
83
+ for e in encs:
84
+ buf.extend(e.ids)
85
+ if written < target_tokens:
86
+ flush_buf()
87
+
88
+ pbar.close()
89
+ mm.flush()
90
+ del mm
91
+ print(f"[+] Wrote {written:,} tokens to {path} "
92
+ f"({os.path.getsize(path)/1e6:.1f} MB)")
93
+
94
+
95
+ class MemmapDataset(Dataset):
96
+ def __init__(self, path, total_tokens, seq_len=SEQ_LEN):
97
+ self.path = path
98
+ self.seq_len = seq_len
99
+ self.n_chunks = total_tokens // seq_len
100
+ self._data = None # lazy open (Multiprocessing-safe)
101
+
102
+ @property
103
+ def data(self):
104
+ if self._data is None:
105
+ self._data = np.memmap(
106
+ self.path, dtype=np.uint16, mode="r",
107
+ shape=(self.n_chunks * self.seq_len,)
108
+ )
109
+ return self._data
110
+
111
+ def __len__(self):
112
+ return self.n_chunks
113
+
114
+ def __getitem__(self, idx):
115
+ s = idx * self.seq_len
116
+ arr = np.asarray(self.data[s:s + self.seq_len], dtype=np.int64)
117
+ ids = torch.from_numpy(arr)
118
+ return {"input_ids": ids, "labels": ids.clone()}
119
+
120
+
121
+ def collate_fn(batch):
122
+ input_ids = torch.stack([b["input_ids"] for b in batch])
123
+ labels = torch.stack([b["labels"] for b in batch])
124
+ return {"input_ids": input_ids, "labels": labels}
125
+
126
+
127
+ print(f"[*] Preparing {TARGET_TOKENS:,} tokens (streaming, memmap-backed)...")
128
+ build_token_bin(fast_tokenizer, TOKEN_BIN, TARGET_TOKENS)
129
+ dataset = MemmapDataset(TOKEN_BIN, TARGET_TOKENS, seq_len=SEQ_LEN)
130
+ print(f"[+] Dataset ready: {len(dataset):,} chunks of {SEQ_LEN} tokens")
131
+
132
+ print("[*] Setting up model...")
133
+ config = LlamaConfig(
134
+ vocab_size=len(tokenizer.get_vocab()),
135
+ hidden_size=48,
136
+ intermediate_size=96,
137
+ num_hidden_layers=4,
138
+ num_attention_heads=4,
139
+ max_position_embeddings=256,
140
+ pad_token_id=tokenizer.pad_token_id,
141
+ bos_token_id=tokenizer.bos_token_id,
142
+ eos_token_id=tokenizer.eos_token_id,
143
+ )
144
+ model = LlamaForCausalLM(config)
145
+ print(f"[*] Model parameters: {model.num_parameters():,}")
146
+
147
+ print("[*] Defining training arguments...")
148
+ training_args = TrainingArguments(
149
+ output_dir="./Supra-Mini-0.1m",
150
+ num_train_epochs=2,
151
+ per_device_train_batch_size=1024,
152
+ gradient_accumulation_steps=1,
153
+ save_steps=500,
154
+ save_total_limit=2,
155
+ logging_steps=100,
156
+ weight_decay=0.01,
157
+ fp16=torch.cuda.is_available(),
158
+ push_to_hub=False,
159
+ report_to="none",
160
+ dataloader_num_workers=2,
161
+ dataloader_pin_memory=True,
162
+ learning_rate=6e-4,
163
+ lr_scheduler_type="cosine",
164
+ warmup_ratio=0.05,
165
+ )
166
+
167
+ trainer = Trainer(
168
+ model=model,
169
+ args=training_args,
170
+ train_dataset=dataset,
171
+ data_collator=collate_fn,
172
+ )
173
+
174
+ print("[*] Starting training...")
175
+ trainer.train()
176
+ trainer.save_model("./Supra-Mini-0.1m-FINAL")
177
+ tokenizer.save_pretrained("./Supra-Mini-0.1m-FINAL")
178
+ print("[*] Training finished.")