| import os |
| import random |
| import pandas as pd |
| import cv2 |
| import torch |
| import torch.nn.utils |
| import torch.nn.functional as F |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib.colors as mcolors |
| from sklearn.model_selection import train_test_split |
| |
| from sam2.build_sam import build_sam2 |
| from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
| def set_seeds(): |
| SEED_VALUE = 42 |
| random.seed(SEED_VALUE) |
| np.random.seed(SEED_VALUE) |
| torch.manual_seed(SEED_VALUE) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(SEED_VALUE) |
| torch.cuda.manual_seed_all(SEED_VALUE) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = True |
| |
| set_seeds() |
|
|
| data_dir = "./sam2-data" |
| images_dir = os.path.join(data_dir, "images") |
| masks_dir = os.path.join(data_dir, "masks") |
| |
| train_df = pd.read_csv(os.path.join(data_dir, "train.csv")) |
| |
| train_df, test_df = train_test_split(train_df, test_size=0.1, random_state=42) |
| |
| train_data = [] |
| for index, row in train_df.iterrows(): |
| image_name = row['imageid'] |
| mask_name = row['maskid'] |
| train_data.append({ |
| "image": os.path.join(images_dir, image_name), |
| "annotation": os.path.join(masks_dir, mask_name) |
| }) |
| |
| test_data = [] |
|
|
| for index, row in test_df.iterrows(): |
| image_name = row['imageid'] |
| mask_name = row['maskid'] |
| test_data.append({ |
| "image": os.path.join(images_dir, image_name), |
| "annotation": os.path.join(masks_dir, mask_name) |
| }) |
|
|
| def read_batch(data, visualize_data=True): |
| ent = data[np.random.randint(len(data))] |
| Img = cv2.imread(ent["image"])[..., ::-1] |
| ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE) |
| |
| if Img is None or ann_map is None: |
| print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}") |
| return None, None, None, 0 |
| |
| r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) |
| Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r))) |
| ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), |
| interpolation=cv2.INTER_NEAREST) |
| |
| binary_mask = np.zeros_like(ann_map, dtype=np.uint8) |
| points = [] |
| inds = np.unique(ann_map)[1:] |
| for ind in inds: |
| mask = (ann_map == ind).astype(np.uint8) |
| binary_mask = np.maximum(binary_mask, mask) |
| |
| eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1) |
| coords = np.argwhere(eroded_mask > 0) |
| if len(coords) > 0: |
| for _ in inds: |
| yx = np.array(coords[np.random.randint(len(coords))]) |
| points.append([yx[1], yx[0]]) |
| points = np.array(points) |
| |
| if visualize_data: |
| plt.figure(figsize=(15, 5)) |
| plt.subplot(1, 3, 1) |
| plt.title('Original Image') |
| plt.imshow(Img) |
| plt.axis('off') |
| |
| plt.subplot(1, 3, 2) |
| plt.title('Binarized Mask') |
| plt.imshow(binary_mask, cmap='gray') |
| plt.axis('off') |
| |
| plt.subplot(1, 3, 3) |
| plt.title('Binarized Mask with Points') |
| plt.imshow(binary_mask, cmap='gray') |
| colors = list(mcolors.TABLEAU_COLORS.values()) |
| for i, point in enumerate(points): |
| plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100) |
| plt.axis('off') |
| |
| plt.tight_layout() |
| plt.show() |
| |
| binary_mask = np.expand_dims(binary_mask, axis=-1) |
| binary_mask = binary_mask.transpose((2, 0, 1)) |
| points = np.expand_dims(points, axis=1) |
| return Img, binary_mask, points, len(inds) |
| |
| |
| def _to_hydra_name(x): |
| if not x: |
| return None |
| s = str(x).replace("\\", "/") |
| if s.endswith(".yaml"): |
| s = s[:-5] |
| |
| |
| |
| if "/sam2/configs/" in s: |
| return s.split("/sam2/")[1] |
| if s.startswith("sam2/configs/"): |
| return s[len("sam2/"):] |
| return s |
|
|
| sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt" |
| model_cfg = "./sam2/configs/sam2.1/sam2.1_hiera_l.yaml" |
|
|
| model_cfg = _to_hydra_name(model_cfg) |
| sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") |
| predictor = SAM2ImagePredictor(sam2_model) |
| |
| predictor.model.sam_mask_decoder.train(True) |
| predictor.model.sam_prompt_encoder.train(True) |
|
|
| scaler = torch.amp.GradScaler() |
| NO_OF_STEPS = 1200 |
| FINE_TUNED_MODEL_NAME = "fine_tuned_sam2" |
| |
| optimizer = torch.optim.AdamW(params=predictor.model.parameters(), |
| lr=0.00005, |
| weight_decay=1e-4) |
| |
| scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.6) |
| accumulation_steps = 8 |
|
|
| def train(predictor, train_data, step, mean_iou): |
| |
| if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): |
| mean_iou = 0.0 |
|
|
| eps = 1e-6 |
|
|
| predictor.model.train() |
| with torch.amp.autocast(device_type='cuda'): |
| image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False) |
|
|
| |
| if image is None or mask is None or num_masks == 0: |
| return mean_iou |
|
|
| input_label = np.ones((num_masks, 1), dtype=np.int64) |
|
|
| if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray): |
| return mean_iou |
| if input_point.size == 0 or input_label.size == 0: |
| return mean_iou |
|
|
| predictor.set_image(image) |
| mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts( |
| input_point, input_label, box=None, mask_logits=None, normalize_coords=True |
| ) |
| if ( |
| unnorm_coords is None or labels is None or |
| unnorm_coords.shape[0] == 0 or labels.shape[0] == 0 |
| ): |
| return mean_iou |
|
|
| sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( |
| points=(unnorm_coords, labels), boxes=None, masks=None |
| ) |
|
|
| batched_mode = unnorm_coords.shape[0] > 1 |
| high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]] |
|
|
| low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( |
| image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), |
| image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=True, |
| repeat_image=batched_mode, |
| high_res_features=high_res_features, |
| ) |
|
|
| prd_masks = predictor._transforms.postprocess_masks( |
| low_res_masks, predictor._orig_hw[-1] |
| ) |
|
|
| gt_mask = torch.tensor(mask.astype(np.float32), device='cuda') |
| prd_mask = torch.sigmoid(prd_masks[:, 0]) |
|
|
| |
| seg_loss = (-gt_mask * torch.log(prd_mask + eps) |
| - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean() |
|
|
| |
| pred_bin = (prd_mask > 0.5).float() |
| inter = (gt_mask * pred_bin).sum(dim=(1, 2)) |
| denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter |
| iou = inter / (denom + eps) |
|
|
| score_loss = torch.abs(prd_scores[:, 0] - iou).mean() |
| loss = seg_loss + 0.05 * score_loss |
|
|
| |
| loss = loss / accumulation_steps |
| scaler.scale(loss).backward() |
|
|
| torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0) |
|
|
| did_optimizer_step = False |
| if step % accumulation_steps == 0: |
| |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| did_optimizer_step = True |
|
|
| |
| if did_optimizer_step: |
| scheduler.step() |
|
|
| |
| iou_np = iou.detach().float().cpu().numpy() |
| iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0) |
| mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np))) |
|
|
| if step % 100 == 0: |
| current_lr = optimizer.param_groups[0]["lr"] |
| print(f"Step {step}: LR={current_lr:.6f} IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}") |
|
|
| return mean_iou |
|
|
| def validate(predictor, test_data, step, mean_iou): |
| |
| if mean_iou is None or (isinstance(mean_iou, float) and (mean_iou != mean_iou)): |
| mean_iou = 0.0 |
|
|
| predictor.model.eval() |
| with torch.amp.autocast(device_type='cuda'): |
| with torch.no_grad(): |
| image, mask, input_point, num_masks = read_batch(test_data, visualize_data=False) |
|
|
| |
| if image is None or mask is None or num_masks == 0: |
| return mean_iou |
|
|
| input_label = np.ones((num_masks, 1), dtype=np.int64) |
|
|
| if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray): |
| return mean_iou |
| if input_point.size == 0 or input_label.size == 0: |
| return mean_iou |
|
|
| predictor.set_image(image) |
| mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts( |
| input_point, input_label, box=None, mask_logits=None, normalize_coords=True |
| ) |
|
|
| if ( |
| unnorm_coords is None or labels is None or |
| unnorm_coords.shape[0] == 0 or labels.shape[0] == 0 |
| ): |
| return mean_iou |
|
|
| sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder( |
| points=(unnorm_coords, labels), boxes=None, masks=None |
| ) |
|
|
| batched_mode = unnorm_coords.shape[0] > 1 |
| high_res_features = [lvl[-1].unsqueeze(0) for lvl in predictor._features["high_res_feats"]] |
| low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder( |
| image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0), |
| image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=True, |
| repeat_image=batched_mode, |
| high_res_features=high_res_features, |
| ) |
|
|
| prd_masks = predictor._transforms.postprocess_masks( |
| low_res_masks, predictor._orig_hw[-1] |
| ) |
|
|
| gt_mask = torch.tensor(mask.astype(np.float32), device='cuda') |
| prd_mask = torch.sigmoid(prd_masks[:, 0]) |
|
|
| |
| eps = 1e-6 |
| seg_loss = (-gt_mask * torch.log(prd_mask + eps) |
| - (1 - gt_mask) * torch.log((1 - prd_mask) + eps)).mean() |
|
|
| |
| pred_bin = (prd_mask > 0.5).float() |
| inter = (gt_mask * pred_bin).sum(dim=(1, 2)) |
| denom = gt_mask.sum(dim=(1, 2)) + pred_bin.sum(dim=(1, 2)) - inter |
| iou = inter / (denom + eps) |
|
|
| |
| score_loss = torch.abs(prd_scores[:, 0] - iou).mean() |
| loss = seg_loss + 0.05 * score_loss |
| loss = loss / accumulation_steps |
|
|
| if step % 100 == 0: |
| torch.save(predictor.model.state_dict(), f"./checkpoints-ft/{FINE_TUNED_MODEL_NAME}_{step}.pt") |
|
|
| iou_np = iou.detach().float().cpu().numpy() |
| iou_np = np.nan_to_num(iou_np, nan=0.0, posinf=1.0, neginf=0.0) |
| mean_iou = float(mean_iou * 0.99 + 0.01 * float(np.mean(iou_np))) |
|
|
| if step % 100 == 0: |
| current_lr = optimizer.param_groups[0]["lr"] |
| print(f"Step {step}: LR={current_lr:.6f} Valid_IoU={mean_iou:.6f} SegLoss={seg_loss.item():.6f}") |
|
|
| return mean_iou |
|
|
| train_mean_iou = 0 |
| valid_mean_iou = 0 |
| |
| |
| |
| |
|
|
| def read_image(image_path, mask_path): |
| img = cv2.imread(image_path)[..., ::-1] |
| mask = cv2.imread(mask_path, 0) |
| r = np.min([1024 / img.shape[1], 1024 / img.shape[0]]) |
| img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r))) |
| mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST) |
| return img, mask |
| |
| def get_points(mask, num_points): |
| points = [] |
| coords = np.argwhere(mask > 0) |
| for i in range(num_points): |
| yx = np.array(coords[np.random.randint(len(coords))]) |
| points.append([[yx[1], yx[0]]]) |
| return np.array(points) |
|
|
| for n in range(3): |
| selected_entry = random.choice(test_data) |
| print(selected_entry) |
| image_path = selected_entry['image'] |
| mask_path = selected_entry['annotation'] |
| print(mask_path,'mask path') |
| |
| |
| image, target_mask = read_image(image_path, mask_path) |
| |
| |
| num_samples = 30 |
| input_points = get_points(target_mask, num_samples) |
| |
| |
| FINE_TUNED_MODEL_WEIGHTS = "./checkpoints-ft/fine_tuned_sam2_1200.pt" |
| sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") |
| |
| |
| predictor = SAM2ImagePredictor(sam2_model) |
| predictor.model.load_state_dict(torch.load(FINE_TUNED_MODEL_WEIGHTS)) |
| |
| |
| |
| |
| with torch.no_grad(): |
| predictor.set_image(image) |
| masks, scores, logits = predictor.predict( |
| point_coords=input_points, |
| point_labels=np.ones([input_points.shape[0], 1]) |
| ) |
| |
| |
| np_masks = np.array(masks[:, 0]) |
| np_scores = scores[:, 0] |
| sorted_masks = np_masks[np.argsort(np_scores)][::-1] |
| |
| |
| seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8) |
| occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool) |
| |
| |
| for i in range(sorted_masks.shape[0]): |
| mask = sorted_masks[i] |
| if (mask * occupancy_mask).sum() / mask.sum() > 0.15: |
| continue |
| |
| mask_bool = mask.astype(bool) |
| mask_bool[occupancy_mask] = False |
| seg_map[mask_bool] = i + 1 |
| occupancy_mask[mask_bool] = True |
| |
| |
| plt.figure(figsize=(18, 6)) |
| |
| plt.subplot(1, 3, 1) |
| plt.title('Test Image') |
| plt.imshow(image) |
| plt.axis('off') |
| |
| plt.subplot(1, 3, 2) |
| plt.title('Original Mask') |
| plt.imshow(target_mask, cmap='gray') |
| plt.axis('off') |
| |
| plt.subplot(1, 3, 3) |
| plt.title('Final Segmentation') |
| plt.imshow(seg_map, cmap='jet') |
| plt.axis('off') |
| |
| plt.tight_layout() |
| plt.show() |