| from abc import ABC, abstractmethod |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| from PIL import Image |
| from transformers import BatchEncoding, BatchFeature |
|
|
| from .torch_utils import get_torch_device |
|
|
|
|
| class BaseVisualRetrieverProcessor(ABC): |
| """ |
| Base class for visual retriever processors. |
| """ |
|
|
| @abstractmethod |
| def process_images( |
| self, |
| images: List[Image.Image], |
| ) -> Union[BatchFeature, BatchEncoding]: |
| pass |
|
|
| @abstractmethod |
| def process_queries( |
| self, |
| queries: List[str], |
| max_length: int = 50, |
| suffix: Optional[str] = None, |
| ) -> Union[BatchFeature, BatchEncoding]: |
| pass |
|
|
| @abstractmethod |
| def score( |
| self, |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| device: Optional[Union[str, torch.device]] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| pass |
|
|
| @staticmethod |
| def score_single_vector( |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| device: Optional[Union[str, torch.device]] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute the dot product score for the given single-vector query and passage embeddings. |
| """ |
| device = device or get_torch_device("auto") |
|
|
| if len(qs) == 0: |
| raise ValueError("No queries provided") |
| if len(ps) == 0: |
| raise ValueError("No passages provided") |
|
|
| qs_stacked = torch.stack(qs).to(device) |
| ps_stacked = torch.stack(ps).to(device) |
|
|
| scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) |
| assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
| scores = scores.to(torch.float32) |
| return scores |
|
|
| @staticmethod |
| def score_multi_vector( |
| qs: List[torch.Tensor], |
| ps: List[torch.Tensor], |
| batch_size: int = 128, |
| device: Optional[Union[str, torch.device]] = None, |
| ) -> torch.Tensor: |
| """ |
| Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. |
| """ |
| device = device or get_torch_device("auto") |
|
|
| if len(qs) == 0: |
| raise ValueError("No queries provided") |
| if len(ps) == 0: |
| raise ValueError("No passages provided") |
|
|
| scores_list: List[torch.Tensor] = [] |
|
|
| for i in range(0, len(qs), batch_size): |
| scores_batch = [] |
| qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i : i + batch_size], batch_first=True, padding_value=0).to( |
| device |
| ) |
| for j in range(0, len(ps), batch_size): |
| ps_batch = torch.nn.utils.rnn.pad_sequence( |
| ps[j : j + batch_size], batch_first=True, padding_value=0 |
| ).to(device) |
| scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) |
| scores_batch = torch.cat(scores_batch, dim=1).cpu() |
| scores_list.append(scores_batch) |
|
|
| scores = torch.cat(scores_list, dim=0) |
| assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
| scores = scores.to(torch.float32) |
| return scores |
|
|
| @abstractmethod |
| def get_n_patches( |
| self, |
| image_size: Tuple[int, int], |
| patch_size: int = 14, |
| *args, |
| **kwargs, |
| ) -> Tuple[int, int]: |
| """ |
| Get the number of patches (n_patches_x, n_patches_y) that will be used to process an |
| image of size (height, width) with the given patch size. |
| """ |
| pass |
|
|