| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
|
| from transformers import PreTrainedModel |
| from configuration_neuroclr import NeuroCLRConfig |
|
|
|
|
| class NeuroCLR(nn.Module): |
| """ |
| Transformer expects x: [B, S, TSlength] because d_model = TSlength. |
| """ |
| def __init__(self, config: NeuroCLRConfig): |
| super().__init__() |
|
|
| encoder_layer = TransformerEncoderLayer( |
| d_model=config.TSlength, |
| dim_feedforward=2 * config.TSlength, |
| nhead=config.nhead, |
| batch_first=True, |
| ) |
| self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer) |
|
|
| self.projector = nn.Sequential( |
| nn.Linear(config.TSlength, config.projector_out1), |
| nn.BatchNorm1d(config.projector_out1), |
| nn.ReLU(), |
| nn.Linear(config.projector_out1, config.projector_out2), |
| ) |
|
|
| self.normalize_input = config.normalize_input |
| self.pooling = config.pooling |
| self.TSlength = config.TSlength |
|
|
| def forward(self, x: torch.Tensor): |
| |
| if self.normalize_input: |
| x = F.normalize(x, dim=-1) |
|
|
| x = self.transformer_encoder(x) |
|
|
| |
| if self.pooling == "mean": |
| h = x.mean(dim=1) |
| elif self.pooling == "last": |
| h = x[:, -1, :] |
| elif self.pooling == "flatten": |
| |
| h = x.reshape(x.shape[0], -1) |
| if h.shape[1] != self.TSlength: |
| raise ValueError( |
| f"pooling='flatten' requires seq_len==1 so h dim == TSlength. " |
| f"Got h dim {h.shape[1]} vs TSlength {self.TSlength}." |
| ) |
| else: |
| raise ValueError(f"Unknown pooling='{self.pooling}'. Use 'mean', 'last', or 'flatten'.") |
|
|
| z = self.projector(h) |
|
|
| return h, z |
|
|
|
|
| class NeuroCLRModel(PreTrainedModel): |
| """ |
| Loads with: |
| AutoModel.from_pretrained(..., trust_remote_code=True) |
| """ |
| config_class = NeuroCLRConfig |
| base_model_prefix = "neuroclr" |
|
|
| def __init__(self, config: NeuroCLRConfig): |
| super().__init__(config) |
| self.neuroclr = NeuroCLR(config) |
| self.post_init() |
|
|
| def forward(self, x: torch.Tensor, **kwargs): |
| h, z = self.neuroclr(x) |
| return {"h": h, "z": z} |
|
|