SegEarth-OV / OV /upsamplers.py
Dingyi111's picture
Duplicate from BiliSakura/SegEarth-OV
fabc606
"""
SimFeatUp upsamplers for dense feature restoration.
From SegEarth-OV/OV-2 simfeatup_dev. Used by CLIP-based variants (OV, OV-2).
"""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from featup.adaptive_conv_cuda.adaptive_conv import AdaptiveConv
except Exception:
AdaptiveConv = None
def adaptive_conv_py_simple(input, filters):
"""Pure PyTorch fallback when featup CUDA is unavailable."""
b, c, h1, w1 = input.shape
b, h2, w2, f1, f2 = filters.shape
assert f1 == f2
t_filters = filters.reshape(b, h2, w2, f1 * f2)
patches = torch.nn.Unfold(f1)(input).view((b, c, f1 * f2, h2, w2))
return torch.einsum("bhwf,bcfhw->bchw", t_filters, patches)
def _meshgrid(device, diameter):
dist_range = torch.linspace(-1, 1, diameter, device=device)
x, y = torch.meshgrid(dist_range, dist_range, indexing="ij")
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
class Bilinear(torch.nn.Module):
def forward(self, source, guidance):
_, _, h, w = guidance.shape
return F.interpolate(source, (h, w), mode="bilinear")
class LayeredResizeConv(torch.nn.Module):
def __init__(self, dim, kernel_size=1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv1 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
self.conv2 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
self.conv3 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
self.conv4 = nn.Conv2d(dim + 3, dim, kernel_size, padding="same")
def apply_conv(self, source, guidance, conv, activation):
big_source = F.interpolate(source, scale_factor=2, mode="bilinear")
_, _, h, w = big_source.shape
small_guidance = F.interpolate(guidance, (h, w), mode="bilinear")
output = activation(conv(torch.cat([big_source, small_guidance], dim=1)))
return big_source + output
def forward(self, source, guidance):
source_2 = self.apply_conv(source, guidance, self.conv1, F.relu)
source_4 = self.apply_conv(source_2, guidance, self.conv2, F.relu)
source_8 = self.apply_conv(source_4, guidance, self.conv3, F.relu)
source_16 = self.apply_conv(source_8, guidance, self.conv4, lambda x: x)
return source_16
class SimpleImplicitFeaturizer(torch.nn.Module):
def __init__(self, n_freqs=20):
super().__init__()
self.n_freqs = n_freqs
self.dim_multiplier = 2
def forward(self, x):
b, c, h, w = x.shape
dtype = x.dtype
grid_h = torch.linspace(-1, 1, h, device=x.device, dtype=dtype)
grid_w = torch.linspace(-1, 1, w, device=x.device, dtype=dtype)
feats = torch.stack(torch.meshgrid(grid_h, grid_w, indexing="ij")).unsqueeze(0)
feats = feats.broadcast_to((b, feats.shape[1], h, w))
freqs = torch.exp(torch.linspace(-2, 10, self.n_freqs, device=x.device)).to(dtype).reshape(
1, self.n_freqs, 1, 1, 1
)
feats = (feats.unsqueeze(1) * freqs).reshape(b, self.n_freqs * self.dim_multiplier, h, w)
return torch.cat([torch.sin(feats), torch.cos(feats), x], dim=1)
class IFA(torch.nn.Module):
def __init__(self, feat_dim, num_scales=20):
super().__init__()
self.feat_dim = feat_dim
self.sin_feats = SimpleImplicitFeaturizer()
self.mlp = nn.Sequential(
nn.Conv2d(feat_dim + (num_scales * 4) + 2, feat_dim, 1),
nn.BatchNorm2d(feat_dim),
nn.LeakyReLU(),
nn.Conv2d(feat_dim, feat_dim, 1),
)
def _upsample_2x(self, source):
b, c, h, w = source.shape
dtype = source.dtype
up_source = F.interpolate(source, (h * 2, w * 2), mode="nearest")
lr_cord = torch.linspace(0, h, steps=h, device=source.device, dtype=dtype)
hr_cord = torch.linspace(0, h, steps=2 * h, device=source.device, dtype=dtype)
lr_coords = torch.stack(torch.meshgrid(lr_cord, lr_cord, indexing="ij")).unsqueeze(0)
hr_coords = torch.stack(torch.meshgrid(hr_cord, hr_cord, indexing="ij")).unsqueeze(0)
up_lr_coords = F.interpolate(lr_coords, (h * 2, w * 2), mode="nearest")
coord_diff = up_lr_coords - hr_coords
coord_diff_feats = self.sin_feats(coord_diff).to(dtype)
bcast_coord_feats = coord_diff_feats.broadcast_to((b, coord_diff_feats.shape[1], h * 2, w * 2))
return self.mlp(torch.cat([up_source, bcast_coord_feats], dim=1))
def forward(self, source, guidance):
_, _, gh, gw = guidance.shape
x = source
while x.shape[2] < gh or x.shape[3] < gw:
x = self._upsample_2x(x)
if x.shape[2] != gh or x.shape[3] != gw:
x = F.interpolate(x, (gh, gw), mode="bilinear")
return x
class JBULearnedRange(torch.nn.Module):
def __init__(self, guidance_dim, feat_dim, key_dim, scale=2, radius=3):
super().__init__()
self.scale = scale
self.radius = radius
self.diameter = self.radius * 2 + 1
self.guidance_dim = guidance_dim
self.key_dim = key_dim
self.feat_dim = feat_dim
self.range_temp = nn.Parameter(torch.tensor(0.0))
self.range_proj = nn.Sequential(
nn.Conv2d(guidance_dim, key_dim, 1, 1),
nn.GELU(),
nn.Dropout2d(0.1),
nn.Conv2d(key_dim, key_dim, 1, 1),
)
self.fixup_proj = nn.Sequential(
nn.Conv2d(guidance_dim + self.diameter ** 2, self.diameter ** 2, 1, 1),
nn.GELU(),
nn.Dropout2d(0.1),
nn.Conv2d(self.diameter ** 2, self.diameter ** 2, 1, 1),
)
self.sigma_spatial = nn.Parameter(torch.tensor(1.0))
def get_range_kernel(self, x):
GB, GC, GH, GW = x.shape
proj_x = self.range_proj(x)
proj_x_padded = F.pad(proj_x, pad=[self.radius] * 4, mode="reflect")
queries = (
torch.nn.Unfold(self.diameter)(proj_x_padded)
.reshape((GB, self.key_dim, self.diameter * self.diameter, GH, GW))
.permute(0, 1, 3, 4, 2)
)
pos_temp = self.range_temp.exp().clamp_min(1e-4).clamp_max(1e4)
return F.softmax(pos_temp * torch.einsum("bchwp,bchw->bphw", queries, proj_x), dim=1)
def get_spatial_kernel(self, device):
patch = _meshgrid(device, self.diameter)
return torch.exp(-patch.square().sum(0) / (2 * self.sigma_spatial ** 2)).reshape(
1, self.diameter * self.diameter, 1, 1
)
def forward(self, source, guidance):
GB, GC, GH, GW = guidance.shape
SB, SC, SH, SQ = source.shape
assert SB == GB
dtype = source.dtype
guidance = guidance.to(dtype)
spatial_kernel = self.get_spatial_kernel(source.device).to(dtype)
range_kernel = self.get_range_kernel(guidance).to(dtype)
combined_kernel = (range_kernel * spatial_kernel).to(dtype)
combined_kernel /= combined_kernel.sum(1, keepdim=True).clamp(1e-7)
combined_kernel += 0.1 * self.fixup_proj(torch.cat([combined_kernel, guidance], dim=1))
combined_kernel = combined_kernel.permute(0, 2, 3, 1).reshape(
GB, GH, GW, self.diameter, self.diameter
)
hr_source = F.interpolate(source, size=(GH, GW), mode="bicubic", align_corners=False)
hr_source_padded = F.pad(hr_source, pad=[self.radius] * 4, mode="reflect")
combined_kernel = combined_kernel.to(hr_source_padded.dtype)
if AdaptiveConv is not None:
result = AdaptiveConv.apply(hr_source_padded, combined_kernel)
else:
result = adaptive_conv_py_simple(hr_source_padded, combined_kernel)
return result
class JBUStack(torch.nn.Module):
def __init__(self, feat_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.up1 = JBULearnedRange(3, feat_dim, 32, radius=3)
self.up2 = JBULearnedRange(3, feat_dim, 32, radius=3)
self.up3 = JBULearnedRange(3, feat_dim, 32, radius=3)
self.up4 = JBULearnedRange(3, feat_dim, 32, radius=3)
self.fixup_proj = nn.Sequential(
nn.Dropout2d(0.2),
nn.Conv2d(feat_dim, feat_dim, kernel_size=1),
)
def upsample(self, source, guidance, up):
_, _, h, w = source.shape
small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2))
return up(source, small_guidance)
def forward(self, source, guidance):
source_2 = self.upsample(source, guidance, self.up1)
source_4 = self.upsample(source_2, guidance, self.up2)
source_8 = self.upsample(source_4, guidance, self.up3)
source_16 = self.upsample(source_8, guidance, self.up4)
return self.fixup_proj(source_16) * 0.1 + source_16
class JBUOne(torch.nn.Module):
def __init__(self, feat_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.up = JBULearnedRange(3, feat_dim, 32, radius=5)
self.fixup_proj = nn.Sequential(
nn.Dropout2d(0.2),
nn.Conv2d(feat_dim, feat_dim, kernel_size=1),
)
def upsample(self, source, guidance, up):
_, _, h, w = source.shape
small_guidance = F.adaptive_avg_pool2d(guidance, (h * 2, w * 2))
return up(source, small_guidance)
def forward(self, source, guidance):
source_2 = self.upsample(source, guidance, self.up)
source_4 = self.upsample(source_2, guidance, self.up)
source_8 = self.upsample(source_4, guidance, self.up)
source_16 = self.upsample(source_8, guidance, self.up)
return self.fixup_proj(source_16) * 0.1 + source_16
FEATUP_CHECKPOINTS = {
"jbu_one": "simfeatup/xclip_jbu_one_million_aid.ckpt",
"jbu_stack": "simfeatup/clip_jbu_stack_cocostuff.ckpt",
"jbu_stack_maskclip": "simfeatup/maskclip_jbu_stack_cocostuff.ckpt",
}
def get_upsampler(name: str, feat_dim: int):
if name == "bilinear":
return Bilinear()
elif name == "jbu_one":
return JBUOne(feat_dim)
elif name == "jbu_stack":
return JBUStack(feat_dim)
elif name == "resize_conv":
return LayeredResizeConv(feat_dim, 1)
elif name == "ifa":
return IFA(feat_dim)
else:
raise ValueError(f"Unknown upsampler: {name}. Use: bilinear, jbu_one, jbu_stack, resize_conv, ifa")