| import glob |
| import nibabel as nib |
| import numpy as np |
| import os |
| import torch |
|
|
| from einops import rearrange |
| from omegaconf import OmegaConf |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from ldm.util import instantiate_from_config |
| import argparse |
|
|
| def compute_orientation(init_axcodes, final_axcodes): |
| """ |
| A thin wrapper around ``nib.orientations.ornt_transform`` |
| |
| :param init_axcodes: Initial orientation codes |
| :param final_axcodes: Target orientation codes |
| :return: orientations array, start_ornt, end_ornt |
| """ |
| ornt_init = nib.orientations.axcodes2ornt(init_axcodes) |
| ornt_fin = nib.orientations.axcodes2ornt(final_axcodes) |
|
|
| ornt_transf = nib.orientations.ornt_transform(ornt_init, ornt_fin) |
|
|
| return ornt_transf, ornt_init, ornt_fin |
|
|
| def do_reorientation(data_array, init_axcodes, final_axcodes): |
| """ |
| source: https://niftynet.readthedocs.io/en/dev/_modules/niftynet/io/misc_io.html#do_reorientation |
| Performs the reorientation (changing order of axes) |
| |
| :param data_array: 3D Array to reorient |
| :param init_axcodes: Initial orientation |
| :param final_axcodes: Target orientation |
| :return data_reoriented: New data array in its reoriented form |
| """ |
| ornt_transf, ornt_init, ornt_fin = compute_orientation(init_axcodes, final_axcodes) |
| if np.array_equal(ornt_init, ornt_fin): |
| return data_array |
|
|
| return nib.orientations.apply_orientation(data_array, ornt_transf) |
|
|
| def compute_fusion_boundaries(upper_iters, latent_length, decode_ratio=4, blend_width=16): |
| """ |
| Compute fusion boundaries based on result segments |
| |
| :param upper_iters: Number of iterations in the generation loop |
| :param latent_length: Length of latent dimension (17) |
| :param decode_ratio: Ratio between decoded and latent dimensions (4) |
| :param blend_width: Width of blending region on each side (16 slices) |
| :return: List of boundary dictionaries for fusion |
| """ |
| boundaries = [] |
| |
| for i in range(1, upper_iters): |
| |
| latent_boundary = i * latent_length |
| |
| |
| decoded_boundary = latent_boundary * decode_ratio + 1 |
| |
| |
| blend_start = max(0, decoded_boundary - blend_width) |
| blend_end = decoded_boundary + blend_width |
| |
| boundaries.append({ |
| 'boundary': decoded_boundary, |
| 'blend_start': blend_start, |
| 'blend_end': blend_end |
| }) |
| |
| return boundaries |
|
|
|
|
| def alternating_volume_fusion_with_boundaries(vol1, vol2, boundaries, offset=32, replace_width=16): |
| """ |
| Fuse two volumes by replacing regions around boundaries |
| |
| :param vol1: First volume with shape (x, y, z) - used as base |
| :param vol2: Second volume with shape (x, y, z-offset), offset by 32 slices |
| :param boundaries: List of boundary dictionaries from compute_fusion_boundaries |
| :param offset: Offset between vol1 and vol2 (default 32) |
| :param replace_width: Width on each side of boundary to replace (default 16) |
| :return: Fused volume with vol2 replacing vol1 at boundary regions |
| """ |
| x, y, total_slices = vol1.shape |
| fused = vol1.copy() |
| |
| |
| sorted_boundaries = sorted(boundaries, key=lambda b: b['boundary']) |
| |
| print(f"Fusion starting with vol1 as base") |
| |
| for i, boundary_info in enumerate(sorted_boundaries): |
| boundary = boundary_info['boundary'] |
| |
| |
| replace_start = max(0, boundary - replace_width) |
| replace_end = min(boundary + replace_width, total_slices) |
| |
| |
| vol2_start = replace_start - offset |
| vol2_end = replace_end - offset |
| |
| |
| if vol2_start >= 0 and vol2_end <= vol2.shape[2]: |
| fused[:, :, replace_start:replace_end] = vol2[:, :, vol2_start:vol2_end] |
| print(f" Boundary {i+1}: z={boundary}, replaced z[{replace_start}:{replace_end}] with vol2[{vol2_start}:{vol2_end}]") |
| else: |
| print(f" Boundary {i+1}: z={boundary}, skipped (vol2 out of range)") |
| |
| print(f"Fusion complete with {len(sorted_boundaries)} boundaries") |
| |
| return fused |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-t", |
| "--time_steps", |
| type=int, |
| default=20, |
| ) |
| args = parser.parse_args() |
| ddim_steps=args.time_steps |
|
|
| logdir = 'logs/full_ct_3d_with_body_mask' |
| ckpt = os.path.join(logdir, "checkpoints", "epoch=000070.ckpt") |
|
|
| configs_file = "configs/latent-diffusion/full_ct_3d_with_body_mask_eval.yaml" |
| configs = OmegaConf.load(configs_file) |
| model = instantiate_from_config(configs.model) |
| model.init_from_ckpt(ckpt) |
| model.eval() |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| print("Using device", device) |
| model = model.to(device) |
|
|
| config = OmegaConf.load('./configs/latent-diffusion/full_ct_3d_with_body_mask_eval.yaml') |
| data = instantiate_from_config(config.data) |
| data.prepare_data() |
| data.setup() |
|
|
| |
| save_path = f'evaluation' |
| save_path = os.path.join(logdir, save_path) |
| if not os.path.exists(save_path): |
| os.makedirs(save_path) |
|
|
| val_dataset = data.datasets['validation'] |
| batch_size = 1 |
| valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) |
| val_num = len(val_dataset) |
| save_gt = True |
|
|
| for idx, data in tqdm(enumerate(valloader)): |
|
|
| if idx >= val_num: |
| break |
|
|
| name=data['name'][0].split('.')[0] |
| ref_root = '/sd/shuhan/CT-RATE/dataset/valid_fixed' |
| ref_nii = os.path.join(ref_root, name.split('_')[0]+'_'+name.split('_')[1], name.split('_')[0]+'_'+name.split('_')[1]+'_'+name.split('_')[2],name+'.nii.gz') |
| data_path = os.path.join(save_path, str(f'{name}.nii.gz')) |
| if os.path.exists(data_path): |
| continue |
|
|
| volume_data = data['volume_data'] |
| volume_seg = data['volume_seg'] |
|
|
| window_length = 65 |
| latent_lenght=17 |
| slice_num = volume_data.shape[2] |
|
|
| upper_iters = slice_num // window_length + 1 if slice_num % window_length != 0 else slice_num // window_length |
| result = torch.zeros((batch_size, upper_iters*latent_lenght, 16, 64, 64)).cuda() |
| print('upper_iters', upper_iters) |
| |
| for i in range(upper_iters): |
| print('i', i) |
| input_data={} |
| if i == upper_iters-1: |
| input_data['name'] = data['name'] |
| input_data['volume_data'] = data['volume_data'][:, :, -window_length:].to(device) |
| input_data['volume_seg'] = data['volume_seg'][:, :, -window_length:].to(device) |
| input_data['input_text'] = data['input_text'] |
| else: |
| input_data['volume_data'] = data['volume_data'][:, :, i*window_length:(i+1)*window_length].to(device) |
| input_data['volume_seg'] = data['volume_seg'][:, :, i*window_length:(i+1)*window_length].to(device) |
| input_data['input_text'] = data['input_text'] |
| |
| with torch.no_grad(): |
| _, c = model.get_input(input_data, model.first_stage_key) |
| |
| if i == 0: |
| samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps) |
| else: |
| samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps, previous=x_minus1) |
|
|
| samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=latent_lenght) |
|
|
| if i == upper_iters-1: |
| result[:, -latent_lenght:] = samples_i |
| else: |
| result[:, i*latent_lenght:(i+1)*latent_lenght] = samples_i |
| x_minus1 = samples_i[:, -1:, ...] |
|
|
| result = result.permute(0,2,1,3,4) |
| x_result = torch.zeros((3,slice_num,512,512)) |
| dec_unit = 65 |
| dec_latent_unit=17 |
| num_dec_iter = slice_num // dec_unit + 1 if slice_num % dec_unit != 0 else slice_num // dec_unit |
| |
| for i in range(num_dec_iter): |
| if i == num_dec_iter - 1: |
| remaining_slices = slice_num % dec_unit if slice_num % dec_unit != 0 else dec_unit |
| x_result[:, -remaining_slices:] = model.decode_first_stage(result[:, :, -latent_lenght:])[0][:, -remaining_slices:] |
| else: |
| x_result[:, i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result[:, :, i*latent_lenght:(i+1)*latent_lenght])[0] |
| |
| x_result = x_result*2 |
| x_result[x_result>1.0] = 1.0 |
| x_result[x_result<-1.0] = -1.0 |
| x_result = (x_result+1)/2 |
| x_result_ = x_result.mean(axis=0).detach().cpu().numpy() |
| x_result = x_result_.transpose(1,2,0) |
|
|
| |
| data2={} |
| data2['volume_data'] = data['volume_data'][:, :, 32:] |
| data2['volume_seg'] = data['volume_seg'][:, :, 32:] |
| data2['name'] = data['name'] |
| data2['input_text'] = data['input_text'] |
| |
| volume_data = data2['volume_data'] |
| volume_seg = data2['volume_seg'] |
|
|
| window_length = 65 |
| latent_lenght=17 |
| slice_num2 = volume_data.shape[2] |
|
|
| upper_iters2 = slice_num2 // window_length + 1 if slice_num2 % window_length != 0 else slice_num2 // window_length |
| result2 = torch.zeros((batch_size, upper_iters2*latent_lenght, 16, 64, 64)).cuda() |
| print('upper_iters2', upper_iters2) |
| |
| for i in range(upper_iters2): |
| print('i', i) |
| input_data={} |
| if i == upper_iters2-1: |
| input_data['name'] = data2['name'] |
| input_data['volume_data'] = data2['volume_data'][:, :, -window_length:].to(device) |
| input_data['volume_seg'] = data2['volume_seg'][:, :, -window_length:].to(device) |
| input_data['input_text'] = data2['input_text'] |
| else: |
| input_data['volume_data'] = data2['volume_data'][:, :, i*window_length:(i+1)*window_length].to(device) |
| input_data['volume_seg'] = data2['volume_seg'][:, :, i*window_length:(i+1)*window_length].to(device) |
| input_data['input_text'] = data2['input_text'] |
| |
| with torch.no_grad(): |
| _, c = model.get_input(input_data, model.first_stage_key) |
| |
| if i == 0: |
| samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps) |
| else: |
| samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps, previous=x_minus1) |
|
|
| samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=latent_lenght) |
|
|
| if i == upper_iters2-1: |
| result2[:, -latent_lenght:] = samples_i |
| else: |
| result2[:, i*latent_lenght:(i+1)*latent_lenght] = samples_i |
| x_minus1 = samples_i[:, -1:, ...] |
|
|
| result2 = result2.permute(0,2,1,3,4) |
| x_result2 = torch.zeros((3,slice_num2,512,512)) |
| dec_unit = 65 |
| dec_latent_unit=17 |
| num_dec_iter = slice_num2 // dec_unit + 1 if slice_num2 % dec_unit != 0 else slice_num2 // dec_unit |
| |
| for i in range(num_dec_iter): |
| if i == num_dec_iter - 1: |
| remaining_slices = slice_num2 % dec_unit if slice_num2 % dec_unit != 0 else dec_unit |
| x_result2[:, -remaining_slices:] = model.decode_first_stage(result2[:, :, -latent_lenght:])[0][:, -remaining_slices:] |
| else: |
| x_result2[:, i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result2[:, :, i*latent_lenght:(i+1)*latent_lenght])[0] |
| |
| x_result2 = x_result2*2 |
| x_result2[x_result2>1.0] = 1.0 |
| x_result2[x_result2<-1.0] = -1.0 |
| x_result2 = (x_result2+1)/2 |
| x_result2_ = x_result2.mean(axis=0).detach().cpu().numpy() |
| x_result2 = x_result2_.transpose(1,2,0) |
|
|
| |
| boundaries = compute_fusion_boundaries( |
| upper_iters=upper_iters, |
| latent_length=16, |
| decode_ratio=4, |
| blend_width=16 |
| ) |
| |
| |
| print("Fusing volumes: replacing boundary regions with vol2...") |
| x_result_fused = alternating_volume_fusion_with_boundaries( |
| x_result, |
| x_result2, |
| boundaries, |
| offset=32, |
| replace_width=16 |
| ) |
|
|
| |
| |
| affine = nib.load(ref_nii).affine |
|
|
| x_result_fused = x_result_fused*2000.0 - 1000.0 |
| |
| data_nii = nib.Nifti1Image(x_result_fused.astype(np.int16), affine) |
|
|
| nib.save(data_nii, data_path) |
| print(f"Saved fused volume: {data_path}") |
| |