| """This module implements the SPICE metric.""" |
|
|
| import os |
| import shutil |
| import subprocess |
| import json |
| import tempfile |
| from typing import List, Dict |
|
|
| import evaluate |
| import datasets |
| from evaluate.utils.logging import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
| CORENLP = "stanford-corenlp-full-2015-12-09" |
| SPICELIB = "lib" |
| SPICE_JAR = "spice-1.0.jar" |
|
|
| _CITATION = """\ |
| @inproceedings{spice2016, |
| title = {SPICE: Semantic Propositional Image Caption Evaluation}, |
| author = {Peter Anderson and Basura Fernando and Mark Johnson and Stephen Gould}, |
| year = {2016}, |
| booktitle = {ECCV} |
| } |
| """ |
|
|
| _DESCRIPTION = """\ |
| This module is designed to evaluate the quality of image captions using the SPICE metric. |
| It compares generated captions with reference captions to assess their semantic similarity. |
| """ |
|
|
| _KWARGS_DESCRIPTION = """ |
| Compute SPICE score. |
| Args: |
| predictions: list of predictions to score. Each predictions |
| should be a string. |
| references: list of reference for each prediction. Each |
| reference should be a string. |
| Returns: |
| spice: SPICE score |
| Examples: |
| >>> metric = evaluate.load("sunhill/spice") |
| >>> results = metric.compute( |
| predictions=[['train traveling down a track in front of a road']], |
| references=[ |
| [ |
| 'a train traveling down tracks next to lights', |
| 'a blue and silver train next to train station and trees', |
| 'a blue train is next to a sidewalk on the rails', |
| 'a passenger train pulls into a train station', |
| 'a train coming down the tracks arriving at a station' |
| ] |
| ] |
| ) |
| >>> print(results) |
| [ |
| { |
| "All": { |
| "pr": 0.25, |
| "re": 0.07142857142857142, |
| "f": 0.11111111111111112, |
| "fn": 13.0, |
| "numImages": 1.0, |
| "fp": 3.0, |
| "tp": 1.0, |
| }, |
| "Relation": { |
| "pr": 0.0, |
| "re": 0.0, |
| "f": 0.0, |
| "fn": 5.0, |
| "numImages": 1.0, |
| "fp": 1.0, |
| "tp": 0.0, |
| }, |
| "Cardinality": { |
| "pr": nan, |
| "re": nan, |
| "f": nan, |
| "fn": 0.0, |
| "numImages": 1.0, |
| "fp": 0.0, |
| "tp": 0.0, |
| }, |
| "Attribute": { |
| "pr": 0.0, |
| "re": 0.0, |
| "f": 0.0, |
| "fn": 4.0, |
| "numImages": 1.0, |
| "fp": 0.0, |
| "tp": 0.0, |
| }, |
| "Size": { |
| "pr": nan, |
| "re": nan, |
| "f": nan, |
| "fn": 0.0, |
| "numImages": 1.0, |
| "fp": 0.0, |
| "tp": 0.0, |
| }, |
| "Color": { |
| "pr": 0.0, |
| "re": 0.0, |
| "f": 0.0, |
| "fn": 1.0, |
| "numImages": 1.0, |
| "fp": 0.0, |
| "tp": 0.0, |
| }, |
| "Object": { |
| "pr": 0.3333333333333333, |
| "re": 0.2, |
| "f": 0.25, |
| "fn": 4.0, |
| "numImages": 1.0, |
| "fp": 2.0, |
| "tp": 1.0, |
| }, |
| } |
| ] |
| """ |
|
|
|
|
| @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| class SPICE(evaluate.Metric): |
| """This module implements the SPICE metric for evaluating image captioning models.""" |
|
|
| def _info(self): |
| return evaluate.MetricInfo( |
| |
| module_type="metric", |
| description=_DESCRIPTION, |
| citation=_CITATION, |
| inputs_description=_KWARGS_DESCRIPTION, |
| |
| features=[ |
| datasets.Features( |
| { |
| "predictions": datasets.Value("string"), |
| "references": datasets.Value("string"), |
| } |
| ), |
| datasets.Features( |
| { |
| "predictions": datasets.Value("string"), |
| "references": datasets.Sequence(datasets.Value("string")), |
| } |
| ), |
| ], |
| |
| homepage="https://huggingface.co/spaces/sunhill/spice", |
| |
| codebase_urls=[ |
| "https://github.com/peteanderson80/SPICE", |
| "https://github.com/EricWWWW/image-caption-metrics", |
| ], |
| reference_urls=["https://panderson.me/spice"], |
| ) |
|
|
| def _download_and_prepare(self, dl_manager): |
| """Optional: download external resources useful to compute the scores""" |
| if os.path.exists("lib/stanford-corenlp-3.6.0-models.jar") and os.path.exists( |
| "lib/stanford-corenlp-3.6.0.jar" |
| ): |
| logger.info("`stanford-corenlp` already exists. Skip downloading.") |
| return |
| logger.info("Downloading `stanford-corenlp`...") |
| url = f"http://nlp.stanford.edu/software/{CORENLP}.zip" |
| extracted_path = dl_manager.download_and_extract(url) |
| tmp_path = os.path.join(extracted_path, CORENLP) |
| shutil.copyfile( |
| os.path.join(tmp_path, "stanford-corenlp-3.6.0-models.jar"), |
| os.path.join(SPICELIB, "stanford-corenlp-3.6.0-models.jar"), |
| ) |
| shutil.copyfile( |
| os.path.join(tmp_path, "stanford-corenlp-3.6.0.jar"), |
| os.path.join(SPICELIB, "stanford-corenlp-3.6.0.jar"), |
| ) |
| logger.info(f"`stanford-corenlp` has been downloaded to {SPICELIB}") |
|
|
| def float_convert(self, obj): |
| try: |
| return float(obj) |
| except (ValueError, TypeError): |
| return float("nan") |
|
|
| def _compute_batch(self, scores: List[Dict]) -> Dict[str, float]: |
| """Compute average scores over all images in the batch.""" |
|
|
| |
| aggregate_scores = { |
| "pr": 0.0, |
| "re": 0.0, |
| "f": 0.0, |
| "fn": 0.0, |
| "numImages": 0.0, |
| "fp": 0.0, |
| "tp": 0.0, |
| } |
| num_images = len(scores) |
| if num_images == 0: |
| return aggregate_scores |
|
|
| |
| for score in scores: |
| for k, v in score.items(): |
| if k in ["fn", "fp", "tp"]: |
| aggregate_scores[k] += v |
| aggregate_scores["numImages"] += 1 |
|
|
| |
| tp = aggregate_scores["tp"] |
| fp = aggregate_scores["fp"] |
| fn = aggregate_scores["fn"] |
| precision = tp / (tp + fp) if (tp + fp) > 0 else float("nan") |
| recall = tp / (tp + fn) if (tp + fn) > 0 else float("nan") |
| f_score = ( |
| 2 * precision * recall / (precision + recall) |
| if precision is not None and recall is not None and (precision + recall) > 0 |
| else float("nan") |
| ) |
| aggregate_scores["pr"] = precision |
| aggregate_scores["re"] = recall |
| aggregate_scores["f"] = f_score |
| return aggregate_scores |
|
|
| def _compute(self, predictions, references, spice_name="All"): |
| """Returns the scores""" |
| assert len(predictions) == len(references), ( |
| "The number of predictions and references should be the same. " |
| f"Got {len(predictions)} predictions and {len(references)} references." |
| ) |
| input_data = [] |
| for i, (prediction, reference) in enumerate(zip(predictions, references)): |
| assert isinstance(prediction, str), ( |
| "Each prediction should be a string. " |
| f"Got {type(prediction)} for image {i}." |
| ) |
| if isinstance(reference, str): |
| reference = [reference] |
| assert isinstance(reference, list) and all( |
| isinstance(ref, str) for ref in reference |
| ), ( |
| "Each reference should be a list of strings. " |
| f"Got {type(reference)} with elements of type {[type(ref) for ref in reference]} for index {i}." |
| ) |
| input_data.append({"image_id": i, "test": prediction, "refs": reference}) |
|
|
| in_file = tempfile.NamedTemporaryFile(delete=False) |
| in_file.write(json.dumps(input_data, indent=2).encode("utf-8")) |
| in_file.close() |
|
|
| out_file = tempfile.NamedTemporaryFile(delete=False) |
| out_file.close() |
| with tempfile.TemporaryDirectory() as cache_dir: |
| spice_cmd = [ |
| "java", |
| "-jar", |
| "-Xmx8G", |
| SPICE_JAR, |
| in_file.name, |
| "-cache", |
| cache_dir, |
| "-out", |
| out_file.name, |
| "-subset", |
| "-silent", |
| ] |
| try: |
| subprocess.run( |
| spice_cmd, |
| check=True, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| ) |
| except subprocess.CalledProcessError as e: |
| raise RuntimeError( |
| f"SPICE command '{' '.join(spice_cmd)}' returned non-zero exit status {e.returncode}. " |
| f"stderr: {e.stderr.decode('utf-8')}" |
| ) from e |
|
|
| with open(out_file.name, "r") as f: |
| results = json.load(f) |
| os.remove(in_file.name) |
| os.remove(out_file.name) |
|
|
| img_id_to_scores = { |
| item["image_id"]: item["scores"][spice_name] for item in results |
| } |
| scores = [ |
| {k: self.float_convert(v) for k, v in img_id_to_scores[image_id].items()} |
| for image_id in range(len(predictions)) |
| ] |
| return {f"spice_{k}": v for k, v in self._compute_batch(scores).items()} |
|
|