| """ |
| SimFeatUp upsamplers for dense feature restoration. |
| From SegEarth-OV/OV-2 simfeatup_dev. Used by CLIP-based variants (OV, OV-2). |
| """ |
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv |
| except Exception: |
| AdaptiveConv = None |
|
|
|
|
| def adaptive_conv_py_simple(input, filters): |
| """Pure PyTorch fallback when featup CUDA is unavailable.""" |
| b, c, h1, w1 = input.shape |
| b, h2, w2, f1, f2 = filters.shape |
| assert f1 == f2 |
| t_filters = filters.reshape(b, h2, w2, f1 * f2) |
| patches = torch.nn.Unfold(f1)(input).view((b, c, f1 * f2, h2, w2)) |
| return torch.einsum("bhwf,bcfhw->bchw", t_filters, patches) |
|
|
|
|
| def _meshgrid(device, diameter): |
| dist_range = torch.linspace(-1, 1, diameter, device=device) |
| x, y = torch.meshgrid(dist_range, dist_range, indexing="ij") |
| return torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0) |
|
|
|
|
| class Bilinear(torch.nn.Module): |
| def forward(self, source, guidance): |
| _, _, h, w = guidance.shape |
| return F.interpolate(source, (h, w), mode="bilinear") |
|
|
|
|
| class LayeredResizeConv(torch.nn.Module): |
| def __init__(self, dim, kernel_size=1, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.conv1 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
| self.conv2 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
| self.conv3 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
| self.conv4 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same") |
|
|
| def apply_conv(self, source, guidance, conv, activation): |
| big_source = F.interpolate(source, scale_factor=2, mode="bilinear") |
| _, _, h, w = big_source.shape |
| small_guidance = F.interpolate(guidance, (h, w), mode="bilinear") |
| output = activation(conv(torch.cat([big_source, small_guidance], dim=1))) |
| return big_source + output |
|
|
| def forward(self, source, guidance): |
| source_2 = self.apply_conv(source, guidance, self.conv1, F.relu) |
| source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu) |
| source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu) |
| source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x) |
| return source_16 |
|
|
|
|
| class SimpleImplicitFeaturizer(torch.nn.Module): |
| def __init__(self, n_freqs=20): |
| super().__init__() |
| self.n_freqs = n_freqs |
| self.dim_multiplier = 2 |
|
|
| def forward(self, x): |
| b, c, h, w = x.shape |
| dtype = x.dtype |
| grid_h = torch.linspace(-1, 1, h, device=x.device, dtype=dtype) |
| grid_w = torch.linspace(-1, 1, w, device=x.device, dtype=dtype) |
| feats = torch.stack(torch.meshgrid(grid_h, grid_w, indexing="ij")).unsqueeze(0) |
| feats = feats.broadcast_to((b, feats.shape[1], h, w)) |
| freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=x.device)).to(dtype).reshape( |
| 1, self.n_freqs, 1, 1, 1 |
| ) |
| feats = (feats.unsqueeze(1) * freqs).reshape(b, self.n_freqs * self.dim_multiplier, h, w) |
| return torch.cat([torch.sin(feats), torch.cos(feats), x], dim=1) |
|
|
|
|
| class IFA(torch.nn.Module): |
| def __init__(self, feat_dim, num_scales=20): |
| super().__init__() |
| self.feat_dim = feat_dim |
| self.sin_feats = SimpleImplicitFeaturizer() |
| self.mlp = nn.Sequential( |
| nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1), |
| nn.BatchNorm2d(feat_dim), |
| nn.LeakyReLU(), |
| nn.Conv2d(feat_dim, feat_dim, 1), |
| ) |
|
|
| def _upsample_2x(self, source): |
| b, c, h, w = source.shape |
| dtype = source.dtype |
| up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest") |
| lr_cord = torch.linspace(0, h, steps=h, device=source.device, dtype=dtype) |
| hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device, dtype=dtype) |
| lr_coords = torch.stack(torch.meshgrid(lr_cord, lr_cord, indexing="ij")).unsqueeze(0) |
| hr_coords = torch.stack(torch.meshgrid(hr_cord, hr_cord, indexing="ij")).unsqueeze(0) |
| up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest") |
| coord_diff = up_lr_coords - hr_coords |
| coord_diff_feats = self.sin_feats(coord_diff).to(dtype) |
| bcast_coord_feats = coord_diff_feats.broadcast_to((b, coord_diff_feats.shape[1], h * 2, w * 2)) |
| return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1)) |
|
|
| def forward(self, source, guidance): |
| _, _, gh, gw = guidance.shape |
| x = source |
| while x.shape[2] < gh or x.shape[3] < gw: |
| x = self._upsample_2x(x) |
| if x.shape[2] != gh or x.shape[3] != gw: |
| x = F.interpolate(x, (gh, gw), mode="bilinear") |
| return x |
|
|
|
|
| class JBULearnedRange(torch.nn.Module): |
| def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3): |
| super().__init__() |
| self.scale = scale |
| self.radius = radius |
| self.diameter = self.radius * 2 + 1 |
| self.guidance_dim = guidance_dim |
| self.key_dim = key_dim |
| self.feat_dim = feat_dim |
| self.range_temp = nn.Parameter(torch.tensor(0.0)) |
| self.range_proj = nn.Sequential( |
| nn.Conv2d(guidance_dim, key_dim, 1, 1), |
| nn.GELU(), |
| nn.Dropout2d(0.1), |
| nn.Conv2d(key_dim, key_dim, 1, 1), |
| ) |
| self.fixup_proj = nn.Sequential( |
| nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1), |
| nn.GELU(), |
| nn.Dropout2d(0.1), |
| nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1), |
| ) |
| self.sigma_spatial = nn.Parameter(torch.tensor(1.0)) |
|
|
| def get_range_kernel(self, x): |
| GB, GC, GH, GW = x.shape |
| proj_x = self.range_proj(x) |
| proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode="reflect") |
| queries = ( |
| torch.nn.Unfold(self.diameter)(proj_x_padded) |
| .reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW)) |
| .permute(0, 1, 3, 4, 2) |
| ) |
| pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4) |
| return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1) |
|
|
| def get_spatial_kernel(self, device): |
| patch = _meshgrid(device, self.diameter) |
| return torch.exp(-patch.square().sum(0) / (2 * self.sigma_spatial ** 2)).reshape( |
| 1, self.diameter * self.diameter, 1, 1 |
| ) |
|
|
| def forward(self, source, guidance): |
| GB, GC, GH, GW = guidance.shape |
| SB, SC, SH, SQ = source.shape |
| assert SB == GB |
| dtype = source.dtype |
| guidance = guidance.to(dtype) |
| spatial_kernel = self.get_spatial_kernel(source.device).to(dtype) |
| range_kernel = self.get_range_kernel(guidance).to(dtype) |
| combined_kernel = (range_kernel * spatial_kernel).to(dtype) |
| combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7) |
| combined_kernel += 0.1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1)) |
| combined_kernel = combined_kernel.permute(0, 2, 3, 1).reshape( |
| GB, GH, GW, self.diameter, self.diameter |
| ) |
| hr_source = F.interpolate(source, size=(GH, GW), mode="bicubic", align_corners=False) |
| hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode="reflect") |
| combined_kernel = combined_kernel.to(hr_source_padded.dtype) |
| if AdaptiveConv is not None: |
| result = AdaptiveConv.apply(hr_source_padded, combined_kernel) |
| else: |
| result = adaptive_conv_py_simple(hr_source_padded, combined_kernel) |
| return result |
|
|
|
|
| class JBUStack(torch.nn.Module): |
| def __init__(self, feat_dim, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3) |
| self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3) |
| self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3) |
| self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3) |
| self.fixup_proj = nn.Sequential( |
| nn.Dropout2d(0.2), |
| nn.Conv2d(feat_dim, feat_dim, kernel_size=1), |
| ) |
|
|
| def upsample(self, source, guidance, up): |
| _, _, h, w = source.shape |
| small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) |
| return up(source, small_guidance) |
|
|
| def forward(self, source, guidance): |
| source_2 = self.upsample(source, guidance, self.up1) |
| source_4 = self.upsample(source_2, guidance, self.up2) |
| source_8 = self.upsample(source_4, guidance, self.up3) |
| source_16 = self.upsample(source_8, guidance, self.up4) |
| return self.fixup_proj(source_16) * 0.1 + source_16 |
|
|
|
|
| class JBUOne(torch.nn.Module): |
| def __init__(self, feat_dim, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.up = JBULearnedRange(3, feat_dim, 32, radius=5) |
| self.fixup_proj = nn.Sequential( |
| nn.Dropout2d(0.2), |
| nn.Conv2d(feat_dim, feat_dim, kernel_size=1), |
| ) |
|
|
| def upsample(self, source, guidance, up): |
| _, _, h, w = source.shape |
| small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2)) |
| return up(source, small_guidance) |
|
|
| def forward(self, source, guidance): |
| source_2 = self.upsample(source, guidance, self.up) |
| source_4 = self.upsample(source_2, guidance, self.up) |
| source_8 = self.upsample(source_4, guidance, self.up) |
| source_16 = self.upsample(source_8, guidance, self.up) |
| return self.fixup_proj(source_16) * 0.1 + source_16 |
|
|
|
|
| FEATUP_CHECKPOINTS = { |
| "jbu_one": "simfeatup/xclip_jbu_one_million_aid.ckpt", |
| "jbu_stack": "simfeatup/clip_jbu_stack_cocostuff.ckpt", |
| "jbu_stack_maskclip": "simfeatup/maskclip_jbu_stack_cocostuff.ckpt", |
| } |
|
|
|
|
| def get_upsampler(name: str, feat_dim: int): |
| if name == "bilinear": |
| return Bilinear() |
| elif name == "jbu_one": |
| return JBUOne(feat_dim) |
| elif name == "jbu_stack": |
| return JBUStack(feat_dim) |
| elif name == "resize_conv": |
| return LayeredResizeConv(feat_dim, 1) |
| elif name == "ifa": |
| return IFA(feat_dim) |
| else: |
| raise ValueError(f"Unknown upsampler: {name}. Use: bilinear, jbu_one, jbu_stack, resize_conv, ifa") |
|
|