--- license: apache-2.0 tags: - vision - image-classification - ternary - quantization - vit datasets: - imagenet-1k - cifar10 - cifar100 --- # FTerViT: Fully Ternary Vision Transformer [![arXiv](https://img.shields.io/badge/arXiv-2605.21171-B31B1B?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2605.21171) [![GitHub](https://img.shields.io/badge/GitHub-FTerViT-181717?style=for-the-badge&logo=github&logoColor=white)](https://github.com/szymonrucinski/FTerViT) [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-FTerViT-FFD21E?style=for-the-badge)](https://huggingface.co/szymonrucinski/FTerViT) [![NeurIPS](https://img.shields.io/badge/NeurIPS-2026-purple?style=for-the-badge)](https://neurips.cc/) [![Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-Live-orange?style=for-the-badge)](https://huggingface.co/spaces/szymonrucinski/FTerViT-demo) Pretrained checkpoints for **FTerViT** — the first fully ternary Vision Transformer where *all* weight matrices and normalization parameters are constrained to {-1, 0, +1}. > **W2A8** · 2-bit weights · 8-bit activations · **100% ternary** · 15x compression · sub-6 MB models ## 🏆 Key Results All models use **W2A8** (2-bit weights, 8-bit activations) with 100% ternary coverage — including patch embedding, LayerNorm, and classifier head. ### 📊 ImageNet-1K | Model | Phase | Epochs | Top-1 (%) | Binary (MB) | Compression | Checkpoint | |-------|-------|--------|-----------|-------------|-------------|------------| | DeiT-Small | Phase 1 | 250 | 75.05 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc75.05_deit_small_224.pth) | | DeiT-III-Small | Phase 1 | 250 | 76.78 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc76.78_deit3_small_224.pth) | | DeiT-Small | Phase 2 | +10 | **77.47** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc77.47_deit_small_224.pth) | | DeiT-III-Small | Phase 2 | +10 | **79.64** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth) | | DeiT-III-Small (384) | Phase 1 | 250 | 78.35 | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc78.35_deit3_small_384.pth) | | DeiT-III-Small (384) | Phase 2 | +10 | **82.43** | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc82.43_deit3_small_384.pth) | ### 📊 CIFAR-10 / CIFAR-100 | Model | Dataset | Top-1 (%) | FP32 Baseline | Binary (MB) | Checkpoint | |-------|---------|-----------|---------------|-------------|------------| | DeiT-Tiny | CIFAR-10 | **97.43** | 97.52 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar10/phase2_ep010_acc97.43_deit_tiny_224.pth) | | DeiT-Tiny | CIFAR-100 | **86.01** | 86.54 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar100/phase2_ep010_acc86.01_deit_tiny_224.pth) | ## 🔧 Training Protocol Training uses a two-phase knowledge distillation approach: - **Phase 1:** QAD with frozen FP32 teacher, KL-only loss, lr=1e-4 cosine decay, 250 epochs - **Phase 2:** Low-lr recovery fine-tuning, lr=1e-5 cosine decay, 10 epochs See the paper for full details. ## 🚀 Self-Contained Inference Example The code below loads and evaluates a FTerViT checkpoint **without any external dependencies beyond `torch`, `timm`, and `huggingface_hub`**. All ternary layer definitions are included inline. ```python """ FTerViT — self-contained inference example. Requirements: pip install torch timm huggingface_hub torchvision """ import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import hf_hub_download # ============================================================================ # Ternary quantization primitives # ============================================================================ def activation_quant(x: torch.Tensor) -> torch.Tensor: """Per-token INT8 activation quantization.""" scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5) return (x * scale).round().clamp_(-128, 127) / scale def activation_quant_2d(x: torch.Tensor) -> torch.Tensor: """Per-channel INT8 activation quantization for Conv2d (NCHW).""" scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5) return (x * scale).round().clamp_(-128, 127) / scale def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor: """Ternary weight quantization: {-1, 0, +1} with absmean scaling.""" scale = 1.0 / w.abs().mean().clamp_(min=1e-5) return (w * scale).round().clamp_(-1, 1) / scale def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor: """Per-output-channel ternary quantization.""" scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5) return (w * scale).round().clamp_(-1, 1) / scale # ============================================================================ # Ternary layer definitions # ============================================================================ class BitLinear(nn.Linear): """Linear layer with ternary weights and INT8 activations.""" def __init__(self, in_features, out_features, bias=True): super().__init__(in_features, out_features, bias) self.norm = nn.RMSNorm(in_features, eps=1e-5) def forward(self, x): x_norm = self.norm(x) if not self.training: max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5) x_scale = 127.0 / max_val x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16) w_f = self.weight.float() w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5) w_q = (w_f * w_scale).round().clamp_(-1, 1) y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16)) return y.to(x_norm.dtype) # Training path (STE) x_q = x_norm + (activation_quant(x_norm) - x_norm).detach() w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach() return F.linear(x_q, w_q) class BitConv2d(nn.Conv2d): """Conv2d with per-channel ternary weights and INT8 activations.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) self.channel_scale = nn.Parameter(torch.ones(out_channels)) def _quant_weight(self): w_q = weight_quant_ternary_per_channel(self.weight) return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1))) def forward(self, x): if self.training: x_q = x + (activation_quant_2d(x) - x).detach() w_q = self.weight + (self._quant_weight() - self.weight).detach() else: x_q = activation_quant_2d(x) w_q = self._quant_weight() return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups) class TernaryLayerNorm(nn.Module): """LayerNorm with ternary affine parameters (gamma, beta).""" def __init__(self, normalized_shape, eps=1e-5): super().__init__() if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.normalized_shape = tuple(normalized_shape) self.eps = eps self.weight = nn.Parameter(torch.ones(self.normalized_shape)) self.bias = nn.Parameter(torch.zeros(self.normalized_shape)) def forward(self, x): if self.training: w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach() b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach() else: w_q = weight_quant_ternary(self.weight) b_q = weight_quant_ternary(self.bias) return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps) # ============================================================================ # Model conversion: FP32 timm model -> fully ternary # ============================================================================ def make_ternary(model: nn.Module) -> nn.Module: """Convert all Linear, LayerNorm, and patch embed Conv2d to ternary.""" # Linear -> BitLinear for name, module in list(model.named_modules()): if isinstance(module, nn.Linear): parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name) parent = model if not parent_name else dict(model.named_modules())[parent_name] setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None)) # LayerNorm -> TernaryLayerNorm for name, module in list(model.named_modules()): if isinstance(module, nn.LayerNorm): parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name) parent = model if not parent_name else dict(model.named_modules())[parent_name] setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps)) # Patch embed Conv2d -> BitConv2d patch_embed = getattr(model, "patch_embed", None) if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d): old = patch_embed.proj new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size, stride=old.stride, padding=old.padding, bias=old.bias is not None) patch_embed.proj = new return model # ============================================================================ # Load and evaluate # ============================================================================ import timm from torchvision import datasets, transforms # --- Configuration (change these) --- MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k" CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth" DATASET = "imagenet" # "imagenet", "cifar10", or "cifar100" DATA_DIR = "./data/imagenet" # path to ImageNet val/ or CIFAR download dir NUM_CLASSES = 1000 # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100 BATCH_SIZE = 128 # ------------------------------------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. Build model + ternary conversion model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES) model = make_ternary(model) # 2. Load checkpoint path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT) sd = torch.load(path, map_location="cpu", weights_only=False) sd = {k.removeprefix("timm_model."): v for k, v in sd.items()} model.load_state_dict(sd, strict=False) model = model.to(device).eval() # 3. Build eval dataloader from timm.data import resolve_data_config, create_transform config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False)) if DATASET == "imagenet": transform = create_transform(**config, is_training=False) val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform) else: # CIFAR models were trained with mean/std = [0.5, 0.5, 0.5] transform = transforms.Compose([ transforms.Resize((config["input_size"][1], config["input_size"][2])), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100 val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True) # 4. Evaluate correct = total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) preds = model(images).argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)") print(f"Evaluated {total} samples") ``` ## 📝 Citation ```bibtex @article{rucinski2026ftervit, title={FTerViT: Fully Ternary Vision Transformer}, author={Ruci{\'n}ski, Szymon and Bonazzi, Pietro and Turetken, Engin and Narduzzi, Simon and Magno, Michele and Maamari, Nadim}, journal={arXiv preprint arXiv:2605.21171}, year={2026} } ```