File size: 1,837 Bytes
204d384
 
e1c2355
204d384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration
import sentencepiece as spm
from PIL import Image

# Load the BLIP model and processor for image description
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

# Load the T5 model for text summarization
t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)

def generate_description(image: Image.Image) -> str:
    """
    Generates a detailed description for the given image.

    Parameters:
    image (PIL.Image.Image): The input image.

    Returns:
    str: The generated description.
    """
    inputs = processor(images=image, return_tensors="pt").to(device)
    outputs = blip_model.generate(**inputs)
    description = processor.decode(outputs[0], skip_special_tokens=True)
    return description

def summarize_text_and_image(description: str, ocr_text: str) -> str:
    """
    Generates a summary combining the image description and OCR-extracted text.

    Parameters:
    description (str): The generated description of the image.
    ocr_text (str): The text extracted from the image using OCR.

    Returns:
    str: The generated summary.
    """
    combined_input = f"Image Description: {description} Text: {ocr_text}"
    input_ids = t5_tokenizer.encode(combined_input, return_tensors="pt", truncation=True).to(device)
    outputs = t5_model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True)
    summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary