|
|
| |
| import re, json, torch, torch.nn.functional as F |
| from pathlib import Path |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import BertTokenizer |
| |
| |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| META = json.load(open("deployment/model_meta.json")) |
| CONFIG = META["config"] |
|
|
| tokenizer = BertTokenizer.from_pretrained(CONFIG["BERT_MODEL"]) |
|
|
| img_transform = transforms.Compose([ |
| transforms.Resize((CONFIG["IMAGE_SIZE"], CONFIG["IMAGE_SIZE"])), |
| transforms.ToTensor(), |
| transforms.Normalize(META["img_mean"], META["img_std"]), |
| ]) |
|
|
| def load_model(): |
| model = MultimodalSentimentModel(CONFIG).to(DEVICE) |
| ckpt = torch.load("deployment/best_model.pt", map_location=DEVICE) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| return model |
|
|
| def predict(model, text: str, image_path: str) -> dict: |
| text = re.sub(r"http\S+", "", text) |
| text = re.sub(r"@\w+", "", text) |
| text = re.sub(r"#(\w+)", r"\1", text).strip() or "no text" |
|
|
| enc = tokenizer(text, max_length=CONFIG["MAX_TEXT_LEN"], |
| padding="max_length", truncation=True, return_tensors="pt") |
| input_ids = enc["input_ids"].to(DEVICE) |
| attention_mask = enc["attention_mask"].to(DEVICE) |
|
|
| img = img_transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(DEVICE) |
|
|
| with torch.no_grad(): |
| logits = model(input_ids, attention_mask, img) |
| probs = F.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
| pred_idx = probs.argmax() |
| return { |
| "label" : META["label_names"][pred_idx], |
| "confidence" : float(probs[pred_idx]), |
| "probabilities": {n: float(p) for n, p in zip(META["label_names"], probs)}, |
| } |
|
|