File size: 2,216 Bytes
d31c7d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (C) 2026 Embedl AB
"""Run inference on the Embedl Mobilevit Small INT8 model via torch.export.

This script loads the shipped ``embedl_mobilevit_small_int8.pt2``
artifact with ``torch.export.load`` and runs a single image through
it. No TensorRT or ONNX runtime is required — just PyTorch.

Usage::

    python infer_pt2.py --image path/to/image.jpg
"""

import argparse
from pathlib import Path

import numpy as np
import torch
from PIL import Image

PT2_PATH = Path(__file__).with_name("embedl_mobilevit_small_int8.pt2")
INPUT_SIZE = (256, 256)
MEAN = np.array([0.0, 0.0, 0.0], dtype=np.float32)
STD = np.array([1.0, 1.0, 1.0], dtype=np.float32)


def preprocess(image_path: Path) -> torch.Tensor:
    # MobileViT-Small uses BGR channel order, [0, 1] range, NO mean/std
    # normalization (matches the upstream HF processor: do_normalize=None).
    image = Image.open(image_path).convert("RGB").resize(INPUT_SIZE)
    arr = np.asarray(image, dtype=np.float32) / 255.0
    arr = (arr - MEAN) / STD
    arr = arr[..., ::-1].copy()          # RGB -> BGR
    arr = arr.transpose(2, 0, 1)[None]   # NCHW
    return torch.from_numpy(arr)


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--image", required=True, type=Path)
    parser.add_argument("--topk", type=int, default=5)
    args = parser.parse_args()

    if not PT2_PATH.exists():
        raise SystemExit(
            f"Expected {PT2_PATH.name} next to this script. "
            "Did you `huggingface-cli download` the repo?"
        )

    # The ExportedProgram captured the model in eval mode at export
    # time, so no further .eval() / no_grad toggling is needed (and
    # neither is supported on the .module() wrapper).
    model = torch.export.load(str(PT2_PATH)).module()

    x = preprocess(args.image)
    logits = model(x)
    probs = torch.softmax(logits, dim=-1).squeeze(0)

    topk_vals, topk_idx = probs.topk(args.topk)
    print(f"Top-{args.topk} predictions for {args.image}:")
    for rank, (idx, val) in enumerate(zip(topk_idx.tolist(), topk_vals.tolist()), 1):
        print(f"  {rank}. class {idx:>5d}  ({val * 100:5.2f}%)")


if __name__ == "__main__":
    main()