| import time |
| import torch |
| from detectron2.engine import SimpleTrainer |
| from typing import Iterable, Generator |
|
|
|
|
| def cycle(iterable: Iterable) -> Generator: |
| while True: |
| for item in iterable: |
| yield item |
|
|
|
|
| class CycleTrainer(SimpleTrainer): |
| def __init__( |
| self, |
| model, |
| data_loader, |
| optimizer, |
| gather_metric_period=1, |
| zero_grad_before_forward=False, |
| async_write_metrics=False, |
| ): |
| super().__init__( |
| model, |
| data_loader, |
| optimizer, |
| gather_metric_period, |
| zero_grad_before_forward, |
| async_write_metrics, |
| ) |
|
|
| @property |
| def _data_loader_iter(self): |
| |
| if self._data_loader_iter_obj is None: |
| self._data_loader_iter_obj = cycle(self.data_loader) |
| return self._data_loader_iter_obj |
|
|