metadata
license: apache-2.0
tags:
- vision
- image-classification
- ternary
- quantization
- vit
datasets:
- imagenet-1k
- cifar10
- cifar100
FTerViT: Fully Ternary Vision Transformer
Pretrained checkpoints for FTerViT — the first fully ternary Vision Transformer where all weight matrices and normalization parameters are constrained to {-1, 0, +1}.
W2A8 · 2-bit weights · 8-bit activations · 100% ternary · 15x compression · sub-6 MB models
🏆 Key Results
All models use W2A8 (2-bit weights, 8-bit activations) with 100% ternary coverage — including patch embedding, LayerNorm, and classifier head.
📊 ImageNet-1K
| Model | Phase | Epochs | Top-1 (%) | Binary (MB) | Compression | Checkpoint |
|---|---|---|---|---|---|---|
| DeiT-Small | Phase 1 | 250 | 75.05 | 5.81 | 15.2x | download |
| DeiT-III-Small | Phase 1 | 250 | 76.78 | 5.81 | 15.2x | download |
| DeiT-Small | Phase 2 | +10 | 77.47 | 5.81 | 15.2x | download |
| DeiT-III-Small | Phase 2 | +10 | 79.64 | 5.81 | 15.2x | download |
| DeiT-III-Small (384) | Phase 1 | 250 | 78.35 | 6.09 | 14.6x | download |
| DeiT-III-Small (384) | Phase 2 | +10 | 82.43 | 6.09 | 14.6x | download |
📊 CIFAR-10 / CIFAR-100
| Model | Dataset | Top-1 (%) | FP32 Baseline | Binary (MB) | Checkpoint |
|---|---|---|---|---|---|
| DeiT-Tiny | CIFAR-10 | 97.43 | 97.52 | 1.53 | download |
| DeiT-Tiny | CIFAR-100 | 86.01 | 86.54 | 1.53 | download |
🔧 Training Protocol
Training uses a two-phase knowledge distillation approach:
- Phase 1: QAD with frozen FP32 teacher, KL-only loss, lr=1e-4 cosine decay, 250 epochs
- Phase 2: Low-lr recovery fine-tuning, lr=1e-5 cosine decay, 10 epochs
See the paper for full details.
🚀 Self-Contained Inference Example
The code below loads and evaluates a FTerViT checkpoint without any external dependencies beyond torch, timm, and huggingface_hub. All ternary layer definitions are included inline.
"""
FTerViT — self-contained inference example.
Requirements: pip install torch timm huggingface_hub torchvision
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
# ============================================================================
# Ternary quantization primitives
# ============================================================================
def activation_quant(x: torch.Tensor) -> torch.Tensor:
"""Per-token INT8 activation quantization."""
scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
return (x * scale).round().clamp_(-128, 127) / scale
def activation_quant_2d(x: torch.Tensor) -> torch.Tensor:
"""Per-channel INT8 activation quantization for Conv2d (NCHW)."""
scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5)
return (x * scale).round().clamp_(-128, 127) / scale
def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor:
"""Ternary weight quantization: {-1, 0, +1} with absmean scaling."""
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
return (w * scale).round().clamp_(-1, 1) / scale
def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor:
"""Per-output-channel ternary quantization."""
scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5)
return (w * scale).round().clamp_(-1, 1) / scale
# ============================================================================
# Ternary layer definitions
# ============================================================================
class BitLinear(nn.Linear):
"""Linear layer with ternary weights and INT8 activations."""
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.norm = nn.RMSNorm(in_features, eps=1e-5)
def forward(self, x):
x_norm = self.norm(x)
if not self.training:
max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
x_scale = 127.0 / max_val
x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16)
w_f = self.weight.float()
w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5)
w_q = (w_f * w_scale).round().clamp_(-1, 1)
y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16))
return y.to(x_norm.dtype)
# Training path (STE)
x_q = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
return F.linear(x_q, w_q)
class BitConv2d(nn.Conv2d):
"""Conv2d with per-channel ternary weights and INT8 activations."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super().__init__(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups,
bias=bias, padding_mode=padding_mode)
self.channel_scale = nn.Parameter(torch.ones(out_channels))
def _quant_weight(self):
w_q = weight_quant_ternary_per_channel(self.weight)
return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1)))
def forward(self, x):
if self.training:
x_q = x + (activation_quant_2d(x) - x).detach()
w_q = self.weight + (self._quant_weight() - self.weight).detach()
else:
x_q = activation_quant_2d(x)
w_q = self._quant_weight()
return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
class TernaryLayerNorm(nn.Module):
"""LayerNorm with ternary affine parameters (gamma, beta)."""
def __init__(self, normalized_shape, eps=1e-5):
super().__init__()
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
self.bias = nn.Parameter(torch.zeros(self.normalized_shape))
def forward(self, x):
if self.training:
w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach()
else:
w_q = weight_quant_ternary(self.weight)
b_q = weight_quant_ternary(self.bias)
return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps)
# ============================================================================
# Model conversion: FP32 timm model -> fully ternary
# ============================================================================
def make_ternary(model: nn.Module) -> nn.Module:
"""Convert all Linear, LayerNorm, and patch embed Conv2d to ternary."""
# Linear -> BitLinear
for name, module in list(model.named_modules()):
if isinstance(module, nn.Linear):
parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
parent = model if not parent_name else dict(model.named_modules())[parent_name]
setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None))
# LayerNorm -> TernaryLayerNorm
for name, module in list(model.named_modules()):
if isinstance(module, nn.LayerNorm):
parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
parent = model if not parent_name else dict(model.named_modules())[parent_name]
setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps))
# Patch embed Conv2d -> BitConv2d
patch_embed = getattr(model, "patch_embed", None)
if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d):
old = patch_embed.proj
new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size,
stride=old.stride, padding=old.padding, bias=old.bias is not None)
patch_embed.proj = new
return model
# ============================================================================
# Load and evaluate
# ============================================================================
import timm
from torchvision import datasets, transforms
# --- Configuration (change these) ---
MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k"
CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth"
DATASET = "imagenet" # "imagenet", "cifar10", or "cifar100"
DATA_DIR = "./data/imagenet" # path to ImageNet val/ or CIFAR download dir
NUM_CLASSES = 1000 # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100
BATCH_SIZE = 128
# ------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Build model + ternary conversion
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
model = make_ternary(model)
# 2. Load checkpoint
path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT)
sd = torch.load(path, map_location="cpu", weights_only=False)
sd = {k.removeprefix("timm_model."): v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()
# 3. Build eval dataloader
from timm.data import resolve_data_config, create_transform
config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False))
if DATASET == "imagenet":
transform = create_transform(**config, is_training=False)
val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform)
else:
# CIFAR models were trained with mean/std = [0.5, 0.5, 0.5]
transform = transforms.Compose([
transforms.Resize((config["input_size"][1], config["input_size"][2])),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100
val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
num_workers=4, pin_memory=True)
# 4. Evaluate
correct = total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
preds = model(images).argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)")
print(f"Evaluated {total} samples")
📝 Citation
@article{rucinski2026ftervit,
title={FTerViT: Fully Ternary Vision Transformer},
author={Ruci{\'n}ski, Szymon and Bonazzi, Pietro and Turetken, Engin and Narduzzi, Simon and Magno, Michele and Maamari, Nadim},
journal={arXiv preprint arXiv:2605.21171},
year={2026}
}