| from dataclasses import dataclass |
|
|
| import numpy as np |
| import timm |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
| from segmentation_models_pytorch.base import SegmentationHead |
| from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder |
| from timm.layers.create_act import create_act_layer |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import SemanticSegmenterOutput |
|
|
| from .convlstm import ConvLSTM |
|
|
|
|
| class ACTUConfig(PretrainedConfig): |
| model_type = "actu" |
|
|
| def __init__( |
| self, |
| |
| in_channels: int = 3, |
| kernel_size: tuple[int, int] = (3, 3), |
| padding="same", |
| stride=(1, 1), |
| backbone="resnet34", |
| bias=True, |
| batch_first=True, |
| bidirectional=False, |
| original_resolution=(256, 256), |
| act_layer="sigmoid", |
| n_classes=1, |
| |
| use_dem_input: bool = False, |
| use_climate_branch: bool = False, |
| |
| climate_seq_len=5, |
| climate_input_dim=6, |
| lstm_hidden_dim=128, |
| num_lstm_layers=1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.in_channels = in_channels |
| self.kernel_size = kernel_size |
| self.padding = padding |
| self.stride = stride |
| self.backbone = backbone |
| self.bias = bias |
| self.batch_first = batch_first |
| self.bidirectional = bidirectional |
| self.original_resolution = original_resolution |
| self.act_layer = act_layer |
| self.n_classes = n_classes |
|
|
| |
| self.use_dem_input = use_dem_input |
| self.use_climate_branch = use_climate_branch |
| self.climate_seq_len = climate_seq_len |
| self.climate_input_dim = climate_input_dim |
| self.lstm_hidden_dim = lstm_hidden_dim |
| self.num_lstm_layers = num_lstm_layers |
|
|
| |
| if self.use_dem_input: |
| self.in_channels += 1 |
|
|
|
|
| class ACTUForImageSegmentation(PreTrainedModel): |
| config_class = ACTUConfig |
|
|
| def __init__(self, config: ACTUConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.encoder: nn.Module = timm.create_model( |
| config.backbone, features_only=True, in_chans=config.in_channels |
| ) |
|
|
| with torch.no_grad(): |
| dummy_input_channels = config.in_channels |
| dummy_input = torch.randn( |
| 1, dummy_input_channels, *config.original_resolution, device=self.device |
| ) |
| embs = self.encoder(dummy_input) |
| self.embs_shape = [e.shape for e in embs] |
| self.encoder_channels = [e[1] for e in self.embs_shape] |
|
|
| self.convlstm = nn.ModuleList( |
| [ |
| ConvLSTM( |
| in_channels=shape[1], |
| hidden_channels=shape[1], |
| kernel_size=config.kernel_size, |
| padding=config.padding, |
| stride=config.stride, |
| bias=config.bias, |
| batch_first=config.batch_first, |
| bidirectional=config.bidirectional, |
| ) |
| for shape in self.embs_shape |
| ] |
| ) |
|
|
| if self.config.use_climate_branch: |
| self.climate_branch = ClimateBranchLSTM( |
| output_shapes=[e[1:] for e in self.embs_shape], |
| lstm_hidden_dim=config.lstm_hidden_dim, |
| climate_seq_len=config.climate_seq_len, |
| climate_input_dim=config.climate_input_dim, |
| num_lstm_layers=config.num_lstm_layers, |
| ) |
| self.fusers = nn.ModuleList( |
| GatedFusion(enc, enc) for enc in self.encoder_channels |
| ) |
|
|
| self.decoder = UnetDecoder( |
| encoder_channels=[1] + self.encoder_channels, |
| decoder_channels=self.encoder_channels[::-1], |
| n_blocks=len(self.encoder_channels), |
| ) |
|
|
| self.seg_head = nn.Sequential( |
| SegmentationHead( |
| in_channels=self.encoder_channels[0], |
| out_channels=config.n_classes, |
| ), |
| create_act_layer(config.act_layer, inplace=True), |
| ) |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| climate: torch.Tensor = None, |
| dem: torch.Tensor = None, |
| labels: torch.Tensor = None, |
| **kwargs, |
| ) -> SemanticSegmenterOutput: |
| b, t = pixel_values.shape[:2] |
| original_size = pixel_values.shape[-2:] |
|
|
| |
| if self.config.use_dem_input: |
| if dem is None: |
| raise ValueError( |
| "DEM tensor must be provided when use_dem_input is True." |
| ) |
| dem_repeated = repeat(dem, "b c h w -> b t c h w", t=t) |
| pixel_values = torch.cat([pixel_values, dem_repeated], dim=2) |
|
|
| |
| encoded_sequence = self._encode_images(pixel_values) |
|
|
| |
| if self.config.use_climate_branch: |
| if climate is None: |
| raise ValueError( |
| "Climate tensor must be provided when use_climate_branch is True." |
| ) |
|
|
| climate_features = self.climate_branch(climate) |
|
|
| |
| encoded_sequence_reshaped = [ |
| rearrange(f, "b t c h w -> (b t) c h w") for f in encoded_sequence |
| ] |
| climate_features_reshaped = [ |
| rearrange(f, "b t c h w -> (b t) c h w") for f in climate_features |
| ] |
|
|
| |
| fused_features = [ |
| fuser(img, clim) |
| for fuser, img, clim in zip( |
| self.fusers, encoded_sequence_reshaped, climate_features_reshaped |
| ) |
| ] |
|
|
| |
| encoded_sequence = [ |
| rearrange(f, "(b t) c h w -> b t c h w", b=b) for f in fused_features |
| ] |
|
|
| |
| temporal_features = self._encode_timeseries(encoded_sequence) |
|
|
| |
| logits = self._decode(temporal_features, size=original_size) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels.float().unsqueeze(1)) |
|
|
| return SemanticSegmenterOutput( |
| loss=loss, |
| logits=logits, |
| ) |
|
|
| def _encode_images(self, x: torch.Tensor) -> list[torch.Tensor]: |
| B = x.size(0) |
| encoded_frames = self.encoder(rearrange(x, "b t c h w -> (b t) c h w")) |
| return [ |
| rearrange(frames, "(b t) c h w -> b t c h w", b=B) |
| for frames in encoded_frames |
| ] |
|
|
| def _encode_timeseries(self, timeseries: torch.Tensor) -> list[torch.Tensor]: |
| outs = [] |
| for convlstm, encoded in reversed(list(zip(self.convlstm, timeseries))): |
| lstm_out, (_, _) = convlstm(encoded) |
| outs.append(lstm_out[:, -1, :, :, :]) |
| return outs |
|
|
| def _decode(self, x: torch.Tensor, size: tuple[int, int]) -> torch.Tensor: |
| trend_map = self.decoder(*[None] + x[::-1]) |
| trend_map = self.seg_head(trend_map) |
| trend_map = F.interpolate( |
| trend_map, size=size, mode="bilinear", align_corners=False |
| ) |
| return trend_map |
|
|
|
|
| class ClimateBranchLSTM(nn.Module): |
| """ |
| Processes climate time series data using an LSTM. |
| Input shape: (B, T, T_1, C_clim) -> e.g., (B, 5, 6, 5) |
| Output shape: (B, T, output_dim) -> e.g., (B, 5, 128) |
| """ |
|
|
| def __init__( |
| self, |
| output_shapes: list[tuple[int, int, int]], |
| climate_input_dim=5, |
| climate_seq_len=6, |
| lstm_hidden_dim=64, |
| num_lstm_layers=1, |
| ): |
| super().__init__() |
| self.climate_seq_len = climate_seq_len |
| self.climate_input_dim = climate_input_dim |
| self.lstm_hidden_dim = lstm_hidden_dim |
| self.num_lstm_layers = num_lstm_layers |
| self.proj_dim = 128 |
| self.output_shapes = output_shapes |
|
|
| self.lstm = nn.LSTM( |
| input_size=climate_input_dim, |
| hidden_size=lstm_hidden_dim, |
| num_layers=num_lstm_layers, |
| batch_first=True, |
| dropout=0.3 if num_lstm_layers > 1 else 0, |
| bidirectional=False, |
| ) |
|
|
| |
| self.fc = nn.Linear(lstm_hidden_dim, self.proj_dim) |
|
|
| self.upsamples = nn.ModuleList( |
| _build_upsampler(self.proj_dim, *shape[:2]) for shape in output_shapes |
| ) |
|
|
| def forward(self, climate_data: torch.Tensor) -> list[torch.Tensor]: |
| |
| B_img, B_cli, T, C = climate_data.shape |
|
|
| |
| lstm_input = rearrange(climate_data, "Bi Bc T C -> (Bi Bc) T C") |
|
|
| |
| _, (hidden, _) = self.lstm.forward(lstm_input) |
| |
| last_hidden = ( |
| hidden[[hidden.size(0) // 2, -1]] if self.lstm.bidirectional else hidden[-1] |
| ) |
| if last_hidden.ndim == 3: |
| last_hidden = hidden.mean(dim=0) |
|
|
| |
| climate_features = self.fc(last_hidden) |
| climate_features = rearrange(climate_features, "b c -> b c 1 1") |
| climate_features = [ |
| rearrange( |
| u(climate_features), "(Bi Bc) C H W -> Bi Bc C H W", Bi=B_img, Bc=B_cli |
| ) |
| for u in self.upsamples |
| ] |
|
|
| return climate_features |
|
|
|
|
| class GatedFusion(nn.Module): |
| def __init__(self, img_channels, clim_channels): |
| super().__init__() |
| self.gate = nn.Sequential( |
| nn.Sequential( |
| nn.Conv2d( |
| img_channels + clim_channels, img_channels, kernel_size=3, padding=1 |
| ), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(img_channels, img_channels, kernel_size=1), |
| nn.Sigmoid(), |
| ) |
| ) |
|
|
| def forward(self, img_feat, clim_feat): |
| gate = self.gate(torch.cat([img_feat, clim_feat], dim=1)) |
| return gate * img_feat + (1 - gate) * clim_feat |
|
|
|
|
| def _build_upsampler( |
| in_channels: int, target_channels: int, target_h: int |
| ) -> nn.Sequential: |
| layers = [] |
| current_h = 1 |
|
|
| |
| layers += [nn.Conv2d(in_channels, target_channels, kernel_size=1), nn.GELU()] |
|
|
| |
| while current_h < target_h: |
| next_h = min(current_h * 2, target_h) |
| layers += [ |
| nn.Upsample(scale_factor=2, mode="nearest"), |
| nn.Conv2d(target_channels, target_channels, kernel_size=3, padding=1), |
| nn.GELU(), |
| ] |
| current_h = next_h |
|
|
| return nn.Sequential(*layers) |
|
|