ViT-Small + LoRA β CIFAR-100
Parameter-efficient fine-tuning of ViT-S/16 on CIFAR-100 Β· Val Acc: 90.46% Β· Test Acc: 90.44%
Overview
This repository contains best_model.pt β the full merged state dict of a ViT-Small/16 model fine-tuned on CIFAR-100 using Low-Rank Adaptation (LoRA). Only the Q, K, V attention projections and classification head were updated during training. All other weights remain frozen.
| Base model | WinKawaks/vit-small-patch16-224 (ImageNet pre-trained) |
| Dataset | CIFAR-100 (50,000 train Β· 10,000 test Β· 100 classes) |
| Method | LoRA on Query, Key, Value attention projections + trainable head |
| Best config | rank=8, alpha=8, dropout=0.1 |
| Trainable params | 259,684 / 21,925,348 (1.18%) |
| Val accuracy | 90.46% (+9.69 pp over frozen-backbone baseline) |
| Test accuracy | 90.44% |
| Hardware | NVIDIA GTX 1080 Ti (11.7 GB VRAM) |
Architecture
ViT-Small/16 (patch=16, dim=384, heads=6, layers=12)
βββ Patch Embedding [frozen]
βββ Transformer Encoder Γ 12
β βββ Multi-Head Self-Attention
β β βββ Query ββ LoRA(AΒ·B, r=8) β
trained
β β βββ Key ββ LoRA(AΒ·B, r=8) β
trained
β β βββ Value ββ LoRA(AΒ·B, r=8) β
trained
β βββ LayerNorm [frozen]
β βββ MLP (FFN) [frozen]
βββ Classification Head (384 β 100) β
trained
LoRA update rule: W' = W + (Ξ±/r) Β· BΒ·A
where W is frozen, A β β^{rΓd} and B β β^{dΓr} are learned.
With r=8 and Ξ±=8, the scaling factor Ξ±/r = 1.0.
Hyperparameters
LoRA (best configuration)
| Parameter | Value |
|---|---|
| Rank (r) | 8 |
| Alpha (Ξ±) | 8 |
| Scaling (Ξ±/r) | 1.0 |
| Dropout | 0.1 |
| Target modules | query, key, value |
| Bias | none |
| Trainable params | 259,684 (1.18%) |
Training
| Parameter | Value |
|---|---|
| Optimizer | AdamW |
| Learning rate | 3e-4 |
| Weight decay | 1e-4 |
| LR scheduler | CosineAnnealingLR |
| Batch size | 128 |
| Epochs | 10 |
| Input resolution | 224 Γ 224 |
Data augmentation (train)
| Transform | Setting |
|---|---|
| RandomHorizontalFlip | p = 0.5 |
| RandomCrop | 224 Γ 224, padding = 28 |
| ColorJitter | brightness=0.3, contrast=0.3, saturation=0.3, hue=0.05 |
| Normalize mean | (0.5071, 0.4867, 0.4408) |
| Normalize std | (0.2675, 0.2565, 0.2761) |
Experiment Results
Grid search β all 10 runs
| Experiment | Rank | Alpha | Dropout | Val Acc | Test Acc | Trainable Params |
|---|---|---|---|---|---|---|
| exp01 β no LoRA (baseline) | β | β | 0.1 | 80.77% | 80.77% | 38,500 |
| exp02 | 2 | 2 | 0.1 | 89.65% | 89.65% | 93,796 |
| exp03 | 2 | 4 | 0.1 | 90.03% | 90.03% | 93,796 |
| exp04 | 2 | 8 | 0.1 | 89.98% | 89.97% | 93,796 |
| exp05 | 4 | 2 | 0.1 | 89.91% | 89.91% | 149,092 |
| exp06 | 4 | 4 | 0.1 | 90.11% | 90.11% | 149,092 |
| exp07 | 4 | 8 | 0.1 | 90.28% | 90.28% | 149,092 |
| exp08 | 8 | 2 | 0.1 | 90.09% | 89.97% | 259,684 |
| exp09 | 8 | 4 | 0.1 | 90.17% | 90.17% | 259,684 |
| exp10 β BEST | 8 | 8 | 0.1 | 90.46% | 90.44% | 259,684 |
Optuna hyperparameter search β 10 trials
Optuna searched over rank β {2, 4, 8}, alpha β {2, 4, 8}, and dropout β [0.05, 0.30].
| Trial | Rank | Alpha | Dropout | Val Acc |
|---|---|---|---|---|
| t0 | 2 | 4 | 0.15 | 90.06% |
| t1 | 4 | 8 | 0.30 | 90.32% |
| t2 | 4 | 2 | 0.15 | 90.03% |
| t3 | 4 | 8 | 0.25 | 90.08% |
| t4 | 4 | 2 | 0.15 | 90.10% |
| t5 β | 8 | 8 | 0.30 | 90.39% |
| t6 | 4 | 4 | 0.05 | 90.27% |
| t7 | 2 | 2 | 0.10 | 89.90% |
| t8 | 8 | 2 | 0.20 | 90.01% |
| t9 | 8 | 4 | 0.15 | 90.06% |
Key findings:
- rank=8, alpha=8 consistently tops the leaderboard across both search phases.
- Higher dropout (0.30 vs 0.10) with the best config yields nearly identical accuracy (90.39% vs 90.46%), confirming robustness.
- Increasing rank beyond 8 or alpha beyond 8 was not explored but is unlikely to yield significant gains given the plateau.
- LoRA provides +9.69 pp over the frozen-backbone baseline at just 1.18% parameter cost.
Quickstart
Install dependencies
pip install torch torchvision transformers peft huggingface_hub Pillow
Load the model and run inference
import torch
from transformers import ViTForImageClassification, ViTImageProcessor
from peft import LoraConfig, get_peft_model
from huggingface_hub import hf_hub_download
from PIL import Image
REPO = "MSG1999/vit-lora-cifar100"
BASE = "WinKawaks/vit-small-patch16-224"
CIFAR100_CLASSES = [
"apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee", "beetle",
"bicycle", "bottle", "bowl", "boy", "bridge", "bus", "butterfly", "camel",
"can", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock",
"cloud", "cockroach", "couch", "crab", "crocodile", "cup", "dinosaur",
"dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster",
"house", "kangaroo", "keyboard", "lamp", "lawn_mower", "leopard", "lion",
"lizard", "lobster", "man", "maple_tree", "motorcycle", "mountain", "mouse",
"mushroom", "oak_tree", "orange", "orchid", "otter", "palm_tree", "pear",
"pickup_truck", "pine_tree", "plain", "plate", "poppy", "porcupine",
"possum", "rabbit", "raccoon", "ray", "road", "rocket", "rose", "sea",
"seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider",
"squirrel", "streetcar", "sunflower", "sweet_pepper", "table", "tank",
"telephone", "television", "tiger", "tractor", "train", "trout", "tulip",
"turtle", "wardrobe", "whale", "willow_tree", "wolf", "woman", "worm",
]
id2label = {i: c for i, c in enumerate(CIFAR100_CLASSES)}
label2id = {c: i for i, c in id2label.items()}
# 1. Reconstruct model with the same LoRA config used during training
base_model = ViTForImageClassification.from_pretrained(
BASE,
num_labels=100,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules=["query", "key", "value"],
bias="none",
)
model = get_peft_model(base_model, lora_config)
# 2. Download and load best_model.pt
ckpt_path = hf_hub_download(repo_id=REPO, filename="best_model.pt")
state_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Model loaded successfully.")
# 3. Inference
processor = ViTImageProcessor.from_pretrained(BASE)
image = Image.open("your_image.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
pred_id = logits.argmax(-1).item()
confidence = logits.softmax(-1)[0, pred_id].item()
print(f"Predicted class : {id2label[pred_id]}")
print(f"Confidence : {confidence * 100:.1f}%")
Batch inference
images = [Image.open(p).convert("RGB") for p in image_paths]
inputs = processor(images=images, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
preds = logits.argmax(-1).tolist()
for path, pred in zip(image_paths, preds):
print(f"{path} β {id2label[pred]}")
Repository files
| File | Description |
|---|---|
best_model.pt |
Full state dict of the best ViT-S + LoRA model (exp10, r=8 Ξ±=8) |
README.md |
This model card |
Training code, logs, and all experiment weights are available in the GitHub repository.
Citation
@misc{gadiya2026vitlora,
title = {ViT-Small + LoRA Fine-tuning on CIFAR-100},
author = {Mahek Gadiya},
year = {2026},
note = {DLOps Assignment 5 β Q1, IIT Jodhpur},
url = {https://huggingface.co/MSG1999/vit-lora-cifar100},
}
DLOps Assignment 5 | IIT Jodhpur |
MSG1999
- Downloads last month
- 84
Model tree for MSG1999/vit-lora-cifar100
Base model
WinKawaks/vit-small-patch16-224Dataset used to train MSG1999/vit-lora-cifar100
Evaluation results
- Validation Accuracy on CIFAR-100self-reported0.905
- Test Accuracy on CIFAR-100self-reported0.904