""" Fashion-MNIST Trainer with MobiusCollective ============================================ Train a wide collective of MobiusLens towers on Fashion-MNIST. Designed for Colab with TensorBoard logging and HuggingFace upload. License: Apache 2.0 Date: 2025-01-10 Author: AbstractPhil """ import os import json import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from typing import Tuple, Dict, Any, Optional from torchvision import datasets, transforms from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm.auto import tqdm from datetime import datetime from pathlib import Path from safetensors.torch import save_file as save_safetensors # HuggingFace login for Colab try: from huggingface_hub import HfApi, login from google.colab import userdata token = userdata.get('HF_TOKEN') os.environ['HF_TOKEN'] = token login(token=token) print("Logged in to HuggingFace via Colab") HF_AVAILABLE = True except: HF_AVAILABLE = False print("HuggingFace upload disabled (not in Colab or no token)") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # TF32 for Ampere+ torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision('high') # ============================================================================ # IMPORTS FROM GEOFRACTAL # ============================================================================ from geofractal.router.wide_router import WideRouter from geofractal.router.base_tower import BaseTower from geofractal.router.components.torch_component import TorchComponent from geofractal.router.components.lens_component import MobiusLens, TriWaveLens from geofractal.router.components.fusion_component import AdaptiveFusion # ============================================================================ # CONV LENS BLOCK # ============================================================================ class ConvLensBlock(TorchComponent): """Depthwise-separable conv with MobiusLens activation.""" def __init__( self, name: str, channels: int, layer_idx: int, total_layers: int, scale_range: Tuple[float, float] = (0.5, 2.5), use_mobius: bool = True, ): super().__init__(name) self.conv = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels), ) if use_mobius: self.lens = MobiusLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range) else: self.lens = TriWaveLens(f'{name}_lens', channels, layer_idx, total_layers, scale_range) self.residual_weight = nn.Parameter(torch.tensor(0.9)) def forward(self, x: Tensor) -> Tensor: identity = x h = self.conv(x) B, C, H, W = h.shape h = h.permute(0, 2, 3, 1) h = self.lens(h) h = h.permute(0, 3, 1, 2) rw = torch.sigmoid(self.residual_weight) return rw * identity + (1 - rw) * h # ============================================================================ # LENS TOWER # ============================================================================ class LensTower(BaseTower): """Shallow tower covering a segment of the scale continuum.""" def __init__( self, name: str, channels: int, depth: int, tower_idx: int, num_towers: int, scale_range: Tuple[float, float] = (0.5, 2.5), use_mobius: bool = True, ): super().__init__(name, strict=False) self.tower_idx = tower_idx self.channels = channels total_layers = num_towers * depth start_layer = tower_idx * depth for i in range(depth): global_idx = start_layer + i block = ConvLensBlock( f'{name}_block_{i}', channels, layer_idx=global_idx, total_layers=total_layers, scale_range=scale_range, use_mobius=use_mobius, ) self.append(block) self.attach('norm', nn.BatchNorm2d(channels)) def forward(self, x: Tensor) -> Tensor: for stage in self.stages: x = stage(x) return self['norm'](x) # ============================================================================ # VISION ADAPTIVE FUSION (wraps AdaptiveFusion for BCHW tensors) # ============================================================================ class VisionAdaptiveFusion(TorchComponent): """ Wraps AdaptiveFusion for vision tensors (B, C, H, W). Permutes to channel-last, fuses, permutes back. """ def __init__(self, name: str, num_towers: int, channels: int): super().__init__(name) self.num_towers = num_towers self.fusion = AdaptiveFusion( f'{name}_adaptive', num_inputs=num_towers, in_features=channels, ) # Output projection (conv for spatial tensors) self.proj = nn.Sequential( nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels), ) def forward(self, *opinions: Tensor) -> Tensor: """ Args: *opinions: N tensors of shape (B, C, H, W) Returns: Fused tensor of shape (B, C, H, W) """ # Permute all to channel-last: (B, H, W, C) channel_last = [op.permute(0, 2, 3, 1) for op in opinions] # Fuse using AdaptiveFusion fused = self.fusion(*channel_last) # (B, H, W, C) # Permute back: (B, C, H, W) fused = fused.permute(0, 3, 1, 2) return self.proj(fused) # ============================================================================ # MOBIUS COLLECTIVE # ============================================================================ class MobiusCollective(WideRouter): """ Wide collective with MobiusLens towers. Architecture: - Light stem (configurable stride) - Multiple shallow towers in parallel (scale continuum) - Adaptive fusion + classification head """ def __init__( self, name: str = 'mobius_collective', in_channels: int = 1, channels: int = 64, num_towers: int = 4, depth_per_tower: int = 2, scale_range: Tuple[float, float] = (0.5, 2.5), use_mobius: bool = True, num_classes: int = 10, stem_stride: int = 2, ): super().__init__(name, auto_discover=True) self.in_channels = in_channels self.channels = channels self.num_towers = num_towers self.depth_per_tower = depth_per_tower self.scale_range = scale_range self.use_mobius = use_mobius self.num_classes = num_classes self.stem_stride = stem_stride # Stem self.attach('stem', nn.Sequential( nn.Conv2d(in_channels, channels, 3, stride=stem_stride, padding=1, bias=False), nn.BatchNorm2d(channels), nn.ReLU(inplace=True), )) # Towers for i in range(num_towers): tower = LensTower( f'tower_{i}', channels=channels, depth=depth_per_tower, tower_idx=i, num_towers=num_towers, scale_range=scale_range, use_mobius=use_mobius, ) self.attach(f'tower_{i}', tower) self.discover_towers() # Fusion (wraps geofractal's AdaptiveFusion for vision tensors) self.attach('fusion', VisionAdaptiveFusion('fusion', num_towers, channels)) # Head self.attach('pool', nn.AdaptiveAvgPool2d(1)) self.attach('head', nn.Linear(channels, num_classes)) def forward(self, x: Tensor) -> Tensor: x = self['stem'](x) opinions = self.wide_forward(x) opinion_list = [opinions[f'tower_{i}'] for i in range(self.num_towers)] fused = self['fusion'](*opinion_list) fused = self['pool'](fused).flatten(1) return self['head'](fused) def get_config(self) -> Dict[str, Any]: return { 'in_channels': self.in_channels, 'channels': self.channels, 'num_towers': self.num_towers, 'depth_per_tower': self.depth_per_tower, 'scale_range': self.scale_range, 'use_mobius': self.use_mobius, 'num_classes': self.num_classes, 'stem_stride': self.stem_stride, } def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]: """Return stats from all lenses for logging.""" stats = {} for tower_name in self.tower_names: tower = self[tower_name] for i, stage in enumerate(tower.stages): key = f"{tower_name}_block_{i}" stats[key] = stage.lens.get_lens_stats() return stats # ============================================================================ # PRESETS # ============================================================================ PRESETS = { 'fashion_mobius_tiny': { 'channels': 32, 'num_towers': 3, 'depth_per_tower': 2, 'scale_range': (0.5, 2.0), 'use_mobius': True, }, 'fashion_mobius_small': { 'channels': 64, 'num_towers': 4, 'depth_per_tower': 2, 'scale_range': (0.5, 2.5), 'use_mobius': True, }, 'fashion_mobius_base': { 'channels': 96, 'num_towers': 4, 'depth_per_tower': 3, 'scale_range': (0.25, 2.75), 'use_mobius': True, }, 'fashion_tri_small': { 'channels': 64, 'num_towers': 4, 'depth_per_tower': 2, 'scale_range': (0.5, 2.5), 'use_mobius': False, }, } # ============================================================================ # DATA # ============================================================================ def get_fashion_mnist_loaders(data_dir: str = './data', batch_size: int = 128): """Get Fashion-MNIST train/val loaders with augmentation.""" train_transform = transforms.Compose([ transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,)), ]) val_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,)), ]) train_dataset = datasets.FashionMNIST( data_dir, train=True, download=True, transform=train_transform ) val_dataset = datasets.FashionMNIST( data_dir, train=False, download=True, transform=val_transform ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True ) val_loader = DataLoader( val_dataset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True ) return train_loader, val_loader # ============================================================================ # CHECKPOINT MANAGER # ============================================================================ class CheckpointManager: """Handles saving, logging, and optional HF upload.""" def __init__( self, output_dir: str, experiment_name: str, hf_repo: Optional[str] = None, save_every: int = 10, upload_every: int = 20, ): self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") self.experiment_name = experiment_name self.hf_repo = hf_repo self.save_every = save_every self.upload_every = upload_every self.run_dir = Path(output_dir) / experiment_name / self.timestamp self.ckpt_dir = self.run_dir / "checkpoints" self.tb_dir = self.run_dir / "tensorboard" self.ckpt_dir.mkdir(parents=True, exist_ok=True) self.tb_dir.mkdir(parents=True, exist_ok=True) self.writer = SummaryWriter(log_dir=str(self.tb_dir)) self.hf_api = HfApi() if HF_AVAILABLE and hf_repo else None self.best_acc = 0.0 self.best_epoch = 0 print(f"Checkpoints: {self.run_dir}") def save_config(self, model_config: Dict, train_config: Dict): config = { 'model': model_config, 'training': train_config, 'timestamp': self.timestamp, } with open(self.run_dir / "config.json", 'w') as f: json.dump(config, f, indent=2) def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""): for name, value in scalars.items(): tag = f"{prefix}/{name}" if prefix else name self.writer.add_scalar(tag, value, epoch) def log_lens_stats(self, epoch: int, model: nn.Module): raw = model._orig_mod if hasattr(model, '_orig_mod') else model stats = raw.get_all_lens_stats() for block_name, block_stats in stats.items(): for stat_name, value in block_stats.items(): if isinstance(value, (int, float)): self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch) def save_checkpoint( self, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler, epoch: int, train_acc: float, val_acc: float, train_loss: float, ): raw = model._orig_mod if hasattr(model, '_orig_mod') else model is_best = val_acc > self.best_acc if is_best: self.best_acc = val_acc self.best_epoch = epoch # Save best save_safetensors(raw.state_dict(), str(self.ckpt_dir / "best_model.safetensors")) torch.save({ 'epoch': epoch, 'model_state_dict': raw.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_acc': self.best_acc, 'train_acc': train_acc, 'val_acc': val_acc, }, self.ckpt_dir / "best_model.pt") # Periodic save if epoch % self.save_every == 0: save_safetensors(raw.state_dict(), str(self.ckpt_dir / f"epoch_{epoch:04d}.safetensors")) def upload(self, epoch: int, force: bool = False): if not self.hf_api or not self.hf_repo: return if not force and epoch % self.upload_every != 0: return try: hf_path = f"fashion_mnist/{self.experiment_name}/{self.timestamp}" for f in [self.run_dir / "config.json", self.ckpt_dir / "best_model.safetensors"]: if f.exists(): self.hf_api.upload_file( path_or_fileobj=str(f), path_in_repo=f"{hf_path}/{f.name}", repo_id=self.hf_repo, repo_type="model", ) print(f"Uploaded to {self.hf_repo}/{hf_path}") except Exception as e: print(f"Upload failed: {e}") def close(self): self.writer.close() # ============================================================================ # TRAINING # ============================================================================ def train_fashion_mnist( preset: str = 'fashion_mobius_small', epochs: int = 50, lr: float = 1e-3, batch_size: int = 128, output_dir: str = './outputs', hf_repo: Optional[str] = 'AbstractPhil/mobiusnet-collective', use_compile: bool = True, save_every: int = 10, upload_every: int = 20, ): """Train MobiusCollective on Fashion-MNIST.""" config = PRESETS[preset] print("=" * 70) print(f"FASHION-MNIST - {preset.upper()}") print("=" * 70) print(f"Channels: {config['channels']}") print(f"Towers: {config['num_towers']} x {config['depth_per_tower']} depth") print(f"Scale range: {config['scale_range']}") print(f"Lens: {'Mobius' if config['use_mobius'] else 'TriWave'}") print() # Data train_loader, val_loader = get_fashion_mnist_loaders('./data', batch_size) # Model model = MobiusCollective( name=preset, in_channels=1, # Fashion-MNIST is grayscale num_classes=10, stem_stride=2, # 28x28 -> 14x14 **config, ).to(device) total_params = sum(p.numel() for p in model.parameters()) print(f"Total params: {total_params:,}") # Checkpoint manager ckpt = CheckpointManager( output_dir=output_dir, experiment_name=preset, hf_repo=hf_repo, save_every=save_every, upload_every=upload_every, ) # Save config train_config = { 'epochs': epochs, 'lr': lr, 'batch_size': batch_size, 'optimizer': 'AdamW', 'scheduler': 'CosineAnnealingLR', 'total_params': total_params, } ckpt.save_config(model.get_config(), train_config) # Compile if use_compile and hasattr(torch, 'compile'): print("Compiling model...") model = torch.compile(model, mode='reduce-overhead') # Optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) best_acc = 0.0 for epoch in range(1, epochs + 1): # Train model.train() train_loss, train_correct, train_total = 0, 0, 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") for x, y in pbar: x, y = x.to(device), y.to(device) optimizer.zero_grad() logits = model(x) loss = F.cross_entropy(logits, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() train_loss += loss.item() * x.size(0) train_correct += (logits.argmax(1) == y).sum().item() train_total += x.size(0) pbar.set_postfix(loss=f"{loss.item():.4f}") scheduler.step() # Validate model.eval() val_correct, val_total = 0, 0 with torch.no_grad(): for x, y in val_loader: x, y = x.to(device), y.to(device) logits = model(x) val_correct += (logits.argmax(1) == y).sum().item() val_total += x.size(0) # Metrics train_acc = train_correct / train_total val_acc = val_correct / val_total avg_loss = train_loss / train_total current_lr = scheduler.get_last_lr()[0] is_best = val_acc > best_acc if is_best: best_acc = val_acc marker = " ★" if is_best else "" print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") # Logging ckpt.log_scalars(epoch, { 'loss': avg_loss, 'train_acc': train_acc, 'val_acc': val_acc, 'best_acc': best_acc, 'lr': current_lr, }, prefix='train') ckpt.log_lens_stats(epoch, model) # Save ckpt.save_checkpoint(model, optimizer, scheduler, epoch, train_acc, val_acc, avg_loss) # Upload ckpt.upload(epoch) # Final upload ckpt.upload(epochs, force=True) ckpt.close() print() print("=" * 70) print("TRAINING COMPLETE") print("=" * 70) print(f"Preset: {preset}") print(f"Best accuracy: {best_acc:.4f}") print(f"Params: {total_params:,}") print(f"Checkpoints: {ckpt.run_dir}") print("=" * 70) return model, best_acc # ============================================================================ # MAIN # ============================================================================ if __name__ == '__main__': model, best_acc = train_fashion_mnist( preset='fashion_mobius_small', epochs=50, lr=1e-3, batch_size=128, output_dir='./outputs', hf_repo='AbstractPhil/mobiusnet-collective', # Set to None to disable upload use_compile=True, save_every=10, upload_every=20, )