msbackup / GenCT-ageencoder /mask_generation_pipeline.py
qic999's picture
Upload folder using huggingface_hub
3bc8d9b verified
import glob
import nibabel as nib
import numpy as np
import os
import torch
from einops import rearrange
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
from ldm.util import instantiate_from_config
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-s",
"--save_path",
type=str,
default=None,
)
args = parser.parse_args()
ddim_steps=args.time_steps
# breakpoint()
config = OmegaConf.load('./configs/latent-diffusion/mask_generation.yaml')
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
save_path = args.save_path
if not os.path.exists(save_path):
os.makedirs(save_path)
val_dataset = data.datasets['validation']
batch_size = 1
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
val_num = len(val_dataset)
save_gt = True
for idx, data in tqdm(enumerate(valloader)):
if idx >= val_num:
break
name = data['name'][0]
volume_data = data['volume_data']
window_length = 16
h = 1
slice_num =volume_data.shape[1]
result = torch.zeros((batch_size, slice_num, 4, 64, 64)).cuda()
upper_iters = (slice_num-h) // (window_length-h)+1 if (slice_num-h)%(window_length-h) != 0 else (slice_num-h) // (window_length-h)
print('upper_iters', upper_iters)
# breakpoint()
for i in range(upper_iters):
print('i', i)
input_data={}
if i == upper_iters-1:
input_data['name'] = data['name']
input_data['volume_data'] = data['volume_data'][:,-window_length:].to(device)
input_data['masked_data'] = data['masked_data'][:,-window_length:].to(device)
input_data['tumor_mask'] = data['tumor_mask'][:,-window_length:].to(device)
else:
input_data['volume_data'] = data['volume_data'][:, i*window_length-i*h:(i+1)*window_length-i*h].to(device)
input_data['masked_data'] = data['masked_data'][:, i*window_length-i*h:(i+1)*window_length-i*h].to(device)
input_data['tumor_mask'] = data['tumor_mask'][:, i*window_length-i*h:(i+1)*window_length-i*h].to(device)
with torch.no_grad():
_, c = model.get_input(input_data, model.first_stage_key)
if i == 0:
samples_i, _ = model.sample_log(cond=c, batch_size=window_length, ddim=True, eta=1., ddim_steps=ddim_steps)
else:
samples_i, _ = model.sample_log(cond=c, batch_size=window_length, ddim=True, eta=1., ddim_steps=ddim_steps, previous=x_minus1)
# breakpoint()
samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=window_length)
if i == upper_iters-1:
result[:, -window_length+h:] = samples_i[:,h:,...]
else:
if i == 0:
result[:, :window_length] = samples_i
else:
result[:, i*window_length-i*h+h:(i+1)*window_length-i*h] = samples_i[:, h:]
x_minus1 = samples_i[:, -h:,...]
# breakpoint()
result = rearrange(result, 'b z c h w -> (b z) c h w')
x_result = torch.zeros((result.shape[0],3,512,512))
# breakpoint()
dec_unit = 16
num_dec_iter = slice_num // dec_unit + 1 if slice_num % dec_unit != 0 else slice_num // dec_unit
for i in range(num_dec_iter):
if i == num_dec_iter - 1:
x_result[-dec_unit:] = model.decode_first_stage(result[-dec_unit:])
x_result[i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result[i*dec_unit:(i+1)*dec_unit])
x_result[x_result>1.0] = 1.0
x_result[x_result<-1.0] = -1.0
x_result = (x_result+1)/2
x_result = rearrange(x_result, '(b z) c h w -> b z c h w', z=slice_num)
x_result_ = x_result[0].mean(axis=1).detach().cpu().numpy()
# x_result = x_result[0,:,0,...].detach().cpu().numpy()
x_result = x_result_.transpose(2,1,0)
# x_result = np.rot90(x_result, k=1, axes=(0,1))
# x_result = np.flip(x_result,axis=(0,1))
# import imageio as io
# io.imsave('exp.png', (x_result[:,:,400]*255).astype(np.uint8))
# breakpoint()
ref_root = '/storage/chenqi/data/CT/Task03_Liver/imagesTr'
ref_nii = os.path.join(ref_root, name+'.nii.gz')
affine = nib.load(ref_nii).affine
x_result = x_result*425.0 - 175.0
data_path = os.path.join(save_path, str(f'{name}.nii.gz'))
data_nii = nib.Nifti1Image(x_result.astype(np.int16), affine)
nib.save(data_nii, data_path)
# breakpoint()