| """ |
| Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py |
| But accepts preloaded model to avoid slowness in use and CUDA forking issues |
| |
| Loader that uses Pix2Struct models to image caption |
| |
| """ |
| from typing import List, Union, Any, Tuple |
|
|
| from langchain.docstore.document import Document |
| from langchain.document_loaders import ImageCaptionLoader |
| from utils import get_device, clear_torch_cache |
| from PIL import Image |
|
|
|
|
| class H2OPix2StructLoader(ImageCaptionLoader): |
| """Loader that extracts text from images""" |
|
|
| def __init__(self, path_images: Union[str, List[str]] = None, model_type="google/pix2struct-textcaps-base", |
| max_new_tokens=50): |
| super().__init__(path_images) |
| self._pix2struct_model = None |
| self._model_type = model_type |
| self._max_new_tokens = max_new_tokens |
|
|
| def set_context(self): |
| if get_device() == 'cuda': |
| import torch |
| n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
| if n_gpus > 0: |
| self.context_class = torch.device |
| self.device = 'cuda' |
| else: |
| self.device = 'cpu' |
| else: |
| self.device = 'cpu' |
|
|
| def load_model(self): |
| try: |
| from transformers import AutoProcessor, Pix2StructForConditionalGeneration |
| except ImportError: |
| raise ValueError( |
| "`transformers` package not found, please install with " |
| "`pip install transformers`." |
| ) |
| if self._pix2struct_model: |
| self._pix2struct_model = self._pix2struct_model.to(self.device) |
| return self |
| self.set_context() |
| self._pix2struct_processor = AutoProcessor.from_pretrained(self._model_type) |
| self._pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(self._model_type).to(self.device) |
| return self |
|
|
| def unload_model(self): |
| if hasattr(self._pix2struct_model, 'cpu'): |
| self._pix2struct_model.cpu() |
| clear_torch_cache() |
|
|
| def set_image_paths(self, path_images: Union[str, List[str]]): |
| """ |
| Load from a list of image files |
| """ |
| if isinstance(path_images, str): |
| self.image_paths = [path_images] |
| else: |
| self.image_paths = path_images |
|
|
| def load(self, prompt=None) -> List[Document]: |
| if self._pix2struct_model is None: |
| self.load_model() |
| results = [] |
| for path_image in self.image_paths: |
| caption, metadata = self._get_captions_and_metadata( |
| processor=self._pix2struct_processor, model=self._pix2struct_model, path_image=path_image |
| ) |
| doc = Document(page_content=caption, metadata=metadata) |
| results.append(doc) |
|
|
| return results |
|
|
| def _get_captions_and_metadata( |
| self, processor: Any, model: Any, path_image: str) -> Tuple[str, dict]: |
| """ |
| Helper function for getting the captions and metadata of an image |
| """ |
| try: |
| image = Image.open(path_image) |
| except Exception: |
| raise ValueError(f"Could not get image data for {path_image}") |
| inputs = self._pix2struct_processor(images=image, return_tensors="pt") |
| inputs = inputs.to(self.device) |
| generated_ids = self._pix2struct_model.generate(**inputs, max_new_tokens=self._max_new_tokens) |
| generated_text = self._pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| metadata: dict = {"image_path": path_image} |
| return generated_text, metadata |
|
|