| from transformers import DonutProcessor, VisionEncoderDecoderModel |
| from PIL import Image |
| import torch |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def load_image(image_path): |
| image = Image.open(image_path).convert("RGB") |
| return image |
|
|
|
|
| def load_donut_model_and_processor(trained_model_repo): |
| donut_processor = DonutProcessor.from_pretrained(trained_model_repo) |
| model = VisionEncoderDecoderModel.from_pretrained(trained_model_repo) |
| model.to(device) |
| return donut_processor, model |
|
|
|
|
| def prepare_data_using_processor(donut_processor, image, task_prompt): |
| |
| pixel_values = donut_processor(image, return_tensors="pt").pixel_values |
| pixel_values = pixel_values.to(device) |
|
|
| |
| decoder_input_ids = donut_processor.tokenizer( |
| task_prompt, add_special_tokens=False, return_tensors="pt" |
| )["input_ids"] |
| decoder_input_ids = decoder_input_ids.to(device) |
|
|
| return pixel_values, decoder_input_ids |
|
|