| """ |
| Dataset and data loading for S2F training. |
| Expects folder structure: each subfolder has BF_001.tif (bright field), *_gray.jpg (heatmap), and optionally .txt (cell_area, sum_force). |
| """ |
| import os |
| import cv2 |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from sklearn.model_selection import train_test_split |
| from concurrent.futures import ThreadPoolExecutor |
| import numpy as np |
|
|
| from utils import config |
|
|
|
|
| def blur_force_map(force_map, ksize=25, sigma=10): |
| if ksize % 2 == 0: |
| ksize += 1 |
| if force_map.dim() == 3: |
| force_map = force_map.unsqueeze(0) |
| device = force_map.device |
| force_map = force_map.cpu() |
| blurred_maps = [] |
| for i in range(force_map.size(0)): |
| force_np = force_map[i, 0].numpy().astype(np.float32) |
| blurred = cv2.GaussianBlur(force_np, (ksize, ksize), sigmaX=sigma) |
| blurred_maps.append(blurred) |
| return torch.from_numpy(np.stack(blurred_maps)).to(device) |
|
|
|
|
| class ImageDataset(Dataset): |
| def __init__(self, image_pairs, transform=None, channel_first=True, |
| blur_heatmap=False, threshold=0.0, return_metadata=False): |
| self.image_pairs = image_pairs |
| self.transform = transform |
| self.channel_first = channel_first |
| self.blur_heatmap = blur_heatmap |
| self.threshold = threshold |
| self.return_metadata = return_metadata |
|
|
| def __len__(self): |
| return len(self.image_pairs) |
|
|
| def __getitem__(self, idx): |
| if self.return_metadata: |
| bf_image, hm_image, numbers, metadata = self.image_pairs[idx] |
| else: |
| bf_image, hm_image, numbers = self.image_pairs[idx] |
| if isinstance(numbers, tuple): |
| cell_area, sum_force = numbers |
| else: |
| cell_area = 0 |
| sum_force = numbers |
|
|
| image = torch.from_numpy(bf_image).float().unsqueeze(0) |
| heatmap = torch.from_numpy(hm_image).float().unsqueeze(0) |
| if self.transform: |
| image, heatmap = self.transform(image, heatmap) |
| cell_area = torch.tensor(cell_area, dtype=torch.float32) |
| sum_force = torch.tensor(sum_force, dtype=torch.float32) |
| heatmap[heatmap <= self.threshold] = 0 |
| if self.blur_heatmap: |
| heatmap = blur_force_map(heatmap) |
| if not self.channel_first: |
| image = image.permute(2, 1, 0) |
| heatmap = heatmap.permute(2, 1, 0) |
| if self.return_metadata: |
| return image, heatmap, cell_area, sum_force, metadata |
| return image, heatmap, cell_area, sum_force |
|
|
|
|
| def load_image(filepath, target_size): |
| img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE) |
| if isinstance(target_size, int): |
| target_size = (target_size, target_size) |
| img = cv2.resize(img, target_size) |
| img = img / 255.0 |
| return img.astype(np.float32) |
|
|
|
|
| def load_text_data(filepath): |
| with open(filepath, 'r') as f: |
| lines = [line.strip() for line in f if line.strip()] |
| cell_area_diff = float(lines[0].split(":")[1].strip()) * config.SCALE_FACTOR_AREA |
| sum_force_diff = float(lines[1].split(":")[1].strip()) * config.SCALE_FACTOR_FORCE |
| return (cell_area_diff, sum_force_diff) |
|
|
|
|
| def load_images_from_subfolders(root_folder, target_size, load_numerical_data=True, |
| load_force_sum=False, return_metadata=False, substrate=None): |
| paired_images = [] |
| numerical_data = [] |
| metadata = [] |
| for subfolder in os.listdir(root_folder): |
| subfolder_path = os.path.join(root_folder, subfolder) |
| if not os.path.isdir(subfolder_path): |
| continue |
| bf_image_path = hm_image_path = txt_file_path = None |
| for filename in os.listdir(subfolder_path): |
| if filename.endswith("BF_001.tif"): |
| bf_image_path = os.path.join(subfolder_path, filename) |
| elif filename.endswith("_gray.jpg"): |
| hm_image_path = os.path.join(subfolder_path, filename) |
| elif filename.endswith(".txt"): |
| txt_file_path = os.path.join(subfolder_path, filename) |
|
|
| if return_metadata: |
| if substrate is None: |
| from utils.substrate_settings import list_substrates |
| raise ValueError("substrate must be passed when return_metadata=True. Options: " + |
| ", ".join(list_substrates())) |
| metadata.append({'folder_name': subfolder, 'substrate': substrate, 'root_folder': root_folder}) |
|
|
| if load_numerical_data: |
| if bf_image_path and hm_image_path and txt_file_path: |
| paired_images.append((bf_image_path, hm_image_path)) |
| numerical_data.append(load_text_data(txt_file_path)) |
| elif load_force_sum: |
| if bf_image_path and hm_image_path: |
| paired_images.append((bf_image_path, hm_image_path)) |
| hm = load_image(hm_image_path, target_size) |
| numerical_data.append((0, float(np.sum(hm)) * config.SCALE_FACTOR_FORCE)) |
| else: |
| if bf_image_path and hm_image_path: |
| paired_images.append((bf_image_path, hm_image_path)) |
|
|
| with ThreadPoolExecutor() as executor: |
| bf_loaded = list(executor.map(lambda p: load_image(p[0], target_size), paired_images)) |
| hm_loaded = list(executor.map(lambda p: load_image(p[1], target_size), paired_images)) |
| if not numerical_data: |
| numerical_data = [(0, 0)] * len(bf_loaded) |
| if return_metadata: |
| return list(zip(bf_loaded, hm_loaded, numerical_data, metadata)) |
| return list(zip(bf_loaded, hm_loaded, numerical_data)) |
|
|
|
|
| def prepare_data(input_folder, batch_size=8, target_size=(1024, 1024), split_size=0.2, |
| use_augmentations=True, train_test_sep_folder=True, channel_first=True, |
| load_numerical_data=False, load_force_sum=False, blur_heatmap=False, |
| threshold=0.0, return_metadata=False, substrate=None): |
| if load_numerical_data and load_force_sum: |
| raise ValueError("load_numerical_data and load_force_sum cannot be True at the same time") |
|
|
| if train_test_sep_folder: |
| train_folder = os.path.join(input_folder, 'train') |
| test_folder = os.path.join(input_folder, 'test') |
| if not (os.path.exists(train_folder) and os.path.exists(test_folder)): |
| raise ValueError(f"train/test folders not found in {input_folder}") |
| train_pairs = load_images_from_subfolders(train_folder, target_size=target_size, |
| load_numerical_data=load_numerical_data, |
| load_force_sum=load_force_sum, |
| return_metadata=return_metadata, substrate=substrate) |
| val_pairs = load_images_from_subfolders(test_folder, target_size=target_size, |
| load_numerical_data=load_numerical_data, |
| load_force_sum=load_force_sum, |
| return_metadata=return_metadata, substrate=substrate) |
| else: |
| image_pairs = load_images_from_subfolders(input_folder, target_size=target_size, |
| load_numerical_data=load_numerical_data, |
| load_force_sum=load_force_sum, |
| return_metadata=return_metadata, substrate=substrate) |
| train_pairs, val_pairs = train_test_split(image_pairs, test_size=split_size, random_state=42) |
|
|
| train_transform = None |
| if use_augmentations: |
| from .augmentations import AdvancedAugmentations |
| train_transform = AdvancedAugmentations(target_size) |
|
|
| train_dataset = ImageDataset(train_pairs, transform=train_transform, channel_first=channel_first, |
| blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata) |
| train_dataset.name = os.path.basename(input_folder) |
| val_dataset = ImageDataset(val_pairs, channel_first=channel_first, |
| blur_heatmap=blur_heatmap, threshold=threshold, return_metadata=return_metadata) |
| val_dataset.name = os.path.basename(input_folder) |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
| return train_loader, val_loader |
|
|
|
|
| def load_folder_data(folder_path, substrate=None, img_size=1024, blur_heatmap=False, |
| batch_size=2, threshold=0.0, return_metadata=False): |
| val_pairs = load_images_from_subfolders(folder_path, target_size=img_size, |
| load_numerical_data=False, load_force_sum=False, |
| return_metadata=return_metadata, substrate=substrate) |
| val_dataset = ImageDataset(val_pairs, channel_first=True, blur_heatmap=blur_heatmap, |
| threshold=threshold, return_metadata=return_metadata) |
| val_dataset.name = os.path.basename(folder_path) |
| return DataLoader(val_dataset, batch_size=batch_size, shuffle=False) |
|
|
|
|
| def collect_image_paths(folder_path, exts=None): |
| if exts is None: |
| exts = {".tif", ".tiff", ".jpg", ".jpeg", ".png"} |
| paths = [] |
| for root, _, files in os.walk(os.path.normpath(folder_path)): |
| for f in files: |
| if os.path.splitext(f)[1].lower() in exts: |
| paths.append(os.path.join(root, f)) |
| return sorted(paths) |
|
|
|
|
| class BrightfieldOnlyDataset(Dataset): |
| """Dataset of brightfield images only (no labels), for inference.""" |
| def __init__(self, folder_path, target_size=1024): |
| self.paths = collect_image_paths(folder_path) |
| self.target_size = (target_size, target_size) if isinstance(target_size, int) else target_size |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, i): |
| x = load_image(self.paths[i], self.target_size) |
| return torch.from_numpy(x).float().unsqueeze(0) |
|
|
|
|
| def load_brightfield_loader(folder_path, img_size=1024, batch_size=2): |
| ds = BrightfieldOnlyDataset(folder_path, target_size=img_size) |
| return DataLoader(ds, batch_size=batch_size, shuffle=False) |
|
|