import os, math, random, time, json, copy from pathlib import Path import torch import torch.nn.functional as F from tokenizers import Tokenizer from transformers import LlamaConfig, LlamaForCausalLM # ───────────────────────────────────────────── # Paths — ajuste se necessário # ───────────────────────────────────────────── TEACHER_WEIGHTS = "model.safetensors" TOKENIZER_FILE = "tokenizer.json" OUTPUT_DIR = Path("student_output") # ───────────────────────────────────────────── # Hiperparâmetros # ───────────────────────────────────────────── TEMPERATURE = 2.0 # suaviza a distribuição do teacher no KL ALPHA = 0.7 # peso do KL loss (1-alpha = peso do CE loss) LR = 3e-4 WEIGHT_DECAY = 0.01 MAX_STEPS = 3000 # passos totais de treino SAVE_EVERY = 500 # salva checkpoint a cada N passos LOG_EVERY = 50 # loga loss a cada N passos GEN_MAX_TOKENS = 128 # tokens gerados pelo teacher por seed SEQ_LEN = 128 # tamanho da janela de contexto para o treino BATCH_SIZE = 4 # sequências por passo (CPU friendly) SEED = 42 # ───────────────────────────────────────────── # Arquiteturas # ───────────────────────────────────────────── TEACHER_CONFIG = dict( vocab_size=4096, hidden_size=64, intermediate_size=128, num_hidden_layers=5, num_attention_heads=8, num_key_value_heads=8, max_position_embeddings=512, rms_norm_eps=1e-6, tie_word_embeddings=True, use_cache=False, bos_token_id=0, eos_token_id=2, pad_token_id=1, ) STUDENT_CONFIG = dict( vocab_size=4096, hidden_size=48, intermediate_size=96, num_hidden_layers=4, num_attention_heads=6, num_key_value_heads=6, max_position_embeddings=512, rms_norm_eps=1e-6, tie_word_embeddings=True, use_cache=False, bos_token_id=0, eos_token_id=2, pad_token_id=1, ) # ───────────────────────────────────────────── # Seed prompts (texto corrido, estilo FineWeb-Edu) # O teacher é um modelo base — ele continua texto, não responde perguntas. # ───────────────────────────────────────────── SEED_PROMPTS = [ # Ciências "The process of photosynthesis allows plants to", "In chemistry, the periodic table organizes elements by", "The theory of evolution explains how species", "Gravity is a fundamental force that causes", "The human nervous system is responsible for", "Cells are the basic unit of life and they", "The water cycle describes how water moves through", "Atoms are the smallest units of matter and", "The immune system protects the body by", "Energy cannot be created or destroyed, it can only", "The speed of light in a vacuum is", "DNA carries the genetic information that determines", "The laws of thermodynamics describe how energy", "In physics, Newton's laws of motion state that", "The ecosystem consists of living organisms and their", # História e sociedade "The Renaissance was a period in European history when", "The Industrial Revolution transformed society by", "Ancient civilizations built complex societies through", "Democracy is a system of government in which", "The printing press changed the spread of knowledge by", "Trade routes in the ancient world connected", "The development of writing allowed humans to", "Philosophical inquiry began in ancient Greece when", "The scientific revolution changed the way people", "Colonial expansion in the 15th century led to", "The concept of human rights emerged from", "Language shapes the way people think and", "Art throughout history has served to", "Economic systems determine how resources are", "Education plays a central role in society because", # Tecnologia e matemática (conceitual, sem cálculo) "Computers process information using binary code, which", "The internet connects millions of devices around", "Algorithms are step-by-step instructions that", "Mathematical patterns can be found in nature when", "Logic is the foundation of reasoning and", "Statistics help us understand data by", "Geometry studies the properties of shapes and", "The concept of infinity in mathematics refers to", "Programming languages allow humans to communicate with", "Artificial intelligence systems learn from", # Natureza e meio ambiente "The Amazon rainforest is home to an extraordinary number of", "Climate change is caused by an increase in", "Ocean currents play an important role in regulating", "Biodiversity refers to the variety of life found in", "The nitrogen cycle is essential for life because", "Renewable energy sources such as solar and wind", "Deforestation has significant consequences for", "Mountains are formed through geological processes including", "The atmosphere protects life on Earth by", "Coral reefs are important ecosystems that support", # Filosofia e cognição "Critical thinking involves the ability to", "Memory is the cognitive process by which", "The brain processes information through complex networks of", "Consciousness refers to the state of being aware of", "Learning occurs most effectively when", "Creativity is the capacity to generate new ideas by", "Problem solving requires breaking a challenge into", "Curiosity drives scientific discovery because", "Knowledge is built through observation and", "Understanding a concept deeply means being able to", # Medicina e corpo humano "The cardiovascular system circulates blood throughout", "Nutrition is fundamental to health because", "Sleep is essential for cognitive function and", "Exercise improves physical health by", "The digestive system breaks down food into", "Mental health is as important as physical health because", "Vaccines work by training the immune system to", "The skeletal system provides structure and support for", "Hormones regulate many bodily functions including", "The lungs exchange oxygen and carbon dioxide through", ] # ───────────────────────────────────────────── # Utilitários # ───────────────────────────────────────────── def set_seed(seed: int): random.seed(seed) torch.manual_seed(seed) def count_params(model: torch.nn.Module) -> int: return sum(p.numel() for p in model.parameters()) def make_config(cfg: dict) -> LlamaConfig: c = LlamaConfig(**cfg) c.rope_theta = 10000.0 return c def load_teacher(weights_path: str, cfg: dict, device: torch.device) -> LlamaForCausalLM: config = make_config(cfg) model = LlamaForCausalLM(config) state = {} from safetensors.torch import load_file raw = load_file(weights_path) # remove prefixo 'model.' se presente para compatibilidade for k, v in raw.items(): new_k = k[len("model."):] if k.startswith("model.") else k state[new_k] = v # tie_word_embeddings: lm_head.weight == embed_tokens.weight if "lm_head.weight" not in state and "embed_tokens.weight" in state: state["lm_head.weight"] = state["embed_tokens.weight"] missing, unexpected = model.model.load_state_dict(state, strict=False) if missing: # tenta carregar no modelo completo full_state = {f"model.{k}": v for k, v in state.items()} model.load_state_dict(full_state, strict=False) model.to(device) model.eval() for p in model.parameters(): p.requires_grad_(False) return model def build_student(cfg: dict, device: torch.device) -> LlamaForCausalLM: config = make_config(cfg) model = LlamaForCausalLM(config) model.to(device) model.train() return model # ───────────────────────────────────────────── # Geração de sequências com o teacher # ───────────────────────────────────────────── @torch.no_grad() def teacher_generate( teacher: LlamaForCausalLM, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: int = 25, ) -> torch.Tensor: """Geração autoregressiva simples com top-k sampling.""" ids = input_ids.clone() max_pos = teacher.config.max_position_embeddings for _ in range(max_new_tokens): if ids.shape[1] >= max_pos: break logits = teacher(ids).logits[:, -1, :] # (B, V) logits = logits / max(temperature, 1e-8) top_vals, _ = torch.topk(logits, top_k, dim=-1) threshold = top_vals[:, -1].unsqueeze(-1) logits = logits.masked_fill(logits < threshold, float("-inf")) probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) # (B, 1) ids = torch.cat([ids, next_id], dim=1) # para se todos geraram EOS if (next_id == teacher.config.eos_token_id).all(): break return ids # ───────────────────────────────────────────── # Distillation loss # ───────────────────────────────────────────── def distill_loss( student_logits: torch.Tensor, teacher_logits: torch.Tensor, labels: torch.Tensor, temperature: float, alpha: float, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Retorna (loss_total, kl_loss, ce_loss). student_logits / teacher_logits : (B, T, V) labels : (B, T) — token ids, -100 para ignorar """ B, T, V = student_logits.shape # ── KL Divergence (soft labels) ────────────────────────────────────── # Flatten para (B*T, V) s_log_probs = F.log_softmax(student_logits.view(-1, V) / temperature, dim=-1) t_probs = F.softmax(teacher_logits.view(-1, V) / temperature, dim=-1) kl = F.kl_div(s_log_probs, t_probs, reduction="batchmean") * (temperature ** 2) # ── Cross-Entropy (hard labels) ────────────────────────────────────── # shift: prediz token i+1 a partir do token i shift_logits = student_logits[:, :-1, :].contiguous().view(-1, V) shift_labels = labels[:, 1:].contiguous().view(-1) ce = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100) loss = alpha * kl + (1.0 - alpha) * ce return loss, kl.detach(), ce.detach() # ───────────────────────────────────────────── # Treino # ───────────────────────────────────────────── def train(): set_seed(SEED) device = torch.device("cpu") OUTPUT_DIR.mkdir(parents=True, exist_ok=True) print("=" * 60) print(" Supra Mini — Distillation Pipeline") print("=" * 60) # ── Tokenizer ──────────────────────────────────────────────────────── if not Path(TOKENIZER_FILE).exists(): raise FileNotFoundError( f"Tokenizer não encontrado: '{TOKENIZER_FILE}'\n" f"Renomeie o arquivo 'tokenizer__1_.json' para 'tokenizer.json' " f"e coloque na mesma pasta deste script." ) tokenizer = Tokenizer.from_file(TOKENIZER_FILE) tokenizer.no_padding() tokenizer.no_truncation() print(f" Tokenizer carregado — vocab={tokenizer.get_vocab_size()}") # ── Teacher ────────────────────────────────────────────────────────── if not Path(TEACHER_WEIGHTS).exists(): raise FileNotFoundError( f"Pesos do teacher não encontrados: '{TEACHER_WEIGHTS}'\n" f"Coloque o arquivo 'model.safetensors' na mesma pasta." ) teacher = load_teacher(TEACHER_WEIGHTS, TEACHER_CONFIG, device) print(f" Teacher carregado — params={count_params(teacher):,} [frozen]") # ── Student ────────────────────────────────────────────────────────── student = build_student(STUDENT_CONFIG, device) print(f" Student inicializado — params={count_params(student):,} [trainable]") print(f" Compressão — {count_params(teacher)/count_params(student):.2f}x") print("=" * 60) optimizer = torch.optim.AdamW( student.parameters(), lr=LR, weight_decay=WEIGHT_DECAY ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=MAX_STEPS, eta_min=LR * 0.1 ) bos_id = TEACHER_CONFIG["bos_token_id"] eos_id = TEACHER_CONFIG["eos_token_id"] pad_id = TEACHER_CONFIG["pad_token_id"] step = 0 running_loss = 0.0 running_kl = 0.0 running_ce = 0.0 t_start = time.time() print(f"\n Iniciando treino — {MAX_STEPS} passos\n") while step < MAX_STEPS: # ── Gera batch de sequências com o teacher ──────────────────────── sequences = [] random.shuffle(SEED_PROMPTS) for prompt in SEED_PROMPTS: if len(sequences) >= BATCH_SIZE: break enc = tokenizer.encode(prompt) prompt_ids = torch.tensor([[bos_id] + enc.ids], dtype=torch.long) with torch.no_grad(): gen_ids = teacher_generate( teacher, prompt_ids, max_new_tokens=GEN_MAX_TOKENS, temperature=1.0, top_k=25, ) # Trunca / padeia para SEQ_LEN seq = gen_ids[0].tolist() if len(seq) < SEQ_LEN: seq = seq + [pad_id] * (SEQ_LEN - len(seq)) else: seq = seq[:SEQ_LEN] sequences.append(seq) if not sequences: continue input_ids = torch.tensor(sequences, dtype=torch.long) # (B, T) # Labels: -100 nos pads para ignorar no CE labels = input_ids.clone() labels[labels == pad_id] = -100 # ── Forward pass teacher (sem gradiente) ───────────────────────── with torch.no_grad(): teacher_logits = teacher(input_ids).logits # (B, T, V) # ── Forward pass student ────────────────────────────────────────── student_logits = student(input_ids).logits # (B, T, V) # ── Loss ────────────────────────────────────────────────────────── loss, kl, ce = distill_loss( student_logits, teacher_logits, labels, temperature=TEMPERATURE, alpha=ALPHA, ) # ── Backprop ─────────────────────────────────────────────────────── optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0) optimizer.step() scheduler.step() step += 1 running_loss += loss.item() running_kl += kl.item() running_ce += ce.item() # ── Log ─────────────────────────────────────────────────────────── if step % LOG_EVERY == 0: avg_loss = running_loss / LOG_EVERY avg_kl = running_kl / LOG_EVERY avg_ce = running_ce / LOG_EVERY elapsed = time.time() - t_start steps_s = step / elapsed eta_s = (MAX_STEPS - step) / max(steps_s, 1e-6) eta_min = eta_s / 60 print( f" step {step:>5}/{MAX_STEPS}" f" loss={avg_loss:.4f}" f" kl={avg_kl:.4f}" f" ce={avg_ce:.4f}" f" lr={scheduler.get_last_lr()[0]:.2e}" f" {steps_s:.2f} steps/s" f" ETA {eta_min:.1f}min" ) running_loss = running_kl = running_ce = 0.0 # ── Checkpoint ──────────────────────────────────────────────────── if step % SAVE_EVERY == 0: ckpt_path = OUTPUT_DIR / f"student_step{step}.pt" torch.save(student.state_dict(), ckpt_path) print(f"\n ✓ Checkpoint salvo: {ckpt_path}\n") # ── Salva modelo final ──────────────────────────────────────────────── final_path = OUTPUT_DIR / "student_final.pt" torch.save(student.state_dict(), final_path) # Salva config do student para carregar depois with open(OUTPUT_DIR / "config_student.json", "w") as f: json.dump(STUDENT_CONFIG, f, indent=2) total_time = (time.time() - t_start) / 60 print(f"\n{'='*60}") print(f" Treino concluído em {total_time:.1f} minutos") print(f" Modelo salvo em: {final_path}") print(f"{'='*60}") if __name__ == "__main__": train()