File size: 12,874 Bytes
d273128 c86ca64 dd50e51 c86ca64 dd50e51 d273128 dd50e51 d273128 dd50e51 d273128 dd50e51 d273128 df30149 31e143d 24abf46 d273128 dd50e51 d273128 df30149 d273128 dd50e51 d273128 dd50e51 c6803f8 d273128 31e143d c6803f8 31e143d c6803f8 31e143d d273128 c6803f8 d273128 c6803f8 31e143d d273128 dd50e51 d273128 31e143d 2576dac 31e143d 2576dac 31e143d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | ---
license: apache-2.0
tags:
- vision
- image-classification
- ternary
- quantization
- vit
datasets:
- imagenet-1k
- cifar10
- cifar100
---
# FTerViT: Fully Ternary Vision Transformer
[](https://arxiv.org/abs/2605.21171)
[](https://github.com/szymonrucinski/FTerViT)
[](https://huggingface.co/szymonrucinski/FTerViT)
[](https://neurips.cc/)
[](https://huggingface.co/spaces/szymonrucinski/FTerViT-demo)
Pretrained checkpoints for **FTerViT** — the first fully ternary Vision Transformer where *all* weight matrices and normalization parameters are constrained to {-1, 0, +1}.
> **W2A8** · 2-bit weights · 8-bit activations · **100% ternary** · 15x compression · sub-6 MB models
## 🏆 Key Results
All models use **W2A8** (2-bit weights, 8-bit activations) with 100% ternary coverage — including patch embedding, LayerNorm, and classifier head.
### 📊 ImageNet-1K
| Model | Phase | Epochs | Top-1 (%) | Binary (MB) | Compression | Checkpoint |
|-------|-------|--------|-----------|-------------|-------------|------------|
| DeiT-Small | Phase 1 | 250 | 75.05 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc75.05_deit_small_224.pth) |
| DeiT-III-Small | Phase 1 | 250 | 76.78 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc76.78_deit3_small_224.pth) |
| DeiT-Small | Phase 2 | +10 | **77.47** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc77.47_deit_small_224.pth) |
| DeiT-III-Small | Phase 2 | +10 | **79.64** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth) |
| DeiT-III-Small (384) | Phase 1 | 250 | 78.35 | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc78.35_deit3_small_384.pth) |
| DeiT-III-Small (384) | Phase 2 | +10 | **82.43** | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc82.43_deit3_small_384.pth) |
### 📊 CIFAR-10 / CIFAR-100
| Model | Dataset | Top-1 (%) | FP32 Baseline | Binary (MB) | Checkpoint |
|-------|---------|-----------|---------------|-------------|------------|
| DeiT-Tiny | CIFAR-10 | **97.43** | 97.52 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar10/phase2_ep010_acc97.43_deit_tiny_224.pth) |
| DeiT-Tiny | CIFAR-100 | **86.01** | 86.54 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar100/phase2_ep010_acc86.01_deit_tiny_224.pth) |
## 🔧 Training Protocol
Training uses a two-phase knowledge distillation approach:
- **Phase 1:** QAD with frozen FP32 teacher, KL-only loss, lr=1e-4 cosine decay, 250 epochs
- **Phase 2:** Low-lr recovery fine-tuning, lr=1e-5 cosine decay, 10 epochs
See the paper for full details.
## 🚀 Self-Contained Inference Example
The code below loads and evaluates a FTerViT checkpoint **without any external dependencies beyond `torch`, `timm`, and `huggingface_hub`**. All ternary layer definitions are included inline.
```python
"""
FTerViT — self-contained inference example.
Requirements: pip install torch timm huggingface_hub torchvision
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
# ============================================================================
# Ternary quantization primitives
# ============================================================================
def activation_quant(x: torch.Tensor) -> torch.Tensor:
"""Per-token INT8 activation quantization."""
scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
return (x * scale).round().clamp_(-128, 127) / scale
def activation_quant_2d(x: torch.Tensor) -> torch.Tensor:
"""Per-channel INT8 activation quantization for Conv2d (NCHW)."""
scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5)
return (x * scale).round().clamp_(-128, 127) / scale
def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor:
"""Ternary weight quantization: {-1, 0, +1} with absmean scaling."""
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
return (w * scale).round().clamp_(-1, 1) / scale
def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor:
"""Per-output-channel ternary quantization."""
scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5)
return (w * scale).round().clamp_(-1, 1) / scale
# ============================================================================
# Ternary layer definitions
# ============================================================================
class BitLinear(nn.Linear):
"""Linear layer with ternary weights and INT8 activations."""
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.norm = nn.RMSNorm(in_features, eps=1e-5)
def forward(self, x):
x_norm = self.norm(x)
if not self.training:
max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
x_scale = 127.0 / max_val
x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16)
w_f = self.weight.float()
w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5)
w_q = (w_f * w_scale).round().clamp_(-1, 1)
y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16))
return y.to(x_norm.dtype)
# Training path (STE)
x_q = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
return F.linear(x_q, w_q)
class BitConv2d(nn.Conv2d):
"""Conv2d with per-channel ternary weights and INT8 activations."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups,
bias=bias, padding_mode=padding_mode)
self.channel_scale = nn.Parameter(torch.ones(out_channels))
def _quant_weight(self):
w_q = weight_quant_ternary_per_channel(self.weight)
return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1)))
def forward(self, x):
if self.training:
x_q = x + (activation_quant_2d(x) - x).detach()
w_q = self.weight + (self._quant_weight() - self.weight).detach()
else:
x_q = activation_quant_2d(x)
w_q = self._quant_weight()
return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
class TernaryLayerNorm(nn.Module):
"""LayerNorm with ternary affine parameters (gamma, beta)."""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
self.bias = nn.Parameter(torch.zeros(self.normalized_shape))
def forward(self, x):
if self.training:
w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach()
else:
w_q = weight_quant_ternary(self.weight)
b_q = weight_quant_ternary(self.bias)
return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps)
# ============================================================================
# Model conversion: FP32 timm model -> fully ternary
# ============================================================================
def make_ternary(model: nn.Module) -> nn.Module:
"""Convert all Linear, LayerNorm, and patch embed Conv2d to ternary."""
# Linear -> BitLinear
for name, module in list(model.named_modules()):
if isinstance(module, nn.Linear):
parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
parent = model if not parent_name else dict(model.named_modules())[parent_name]
setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None))
# LayerNorm -> TernaryLayerNorm
for name, module in list(model.named_modules()):
if isinstance(module, nn.LayerNorm):
parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
parent = model if not parent_name else dict(model.named_modules())[parent_name]
setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps))
# Patch embed Conv2d -> BitConv2d
patch_embed = getattr(model, "patch_embed", None)
if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d):
old = patch_embed.proj
new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size,
stride=old.stride, padding=old.padding, bias=old.bias is not None)
patch_embed.proj = new
return model
# ============================================================================
# Load and evaluate
# ============================================================================
import timm
from torchvision import datasets, transforms
# --- Configuration (change these) ---
MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k"
CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth"
DATASET = "imagenet" # "imagenet", "cifar10", or "cifar100"
DATA_DIR = "./data/imagenet" # path to ImageNet val/ or CIFAR download dir
NUM_CLASSES = 1000 # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100
BATCH_SIZE = 128
# ------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Build model + ternary conversion
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
model = make_ternary(model)
# 2. Load checkpoint
path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT)
sd = torch.load(path, map_location="cpu", weights_only=False)
sd = {k.removeprefix("timm_model."): v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
# 3. Build eval dataloader
from timm.data import resolve_data_config, create_transform
config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False))
if DATASET == "imagenet":
transform = create_transform(**config, is_training=False)
val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform)
else:
# CIFAR models were trained with mean/std = [0.5, 0.5, 0.5]
transform = transforms.Compose([
transforms.Resize((config["input_size"][1], config["input_size"][2])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100
val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=4, pin_memory=True)
# 4. Evaluate
correct = total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
preds = model(images).argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)")
print(f"Evaluated {total} samples")
```
## 📝 Citation
```bibtex
@article{rucinski2026ftervit,
title={FTerViT: Fully Ternary Vision Transformer},
author={Ruci{\'n}ski, Szymon and Bonazzi, Pietro and Turetken, Engin and Narduzzi, Simon and Magno, Michele and Maamari, Nadim},
journal={arXiv preprint arXiv:2605.21171},
year={2026}
}
```
|