dann-od's picture
MobileVit-Small First commit
d31c7d4 verified
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl Mobilevit Small INT8 model via torch.export.
This script loads the shipped ``embedl_mobilevit_small_int8.pt2``
artifact with ``torch.export.load`` and runs a single image through
it. No TensorRT or ONNX runtime is required — just PyTorch.
Usage::
python infer_pt2.py --image path/to/image.jpg
"""
import argparse
from pathlib import Path
import numpy as np
import torch
from PIL import Image
PT2_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.pt2")
INPUT_SIZE = (256, 256)
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)
def preprocess(image_path: Path) -> torch.Tensor:
# MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
# normalization (matches the upstream HF processor: do_normalize=None).
image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
arr = np.asarray(image, dtype=np.float32) / 255.0
arr = (arr - MEAN) / STD
arr = arr[..., ::-1].copy() # RGB -> BGR
arr = arr.transpose(2, 0, 1)[None] # NCHW
return torch.from_numpy(arr)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--image", required=True, type=Path)
parser.add_argument("--topk", type=int, default=5)
args = parser.parse_args()
if not PT2_PATH.exists():
raise SystemExit(
f"Expected {PT2_PATH.name} next to this script. "
"Did you `huggingface-cli download` the repo?"
)
# The ExportedProgram captured the model in eval mode at export
# time, so no further .eval() / no_grad toggling is needed (and
# neither is supported on the .module() wrapper).
model = torch.export.load(str(PT2_PATH)).module()
x = preprocess(args.image)
logits = model(x)
probs = torch.softmax(logits, dim=-1).squeeze(0)
topk_vals, topk_idx = probs.topk(args.topk)
print(f"Top-{args.topk} predictions for {args.image}:")
for rank, (idx, val) in enumerate(zip(topk_idx.tolist(), topk_vals.tolist()), 1):
print(f" {rank}. class {idx:>5d} ({val * 100:5.2f}%)")
if __name__ == "__main__":
main()