JaydeepR commited on
Commit
204d384
·
verified ·
1 Parent(s): 5eb4eea

Create summarization_model.py

Browse files
Files changed (1) hide show
  1. summarization_model.py +45 -0
summarization_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BlipProcessor, BlipForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration
3
+
4
+ from PIL import Image
5
+
6
+ # Load the BLIP model and processor for image description
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
9
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
10
+
11
+ # Load the T5 model for text summarization
12
+ t5_tokenizer = T5Tokenizer.from_pretrained("t5-small")
13
+ t5_model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
14
+
15
+ def generate_description(image: Image.Image) -> str:
16
+ """
17
+ Generates a detailed description for the given image.
18
+
19
+ Parameters:
20
+ image (PIL.Image.Image): The input image.
21
+
22
+ Returns:
23
+ str: The generated description.
24
+ """
25
+ inputs = processor(images=image, return_tensors="pt").to(device)
26
+ outputs = blip_model.generate(**inputs)
27
+ description = processor.decode(outputs[0], skip_special_tokens=True)
28
+ return description
29
+
30
+ def summarize_text_and_image(description: str, ocr_text: str) -> str:
31
+ """
32
+ Generates a summary combining the image description and OCR-extracted text.
33
+
34
+ Parameters:
35
+ description (str): The generated description of the image.
36
+ ocr_text (str): The text extracted from the image using OCR.
37
+
38
+ Returns:
39
+ str: The generated summary.
40
+ """
41
+ combined_input = f"Image Description: {description} Text: {ocr_text}"
42
+ input_ids = t5_tokenizer.encode(combined_input, return_tensors="pt", truncation=True).to(device)
43
+ outputs = t5_model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True)
44
+ summary = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
45
+ return summary