"""Unified multi-modal ARB pure-ternary pre-trainer. Supports text, code, image, audio, and video modalities with weighted mixing, checkpoint resume, and packed ternary state updates. Core pretraining freezes all IEEE-float parameters; LoRA/AdamW paths live under ``training/finetuning``. Usage: # Phase 1a — Text pre-training smoke test (100M tokens on RTX 6000 Pro) python training/pretrain.py --text-data training/data/tinyshakespeare.txt \\ --text-weight 1.0 --steps 50000 --batch 8 --ctx 1024 # Phase 1b — Full text + code pre-training python training/pretrain.py --text-weight 0.95 --code-weight 0.05 \\ --steps 1000000 --batch 16 --ctx 2048 # Phase 2 — Add vision (freeze text, train vision adapters) python training/pretrain.py --resume models/checkpoints/phase1b/best.pt \\ --image-weight 0.3 --text-weight 1.0 # Phase 3 — Add audio python training/pretrain.py --resume models/checkpoints/phase2/best.pt \\ --audio-weight 0.2 --text-weight 1.0 # Phase 4 — Add video python training/pretrain.py --resume models/checkpoints/phase3/best.pt \\ --video-weight 0.1 --text-weight 1.0 # Smoke test (1 step, CPU) python training/pretrain.py --steps 1 --batch 1 --ctx 4 --cpu --no-save """ import argparse, os, random, sys, time from dataclasses import dataclass from pathlib import Path from typing import Optional import torch from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from arbitor import ARBModel from arbitor.config import CTX from arbitor.kernel.ternary_audit import audit_model, format_audit, freeze_float_parameters, trainable_parameters from arbitor.kernel.ternary_scale import TScaleType from training.data import ( FineWebStream, FineWebConfig, StarCoderStream, StarCoderConfig, CC12MStream, CC12MConfig, LibriSpeechStream, LibriSpeechConfig, WebVidStream, WebVidConfig, ) @dataclass class TrainConfig: steps: int = 5000 batch: int = 8 ctx: int = min(CTX, 1024) accum: int = 1 tscale_type: str = "T32" backend: str = "triton" freeze_text: bool = False freeze_vision: bool = False freeze_audio: bool = False freeze_video: bool = False enable_vq: bool = True enable_graph: bool = True enable_moe: bool = True enable_attention: bool = True enable_output_router: bool = False text_weight: float = 1.0 code_weight: float = 0.0 image_weight: float = 0.0 audio_weight: float = 0.0 video_weight: float = 0.0 text_data: Optional[str] = None data_dir: str = "training/data" out_dir: str = "models/checkpoints" run: str = "pretrain" resume: Optional[str] = None no_save: bool = False save_interval: int = 5000 eval_interval: int = 500 log_interval: int = 10 seed: int = 42 cpu: bool = False max_moe_iters: int = 4 class LocalByteStream: """Small local byte stream for smoke tests and phase-1 text bootstrap.""" def __init__(self, path: str, ctx: int, batch_size: int): self.path = Path(path) self.ctx = ctx self.batch_size = batch_size def _load(self) -> torch.Tensor: if not self.path.exists(): raise FileNotFoundError(f"Local text data not found: {self.path}") if self.path.suffix == ".pt": data = torch.load(self.path, weights_only=True).long().cpu() else: data = torch.tensor(list(self.path.read_bytes()), dtype=torch.long) if data.numel() <= self.ctx + 1: raise ValueError(f"Local text data has {data.numel()} tokens but ctx={self.ctx}") return data def batches(self): data = self._load() while True: ix = torch.randint(0, data.numel() - self.ctx - 1, (self.batch_size,)) x = torch.stack([data[i : i + self.ctx] for i in ix]) yield x, x[:, 3:].contiguous() def build_model(cfg: TrainConfig, device: torch.device): model = ARBModel( enable_image=cfg.image_weight > 0, enable_audio=cfg.audio_weight > 0, enable_vq=cfg.enable_vq, enable_graph=cfg.enable_graph, enable_memory_modules=False, enable_moe=cfg.enable_moe, max_moe_iters=cfg.max_moe_iters, tscale_type=getattr(TScaleType, cfg.tscale_type.upper(), TScaleType.T32), enable_attention=cfg.enable_attention and cfg.enable_graph and cfg.enable_vq, enable_output_router=cfg.enable_output_router, enable_video_output=cfg.video_weight > 0, enable_talker_output=cfg.audio_weight > 0, ).to(device) freeze_float_parameters(model) print(format_audit(audit_model(model))) return model def create_streams(cfg: TrainConfig): streams = {} if cfg.text_weight > 0: if cfg.text_data: streams['text'] = LocalByteStream(cfg.text_data, ctx=cfg.ctx, batch_size=cfg.batch) else: streams['text'] = FineWebStream(FineWebConfig(ctx=cfg.ctx, batch_size=cfg.batch)) if cfg.code_weight > 0: streams['code'] = StarCoderStream(StarCoderConfig(ctx=cfg.ctx, batch_size=cfg.batch)) if cfg.image_weight > 0: streams['image'] = CC12MStream(CC12MConfig(batch_size=max(1, cfg.batch // 2))) if cfg.audio_weight > 0: streams['audio'] = LibriSpeechStream(LibriSpeechConfig(batch_size=max(1, cfg.batch // 2))) if cfg.video_weight > 0: streams['video'] = WebVidStream(WebVidConfig(batch_size=max(1, cfg.batch // 4))) return streams def sample_modality(cfg: TrainConfig) -> str: weights = { 'text': cfg.text_weight, 'code': cfg.code_weight, 'image': cfg.image_weight, 'audio': cfg.audio_weight, 'video': cfg.video_weight, } active = {k: v for k, v in weights.items() if v > 0} if not active: return 'text' total = sum(active.values()) r = random.random() * total cumulative = 0.0 for k, v in active.items(): cumulative += v if r <= cumulative: return k return list(active.keys())[-1] def compute_loss(model, modality: str, batch, device): if modality in ('text', 'code'): x = batch[0].to(device, non_blocking=True) targets = x[:, 3:].contiguous() _, losses, _, _ = model(x, targets=targets) return losses.total if modality == 'image': images, captions = batch images = images.to(device, non_blocking=True) targets = captions.to(device, non_blocking=True) if targets.size(1) < 4: raise ValueError("Image caption batch must contain at least 4 byte tokens") _, losses, _, _ = model(x=targets, images=images, targets=targets[:, 3:]) return losses.total if modality == 'audio': waves, vq_targets = batch waves = waves.to(device, non_blocking=True) targets = vq_targets.to(device, non_blocking=True) if targets.size(1) < 4: raise ValueError("Audio token batch must contain at least 4 tokens") _, losses, _, _ = model(x=targets, audio=waves, targets=targets[:, 3:]) return losses.total if modality == 'video': text_tokens, latent_targets = batch text_tokens = text_tokens.to(device, non_blocking=True) latents = latent_targets.to(device, non_blocking=True) embedded = model.embedding(text_tokens) seq_out = model.multimodal_sequencer({'text': embedded}) rel = seq_out['text'] pred = model.video_head(rel) latents = match_latents(latents, pred) loss = torch.nn.functional.mse_loss(pred, latents) return loss raise ValueError(f"Unknown modality: {modality}") def match_latents(target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: if target.shape[0] == 1 and pred.shape[0] > 1: target = target.expand(pred.shape[0], -1, -1, -1, -1).contiguous() if target.shape[1] != pred.shape[1]: if target.shape[1] > pred.shape[1]: target = target[:, :pred.shape[1]] else: pad = target.new_zeros(target.shape[0], pred.shape[1] - target.shape[1], *target.shape[2:]) target = torch.cat([target, pad], dim=1) if target.shape[2:] != pred.shape[2:]: target = torch.nn.functional.interpolate( target, size=pred.shape[2:], mode="trilinear", align_corners=False ) return target def save_checkpoint(path: Path, model, step: int, loss: float, cfg: TrainConfig): if cfg.no_save: return path.parent.mkdir(parents=True, exist_ok=True) state = { 'step': step, 'loss': loss, 'model': model.state_dict(), 'config': vars(cfg), } torch.save(state, path) def load_checkpoint(path: str, model, device): ckpt_path = Path(path) if ckpt_path.is_dir(): if (ckpt_path / "latest.pt").exists(): ckpt_path = ckpt_path / "latest.pt" elif (ckpt_path / "best.pt").exists(): ckpt_path = ckpt_path / "best.pt" elif (ckpt_path / "final.pt").exists(): ckpt_path = ckpt_path / "final.pt" state = torch.load(ckpt_path, map_location=device, weights_only=True) missing, unexpected = model.load_state_dict(state['model'], strict=False) if missing or unexpected: print( "Checkpoint loaded with architecture drift: " f"{len(missing)} missing keys, {len(unexpected)} unexpected keys" ) return state.get('step', 0), state.get('loss', float('inf')) def train(cfg: TrainConfig): torch.manual_seed(cfg.seed) random.seed(cfg.seed) os.environ["ARB_TERNARY_BACKEND"] = cfg.backend if cfg.backend == "tilelang" and os.environ.get("ARB_TILELANG_TRAINING", "0").lower() not in {"1", "true", "yes"}: raise ValueError( "TileLang BigInt training is unfinished and disabled by default. " "Use --backend triton for production training." ) device = torch.device("cuda" if torch.cuda.is_available() and not cfg.cpu else "cpu") print(f"Device: {device}") print(f"Ternary backend: {cfg.backend}") model = build_model(cfg, device) streams = create_streams(cfg) if not streams: raise ValueError("No active training streams. Set at least one modality weight above 0.") print(f"Active modalities: {', '.join(streams.keys())}") params = trainable_parameters(model) if params: raise RuntimeError( "Pure ternary pretrain found trainable torch Parameters after freeze. " "Use training/finetuning for LoRA adapters." ) start_step = 0 if cfg.resume: ckpt_path = Path(cfg.resume) if ckpt_path.exists(): start_step, _ = load_checkpoint(str(ckpt_path), model, device) print(f"Resumed from step {start_step}") run_dir = Path(cfg.out_dir) / cfg.run writer = SummaryWriter(str(run_dir)) model.train() stream_iters = {k: s.batches() for k, s in streams.items()} best_loss = float('inf') last_loss = float('inf') step = start_step accum_loss = 0.0 accum_steps = 0 start_time = time.perf_counter() pbar = tqdm(range(start_step, cfg.steps), desc="train", dynamic_ncols=True, initial=start_step, total=cfg.steps) for step in pbar: modality = sample_modality(cfg) stream = stream_iters.get(modality) if stream is None: continue try: batch = next(stream) except StopIteration: stream_iters[modality] = streams[modality].batches() batch = next(stream_iters[modality]) model.zero_grad(set_to_none=True) raw_loss = compute_loss(model, modality, batch, device) last_loss = raw_loss.detach().item() loss = raw_loss if cfg.accum > 1: loss = raw_loss / cfg.accum if not torch.isfinite(loss).all(): raise FloatingPointError(f"Non-finite {modality} pretraining loss; aborting before ternary update") model.prepare_ternary_backward(loss.detach(), update_scales=True) loss.backward() accum_loss += raw_loss.detach().item() accum_steps += 1 if accum_steps >= cfg.accum: model._ternary_update_memory(accum_threshold=3, update_scales=True, loss_signal=raw_loss.detach()) model.zero_grad(set_to_none=True) report_step = step + 1 if cfg.log_interval and (step + 1) % cfg.log_interval == 0: avg = accum_loss / cfg.accum writer.add_scalar("loss/train", avg, step) pbar.set_postfix(loss=f"{avg:.4f}", mod=modality) print(f"step {report_step:>6d} loss={avg:.4f} mod={modality}") if cfg.eval_interval and (step + 1) % cfg.eval_interval == 0: avg_loss = accum_loss / cfg.accum if avg_loss < best_loss: best_loss = avg_loss save_checkpoint(run_dir / "best.pt", model, step, avg_loss, cfg) print(f"step {report_step:>6d} loss={avg_loss:.4f} mod={modality}") if cfg.save_interval and (step + 1) % cfg.save_interval == 0: save_checkpoint(run_dir / "latest.pt", model, step, accum_loss / cfg.accum, cfg) accum_loss = 0.0 accum_steps = 0 total_time = time.perf_counter() - start_time print(f"Training complete. {cfg.steps - start_step} steps in {total_time / 3600:.1f}h") save_checkpoint(run_dir / "final.pt", model, step, last_loss, cfg) writer.close() def parse_args(): p = argparse.ArgumentParser(description="Unified ARB multi-modal pre-trainer") p.add_argument("--steps", type=int, default=5000) p.add_argument("--batch", type=int, default=8) p.add_argument("--ctx", type=int, default=min(CTX, 1024)) p.add_argument("--accum", type=int, default=1) p.add_argument("--tscale-type", type=str, default="T32") p.add_argument("--backend", choices=("triton", "torch", "auto", "tilelang"), default="triton", help="Training backend. Triton is the production BigInt ternary path.") p.add_argument("--no-save", action="store_true") p.add_argument("--save-interval", type=int, default=5000) p.add_argument("--eval-interval", type=int, default=500) p.add_argument("--log-interval", type=int, default=10) p.add_argument("--seed", type=int, default=42) p.add_argument("--cpu", action="store_true") p.add_argument("--max-moe-iters", type=int, default=4) p.add_argument("--out-dir", type=str, default="models/checkpoints") p.add_argument("--run", type=str, default="pretrain") p.add_argument("--resume", type=str, default=None, help="Path to checkpoint .pt or directory with latest.pt") p.add_argument("--freeze-text", action="store_true", help=argparse.SUPPRESS) p.add_argument("--freeze-vision", action="store_true", help=argparse.SUPPRESS) p.add_argument("--freeze-audio", action="store_true", help=argparse.SUPPRESS) p.add_argument("--freeze-video", action="store_true", help=argparse.SUPPRESS) p.add_argument("--no-vq", dest="enable_vq", action="store_false") p.add_argument("--no-graph", dest="enable_graph", action="store_false") p.add_argument("--no-moe", dest="enable_moe", action="store_false") p.add_argument("--no-attention", dest="enable_attention", action="store_false") p.add_argument("--enable-output-router", action="store_true", default=False) p.set_defaults(enable_vq=True, enable_graph=True, enable_moe=True, enable_attention=True) p.add_argument("--text-weight", type=float, default=1.0) p.add_argument("--code-weight", type=float, default=0.0) p.add_argument("--image-weight", type=float, default=0.0) p.add_argument("--audio-weight", type=float, default=0.0) p.add_argument("--video-weight", type=float, default=0.0) p.add_argument("--text-data", type=str, default=None, help="Optional local .txt/.pt byte data for text pretraining smoke/bootstrap") return p.parse_args() if __name__ == "__main__": cfg = TrainConfig(**vars(parse_args())) train(cfg)