| import torch |
| from PIL import Image |
| from transformers import AutoModel, AutoProcessor |
| from typing import List, Union, Optional |
|
|
|
|
| class OpsColQwen3Embedder: |
| """ |
| Embedder for OpsColQwen3-4B model. |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = "OpenSearch-AI/Ops-Colqwen3-4B", |
| dims: int = 2560, |
| device: Optional[str] = None, |
| **kwargs |
| ): |
| """ |
| Initialize the embedder. |
| |
| Args: |
| model_name: Model path or hub name |
| dims: Embedding dimensions |
| device: Device to use for inference ('mps', 'cuda', or 'cpu') |
| **kwargs: Additional arguments passed to from_pretrained |
| """ |
|
|
| device_map = kwargs.pop('device_map', None) |
| if not device_map: |
| if device: |
| device_map = device |
| elif torch.cuda.is_available(): |
| device_map = "cuda" |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| device_map = "mps" |
| else: |
| device_map = "cpu" |
|
|
| dtype = kwargs.pop('dtype', torch.float16 if device_map != "cpu" else torch.float32) |
|
|
| self.model = AutoModel.from_pretrained( |
| model_name, |
| dims=dims, |
| trust_remote_code=True, |
| dtype=dtype, |
| device_map=device_map, |
| **kwargs |
| ) |
| self.model.eval() |
|
|
| self.processor = AutoProcessor.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| **kwargs |
| ) |
|
|
| self.device = device_map |
| self.dims = dims |
|
|
| def encode_queries( |
| self, |
| queries: List[str] |
| ) -> List[torch.Tensor]: |
| """ |
| Encode a list of text queries. |
| |
| Args: |
| queries: List of query texts |
| |
| Returns: |
| List of query embeddings |
| """ |
| query_inputs = self.processor.process_queries(queries) |
| query_inputs = {k: v.to(self.device) for k, v in query_inputs.items()} |
|
|
| with torch.no_grad(): |
| query_embeddings = self.model(**query_inputs) |
|
|
| return [q.cpu() for q in query_embeddings] |
|
|
| def encode_images( |
| self, |
| images: List[Union[str, Image.Image]] |
| ) -> List[torch.Tensor]: |
| """ |
| Encode a list of images. |
| |
| Args: |
| images: List of image paths or PIL Images |
| |
| Returns: |
| List of image embeddings |
| """ |
| image_objects = [] |
| for img in images: |
| if isinstance(img, str): |
| image_objects.append(Image.open(img).convert("RGB")) |
| elif isinstance(img, Image.Image): |
| image_objects.append(img) |
| else: |
| raise ValueError(f"Unsupported image type: {type(img)}") |
|
|
| image_inputs = self.processor.process_images(image_objects) |
| image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} |
|
|
| with torch.no_grad(): |
| image_embeddings = self.model(**image_inputs) |
|
|
| return [i.cpu() for i in image_embeddings] |
|
|
| def compute_scores( |
| self, |
| query_embeddings: List[torch.Tensor], |
| image_embeddings: List[torch.Tensor] |
| ) -> torch.Tensor: |
| """ |
| Compute similarity scores between queries and images. |
| |
| Args: |
| query_embeddings: List of query embeddings |
| image_embeddings: List of image embeddings |
| |
| Returns: |
| Similarity scores matrix |
| """ |
| return self.processor.score_multi_vector(query_embeddings, image_embeddings) |
|
|
| def encode_and_score( |
| self, |
| queries: List[str], |
| images: List[Union[str, Image.Image]] |
| ): |
| """ |
| Convenience method to encode queries and images and compute scores. |
| |
| Args: |
| queries: List of query texts |
| images: List of images (paths or PIL objects) |
| |
| Returns: |
| Similarity scores between queries and images |
| """ |
| query_embeddings = self.encode_queries(queries) |
| image_embeddings = self.encode_images(images) |
| return self.compute_scores(query_embeddings, image_embeddings) |
|
|
|
|
| |
| if __name__ == "__main__": |
| images = [Image.new("RGB", (32, 32), color="white"), Image.new("RGB", (16, 16), color="black")] |
| queries = ["Is attention really all you need?", "What is the amount of bananas farmed in Salvador?"] |
|
|
| embedder = OpsColQwen3Embedder( |
| model_name="OpenSearch-AI/Ops-Colqwen3-4B", |
| dims=2560, |
| dtype=torch.float16, |
| attn_implementation="flash_attention_2", |
| ) |
|
|
| query_embeddings = embedder.encode_queries(queries) |
| image_embeddings = embedder.encode_images(images) |
| print(query_embeddings[0].shape, image_embeddings[0].shape) |
|
|
| scores = embedder.compute_scores(query_embeddings, image_embeddings) |
|
|
| print(f"Scores:\n{scores}") |