| from itertools import product |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| try: |
| import cupy as cp |
| from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix, eye as cp_eye, diags as cp_diags |
| from cupyx.scipy.sparse import linalg as cp_s_linalg |
| except ImportError: |
| print("Cupy not installed") |
| import numpy as np |
| from scipy.sparse import csr_matrix, eye, diags |
| from scipy.sparse import linalg as s_linalg |
| from kornia.color import rgb_to_lab |
|
|
|
|
| def make_input_divisible(x: torch.Tensor, patch_size=16) -> torch.Tensor: |
| """Pad some pixels to make the input size divisible by the patch size.""" |
| B, _, H_0, W_0 = x.shape |
| pad_w = (patch_size - W_0 % patch_size) % patch_size |
| pad_h = (patch_size - H_0 % patch_size) % patch_size |
|
|
| x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0) |
|
|
| return x |
|
|
|
|
| def reshape_windows(x): |
| height_width = [(y.shape[0], y.shape[1]) for y in x] |
| dim = x[0].shape[-1] |
| x = [torch.reshape(y, (-1, dim)) for y in x] |
| |
| return torch.cat(x, dim=0), height_width |
|
|
|
|
| def normalize_connection_graph_cupy(G): |
| W = cp_csr_matrix(G) |
| W = W - cp_diags(W.diagonal(), 0) |
| S = W.sum(axis=1) |
| |
| S[S == 0] = 1 |
| D = cp.array(1.0 / cp.sqrt(S)) |
| D[cp.isnan(D)] = 0 |
| D[cp.isinf(D)] = 0 |
| D_mh = cp_diags(D.reshape(-1), 0) |
| Wn = D_mh * W * D_mh |
| return Wn |
|
|
|
|
| def normalize_connection_graph(G): |
| W = csr_matrix(G) |
| W = W - diags(W.diagonal(), 0) |
| S = W.sum(axis=1) |
| S[S == 0] = 1 |
| D = np.array(1.0 / np.sqrt(S)) |
| D[np.isnan(D)] = 0 |
| D[np.isinf(D)] = 0 |
| D_mh = diags(D.reshape(-1), 0) |
| Wn = D_mh * W * D_mh |
| return Wn |
|
|
|
|
| def cp_dfs_search(L, Y, tol=1e-6, maxiter=10): |
| out = cp_s_linalg.cg(L, Y, tol=tol, maxiter=maxiter)[0] |
|
|
| return out |
|
|
|
|
| def dfs_search(L, Y, tol=1e-6, maxiter=10): |
| out = s_linalg.cg(L, Y, rtol=tol, maxiter=maxiter)[0] |
|
|
| return out |
|
|
|
|
| def perform_lp(L, preds): |
| if torch.cuda.is_available(): |
| lp_preds = cp.zeros(preds.shape) |
| preds = cp.asarray(preds) |
| for cls_idx, y_cls in enumerate(preds.T): |
| Y = y_cls |
| lp_preds[:, cls_idx] = cp_dfs_search(L, Y) |
| lp_preds = torch.as_tensor(lp_preds, device="cuda") |
| else: |
| lp_preds = np.zeros(preds.shape) |
| for cls_idx, y_cls in enumerate(preds.T): |
| Y = y_cls |
| lp_preds[:, cls_idx] = dfs_search(L, Y) |
| lp_preds = torch.as_tensor(lp_preds, device="cpu") |
|
|
| return lp_preds |
|
|
|
|
| def get_lposs_laplacian(feats, locations, height_width, sigma=0.0, pix_dist_pow=2, k=100, gamma=1.0, alpha=0.95, patch_size=16): |
| idx_window = torch.cat([window * torch.ones((h*w, ), device=feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)]) |
| idx_h = torch.cat([torch.arange(h).view(-1,1).repeat(1, w).flatten() for h, w in height_width]).to(feats.device) |
| idx_w = torch.cat([torch.arange(w).view(1,-1).repeat(h, 1).flatten() for h, w in height_width]).to(feats.device) |
| loc_h = locations[idx_window, 0] + (patch_size // 2) + idx_h * patch_size |
| loc_w = locations[idx_window, 2] + (patch_size // 2) + idx_w * patch_size |
| locs = torch.stack((loc_h, loc_w), 1) |
| locs = torch.unsqueeze(locs, 0) |
| dist = torch.cdist(locs, locs, p=2) |
| dist = dist[0, ...] |
| dist = dist ** pix_dist_pow |
| geometry_affinity = torch.exp(-sigma * dist) |
|
|
| N = feats.shape[0] |
| |
| affinity = feats @ feats.T |
| sims, ks = torch.topk(affinity, k=k, dim=1) |
|
|
| sims[sims < 0] = 0 |
| sims = sims ** gamma |
| geometry_affinity = geometry_affinity.gather(1, ks).flatten() |
| sims = sims.flatten() |
| sims = sims * geometry_affinity |
| ks = ks.flatten() |
| rows = torch.arange(N).repeat_interleave(k) |
| |
| if torch.cuda.is_available(): |
| W = cp_csr_matrix( |
| (cp.asarray(sims), (cp.asarray(rows), cp.asarray(ks))), |
| shape=(N, N), |
| ) |
| W = W + W.T |
| Wn = normalize_connection_graph_cupy(W) |
| L = cp_eye(Wn.shape[0]) - alpha * Wn |
| else: |
| W = csr_matrix( |
| (sims.cpu().numpy(), (rows.cpu().numpy(), ks.cpu().numpy())), |
| shape=(N, N), |
| ) |
| W = W + W.T |
| Wn = normalize_connection_graph(W) |
| L = eye(Wn.shape[0]) - alpha * Wn |
|
|
| return L |
|
|
|
|
| def lposs(clip, dino, img, classnames, window_size=(224,224), window_stride=(112, 112), sigma=0.01, pix_dist_pow=1, lp_k_image=400, lp_gamma=3.0, lp_alpha=0.95): |
| h_stride, w_stride = window_stride |
| h_crop, w_crop = window_size |
| batch_size, _, h_img, w_img = img.size() |
| h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 |
| w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 |
|
|
| clf = clip.get_classifier(classnames) |
| |
| locations = img.new_zeros((h_grids*w_grids, 4)) |
| dino_feats = [] |
| clip_feats = [] |
| for h_idx in range(h_grids): |
| for w_idx in range(w_grids): |
| y1 = h_idx * h_stride |
| x1 = w_idx * w_stride |
| y2 = min(y1 + h_crop, h_img) |
| x2 = min(x1 + w_crop, w_img) |
| y1 = max(y2 - h_crop, 0) |
| x1 = max(x2 - w_crop, 0) |
| crop_img = img[:, :, y1:y2, x1:x2] |
|
|
| img_dino_feats, (h_dino, w_dino) = dino(make_input_divisible(crop_img, dino.patch_size)) |
| img_dino_feats = img_dino_feats.reshape((batch_size, -1, h_dino, w_dino)).permute(0, 2, 3, 1) |
| img_clip_feats = clip(make_input_divisible(crop_img, clip.patch_size)) |
|
|
| if img_clip_feats.shape[1] != img_dino_feats.shape[1] or img_clip_feats.shape[2] != img_dino_feats.shape[2]: |
| img_clip_feats = F.interpolate(img_clip_feats, size=(img_dino_feats.shape[1], img_dino_feats.shape[2]), mode='bilinear', align_corners=False) |
| |
| img_clip_feats = img_clip_feats.permute(0, 2, 3, 1) |
|
|
| dino_feats.append(img_dino_feats[0, ...]) |
| clip_feats.append(img_clip_feats[0, ...]) |
| locations[h_idx*w_grids + w_idx, 0] = y1 |
| locations[h_idx*w_grids + w_idx, 1] = y2 |
| locations[h_idx*w_grids + w_idx, 2] = x1 |
| locations[h_idx*w_grids + w_idx, 3] = x2 |
|
|
| num_classes = clf.shape[0] |
|
|
| patch_size = dino.patch_size |
| |
| dino_feats, height_width = reshape_windows(dino_feats) |
| clip_feats, _ = reshape_windows(clip_feats) |
| dino_feats = F.normalize(dino_feats, p=2, dim=-1) |
| clip_feats = F.normalize(clip_feats, p=2, dim=-1) |
|
|
| L = get_lposs_laplacian(dino_feats, locations, height_width, sigma=sigma, pix_dist_pow=pix_dist_pow, k=lp_k_image, gamma=lp_gamma, alpha=lp_alpha, patch_size=patch_size) |
| clip_preds = clip_feats @ clf.T |
| |
| lp_preds = perform_lp(L, clip_preds) |
| |
| preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) |
| count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) |
| idx_window = torch.cat([window * torch.ones((h*w, ), device=dino_feats.device, dtype=torch.int64) for window, (h, w) in enumerate(height_width)]) |
| for h_idx in range(h_grids): |
| for w_idx in range(w_grids): |
| y1 = h_idx * h_stride |
| x1 = w_idx * w_stride |
| y2 = min(y1 + h_crop, h_img) |
| x2 = min(x1 + w_crop, w_img) |
| y1 = max(y2 - h_crop, 0) |
| x1 = max(x2 - w_crop, 0) |
| win_id = h_idx*w_grids + w_idx |
| crop_seg_logit = lp_preds[torch.where(idx_window == win_id)[0], :] |
| crop_seg_logit = torch.reshape(crop_seg_logit, height_width[win_id]+(num_classes, )) |
| crop_seg_logit = torch.unsqueeze(crop_seg_logit, 0) |
| crop_seg_logit = torch.permute(crop_seg_logit, (0, 3, 1, 2)) |
| crop_seg_logit = F.interpolate( |
| input=crop_seg_logit, |
| size=(y2-y1, x2-x1), |
| mode='bilinear', |
| align_corners=False |
| ) |
| assert crop_seg_logit.shape[2] == (y2 - y1) and crop_seg_logit.shape[3] == (x2 - x1) |
| preds += F.pad(crop_seg_logit, |
| (int(x1), int(preds.shape[3] - x2), int(y1), |
| int(preds.shape[2] - y2))) |
|
|
| count_mat[:, :, y1:y2, x1:x2] += 1 |
| assert (count_mat == 0).sum() == 0 |
| preds = preds / count_mat |
|
|
| return preds |
|
|
|
|
| def get_pixel_connections(img, neigh=1): |
| img = img[0, ...] |
| img_lab = rgb_to_lab(img) |
| img_lab = img_lab.permute((1, 2, 0)) |
| img_lab /= torch.tensor([100, 128, 128], device=img.device) |
| img_h, img_w, _ = img_lab.shape |
| img_lab = img_lab.reshape((img_h*img_w, -1)) |
|
|
| idx = torch.arange(img_h * img_w).to(img.device) |
| loc_h = idx // img_w |
| loc_w = idx % img_w |
| locs = torch.stack((loc_h, loc_w), 1) |
| |
| rows, cols = [], [] |
|
|
| for mov in product(range(-neigh, neigh+1), range(-neigh, neigh+1)): |
| if mov[0] == 0 and mov[1] == 0: |
| continue |
| new_locs = locs + torch.tensor(mov).to(img.device) |
| mask = torch.logical_and(torch.logical_and(torch.logical_and(new_locs[:, 0] >= 0, new_locs[:, 1] >= 0), new_locs[:, 0] < img_h), new_locs[:, 1] < img_w) |
| rows.append(torch.where(mask)[0]) |
| col = new_locs[mask, :] |
| col = col[:, 0] * img_w + col[:, 1] |
| cols.append(col) |
|
|
| rows = torch.cat(rows) |
| cols = torch.cat(cols) |
| pixel_pixel_data = ((img_lab[rows, :] - img_lab[cols, :]) ** 2).sum(dim=-1) |
|
|
| return rows, cols, pixel_pixel_data, locs |
|
|
|
|
| def get_laplacian(rows, cols, data, N, alpha=0.99): |
| if torch.cuda.is_available(): |
| rows = cp.asarray(rows) |
| cols = cp.asarray(cols) |
| data = cp.asarray(data) |
| W = cp_csr_matrix( |
| (data, (rows, cols)), |
| shape=(N, N), |
| ) |
|
|
| Wn = normalize_connection_graph_cupy(W) |
| L = cp_eye(Wn.shape[0]) - alpha * Wn |
| else: |
| W = csr_matrix( |
| (data.cpu().numpy(), (rows.cpu().numpy(), cols.cpu().numpy())), |
| shape=(N, N), |
| ) |
|
|
| Wn = normalize_connection_graph(W) |
| L = eye(Wn.shape[0]) - alpha * Wn |
| return L |
|
|
|
|
| def lposs_plus(img, preds, tau=0.01, alpha=0.95, r=13): |
| preds = preds[0, ...] |
| num_classes, h_img, w_img = preds.shape |
| preds = preds.permute((1, 2, 0)) |
| preds = preds.reshape((h_img*w_img, -1)) |
|
|
| rows, cols, pixel_pixel_data, locs = get_pixel_connections(img, neigh=r//2) |
| pixel_pixel_data = torch.sqrt(pixel_pixel_data) |
| pixel_pixel_data = torch.exp(-pixel_pixel_data / tau) |
| L = get_laplacian(rows, cols, pixel_pixel_data, preds.shape[0], alpha=alpha) |
|
|
| lp_preds = perform_lp(L, preds) |
| |
| return lp_preds.reshape((h_img, w_img, num_classes)).permute((2, 0, 1)).unsqueeze(0) |
|
|