| |
| """ |
| Prototype LM for geometric simplex structures. |
| |
| Requires the geometricvocab's SimplexFactory for valid simplex representations, or the simplex behavior will not learn. |
| |
| try: |
| !pip uninstall -qy geometricvocab |
| except: |
| pass |
| |
| !pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git |
| |
| License: MIT |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.utils.tensorboard import SummaryWriter |
| import math |
| from itertools import combinations |
| import time |
| import os |
| import json |
| from tqdm.auto import tqdm |
| from pathlib import Path |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Device: {device}") |
|
|
| from geovocab2.shapes.factory.simplex_factory import SimplexFactory |
| from huggingface_hub import HfApi, create_repo, upload_folder |
| import tiktoken |
|
|
| |
| |
| |
|
|
| HF_REPO = "AbstractPhil/ksimplex-llm-prototype" |
| RUN_NAME = f"run_{int(time.time())}" |
| CHECKPOINT_DIR = Path(f"./checkpoints/{RUN_NAME}") |
| TENSORBOARD_DIR = Path(f"./runs/{RUN_NAME}") |
|
|
| CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
| TENSORBOARD_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| |
| |
|
|
| class CMValidator(nn.Module): |
| def __init__(self, k): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| |
| pairs = list(combinations(range(self._nv), 2)) |
| self._npairs = len(pairs) |
| self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
| self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
| |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
| |
| def forward(self, verts): |
| gram = torch.einsum('...ve,...we->...vw', verts, verts) |
| norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
| d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
| d2_mat = F.relu(d2_mat) |
| |
| d2_pairs = d2_mat[..., self._pi, self._pj] |
| |
| shape = d2_mat.shape[:-2] |
| V = d2_mat.shape[-1] |
| cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype) |
| cm[..., 0, 1:] = 1.0 |
| cm[..., 1:, 0] = 1.0 |
| cm[..., 1:, 1:] = d2_mat |
| |
| vol2 = self._prefactor * torch.linalg.det(cm) |
| |
| return d2_pairs, vol2 |
|
|
|
|
| |
| |
| |
|
|
| class KSimplexChannel(nn.Module): |
| BASE_DEFORM = 0.05 |
| |
| def __init__(self, k, in_dim, edim, feat_dim): |
| super().__init__() |
| self._k = k |
| self._nv = k + 1 |
| self._edim = edim |
| self._feat_dim = feat_dim |
| |
| self._cm = CMValidator(k) |
| self._geo_dim = self._cm._npairs + 1 |
| |
| factory = SimplexFactory(k=k, embed_dim=edim, method="regular", scale=1.0) |
| self.register_buffer('_template', factory.build_torch(dtype=torch.float32)) |
| |
| self._to_coords = nn.Linear(in_dim, self._nv * edim) |
| self._to_feats = nn.Linear(in_dim, self._nv * feat_dim) |
| |
| self._geo_gate = nn.Sequential( |
| nn.Linear(self._geo_dim, feat_dim), |
| nn.Sigmoid(), |
| ) |
| |
| self._out_dim = feat_dim + self._geo_dim |
| |
| @property |
| def out_dim(self): |
| return self._out_dim |
| |
| def forward(self, x): |
| coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim)) |
| verts = self._template + self.BASE_DEFORM * coords |
| |
| vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim)) |
| |
| d2, vol2 = self._cm(verts) |
| geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
| |
| gate = self._geo_gate(geo) |
| validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) |
| |
| feat_agg = vert_feats.mean(dim=-2) * gate * validity |
| |
| out = torch.cat([feat_agg, geo], dim=-1) |
| |
| return out, vol2, d2.mean(dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class TokenToKChannels(nn.Module): |
| def __init__(self, embed_dim, depth, edim, feat_dim, hidden=256): |
| super().__init__() |
| self._depth = depth |
| |
| self._proj = nn.Sequential( |
| nn.Linear(embed_dim, hidden), |
| nn.LayerNorm(hidden), |
| nn.GELU(), |
| nn.Linear(hidden, hidden), |
| nn.LayerNorm(hidden), |
| nn.GELU(), |
| ) |
| |
| self._k_encoders = nn.ModuleList([ |
| KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim) |
| for k in range(depth) |
| ]) |
| |
| self._k_out_dims = [enc.out_dim for enc in self._k_encoders] |
| self._max_out_dim = max(self._k_out_dims) |
| |
| def forward(self, x): |
| h = self._proj(x) |
| |
| out_list, vol2_list, d2_list = [], [], [] |
| |
| for enc in self._k_encoders: |
| out, vol2, d2_mean = enc(h) |
| |
| pad_size = self._max_out_dim - out.shape[-1] |
| if pad_size > 0: |
| out = F.pad(out, (0, pad_size)) |
| |
| out_list.append(out) |
| vol2_list.append(vol2) |
| d2_list.append(d2_mean) |
| |
| k_channels = torch.stack(out_list, dim=-2) |
| vol2 = torch.stack(vol2_list, dim=-1) |
| d2_mean = torch.stack(d2_list, dim=-1) |
| |
| return k_channels, vol2, d2_mean |
|
|
|
|
| |
| |
| |
|
|
| class KChannelCrossAttention(nn.Module): |
| def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1): |
| super().__init__() |
| self._depth = depth |
| self._feat_dim = feat_dim |
| self._num_heads = num_heads |
| self._head_dim = feat_dim // num_heads |
| |
| self._norm_q = nn.LayerNorm(feat_dim) |
| self._norm_kv = nn.LayerNorm(feat_dim) |
| |
| self._to_q = nn.Linear(feat_dim, feat_dim) |
| self._to_k = nn.Linear(feat_dim, feat_dim) |
| self._to_v = nn.Linear(feat_dim, feat_dim) |
| self._out = nn.Linear(feat_dim, feat_dim) |
| self._drop = nn.Dropout(dropout) |
| |
| self._scale = self._head_dim ** -0.5 |
| |
| def forward(self, x): |
| B, T, K, F = x.shape |
| |
| x_flat = x.view(B * T, K, F) |
| |
| q = self._to_q(self._norm_q(x_flat)) |
| k = self._to_k(self._norm_kv(x_flat)) |
| v = self._to_v(self._norm_kv(x_flat)) |
| |
| q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
| k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
| v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
| |
| attn = (q @ k.transpose(-2, -1)) * self._scale |
| attn = attn.softmax(dim=-1) |
| attn = self._drop(attn) |
| |
| out = (attn @ v).transpose(1, 2).reshape(B * T, K, F) |
| out = self._out(out) |
| out = self._drop(out) |
| |
| return x + out.view(B, T, K, F) |
|
|
|
|
| |
| |
| |
|
|
| class CausalSequenceAttention(nn.Module): |
| def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1, max_seq_len=2048): |
| super().__init__() |
| self._num_heads = num_heads |
| |
| total_dim = depth * feat_dim |
| self._head_dim = total_dim // num_heads |
| |
| self._norm = nn.LayerNorm(total_dim) |
| self._to_qkv = nn.Linear(total_dim, 3 * total_dim) |
| self._out = nn.Linear(total_dim, total_dim) |
| self._drop = nn.Dropout(dropout) |
| |
| self._scale = self._head_dim ** -0.5 |
| |
| self.register_buffer( |
| '_causal_mask', |
| torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() |
| ) |
| |
| def forward(self, x): |
| B, T, K, F = x.shape |
| |
| x_flat = x.view(B, T, K * F) |
| x_norm = self._norm(x_flat) |
| |
| qkv = self._to_qkv(x_norm).chunk(3, dim=-1) |
| q, k, v = [t.view(B, T, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv] |
| |
| attn = (q @ k.transpose(-2, -1)) * self._scale |
| |
| mask = self._causal_mask[:T, :T] |
| attn = attn.masked_fill(~mask, float('-inf')) |
| attn = attn.softmax(dim=-1) |
| attn = self._drop(attn) |
| |
| out = (attn @ v).transpose(1, 2).reshape(B, T, K * F) |
| out = self._out(out) |
| out = self._drop(out) |
| |
| return x + out.view(B, T, K, F) |
|
|
|
|
| |
| |
| |
|
|
| class GeoBlock(nn.Module): |
| def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1, max_seq_len=2048): |
| super().__init__() |
| |
| self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout) |
| self._seq_attn = CausalSequenceAttention(depth, feat_dim, num_heads, dropout, max_seq_len) |
| |
| total_dim = depth * feat_dim |
| self._norm = nn.LayerNorm(total_dim) |
| self._mlp = nn.Sequential( |
| nn.Linear(total_dim, int(total_dim * mlp_ratio)), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(int(total_dim * mlp_ratio), total_dim), |
| nn.Dropout(dropout), |
| ) |
| |
| def forward(self, x): |
| B, T, K, F = x.shape |
| |
| x = self._k_attn(x) |
| x = self._seq_attn(x) |
| |
| x_flat = x.view(B, T, K * F) |
| x_flat = x_flat + self._mlp(self._norm(x_flat)) |
| x = x_flat.view(B, T, K, F) |
| |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class GeometricLM(nn.Module): |
| def __init__( |
| self, |
| vocab_size, |
| max_seq_len=512, |
| embed_dim=256, |
| depth=4, |
| edim=16, |
| feat_dim=64, |
| hidden=256, |
| num_heads=8, |
| num_blocks=8, |
| dropout=0.1, |
| ): |
| super().__init__() |
| |
| self._vocab_size = vocab_size |
| self._max_seq_len = max_seq_len |
| self._depth = depth |
| self._feat_dim = feat_dim |
| |
| self._tok_embed = nn.Embedding(vocab_size, embed_dim) |
| self._pos_embed = nn.Embedding(max_seq_len, embed_dim) |
| |
| self._tok_to_k = TokenToKChannels(embed_dim, depth, edim, feat_dim, hidden) |
| self._max_out_dim = self._tok_to_k._max_out_dim |
| |
| self._proj = nn.Linear(self._max_out_dim, feat_dim) |
| |
| self._blocks = nn.ModuleList([ |
| GeoBlock(depth, feat_dim, num_heads, dropout=dropout, max_seq_len=max_seq_len) |
| for _ in range(num_blocks) |
| ]) |
| |
| total_dim = depth * feat_dim |
| self._norm = nn.LayerNorm(total_dim) |
| self._lm_head = nn.Linear(total_dim, vocab_size, bias=False) |
| |
| self._config = { |
| 'vocab_size': vocab_size, |
| 'max_seq_len': max_seq_len, |
| 'embed_dim': embed_dim, |
| 'depth': depth, |
| 'edim': edim, |
| 'feat_dim': feat_dim, |
| 'hidden': hidden, |
| 'num_heads': num_heads, |
| 'num_blocks': num_blocks, |
| 'dropout': dropout, |
| 'total_dim': total_dim, |
| } |
| |
| def forward(self, tokens): |
| B, T = tokens.shape |
| |
| pos = torch.arange(T, device=tokens.device) |
| x = self._tok_embed(tokens) + self._pos_embed(pos) |
| |
| k_channels, vol2, d2_mean = self._tok_to_k(x) |
| k_channels = self._proj(k_channels) |
| |
| for blk in self._blocks: |
| k_channels = blk(k_channels) |
| |
| out = k_channels.flatten(-2) |
| logits = self._lm_head(self._norm(out)) |
| |
| return logits, {'vol2': vol2, 'd2_mean': d2_mean} |
| |
| @torch.no_grad() |
| def generate(self, prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=50): |
| self.eval() |
| tokens = prompt_tokens.clone() |
| |
| for _ in range(max_new_tokens): |
| ctx = tokens[:, -self._max_seq_len:] |
| logits, _ = self(ctx) |
| logits = logits[:, -1, :] / temperature |
| |
| if top_k > 0: |
| v, _ = torch.topk(logits, top_k) |
| logits[logits < v[:, [-1]]] = float('-inf') |
| |
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| tokens = torch.cat([tokens, next_tok], dim=1) |
| |
| return tokens |
|
|
|
|
| |
| |
| |
|
|
| class TokenizedDataset(Dataset): |
| def __init__(self, tokens, seq_len, stride=None): |
| self._tokens = tokens |
| self._seq_len = seq_len |
| self._stride = stride if stride else seq_len // 2 |
| |
| def __len__(self): |
| return max(0, (len(self._tokens) - self._seq_len - 1) // self._stride) |
| |
| def __getitem__(self, idx): |
| start = idx * self._stride |
| chunk = self._tokens[start:start + self._seq_len + 1] |
| x = torch.tensor(chunk[:-1], dtype=torch.long) |
| y = torch.tensor(chunk[1:], dtype=torch.long) |
| return x, y |
|
|
|
|
| |
| |
| |
|
|
| def lm_loss(logits, targets, info, ce_weight=1.0, validity_weight=0.1): |
| B, T, V = logits.shape |
| ce = F.cross_entropy(logits.view(B * T, V), targets.view(B * T)) |
| validity = F.relu(-info['vol2']).mean() |
| total = ce_weight * ce + validity_weight * validity |
| return total, ce, validity |
|
|
|
|
| @torch.no_grad() |
| def compute_metrics(info, depth): |
| vol2 = info['vol2'] |
| d2_mean = info['d2_mean'] |
| |
| m = {'valid_rate': (vol2 > 0).float().mean().item()} |
| for k in range(depth): |
| m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item() |
| m[f'k{k+1}_vol2'] = vol2[..., k].mean().item() |
| m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item() |
| return m |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sanity_check(model, enc, device): |
| """Verify no information leak.""" |
| print("\n" + "=" * 60) |
| print("SANITY CHECK") |
| print("=" * 60) |
| |
| model.eval() |
| |
| |
| random_tokens = torch.randint(0, 1000, (4, 256), device=device) |
| logits, _ = model(random_tokens) |
| random_targets = torch.randint(0, enc.n_vocab, (4, 256), device=device) |
| ce = F.cross_entropy(logits.view(-1, enc.n_vocab), random_targets.view(-1)) |
| |
| expected_ce = math.log(enc.n_vocab) |
| print(f"Test 1 - Random input:") |
| print(f" CE: {ce.item():.2f} (expected ~{expected_ce:.2f})") |
| print(f" PPL: {math.exp(min(ce.item(), 20)):.0f} (expected ~{enc.n_vocab})") |
| |
| test1_pass = ce.item() > 8.0 |
| print(f" Status: {'✓ PASS' if test1_pass else '✗ FAIL'}") |
| |
| |
| tokens1 = torch.zeros(1, 256, dtype=torch.long, device=device) |
| tokens2 = torch.zeros(1, 256, dtype=torch.long, device=device) |
| tokens2[0, 128:] = 999 |
| |
| logits1, _ = model(tokens1) |
| logits2, _ = model(tokens2) |
| |
| diff_early = (logits1[0, :128] - logits2[0, :128]).abs().max().item() |
| diff_late = (logits1[0, 128:] - logits2[0, 128:]).abs().max().item() |
| |
| print(f"\nTest 2 - Causal mask:") |
| print(f" Early positions diff: {diff_early:.6f} (should be ~0)") |
| print(f" Late positions diff: {diff_late:.6f} (should be >0)") |
| |
| test2_pass = diff_early < 1e-5 and diff_late > 1e-3 |
| print(f" Status: {'✓ PASS' if test2_pass else '✗ FAIL'}") |
| |
| |
| print(f"\nTest 3 - Dataset offset:") |
| test_tokens = list(range(100)) |
| ds = TokenizedDataset(test_tokens, seq_len=10) |
| x, y = ds[0] |
| offset_correct = all(x[i] + 1 == y[i] for i in range(len(x))) |
| print(f" x: {x[:5].tolist()}...") |
| print(f" y: {y[:5].tolist()}...") |
| print(f" Offset correct: {'✓ PASS' if offset_correct else '✗ FAIL'}") |
| |
| print("=" * 60) |
| |
| all_pass = test1_pass and test2_pass and offset_correct |
| if not all_pass: |
| print("⚠️ WARNING: Some sanity checks failed!") |
| else: |
| print("✓ All sanity checks passed!") |
| |
| print("=" * 60 + "\n") |
| |
| model.train() |
| return all_pass |
|
|
|
|
| |
| |
| |
|
|
| PROMPTS = [ |
| "ROMEO: ", |
| "JULIET: ", |
| "To be or not to be", |
| "The king ", |
| "Once upon a time", |
| "First Citizen:\n", |
| "What light through yonder", |
| "Friends, Romans, countrymen", |
| "Now is the winter of", |
| "All the world's a stage", |
| ] |
|
|
| @torch.no_grad() |
| def generate_samples(model, enc, device, epoch, writer=None): |
| """Generate samples from all prompts.""" |
| model.eval() |
| |
| samples = [] |
| print(f"\n{'='*60}") |
| print(f"GENERATION SAMPLES - Epoch {epoch}") |
| print(f"{'='*60}") |
| |
| for i, prompt in enumerate(PROMPTS): |
| prompt_tokens = torch.tensor([enc.encode(prompt)], device=device) |
| |
| out_tokens = model.generate( |
| prompt_tokens, |
| max_new_tokens=100, |
| temperature=0.8, |
| top_k=50 |
| ) |
| |
| generated = enc.decode(out_tokens[0].tolist()) |
| samples.append({'prompt': prompt, 'generated': generated}) |
| |
| print(f"\n--- Prompt {i+1}: '{prompt.strip()}' ---") |
| print(generated[:300]) |
| if len(generated) > 300: |
| print("...") |
| |
| print(f"{'='*60}\n") |
| |
| |
| if writer: |
| sample_text = "\n\n".join([ |
| f"**Prompt:** {s['prompt']}\n**Generated:**\n{s['generated'][:500]}" |
| for s in samples |
| ]) |
| writer.add_text("samples/generated", sample_text, epoch) |
| |
| model.train() |
| return samples |
|
|
|
|
| |
| |
| |
|
|
| def save_checkpoint(model, optimizer, scheduler, epoch, config, metrics, checkpoint_dir): |
| """Save checkpoint locally.""" |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'config': config, |
| 'metrics': metrics, |
| } |
| |
| path = checkpoint_dir / f"checkpoint_epoch_{epoch:03d}.pt" |
| torch.save(checkpoint, path) |
| |
| |
| torch.save(checkpoint, checkpoint_dir / "checkpoint_latest.pt") |
| |
| |
| with open(checkpoint_dir / "config.json", 'w') as f: |
| json.dump(config, f, indent=2) |
| |
| print(f"Saved checkpoint: {path}") |
| return path |
|
|
|
|
| def upload_to_hf(checkpoint_dir, repo_id, epoch): |
| """Upload checkpoint directory to HuggingFace.""" |
| try: |
| api = HfApi() |
| |
| |
| try: |
| create_repo(repo_id, exist_ok=True, repo_type="model") |
| except Exception as e: |
| print(f"Repo creation note: {e}") |
| |
| |
| api.upload_folder( |
| folder_path=str(checkpoint_dir), |
| repo_id=repo_id, |
| commit_message=f"Epoch {epoch} checkpoint", |
| ) |
| |
| print(f"Uploaded to HuggingFace: {repo_id}") |
| return True |
| except Exception as e: |
| print(f"HuggingFace upload failed: {e}") |
| return False |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| import urllib.request |
| |
| |
| writer = SummaryWriter(log_dir=str(TENSORBOARD_DIR)) |
| print(f"TensorBoard logs: {TENSORBOARD_DIR}") |
| print(f"Checkpoints: {CHECKPOINT_DIR}") |
| print(f"HuggingFace repo: {HF_REPO}") |
| |
| |
| data_path = './data/shakespeare.txt' |
| if not os.path.exists(data_path): |
| os.makedirs('./data', exist_ok=True) |
| url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' |
| print("Downloading Shakespeare...") |
| urllib.request.urlretrieve(url, data_path) |
| |
| with open(data_path, 'r') as f: |
| text = f.read() |
| |
| print(f"Text length: {len(text):,} chars") |
| |
| |
| print("Loading tokenizer...") |
| enc = tiktoken.get_encoding("gpt2") |
| |
| print("Tokenizing...") |
| tokens = enc.encode(text) |
| print(f"Token count: {len(tokens):,}") |
| print(f"Vocab size: {enc.n_vocab:,}") |
| print(f"Compression ratio: {len(text) / len(tokens):.2f}x") |
| |
| |
| seq_len = 256 |
| split_idx = int(len(tokens) * 0.9) |
| train_tokens = tokens[:split_idx] |
| val_tokens = tokens[split_idx:] |
| |
| train_ds = TokenizedDataset(train_tokens, seq_len) |
| val_ds = TokenizedDataset(val_tokens, seq_len) |
| |
| batch_size = 12 |
| train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) |
| val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True) |
| |
| print(f"Train sequences: {len(train_ds):,} ({len(train_dl)} batches)") |
| print(f"Val sequences: {len(val_ds):,} ({len(val_dl)} batches)") |
| |
| |
| model_config = { |
| 'vocab_size': enc.n_vocab, |
| 'max_seq_len': seq_len, |
| 'embed_dim': 384, |
| 'depth': 4, |
| 'edim': 16, |
| 'feat_dim': 96, |
| 'hidden': 384, |
| 'num_heads': 8, |
| 'num_blocks': 8, |
| 'dropout': 0.1, |
| } |
| |
| |
| train_config = { |
| 'batch_size': batch_size, |
| 'seq_len': seq_len, |
| 'lr': 3e-4, |
| 'weight_decay': 0.1, |
| 'num_epochs': 14, |
| 'grad_clip': 1.0, |
| 'ce_weight': 1.0, |
| 'validity_weight': 0.1, |
| } |
| |
| full_config = { |
| 'model': model_config, |
| 'training': train_config, |
| 'data': { |
| 'train_tokens': len(train_tokens), |
| 'val_tokens': len(val_tokens), |
| 'vocab_size': enc.n_vocab, |
| }, |
| 'run_name': RUN_NAME, |
| } |
| |
| |
| with open(CHECKPOINT_DIR / "config.json", 'w') as f: |
| json.dump(full_config, f, indent=2) |
| |
| |
| print("\nBuilding model...") |
| model = GeometricLM(**model_config).to(device) |
| |
| print(f"\nConfig:") |
| for k, v in model._config.items(): |
| print(f" {k}: {v}") |
| |
| params = sum(p.numel() for p in model.parameters()) |
| print(f" params: {params:,}") |
| full_config['model']['params'] = params |
| |
| |
| sanity_check(model, enc, device) |
| |
| print("\nCompiling...") |
| |
| |
| |
| opt = torch.optim.AdamW( |
| model.parameters(), |
| lr=train_config['lr'], |
| weight_decay=train_config['weight_decay'] |
| ) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=train_config['num_epochs']) |
| |
| |
| |
| |
| best_val = float('inf') |
| best_ppl = float('inf') |
| global_step = 0 |
| |
| print("\nTraining...") |
| print("=" * 120) |
| |
| epoch_pbar = tqdm(range(train_config['num_epochs']), desc="Epochs", position=0) |
| |
| for ep in epoch_pbar: |
| epoch_start = time.time() |
| |
| |
| model.train() |
| ce_sum, val_sum, n = 0, 0, 0 |
| |
| train_pbar = tqdm(train_dl, desc=f"Train {ep+1}", leave=False, position=1) |
| for batch_idx, (x, y) in enumerate(train_pbar): |
| x, y = x.to(device), y.to(device) |
| |
| opt.zero_grad() |
| logits, info = model(x) |
| loss, ce, val = lm_loss( |
| logits, y, info, |
| ce_weight=train_config['ce_weight'], |
| validity_weight=train_config['validity_weight'] |
| ) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip']) |
| opt.step() |
| |
| ce_sum += ce.item() * x.size(0) |
| val_sum += val.item() * x.size(0) |
| n += x.size(0) |
| |
| |
| if global_step % 100 == 0: |
| writer.add_scalar("train/ce_batch", ce.item(), global_step) |
| writer.add_scalar("train/ppl_batch", math.exp(min(ce.item(), 10)), global_step) |
| writer.add_scalar("train/validity_batch", val.item(), global_step) |
| writer.add_scalar("train/lr", sched.get_last_lr()[0], global_step) |
| |
| global_step += 1 |
| |
| train_pbar.set_postfix({ |
| 'CE': f'{ce.item():.3f}', |
| 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' |
| }) |
| |
| tr_ce = ce_sum / n |
| tr_ppl = math.exp(min(tr_ce, 10)) |
| tr_val = val_sum / n |
| |
| |
| model.eval() |
| ce_sum, n = 0, 0 |
| metrics_agg = [] |
| |
| val_pbar = tqdm(val_dl, desc=f"Val {ep+1}", leave=False, position=1) |
| with torch.no_grad(): |
| for x, y in val_pbar: |
| x, y = x.to(device), y.to(device) |
| logits, info = model(x) |
| _, ce, _ = lm_loss(logits, y, info) |
| ce_sum += ce.item() * x.size(0) |
| n += x.size(0) |
| metrics_agg.append(compute_metrics(info, model._config['depth'])) |
| |
| val_pbar.set_postfix({ |
| 'CE': f'{ce.item():.3f}', |
| 'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' |
| }) |
| |
| va_ce = ce_sum / n |
| va_ppl = math.exp(min(va_ce, 10)) |
| |
| sched.step() |
| |
| if va_ce < best_val: |
| best_val = va_ce |
| best_ppl = va_ppl |
| |
| |
| m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]} |
| |
| epoch_time = time.time() - epoch_start |
| |
| |
| writer.add_scalar("epoch/train_ce", tr_ce, ep) |
| writer.add_scalar("epoch/train_ppl", tr_ppl, ep) |
| writer.add_scalar("epoch/val_ce", va_ce, ep) |
| writer.add_scalar("epoch/val_ppl", va_ppl, ep) |
| writer.add_scalar("epoch/best_ppl", best_ppl, ep) |
| writer.add_scalar("epoch/validity_loss", tr_val, ep) |
| writer.add_scalar("epoch/time", epoch_time, ep) |
| |
| for k in range(model._config['depth']): |
| writer.add_scalar(f"geometry/k{k+1}_valid", m[f'k{k+1}_valid'], ep) |
| writer.add_scalar(f"geometry/k{k+1}_vol2", m[f'k{k+1}_vol2'], ep) |
| writer.add_scalar(f"geometry/k{k+1}_d2", m[f'k{k+1}_d2'], ep) |
| |
| writer.add_scalar("geometry/valid_rate", m['valid_rate'], ep) |
| |
| |
| epoch_pbar.set_postfix({ |
| 'TrPPL': f'{tr_ppl:.1f}', |
| 'VaPPL': f'{va_ppl:.1f}', |
| 'Best': f'{best_ppl:.1f}', |
| 'Valid': f"{m['valid_rate']:.0%}" |
| }) |
| |
| tqdm.write( |
| f"\nEp {ep+1:3d} | TrCE {tr_ce:.4f} | VaCE {va_ce:.4f} | " |
| f"TrPPL {tr_ppl:7.2f} | VaPPL {va_ppl:7.2f} | BestPPL {best_ppl:.2f} | " |
| f"Time {epoch_time:.1f}s" |
| ) |
| tqdm.write( |
| f" | k1 {m['k1_valid']:5.1%} vol²={m['k1_vol2']:.2e} | " |
| f"k2 {m['k2_valid']:5.1%} vol²={m['k2_vol2']:.2e} | " |
| f"k3 {m['k3_valid']:5.1%} vol²={m['k3_vol2']:.2e} | " |
| f"k4 {m['k4_valid']:5.1%} vol²={m['k4_vol2']:.2e}" |
| ) |
| |
| |
| if ep % 25 == 0 or ep == train_config['num_epochs'] - 1: |
| samples = generate_samples(model, enc, device, ep + 1, writer) |
| |
| |
| with open(CHECKPOINT_DIR / f"samples_epoch_{ep+1:03d}.json", 'w') as f: |
| json.dump(samples, f, indent=2) |
| |
| |
| metrics = { |
| 'epoch': ep + 1, |
| 'train_ce': tr_ce, |
| 'train_ppl': tr_ppl, |
| 'val_ce': va_ce, |
| 'val_ppl': va_ppl, |
| 'best_ppl': best_ppl, |
| 'geometry': m, |
| } |
| |
| if ep % 2 == 0 or ep == train_config['num_epochs'] - 1: |
| save_checkpoint(model, opt, sched, ep + 1, full_config, metrics, CHECKPOINT_DIR) |
| |
| |
| |
| if train_config['num_epochs'] - 1 == ep: |
| upload_to_hf(CHECKPOINT_DIR, HF_REPO, ep + 1) |
| |
| |
| writer.close() |
| |
| print("\n" + "=" * 120) |
| print(f"Training complete!") |
| print(f"Best val CE: {best_val:.4f}, PPL: {best_ppl:.2f}") |
| print(f"Checkpoints: {CHECKPOINT_DIR}") |
| print(f"TensorBoard: {TENSORBOARD_DIR}") |
| print(f"HuggingFace: https://huggingface.co/{HF_REPO}") |
| print("=" * 120) |
| |
| return model, enc |
|
|
|
|
| if __name__ == "__main__": |
| model, tokenizer = train() |