| """Optional Weights & Biases logging for Ref-AVS training.""" |
| import os |
|
|
| import torchvision |
| import wandb |
|
|
|
|
| class Tensorboard: |
| def __init__(self, config): |
| key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '') |
| if key: |
| os.environ['WANDB_API_KEY'] = key |
| mode = 'online' if config.get('wandb_online', False) else 'disabled' |
| self.tensor_board = wandb.init( |
| project=config['proj_name'], |
| name=config['experiment_name'], |
| config=config, |
| mode=mode, |
| settings=wandb.Settings(code_dir='.'), |
| ) |
| self.restore_transform = torchvision.transforms.Compose([ |
| DeNormalize(config['image_mean'], config['image_std']), |
| torchvision.transforms.ToPILImage(), |
| ]) |
|
|
| def upload_wandb_info(self, info_dict): |
| for key, value in info_dict.items(): |
| self.tensor_board.log({key: value}) |
|
|
| def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4): |
| n = min(pseudo_label_from_pred.shape[0], img_number) |
| frames = frames[:n] |
| pseudo_label_from_sam = pseudo_label_from_sam[:n].float() |
| pseudo_label_from_pred = pseudo_label_from_pred[:n].float() |
| pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5 |
| pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5 |
| self.tensor_board.log({ |
| 'image': [wandb.Image(j, caption=f'id {i}') for i, j in enumerate(frames)], |
| 'label': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_sam)], |
| 'logits': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_pred)], |
| }) |
|
|
| def finish(self): |
| self.tensor_board.finish() |
|
|
|
|
| class DeNormalize(object): |
| def __init__(self, mean, std): |
| self.mean = mean |
| self.std = std |
|
|
| def __call__(self, tensor): |
| for t, m, s in zip(tensor, self.mean, self.std): |
| t.mul_(s).add_(m) |
| return tensor |
|
|