msbackup / GenCT-ageencoder /inference_abdomen_our_3d_new4.py
qic999's picture
Upload folder using huggingface_hub
3bc8d9b verified
import glob
import nibabel as nib
import numpy as np
import os
import torch
from einops import rearrange
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
from ldm.util import instantiate_from_config
import argparse
def compute_orientation(init_axcodes, final_axcodes):
"""
A thin wrapper around ``nib.orientations.ornt_transform``
:param init_axcodes: Initial orientation codes
:param final_axcodes: Target orientation codes
:return: orientations array, start_ornt, end_ornt
"""
ornt_init = nib.orientations.axcodes2ornt(init_axcodes)
ornt_fin = nib.orientations.axcodes2ornt(final_axcodes)
ornt_transf = nib.orientations.ornt_transform(ornt_init, ornt_fin)
return ornt_transf, ornt_init, ornt_fin
def do_reorientation(data_array, init_axcodes, final_axcodes):
"""
source: https://niftynet.readthedocs.io/en/dev/_modules/niftynet/io/misc_io.html#do_reorientation
Performs the reorientation (changing order of axes)
:param data_array: 3D Array to reorient
:param init_axcodes: Initial orientation
:param final_axcodes: Target orientation
:return data_reoriented: New data array in its reoriented form
"""
ornt_transf, ornt_init, ornt_fin = compute_orientation(init_axcodes, final_axcodes)
if np.array_equal(ornt_init, ornt_fin):
return data_array
return nib.orientations.apply_orientation(data_array, ornt_transf)
def compute_fusion_boundaries(upper_iters, latent_length, decode_ratio=4, blend_width=16):
"""
Compute fusion boundaries based on result segments
:param upper_iters: Number of iterations in the generation loop
:param latent_length: Length of latent dimension (17)
:param decode_ratio: Ratio between decoded and latent dimensions (4)
:param blend_width: Width of blending region on each side (16 slices)
:return: List of boundary dictionaries for fusion
"""
boundaries = []
for i in range(1, upper_iters):
# Calculate boundary position in latent space
latent_boundary = i * latent_length
# Convert to decoded space (multiply by decode_ratio)
decoded_boundary = latent_boundary * decode_ratio + 1
# Define blending region: boundary ± blend_width
blend_start = max(0, decoded_boundary - blend_width)
blend_end = decoded_boundary + blend_width
boundaries.append({
'boundary': decoded_boundary,
'blend_start': blend_start,
'blend_end': blend_end
})
return boundaries
def alternating_volume_fusion_with_boundaries(vol1, vol2, boundaries, offset=32, replace_width=16):
"""
Fuse two volumes by replacing regions around boundaries
:param vol1: First volume with shape (x, y, z) - used as base
:param vol2: Second volume with shape (x, y, z-offset), offset by 32 slices
:param boundaries: List of boundary dictionaries from compute_fusion_boundaries
:param offset: Offset between vol1 and vol2 (default 32)
:param replace_width: Width on each side of boundary to replace (default 16)
:return: Fused volume with vol2 replacing vol1 at boundary regions
"""
x, y, total_slices = vol1.shape
fused = vol1.copy() # Start with vol1 as base
# Sort boundaries by position
sorted_boundaries = sorted(boundaries, key=lambda b: b['boundary'])
print(f"Fusion starting with vol1 as base")
for i, boundary_info in enumerate(sorted_boundaries):
boundary = boundary_info['boundary']
# Define replacement region: boundary ± replace_width
replace_start = max(0, boundary - replace_width)
replace_end = min(boundary + replace_width, total_slices)
# Map to vol2 coordinates
vol2_start = replace_start - offset
vol2_end = replace_end - offset
# Replace this region with vol2 if valid
if vol2_start >= 0 and vol2_end <= vol2.shape[2]:
fused[:, :, replace_start:replace_end] = vol2[:, :, vol2_start:vol2_end]
print(f" Boundary {i+1}: z={boundary}, replaced z[{replace_start}:{replace_end}] with vol2[{vol2_start}:{vol2_end}]")
else:
print(f" Boundary {i+1}: z={boundary}, skipped (vol2 out of range)")
print(f"Fusion complete with {len(sorted_boundaries)} boundaries")
return fused
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--time_steps",
type=int,
default=20,
)
args = parser.parse_args()
ddim_steps=args.time_steps
logdir = 'logs/full_ct_3d_with_body_mask'
ckpt = os.path.join(logdir, "checkpoints", "epoch=000070.ckpt")
configs_file = "configs/latent-diffusion/full_ct_3d_with_body_mask_eval.yaml"
configs = OmegaConf.load(configs_file)
model = instantiate_from_config(configs.model)
model.init_from_ckpt(ckpt)
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device", device)
model = model.to(device)
config = OmegaConf.load('./configs/latent-diffusion/full_ct_3d_with_body_mask_eval.yaml')
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
# save_path = f'3d_results_step{ddim_steps}_train_latest'
save_path = f'evaluation'
save_path = os.path.join(logdir, save_path)
if not os.path.exists(save_path):
os.makedirs(save_path)
val_dataset = data.datasets['validation']
batch_size = 1
valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
val_num = len(val_dataset)
save_gt = True
for idx, data in tqdm(enumerate(valloader)):
if idx >= val_num:
break
name=data['name'][0].split('.')[0]
ref_root = '/sd/shuhan/CT-RATE/dataset/valid_fixed'
ref_nii = os.path.join(ref_root, name.split('_')[0]+'_'+name.split('_')[1], name.split('_')[0]+'_'+name.split('_')[1]+'_'+name.split('_')[2],name+'.nii.gz')
data_path = os.path.join(save_path, str(f'{name}.nii.gz'))
if os.path.exists(data_path):
continue
volume_data = data['volume_data']
volume_seg = data['volume_seg']
window_length = 65
latent_lenght=17
slice_num = volume_data.shape[2]
upper_iters = slice_num // window_length + 1 if slice_num % window_length != 0 else slice_num // window_length
result = torch.zeros((batch_size, upper_iters*latent_lenght, 16, 64, 64)).cuda()
print('upper_iters', upper_iters)
for i in range(upper_iters):
print('i', i)
input_data={}
if i == upper_iters-1:
input_data['name'] = data['name']
input_data['volume_data'] = data['volume_data'][:, :, -window_length:].to(device)
input_data['volume_seg'] = data['volume_seg'][:, :, -window_length:].to(device)
input_data['input_text'] = data['input_text']
else:
input_data['volume_data'] = data['volume_data'][:, :, i*window_length:(i+1)*window_length].to(device)
input_data['volume_seg'] = data['volume_seg'][:, :, i*window_length:(i+1)*window_length].to(device)
input_data['input_text'] = data['input_text']
with torch.no_grad():
_, c = model.get_input(input_data, model.first_stage_key)
if i == 0:
samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps)
else:
samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps, previous=x_minus1)
samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=latent_lenght)
if i == upper_iters-1:
result[:, -latent_lenght:] = samples_i
else:
result[:, i*latent_lenght:(i+1)*latent_lenght] = samples_i
x_minus1 = samples_i[:, -1:, ...] # Take last slice for next iteration
result = result.permute(0,2,1,3,4)
x_result = torch.zeros((3,slice_num,512,512))
dec_unit = 65
dec_latent_unit=17
num_dec_iter = slice_num // dec_unit + 1 if slice_num % dec_unit != 0 else slice_num // dec_unit
for i in range(num_dec_iter):
if i == num_dec_iter - 1:
remaining_slices = slice_num % dec_unit if slice_num % dec_unit != 0 else dec_unit
x_result[:, -remaining_slices:] = model.decode_first_stage(result[:, :, -latent_lenght:])[0][:, -remaining_slices:]
else:
x_result[:, i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result[:, :, i*latent_lenght:(i+1)*latent_lenght])[0]
x_result = x_result*2
x_result[x_result>1.0] = 1.0
x_result[x_result<-1.0] = -1.0
x_result = (x_result+1)/2
x_result_ = x_result.mean(axis=0).detach().cpu().numpy()
x_result = x_result_.transpose(1,2,0)
# Process second volume with offset
data2={}
data2['volume_data'] = data['volume_data'][:, :, 32:]
data2['volume_seg'] = data['volume_seg'][:, :, 32:]
data2['name'] = data['name']
data2['input_text'] = data['input_text']
volume_data = data2['volume_data']
volume_seg = data2['volume_seg']
window_length = 65
latent_lenght=17
slice_num2 = volume_data.shape[2]
upper_iters2 = slice_num2 // window_length + 1 if slice_num2 % window_length != 0 else slice_num2 // window_length
result2 = torch.zeros((batch_size, upper_iters2*latent_lenght, 16, 64, 64)).cuda()
print('upper_iters2', upper_iters2)
for i in range(upper_iters2):
print('i', i)
input_data={}
if i == upper_iters2-1:
input_data['name'] = data2['name']
input_data['volume_data'] = data2['volume_data'][:, :, -window_length:].to(device)
input_data['volume_seg'] = data2['volume_seg'][:, :, -window_length:].to(device)
input_data['input_text'] = data2['input_text']
else:
input_data['volume_data'] = data2['volume_data'][:, :, i*window_length:(i+1)*window_length].to(device)
input_data['volume_seg'] = data2['volume_seg'][:, :, i*window_length:(i+1)*window_length].to(device)
input_data['input_text'] = data2['input_text']
with torch.no_grad():
_, c = model.get_input(input_data, model.first_stage_key)
if i == 0:
samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps)
else:
samples_i, _ = model.sample_log(cond=c, batch_size=latent_lenght, ddim=True, eta=1., ddim_steps=ddim_steps, previous=x_minus1)
samples_i = rearrange(samples_i, '(b z) c h w -> b z c h w', z=latent_lenght)
if i == upper_iters2-1:
result2[:, -latent_lenght:] = samples_i
else:
result2[:, i*latent_lenght:(i+1)*latent_lenght] = samples_i
x_minus1 = samples_i[:, -1:, ...] # Take last slice for next iteration
result2 = result2.permute(0,2,1,3,4)
x_result2 = torch.zeros((3,slice_num2,512,512))
dec_unit = 65
dec_latent_unit=17
num_dec_iter = slice_num2 // dec_unit + 1 if slice_num2 % dec_unit != 0 else slice_num2 // dec_unit
for i in range(num_dec_iter):
if i == num_dec_iter - 1:
remaining_slices = slice_num2 % dec_unit if slice_num2 % dec_unit != 0 else dec_unit
x_result2[:, -remaining_slices:] = model.decode_first_stage(result2[:, :, -latent_lenght:])[0][:, -remaining_slices:]
else:
x_result2[:, i*dec_unit:(i+1)*dec_unit] = model.decode_first_stage(result2[:, :, i*latent_lenght:(i+1)*latent_lenght])[0]
x_result2 = x_result2*2
x_result2[x_result2>1.0] = 1.0
x_result2[x_result2<-1.0] = -1.0
x_result2 = (x_result2+1)/2
x_result2_ = x_result2.mean(axis=0).detach().cpu().numpy()
x_result2 = x_result2_.transpose(1,2,0)
# Compute fusion boundaries based on result segments
boundaries = compute_fusion_boundaries(
upper_iters=upper_iters,
latent_length=16,
decode_ratio=4,
blend_width=16
)
# Fuse volumes: start with vol1, replace regions around boundaries with vol2
print("Fusing volumes: replacing boundary regions with vol2...")
x_result_fused = alternating_volume_fusion_with_boundaries(
x_result,
x_result2,
boundaries,
offset=32,
replace_width=16
)
# Load reference and save fused result
affine = nib.load(ref_nii).affine
x_result_fused = x_result_fused*2000.0 - 1000.0
data_nii = nib.Nifti1Image(x_result_fused.astype(np.int16), affine)
nib.save(data_nii, data_path)
print(f"Saved fused volume: {data_path}")
# breakpoint()