| """Unified multi-modal ARB pure-ternary pre-trainer. |
| |
| Supports text, code, image, audio, and video modalities with weighted mixing, |
| checkpoint resume, and packed ternary state updates. Core pretraining freezes |
| all IEEE-float parameters; LoRA/AdamW paths live under ``training/finetuning``. |
| |
| Usage: |
| # Phase 1a — Text pre-training smoke test (100M tokens on RTX 6000 Pro) |
| python training/pretrain.py --text-data training/data/tinyshakespeare.txt \\ |
| --text-weight 1.0 --steps 50000 --batch 8 --ctx 1024 |
| |
| # Phase 1b — Full text + code pre-training |
| python training/pretrain.py --text-weight 0.95 --code-weight 0.05 \\ |
| --steps 1000000 --batch 16 --ctx 2048 |
| |
| # Phase 2 — Add vision (freeze text, train vision adapters) |
| python training/pretrain.py --resume models/checkpoints/phase1b/best.pt \\ |
| --image-weight 0.3 --text-weight 1.0 |
| |
| # Phase 3 — Add audio |
| python training/pretrain.py --resume models/checkpoints/phase2/best.pt \\ |
| --audio-weight 0.2 --text-weight 1.0 |
| |
| # Phase 4 — Add video |
| python training/pretrain.py --resume models/checkpoints/phase3/best.pt \\ |
| --video-weight 0.1 --text-weight 1.0 |
| |
| # Smoke test (1 step, CPU) |
| python training/pretrain.py --steps 1 --batch 1 --ctx 4 --cpu --no-save |
| """ |
| import argparse, os, random, sys, time |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm import tqdm |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) |
| from arbitor import ARBModel |
| from arbitor.config import CTX |
| from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters |
| from arbitor.kernel.ternary_scale import TScaleType |
| from training.data import ( |
| FineWebStream, FineWebConfig, |
| StarCoderStream, StarCoderConfig, |
| CC12MStream, CC12MConfig, |
| LibriSpeechStream, LibriSpeechConfig, |
| WebVidStream, WebVidConfig, |
| ) |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| steps: int = 5000 |
| batch: int = 8 |
| ctx: int = min(CTX, 1024) |
| accum: int = 1 |
| tscale_type: str = "T32" |
| backend: str = "triton" |
|
|
| freeze_text: bool = False |
| freeze_vision: bool = False |
| freeze_audio: bool = False |
| freeze_video: bool = False |
| enable_vq: bool = True |
| enable_graph: bool = True |
| enable_moe: bool = True |
| enable_attention: bool = True |
| enable_output_router: bool = False |
|
|
| text_weight: float = 1.0 |
| code_weight: float = 0.0 |
| image_weight: float = 0.0 |
| audio_weight: float = 0.0 |
| video_weight: float = 0.0 |
|
|
| text_data: Optional[str] = None |
| data_dir: str = "training/data" |
| out_dir: str = "models/checkpoints" |
| run: str = "pretrain" |
| resume: Optional[str] = None |
| no_save: bool = False |
| save_interval: int = 5000 |
| eval_interval: int = 500 |
| log_interval: int = 10 |
| seed: int = 42 |
| cpu: bool = False |
|
|
| max_moe_iters: int = 4 |
|
|
|
|
| class LocalByteStream: |
| """Small local byte stream for smoke tests and phase-1 text bootstrap.""" |
|
|
| def __init__(self, path: str, ctx: int, batch_size: int): |
| self.path = Path(path) |
| self.ctx = ctx |
| self.batch_size = batch_size |
|
|
| def _load(self) -> torch.Tensor: |
| if not self.path.exists(): |
| raise FileNotFoundError(f"Local text data not found: {self.path}") |
| if self.path.suffix == ".pt": |
| data = torch.load(self.path, weights_only=True).long().cpu() |
| else: |
| data = torch.tensor(list(self.path.read_bytes()), dtype=torch.long) |
| if data.numel() <= self.ctx + 1: |
| raise ValueError(f"Local text data has {data.numel()} tokens but ctx={self.ctx}") |
| return data |
|
|
| def batches(self): |
| data = self._load() |
| while True: |
| ix = torch.randint(0, data.numel() - self.ctx - 1, (self.batch_size,)) |
| x = torch.stack([data[i : i + self.ctx] for i in ix]) |
| yield x, x[:, 3:].contiguous() |
|
|
|
|
| def build_model(cfg: TrainConfig, device: torch.device): |
| model = ARBModel( |
| enable_image=cfg.image_weight > 0, |
| enable_audio=cfg.audio_weight > 0, |
| enable_vq=cfg.enable_vq, |
| enable_graph=cfg.enable_graph, |
| enable_memory_modules=False, |
| enable_moe=cfg.enable_moe, |
| max_moe_iters=cfg.max_moe_iters, |
| tscale_type=getattr(TScaleType, cfg.tscale_type.upper(), TScaleType.T32), |
| enable_attention=cfg.enable_attention and cfg.enable_graph and cfg.enable_vq, |
| enable_output_router=cfg.enable_output_router, |
| enable_video_output=cfg.video_weight > 0, |
| enable_talker_output=cfg.audio_weight > 0, |
| ).to(device) |
|
|
| freeze_float_parameters(model) |
| print(format_audit(audit_model(model))) |
| return model |
|
|
|
|
| def create_streams(cfg: TrainConfig): |
| streams = {} |
| if cfg.text_weight > 0: |
| if cfg.text_data: |
| streams['text'] = LocalByteStream(cfg.text_data, ctx=cfg.ctx, batch_size=cfg.batch) |
| else: |
| streams['text'] = FineWebStream(FineWebConfig(ctx=cfg.ctx, batch_size=cfg.batch)) |
| if cfg.code_weight > 0: |
| streams['code'] = StarCoderStream(StarCoderConfig(ctx=cfg.ctx, batch_size=cfg.batch)) |
| if cfg.image_weight > 0: |
| streams['image'] = CC12MStream(CC12MConfig(batch_size=max(1, cfg.batch // 2))) |
| if cfg.audio_weight > 0: |
| streams['audio'] = LibriSpeechStream(LibriSpeechConfig(batch_size=max(1, cfg.batch // 2))) |
| if cfg.video_weight > 0: |
| streams['video'] = WebVidStream(WebVidConfig(batch_size=max(1, cfg.batch // 4))) |
| return streams |
|
|
|
|
| def sample_modality(cfg: TrainConfig) -> str: |
| weights = { |
| 'text': cfg.text_weight, |
| 'code': cfg.code_weight, |
| 'image': cfg.image_weight, |
| 'audio': cfg.audio_weight, |
| 'video': cfg.video_weight, |
| } |
| active = {k: v for k, v in weights.items() if v > 0} |
| if not active: |
| return 'text' |
| total = sum(active.values()) |
| r = random.random() * total |
| cumulative = 0.0 |
| for k, v in active.items(): |
| cumulative += v |
| if r <= cumulative: |
| return k |
| return list(active.keys())[-1] |
|
|
|
|
| def compute_loss(model, modality: str, batch, device): |
| if modality in ('text', 'code'): |
| x = batch[0].to(device, non_blocking=True) |
| targets = x[:, 3:].contiguous() |
| _, losses, _, _ = model(x, targets=targets) |
| return losses.total |
|
|
| if modality == 'image': |
| images, captions = batch |
| images = images.to(device, non_blocking=True) |
| targets = captions.to(device, non_blocking=True) |
| if targets.size(1) < 4: |
| raise ValueError("Image caption batch must contain at least 4 byte tokens") |
| _, losses, _, _ = model(x=targets, images=images, targets=targets[:, 3:]) |
| return losses.total |
|
|
| if modality == 'audio': |
| waves, vq_targets = batch |
| waves = waves.to(device, non_blocking=True) |
| targets = vq_targets.to(device, non_blocking=True) |
| if targets.size(1) < 4: |
| raise ValueError("Audio token batch must contain at least 4 tokens") |
| _, losses, _, _ = model(x=targets, audio=waves, targets=targets[:, 3:]) |
| return losses.total |
|
|
| if modality == 'video': |
| text_tokens, latent_targets = batch |
| text_tokens = text_tokens.to(device, non_blocking=True) |
| latents = latent_targets.to(device, non_blocking=True) |
| embedded = model.embedding(text_tokens) |
| seq_out = model.multimodal_sequencer({'text': embedded}) |
| rel = seq_out['text'] |
| pred = model.video_head(rel) |
| latents = match_latents(latents, pred) |
| loss = torch.nn.functional.mse_loss(pred, latents) |
| return loss |
|
|
| raise ValueError(f"Unknown modality: {modality}") |
|
|
|
|
| def match_latents(target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: |
| if target.shape[0] == 1 and pred.shape[0] > 1: |
| target = target.expand(pred.shape[0], -1, -1, -1, -1).contiguous() |
| if target.shape[1] != pred.shape[1]: |
| if target.shape[1] > pred.shape[1]: |
| target = target[:, :pred.shape[1]] |
| else: |
| pad = target.new_zeros(target.shape[0], pred.shape[1] - target.shape[1], *target.shape[2:]) |
| target = torch.cat([target, pad], dim=1) |
| if target.shape[2:] != pred.shape[2:]: |
| target = torch.nn.functional.interpolate( |
| target, size=pred.shape[2:], mode="trilinear", align_corners=False |
| ) |
| return target |
|
|
|
|
| def save_checkpoint(path: Path, model, step: int, loss: float, cfg: TrainConfig): |
| if cfg.no_save: |
| return |
| path.parent.mkdir(parents=True, exist_ok=True) |
| state = { |
| 'step': step, |
| 'loss': loss, |
| 'model': model.state_dict(), |
| 'config': vars(cfg), |
| } |
| torch.save(state, path) |
|
|
|
|
| def load_checkpoint(path: str, model, device): |
| ckpt_path = Path(path) |
| if ckpt_path.is_dir(): |
| if (ckpt_path / "latest.pt").exists(): |
| ckpt_path = ckpt_path / "latest.pt" |
| elif (ckpt_path / "best.pt").exists(): |
| ckpt_path = ckpt_path / "best.pt" |
| elif (ckpt_path / "final.pt").exists(): |
| ckpt_path = ckpt_path / "final.pt" |
| state = torch.load(ckpt_path, map_location=device, weights_only=True) |
| missing, unexpected = model.load_state_dict(state['model'], strict=False) |
| if missing or unexpected: |
| print( |
| "Checkpoint loaded with architecture drift: " |
| f"{len(missing)} missing keys, {len(unexpected)} unexpected keys" |
| ) |
| return state.get('step', 0), state.get('loss', float('inf')) |
|
|
|
|
| def train(cfg: TrainConfig): |
| torch.manual_seed(cfg.seed) |
| random.seed(cfg.seed) |
| os.environ["ARB_TERNARY_BACKEND"] = cfg.backend |
| if cfg.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: |
| raise ValueError( |
| "TileLang BigInt training is unfinished and disabled by default. " |
| "Use --backend triton for production training." |
| ) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() and not cfg.cpu else "cpu") |
| print(f"Device: {device}") |
| print(f"Ternary backend: {cfg.backend}") |
|
|
| model = build_model(cfg, device) |
| streams = create_streams(cfg) |
| if not streams: |
| raise ValueError("No active training streams. Set at least one modality weight above 0.") |
| print(f"Active modalities: {', '.join(streams.keys())}") |
|
|
| params = trainable_parameters(model) |
| if params: |
| raise RuntimeError( |
| "Pure ternary pretrain found trainable torch Parameters after freeze. " |
| "Use training/finetuning for LoRA adapters." |
| ) |
| start_step = 0 |
|
|
| if cfg.resume: |
| ckpt_path = Path(cfg.resume) |
| if ckpt_path.exists(): |
| start_step, _ = load_checkpoint(str(ckpt_path), model, device) |
| print(f"Resumed from step {start_step}") |
|
|
| run_dir = Path(cfg.out_dir) / cfg.run |
| writer = SummaryWriter(str(run_dir)) |
| model.train() |
|
|
| stream_iters = {k: s.batches() for k, s in streams.items()} |
| best_loss = float('inf') |
| last_loss = float('inf') |
| step = start_step |
| accum_loss = 0.0 |
| accum_steps = 0 |
| start_time = time.perf_counter() |
| pbar = tqdm(range(start_step, cfg.steps), desc="train", dynamic_ncols=True, |
| initial=start_step, total=cfg.steps) |
|
|
| for step in pbar: |
| modality = sample_modality(cfg) |
| stream = stream_iters.get(modality) |
| if stream is None: |
| continue |
|
|
| try: |
| batch = next(stream) |
| except StopIteration: |
| stream_iters[modality] = streams[modality].batches() |
| batch = next(stream_iters[modality]) |
|
|
| model.zero_grad(set_to_none=True) |
|
|
| raw_loss = compute_loss(model, modality, batch, device) |
| last_loss = raw_loss.detach().item() |
|
|
| loss = raw_loss |
| if cfg.accum > 1: |
| loss = raw_loss / cfg.accum |
| if not torch.isfinite(loss).all(): |
| raise FloatingPointError(f"Non-finite {modality} pretraining loss; aborting before ternary update") |
| model.prepare_ternary_backward(loss.detach(), update_scales=True) |
| loss.backward() |
| accum_loss += raw_loss.detach().item() |
| accum_steps += 1 |
|
|
| if accum_steps >= cfg.accum: |
| model._ternary_update_memory(accum_threshold=3, update_scales=True, |
| loss_signal=raw_loss.detach()) |
| model.zero_grad(set_to_none=True) |
|
|
| report_step = step + 1 |
| if cfg.log_interval and (step + 1) % cfg.log_interval == 0: |
| avg = accum_loss / cfg.accum |
| writer.add_scalar("loss/train", avg, step) |
| pbar.set_postfix(loss=f"{avg:.4f}", mod=modality) |
| print(f"step {report_step:>6d} loss={avg:.4f} mod={modality}") |
|
|
| if cfg.eval_interval and (step + 1) % cfg.eval_interval == 0: |
| avg_loss = accum_loss / cfg.accum |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| save_checkpoint(run_dir / "best.pt", model, step, avg_loss, cfg) |
| print(f"step {report_step:>6d} loss={avg_loss:.4f} mod={modality}") |
|
|
| if cfg.save_interval and (step + 1) % cfg.save_interval == 0: |
| save_checkpoint(run_dir / "latest.pt", model, step, accum_loss / cfg.accum, cfg) |
|
|
| accum_loss = 0.0 |
| accum_steps = 0 |
|
|
| total_time = time.perf_counter() - start_time |
| print(f"Training complete. {cfg.steps - start_step} steps in {total_time / 3600:.1f}h") |
| save_checkpoint(run_dir / "final.pt", model, step, last_loss, cfg) |
| writer.close() |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Unified ARB multi-modal pre-trainer") |
| p.add_argument("--steps", type=int, default=5000) |
| p.add_argument("--batch", type=int, default=8) |
| p.add_argument("--ctx", type=int, default=min(CTX, 1024)) |
| p.add_argument("--accum", type=int, default=1) |
| p.add_argument("--tscale-type", type=str, default="T32") |
| p.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton", |
| help="Training backend. Triton is the production BigInt ternary path.") |
| p.add_argument("--no-save", action="store_true") |
| p.add_argument("--save-interval", type=int, default=5000) |
| p.add_argument("--eval-interval", type=int, default=500) |
| p.add_argument("--log-interval", type=int, default=10) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--cpu", action="store_true") |
| p.add_argument("--max-moe-iters", type=int, default=4) |
| p.add_argument("--out-dir", type=str, default="models/checkpoints") |
| p.add_argument("--run", type=str, default="pretrain") |
| p.add_argument("--resume", type=str, default=None, |
| help="Path to checkpoint .pt or directory with latest.pt") |
|
|
| p.add_argument("--freeze-text", action="store_true", help=argparse.SUPPRESS) |
| p.add_argument("--freeze-vision", action="store_true", help=argparse.SUPPRESS) |
| p.add_argument("--freeze-audio", action="store_true", help=argparse.SUPPRESS) |
| p.add_argument("--freeze-video", action="store_true", help=argparse.SUPPRESS) |
| p.add_argument("--no-vq", dest="enable_vq", action="store_false") |
| p.add_argument("--no-graph", dest="enable_graph", action="store_false") |
| p.add_argument("--no-moe", dest="enable_moe", action="store_false") |
| p.add_argument("--no-attention", dest="enable_attention", action="store_false") |
| p.add_argument("--enable-output-router", action="store_true", default=False) |
| p.set_defaults(enable_vq=True, enable_graph=True, enable_moe=True, enable_attention=True) |
|
|
| p.add_argument("--text-weight", type=float, default=1.0) |
| p.add_argument("--code-weight", type=float, default=0.0) |
| p.add_argument("--image-weight", type=float, default=0.0) |
| p.add_argument("--audio-weight", type=float, default=0.0) |
| p.add_argument("--video-weight", type=float, default=0.0) |
| p.add_argument("--text-data", type=str, default=None, |
| help="Optional local .txt/.pt byte data for text pretraining smoke/bootstrap") |
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| cfg = TrainConfig(**vars(parse_args())) |
| train(cfg) |
|
|