| import torch |
| from PIL import Image |
| from .preprocess import preprocess_image |
| from .utils import load_model |
|
|
|
|
| def predict_with_model(model, inputs): |
| """Runs inference and returns the predicted class.""" |
| model.eval() |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| predicted_class = logits.argmax(dim=-1).item() |
| return predicted_class |
|
|
|
|
| def predict(image_path): |
| """Loads an image, preprocesses it, runs the model, and returns the prediction.""" |
| image = Image.open(image_path).convert("RGB") |
| inputs = preprocess_image(image) |
|
|
| |
| model = load_model() |
|
|
| |
| device = model.device |
| inputs = {key: tensor.to(device) for key, tensor in inputs.items()} |
|
|
| return predict_with_model(model, inputs) |
|
|