Image Classification
vision
ternary
quantization
vit
szymonrucinski commited on
Commit
c6803f8
·
verified ·
1 Parent(s): 31e143d

Add comprehensive README with self-contained inference example

Browse files
Files changed (1) hide show
  1. README.md +187 -13
README.md CHANGED
@@ -48,25 +48,199 @@ Training uses a two-phase knowledge distillation approach:
48
 
49
  See the paper for full details.
50
 
51
- ## Loading a Checkpoint
 
 
52
 
53
  ```python
54
- import timm
 
 
 
55
  import torch
 
 
56
  from huggingface_hub import hf_hub_download
57
 
58
- # 1. Build model and apply ternary conversion
59
- model = timm.create_model("deit3_small_patch16_224.fb_in22k_ft_in1k", pretrained=False, num_classes=1000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Replace Linear -> BitLinear, LayerNorm -> TernaryLayerNorm, Conv2d -> BitConv2d
62
- # (see repo for ternary conversion utilities)
63
-
64
- # 2. Download and load checkpoint
65
- path = hf_hub_download("szymonrucinski/FTerViT", "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth")
66
- state_dict = torch.load(path, map_location="cpu")
67
- # Strip wrapper prefix if present
68
- state_dict = {k.removeprefix("timm_model."): v for k, v in state_dict.items()}
69
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ```
71
 
72
  ## Citation
 
48
 
49
  See the paper for full details.
50
 
51
+ ## Self-Contained Inference Example
52
+
53
+ The code below loads and evaluates a FTerViT checkpoint **without any external dependencies beyond `torch`, `timm`, and `huggingface_hub`**. All ternary layer definitions are included inline.
54
 
55
  ```python
56
+ """
57
+ FTerViT — self-contained inference example.
58
+ Requirements: pip install torch timm huggingface_hub torchvision
59
+ """
60
  import torch
61
+ import torch.nn as nn
62
+ import torch.nn.functional as F
63
  from huggingface_hub import hf_hub_download
64
 
65
+ # ============================================================================
66
+ # Ternary quantization primitives
67
+ # ============================================================================
68
+
69
+ def activation_quant(x: torch.Tensor) -> torch.Tensor:
70
+ """Per-token INT8 activation quantization."""
71
+ scale = 127.0 / x.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
72
+ return (x * scale).round().clamp_(-128, 127) / scale
73
+
74
+ def activation_quant_2d(x: torch.Tensor) -> torch.Tensor:
75
+ """Per-channel INT8 activation quantization for Conv2d (NCHW)."""
76
+ scale = 127.0 / x.abs().amax(dim=(2, 3), keepdim=True).clamp_(min=1e-5)
77
+ return (x * scale).round().clamp_(-128, 127) / scale
78
+
79
+ def weight_quant_ternary(w: torch.Tensor) -> torch.Tensor:
80
+ """Ternary weight quantization: {-1, 0, +1} with absmean scaling."""
81
+ scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
82
+ return (w * scale).round().clamp_(-1, 1) / scale
83
+
84
+ def weight_quant_ternary_per_channel(w: torch.Tensor) -> torch.Tensor:
85
+ """Per-output-channel ternary quantization."""
86
+ scale = 1.0 / w.abs().mean(dim=tuple(range(1, w.dim())), keepdim=True).clamp_(min=1e-5)
87
+ return (w * scale).round().clamp_(-1, 1) / scale
88
+
89
+ # ============================================================================
90
+ # Ternary layer definitions
91
+ # ============================================================================
92
+
93
+ class BitLinear(nn.Linear):
94
+ """Linear layer with ternary weights and INT8 activations."""
95
+ def __init__(self, in_features, out_features, bias=True):
96
+ super().__init__(in_features, out_features, bias)
97
+ self.norm = nn.RMSNorm(in_features, eps=1e-5)
98
+
99
+ def forward(self, x):
100
+ x_norm = self.norm(x)
101
+ if not self.training:
102
+ max_val = x_norm.abs().amax(dim=-1, keepdim=True).clamp_(min=1e-5)
103
+ x_scale = 127.0 / max_val
104
+ x_q = (x_norm * x_scale).round().clamp_(-128, 127).to(torch.bfloat16)
105
+ w_f = self.weight.float()
106
+ w_scale = 1.0 / w_f.abs().mean().clamp_(min=1e-5)
107
+ w_q = (w_f * w_scale).round().clamp_(-1, 1)
108
+ y = F.linear(x_q, w_q.to(torch.bfloat16)) / (w_scale * x_scale.to(torch.bfloat16))
109
+ return y.to(x_norm.dtype)
110
+ # Training path (STE)
111
+ x_q = x_norm + (activation_quant(x_norm) - x_norm).detach()
112
+ w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
113
+ return F.linear(x_q, w_q)
114
+
115
+ class BitConv2d(nn.Conv2d):
116
+ """Conv2d with per-channel ternary weights and INT8 activations."""
117
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
118
+ padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
119
+ super().__init__(in_channels, out_channels, kernel_size, stride=stride,
120
+ padding=padding, dilation=dilation, groups=groups,
121
+ bias=bias, padding_mode=padding_mode)
122
+ self.channel_scale = nn.Parameter(torch.ones(out_channels))
123
+
124
+ def _quant_weight(self):
125
+ w_q = weight_quant_ternary_per_channel(self.weight)
126
+ return w_q * self.channel_scale.view(-1, *([1] * (self.weight.dim() - 1)))
127
+
128
+ def forward(self, x):
129
+ if self.training:
130
+ x_q = x + (activation_quant_2d(x) - x).detach()
131
+ w_q = self.weight + (self._quant_weight() - self.weight).detach()
132
+ else:
133
+ x_q = activation_quant_2d(x)
134
+ w_q = self._quant_weight()
135
+ return F.conv2d(x_q, w_q, self.bias, self.stride, self.padding, self.dilation, self.groups)
136
+
137
+ class TernaryLayerNorm(nn.Module):
138
+ """LayerNorm with ternary affine parameters (gamma, beta)."""
139
+ def __init__(self, normalized_shape, eps=1e-5):
140
+ super().__init__()
141
+ if isinstance(normalized_shape, int):
142
+ normalized_shape = (normalized_shape,)
143
+ self.normalized_shape = tuple(normalized_shape)
144
+ self.eps = eps
145
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape))
146
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape))
147
+
148
+ def forward(self, x):
149
+ if self.training:
150
+ w_q = self.weight + (weight_quant_ternary(self.weight) - self.weight).detach()
151
+ b_q = self.bias + (weight_quant_ternary(self.bias) - self.bias).detach()
152
+ else:
153
+ w_q = weight_quant_ternary(self.weight)
154
+ b_q = weight_quant_ternary(self.bias)
155
+ return F.layer_norm(x, self.normalized_shape, w_q, b_q, self.eps)
156
+
157
+ # ============================================================================
158
+ # Model conversion: FP32 timm model -> fully ternary
159
+ # ============================================================================
160
+
161
+ def make_ternary(model: nn.Module) -> nn.Module:
162
+ """Convert all Linear, LayerNorm, and patch embed Conv2d to ternary."""
163
+ # Linear -> BitLinear
164
+ for name, module in list(model.named_modules()):
165
+ if isinstance(module, nn.Linear):
166
+ parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
167
+ parent = model if not parent_name else dict(model.named_modules())[parent_name]
168
+ setattr(parent, attr, BitLinear(module.in_features, module.out_features, bias=module.bias is not None))
169
+ # LayerNorm -> TernaryLayerNorm
170
+ for name, module in list(model.named_modules()):
171
+ if isinstance(module, nn.LayerNorm):
172
+ parent_name, attr = name.rsplit(".", 1) if "." in name else ("", name)
173
+ parent = model if not parent_name else dict(model.named_modules())[parent_name]
174
+ setattr(parent, attr, TernaryLayerNorm(module.normalized_shape, eps=module.eps))
175
+ # Patch embed Conv2d -> BitConv2d
176
+ patch_embed = getattr(model, "patch_embed", None)
177
+ if patch_embed and hasattr(patch_embed, "proj") and isinstance(patch_embed.proj, nn.Conv2d):
178
+ old = patch_embed.proj
179
+ new = BitConv2d(old.in_channels, old.out_channels, old.kernel_size,
180
+ stride=old.stride, padding=old.padding, bias=old.bias is not None)
181
+ patch_embed.proj = new
182
+ return model
183
+
184
+ # ============================================================================
185
+ # Load and evaluate
186
+ # ============================================================================
187
 
188
+ import timm
189
+ from torchvision import datasets, transforms
190
+
191
+ # --- Configuration (change these) ---
192
+ MODEL_NAME = "deit3_small_patch16_224.fb_in22k_ft_in1k"
193
+ CHECKPOINT = "imagenet1k/phase2_ep010_acc79.64_deit3_small_224.pth"
194
+ DATASET = "imagenet" # "imagenet", "cifar10", or "cifar100"
195
+ DATA_DIR = "./data/imagenet" # path to ImageNet val/ or CIFAR download dir
196
+ NUM_CLASSES = 1000 # 1000 for ImageNet, 10 for CIFAR-10, 100 for CIFAR-100
197
+ BATCH_SIZE = 128
198
+ # ------------------------------------
199
+
200
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
201
+
202
+ # 1. Build model + ternary conversion
203
+ model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
204
+ model = make_ternary(model)
205
+
206
+ # 2. Load checkpoint
207
+ path = hf_hub_download("szymonrucinski/FTerViT", CHECKPOINT)
208
+ sd = torch.load(path, map_location="cpu", weights_only=False)
209
+ sd = {k.removeprefix("timm_model."): v for k, v in sd.items()}
210
+ model.load_state_dict(sd, strict=False)
211
+ model = model.to(device).eval()
212
+
213
+ # 3. Build eval dataloader
214
+ from timm.data import resolve_data_config, create_transform
215
+ config = resolve_data_config({}, model=timm.create_model(MODEL_NAME, pretrained=False))
216
+
217
+ if DATASET == "imagenet":
218
+ transform = create_transform(**config, is_training=False)
219
+ val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=transform)
220
+ else:
221
+ # CIFAR models were trained with mean/std = [0.5, 0.5, 0.5]
222
+ transform = transforms.Compose([
223
+ transforms.Resize((config["input_size"][1], config["input_size"][2])),
224
+ transforms.ToTensor(),
225
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
226
+ ])
227
+ cls = datasets.CIFAR10 if NUM_CLASSES == 10 else datasets.CIFAR100
228
+ val_dataset = cls(root=DATA_DIR, train=False, download=True, transform=transform)
229
+
230
+ val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
231
+ num_workers=4, pin_memory=True)
232
+
233
+ # 4. Evaluate
234
+ correct = total = 0
235
+ with torch.no_grad():
236
+ for images, labels in val_loader:
237
+ images, labels = images.to(device), labels.to(device)
238
+ preds = model(images).argmax(dim=1)
239
+ correct += (preds == labels).sum().item()
240
+ total += labels.size(0)
241
+
242
+ print(f"Top-1 accuracy: {correct / total:.4f} ({correct / total * 100:.2f}%)")
243
+ print(f"Evaluated {total} samples")
244
  ```
245
 
246
  ## Citation