| |
| from typing import Any, Iterator, List, Optional, Sequence, Union |
|
|
| from mmengine.dataset import pseudo_collate |
| from mmengine.registry import EVALUATOR, METRICS |
| from mmengine.structures import BaseDataElement |
| from .metric import BaseMetric |
|
|
|
|
| @EVALUATOR.register_module() |
| class Evaluator: |
| """Wrapper class to compose multiple :class:`BaseMetric` instances. |
| |
| Args: |
| metrics (dict or BaseMetric or Sequence): The config of metrics. |
| """ |
|
|
| def __init__(self, metrics: Union[dict, BaseMetric, Sequence]): |
| self._dataset_meta: Optional[dict] = None |
| if not isinstance(metrics, Sequence): |
| metrics = [metrics] |
| self.metrics: List[BaseMetric] = [] |
| for metric in metrics: |
| if isinstance(metric, dict): |
| self.metrics.append(METRICS.build(metric)) |
| else: |
| self.metrics.append(metric) |
|
|
| @property |
| def dataset_meta(self) -> Optional[dict]: |
| """Optional[dict]: Meta info of the dataset.""" |
| return self._dataset_meta |
|
|
| @dataset_meta.setter |
| def dataset_meta(self, dataset_meta: dict) -> None: |
| """Set the dataset meta info to the evaluator and it's metrics.""" |
| self._dataset_meta = dataset_meta |
| for metric in self.metrics: |
| metric.dataset_meta = dataset_meta |
|
|
| def process(self, |
| data_samples: Sequence[BaseDataElement], |
| data_batch: Optional[Any] = None): |
| """Convert ``BaseDataSample`` to dict and invoke process method of each |
| metric. |
| |
| Args: |
| data_samples (Sequence[BaseDataElement]): predictions of the model, |
| and the ground truth of the validation set. |
| data_batch (Any, optional): A batch of data from the dataloader. |
| """ |
| _data_samples = [] |
| for data_sample in data_samples: |
| if isinstance(data_sample, BaseDataElement): |
| _data_samples.append(data_sample.to_dict()) |
| else: |
| _data_samples.append(data_sample) |
|
|
| for metric in self.metrics: |
| metric.process(data_batch, _data_samples) |
|
|
| def evaluate(self, size: int) -> dict: |
| """Invoke ``evaluate`` method of each metric and collect the metrics |
| dictionary. |
| |
| Args: |
| size (int): Length of the entire validation dataset. When batch |
| size > 1, the dataloader may pad some data samples to make |
| sure all ranks have the same length of dataset slice. The |
| ``collect_results`` function will drop the padded data based on |
| this size. |
| |
| Returns: |
| dict: Evaluation results of all metrics. The keys are the names |
| of the metrics, and the values are corresponding results. |
| """ |
| metrics = {} |
| for metric in self.metrics: |
| _results = metric.evaluate(size) |
|
|
| |
| for name in _results.keys(): |
| if name in metrics: |
| raise ValueError( |
| 'There are multiple evaluation results with the same ' |
| f'metric name {name}. Please make sure all metrics ' |
| 'have different prefixes.') |
|
|
| metrics.update(_results) |
| return metrics |
|
|
| def offline_evaluate(self, |
| data_samples: Sequence, |
| data: Optional[Sequence] = None, |
| chunk_size: int = 1): |
| """Offline evaluate the dumped predictions on the given data . |
| |
| Args: |
| data_samples (Sequence): All predictions and ground truth of the |
| model and the validation set. |
| data (Sequence, optional): All data of the validation set. |
| chunk_size (int): The number of data samples and predictions to be |
| processed in a batch. |
| """ |
|
|
| |
| def get_chunks(seq: Iterator, chunk_size=1): |
| stop = False |
| while not stop: |
| chunk = [] |
| for _ in range(chunk_size): |
| try: |
| chunk.append(next(seq)) |
| except StopIteration: |
| stop = True |
| break |
| if chunk: |
| yield chunk |
|
|
| if data is not None: |
| assert len(data_samples) == len(data), ( |
| 'data_samples and data should have the same length, but got ' |
| f'data_samples length: {len(data_samples)} ' |
| f'data length: {len(data)}') |
| data = get_chunks(iter(data), chunk_size) |
|
|
| size = 0 |
| for output_chunk in get_chunks(iter(data_samples), chunk_size): |
| if data is not None: |
| data_chunk = pseudo_collate(next(data)) |
| else: |
| data_chunk = None |
| size += len(output_chunk) |
| self.process(output_chunk, data_chunk) |
| return self.evaluate(size) |
|
|