| import os, math, random, time, json, copy
|
| from pathlib import Path
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from tokenizers import Tokenizer
|
| from transformers import LlamaConfig, LlamaForCausalLM
|
|
|
|
|
|
|
|
|
|
|
| TEACHER_WEIGHTS = "model.safetensors"
|
| TOKENIZER_FILE = "tokenizer.json"
|
| OUTPUT_DIR = Path("student_output")
|
|
|
|
|
|
|
|
|
| TEMPERATURE = 2.0
|
| ALPHA = 0.7
|
| LR = 3e-4
|
| WEIGHT_DECAY = 0.01
|
| MAX_STEPS = 3000
|
| SAVE_EVERY = 500
|
| LOG_EVERY = 50
|
| GEN_MAX_TOKENS = 128
|
| SEQ_LEN = 128
|
| BATCH_SIZE = 4
|
| SEED = 42
|
|
|
|
|
|
|
|
|
|
|
| TEACHER_CONFIG = dict(
|
| vocab_size=4096, hidden_size=64, intermediate_size=128,
|
| num_hidden_layers=5, num_attention_heads=8, num_key_value_heads=8,
|
| max_position_embeddings=512, rms_norm_eps=1e-6,
|
| tie_word_embeddings=True, use_cache=False,
|
| bos_token_id=0, eos_token_id=2, pad_token_id=1,
|
| )
|
|
|
| STUDENT_CONFIG = dict(
|
| vocab_size=4096, hidden_size=48, intermediate_size=96,
|
| num_hidden_layers=4, num_attention_heads=6, num_key_value_heads=6,
|
| max_position_embeddings=512, rms_norm_eps=1e-6,
|
| tie_word_embeddings=True, use_cache=False,
|
| bos_token_id=0, eos_token_id=2, pad_token_id=1,
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| SEED_PROMPTS = [
|
|
|
| "The process of photosynthesis allows plants to",
|
| "In chemistry, the periodic table organizes elements by",
|
| "The theory of evolution explains how species",
|
| "Gravity is a fundamental force that causes",
|
| "The human nervous system is responsible for",
|
| "Cells are the basic unit of life and they",
|
| "The water cycle describes how water moves through",
|
| "Atoms are the smallest units of matter and",
|
| "The immune system protects the body by",
|
| "Energy cannot be created or destroyed, it can only",
|
| "The speed of light in a vacuum is",
|
| "DNA carries the genetic information that determines",
|
| "The laws of thermodynamics describe how energy",
|
| "In physics, Newton's laws of motion state that",
|
| "The ecosystem consists of living organisms and their",
|
|
|
|
|
| "The Renaissance was a period in European history when",
|
| "The Industrial Revolution transformed society by",
|
| "Ancient civilizations built complex societies through",
|
| "Democracy is a system of government in which",
|
| "The printing press changed the spread of knowledge by",
|
| "Trade routes in the ancient world connected",
|
| "The development of writing allowed humans to",
|
| "Philosophical inquiry began in ancient Greece when",
|
| "The scientific revolution changed the way people",
|
| "Colonial expansion in the 15th century led to",
|
| "The concept of human rights emerged from",
|
| "Language shapes the way people think and",
|
| "Art throughout history has served to",
|
| "Economic systems determine how resources are",
|
| "Education plays a central role in society because",
|
|
|
|
|
| "Computers process information using binary code, which",
|
| "The internet connects millions of devices around",
|
| "Algorithms are step-by-step instructions that",
|
| "Mathematical patterns can be found in nature when",
|
| "Logic is the foundation of reasoning and",
|
| "Statistics help us understand data by",
|
| "Geometry studies the properties of shapes and",
|
| "The concept of infinity in mathematics refers to",
|
| "Programming languages allow humans to communicate with",
|
| "Artificial intelligence systems learn from",
|
|
|
|
|
| "The Amazon rainforest is home to an extraordinary number of",
|
| "Climate change is caused by an increase in",
|
| "Ocean currents play an important role in regulating",
|
| "Biodiversity refers to the variety of life found in",
|
| "The nitrogen cycle is essential for life because",
|
| "Renewable energy sources such as solar and wind",
|
| "Deforestation has significant consequences for",
|
| "Mountains are formed through geological processes including",
|
| "The atmosphere protects life on Earth by",
|
| "Coral reefs are important ecosystems that support",
|
|
|
|
|
| "Critical thinking involves the ability to",
|
| "Memory is the cognitive process by which",
|
| "The brain processes information through complex networks of",
|
| "Consciousness refers to the state of being aware of",
|
| "Learning occurs most effectively when",
|
| "Creativity is the capacity to generate new ideas by",
|
| "Problem solving requires breaking a challenge into",
|
| "Curiosity drives scientific discovery because",
|
| "Knowledge is built through observation and",
|
| "Understanding a concept deeply means being able to",
|
|
|
|
|
| "The cardiovascular system circulates blood throughout",
|
| "Nutrition is fundamental to health because",
|
| "Sleep is essential for cognitive function and",
|
| "Exercise improves physical health by",
|
| "The digestive system breaks down food into",
|
| "Mental health is as important as physical health because",
|
| "Vaccines work by training the immune system to",
|
| "The skeletal system provides structure and support for",
|
| "Hormones regulate many bodily functions including",
|
| "The lungs exchange oxygen and carbon dioxide through",
|
| ]
|
|
|
|
|
|
|
|
|
|
|
| def set_seed(seed: int):
|
| random.seed(seed)
|
| torch.manual_seed(seed)
|
|
|
|
|
| def count_params(model: torch.nn.Module) -> int:
|
| return sum(p.numel() for p in model.parameters())
|
|
|
|
|
| def make_config(cfg: dict) -> LlamaConfig:
|
| c = LlamaConfig(**cfg)
|
| c.rope_theta = 10000.0
|
| return c
|
|
|
|
|
| def load_teacher(weights_path: str, cfg: dict, device: torch.device) -> LlamaForCausalLM:
|
| config = make_config(cfg)
|
| model = LlamaForCausalLM(config)
|
| state = {}
|
|
|
| from safetensors.torch import load_file
|
| raw = load_file(weights_path)
|
|
|
| for k, v in raw.items():
|
| new_k = k[len("model."):] if k.startswith("model.") else k
|
| state[new_k] = v
|
|
|
|
|
| if "lm_head.weight" not in state and "embed_tokens.weight" in state:
|
| state["lm_head.weight"] = state["embed_tokens.weight"]
|
|
|
| missing, unexpected = model.model.load_state_dict(state, strict=False)
|
| if missing:
|
|
|
| full_state = {f"model.{k}": v for k, v in state.items()}
|
| model.load_state_dict(full_state, strict=False)
|
|
|
| model.to(device)
|
| model.eval()
|
| for p in model.parameters():
|
| p.requires_grad_(False)
|
| return model
|
|
|
|
|
| def build_student(cfg: dict, device: torch.device) -> LlamaForCausalLM:
|
| config = make_config(cfg)
|
| model = LlamaForCausalLM(config)
|
| model.to(device)
|
| model.train()
|
| return model
|
|
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def teacher_generate(
|
| teacher: LlamaForCausalLM,
|
| input_ids: torch.Tensor,
|
| max_new_tokens: int,
|
| temperature: float = 1.0,
|
| top_k: int = 25,
|
| ) -> torch.Tensor:
|
| """GeraΓ§Γ£o autoregressiva simples com top-k sampling."""
|
| ids = input_ids.clone()
|
| max_pos = teacher.config.max_position_embeddings
|
|
|
| for _ in range(max_new_tokens):
|
| if ids.shape[1] >= max_pos:
|
| break
|
| logits = teacher(ids).logits[:, -1, :]
|
| logits = logits / max(temperature, 1e-8)
|
| top_vals, _ = torch.topk(logits, top_k, dim=-1)
|
| threshold = top_vals[:, -1].unsqueeze(-1)
|
| logits = logits.masked_fill(logits < threshold, float("-inf"))
|
| probs = F.softmax(logits, dim=-1)
|
| next_id = torch.multinomial(probs, num_samples=1)
|
| ids = torch.cat([ids, next_id], dim=1)
|
|
|
|
|
| if (next_id == teacher.config.eos_token_id).all():
|
| break
|
|
|
| return ids
|
|
|
|
|
|
|
|
|
|
|
| def distill_loss(
|
| student_logits: torch.Tensor,
|
| teacher_logits: torch.Tensor,
|
| labels: torch.Tensor,
|
| temperature: float,
|
| alpha: float,
|
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| """
|
| Retorna (loss_total, kl_loss, ce_loss).
|
| student_logits / teacher_logits : (B, T, V)
|
| labels : (B, T) β token ids, -100 para ignorar
|
| """
|
| B, T, V = student_logits.shape
|
|
|
|
|
|
|
| s_log_probs = F.log_softmax(student_logits.view(-1, V) / temperature, dim=-1)
|
| t_probs = F.softmax(teacher_logits.view(-1, V) / temperature, dim=-1)
|
| kl = F.kl_div(s_log_probs, t_probs, reduction="batchmean") * (temperature ** 2)
|
|
|
|
|
|
|
| shift_logits = student_logits[:, :-1, :].contiguous().view(-1, V)
|
| shift_labels = labels[:, 1:].contiguous().view(-1)
|
| ce = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
|
|
|
| loss = alpha * kl + (1.0 - alpha) * ce
|
| return loss, kl.detach(), ce.detach()
|
|
|
|
|
|
|
|
|
|
|
| def train():
|
| set_seed(SEED)
|
| device = torch.device("cpu")
|
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
| print("=" * 60)
|
| print(" Supra Mini β Distillation Pipeline")
|
| print("=" * 60)
|
|
|
|
|
| if not Path(TOKENIZER_FILE).exists():
|
| raise FileNotFoundError(
|
| f"Tokenizer nΓ£o encontrado: '{TOKENIZER_FILE}'\n"
|
| f"Renomeie o arquivo 'tokenizer__1_.json' para 'tokenizer.json' "
|
| f"e coloque na mesma pasta deste script."
|
| )
|
| tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
|
| tokenizer.no_padding()
|
| tokenizer.no_truncation()
|
| print(f" Tokenizer carregado β vocab={tokenizer.get_vocab_size()}")
|
|
|
|
|
| if not Path(TEACHER_WEIGHTS).exists():
|
| raise FileNotFoundError(
|
| f"Pesos do teacher nΓ£o encontrados: '{TEACHER_WEIGHTS}'\n"
|
| f"Coloque o arquivo 'model.safetensors' na mesma pasta."
|
| )
|
| teacher = load_teacher(TEACHER_WEIGHTS, TEACHER_CONFIG, device)
|
| print(f" Teacher carregado β params={count_params(teacher):,} [frozen]")
|
|
|
|
|
| student = build_student(STUDENT_CONFIG, device)
|
| print(f" Student inicializado β params={count_params(student):,} [trainable]")
|
| print(f" CompressΓ£o β {count_params(teacher)/count_params(student):.2f}x")
|
| print("=" * 60)
|
|
|
| optimizer = torch.optim.AdamW(
|
| student.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
|
| )
|
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| optimizer, T_max=MAX_STEPS, eta_min=LR * 0.1
|
| )
|
|
|
| bos_id = TEACHER_CONFIG["bos_token_id"]
|
| eos_id = TEACHER_CONFIG["eos_token_id"]
|
| pad_id = TEACHER_CONFIG["pad_token_id"]
|
|
|
| step = 0
|
| running_loss = 0.0
|
| running_kl = 0.0
|
| running_ce = 0.0
|
| t_start = time.time()
|
|
|
| print(f"\n Iniciando treino β {MAX_STEPS} passos\n")
|
|
|
| while step < MAX_STEPS:
|
|
|
| sequences = []
|
| random.shuffle(SEED_PROMPTS)
|
|
|
| for prompt in SEED_PROMPTS:
|
| if len(sequences) >= BATCH_SIZE:
|
| break
|
|
|
| enc = tokenizer.encode(prompt)
|
| prompt_ids = torch.tensor([[bos_id] + enc.ids], dtype=torch.long)
|
|
|
| with torch.no_grad():
|
| gen_ids = teacher_generate(
|
| teacher, prompt_ids,
|
| max_new_tokens=GEN_MAX_TOKENS,
|
| temperature=1.0,
|
| top_k=25,
|
| )
|
|
|
|
|
| seq = gen_ids[0].tolist()
|
| if len(seq) < SEQ_LEN:
|
| seq = seq + [pad_id] * (SEQ_LEN - len(seq))
|
| else:
|
| seq = seq[:SEQ_LEN]
|
|
|
| sequences.append(seq)
|
|
|
| if not sequences:
|
| continue
|
|
|
| input_ids = torch.tensor(sequences, dtype=torch.long)
|
|
|
|
|
| labels = input_ids.clone()
|
| labels[labels == pad_id] = -100
|
|
|
|
|
| with torch.no_grad():
|
| teacher_logits = teacher(input_ids).logits
|
|
|
|
|
| student_logits = student(input_ids).logits
|
|
|
|
|
| loss, kl, ce = distill_loss(
|
| student_logits, teacher_logits, labels,
|
| temperature=TEMPERATURE, alpha=ALPHA,
|
| )
|
|
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
|
| optimizer.step()
|
| scheduler.step()
|
|
|
| step += 1
|
| running_loss += loss.item()
|
| running_kl += kl.item()
|
| running_ce += ce.item()
|
|
|
|
|
| if step % LOG_EVERY == 0:
|
| avg_loss = running_loss / LOG_EVERY
|
| avg_kl = running_kl / LOG_EVERY
|
| avg_ce = running_ce / LOG_EVERY
|
| elapsed = time.time() - t_start
|
| steps_s = step / elapsed
|
| eta_s = (MAX_STEPS - step) / max(steps_s, 1e-6)
|
| eta_min = eta_s / 60
|
|
|
| print(
|
| f" step {step:>5}/{MAX_STEPS}"
|
| f" loss={avg_loss:.4f}"
|
| f" kl={avg_kl:.4f}"
|
| f" ce={avg_ce:.4f}"
|
| f" lr={scheduler.get_last_lr()[0]:.2e}"
|
| f" {steps_s:.2f} steps/s"
|
| f" ETA {eta_min:.1f}min"
|
| )
|
| running_loss = running_kl = running_ce = 0.0
|
|
|
|
|
| if step % SAVE_EVERY == 0:
|
| ckpt_path = OUTPUT_DIR / f"student_step{step}.pt"
|
| torch.save(student.state_dict(), ckpt_path)
|
| print(f"\n β Checkpoint salvo: {ckpt_path}\n")
|
|
|
|
|
| final_path = OUTPUT_DIR / "student_final.pt"
|
| torch.save(student.state_dict(), final_path)
|
|
|
|
|
| with open(OUTPUT_DIR / "config_student.json", "w") as f:
|
| json.dump(STUDENT_CONFIG, f, indent=2)
|
|
|
| total_time = (time.time() - t_start) / 60
|
| print(f"\n{'='*60}")
|
| print(f" Treino concluΓdo em {total_time:.1f} minutos")
|
| print(f" Modelo salvo em: {final_path}")
|
| print(f"{'='*60}")
|
|
|
|
|
| if __name__ == "__main__":
|
| train()
|
|
|