| import numpy as np |
|
|
| import torch |
| from transformers import GPT2TokenizerFast |
| from .models import VisionGPT2Model |
|
|
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
|
|
| from PIL import Image |
| import matplotlib.pyplot as plt |
| from types import SimpleNamespace |
| import pathlib |
| from tkinter import filedialog |
|
|
| def download(url:str, filename:str)->pathlib.Path: |
| import functools |
| import shutil |
| import requests |
| from tqdm.auto import tqdm |
| |
| r = requests.get(url, stream=True, allow_redirects=True) |
| if r.status_code != 200: |
| r.raise_for_status() |
| raise RuntimeError(f"Request to {url} returned status code {r.status_code}\n Please download the captioner.pt file manually from the link provided in the README.md file.") |
| file_size = int(r.headers.get('Content-Length', 0)) |
|
|
| path = pathlib.Path(filename).expanduser().resolve() |
| path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| desc = "(Unknown total file size)" if file_size == 0 else "" |
| r.raw.read = functools.partial(r.raw.read, decode_content=True) |
| with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw: |
| with path.open("wb") as f: |
| shutil.copyfileobj(r_raw, f) |
|
|
| return path |
|
|
| def main(): |
| model_config = SimpleNamespace( |
| vocab_size = 50257, |
| embed_dim = 768, |
| num_heads = 12, |
| seq_len = 1024, |
| depth = 12, |
| attention_dropout = 0.1, |
| residual_dropout = 0.1, |
| mlp_ratio = 4, |
| mlp_dropout = 0.1, |
| emb_dropout = 0.1, |
| ) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| model = VisionGPT2Model(model_config).to(device) |
| try: |
| sd = torch.load("captioner.pt", map_location=device) |
| except: |
| print("Model not found. Downloading Model ") |
| url = "https://drive.usercontent.google.com/download?id=1X51wAI7Bsnrhd2Pa4WUoHIXvvhIcRH7Y&export=download&authuser=0&confirm=t&uuid=ae5c4861-4411-4f81-88cd-66ea30b6fe2b&at=APZUnTWodeDt1upcQVMej2TDcADs%3A1722666079498" |
| path = download(url, "captioner.pt") |
| sd = torch.load(path, map_location=device) |
|
|
| model.load_state_dict(sd) |
| model.eval() |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
|
| tfms = A.Compose([ |
| A.Resize(224, 224), |
| A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5],always_apply=True), |
| ToTensorV2() |
| ]) |
|
|
| test_img:str = filedialog.askopenfilename(title = "Select an image", |
| filetypes = (("jpeg files","*.jpg"),("png files",'*.png'),("all files","*.*"))) |
|
|
| im = Image.open(test_img).convert("RGB") |
|
|
| det = True |
| temp = 1.0 |
| max_tokens = 50 |
|
|
| image = np.array(im) |
| image:torch.Tensor = tfms(image=image)['image'] |
| image = image.unsqueeze(0).to(device) |
| seq = torch.ones(1,1).to(device).long()*tokenizer.bos_token_id |
|
|
| caption = model.generate(image, seq, max_tokens, temp, det) |
| caption = tokenizer.decode(caption.numpy(), skip_special_tokens=True) |
|
|
| plt.imshow(im) |
| plt.title(f"Predicted : {caption}") |
| plt.axis('off') |
| plt.show() |
|
|
|
|
| if __name__ == "__main__": |
| main() |