AxionLab-official commited on
Commit
5df6e5e
Β·
verified Β·
1 Parent(s): dd5b725

Upload distill.py

Browse files
Files changed (1) hide show
  1. distill.py +428 -0
distill.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math, random, time, json, copy
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tokenizers import Tokenizer
7
+ from transformers import LlamaConfig, LlamaForCausalLM
8
+
9
+
10
+ # ─────────────────────────────────────────────
11
+ # Paths β€” ajuste se necessΓ‘rio
12
+ # ─────────────────────────────────────────────
13
+ TEACHER_WEIGHTS = "model.safetensors"
14
+ TOKENIZER_FILE = "tokenizer.json"
15
+ OUTPUT_DIR = Path("student_output")
16
+
17
+ # ─────────────────────────────────────────────
18
+ # HiperparΓ’metros
19
+ # ─────────────────────────────────────────────
20
+ TEMPERATURE = 2.0 # suaviza a distribuiΓ§Γ£o do teacher no KL
21
+ ALPHA = 0.7 # peso do KL loss (1-alpha = peso do CE loss)
22
+ LR = 3e-4
23
+ WEIGHT_DECAY = 0.01
24
+ MAX_STEPS = 3000 # passos totais de treino
25
+ SAVE_EVERY = 500 # salva checkpoint a cada N passos
26
+ LOG_EVERY = 50 # loga loss a cada N passos
27
+ GEN_MAX_TOKENS = 128 # tokens gerados pelo teacher por seed
28
+ SEQ_LEN = 128 # tamanho da janela de contexto para o treino
29
+ BATCH_SIZE = 4 # sequΓͺncias por passo (CPU friendly)
30
+ SEED = 42
31
+
32
+
33
+ # ─────────────────────────────────────────────
34
+ # Arquiteturas
35
+ # ─────────────────────────────────────────────
36
+ TEACHER_CONFIG = dict(
37
+ vocab_size=4096, hidden_size=64, intermediate_size=128,
38
+ num_hidden_layers=5, num_attention_heads=8, num_key_value_heads=8,
39
+ max_position_embeddings=512, rms_norm_eps=1e-6,
40
+ tie_word_embeddings=True, use_cache=False,
41
+ bos_token_id=0, eos_token_id=2, pad_token_id=1,
42
+ )
43
+
44
+ STUDENT_CONFIG = dict(
45
+ vocab_size=4096, hidden_size=48, intermediate_size=96,
46
+ num_hidden_layers=4, num_attention_heads=6, num_key_value_heads=6,
47
+ max_position_embeddings=512, rms_norm_eps=1e-6,
48
+ tie_word_embeddings=True, use_cache=False,
49
+ bos_token_id=0, eos_token_id=2, pad_token_id=1,
50
+ )
51
+
52
+
53
+ # ─────────────────────────────────────────────
54
+ # Seed prompts (texto corrido, estilo FineWeb-Edu)
55
+ # O teacher Γ© um modelo base β€” ele continua texto, nΓ£o responde perguntas.
56
+ # ─────────────────────────────────────────────
57
+ SEED_PROMPTS = [
58
+ # CiΓͺncias
59
+ "The process of photosynthesis allows plants to",
60
+ "In chemistry, the periodic table organizes elements by",
61
+ "The theory of evolution explains how species",
62
+ "Gravity is a fundamental force that causes",
63
+ "The human nervous system is responsible for",
64
+ "Cells are the basic unit of life and they",
65
+ "The water cycle describes how water moves through",
66
+ "Atoms are the smallest units of matter and",
67
+ "The immune system protects the body by",
68
+ "Energy cannot be created or destroyed, it can only",
69
+ "The speed of light in a vacuum is",
70
+ "DNA carries the genetic information that determines",
71
+ "The laws of thermodynamics describe how energy",
72
+ "In physics, Newton's laws of motion state that",
73
+ "The ecosystem consists of living organisms and their",
74
+
75
+ # HistΓ³ria e sociedade
76
+ "The Renaissance was a period in European history when",
77
+ "The Industrial Revolution transformed society by",
78
+ "Ancient civilizations built complex societies through",
79
+ "Democracy is a system of government in which",
80
+ "The printing press changed the spread of knowledge by",
81
+ "Trade routes in the ancient world connected",
82
+ "The development of writing allowed humans to",
83
+ "Philosophical inquiry began in ancient Greece when",
84
+ "The scientific revolution changed the way people",
85
+ "Colonial expansion in the 15th century led to",
86
+ "The concept of human rights emerged from",
87
+ "Language shapes the way people think and",
88
+ "Art throughout history has served to",
89
+ "Economic systems determine how resources are",
90
+ "Education plays a central role in society because",
91
+
92
+ # Tecnologia e matemΓ‘tica (conceitual, sem cΓ‘lculo)
93
+ "Computers process information using binary code, which",
94
+ "The internet connects millions of devices around",
95
+ "Algorithms are step-by-step instructions that",
96
+ "Mathematical patterns can be found in nature when",
97
+ "Logic is the foundation of reasoning and",
98
+ "Statistics help us understand data by",
99
+ "Geometry studies the properties of shapes and",
100
+ "The concept of infinity in mathematics refers to",
101
+ "Programming languages allow humans to communicate with",
102
+ "Artificial intelligence systems learn from",
103
+
104
+ # Natureza e meio ambiente
105
+ "The Amazon rainforest is home to an extraordinary number of",
106
+ "Climate change is caused by an increase in",
107
+ "Ocean currents play an important role in regulating",
108
+ "Biodiversity refers to the variety of life found in",
109
+ "The nitrogen cycle is essential for life because",
110
+ "Renewable energy sources such as solar and wind",
111
+ "Deforestation has significant consequences for",
112
+ "Mountains are formed through geological processes including",
113
+ "The atmosphere protects life on Earth by",
114
+ "Coral reefs are important ecosystems that support",
115
+
116
+ # Filosofia e cogniΓ§Γ£o
117
+ "Critical thinking involves the ability to",
118
+ "Memory is the cognitive process by which",
119
+ "The brain processes information through complex networks of",
120
+ "Consciousness refers to the state of being aware of",
121
+ "Learning occurs most effectively when",
122
+ "Creativity is the capacity to generate new ideas by",
123
+ "Problem solving requires breaking a challenge into",
124
+ "Curiosity drives scientific discovery because",
125
+ "Knowledge is built through observation and",
126
+ "Understanding a concept deeply means being able to",
127
+
128
+ # Medicina e corpo humano
129
+ "The cardiovascular system circulates blood throughout",
130
+ "Nutrition is fundamental to health because",
131
+ "Sleep is essential for cognitive function and",
132
+ "Exercise improves physical health by",
133
+ "The digestive system breaks down food into",
134
+ "Mental health is as important as physical health because",
135
+ "Vaccines work by training the immune system to",
136
+ "The skeletal system provides structure and support for",
137
+ "Hormones regulate many bodily functions including",
138
+ "The lungs exchange oxygen and carbon dioxide through",
139
+ ]
140
+
141
+
142
+ # ─────────────────────────────────────────────
143
+ # UtilitΓ‘rios
144
+ # ─────────────────────────────────────────────
145
+ def set_seed(seed: int):
146
+ random.seed(seed)
147
+ torch.manual_seed(seed)
148
+
149
+
150
+ def count_params(model: torch.nn.Module) -> int:
151
+ return sum(p.numel() for p in model.parameters())
152
+
153
+
154
+ def make_config(cfg: dict) -> LlamaConfig:
155
+ c = LlamaConfig(**cfg)
156
+ c.rope_theta = 10000.0
157
+ return c
158
+
159
+
160
+ def load_teacher(weights_path: str, cfg: dict, device: torch.device) -> LlamaForCausalLM:
161
+ config = make_config(cfg)
162
+ model = LlamaForCausalLM(config)
163
+ state = {}
164
+
165
+ from safetensors.torch import load_file
166
+ raw = load_file(weights_path)
167
+ # remove prefixo 'model.' se presente para compatibilidade
168
+ for k, v in raw.items():
169
+ new_k = k[len("model."):] if k.startswith("model.") else k
170
+ state[new_k] = v
171
+
172
+ # tie_word_embeddings: lm_head.weight == embed_tokens.weight
173
+ if "lm_head.weight" not in state and "embed_tokens.weight" in state:
174
+ state["lm_head.weight"] = state["embed_tokens.weight"]
175
+
176
+ missing, unexpected = model.model.load_state_dict(state, strict=False)
177
+ if missing:
178
+ # tenta carregar no modelo completo
179
+ full_state = {f"model.{k}": v for k, v in state.items()}
180
+ model.load_state_dict(full_state, strict=False)
181
+
182
+ model.to(device)
183
+ model.eval()
184
+ for p in model.parameters():
185
+ p.requires_grad_(False)
186
+ return model
187
+
188
+
189
+ def build_student(cfg: dict, device: torch.device) -> LlamaForCausalLM:
190
+ config = make_config(cfg)
191
+ model = LlamaForCausalLM(config)
192
+ model.to(device)
193
+ model.train()
194
+ return model
195
+
196
+
197
+ # ─────────────────────────────────────────────
198
+ # GeraΓ§Γ£o de sequΓͺncias com o teacher
199
+ # ─────────────────────────────────────────────
200
+ @torch.no_grad()
201
+ def teacher_generate(
202
+ teacher: LlamaForCausalLM,
203
+ input_ids: torch.Tensor,
204
+ max_new_tokens: int,
205
+ temperature: float = 1.0,
206
+ top_k: int = 25,
207
+ ) -> torch.Tensor:
208
+ """GeraΓ§Γ£o autoregressiva simples com top-k sampling."""
209
+ ids = input_ids.clone()
210
+ max_pos = teacher.config.max_position_embeddings
211
+
212
+ for _ in range(max_new_tokens):
213
+ if ids.shape[1] >= max_pos:
214
+ break
215
+ logits = teacher(ids).logits[:, -1, :] # (B, V)
216
+ logits = logits / max(temperature, 1e-8)
217
+ top_vals, _ = torch.topk(logits, top_k, dim=-1)
218
+ threshold = top_vals[:, -1].unsqueeze(-1)
219
+ logits = logits.masked_fill(logits < threshold, float("-inf"))
220
+ probs = F.softmax(logits, dim=-1)
221
+ next_id = torch.multinomial(probs, num_samples=1) # (B, 1)
222
+ ids = torch.cat([ids, next_id], dim=1)
223
+
224
+ # para se todos geraram EOS
225
+ if (next_id == teacher.config.eos_token_id).all():
226
+ break
227
+
228
+ return ids
229
+
230
+
231
+ # ─────────────────────────────────────────────
232
+ # Distillation loss
233
+ # ─────────────────────────────────────────────
234
+ def distill_loss(
235
+ student_logits: torch.Tensor,
236
+ teacher_logits: torch.Tensor,
237
+ labels: torch.Tensor,
238
+ temperature: float,
239
+ alpha: float,
240
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
241
+ """
242
+ Retorna (loss_total, kl_loss, ce_loss).
243
+ student_logits / teacher_logits : (B, T, V)
244
+ labels : (B, T) β€” token ids, -100 para ignorar
245
+ """
246
+ B, T, V = student_logits.shape
247
+
248
+ # ── KL Divergence (soft labels) ──────────────────────────────────────
249
+ # Flatten para (B*T, V)
250
+ s_log_probs = F.log_softmax(student_logits.view(-1, V) / temperature, dim=-1)
251
+ t_probs = F.softmax(teacher_logits.view(-1, V) / temperature, dim=-1)
252
+ kl = F.kl_div(s_log_probs, t_probs, reduction="batchmean") * (temperature ** 2)
253
+
254
+ # ── Cross-Entropy (hard labels) ──────────────────────────────────────
255
+ # shift: prediz token i+1 a partir do token i
256
+ shift_logits = student_logits[:, :-1, :].contiguous().view(-1, V)
257
+ shift_labels = labels[:, 1:].contiguous().view(-1)
258
+ ce = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
259
+
260
+ loss = alpha * kl + (1.0 - alpha) * ce
261
+ return loss, kl.detach(), ce.detach()
262
+
263
+
264
+ # ─────────────────────────────────────────────
265
+ # Treino
266
+ # ─────────────────────────────────────────────
267
+ def train():
268
+ set_seed(SEED)
269
+ device = torch.device("cpu")
270
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
271
+
272
+ print("=" * 60)
273
+ print(" Supra Mini β€” Distillation Pipeline")
274
+ print("=" * 60)
275
+
276
+ # ── Tokenizer ────────────────────────────────────────────────────────
277
+ if not Path(TOKENIZER_FILE).exists():
278
+ raise FileNotFoundError(
279
+ f"Tokenizer nΓ£o encontrado: '{TOKENIZER_FILE}'\n"
280
+ f"Renomeie o arquivo 'tokenizer__1_.json' para 'tokenizer.json' "
281
+ f"e coloque na mesma pasta deste script."
282
+ )
283
+ tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
284
+ tokenizer.no_padding()
285
+ tokenizer.no_truncation()
286
+ print(f" Tokenizer carregado β€” vocab={tokenizer.get_vocab_size()}")
287
+
288
+ # ── Teacher ──────────────────────────────────────────────────────────
289
+ if not Path(TEACHER_WEIGHTS).exists():
290
+ raise FileNotFoundError(
291
+ f"Pesos do teacher nΓ£o encontrados: '{TEACHER_WEIGHTS}'\n"
292
+ f"Coloque o arquivo 'model.safetensors' na mesma pasta."
293
+ )
294
+ teacher = load_teacher(TEACHER_WEIGHTS, TEACHER_CONFIG, device)
295
+ print(f" Teacher carregado β€” params={count_params(teacher):,} [frozen]")
296
+
297
+ # ── Student ──────────────────────────────────────────────────────────
298
+ student = build_student(STUDENT_CONFIG, device)
299
+ print(f" Student inicializado β€” params={count_params(student):,} [trainable]")
300
+ print(f" CompressΓ£o β€” {count_params(teacher)/count_params(student):.2f}x")
301
+ print("=" * 60)
302
+
303
+ optimizer = torch.optim.AdamW(
304
+ student.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
305
+ )
306
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
307
+ optimizer, T_max=MAX_STEPS, eta_min=LR * 0.1
308
+ )
309
+
310
+ bos_id = TEACHER_CONFIG["bos_token_id"]
311
+ eos_id = TEACHER_CONFIG["eos_token_id"]
312
+ pad_id = TEACHER_CONFIG["pad_token_id"]
313
+
314
+ step = 0
315
+ running_loss = 0.0
316
+ running_kl = 0.0
317
+ running_ce = 0.0
318
+ t_start = time.time()
319
+
320
+ print(f"\n Iniciando treino β€” {MAX_STEPS} passos\n")
321
+
322
+ while step < MAX_STEPS:
323
+ # ── Gera batch de sequΓͺncias com o teacher ────────────────────────
324
+ sequences = []
325
+ random.shuffle(SEED_PROMPTS)
326
+
327
+ for prompt in SEED_PROMPTS:
328
+ if len(sequences) >= BATCH_SIZE:
329
+ break
330
+
331
+ enc = tokenizer.encode(prompt)
332
+ prompt_ids = torch.tensor([[bos_id] + enc.ids], dtype=torch.long)
333
+
334
+ with torch.no_grad():
335
+ gen_ids = teacher_generate(
336
+ teacher, prompt_ids,
337
+ max_new_tokens=GEN_MAX_TOKENS,
338
+ temperature=1.0,
339
+ top_k=25,
340
+ )
341
+
342
+ # Trunca / padeia para SEQ_LEN
343
+ seq = gen_ids[0].tolist()
344
+ if len(seq) < SEQ_LEN:
345
+ seq = seq + [pad_id] * (SEQ_LEN - len(seq))
346
+ else:
347
+ seq = seq[:SEQ_LEN]
348
+
349
+ sequences.append(seq)
350
+
351
+ if not sequences:
352
+ continue
353
+
354
+ input_ids = torch.tensor(sequences, dtype=torch.long) # (B, T)
355
+
356
+ # Labels: -100 nos pads para ignorar no CE
357
+ labels = input_ids.clone()
358
+ labels[labels == pad_id] = -100
359
+
360
+ # ── Forward pass teacher (sem gradiente) ─────────────────────────
361
+ with torch.no_grad():
362
+ teacher_logits = teacher(input_ids).logits # (B, T, V)
363
+
364
+ # ── Forward pass student ──────────────────────────────────────────
365
+ student_logits = student(input_ids).logits # (B, T, V)
366
+
367
+ # ── Loss ──────────────────────────────────────────────────────────
368
+ loss, kl, ce = distill_loss(
369
+ student_logits, teacher_logits, labels,
370
+ temperature=TEMPERATURE, alpha=ALPHA,
371
+ )
372
+
373
+ # ── Backprop ───────────────────────────────────────────────────────
374
+ optimizer.zero_grad()
375
+ loss.backward()
376
+ torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
377
+ optimizer.step()
378
+ scheduler.step()
379
+
380
+ step += 1
381
+ running_loss += loss.item()
382
+ running_kl += kl.item()
383
+ running_ce += ce.item()
384
+
385
+ # ── Log ───────────────────────────────────────────────────────────
386
+ if step % LOG_EVERY == 0:
387
+ avg_loss = running_loss / LOG_EVERY
388
+ avg_kl = running_kl / LOG_EVERY
389
+ avg_ce = running_ce / LOG_EVERY
390
+ elapsed = time.time() - t_start
391
+ steps_s = step / elapsed
392
+ eta_s = (MAX_STEPS - step) / max(steps_s, 1e-6)
393
+ eta_min = eta_s / 60
394
+
395
+ print(
396
+ f" step {step:>5}/{MAX_STEPS}"
397
+ f" loss={avg_loss:.4f}"
398
+ f" kl={avg_kl:.4f}"
399
+ f" ce={avg_ce:.4f}"
400
+ f" lr={scheduler.get_last_lr()[0]:.2e}"
401
+ f" {steps_s:.2f} steps/s"
402
+ f" ETA {eta_min:.1f}min"
403
+ )
404
+ running_loss = running_kl = running_ce = 0.0
405
+
406
+ # ── Checkpoint ────────────────────────────────────────────────────
407
+ if step % SAVE_EVERY == 0:
408
+ ckpt_path = OUTPUT_DIR / f"student_step{step}.pt"
409
+ torch.save(student.state_dict(), ckpt_path)
410
+ print(f"\n βœ“ Checkpoint salvo: {ckpt_path}\n")
411
+
412
+ # ── Salva modelo final ────────────────────────────────────────────────
413
+ final_path = OUTPUT_DIR / "student_final.pt"
414
+ torch.save(student.state_dict(), final_path)
415
+
416
+ # Salva config do student para carregar depois
417
+ with open(OUTPUT_DIR / "config_student.json", "w") as f:
418
+ json.dump(STUDENT_CONFIG, f, indent=2)
419
+
420
+ total_time = (time.time() - t_start) / 60
421
+ print(f"\n{'='*60}")
422
+ print(f" Treino concluΓ­do em {total_time:.1f} minutos")
423
+ print(f" Modelo salvo em: {final_path}")
424
+ print(f"{'='*60}")
425
+
426
+
427
+ if __name__ == "__main__":
428
+ train()