|
|
| import json, torch, timm |
| from PIL import Image |
| from safetensors.torch import load_file |
| from torchvision import transforms |
|
|
| MODEL_NAME = "vit_base_patch16_224" |
| IMG_SIZE = 224 |
| MEAN = [0.485, 0.456, 0.406] |
| STD = [0.229, 0.224, 0.225] |
|
|
| def load_model(repo_dir="."): |
| with open(f"{repo_dir}/config.json") as f: |
| cfg = json.load(f) |
|
|
| model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=cfg["num_labels"]) |
| state = load_file(f"{repo_dir}/model.safetensors") |
| model.load_state_dict(state) |
| model.eval() |
| return model, cfg |
|
|
| def predict(image_path, repo_dir="."): |
| model, cfg = load_model(repo_dir) |
|
|
| tfm = transforms.Compose([ |
| transforms.Resize(256), |
| transforms.CenterCrop(IMG_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize(MEAN, STD), |
| ]) |
|
|
| img = Image.open(image_path).convert("RGB") |
| x = tfm(img).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| logits = model(x) |
| pred = logits.argmax(-1).item() |
|
|
| return cfg["id2label"][str(pred)] |
|
|