| import os
|
|
|
| class WandbLogger:
|
| """
|
| Log using `Weights and Biases`.
|
| """
|
| def __init__(self):
|
| try:
|
| import wandb
|
| except ImportError:
|
| raise ImportError(
|
| "To use the Weights and Biases Logger please install wandb."
|
| "Run `pip install wandb` to install it."
|
| )
|
|
|
| self._wandb = wandb
|
|
|
|
|
| if self._wandb.run is None:
|
| self._wandb.init(
|
| project='diff_derain',
|
| dir='./experiments'
|
| )
|
|
|
| self.config = self._wandb.config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.eval_table = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.infer_table = None
|
|
|
| def log_metrics(self, metrics, commit=True):
|
| """
|
| Log train/validation metrics onto W&B.
|
|
|
| metrics: dictionary of metrics to be logged
|
| """
|
| self._wandb.log(metrics, commit=commit)
|
|
|
| def log_image(self, key_name, image_array):
|
| """
|
| Log image array onto W&B.
|
|
|
| key_name: name of the key
|
| image_array: numpy array of image.
|
| """
|
| self._wandb.log({key_name: self._wandb.Image(image_array)})
|
|
|
| def log_images(self, key_name, list_images):
|
| """
|
| Log list of image array onto W&B
|
|
|
| key_name: name of the key
|
| list_images: list of numpy image arrays
|
| """
|
| self._wandb.log({key_name: [self._wandb.Image(img) for img in list_images]})
|
|
|
| def log_checkpoint(self, current_epoch, current_step):
|
| """
|
| Log the model checkpoint as W&B artifacts
|
|
|
| current_epoch: the current epoch
|
| current_step: the current batch step
|
| """
|
| model_artifact = self._wandb.Artifact(
|
| self._wandb.run.id + "_model", type="model"
|
| )
|
|
|
| gen_path = os.path.join(
|
| self.config.path['checkpoint'], 'I{}_E{}_gen.pth'.format(current_step, current_epoch))
|
| opt_path = os.path.join(
|
| self.config.path['checkpoint'], 'I{}_E{}_opt.pth'.format(current_step, current_epoch))
|
|
|
| model_artifact.add_file(gen_path)
|
| model_artifact.add_file(opt_path)
|
| self._wandb.log_artifact(model_artifact, aliases=["latest"])
|
|
|
| def log_eval_data(self, fake_img, sr_img, hr_img, psnr=None, ssim=None):
|
| """
|
| Add data row-wise to the initialized table.
|
| """
|
| if psnr is not None and ssim is not None:
|
| self.eval_table.add_data(
|
| self._wandb.Image(fake_img),
|
| self._wandb.Image(sr_img),
|
| self._wandb.Image(hr_img),
|
| psnr,
|
| ssim
|
| )
|
| else:
|
| self.infer_table.add_data(
|
| self._wandb.Image(fake_img),
|
| self._wandb.Image(sr_img),
|
| self._wandb.Image(hr_img)
|
| )
|
|
|
| def log_eval_table(self, commit=False):
|
| """
|
| Log the table
|
| """
|
| if self.eval_table:
|
| self._wandb.log({'eval_data': self.eval_table}, commit=commit)
|
| elif self.infer_table:
|
| self._wandb.log({'infer_data': self.infer_table}, commit=commit)
|
|
|