""" SimFeatUp upsamplers for dense feature restoration. From SegEarth-OV/OV-2 simfeatup_dev. Used by CLIP-based variants (OV, OV-2). """ import math from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F try: from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv except Exception: AdaptiveConv = None def adaptive_conv_py_simple(input, filters): """Pure PyTorch fallback when featup CUDA is unavailable.""" b, c, h1, w1 = input.shape b, h2, w2, f1, f2 = filters.shape assert f1 == f2 t_filters = filters.reshape(b, h2, w2, f1 * f2) patches = torch.nn.Unfold(f1)(input).view((b, c, f1 * f2, h2, w2)) return torch.einsum("bhwf,bcfhw->bchw", t_filters, patches) def _meshgrid(device, diameter): dist_range = torch.linspace(-1, 1, diameter, device=device) x, y = torch.meshgrid(dist_range, dist_range, indexing="ij") return torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0) class Bilinear(torch.nn.Module): def forward(self, source, guidance): _, _, h, w = guidance.shape return F.interpolate(source, (h, w), mode="bilinear") class LayeredResizeConv(torch.nn.Module): def __init__(self, dim, kernel_size=1, *args, **kwargs): super().__init__(*args, **kwargs) self.conv1 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") self.conv2 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") self.conv3 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") self.conv4 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") def apply_conv(self, source, guidance, conv, activation): big_source = F.interpolate(source, scale_factor=2, mode="bilinear") _, _, h, w = big_source.shape small_guidance = F.interpolate(guidance, (h, w), mode="bilinear") output = activation(conv(torch.cat([big_source, small_guidance], dim=1))) return big_source + output def forward(self, source, guidance): source_2 = self.apply_conv(source, guidance, self.conv1, F.relu) source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu) source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu) source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x) return source_16 class SimpleImplicitFeaturizer(torch.nn.Module): def __init__(self, n_freqs=20): super().__init__() self.n_freqs = n_freqs self.dim_multiplier = 2 def forward(self, x): b, c, h, w = x.shape dtype = x.dtype grid_h = torch.linspace(-1, 1, h, device=x.device, dtype=dtype) grid_w = torch.linspace(-1, 1, w, device=x.device, dtype=dtype) feats = torch.stack(torch.meshgrid(grid_h, grid_w, indexing="ij")).unsqueeze(0) feats = feats.broadcast_to((b, feats.shape[1], h, w)) freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=x.device)).to(dtype).reshape( 1, self.n_freqs, 1, 1, 1 ) feats = (feats.unsqueeze(1) * freqs).reshape(b, self.n_freqs * self.dim_multiplier, h, w) return torch.cat([torch.sin(feats), torch.cos(feats), x], dim=1) class IFA(torch.nn.Module): def __init__(self, feat_dim, num_scales=20): super().__init__() self.feat_dim = feat_dim self.sin_feats = SimpleImplicitFeaturizer() self.mlp = nn.Sequential( nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1), nn.BatchNorm2d(feat_dim), nn.LeakyReLU(), nn.Conv2d(feat_dim, feat_dim, 1), ) def _upsample_2x(self, source): b, c, h, w = source.shape dtype = source.dtype up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest") lr_cord = torch.linspace(0, h, steps=h, device=source.device, dtype=dtype) hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device, dtype=dtype) lr_coords = torch.stack(torch.meshgrid(lr_cord, lr_cord, indexing="ij")).unsqueeze(0) hr_coords = torch.stack(torch.meshgrid(hr_cord, hr_cord, indexing="ij")).unsqueeze(0) up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest") coord_diff = up_lr_coords - hr_coords coord_diff_feats = self.sin_feats(coord_diff).to(dtype) bcast_coord_feats = coord_diff_feats.broadcast_to((b, coord_diff_feats.shape[1], h * 2, w * 2)) return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1)) def forward(self, source, guidance): _, _, gh, gw = guidance.shape x = source while x.shape[2] < gh or x.shape[3] < gw: x = self._upsample_2x(x) if x.shape[2] != gh or x.shape[3] != gw: x = F.interpolate(x, (gh, gw), mode="bilinear") return x class JBULearnedRange(torch.nn.Module): def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3): super().__init__() self.scale = scale self.radius = radius self.diameter = self.radius * 2 + 1 self.guidance_dim = guidance_dim self.key_dim = key_dim self.feat_dim = feat_dim self.range_temp = nn.Parameter(torch.tensor(0.0)) self.range_proj = nn.Sequential( nn.Conv2d(guidance_dim, key_dim, 1, 1), nn.GELU(), nn.Dropout2d(0.1), nn.Conv2d(key_dim, key_dim, 1, 1), ) self.fixup_proj = nn.Sequential( nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1), nn.GELU(), nn.Dropout2d(0.1), nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1), ) self.sigma_spatial = nn.Parameter(torch.tensor(1.0)) def get_range_kernel(self, x): GB, GC, GH, GW = x.shape proj_x = self.range_proj(x) proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode="reflect") queries = ( torch.nn.Unfold(self.diameter)(proj_x_padded) .reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW)) .permute(0, 1, 3, 4, 2) ) pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4) return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1) def get_spatial_kernel(self, device): patch = _meshgrid(device, self.diameter) return torch.exp(-patch.square().sum(0) / (2 * self.sigma_spatial ** 2)).reshape( 1, self.diameter * self.diameter, 1, 1 ) def forward(self, source, guidance): GB, GC, GH, GW = guidance.shape SB, SC, SH, SQ = source.shape assert SB == GB dtype = source.dtype guidance = guidance.to(dtype) spatial_kernel = self.get_spatial_kernel(source.device).to(dtype) range_kernel = self.get_range_kernel(guidance).to(dtype) combined_kernel = (range_kernel * spatial_kernel).to(dtype) combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7) combined_kernel += 0.1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1)) combined_kernel = combined_kernel.permute(0, 2, 3, 1).reshape( GB, GH, GW, self.diameter, self.diameter ) hr_source = F.interpolate(source, size=(GH, GW), mode="bicubic", align_corners=False) hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode="reflect") combined_kernel = combined_kernel.to(hr_source_padded.dtype) if AdaptiveConv is not None: result = AdaptiveConv.apply(hr_source_padded, combined_kernel) else: result = adaptive_conv_py_simple(hr_source_padded, combined_kernel) return result class JBUStack(torch.nn.Module): def __init__(self, feat_dim, *args, **kwargs): super().__init__(*args, **kwargs) self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3) self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3) self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3) self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3) self.fixup_proj = nn.Sequential( nn.Dropout2d(0.2), nn.Conv2d(feat_dim, feat_dim, kernel_size=1), ) def upsample(self, source, guidance, up): _, _, h, w = source.shape small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) return up(source, small_guidance) def forward(self, source, guidance): source_2 = self.upsample(source, guidance, self.up1) source_4 = self.upsample(source_2, guidance, self.up2) source_8 = self.upsample(source_4, guidance, self.up3) source_16 = self.upsample(source_8, guidance, self.up4) return self.fixup_proj(source_16) * 0.1 + source_16 class JBUOne(torch.nn.Module): def __init__(self, feat_dim, *args, **kwargs): super().__init__(*args, **kwargs) self.up = JBULearnedRange(3, feat_dim, 32, radius=5) self.fixup_proj = nn.Sequential( nn.Dropout2d(0.2), nn.Conv2d(feat_dim, feat_dim, kernel_size=1), ) def upsample(self, source, guidance, up): _, _, h, w = source.shape small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) return up(source, small_guidance) def forward(self, source, guidance): source_2 = self.upsample(source, guidance, self.up) source_4 = self.upsample(source_2, guidance, self.up) source_8 = self.upsample(source_4, guidance, self.up) source_16 = self.upsample(source_8, guidance, self.up) return self.fixup_proj(source_16) * 0.1 + source_16 FEATUP_CHECKPOINTS = { "jbu_one": "simfeatup/xclip_jbu_one_million_aid.ckpt", "jbu_stack": "simfeatup/clip_jbu_stack_cocostuff.ckpt", "jbu_stack_maskclip": "simfeatup/maskclip_jbu_stack_cocostuff.ckpt", } def get_upsampler(name: str, feat_dim: int): if name == "bilinear": return Bilinear() elif name == "jbu_one": return JBUOne(feat_dim) elif name == "jbu_stack": return JBUStack(feat_dim) elif name == "resize_conv": return LayeredResizeConv(feat_dim, 1) elif name == "ifa": return IFA(feat_dim) else: raise ValueError(f"Unknown upsampler: {name}. Use: bilinear, jbu_one, jbu_stack, resize_conv, ifa")