| from typing import Optional |
|
|
| import torch |
| from torch import Tensor |
| from torch.nn import Linear, Module |
| from transformers import PreTrainedModel |
|
|
| from .encoder import MarlinEncoder |
| from .decoder import MarlinDecoder |
|
|
| from .config import MarlinConfig |
|
|
|
|
| class Marlin(Module): |
| def __init__( |
| self, |
| img_size: int, |
| patch_size: int, |
| n_frames: int, |
| encoder_embed_dim: int, |
| encoder_depth: int, |
| encoder_num_heads: int, |
| decoder_embed_dim: int, |
| decoder_depth: int, |
| decoder_num_heads: int, |
| mlp_ratio: float, |
| qkv_bias: bool, |
| qk_scale: Optional[float], |
| drop_rate: float, |
| attn_drop_rate: float, |
| norm_layer: str, |
| init_values: float, |
| tubelet_size: int, |
| as_feature_extractor: bool = True, |
| ): |
| super().__init__() |
| self.encoder = MarlinEncoder( |
| img_size=img_size, |
| patch_size=patch_size, |
| n_frames=n_frames, |
| embed_dim=encoder_embed_dim, |
| depth=encoder_depth, |
| num_heads=encoder_num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| drop_rate=drop_rate, |
| attn_drop_rate=attn_drop_rate, |
| norm_layer=norm_layer, |
| init_values=init_values, |
| tubelet_size=tubelet_size, |
| ) |
| self.as_feature_extractor = as_feature_extractor |
| self.clip_frames = n_frames |
| if as_feature_extractor: |
| self.enc_dec_proj = None |
| self.decoder = None |
| else: |
| self.decoder = MarlinDecoder( |
| img_size=img_size, |
| patch_size=patch_size, |
| embed_dim=decoder_embed_dim, |
| depth=decoder_depth, |
| num_heads=decoder_num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| drop_rate=drop_rate, |
| attn_drop_rate=attn_drop_rate, |
| norm_layer=norm_layer, |
| init_values=init_values, |
| tubelet_size=tubelet_size, |
| ) |
|
|
| self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False) |
|
|
| def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: |
| if self.as_feature_extractor: |
| raise RuntimeError( |
| "For feature extraction, please use `extract_features` or `extract_video`." |
| ) |
| else: |
| assert mask is not None |
| x = self.encoder(x, mask) |
| x = self.enc_dec_proj(x) |
| x = self.decoder(x, mask) |
| return x |
|
|
| @property |
| def device(self): |
| return self.encoder.norm.weight.device |
|
|
| def extract_features(self, x: Tensor, keep_seq: bool = True): |
| """Extract features for one video clip (v)""" |
| if self.training: |
| return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) |
| else: |
| with torch.no_grad(): |
| return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) |
|
|
|
|
| class MarlinModel(PreTrainedModel): |
| config_class = MarlinConfig |
|
|
| def __init__(self, config: MarlinConfig): |
| super().__init__(config) |
| self.config = config |
| self.marlin = Marlin( |
| img_size=config.img_size, |
| patch_size=config.patch_size, |
| n_frames=config.n_frames, |
| encoder_embed_dim=config.encoder_embed_dim, |
| encoder_depth=config.encoder_depth, |
| encoder_num_heads=config.encoder_num_heads, |
| decoder_embed_dim=config.decoder_embed_dim, |
| decoder_depth=config.decoder_depth, |
| decoder_num_heads=config.decoder_num_heads, |
| mlp_ratio=config.mlp_ratio, |
| qkv_bias=config.qkv_bias, |
| qk_scale=config.qk_scale, |
| drop_rate=config.drop_rate, |
| attn_drop_rate=config.attn_drop_rate, |
| norm_layer=config.norm_layer, |
| init_values=config.init_values, |
| tubelet_size=config.tubelet_size, |
| ) |
|
|
| def forward(self, x: Tensor, keep_seq: bool = True): |
| return self.marlin.extract_features(x, keep_seq=keep_seq) |
|
|