| import argparse |
| import math |
| import os |
| from collections import Counter |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from datasets import load_from_disk |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import LambdaLR |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| |
| def modulate(x, shift, scale): |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| class TimestepEmbedder(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(1, hidden_size, bias=True), nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
| def forward(self, t): |
| return self.mlp(t.unsqueeze(-1)) |
|
|
| class DiTBlock(nn.Module): |
| def __init__(self, hidden_size, n_heads): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True) |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(), |
| nn.Linear(4 * hidden_size, hidden_size) |
| ) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| ) |
| def forward(self, x, c): |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) |
| x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) |
| attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1) |
| x = x + gate_msa.unsqueeze(1) * attn_output |
| x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| mlp_output = self.mlp(x_norm2) |
| x = x + gate_mlp.unsqueeze(1) * mlp_output |
| return x |
|
|
| class MDLM(nn.Module): |
| def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.seq_len = seq_len |
| self.model_dim = model_dim |
| self.mask_token_id = vocab_size |
| self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) |
| self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim)) |
| self.time_embedder = TimestepEmbedder(model_dim) |
| self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)]) |
| self.final_norm = nn.LayerNorm(model_dim) |
| self.lm_head = nn.Linear(model_dim, vocab_size) |
| self.apply(self._init_weights) |
| def _init_weights(self, module): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| if module.bias is not None: |
| module.bias.data.zero_() |
| if module.weight is not None: |
| module.weight.data.fill_(1.0) |
| def forward(self, x, t): |
| seq_len = x.shape[1] |
| x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :] |
| t_embed = self.time_embedder(t) |
| for block in self.transformer_blocks: |
| x_embed = block(x_embed, t_embed) |
| x_embed = self.final_norm(x_embed) |
| logits = self.lm_head(x_embed) |
| return logits |
|
|
| |
| def get_lr_scheduler(optimizer, warmup_steps, total_steps, lr_min, lr_max): |
| def lr_lambda(current_step): |
| if current_step < warmup_steps: |
| lr_range = lr_max - lr_min |
| lr = lr_min + lr_range * (current_step / warmup_steps) |
| return lr / lr_max |
| else: |
| denominator = total_steps - warmup_steps |
| if denominator == 0: progress = 1.0 |
| else: progress = (current_step - warmup_steps) / denominator |
| cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress)) |
| lr_range = lr_max - lr_min |
| lr = lr_min + lr_range * cosine_decay |
| return lr / lr_max |
| return LambdaLR(optimizer, lr_lambda) |
|
|
| |
| def train_one_epoch(model, dataloader, optimizer, scheduler, device, epoch, args): |
| model.train() |
| total_loss = 0.0 |
| progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Train]") |
|
|
| |
| is_rectified_dataset = 'input_ids_x0' in dataloader.dataset.column_names |
|
|
| for batch in progress_bar: |
| optimizer.zero_grad() |
| |
| if is_rectified_dataset: |
| x_0 = torch.tensor(batch['input_ids_x0']).to(device) |
| x_1 = torch.tensor(batch['input_ids_x1']).to(device) |
| else: |
| x_1 = torch.tensor(batch['input_ids']).to(device) |
| x_0 = torch.randint(0, model.vocab_size, x_1.shape, device=device) |
|
|
| batch_size, _ = x_1.shape |
| t_continuous = torch.rand(batch_size, device=device) |
| mask = torch.rand(x_1.shape, device=device) < t_continuous.view(-1, 1) |
| x_t = torch.where(mask, x_1, x_0) |
| |
| logits = model(x_t, t_continuous) |
| loss = F.cross_entropy(logits.view(-1, model.vocab_size), x_1.view(-1), label_smoothing=args.label_smoothing) |
|
|
| loss.backward() |
| optimizer.step() |
| scheduler.step() |
|
|
| total_loss += loss.item() |
| progress_bar.set_postfix(loss=loss.item(), lr=scheduler.get_last_lr()[0]) |
|
|
| return total_loss / len(dataloader) |
|
|
| def validate(model, val_dataloader, device, epoch, args): |
| model.eval() |
| total_val_nll = 0.0 |
| total_tc = 0.0 |
| tc_batches = 0 |
| progress_bar = tqdm(val_dataloader, desc=f"Epoch {epoch+1} [Val]") |
| |
| is_rectified_dataset = 'input_ids_x0' in val_dataloader.dataset.column_names |
|
|
| with torch.no_grad(): |
| for i, batch in enumerate(progress_bar): |
| if is_rectified_dataset: |
| x_0 = torch.tensor(batch['input_ids_x0']).to(device) |
| x_1 = torch.tensor(batch['input_ids_x1']).to(device) |
| else: |
| x_1 = torch.tensor(batch['input_ids']).to(device) |
| x_0 = torch.randint(0, model.vocab_size, x_1.shape, device=device) |
|
|
| batch_size, seq_len = x_1.shape |
| t_continuous = torch.rand(batch_size, device=device) |
| mask = torch.rand(x_1.shape, device=device) < t_continuous.view(-1, 1) |
| x_t = torch.where(mask, x_1, x_0) |
| |
| logits = model(x_t, t_continuous) |
| val_nll = F.cross_entropy(logits.view(-1, model.vocab_size), x_1.view(-1)) |
| total_val_nll += val_nll.item() |
|
|
| if i < args.tc_batches: |
| k = args.tc_k_samples |
| p_marginal = F.softmax(logits, dim=-1) |
| sampled_x1 = torch.multinomial(p_marginal.view(-1, model.vocab_size), k, replacement=True).view(batch_size, seq_len, k) |
| |
| kl_divs = [] |
| for b in range(batch_size): |
| sample_tuples = [tuple(s.tolist()) for s in sampled_x1[b].T] |
| joint_counts = Counter(sample_tuples) |
| p_joint_est = {k: v / len(sample_tuples) for k, v in joint_counts.items()} |
| kl_sum = 0 |
| for seq_tuple, p_j in p_joint_est.items(): |
| log_p_marginal_prod = 0 |
| for pos, token_id in enumerate(seq_tuple): |
| log_p_marginal_prod += torch.log(p_marginal[b, pos, token_id] + 1e-9) |
| kl_sum += p_j * (math.log(p_j + 1e-9) - log_p_marginal_prod) |
| kl_divs.append(kl_sum) |
| total_tc += sum(kl_divs) / len(kl_divs) |
| tc_batches += 1 |
|
|
| avg_val_nll = total_val_nll / len(val_dataloader) |
| perplexity = math.exp(avg_val_nll) |
| avg_tc = total_tc / tc_batches if tc_batches > 0 else 0 |
| return avg_val_nll, perplexity, avg_tc |
|
|
| |
| def main(args): |
| run_name = f"rectify_version{args.version}_lr{args.learning_rate}_ls{args.label_smoothing}" |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| args.checkpoint_dir = os.path.join(args.checkpoint_dir, run_name) |
| print(f"Saving to {args.checkpoint_dir}") |
| os.makedirs(args.checkpoint_dir, exist_ok=True) |
|
|
| print("Loading datasets...") |
| train_dataset = load_from_disk(args.train_dataset_path) |
| val_dataset = load_from_disk(args.val_dataset_path) |
| train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=False) |
| val_dataloader = DataLoader(val_dataset, batch_size=None, shuffle=False) |
|
|
| print("Initializing model...") |
| model = MDLM(args.vocab_size, args.seq_len, args.model_dim, args.n_heads, args.n_layers).to(device) |
| |
| if args.resume_from_checkpoint: |
| print(f"Loading weights from: {args.resume_from_checkpoint}") |
| try: |
| checkpoint = torch.load(args.resume_from_checkpoint, map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print("Model weights loaded successfully for fine-tuning.") |
| except Exception as e: |
| print(f"Error loading checkpoint for resume: {e}") |
| return |
|
|
| print(f"Model initialized with {sum(p.numel() for p in model.parameters()):,} parameters.") |
| |
| optimizer = AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) |
| num_training_steps = args.epochs * len(train_dataloader) |
| warmup_steps = int(num_training_steps * 0.1) |
| scheduler = get_lr_scheduler(optimizer, warmup_steps, num_training_steps, args.min_learning_rate, args.learning_rate) |
|
|
| best_val_nll = float('inf') |
|
|
| print("Starting training...") |
| for epoch in range(args.epochs): |
| train_loss = train_one_epoch(model, train_dataloader, optimizer, scheduler, device, epoch, args) |
| val_nll, perplexity, tc_error = validate(model, val_dataloader, device, epoch, args) |
|
|
| print(f"Epoch {epoch+1}/{args.epochs} -> Train Loss: {train_loss:.4f}, Val NLL: {val_nll:.4f}, Val PPL: {perplexity:.2f}, TC: {tc_error:.4f}") |
|
|
| if val_nll < best_val_nll: |
| best_val_nll = val_nll |
| best_checkpoint_path = os.path.join(args.checkpoint_dir, "best.pt") |
| torch.save({ |
| 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), |
| 'val_nll': val_nll, 'tc_error': tc_error, 'args': args |
| }, best_checkpoint_path) |
| print(f"New best checkpoint saved to {best_checkpoint_path}") |
|
|
| print("Training complete.") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Train or fine-tune a ReDi (MDLM) model.") |
|
|
| parser.add_argument("--train_dataset_path", type=str, required=True) |
| parser.add_argument("--val_dataset_path", type=str, required=True) |
| parser.add_argument("--model_dim", type=int, default=1024) |
| parser.add_argument("--n_heads", type=int, default=8) |
| parser.add_argument("--n_layers", type=int, default=6) |
| parser.add_argument("--vocab_size", type=int, default=24) |
| parser.add_argument("--seq_len", type=int, default=100) |
| parser.add_argument("--epochs", type=int, default=50) |
| parser.add_argument("--learning_rate", type=float, default=1e-4) |
| parser.add_argument("--min_learning_rate", type=float, default=1e-5) |
| parser.add_argument("--weight_decay", type=float, default=1e-5) |
| parser.add_argument("--label_smoothing", type=float, default=0) |
| parser.add_argument("--tc_batches", type=int, default=5) |
| parser.add_argument("--tc_k_samples", type=int, default=10) |
| parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints") |
| parser.add_argument("--version", type=str, default=1, help="Rectification iteration number.") |
| parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="Path to a checkpoint to load model weights from for fine-tuning.") |
|
|
| args = parser.parse_args() |
| main(args) |
|
|