Add comprehensive README with self-contained inference example
Browse files
README.md
CHANGED
|
@@ -48,25 +48,199 @@ Training uses a two-phase knowledge distillation approach:
|
|
| 48 |
|
| 49 |
See the paper for full details.
|
| 50 |
|
| 51 |
-
##
|
|
|
|
|
|
|
| 52 |
|
| 53 |
```python
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
import torch
|
|
|
|
|
|
|
| 56 |
from huggingface_hub import hf_hub_download
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
#
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
```
|
| 71 |
|
| 72 |
## Citation
|
|
|
|
| 48 |
|
| 49 |
See the paper for full details.
|
| 50 |
|
| 51 |
+
## Self-Contained Inference Example
|
| 52 |
+
|
| 53 |
+
The code below loads and evaluates a FTerViT checkpoint **without any external dependencies beyond `torch`, `timm`, and `huggingface_hub`**. All ternary layer definitions are included inline.
|
| 54 |
|
| 55 |
```python
|
| 56 |
+
"""
|
| 57 |
+
FTerViT — self-contained inference example.
|
| 58 |
+
Requirements: pip install torch timm huggingface_hub torchvision
|
| 59 |
+
"""
|
| 60 |
import torch
|
| 61 |
+
import torch.nn as nn
|
| 62 |
+
import torch.nn.functional as F
|
| 63 |
from huggingface_hub import hf_hub_download
|
| 64 |
|
| 65 |
+
# ============================================================================
|
| 66 |
+
# Ternary quantization primitives
|
| 67 |
+
# ============================================================================
|
| 68 |
+
|
| 69 |
+
def activation_quant(x: torch.Tensor) -> torch.Tensor:
|
| 70 |
+
"""Per-token INT8 activation quantization."""
|
| 71 |
+
scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
|
| 72 |
+
return (x * scale).round().clamp_(-128, 127) / scale
|
| 73 |
+
|
| 74 |
+
def activation_quant_2d(x: torch.Tensor) -> torch.Tensor:
|
| 75 |
+
"""Per-channel INT8 activation quantization for Conv2d (NCHW)."""
|
| 76 |
+
scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5)
|
| 77 |
+
return (x * scale).round().clamp_(-128, 127) / scale
|
| 78 |
+
|
| 79 |
+
def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
"""Ternary weight quantization: {-1, 0, +1} with absmean scaling."""
|
| 81 |
+
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
|
| 82 |
+
return (w * scale).round().clamp_(-1, 1) / scale
|
| 83 |
+
|
| 84 |
+
def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""Per-output-channel ternary quantization."""
|
| 86 |
+
scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5)
|
| 87 |
+
return (w * scale).round().clamp_(-1, 1) / scale
|
| 88 |
+
|
| 89 |
+
# ============================================================================
|
| 90 |
+
# Ternary layer definitions
|
| 91 |
+
# ============================================================================
|
| 92 |
+
|
| 93 |
+
class BitLinear(nn.Linear):
|
| 94 |
+
"""Linear layer with ternary weights and INT8 activations."""
|
| 95 |
+
def __init__(self, in_features, out_features, bias=True):
|
| 96 |
+
super().__init__(in_features, out_features, bias)
|
| 97 |
+
self.norm = nn.RMSNorm(in_features, eps=1e-5)
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
x_norm = self.norm(x)
|
| 101 |
+
if not self.training:
|
| 102 |
+
max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
|
| 103 |
+
x_scale = 127.0 / max_val
|
| 104 |
+
x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16)
|
| 105 |
+
w_f = self.weight.float()
|
| 106 |
+
w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5)
|
| 107 |
+
w_q = (w_f * w_scale).round().clamp_(-1, 1)
|
| 108 |
+
y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16))
|
| 109 |
+
return y.to(x_norm.dtype)
|
| 110 |
+
# Training path (STE)
|
| 111 |
+
x_q = x_norm + (activation_quant(x_norm) - x_norm).detach()
|
| 112 |
+
w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
|
| 113 |
+
return F.linear(x_q, w_q)
|
| 114 |
+
|
| 115 |
+
class BitConv2d(nn.Conv2d):
|
| 116 |
+
"""Conv2d with per-channel ternary weights and INT8 activations."""
|
| 117 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 118 |
+
padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
|
| 119 |
+
super().__init__(in_channels, out_channels, kernel_size, stride=stride,
|
| 120 |
+
padding=padding, dilation=dilation, groups=groups,
|
| 121 |
+
bias=bias, padding_mode=padding_mode)
|
| 122 |
+
self.channel_scale = nn.Parameter(torch.ones(out_channels))
|
| 123 |
+
|
| 124 |
+
def _quant_weight(self):
|
| 125 |
+
w_q = weight_quant_ternary_per_channel(self.weight)
|
| 126 |
+
return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1)))
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
if self.training:
|
| 130 |
+
x_q = x + (activation_quant_2d(x) - x).detach()
|
| 131 |
+
w_q = self.weight + (self._quant_weight() - self.weight).detach()
|
| 132 |
+
else:
|
| 133 |
+
x_q = activation_quant_2d(x)
|
| 134 |
+
w_q = self._quant_weight()
|
| 135 |
+
return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
| 136 |
+
|
| 137 |
+
class TernaryLayerNorm(nn.Module):
|
| 138 |
+
"""LayerNorm with ternary affine parameters (gamma, beta)."""
|
| 139 |
+
def __init__(self, normalized_shape, eps=1e-5):
|
| 140 |
+
super().__init__()
|
| 141 |
+
if isinstance(normalized_shape, int):
|
| 142 |
+
normalized_shape = (normalized_shape,)
|
| 143 |
+
self.normalized_shape = tuple(normalized_shape)
|
| 144 |
+
self.eps = eps
|
| 145 |
+
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
|
| 146 |
+
self.bias = nn.Parameter(torch.zeros(self.normalized_shape))
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
if self.training:
|
| 150 |
+
w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
|
| 151 |
+
b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach()
|
| 152 |
+
else:
|
| 153 |
+
w_q = weight_quant_ternary(self.weight)
|
| 154 |
+
b_q = weight_quant_ternary(self.bias)
|
| 155 |
+
return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps)
|
| 156 |
+
|
| 157 |
+
# ============================================================================
|
| 158 |
+
# Model conversion: FP32 timm model -> fully ternary
|
| 159 |
+
# ============================================================================
|
| 160 |
+
|
| 161 |
+
def make_ternary(model: nn.Module) -> nn.Module:
|
| 162 |
+
"""Convert all Linear, LayerNorm, and patch embed Conv2d to ternary."""
|
| 163 |
+
# Linear -> BitLinear
|
| 164 |
+
for name, module in list(model.named_modules()):
|
| 165 |
+
if isinstance(module, nn.Linear):
|
| 166 |
+
parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
|
| 167 |
+
parent = model if not parent_name else dict(model.named_modules())[parent_name]
|
| 168 |
+
setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None))
|
| 169 |
+
# LayerNorm -> TernaryLayerNorm
|
| 170 |
+
for name, module in list(model.named_modules()):
|
| 171 |
+
if isinstance(module, nn.LayerNorm):
|
| 172 |
+
parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
|
| 173 |
+
parent = model if not parent_name else dict(model.named_modules())[parent_name]
|
| 174 |
+
setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps))
|
| 175 |
+
# Patch embed Conv2d -> BitConv2d
|
| 176 |
+
patch_embed = getattr(model, "patch_embed", None)
|
| 177 |
+
if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d):
|
| 178 |
+
old = patch_embed.proj
|
| 179 |
+
new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size,
|
| 180 |
+
stride=old.stride, padding=old.padding, bias=old.bias is not None)
|
| 181 |
+
patch_embed.proj = new
|
| 182 |
+
return model
|
| 183 |
+
|
| 184 |
+
# ============================================================================
|
| 185 |
+
# Load and evaluate
|
| 186 |
+
# ============================================================================
|
| 187 |
|
| 188 |
+
import timm
|
| 189 |
+
from torchvision import datasets, transforms
|
| 190 |
+
|
| 191 |
+
# --- Configuration (change these) ---
|
| 192 |
+
MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k"
|
| 193 |
+
CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth"
|
| 194 |
+
DATASET = "imagenet" # "imagenet", "cifar10", or "cifar100"
|
| 195 |
+
DATA_DIR = "./data/imagenet" # path to ImageNet val/ or CIFAR download dir
|
| 196 |
+
NUM_CLASSES = 1000 # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100
|
| 197 |
+
BATCH_SIZE = 128
|
| 198 |
+
# ------------------------------------
|
| 199 |
+
|
| 200 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 201 |
+
|
| 202 |
+
# 1. Build model + ternary conversion
|
| 203 |
+
model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
|
| 204 |
+
model = make_ternary(model)
|
| 205 |
+
|
| 206 |
+
# 2. Load checkpoint
|
| 207 |
+
path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT)
|
| 208 |
+
sd = torch.load(path, map_location="cpu", weights_only=False)
|
| 209 |
+
sd = {k.removeprefix("timm_model."): v for k, v in sd.items()}
|
| 210 |
+
model.load_state_dict(sd, strict=False)
|
| 211 |
+
model = model.to(device).eval()
|
| 212 |
+
|
| 213 |
+
# 3. Build eval dataloader
|
| 214 |
+
from timm.data import resolve_data_config, create_transform
|
| 215 |
+
config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False))
|
| 216 |
+
|
| 217 |
+
if DATASET == "imagenet":
|
| 218 |
+
transform = create_transform(**config, is_training=False)
|
| 219 |
+
val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform)
|
| 220 |
+
else:
|
| 221 |
+
# CIFAR models were trained with mean/std = [0.5, 0.5, 0.5]
|
| 222 |
+
transform = transforms.Compose([
|
| 223 |
+
transforms.Resize((config["input_size"][1], config["input_size"][2])),
|
| 224 |
+
transforms.ToTensor(),
|
| 225 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 226 |
+
])
|
| 227 |
+
cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100
|
| 228 |
+
val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform)
|
| 229 |
+
|
| 230 |
+
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
|
| 231 |
+
num_workers=4, pin_memory=True)
|
| 232 |
+
|
| 233 |
+
# 4. Evaluate
|
| 234 |
+
correct = total = 0
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
for images, labels in val_loader:
|
| 237 |
+
images, labels = images.to(device), labels.to(device)
|
| 238 |
+
preds = model(images).argmax(dim=1)
|
| 239 |
+
correct += (preds == labels).sum().item()
|
| 240 |
+
total += labels.size(0)
|
| 241 |
+
|
| 242 |
+
print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)")
|
| 243 |
+
print(f"Evaluated {total} samples")
|
| 244 |
```
|
| 245 |
|
| 246 |
## Citation
|