Image Classification
vision
ternary
quantization
vit

FTerViT: Fully Ternary Vision Transformer

arXiv GitHub HuggingFace NeurIPS Demo

Pretrained checkpoints for FTerViT — the first fully ternary Vision Transformer where all weight matrices and normalization parameters are constrained to {-1, 0, +1}.

W2A8 · 2-bit weights · 8-bit activations · 100% ternary · 15x compression · sub-6 MB models

🏆 Key Results

All models use W2A8 (2-bit weights, 8-bit activations) with 100% ternary coverage — including patch embedding, LayerNorm, and classifier head.

📊 ImageNet-1K

Model Phase Epochs Top-1 (%) Binary (MB) Compression Checkpoint
DeiT-Small Phase 1 250 75.05 5.81 15.2x download
DeiT-III-Small Phase 1 250 76.78 5.81 15.2x download
DeiT-Small Phase 2 +10 77.47 5.81 15.2x download
DeiT-III-Small Phase 2 +10 79.64 5.81 15.2x download
DeiT-III-Small (384) Phase 1 250 78.35 6.09 14.6x download
DeiT-III-Small (384) Phase 2 +10 82.43 6.09 14.6x download

📊 CIFAR-10 / CIFAR-100

Model Dataset Top-1 (%) FP32 Baseline Binary (MB) Checkpoint
DeiT-Tiny CIFAR-10 97.43 97.52 1.53 download
DeiT-Tiny CIFAR-100 86.01 86.54 1.53 download

🔧 Training Protocol

Training uses a two-phase knowledge distillation approach:

  • Phase 1: QAD with frozen FP32 teacher, KL-only loss, lr=1e-4 cosine decay, 250 epochs
  • Phase 2: Low-lr recovery fine-tuning, lr=1e-5 cosine decay, 10 epochs

See the paper for full details.

🚀 Self-Contained Inference Example

The code below loads and evaluates a FTerViT checkpoint without any external dependencies beyond torch, timm, and huggingface_hub. All ternary layer definitions are included inline.

"""
FTerViT — self-contained inference example.
Requirements: pip install torch timm huggingface_hub torchvision
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

# ============================================================================
# Ternary quantization primitives
# ============================================================================

def activation_quant(x: torch.Tensor) -> torch.Tensor:
    """Per-token INT8 activation quantization."""
    scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
    return (x * scale).round().clamp_(-128, 127) / scale

def activation_quant_2d(x: torch.Tensor) -> torch.Tensor:
    """Per-channel INT8 activation quantization for Conv2d (NCHW)."""
    scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5)
    return (x * scale).round().clamp_(-128, 127) / scale

def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor:
    """Ternary weight quantization: {-1, 0, +1} with absmean scaling."""
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    return (w * scale).round().clamp_(-1, 1) / scale

def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor:
    """Per-output-channel ternary quantization."""
    scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5)
    return (w * scale).round().clamp_(-1, 1) / scale

# ============================================================================
# Ternary layer definitions
# ============================================================================

class BitLinear(nn.Linear):
    """Linear layer with ternary weights and INT8 activations."""
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.norm = nn.RMSNorm(in_features, eps=1e-5)

    def forward(self, x):
        x_norm = self.norm(x)
        if not self.training:
            max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
            x_scale = 127.0 / max_val
            x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16)
            w_f = self.weight.float()
            w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5)
            w_q = (w_f * w_scale).round().clamp_(-1, 1)
            y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16))
            return y.to(x_norm.dtype)
        # Training path (STE)
        x_q = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
        return F.linear(x_q, w_q)

class BitConv2d(nn.Conv2d):
    """Conv2d with per-channel ternary weights and INT8 activations."""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, dilation=dilation, groups=groups,
                         bias=bias, padding_mode=padding_mode)
        self.channel_scale = nn.Parameter(torch.ones(out_channels))

    def _quant_weight(self):
        w_q = weight_quant_ternary_per_channel(self.weight)
        return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1)))

    def forward(self, x):
        if self.training:
            x_q = x + (activation_quant_2d(x) - x).detach()
            w_q = self.weight + (self._quant_weight() - self.weight).detach()
        else:
            x_q = activation_quant_2d(x)
            w_q = self._quant_weight()
        return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups)

class TernaryLayerNorm(nn.Module):
    """LayerNorm with ternary affine parameters (gamma, beta)."""
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(self.normalized_shape))
        self.bias = nn.Parameter(torch.zeros(self.normalized_shape))

    def forward(self, x):
        if self.training:
            w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
            b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach()
        else:
            w_q = weight_quant_ternary(self.weight)
            b_q = weight_quant_ternary(self.bias)
        return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps)

# ============================================================================
# Model conversion: FP32 timm model -> fully ternary
# ============================================================================

def make_ternary(model: nn.Module) -> nn.Module:
    """Convert all Linear, LayerNorm, and patch embed Conv2d to ternary."""
    # Linear -> BitLinear
    for name, module in list(model.named_modules()):
        if isinstance(module, nn.Linear):
            parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
            parent = model if not parent_name else dict(model.named_modules())[parent_name]
            setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None))
    # LayerNorm -> TernaryLayerNorm
    for name, module in list(model.named_modules()):
        if isinstance(module, nn.LayerNorm):
            parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
            parent = model if not parent_name else dict(model.named_modules())[parent_name]
            setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps))
    # Patch embed Conv2d -> BitConv2d
    patch_embed = getattr(model, "patch_embed", None)
    if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d):
        old = patch_embed.proj
        new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size,
                        stride=old.stride, padding=old.padding, bias=old.bias is not None)
        patch_embed.proj = new
    return model

# ============================================================================
# Load and evaluate
# ============================================================================

import timm
from torchvision import datasets, transforms

# --- Configuration (change these) ---
MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k"
CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth"
DATASET = "imagenet"       # "imagenet", "cifar10", or "cifar100"
DATA_DIR = "./data/imagenet"  # path to ImageNet val/ or CIFAR download dir
NUM_CLASSES = 1000          # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100
BATCH_SIZE = 128
# ------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Build model + ternary conversion
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
model = make_ternary(model)

# 2. Load checkpoint
path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT)
sd = torch.load(path, map_location="cpu", weights_only=False)
sd = {k.removeprefix("timm_model."): v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()

# 3. Build eval dataloader
from timm.data import resolve_data_config, create_transform
config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False))

if DATASET == "imagenet":
    transform = create_transform(**config, is_training=False)
    val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform)
else:
    # CIFAR models were trained with mean/std = [0.5, 0.5, 0.5]
    transform = transforms.Compose([
        transforms.Resize((config["input_size"][1], config["input_size"][2])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100
    val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=4, pin_memory=True)

# 4. Evaluate
correct = total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images).argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)")
print(f"Evaluated {total} samples")

📝 Citation

@article{rucinski2026ftervit,
  title={FTerViT: Fully Ternary Vision Transformer},
  author={Ruci{\'n}ski, Szymon and Bonazzi, Pietro and Turetken, Engin and Narduzzi, Simon and Magno, Michele and Maamari, Nadim},
  journal={arXiv preprint arXiv:2605.21171},
  year={2026}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Datasets used to train szymonrucinski/FTerViT

Space using szymonrucinski/FTerViT 1

Paper for szymonrucinski/FTerViT