| import wandb |
| import torch |
| from torchvision.utils import make_grid |
| import torch.distributed as dist |
| from PIL import Image |
| import os |
| import argparse |
| import hashlib |
| import math |
|
|
|
|
| def is_main_process(): |
| return dist.get_rank() == 0 |
|
|
| def namespace_to_dict(namespace): |
| return { |
| k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v |
| for k, v in vars(namespace).items() |
| } |
|
|
|
|
| def generate_run_id(exp_name): |
| |
| return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8) |
|
|
|
|
| def initialize(args, entity, exp_name, project_name): |
| config_dict = namespace_to_dict(args) |
| wandb.login(key=os.environ["WANDB_KEY"]) |
| wandb.init( |
| entity=entity, |
| project=project_name, |
| name=exp_name, |
| config=config_dict, |
| id=generate_run_id(exp_name), |
| resume="allow", |
| ) |
|
|
|
|
| def log(stats, step=None): |
| if is_main_process(): |
| wandb.log({k: v for k, v in stats.items()}, step=step) |
|
|
|
|
| def log_image(name, sample, step=None): |
| if is_main_process(): |
| sample = array2grid(sample) |
| wandb.log({f"{name}": wandb.Image(sample), "train_step": step}) |
|
|
|
|
| def array2grid(x): |
| nrow = round(math.sqrt(x.size(0))) |
| x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1)) |
| x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy() |
| return x |