| --- |
| language: en |
| tags: |
| - image-classification |
| - image-captioning |
|
|
| --- |
| |
| # Poster2Plot |
|
|
| An image captioning model to generate movie/t.v show plot from poster. It generates decent plots but is no way perfect. We are still working on improving the model. |
|
|
| ## Live demo on Hugging Face Spaces: https://huggingface.co/spaces/deepklarity/poster2plot |
|
|
| # Model Details |
|
|
| The base model uses a Vision Transformer (ViT) model as an image encoder and GPT-2 as a decoder. |
|
|
| We used the following models: |
|
|
| * Encoder: [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) |
| * Decoder: [gpt2](https://huggingface.co/gpt2) |
|
|
| # Datasets |
|
|
| Publicly available IMDb datasets were used to train the model. |
|
|
| # How to use |
|
|
| ## In PyTorch |
|
|
| ```python |
| import torch |
| import re |
| import requests |
| from PIL import Image |
| from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel |
| |
| # Pattern to ignore all the text after 2 or more full stops |
| regex_pattern = "[.]{2,}" |
| |
| |
| def post_process(text): |
| try: |
| text = text.strip() |
| text = re.split(regex_pattern, text)[0] |
| except Exception as e: |
| print(e) |
| pass |
| return text |
| |
| |
| def predict(image, max_length=64, num_beams=4): |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values |
| pixel_values = pixel_values.to(device) |
| |
| with torch.no_grad(): |
| output_ids = model.generate( |
| pixel_values, |
| max_length=max_length, |
| num_beams=num_beams, |
| return_dict_in_generate=True, |
| ).sequences |
| |
| preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| pred = post_process(preds[0]) |
| |
| return pred |
| |
| |
| model_name_or_path = "deepklarity/poster2plot" |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| |
| # Load model. |
| |
| model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path) |
| model.to(device) |
| print("Loaded model") |
| |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path) |
| print("Loaded feature_extractor") |
| |
| tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True) |
| if model.decoder.name_or_path == "gpt2": |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| print("Loaded tokenizer") |
| |
| url = "https://upload.wikimedia.org/wikipedia/en/2/26/Moana_Teaser_Poster.jpg" |
| with Image.open(requests.get(url, stream=True).raw) as image: |
| pred = predict(image) |
| |
| print(pred) |
| |
| ``` |
|
|
|
|
|
|
|
|