| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from datetime import timedelta |
| from pathlib import Path |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| from monai.networks.utils import copy_model_state |
| from monai.utils import RankFilter |
| from monai.networks.schedulers import RFlowScheduler |
| |
| from monai.networks.schedulers.ddpm import DDPMPredictionType |
| from torch.amp import GradScaler, autocast |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from .utils import binarize_labels, define_instance, prepare_maisi_controlnet_json_dataloader, setup_ddp |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="maisi.controlnet.training") |
| parser.add_argument( |
| "-e", |
| "--environment-file", |
| default="./configs/environment_maisi_controlnet_train.json", |
| help="environment json file that stores environment path", |
| ) |
| parser.add_argument( |
| "-c", |
| "--config-file", |
| default="./configs/config_maisi-ddpm.json", |
| help="config json file that stores network hyper-parameters", |
| ) |
| parser.add_argument( |
| "-t", |
| "--training-config", |
| default="./configs/config_maisi_controlnet_train.json", |
| help="config json file that stores training hyper-parameters", |
| ) |
| parser.add_argument("-g", "--gpus", default=1, type=int, help="number of gpus per node") |
|
|
| args = parser.parse_args() |
|
|
| |
| logger = logging.getLogger("maisi.controlnet.training") |
| |
| use_ddp = args.gpus > 1 |
| if use_ddp: |
| rank = int(os.environ["LOCAL_RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
| device = setup_ddp(rank, world_size) |
| logger.addFilter(RankFilter()) |
| else: |
| rank = 0 |
| world_size = 1 |
| device = torch.device(f"cuda:{rank}") |
|
|
| torch.cuda.set_device(device) |
| logger.info(f"Number of GPUs: {torch.cuda.device_count()}") |
| logger.info(f"World_size: {world_size}") |
|
|
| with open(args.environment_file, "r") as env_file: |
| env_dict = json.load(env_file) |
| with open(args.config_file, "r") as config_file: |
| config_dict = json.load(config_file) |
| with open(args.training_config, "r") as training_config_file: |
| training_config_dict = json.load(training_config_file) |
|
|
| for k, v in env_dict.items(): |
| setattr(args, k, v) |
| for k, v in config_dict.items(): |
| setattr(args, k, v) |
| for k, v in training_config_dict.items(): |
| setattr(args, k, v) |
|
|
| |
| if rank == 0: |
| tensorboard_path = os.path.join(args.tfevent_path, args.exp_name) |
| Path(tensorboard_path).mkdir(parents=True, exist_ok=True) |
| tensorboard_writer = SummaryWriter(tensorboard_path) |
|
|
| |
| train_loader, _ = prepare_maisi_controlnet_json_dataloader( |
| json_data_list=args.json_data_list, |
| data_base_dir=args.data_base_dir, |
| rank=rank, |
| world_size=world_size, |
| batch_size=args.controlnet_train["batch_size"], |
| cache_rate=args.controlnet_train["cache_rate"], |
| fold=args.controlnet_train["fold"], |
| ) |
|
|
| |
| |
| unet = define_instance(args, "diffusion_unet_def").to(device) |
| include_body_region = unet.include_top_region_index_input |
| include_modality = unet.num_class_embeds is not None |
|
|
| |
| if args.trained_diffusion_path is not None: |
| if not os.path.exists(args.trained_diffusion_path): |
| raise ValueError("Please download the trained diffusion unet checkpoint.") |
| diffusion_model_ckpt = torch.load(args.trained_diffusion_path, map_location=device, weights_only=False) |
| unet.load_state_dict(diffusion_model_ckpt["unet_state_dict"]) |
| |
| scale_factor = diffusion_model_ckpt["scale_factor"] |
| logger.info(f"Load trained diffusion model from {args.trained_diffusion_path}.") |
| logger.info(f"loaded scale_factor from diffusion model ckpt -> {scale_factor}.") |
| else: |
| logger.info("trained diffusion model is not loaded.") |
| scale_factor = 1.0 |
| logger.info(f"set scale_factor -> {scale_factor}.") |
|
|
| |
| controlnet = define_instance(args, "controlnet_def").to(device) |
| |
| copy_model_state(controlnet, unet.state_dict()) |
| |
| if args.trained_controlnet_path is not None: |
| if not os.path.exists(args.trained_controlnet_path): |
| raise ValueError("Please download the trained ControlNet checkpoint.") |
| controlnet.load_state_dict( |
| torch.load(args.trained_controlnet_path, map_location=device, weights_only=False)["controlnet_state_dict"] |
| ) |
| logger.info(f"load trained controlnet model from {args.trained_controlnet_path}") |
| else: |
| logger.info("train controlnet model from scratch.") |
| |
| for p in unet.parameters(): |
| p.requires_grad = False |
|
|
| noise_scheduler = define_instance(args, "noise_scheduler") |
|
|
| if use_ddp: |
| controlnet = DDP(controlnet, device_ids=[device], output_device=rank, find_unused_parameters=True) |
|
|
| |
| weighted_loss = args.controlnet_train["weighted_loss"] |
| weighted_loss_label = args.controlnet_train["weighted_loss_label"] |
| optimizer = torch.optim.AdamW(params=controlnet.parameters(), lr=args.controlnet_train["lr"]) |
| total_steps = (args.controlnet_train["n_epochs"] * len(train_loader.dataset)) / args.controlnet_train["batch_size"] |
| logger.info(f"total number of training steps: {total_steps}.") |
|
|
| lr_scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_steps, power=2.0) |
|
|
| |
| n_epochs = args.controlnet_train["n_epochs"] |
| scaler = GradScaler("cuda") |
| total_step = 0 |
| best_loss = 1e4 |
|
|
| if weighted_loss > 1.0: |
| logger.info(f"apply weighted loss = {weighted_loss} on labels: {weighted_loss_label}") |
|
|
| controlnet.train() |
| unet.eval() |
| prev_time = time.time() |
| for epoch in range(n_epochs): |
| epoch_loss_ = 0 |
| for step, batch in enumerate(train_loader): |
| |
| |
| |
| |
| images = batch["image"].to(device) * scale_factor |
| labels = batch["label"].to(device) |
| |
| if include_body_region: |
| top_region_index_tensor = batch["top_region_index"].to(device) |
| bottom_region_index_tensor = batch["bottom_region_index"].to(device) |
| |
| if include_modality: |
| modality_tensor = torch.ones((len(images),), dtype=torch.long).to(device) |
| spacing_tensor = batch["spacing"].to(device) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| with autocast("cuda", enabled=True): |
| |
| noise_shape = list(images.shape) |
| noise = torch.randn(noise_shape, dtype=images.dtype).to(device) |
|
|
| |
| controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() |
|
|
| |
| if isinstance(noise_scheduler, RFlowScheduler): |
| timesteps = noise_scheduler.sample_timesteps(images) |
| else: |
| timesteps = torch.randint( |
| 0, noise_scheduler.num_train_timesteps, (images.shape[0],), device=images.device |
| ).long() |
|
|
| |
| noisy_latent = noise_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps) |
|
|
| |
| |
| controlnet_inputs = { |
| "x": noisy_latent, |
| "timesteps": timesteps, |
| "controlnet_cond": controlnet_cond, |
| } |
| if include_modality: |
| controlnet_inputs.update( |
| { |
| "class_labels": modality_tensor, |
| } |
| ) |
| down_block_res_samples, mid_block_res_sample = controlnet(**controlnet_inputs) |
|
|
| |
| |
| unet_inputs = { |
| "x": noisy_latent, |
| "timesteps": timesteps, |
| "spacing_tensor": spacing_tensor, |
| "down_block_additional_residuals": down_block_res_samples, |
| "mid_block_additional_residual": mid_block_res_sample, |
| } |
| |
| if include_body_region: |
| unet_inputs.update( |
| { |
| "top_region_index_tensor": top_region_index_tensor, |
| "bottom_region_index_tensor": bottom_region_index_tensor, |
| } |
| ) |
| if include_modality: |
| unet_inputs.update( |
| { |
| "class_labels": modality_tensor, |
| } |
| ) |
| model_output = unet(**unet_inputs) |
|
|
| if noise_scheduler.prediction_type == DDPMPredictionType.EPSILON: |
| |
| model_gt = noise |
| elif noise_scheduler.prediction_type == DDPMPredictionType.SAMPLE: |
| |
| model_gt = images |
| elif noise_scheduler.prediction_type == DDPMPredictionType.V_PREDICTION: |
| |
| model_gt = images - noise |
| else: |
| raise ValueError( |
| "noise scheduler prediction type has to be chosen from ", |
| f"[{DDPMPredictionType.EPSILON},{DDPMPredictionType.SAMPLE},{DDPMPredictionType.V_PREDICTION}]", |
| ) |
|
|
| if weighted_loss > 1.0: |
| weights = torch.ones_like(images).to(images.device) |
| roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(images.device) |
| interpolate_label = F.interpolate(labels, size=images.shape[2:], mode="nearest") |
| |
| for label in weighted_loss_label: |
| roi[interpolate_label == label] = 1 |
| weights[roi.repeat(1, images.shape[1], 1, 1, 1) == 1] = weighted_loss |
| loss = (F.l1_loss(model_output.float(), model_gt.float(), reduction="none") * weights).mean() |
| else: |
| loss = F.l1_loss(model_output.float(), model_gt.float()) |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| lr_scheduler.step() |
| total_step += 1 |
|
|
| if rank == 0: |
| |
| tensorboard_writer.add_scalar( |
| "train/train_controlnet_loss_iter", loss.detach().cpu().item(), total_step |
| ) |
| batches_done = step + 1 |
| batches_left = len(train_loader) - batches_done |
| time_left = timedelta(seconds=batches_left * (time.time() - prev_time)) |
| prev_time = time.time() |
| logger.info( |
| "\r[Epoch %d/%d] [Batch %d/%d] [LR: %.8f] [loss: %.4f] ETA: %s " |
| % ( |
| epoch + 1, |
| n_epochs, |
| step + 1, |
| len(train_loader), |
| lr_scheduler.get_last_lr()[0], |
| loss.detach().cpu().item(), |
| time_left, |
| ) |
| ) |
| epoch_loss_ += loss.detach() |
|
|
| epoch_loss = epoch_loss_ / (step + 1) |
|
|
| if use_ddp: |
| dist.barrier() |
| dist.all_reduce(epoch_loss, op=torch.distributed.ReduceOp.AVG) |
|
|
| if rank == 0: |
| tensorboard_writer.add_scalar("train/train_controlnet_loss_epoch", epoch_loss.cpu().item(), total_step) |
| |
| controlnet_state_dict = controlnet.module.state_dict() if world_size > 1 else controlnet.state_dict() |
| torch.save( |
| { |
| "epoch": epoch + 1, |
| "loss": epoch_loss, |
| "controlnet_state_dict": controlnet_state_dict, |
| }, |
| f"{args.model_dir}/{args.exp_name}_current.pt", |
| ) |
|
|
| if epoch_loss < best_loss: |
| best_loss = epoch_loss |
| logger.info(f"best loss -> {best_loss}.") |
| torch.save( |
| { |
| "epoch": epoch + 1, |
| "loss": best_loss, |
| "controlnet_state_dict": controlnet_state_dict, |
| }, |
| f"{args.model_dir}/{args.exp_name}_best.pt", |
| ) |
|
|
| torch.cuda.empty_cache() |
| if use_ddp: |
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| logging.basicConfig( |
| stream=sys.stdout, |
| level=logging.INFO, |
| format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| main() |