| import evaluate |
| from evaluate.utils.file_utils import add_start_docstrings |
| import datasets |
| import torch |
| from transformers import CLIPProcessor, CLIPModel |
| from tqdm import tqdm |
|
|
| _DESCRIPTION = """ |
| This metric evaluates CLIP models on image-text retrieval tasks using standard datasets. |
| It calculates Recall@K metrics for both text-to-image and image-to-text retrieval. |
| """ |
|
|
| _KWARGS_DESCRIPTION = """ |
| Args: |
| model_name: Name or path of the CLIP model to evaluate (e.g., "openai/clip-vit-base-patch32") |
| dataset_names: List of dataset names to evaluate on (choices: "mscoco", "flickr") |
| n_examples: Number of examples to use for evaluation (-1 for all) |
| |
| Returns: |
| Dictionary containing Recall@K metrics for each dataset and retrieval direction |
| """ |
|
|
| _CITATION = """ |
| @inproceedings{radford2021learning, |
| title={Learning transferable visual models from natural language supervision}, |
| author={Radford, Alec and Kim, Jong Wook and Hallacy, Chris and Ramesh, Aditya and others}, |
| booktitle={International Conference on Machine Learning}, |
| year={2021}, |
| } |
| """ |
|
|
| @add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| class DmxClipEval(evaluate.Metric): |
| def _info(self): |
| return evaluate.MetricInfo( |
| module_type="metric", |
| description=_DESCRIPTION, |
| citation=_CITATION, |
| inputs_description=_KWARGS_DESCRIPTION, |
| features=datasets.Features( |
| { |
| "dataset_names": datasets.Value("string"), |
| } |
| ), |
| ) |
|
|
| def clip_dataset_evaluator( |
| self, model, device, dataset_name="mscoco", n_examples=-1 |
| ): |
| processor = CLIPProcessor.from_pretrained(model.config._name_or_path) |
| if dataset_name == "mscoco": |
| ds = datasets.load_dataset( |
| "clip-benchmark/wds_mscoco_captions", split="test" |
| ) |
| elif dataset_name == "flickr": |
| ds = datasets.load_dataset("clip-benchmark/wds_flickr8k", split="test") |
| else: |
| raise ValueError(f"invalid dataset name : {dataset_name}") |
|
|
| if n_examples != -1: |
| ds = ds.select(range(min(n_examples, len(ds)))) |
|
|
| dl = torch.utils.data.DataLoader(torch.arange(len(ds)), batch_size=8) |
| all_image_embeds = [] |
| all_text_embeds = [] |
|
|
| for indices in tqdm(dl, desc=f"Processing {dataset_name}"): |
| batch = ds[indices.tolist()] |
| inputs = processor( |
| text=batch["txt"], |
| images=batch["jpg"], |
| return_tensors="pt", |
| padding=True, |
| ) |
| inputs["input_ids"] = inputs["input_ids"][:, :77] |
| inputs["attention_mask"] = inputs["attention_mask"][:, :77] |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| output = model(**inputs) |
|
|
| all_image_embeds.append(output.image_embeds.cpu()) |
| all_text_embeds.append(output.text_embeds.cpu()) |
|
|
| all_image_embeds = torch.cat(all_image_embeds, dim=0) |
| all_text_embeds = torch.cat(all_text_embeds, dim=0) |
| text_img_sim = all_text_embeds @ all_image_embeds.t() |
|
|
| def get_top_k(sim_mat, k_arr): |
| ordered_winners = torch.argsort(sim_mat, dim=-1, descending=True) |
| correct_winner_mask = ( |
| ordered_winners |
| == torch.arange(ordered_winners.shape[0]) |
| .unsqueeze(1) |
| .to(ordered_winners.device) |
| ).long() |
| return [ |
| correct_winner_mask[:, :k].sum(-1).float().mean().item() for k in k_arr |
| ] |
|
|
| k_arr = [1, 5, 10] |
| metrics = { |
| **{ |
| f"{dataset_name}:image_recall@{k}": val |
| for k, val in zip(k_arr, get_top_k(text_img_sim, k_arr)) |
| }, |
| **{ |
| f"{dataset_name}:text_recall@{k}": val |
| for k, val in zip(k_arr, get_top_k(text_img_sim.t(), k_arr)) |
| }, |
| } |
| return metrics |
|
|
| def clip_evaluator(self, model, device, n_examples=-1): |
| metrics = {} |
| for dataset_name in ["mscoco", "flickr"]: |
| metrics.update( |
| self.clip_dataset_evaluator(model, device, dataset_name, n_examples) |
| ) |
| return metrics |
|
|
| def _compute(self, model, dataset_names, n_examples, **kwargs): |
| dataset = dataset_names[0] |
| num_examples = n_examples[0] |
| model_input = model[0] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| if isinstance(model_input, str): |
| actual_model = CLIPModel.from_pretrained(model_input).to(device) |
| else: |
| actual_model = model_input |
| |
| datasets_to_evaluate = [dataset] |
| |
| metrics = {} |
| for ds_name in datasets_to_evaluate: |
| dataset_metrics = self.clip_dataset_evaluator( |
| model=actual_model, |
| device=device, |
| dataset_name=ds_name, |
| n_examples=num_examples, |
| ) |
| metrics.update(dataset_metrics) |
|
|
| return metrics |