MaskEdit / instruct-pix2pix-BioMedCLIP-concat-newdata-data-Opacity /stable_diffusion /ldm /data /ct_clip_data_train.py
| import os | |
| import glob | |
| import json | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as transforms | |
| from functools import partial | |
| import torch.nn.functional as F | |
| import nibabel as nib | |
| import tqdm | |
| import copy | |
| def resize_array(array, current_spacing): | |
| """ | |
| Resize the array to match the target spacing. | |
| Args: | |
| array (torch.Tensor): Input array to be resized. | |
| current_spacing (tuple): Current voxel spacing (z_spacing, xy_spacing, xy_spacing). | |
| target_spacing (tuple): Target voxel spacing (target_z_spacing, target_x_spacing, target_y_spacing). | |
| Returns: | |
| np.ndarray: Resized array. | |
| """ | |
| # Calculate new dimensions | |
| original_shape = array.shape[2:] | |
| new_shape = [original_shape[0], 256, 256] | |
| scaling_factors = [new_shape[i] / original_shape[i] for i in range(len(original_shape))] | |
| resized_spacing = [current_spacing[i] / scaling_factors[i] for i in range(len(original_shape))] | |
| # Resize the array | |
| resized_array = F.interpolate(array, size=new_shape, mode='trilinear', align_corners=False).cpu().numpy() | |
| # breakpoint() | |
| return resized_array, resized_spacing | |
| def resize_mask(array, current_spacing): | |
| """ | |
| Resize the array to match the target spacing. | |
| Args: | |
| array (torch.Tensor): Input array to be resized. | |
| current_spacing (tuple): Current voxel spacing (z_spacing, xy_spacing, xy_spacing). | |
| target_spacing (tuple): Target voxel spacing (target_z_spacing, target_x_spacing, target_y_spacing). | |
| Returns: | |
| np.ndarray: Resized array. | |
| """ | |
| # Calculate new dimensions | |
| original_shape = array.shape[2:] | |
| new_shape = [original_shape[0], 256, 256] | |
| resized_array = F.interpolate(array, size=new_shape, mode='nearest').cpu().numpy() | |
| # breakpoint() | |
| return resized_array | |
| class CTReportDataset(Dataset): | |
| def __init__(self, data_folder, csv_file, min_slices=20, resize_dim=500, force_num_frames=True): | |
| self.data_folder = data_folder | |
| self.min_slices = min_slices | |
| # self.accession_to_text = self.load_accession_text(csv_file) | |
| self.paths=[] | |
| self.samples = self.prepare_samples() | |
| percent = 80 | |
| num_files = int((len(self.samples) * percent) / 100) | |
| #num_files = 2286 | |
| self.samples = self.samples[:num_files] | |
| print(len(self.samples)) | |
| self.count = 0 | |
| # breakpoint() | |
| #self.resize_dim = resize_dim | |
| #self.resize_transform = transforms.Resize((resize_dim, resize_dim)) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((resize_dim,resize_dim)), | |
| transforms.ToTensor() | |
| ]) | |
| self.nii_to_tensor = partial(self.nii_img_to_tensor, transform = self.transform) | |
| self.sample_length=64 | |
| def load_accession_text(self, csv_file): | |
| df = pd.read_csv(csv_file) | |
| accession_to_text = {} | |
| for index, row in df.iterrows(): | |
| # breakpoint() | |
| accession_to_text[row['VolumeName']] = row["Findings_EN"],row['Impressions_EN'] | |
| return accession_to_text | |
| def prepare_samples(self): | |
| samples = [] | |
| import json | |
| with open('/sd/shuhan/CT-RATE/single_disease_mask_json/train_single_prompt_opacity.json', 'r') as f: | |
| items = [json.loads(line) for line in f] | |
| # 2. 提取所有 volume_path | |
| effusion_mask_paths = [item['disease_mask'] | |
| for item in items | |
| if 'disease_mask' in item] | |
| organ_mask_paths = [item['organ_mask'] | |
| for item in items | |
| if 'organ_mask' in item] | |
| disease_findings_list = [item['disease_findings'] | |
| for item in items | |
| if 'disease_findings' in item] | |
| disease_mask_channels = [item['disease_mask_channel'] | |
| for item in items | |
| if 'disease_mask_channel' in item] | |
| disease_labels = [item['disease_label'] | |
| for item in items | |
| if 'disease_label' in item] | |
| disease_classes = [item['disease_class'] | |
| for item in items | |
| if 'disease_class' in item] | |
| for (organ_mask_file, effusion_mask_file, disease_findings, disease_mask_channel, disease_label, disease_class) in tqdm.tqdm(zip(organ_mask_paths, effusion_mask_paths, disease_findings_list, disease_mask_channels, disease_labels, disease_classes)): | |
| # if effusion_mask_file=='effusion_mask/train_fixed/train_288_b_1.nii.gz': | |
| # continue | |
| # breakpoint() | |
| # for patient_folder in tqdm.tqdm(glob.glob(os.path.join(self.data_folder, '*'))): | |
| # for accession_folder in glob.glob(os.path.join(patient_folder, '*')): | |
| # for nii_file in glob.glob(os.path.join(accession_folder, '*.nii.gz')): | |
| accession_number = organ_mask_file.split("/")[-1] | |
| seg_file = '/sd/shuhan/CT-RATE/'+effusion_mask_file | |
| nii_file = '/sd/shuhan/CT-RATE/'+organ_mask_file | |
| # breakpoint() | |
| #accession_number = accession_number.replace(".npz", ".nii.gz") | |
| # if accession_number not in self.accession_to_text: | |
| # continue | |
| impression_text = disease_findings | |
| # if impression_text == "Not given.": | |
| # impression_text="" | |
| # input_text_concat = "" | |
| # for text in impression_text: | |
| # input_text_concat = input_text_concat + str(text) | |
| # input_text_concat = impression_text[0] | |
| # input_text = f'{impression_text}' | |
| samples.append((nii_file, seg_file, impression_text, disease_mask_channel, disease_label, disease_class)) | |
| self.paths.append(nii_file) | |
| return samples | |
| def __len__(self): | |
| return len(self.samples) | |
| def nii_img_to_tensor(self, path, seg_file, disease_mask_channel, disease_label, disease_class, transform): | |
| nii_img = nib.load(str(path)) | |
| img_data = nii_img.get_fdata() | |
| df = pd.read_csv("/sd/shuhan/CT-RATE/metadata/all_metadata.csv") #select the metadata | |
| file_name = path.split("/")[-1] | |
| row = df[df['VolumeName'] == file_name] | |
| slope = float(row["RescaleSlope"].iloc[0]) | |
| intercept = float(row["RescaleIntercept"].iloc[0]) | |
| xy_spacing = float(row["XYSpacing"].iloc[0][1:][:-2].split(",")[0]) | |
| z_spacing = float(row["ZSpacing"].iloc[0]) | |
| nii_seg = nib.load(str(seg_file)) | |
| # breakpoint() | |
| # xy_spacing = nib.affines.voxel_sizes(img.affine) | |
| # z_spacing = nib.affines.voxel_sizes(img.affine) | |
| mask_data = nii_seg.get_fdata()[int(disease_mask_channel)] | |
| current = (z_spacing, xy_spacing, xy_spacing) | |
| # breakpoint() | |
| # img_data = slope * img_data + intercept | |
| img_data = img_data.transpose(2, 0, 1) | |
| tensor = torch.tensor(img_data) | |
| tensor = tensor.unsqueeze(0).unsqueeze(0) | |
| img_data, target_spacing = resize_array(tensor, current) | |
| img_data = img_data[0][0] | |
| mask_data = mask_data.transpose(2, 0, 1) | |
| tensor = torch.tensor(mask_data) | |
| tensor = tensor.unsqueeze(0).unsqueeze(0) | |
| mask_data = resize_mask(tensor, current) | |
| mask_data = mask_data[0][0] | |
| # breakpoint() | |
| assert mask_data.shape == img_data.shape | |
| start_id = np.random.randint(0, img_data.shape[0]-1) | |
| img_data = img_data[start_id] | |
| mask_data = mask_data[start_id] | |
| mask_all = np.zeros_like(img_data) | |
| mask_all[mask_data>0] = 280 | |
| mask_data = (((mask_all ) / 300)).astype(np.float32) * 2 -1 | |
| img_data = (((img_data ) / 300)).astype(np.float32) * 2 -1 | |
| img_data = torch.tensor(img_data) | |
| mask_data = torch.tensor(mask_data) | |
| img_data = img_data.unsqueeze(0) | |
| mask_data = mask_data.unsqueeze(0) | |
| img_data=img_data.repeat(3,1,1) | |
| mask_data=mask_data.repeat(3,1,1) | |
| # example = {} | |
| # example['name'] = file_name | |
| # example['volume_data'] = tensor | |
| # # example['organ_mask'] = volume_seg | |
| # example['spacing'] = target_spacing | |
| return img_data, mask_data, target_spacing, file_name | |
| # return example | |
| def __getitem__(self, index): | |
| nii_file, seg_file, input_text, disease_mask_channel, disease_label, disease_class = self.samples[index] | |
| video_tensor, volume_seg, spacing, file_name = self.nii_to_tensor(nii_file, seg_file, disease_mask_channel, disease_label, disease_class) | |
| input_text = str(input_text) | |
| input_text = input_text.replace('"', '') | |
| input_text = input_text.replace('\'', '') | |
| input_text = input_text.replace('(', '') | |
| input_text = input_text.replace(')', '') | |
| return dict(name=file_name, edited=torch.cat([video_tensor.float(), volume_seg.float()],dim=-1), edit=dict(c_concat=torch.cat([video_tensor.float(), torch.ones_like(video_tensor).detach()*-1],dim=-1), c_crossattn=input_text)) | |