| |
| |
| |
| |
|
|
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| from skimage import filters |
| import matplotlib.pyplot as plt |
| from scipy.ndimage import maximum_filter, label, find_objects |
|
|
| def dilate_mask(latents_mask, k, latents_dtype): |
| |
| mask_2d = latents_mask.view(64, 64) |
|
|
| |
| kernel = torch.ones(2*k+1, 2*k+1, device=mask_2d.device, dtype=mask_2d.dtype) |
|
|
| |
| mask_4d = mask_2d.unsqueeze(0).unsqueeze(0) |
|
|
| |
| dilated_mask = F.conv2d(mask_4d, kernel.unsqueeze(0).unsqueeze(0), padding=k) |
|
|
| |
| dilated_mask = (dilated_mask > 0).to(mask_2d.dtype) |
|
|
| |
| dilated_mask = dilated_mask.view(4096, 1).to(latents_dtype) |
|
|
| return dilated_mask |
|
|
| def clipseg_predict(model, processor, image, text, device): |
| inputs = processor(text=text, images=image, return_tensors="pt") |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| preds = outputs.logits.unsqueeze(1) |
| preds = torch.sigmoid(preds) |
|
|
| otsu_thr = filters.threshold_otsu(preds.cpu().numpy()) |
| subject_mask = (preds > otsu_thr).float() |
|
|
| return subject_mask |
|
|
| def grounding_sam_predict(model, processor, sam_predictor, image, text, device): |
| inputs = processor(images=image, text=text, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| results = processor.post_process_grounded_object_detection( |
| outputs, |
| inputs.input_ids, |
| box_threshold=0.4, |
| text_threshold=0.3, |
| target_sizes=[image.size[::-1]] |
| ) |
|
|
| input_boxes = results[0]["boxes"].cpu().numpy() |
|
|
| if input_boxes.shape[0] == 0: |
| return torch.ones((64, 64), device=device) |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| sam_predictor.set_image(image) |
| masks, scores, logits = sam_predictor.predict( |
| point_coords=None, |
| point_labels=None, |
| box=input_boxes, |
| multimask_output=False, |
| ) |
|
|
| subject_mask = torch.tensor(masks[0], device=device) |
|
|
| return subject_mask |
|
|
| def mask_to_box_sam_predict(mask, sam_predictor, image, text, device): |
| H, W = image.size |
|
|
| |
| mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
| mask_indices = torch.nonzero(mask) |
| top_left = mask_indices.min(dim=0)[0] |
| bottom_right = mask_indices.max(dim=0)[0] |
| |
| |
| input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]]) |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| sam_predictor.set_image(image) |
| masks, scores, logits = sam_predictor.predict( |
| point_coords=None, |
| point_labels=None, |
| box=input_boxes, |
| multimask_output=True, |
| ) |
|
|
| |
| subject_mask = torch.tensor(np.max(masks, axis=0), device=device) |
|
|
| return subject_mask, input_boxes[0] |
|
|
| def mask_to_mask_sam_predict(mask, sam_predictor, image, text, device): |
| H, W = (256, 256) |
|
|
| |
| mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(1, H, W) |
| mask_input = mask.float().cpu().numpy() |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| sam_predictor.set_image(image) |
| masks, scores, logits = sam_predictor.predict( |
| point_coords=None, |
| point_labels=None, |
| mask_input=mask_input, |
| multimask_output=False, |
| ) |
|
|
| subject_mask = torch.tensor(masks[0], device=device) |
|
|
| return subject_mask |
|
|
| def mask_to_points_sam_predict(mask, sam_predictor, image, text, device): |
| H, W = image.size |
|
|
| |
| mask = F.interpolate(mask.view(1, 1, mask.shape[-2], mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
| mask_indices = torch.nonzero(mask) |
|
|
| |
| n_points = 2 |
| point_coords = mask_indices[torch.randperm(mask_indices.shape[0])[:n_points]].float().cpu().numpy() |
| point_labels = torch.ones((n_points,)).float().cpu().numpy() |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| sam_predictor.set_image(image) |
| masks, scores, logits = sam_predictor.predict( |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=False, |
| ) |
|
|
| subject_mask = torch.tensor(masks[0], device=device) |
|
|
| return subject_mask |
|
|
| def attention_to_points_sam_predict(subject_attention, subject_mask, sam_predictor, image, text, device): |
| H, W = image.size |
|
|
| |
| subject_attention = F.interpolate(subject_attention.view(1, 1, subject_attention.shape[-2], subject_attention.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
| subject_mask = F.interpolate(subject_mask.view(1, 1, subject_mask.shape[-2], subject_mask.shape[-1]), size=(H, W), mode='bilinear').view(H, W) |
|
|
| |
| subject_mask_indices = torch.nonzero(subject_mask) |
| top_left = subject_mask_indices.min(dim=0)[0] |
| bottom_right = subject_mask_indices.max(dim=0)[0] |
| box_width = bottom_right[1] - top_left[1] |
| box_height = bottom_right[0] - top_left[0] |
|
|
| |
| n_points = 3 |
| max_thr = 0.35 |
| max_attention = torch.max(subject_attention) |
| min_distance = max(box_width, box_height) // (n_points + 1) |
| |
|
|
| |
| selected_points = [] |
|
|
| |
| remaining_attention = subject_attention.clone() |
|
|
| for _ in range(n_points): |
| if remaining_attention.max() < max_thr * max_attention: |
| break |
|
|
| |
| point = torch.argmax(remaining_attention) |
| y, x = torch.unravel_index(point, remaining_attention.shape) |
| y, x = y.item(), x.item() |
| |
| |
| selected_points.append((x, y)) |
| |
| |
| y_min = max(0, y - min_distance) |
| y_max = min(H, y + min_distance + 1) |
| x_min = max(0, x - min_distance) |
| x_max = min(W, x + min_distance + 1) |
| remaining_attention[y_min:y_max, x_min:x_max] = 0 |
|
|
| |
| point_coords = np.array(selected_points) |
| point_labels = np.ones(point_coords.shape[0], dtype=int) |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| sam_predictor.set_image(image) |
| masks, scores, logits = sam_predictor.predict( |
| point_coords=point_coords, |
| point_labels=point_labels, |
| multimask_output=False, |
| ) |
|
|
| subject_mask = torch.tensor(masks[0], device=device) |
|
|
| return subject_mask, point_coords |
|
|
| def sam_refine_step(mask, sam_predictor, image, device): |
| mask_indices = torch.nonzero(mask) |
| top_left = mask_indices.min(dim=0)[0] |
| bottom_right = mask_indices.max(dim=0)[0] |
| |
| |
| input_boxes = np.array([[top_left[1].item(), top_left[0].item(), bottom_right[1].item(), bottom_right[0].item()]]) |
|
|
| with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): |
| sam_predictor.set_image(image) |
| masks, scores, logits = sam_predictor.predict( |
| point_coords=None, |
| point_labels=None, |
| box=input_boxes, |
| multimask_output=True, |
| ) |
|
|
| |
| subject_mask = torch.tensor(np.max(masks, axis=0), device=device) |
|
|
| return subject_mask, input_boxes[0] |
|
|
|
|