| import torch |
| from model import LVL |
| from transformers import RobertaTokenizer |
| from PIL import Image |
| from torchvision import transforms |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = LVL() |
| model.load_state_dict(torch.load("scold.pth", map_location=device)) |
| model.to(device) |
| model.eval() |
|
|
| |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor() |
| ]) |
|
|
|
|
| def predict(image_path, text): |
| image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device) |
| tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) |
|
|
| with torch.no_grad(): |
| img_feat, txt_feat = model(image, tokens["input_ids"], tokens["attention_mask"]) |
| similarity = torch.matmul(img_feat, txt_feat.T).squeeze() |
|
|
| return similarity.item() |
|
|