Image Classification
vision
ternary
quantization
vit
File size: 12,874 Bytes
d273128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c86ca64
dd50e51
 
 
c86ca64
dd50e51
d273128
 
dd50e51
d273128
dd50e51
d273128
 
 
dd50e51
d273128
df30149
 
31e143d
 
 
 
24abf46
 
d273128
dd50e51
d273128
df30149
 
d273128
 
 
dd50e51
d273128
 
 
 
 
 
 
 
dd50e51
c6803f8
 
d273128
31e143d
c6803f8
 
 
 
31e143d
c6803f8
 
31e143d
d273128
c6803f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d273128
c6803f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31e143d
d273128
dd50e51
d273128
31e143d
2576dac
31e143d
 
2576dac
31e143d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
---
license: apache-2.0
tags:
  - vision
  - image-classification
  - ternary
  - quantization
  - vit
datasets:
  - imagenet-1k
  - cifar10
  - cifar100
---

# FTerViT: Fully Ternary Vision Transformer

[![arXiv](https://img.shields.io/badge/arXiv-2605.21171-B31B1B?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2605.21171)
[![GitHub](https://img.shields.io/badge/GitHub-FTerViT-181717?style=for-the-badge&logo=github&logoColor=white)](https://github.com/szymonrucinski/FTerViT)
[![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-FTerViT-FFD21E?style=for-the-badge)](https://huggingface.co/szymonrucinski/FTerViT)
[![NeurIPS](https://img.shields.io/badge/NeurIPS-2026-purple?style=for-the-badge)](https://neurips.cc/)
[![Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Demo-Live-orange?style=for-the-badge)](https://huggingface.co/spaces/szymonrucinski/FTerViT-demo)

Pretrained checkpoints for **FTerViT** — the first fully ternary Vision Transformer where *all* weight matrices and normalization parameters are constrained to {-1, 0, +1}.

> **W2A8** · 2-bit weights · 8-bit activations · **100% ternary** · 15x compression · sub-6 MB models

## 🏆 Key Results

All models use **W2A8** (2-bit weights, 8-bit activations) with 100% ternary coverage — including patch embedding, LayerNorm, and classifier head.

### 📊 ImageNet-1K

| Model | Phase | Epochs | Top-1 (%) | Binary (MB) | Compression | Checkpoint |
|-------|-------|--------|-----------|-------------|-------------|------------|
| DeiT-Small | Phase 1 | 250 | 75.05 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc75.05_deit_small_224.pth) |
| DeiT-III-Small | Phase 1 | 250 | 76.78 | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc76.78_deit3_small_224.pth) |
| DeiT-Small | Phase 2 | +10 | **77.47** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc77.47_deit_small_224.pth) |
| DeiT-III-Small | Phase 2 | +10 | **79.64** | 5.81 | 15.2x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth) |
| DeiT-III-Small (384) | Phase 1 | 250 | 78.35 | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase1_ep250_acc78.35_deit3_small_384.pth) |
| DeiT-III-Small (384) | Phase 2 | +10 | **82.43** | 6.09 | 14.6x | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/imagenet1k/phase2_ep010_acc82.43_deit3_small_384.pth) |

### 📊 CIFAR-10 / CIFAR-100

| Model | Dataset | Top-1 (%) | FP32 Baseline | Binary (MB) | Checkpoint |
|-------|---------|-----------|---------------|-------------|------------|
| DeiT-Tiny | CIFAR-10 | **97.43** | 97.52 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar10/phase2_ep010_acc97.43_deit_tiny_224.pth) |
| DeiT-Tiny | CIFAR-100 | **86.01** | 86.54 | 1.53 | [download](https://huggingface.co/szymonrucinski/FTerViT/resolve/main/cifar100/phase2_ep010_acc86.01_deit_tiny_224.pth) |

## 🔧 Training Protocol

Training uses a two-phase knowledge distillation approach:

- **Phase 1:** QAD with frozen FP32 teacher, KL-only loss, lr=1e-4 cosine decay, 250 epochs
- **Phase 2:** Low-lr recovery fine-tuning, lr=1e-5 cosine decay, 10 epochs

See the paper for full details.

## 🚀 Self-Contained Inference Example

The code below loads and evaluates a FTerViT checkpoint **without any external dependencies beyond `torch`, `timm`, and `huggingface_hub`**. All ternary layer definitions are included inline.

```python
"""
FTerViT — self-contained inference example.
Requirements: pip install torch timm huggingface_hub torchvision
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

# ============================================================================
# Ternary quantization primitives
# ============================================================================

def activation_quant(x: torch.Tensor) -> torch.Tensor:
    """Per-token INT8 activation quantization."""
    scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
    return (x * scale).round().clamp_(-128, 127) / scale

def activation_quant_2d(x: torch.Tensor) -> torch.Tensor:
    """Per-channel INT8 activation quantization for Conv2d (NCHW)."""
    scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5)
    return (x * scale).round().clamp_(-128, 127) / scale

def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor:
    """Ternary weight quantization: {-1, 0, +1} with absmean scaling."""
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    return (w * scale).round().clamp_(-1, 1) / scale

def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor:
    """Per-output-channel ternary quantization."""
    scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5)
    return (w * scale).round().clamp_(-1, 1) / scale

# ============================================================================
# Ternary layer definitions
# ============================================================================

class BitLinear(nn.Linear):
    """Linear layer with ternary weights and INT8 activations."""
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features, bias)
        self.norm = nn.RMSNorm(in_features, eps=1e-5)

    def forward(self, x):
        x_norm = self.norm(x)
        if not self.training:
            max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
            x_scale = 127.0 / max_val
            x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16)
            w_f = self.weight.float()
            w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5)
            w_q = (w_f * w_scale).round().clamp_(-1, 1)
            y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16))
            return y.to(x_norm.dtype)
        # Training path (STE)
        x_q = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
        return F.linear(x_q, w_q)

class BitConv2d(nn.Conv2d):
    """Conv2d with per-channel ternary weights and INT8 activations."""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride,
                         padding=padding, dilation=dilation, groups=groups,
                         bias=bias, padding_mode=padding_mode)
        self.channel_scale = nn.Parameter(torch.ones(out_channels))

    def _quant_weight(self):
        w_q = weight_quant_ternary_per_channel(self.weight)
        return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1)))

    def forward(self, x):
        if self.training:
            x_q = x + (activation_quant_2d(x) - x).detach()
            w_q = self.weight + (self._quant_weight() - self.weight).detach()
        else:
            x_q = activation_quant_2d(x)
            w_q = self._quant_weight()
        return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups)

class TernaryLayerNorm(nn.Module):
    """LayerNorm with ternary affine parameters (gamma, beta)."""
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(self.normalized_shape))
        self.bias = nn.Parameter(torch.zeros(self.normalized_shape))

    def forward(self, x):
        if self.training:
            w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
            b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach()
        else:
            w_q = weight_quant_ternary(self.weight)
            b_q = weight_quant_ternary(self.bias)
        return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps)

# ============================================================================
# Model conversion: FP32 timm model -> fully ternary
# ============================================================================

def make_ternary(model: nn.Module) -> nn.Module:
    """Convert all Linear, LayerNorm, and patch embed Conv2d to ternary."""
    # Linear -> BitLinear
    for name, module in list(model.named_modules()):
        if isinstance(module, nn.Linear):
            parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
            parent = model if not parent_name else dict(model.named_modules())[parent_name]
            setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None))
    # LayerNorm -> TernaryLayerNorm
    for name, module in list(model.named_modules()):
        if isinstance(module, nn.LayerNorm):
            parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
            parent = model if not parent_name else dict(model.named_modules())[parent_name]
            setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps))
    # Patch embed Conv2d -> BitConv2d
    patch_embed = getattr(model, "patch_embed", None)
    if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d):
        old = patch_embed.proj
        new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size,
                        stride=old.stride, padding=old.padding, bias=old.bias is not None)
        patch_embed.proj = new
    return model

# ============================================================================
# Load and evaluate
# ============================================================================

import timm
from torchvision import datasets, transforms

# --- Configuration (change these) ---
MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k"
CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth"
DATASET = "imagenet"       # "imagenet", "cifar10", or "cifar100"
DATA_DIR = "./data/imagenet"  # path to ImageNet val/ or CIFAR download dir
NUM_CLASSES = 1000          # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100
BATCH_SIZE = 128
# ------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Build model + ternary conversion
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
model = make_ternary(model)

# 2. Load checkpoint
path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT)
sd = torch.load(path, map_location="cpu", weights_only=False)
sd = {k.removeprefix("timm_model."): v for k, v in sd.items()}
model.load_state_dict(sd, strict=False)
model = model.to(device).eval()

# 3. Build eval dataloader
from timm.data import resolve_data_config, create_transform
config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False))

if DATASET == "imagenet":
    transform = create_transform(**config, is_training=False)
    val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform)
else:
    # CIFAR models were trained with mean/std = [0.5, 0.5, 0.5]
    transform = transforms.Compose([
        transforms.Resize((config["input_size"][1], config["input_size"][2])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100
    val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform)

val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                          num_workers=4, pin_memory=True)

# 4. Evaluate
correct = total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images).argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)")
print(f"Evaluated {total} samples")
```

## 📝 Citation

```bibtex
@article{rucinski2026ftervit,
  title={FTerViT: Fully Ternary Vision Transformer},
  author={Ruci{\'n}ski, Szymon and Bonazzi, Pietro and Turetken, Engin and Narduzzi, Simon and Magno, Michele and Maamari, Nadim},
  journal={arXiv preprint arXiv:2605.21171},
  year={2026}
}
```