| --- |
| license: apache-2.0 |
| tags: |
| - vision |
| - image-classification |
| - ternary |
| - quantization |
| - vit |
| datasets: |
| - imagenet-1k |
| - cifar10 |
| - cifar100 |
| --- |
| |
| # FTerViT: Fully Ternary Vision Transformer |
|
|
| [](https://arxiv.org/abs/2605.21171) |
| [](https://github.com/szymonrucinski/FTerViT) |
| [](https://huggingface.co/szymonrucinski/FTerViT) |
| [](https://neurips.cc/) |
| [](https://huggingface.co/spaces/szymonrucinski/FTerViT-demo) |
|
|
| Pretrained checkpoints for **FTerViT** — the first fully ternary Vision Transformer where *all* weight matrices and normalization parameters are constrained to {-1, 0, +1}. |
|
|
| > **W2A8** · 2-bit weights · 8-bit activations · **100% ternary** · 15x compression · sub-6 MB models |
|
|
| ## 🏆 Key Results |
|
|
| All models use **W2A8** (2-bit weights, 8-bit activations) with 100% ternary coverage — including patch embedding, LayerNorm, and classifier head. |
|
|
| ### 📊 ImageNet-1K |
|
|
| | Model | Phase | Epochs | Top-1 (%) | Binary (MB) | Compression | Checkpoint | |
| |-------|-------|--------|-----------|-------------|-------------|------------| |
| | DeiT-Small | Phase 1 | 250 | 75.05 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc75.05_deit_small_224.pth) | |
| | DeiT-III-Small | Phase 1 | 250 | 76.78 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc76.78_deit3_small_224.pth) | |
| | DeiT-Small | Phase 2 | +10 | **77.47** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc77.47_deit_small_224.pth) | |
| | DeiT-III-Small | Phase 2 | +10 | **79.64** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth) | |
| | DeiT-III-Small (384) | Phase 1 | 250 | 78.35 | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc78.35_deit3_small_384.pth) | |
| | DeiT-III-Small (384) | Phase 2 | +10 | **82.43** | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc82.43_deit3_small_384.pth) | |
|
|
| ### 📊 CIFAR-10 / CIFAR-100 |
|
|
| | Model | Dataset | Top-1 (%) | FP32 Baseline | Binary (MB) | Checkpoint | |
| |-------|---------|-----------|---------------|-------------|------------| |
| | DeiT-Tiny | CIFAR-10 | **97.43** | 97.52 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar10/phase2_ep010_acc97.43_deit_tiny_224.pth) | |
| | DeiT-Tiny | CIFAR-100 | **86.01** | 86.54 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar100/phase2_ep010_acc86.01_deit_tiny_224.pth) | |
|
|
| ## 🔧 Training Protocol |
|
|
| Training uses a two-phase knowledge distillation approach: |
|
|
| - **Phase 1:** QAD with frozen FP32 teacher, KL-only loss, lr=1e-4 cosine decay, 250 epochs |
| - **Phase 2:** Low-lr recovery fine-tuning, lr=1e-5 cosine decay, 10 epochs |
|
|
| See the paper for full details. |
|
|
| ## 🚀 Self-Contained Inference Example |
|
|
| The code below loads and evaluates a FTerViT checkpoint **without any external dependencies beyond `torch`, `timm`, and `huggingface_hub`**. All ternary layer definitions are included inline. |
| |
| ```python |
| """ |
| FTerViT — self-contained inference example. |
| Requirements: pip install torch timm huggingface_hub torchvision |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from huggingface_hub import hf_hub_download |
| |
| # ============================================================================ |
| # Ternary quantization primitives |
| # ============================================================================ |
| |
| def activation_quant(x: torch.Tensor) -> torch.Tensor: |
| """Per-token INT8 activation quantization.""" |
| scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5) |
| return (x * scale).round().clamp_(-128, 127) / scale |
| |
| def activation_quant_2d(x: torch.Tensor) -> torch.Tensor: |
| """Per-channel INT8 activation quantization for Conv2d (NCHW).""" |
| scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5) |
| return (x * scale).round().clamp_(-128, 127) / scale |
| |
| def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor: |
| """Ternary weight quantization: {-1, 0, +1} with absmean scaling.""" |
| scale = 1.0 / w.abs().mean().clamp_(min=1e-5) |
| return (w * scale).round().clamp_(-1, 1) / scale |
| |
| def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor: |
| """Per-output-channel ternary quantization.""" |
| scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5) |
| return (w * scale).round().clamp_(-1, 1) / scale |
| |
| # ============================================================================ |
| # Ternary layer definitions |
| # ============================================================================ |
| |
| class BitLinear(nn.Linear): |
| """Linear layer with ternary weights and INT8 activations.""" |
| def __init__(self, in_features, out_features, bias=True): |
| super().__init__(in_features, out_features, bias) |
| self.norm = nn.RMSNorm(in_features, eps=1e-5) |
| |
| def forward(self, x): |
| x_norm = self.norm(x) |
| if not self.training: |
| max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5) |
| x_scale = 127.0 / max_val |
| x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16) |
| w_f = self.weight.float() |
| w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5) |
| w_q = (w_f * w_scale).round().clamp_(-1, 1) |
| y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16)) |
| return y.to(x_norm.dtype) |
| # Training path (STE) |
| x_q = x_norm + (activation_quant(x_norm) - x_norm).detach() |
| w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach() |
| return F.linear(x_q, w_q) |
| |
| class BitConv2d(nn.Conv2d): |
| """Conv2d with per-channel ternary weights and INT8 activations.""" |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): |
| super().__init__(in_channels, out_channels, kernel_size, stride=stride, |
| padding=padding, dilation=dilation, groups=groups, |
| bias=bias, padding_mode=padding_mode) |
| self.channel_scale = nn.Parameter(torch.ones(out_channels)) |
| |
| def _quant_weight(self): |
| w_q = weight_quant_ternary_per_channel(self.weight) |
| return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1))) |
| |
| def forward(self, x): |
| if self.training: |
| x_q = x + (activation_quant_2d(x) - x).detach() |
| w_q = self.weight + (self._quant_weight() - self.weight).detach() |
| else: |
| x_q = activation_quant_2d(x) |
| w_q = self._quant_weight() |
| return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups) |
| |
| class TernaryLayerNorm(nn.Module): |
| """LayerNorm with ternary affine parameters (gamma, beta).""" |
| def __init__(self, normalized_shape, eps=1e-5): |
| super().__init__() |
| if isinstance(normalized_shape, int): |
| normalized_shape = (normalized_shape,) |
| self.normalized_shape = tuple(normalized_shape) |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(self.normalized_shape)) |
| self.bias = nn.Parameter(torch.zeros(self.normalized_shape)) |
| |
| def forward(self, x): |
| if self.training: |
| w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach() |
| b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach() |
| else: |
| w_q = weight_quant_ternary(self.weight) |
| b_q = weight_quant_ternary(self.bias) |
| return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps) |
| |
| # ============================================================================ |
| # Model conversion: FP32 timm model -> fully ternary |
| # ============================================================================ |
| |
| def make_ternary(model: nn.Module) -> nn.Module: |
| """Convert all Linear, LayerNorm, and patch embed Conv2d to ternary.""" |
| # Linear -> BitLinear |
| for name, module in list(model.named_modules()): |
| if isinstance(module, nn.Linear): |
| parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name) |
| parent = model if not parent_name else dict(model.named_modules())[parent_name] |
| setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None)) |
| # LayerNorm -> TernaryLayerNorm |
| for name, module in list(model.named_modules()): |
| if isinstance(module, nn.LayerNorm): |
| parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name) |
| parent = model if not parent_name else dict(model.named_modules())[parent_name] |
| setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps)) |
| # Patch embed Conv2d -> BitConv2d |
| patch_embed = getattr(model, "patch_embed", None) |
| if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d): |
| old = patch_embed.proj |
| new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size, |
| stride=old.stride, padding=old.padding, bias=old.bias is not None) |
| patch_embed.proj = new |
| return model |
| |
| # ============================================================================ |
| # Load and evaluate |
| # ============================================================================ |
| |
| import timm |
| from torchvision import datasets, transforms |
| |
| # --- Configuration (change these) --- |
| MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k" |
| CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth" |
| DATASET = "imagenet" # "imagenet", "cifar10", or "cifar100" |
| DATA_DIR = "./data/imagenet" # path to ImageNet val/ or CIFAR download dir |
| NUM_CLASSES = 1000 # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100 |
| BATCH_SIZE = 128 |
| # ------------------------------------ |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| # 1. Build model + ternary conversion |
| model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES) |
| model = make_ternary(model) |
| |
| # 2. Load checkpoint |
| path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT) |
| sd = torch.load(path, map_location="cpu", weights_only=False) |
| sd = {k.removeprefix("timm_model."): v for k, v in sd.items()} |
| model.load_state_dict(sd, strict=False) |
| model = model.to(device).eval() |
| |
| # 3. Build eval dataloader |
| from timm.data import resolve_data_config, create_transform |
| config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False)) |
| |
| if DATASET == "imagenet": |
| transform = create_transform(**config, is_training=False) |
| val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform) |
| else: |
| # CIFAR models were trained with mean/std = [0.5, 0.5, 0.5] |
| transform = transforms.Compose([ |
| transforms.Resize((config["input_size"][1], config["input_size"][2])), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
| ]) |
| cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100 |
| val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform) |
| |
| val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, |
| num_workers=4, pin_memory=True) |
| |
| # 4. Evaluate |
| correct = total = 0 |
| with torch.no_grad(): |
| for images, labels in val_loader: |
| images, labels = images.to(device), labels.to(device) |
| preds = model(images).argmax(dim=1) |
| correct += (preds == labels).sum().item() |
| total += labels.size(0) |
| |
| print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)") |
| print(f"Evaluated {total} samples") |
| ``` |
| |
| ## 📝 Citation |
| |
| ```bibtex |
| @article{rucinski2026ftervit, |
| title={FTerViT: Fully Ternary Vision Transformer}, |
| author={Ruci{\'n}ski, Szymon and Bonazzi, Pietro and Turetken, Engin and Narduzzi, Simon and Magno, Michele and Maamari, Nadim}, |
| journal={arXiv preprint arXiv:2605.21171}, |
| year={2026} |
| } |
| ``` |
| |