| import glob |
| import nibabel as nib |
| import numpy as np |
| import os |
| import torch |
| from skimage import io |
| from einops import rearrange |
| from omegaconf import OmegaConf |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from ldm.util import instantiate_from_config |
|
|
|
|
| logdir = 'logs/full_ct_2d_with_body_mask/' |
| ckpt = os.path.join(logdir, "checkpoints", "epoch=000064.ckpt") |
|
|
| configs_file = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[1] |
| configs = OmegaConf.load(configs_file) |
| model = instantiate_from_config(configs.model) |
| model.init_from_ckpt(ckpt) |
| model.eval() |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print("Using device", device) |
| model = model.to(device) |
|
|
| config = OmegaConf.load('./configs/latent-diffusion/full_ct_2d_with_body_mask.yaml') |
| data = instantiate_from_config(config.data) |
| data.prepare_data() |
| data.setup() |
|
|
| save_path = 'inference' |
| save_path = os.path.join(logdir, save_path) |
| if not os.path.exists(save_path): |
| os.makedirs(save_path) |
|
|
| val_dataset = data.datasets['validation'] |
| batch_size = 1 |
| valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True) |
| val_num = len(val_dataset) |
| save_gt = True |
| |
| |
| for idx, data in tqdm(enumerate(valloader)): |
| |
| |
| name=data['name'][0].split('.')[0] |
| data['volume_data'] = data['volume_data'].to(device) |
| data['volume_seg'] = data['volume_seg'].to(device) |
| data['age_value'] = data['age_value'].to(device) |
| |
|
|
| x, c = model.get_input(data, model.first_stage_key) |
| |
| samples, _ = model.sample_log(cond=c, batch_size=17, ddim=True, eta=1., ddim_steps=100) |
| samples=samples.permute(1,0,2,3) |
|
|
| res = model.decode_first_stage(samples[None,:]) |
| res = res*2 |
| |
| res[res>1.0] = 1.0 |
| res[res<-1.0] = -1.0 |
| res = (res+1)/2 |
| |
| res = res[0].mean(axis=0).detach().cpu().numpy() |
| |
| |
| for z_slice in range(res.shape[0]): |
| |
| res_slice = np.repeat(res[z_slice][None,], 3, 0).transpose(1,2,0)* 255 |
| |
| |
| io.imsave(os.path.join(save_path, str(name) + f'_sample_{z_slice}.png'), res_slice.astype(np.uint8)) |
| |
| |
|
|
| break |
|
|
|
|
|
|