| import torch, json |
| import torchvision |
| from torchvision import transforms, models |
| from PIL import Image |
|
|
| def build_model(arch, dropout, width, freeze_backbone, num_classes=2): |
| import torch.nn as nn |
| if arch == "smallcnn": |
| class SmallCNN(nn.Module): |
| def __init__(self, num_classes=2, dropout=0.2, width=32): |
| super().__init__() |
| c = width |
| self.features = nn.Sequential( |
| nn.Conv2d(3, c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), |
| nn.Conv2d(c, 2*c, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), |
| nn.Conv2d(2*c, 4*c, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), |
| ) |
| self.head = nn.Sequential(nn.Flatten(), nn.Dropout(dropout), nn.Linear(4*c, num_classes)) |
| def forward(self, x): return self.head(self.features(x)) |
| return SmallCNN(num_classes=num_classes, dropout=dropout, width=width) |
| elif arch == "resnet18": |
| m = models.resnet18(weights=None) |
| in_features = m.fc.in_features |
| import torch.nn as nn |
| m.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_features, num_classes)) |
| return m |
| elif arch == "mobilenet_v3_small": |
| m = models.mobilenet_v3_small(weights=None) |
| in_features = m.classifier[-1].in_features |
| import torch.nn as nn |
| m.classifier[-1] = nn.Linear(in_features, num_classes) |
| return m |
| else: |
| raise ValueError("Unknown arch") |
|
|
| def load_model(model_path="model_state.pt", config_path="config.json", device="cpu"): |
| with open(config_path) as f: |
| cfg = json.load(f) |
| model = build_model(cfg["arch"], cfg["dropout"], cfg["width"], cfg["freeze_backbone"], cfg["num_classes"]) |
| state = torch.load(model_path, map_location=device) |
| model.load_state_dict(state, strict=True) |
| model.to(device).eval() |
| tfm = transforms.Compose([ |
| transforms.Resize(int(cfg["img_size"]*1.14)), |
| transforms.CenterCrop(cfg["img_size"]), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=cfg["mean"], std=cfg["std"]), |
| ]) |
| return model, tfm, cfg |
|
|
| def predict_image(image_path, model, tfm, device="cpu"): |
| img = Image.open(image_path).convert("RGB") |
| x = tfm(img).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logits = model(x) |
| probs = torch.softmax(logits, dim=1).cpu().numpy().ravel().tolist() |
| pred = int(logits.argmax(dim=1).item()) |
| return pred, probs |
|
|