mineself2016 commited on
Commit
c174f3b
·
verified ·
1 Parent(s): 2fde376

Unify repo: default 24l-512d at root, add size variants via subfolder

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 24l-512d/config.json +28 -0
  2. 24l-512d/configuration_genemamba.py +97 -0
  3. 24l-512d/model.safetensors +3 -0
  4. 24l-512d/modeling_genemamba.py +395 -0
  5. 24l-512d/modeling_outputs.py +81 -0
  6. 24l-512d/special_tokens_map.json +4 -0
  7. 24l-512d/tokenizer.json +0 -0
  8. 24l-512d/tokenizer_config.json +8 -0
  9. 24l-768d/config.json +28 -0
  10. 24l-768d/configuration_genemamba.py +97 -0
  11. 24l-768d/model.safetensors +3 -0
  12. 24l-768d/modeling_genemamba.py +395 -0
  13. 24l-768d/modeling_outputs.py +81 -0
  14. 24l-768d/special_tokens_map.json +4 -0
  15. 24l-768d/tokenizer.json +0 -0
  16. 24l-768d/tokenizer_config.json +8 -0
  17. 48l-512d/config.json +28 -0
  18. 48l-512d/configuration_genemamba.py +97 -0
  19. 48l-512d/model.safetensors +3 -0
  20. 48l-512d/modeling_genemamba.py +395 -0
  21. 48l-512d/modeling_outputs.py +81 -0
  22. 48l-512d/special_tokens_map.json +4 -0
  23. 48l-512d/tokenizer.json +0 -0
  24. 48l-512d/tokenizer_config.json +8 -0
  25. 48l-768d/config.json +28 -0
  26. 48l-768d/configuration_genemamba.py +97 -0
  27. 48l-768d/model.safetensors +3 -0
  28. 48l-768d/modeling_genemamba.py +395 -0
  29. 48l-768d/modeling_outputs.py +81 -0
  30. 48l-768d/special_tokens_map.json +4 -0
  31. 48l-768d/tokenizer.json +0 -0
  32. 48l-768d/tokenizer_config.json +8 -0
  33. README.md +133 -0
  34. config.json +28 -0
  35. configuration_genemamba.py +97 -0
  36. examples/00_preprocess_to_input_ids.py +75 -0
  37. examples/01_extract_embeddings.py +150 -0
  38. examples/downstream/10_finetune_classification.py +248 -0
  39. examples/downstream/11_zero_shot_logreg.py +98 -0
  40. examples/downstream/12_batch_integration_eval.py +79 -0
  41. examples/downstream/20_continue_pretraining_reference.py +265 -0
  42. examples/downstream/21_pretrain_from_scratch_reference.py +280 -0
  43. examples/downstream/README.md +35 -0
  44. examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_with_label.py +378 -0
  45. examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label.py +161 -0
  46. examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py +197 -0
  47. model.safetensors +3 -0
  48. modeling_genemamba.py +395 -0
  49. modeling_outputs.py +81 -0
  50. special_tokens_map.json +4 -0
24l-512d/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "genemamba",
3
+ "architectures": [
4
+ "GeneMambaModel"
5
+ ],
6
+ "vocab_size": 25426,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 512,
9
+ "num_hidden_layers": 24,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "initializer_range": 0.02,
13
+ "mamba_mode": "gate",
14
+ "embedding_pooling": "mean",
15
+ "num_labels": 2,
16
+ "pad_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "bos_token_id": 0,
19
+ "use_cache": true,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.40.2",
22
+ "auto_map": {
23
+ "AutoConfig": "configuration_genemamba.GeneMambaConfig",
24
+ "AutoModel": "modeling_genemamba.GeneMambaModel",
25
+ "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
26
+ "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
27
+ }
28
+ }
24l-512d/configuration_genemamba.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GeneMamba model.
3
+ Defines all hyperparameters and settings for the GeneMamba architecture.
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional
8
+
9
+
10
+ class GeneMambaConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for GeneMamba model.
13
+
14
+ This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
15
+ It can be used to instantiate models from pretrained checkpoints or customize model initialization.
16
+
17
+ Args:
18
+ vocab_size (int, optional, defaults to 25426):
19
+ Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
20
+
21
+ hidden_size (int, optional, defaults to 512):
22
+ Dimensionality of the hidden/embedding layers (d_model in Mamba).
23
+
24
+ num_hidden_layers (int, optional, defaults to 24):
25
+ Number of Mamba layers (mamba_layer).
26
+
27
+ intermediate_size (int, optional, defaults to 2048):
28
+ Dimensionality of intermediate representations in MLP.
29
+
30
+ max_position_embeddings (int, optional, defaults to 2048):
31
+ Maximum sequence length (seq_len).
32
+
33
+ hidden_dropout_prob (float, optional, defaults to 0.1):
34
+ Dropout probability for hidden states.
35
+
36
+ initializer_range (float, optional, defaults to 0.02):
37
+ Standard deviation of truncated normal initializer.
38
+
39
+ mamba_mode (str, optional, defaults to "gate"):
40
+ Aggregation mode for bidirectional Mamba layers.
41
+ Options: "mean", "sum", "concat", "gate".
42
+
43
+ embedding_pooling (str, optional, defaults to "mean"):
44
+ Method for pooling to get cell embedding.
45
+ Options: "CLS", "mean", "weighted".
46
+
47
+ num_labels (int, optional, defaults to 2):
48
+ Number of labels for sequence classification tasks.
49
+
50
+ pad_token_id (int, optional, defaults to 1):
51
+ Token ID for padding.
52
+
53
+ bos_token_id (int, optional, defaults to None):
54
+ Token ID for beginning of sequence.
55
+
56
+ eos_token_id (int, optional, defaults to None):
57
+ Token ID for end of sequence.
58
+ """
59
+
60
+ model_type = "genemamba"
61
+ attribute_map = {
62
+ "hidden_size": "hidden_size",
63
+ "num_hidden_layers": "num_hidden_layers",
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size: int = 25426,
69
+ hidden_size: int = 512,
70
+ num_hidden_layers: int = 24,
71
+ intermediate_size: int = 2048,
72
+ max_position_embeddings: int = 2048,
73
+ hidden_dropout_prob: float = 0.1,
74
+ initializer_range: float = 0.02,
75
+ mamba_mode: str = "gate",
76
+ embedding_pooling: str = "mean",
77
+ num_labels: int = 2,
78
+ pad_token_id: int = 1,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs
82
+ ):
83
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
84
+
85
+ self.vocab_size = vocab_size
86
+ self.hidden_size = hidden_size
87
+ self.num_hidden_layers = num_hidden_layers
88
+ self.intermediate_size = intermediate_size
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.hidden_dropout_prob = hidden_dropout_prob
91
+ self.initializer_range = initializer_range
92
+ self.mamba_mode = mamba_mode
93
+ self.embedding_pooling = embedding_pooling
94
+ self.num_labels = num_labels
95
+ self.pad_token_id = pad_token_id
96
+ self.bos_token_id = bos_token_id
97
+ self.eos_token_id = eos_token_id
24l-512d/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
3
+ size 262998656
24l-512d/modeling_genemamba.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of GeneMamba model for Hugging Face Transformers.
3
+ Includes backbone model and task-specific heads for various downstream tasks.
4
+ """
5
+
6
+ import math
7
+ import logging
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_, constant_
14
+
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
17
+ from transformers.models.auto import register_model_for_auto_class
18
+
19
+ from mamba_ssm import Mamba
20
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm
21
+
22
+ from .configuration_genemamba import GeneMambaConfig
23
+ from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ===========================
29
+ # Core Architecture Components
30
+ # ===========================
31
+
32
+ class EncoderLayer(nn.Module):
33
+ """
34
+ Single Mamba encoder layer with residual connection.
35
+ Applies a Mamba2 or Mamba layer followed by addition with input.
36
+
37
+ Args:
38
+ hidden_size (int): Dimension of hidden states.
39
+ """
40
+
41
+ def __init__(self, hidden_size: int):
42
+ super(EncoderLayer, self).__init__()
43
+ self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
44
+
45
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
49
+
50
+ Returns:
51
+ torch.Tensor: Output after Mamba layer and residual connection.
52
+ """
53
+ output = self.mamba(X) + X
54
+ return output
55
+
56
+
57
+ class MambaMixer(nn.Module):
58
+ """
59
+ Stack of Mamba encoder layers with bidirectional processing and aggregation.
60
+ Processes sequences in both forward and reverse directions, then aggregates.
61
+
62
+ Args:
63
+ mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
64
+ hidden_size (int): Dimension of hidden states.
65
+ num_hidden_layers (int): Number of Mamba layers.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ mode: str = "gate",
71
+ hidden_size: int = 512,
72
+ num_hidden_layers: int = 24
73
+ ):
74
+ super(MambaMixer, self).__init__()
75
+ self.mode = mode
76
+ self.hidden_size = hidden_size
77
+
78
+ # Create Mamba layers
79
+ self.layers = nn.ModuleList(
80
+ [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
81
+ )
82
+
83
+ # Aggregation modules for certain modes
84
+ if mode in ["concat", "gate"]:
85
+ self.aggr = nn.Linear(hidden_size * 2, hidden_size)
86
+
87
+ def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
88
+ """
89
+ Reverse a sequence based on actual length (ignoring padding).
90
+
91
+ Args:
92
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
93
+ mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
94
+
95
+ Returns:
96
+ torch.Tensor: Reversed tensor.
97
+ """
98
+ batch_size, seq_length, embedding_dim = X.size()
99
+
100
+ if mask is None:
101
+ # Simple flip
102
+ return X.flip([1])
103
+
104
+ # Flip based on actual sequence length (marked by mask)
105
+ lengths = (~mask).sum(dim=1)
106
+ pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
107
+ flip_mask = pos_tensor < lengths.unsqueeze(1)
108
+ reversed_positions = torch.where(
109
+ flip_mask,
110
+ lengths.unsqueeze(1) - 1 - pos_tensor,
111
+ pos_tensor
112
+ )
113
+
114
+ X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
115
+ return X_reverse
116
+
117
+ def forward(
118
+ self,
119
+ X: torch.Tensor,
120
+ padding_mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ """
123
+ Process sequence through bidirectional Mamba layers.
124
+
125
+ Args:
126
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
127
+ padding_mask (torch.Tensor, optional): Padding mask.
128
+
129
+ Returns:
130
+ torch.Tensor: Output after processing all layers and aggregation.
131
+ """
132
+
133
+ for layer in self.layers:
134
+ # Flip sequence for reverse processing
135
+ X_flip = self.flip_sequence(X, padding_mask)
136
+
137
+ # Forward and reverse passes
138
+ X_f = layer(X)
139
+ X_b = layer(X_flip)
140
+
141
+ # Flip back the reverse output
142
+ X_b = self.flip_sequence(X_b, padding_mask)
143
+
144
+ # Aggregate forward and reverse
145
+ if self.mode == "mean":
146
+ X = (X_f + X_b) / 2
147
+ elif self.mode == "sum":
148
+ X = X_f + X_b
149
+ elif self.mode == "concat":
150
+ X = torch.cat([X_f, X_b], dim=-1)
151
+ X = self.aggr(X)
152
+ elif self.mode == "gate":
153
+ z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
154
+ X = z * X_f + (1 - z) * X_b
155
+ else:
156
+ raise ValueError(f"Invalid aggregation mode: {self.mode}")
157
+
158
+ return X
159
+
160
+
161
+ # ===========================
162
+ # Base Model Classes
163
+ # ===========================
164
+
165
+ class GeneMambaPreTrainedModel(PreTrainedModel):
166
+ """
167
+ Base class for all GeneMamba models.
168
+ Handles weight initialization and provides standard model interfaces.
169
+ """
170
+
171
+ config_class = GeneMambaConfig
172
+ base_model_prefix = "genemamba"
173
+ supports_gradient_checkpointing = True
174
+
175
+ def _init_weights(self, module):
176
+ """Initialize module weights."""
177
+ if isinstance(module, nn.Linear):
178
+ normal_(module.weight, std=self.config.initializer_range)
179
+ if module.bias is not None:
180
+ constant_(module.bias, 0.0)
181
+ elif isinstance(module, nn.Embedding):
182
+ normal_(module.weight, std=self.config.initializer_range)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+ elif isinstance(module, nn.LayerNorm):
186
+ constant_(module.bias, 0.0)
187
+ constant_(module.weight, 1.0)
188
+
189
+
190
+ class GeneMambaModel(GeneMambaPreTrainedModel):
191
+ """
192
+ GeneMamba backbone model - outputs cell embeddings and hidden states.
193
+ This is the core model used by task-specific heads.
194
+
195
+ Args:
196
+ config (GeneMambaConfig): Model configuration class.
197
+ """
198
+
199
+ def __init__(self, config: GeneMambaConfig):
200
+ super().__init__(config)
201
+ self.config = config
202
+
203
+ # Embedding layer
204
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
205
+
206
+ # Mamba layers with bidirectional aggregation
207
+ self.mamba_mixer = MambaMixer(
208
+ mode=config.mamba_mode,
209
+ hidden_size=config.hidden_size,
210
+ num_hidden_layers=config.num_hidden_layers
211
+ )
212
+
213
+ # Final layer normalization
214
+ self.norm = RMSNorm(config.hidden_size)
215
+
216
+ self.apply(self._init_weights)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ """Return embedding layer."""
220
+ return self.embeddings
221
+
222
+ def set_input_embeddings(self, value: nn.Embedding):
223
+ """Set embedding layer."""
224
+ self.embeddings = value
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ output_hidden_states: bool = False,
231
+ ) -> GeneMambaModelOutput:
232
+ """
233
+ Args:
234
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
235
+ attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
236
+ output_hidden_states (bool): Whether to output hidden states from all layers.
237
+
238
+ Returns:
239
+ GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
240
+ """
241
+ # Get embeddings
242
+ hidden_states = self.embeddings(input_ids)
243
+
244
+ # Pass through Mamba layers
245
+ hidden_states = self.mamba_mixer(hidden_states, attention_mask)
246
+
247
+ # Apply final normalization
248
+ hidden_states = self.norm(hidden_states)
249
+
250
+ # Compute pooled embedding (cell representation)
251
+ if self.config.embedding_pooling == "CLS":
252
+ # Use first token (CLS)
253
+ pooled_embedding = hidden_states[:, 0, :]
254
+ elif self.config.embedding_pooling == "mean":
255
+ # Mean pooling over sequence
256
+ if attention_mask is not None:
257
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
258
+ pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
259
+ else:
260
+ pooled_embedding = hidden_states.mean(dim=1)
261
+ else:
262
+ raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
263
+
264
+ return GeneMambaModelOutput(
265
+ last_hidden_state=hidden_states,
266
+ pooled_embedding=pooled_embedding,
267
+ hidden_states=hidden_states if output_hidden_states else None,
268
+ embedding_pooling=self.config.embedding_pooling,
269
+ )
270
+
271
+
272
+ # ===========================
273
+ # Task-Specific Models
274
+ # ===========================
275
+
276
+ @register_model_for_auto_class("AutoModel")
277
+ class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
278
+ """
279
+ GeneMamba model for masked language modeling (MLM).
280
+ Suitable for pretraining and domain adaptation.
281
+
282
+ Args:
283
+ config (GeneMambaConfig): Model configuration class.
284
+ """
285
+
286
+ def __init__(self, config: GeneMambaConfig):
287
+ super().__init__(config)
288
+ self.genemamba = GeneMambaModel(config)
289
+
290
+ # Language modeling head
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
292
+
293
+ self.apply(self._init_weights)
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ labels: Optional[torch.Tensor] = None,
300
+ output_hidden_states: bool = False,
301
+ ) -> GeneMambaMaskedLMOutput:
302
+ """
303
+ Args:
304
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
305
+ attention_mask (torch.Tensor, optional): Attention mask.
306
+ labels (torch.Tensor, optional): Target token ids for MLM loss.
307
+ output_hidden_states (bool): Whether to output hidden states.
308
+
309
+ Returns:
310
+ GeneMambaMaskedLMOutput: Contains logits and optional loss.
311
+ """
312
+ outputs = self.genemamba(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ output_hidden_states=output_hidden_states,
316
+ )
317
+
318
+ logits = self.lm_head(outputs.last_hidden_state)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
324
+
325
+ return GeneMambaMaskedLMOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
329
+ )
330
+
331
+
332
+ @register_model_for_auto_class("AutoModelForSequenceClassification")
333
+ class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
334
+ """
335
+ GeneMamba model for sequence classification tasks.
336
+ Ideal for cell type annotation, tissue classification, etc.
337
+
338
+ Args:
339
+ config (GeneMambaConfig): Model configuration class.
340
+ """
341
+
342
+ def __init__(self, config: GeneMambaConfig):
343
+ super().__init__(config)
344
+ self.num_labels = config.num_labels
345
+ self.config = config
346
+
347
+ self.genemamba = GeneMambaModel(config)
348
+
349
+ # Classification head
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
352
+
353
+ self.apply(self._init_weights)
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> GeneMambaSequenceClassifierOutput:
362
+ """
363
+ Args:
364
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
365
+ attention_mask (torch.Tensor, optional): Attention mask.
366
+ labels (torch.Tensor, optional): Class labels for classification loss.
367
+ output_hidden_states (bool): Whether to output hidden states.
368
+
369
+ Returns:
370
+ GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
371
+ """
372
+ outputs = self.genemamba(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ output_hidden_states=output_hidden_states,
376
+ )
377
+
378
+ pooled_embedding = outputs.pooled_embedding
379
+ logits = self.classifier(self.dropout(pooled_embedding))
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ loss_fct = nn.CrossEntropyLoss()
384
+ loss = loss_fct(logits, labels)
385
+
386
+ return GeneMambaSequenceClassifierOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
390
+ pooled_embedding=pooled_embedding,
391
+ )
392
+
393
+
394
+ # Register tokenizer class
395
+ register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
24l-512d/modeling_outputs.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom ModelOutput classes for GeneMamba.
3
+ Defines the output structure for different GeneMamba tasks.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class GeneMambaModelOutput(ModelOutput):
14
+ """
15
+ Base output class for GeneMamba models.
16
+
17
+ Attributes:
18
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
19
+ Sequence of hidden-states at the output of the last layer of the model.
20
+
21
+ hidden_states (tuple(torch.FloatTensor), optional):
22
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
23
+
24
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
25
+ Cell/sequence-level embedding (pooled representation) used for downstream tasks.
26
+ This is the recommended embedding to use for classification, clustering, etc.
27
+
28
+ embedding_pooling (str):
29
+ The pooling method used to generate pooled_embedding.
30
+ """
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ pooled_embedding: torch.FloatTensor = None
35
+ embedding_pooling: str = "mean"
36
+
37
+
38
+ @dataclass
39
+ class GeneMambaSequenceClassifierOutput(ModelOutput):
40
+ """
41
+ Output class for GeneMamba sequence classification models.
42
+
43
+ Attributes:
44
+ loss (torch.FloatTensor of shape (), optional):
45
+ Classification loss (if labels were provided).
46
+
47
+ logits (torch.FloatTensor of shape (batch_size, num_labels)):
48
+ Classification scores (before softmax).
49
+
50
+ hidden_states (tuple(torch.FloatTensor), optional):
51
+ Hidden-states of the model at the output of each layer.
52
+
53
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
54
+ Cell embedding before classification head.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ logits: torch.FloatTensor = None
59
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
60
+ pooled_embedding: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class GeneMambaMaskedLMOutput(ModelOutput):
65
+ """
66
+ Output class for GeneMamba masked language modeling.
67
+
68
+ Attributes:
69
+ loss (torch.FloatTensor of shape (), optional):
70
+ MLM loss (if labels were provided).
71
+
72
+ logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
73
+ Prediction scores of the language modeling head.
74
+
75
+ hidden_states (tuple(torch.FloatTensor), optional):
76
+ Hidden-states of the model at the output of each layer.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ logits: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24l-512d/special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]"
4
+ }
24l-512d/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
24l-512d/tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "clean_up_tokenization_spaces": true,
4
+ "model_max_length": 1000000000000000019884624838656,
5
+ "pad_token": "[PAD]",
6
+ "tokenizer_class": "PreTrainedTokenizerFast",
7
+ "unk_token": "[UNK]"
8
+ }
24l-768d/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "genemamba",
3
+ "architectures": [
4
+ "GeneMambaModel"
5
+ ],
6
+ "vocab_size": 25426,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 768,
9
+ "num_hidden_layers": 24,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "initializer_range": 0.02,
13
+ "mamba_mode": "gate",
14
+ "embedding_pooling": "mean",
15
+ "num_labels": 2,
16
+ "pad_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "bos_token_id": 0,
19
+ "use_cache": true,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.40.2",
22
+ "auto_map": {
23
+ "AutoConfig": "configuration_genemamba.GeneMambaConfig",
24
+ "AutoModel": "modeling_genemamba.GeneMambaModel",
25
+ "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
26
+ "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
27
+ }
28
+ }
24l-768d/configuration_genemamba.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GeneMamba model.
3
+ Defines all hyperparameters and settings for the GeneMamba architecture.
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional
8
+
9
+
10
+ class GeneMambaConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for GeneMamba model.
13
+
14
+ This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
15
+ It can be used to instantiate models from pretrained checkpoints or customize model initialization.
16
+
17
+ Args:
18
+ vocab_size (int, optional, defaults to 25426):
19
+ Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
20
+
21
+ hidden_size (int, optional, defaults to 512):
22
+ Dimensionality of the hidden/embedding layers (d_model in Mamba).
23
+
24
+ num_hidden_layers (int, optional, defaults to 24):
25
+ Number of Mamba layers (mamba_layer).
26
+
27
+ intermediate_size (int, optional, defaults to 2048):
28
+ Dimensionality of intermediate representations in MLP.
29
+
30
+ max_position_embeddings (int, optional, defaults to 2048):
31
+ Maximum sequence length (seq_len).
32
+
33
+ hidden_dropout_prob (float, optional, defaults to 0.1):
34
+ Dropout probability for hidden states.
35
+
36
+ initializer_range (float, optional, defaults to 0.02):
37
+ Standard deviation of truncated normal initializer.
38
+
39
+ mamba_mode (str, optional, defaults to "gate"):
40
+ Aggregation mode for bidirectional Mamba layers.
41
+ Options: "mean", "sum", "concat", "gate".
42
+
43
+ embedding_pooling (str, optional, defaults to "mean"):
44
+ Method for pooling to get cell embedding.
45
+ Options: "CLS", "mean", "weighted".
46
+
47
+ num_labels (int, optional, defaults to 2):
48
+ Number of labels for sequence classification tasks.
49
+
50
+ pad_token_id (int, optional, defaults to 1):
51
+ Token ID for padding.
52
+
53
+ bos_token_id (int, optional, defaults to None):
54
+ Token ID for beginning of sequence.
55
+
56
+ eos_token_id (int, optional, defaults to None):
57
+ Token ID for end of sequence.
58
+ """
59
+
60
+ model_type = "genemamba"
61
+ attribute_map = {
62
+ "hidden_size": "hidden_size",
63
+ "num_hidden_layers": "num_hidden_layers",
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size: int = 25426,
69
+ hidden_size: int = 512,
70
+ num_hidden_layers: int = 24,
71
+ intermediate_size: int = 2048,
72
+ max_position_embeddings: int = 2048,
73
+ hidden_dropout_prob: float = 0.1,
74
+ initializer_range: float = 0.02,
75
+ mamba_mode: str = "gate",
76
+ embedding_pooling: str = "mean",
77
+ num_labels: int = 2,
78
+ pad_token_id: int = 1,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs
82
+ ):
83
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
84
+
85
+ self.vocab_size = vocab_size
86
+ self.hidden_size = hidden_size
87
+ self.num_hidden_layers = num_hidden_layers
88
+ self.intermediate_size = intermediate_size
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.hidden_dropout_prob = hidden_dropout_prob
91
+ self.initializer_range = initializer_range
92
+ self.mamba_mode = mamba_mode
93
+ self.embedding_pooling = embedding_pooling
94
+ self.num_labels = num_labels
95
+ self.pad_token_id = pad_token_id
96
+ self.bos_token_id = bos_token_id
97
+ self.eos_token_id = eos_token_id
24l-768d/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b423a3555eecacc88ff587c1d3f689a2caa05ede0a01d09dbaae175f23a2e7e1
3
+ size 508241792
24l-768d/modeling_genemamba.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of GeneMamba model for Hugging Face Transformers.
3
+ Includes backbone model and task-specific heads for various downstream tasks.
4
+ """
5
+
6
+ import math
7
+ import logging
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_, constant_
14
+
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
17
+ from transformers.models.auto import register_model_for_auto_class
18
+
19
+ from mamba_ssm import Mamba
20
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm
21
+
22
+ from .configuration_genemamba import GeneMambaConfig
23
+ from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ===========================
29
+ # Core Architecture Components
30
+ # ===========================
31
+
32
+ class EncoderLayer(nn.Module):
33
+ """
34
+ Single Mamba encoder layer with residual connection.
35
+ Applies a Mamba2 or Mamba layer followed by addition with input.
36
+
37
+ Args:
38
+ hidden_size (int): Dimension of hidden states.
39
+ """
40
+
41
+ def __init__(self, hidden_size: int):
42
+ super(EncoderLayer, self).__init__()
43
+ self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
44
+
45
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
49
+
50
+ Returns:
51
+ torch.Tensor: Output after Mamba layer and residual connection.
52
+ """
53
+ output = self.mamba(X) + X
54
+ return output
55
+
56
+
57
+ class MambaMixer(nn.Module):
58
+ """
59
+ Stack of Mamba encoder layers with bidirectional processing and aggregation.
60
+ Processes sequences in both forward and reverse directions, then aggregates.
61
+
62
+ Args:
63
+ mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
64
+ hidden_size (int): Dimension of hidden states.
65
+ num_hidden_layers (int): Number of Mamba layers.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ mode: str = "gate",
71
+ hidden_size: int = 512,
72
+ num_hidden_layers: int = 24
73
+ ):
74
+ super(MambaMixer, self).__init__()
75
+ self.mode = mode
76
+ self.hidden_size = hidden_size
77
+
78
+ # Create Mamba layers
79
+ self.layers = nn.ModuleList(
80
+ [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
81
+ )
82
+
83
+ # Aggregation modules for certain modes
84
+ if mode in ["concat", "gate"]:
85
+ self.aggr = nn.Linear(hidden_size * 2, hidden_size)
86
+
87
+ def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
88
+ """
89
+ Reverse a sequence based on actual length (ignoring padding).
90
+
91
+ Args:
92
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
93
+ mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
94
+
95
+ Returns:
96
+ torch.Tensor: Reversed tensor.
97
+ """
98
+ batch_size, seq_length, embedding_dim = X.size()
99
+
100
+ if mask is None:
101
+ # Simple flip
102
+ return X.flip([1])
103
+
104
+ # Flip based on actual sequence length (marked by mask)
105
+ lengths = (~mask).sum(dim=1)
106
+ pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
107
+ flip_mask = pos_tensor < lengths.unsqueeze(1)
108
+ reversed_positions = torch.where(
109
+ flip_mask,
110
+ lengths.unsqueeze(1) - 1 - pos_tensor,
111
+ pos_tensor
112
+ )
113
+
114
+ X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
115
+ return X_reverse
116
+
117
+ def forward(
118
+ self,
119
+ X: torch.Tensor,
120
+ padding_mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ """
123
+ Process sequence through bidirectional Mamba layers.
124
+
125
+ Args:
126
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
127
+ padding_mask (torch.Tensor, optional): Padding mask.
128
+
129
+ Returns:
130
+ torch.Tensor: Output after processing all layers and aggregation.
131
+ """
132
+
133
+ for layer in self.layers:
134
+ # Flip sequence for reverse processing
135
+ X_flip = self.flip_sequence(X, padding_mask)
136
+
137
+ # Forward and reverse passes
138
+ X_f = layer(X)
139
+ X_b = layer(X_flip)
140
+
141
+ # Flip back the reverse output
142
+ X_b = self.flip_sequence(X_b, padding_mask)
143
+
144
+ # Aggregate forward and reverse
145
+ if self.mode == "mean":
146
+ X = (X_f + X_b) / 2
147
+ elif self.mode == "sum":
148
+ X = X_f + X_b
149
+ elif self.mode == "concat":
150
+ X = torch.cat([X_f, X_b], dim=-1)
151
+ X = self.aggr(X)
152
+ elif self.mode == "gate":
153
+ z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
154
+ X = z * X_f + (1 - z) * X_b
155
+ else:
156
+ raise ValueError(f"Invalid aggregation mode: {self.mode}")
157
+
158
+ return X
159
+
160
+
161
+ # ===========================
162
+ # Base Model Classes
163
+ # ===========================
164
+
165
+ class GeneMambaPreTrainedModel(PreTrainedModel):
166
+ """
167
+ Base class for all GeneMamba models.
168
+ Handles weight initialization and provides standard model interfaces.
169
+ """
170
+
171
+ config_class = GeneMambaConfig
172
+ base_model_prefix = "genemamba"
173
+ supports_gradient_checkpointing = True
174
+
175
+ def _init_weights(self, module):
176
+ """Initialize module weights."""
177
+ if isinstance(module, nn.Linear):
178
+ normal_(module.weight, std=self.config.initializer_range)
179
+ if module.bias is not None:
180
+ constant_(module.bias, 0.0)
181
+ elif isinstance(module, nn.Embedding):
182
+ normal_(module.weight, std=self.config.initializer_range)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+ elif isinstance(module, nn.LayerNorm):
186
+ constant_(module.bias, 0.0)
187
+ constant_(module.weight, 1.0)
188
+
189
+
190
+ class GeneMambaModel(GeneMambaPreTrainedModel):
191
+ """
192
+ GeneMamba backbone model - outputs cell embeddings and hidden states.
193
+ This is the core model used by task-specific heads.
194
+
195
+ Args:
196
+ config (GeneMambaConfig): Model configuration class.
197
+ """
198
+
199
+ def __init__(self, config: GeneMambaConfig):
200
+ super().__init__(config)
201
+ self.config = config
202
+
203
+ # Embedding layer
204
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
205
+
206
+ # Mamba layers with bidirectional aggregation
207
+ self.mamba_mixer = MambaMixer(
208
+ mode=config.mamba_mode,
209
+ hidden_size=config.hidden_size,
210
+ num_hidden_layers=config.num_hidden_layers
211
+ )
212
+
213
+ # Final layer normalization
214
+ self.norm = RMSNorm(config.hidden_size)
215
+
216
+ self.apply(self._init_weights)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ """Return embedding layer."""
220
+ return self.embeddings
221
+
222
+ def set_input_embeddings(self, value: nn.Embedding):
223
+ """Set embedding layer."""
224
+ self.embeddings = value
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ output_hidden_states: bool = False,
231
+ ) -> GeneMambaModelOutput:
232
+ """
233
+ Args:
234
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
235
+ attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
236
+ output_hidden_states (bool): Whether to output hidden states from all layers.
237
+
238
+ Returns:
239
+ GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
240
+ """
241
+ # Get embeddings
242
+ hidden_states = self.embeddings(input_ids)
243
+
244
+ # Pass through Mamba layers
245
+ hidden_states = self.mamba_mixer(hidden_states, attention_mask)
246
+
247
+ # Apply final normalization
248
+ hidden_states = self.norm(hidden_states)
249
+
250
+ # Compute pooled embedding (cell representation)
251
+ if self.config.embedding_pooling == "CLS":
252
+ # Use first token (CLS)
253
+ pooled_embedding = hidden_states[:, 0, :]
254
+ elif self.config.embedding_pooling == "mean":
255
+ # Mean pooling over sequence
256
+ if attention_mask is not None:
257
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
258
+ pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
259
+ else:
260
+ pooled_embedding = hidden_states.mean(dim=1)
261
+ else:
262
+ raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
263
+
264
+ return GeneMambaModelOutput(
265
+ last_hidden_state=hidden_states,
266
+ pooled_embedding=pooled_embedding,
267
+ hidden_states=hidden_states if output_hidden_states else None,
268
+ embedding_pooling=self.config.embedding_pooling,
269
+ )
270
+
271
+
272
+ # ===========================
273
+ # Task-Specific Models
274
+ # ===========================
275
+
276
+ @register_model_for_auto_class("AutoModel")
277
+ class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
278
+ """
279
+ GeneMamba model for masked language modeling (MLM).
280
+ Suitable for pretraining and domain adaptation.
281
+
282
+ Args:
283
+ config (GeneMambaConfig): Model configuration class.
284
+ """
285
+
286
+ def __init__(self, config: GeneMambaConfig):
287
+ super().__init__(config)
288
+ self.genemamba = GeneMambaModel(config)
289
+
290
+ # Language modeling head
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
292
+
293
+ self.apply(self._init_weights)
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ labels: Optional[torch.Tensor] = None,
300
+ output_hidden_states: bool = False,
301
+ ) -> GeneMambaMaskedLMOutput:
302
+ """
303
+ Args:
304
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
305
+ attention_mask (torch.Tensor, optional): Attention mask.
306
+ labels (torch.Tensor, optional): Target token ids for MLM loss.
307
+ output_hidden_states (bool): Whether to output hidden states.
308
+
309
+ Returns:
310
+ GeneMambaMaskedLMOutput: Contains logits and optional loss.
311
+ """
312
+ outputs = self.genemamba(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ output_hidden_states=output_hidden_states,
316
+ )
317
+
318
+ logits = self.lm_head(outputs.last_hidden_state)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
324
+
325
+ return GeneMambaMaskedLMOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
329
+ )
330
+
331
+
332
+ @register_model_for_auto_class("AutoModelForSequenceClassification")
333
+ class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
334
+ """
335
+ GeneMamba model for sequence classification tasks.
336
+ Ideal for cell type annotation, tissue classification, etc.
337
+
338
+ Args:
339
+ config (GeneMambaConfig): Model configuration class.
340
+ """
341
+
342
+ def __init__(self, config: GeneMambaConfig):
343
+ super().__init__(config)
344
+ self.num_labels = config.num_labels
345
+ self.config = config
346
+
347
+ self.genemamba = GeneMambaModel(config)
348
+
349
+ # Classification head
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
352
+
353
+ self.apply(self._init_weights)
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> GeneMambaSequenceClassifierOutput:
362
+ """
363
+ Args:
364
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
365
+ attention_mask (torch.Tensor, optional): Attention mask.
366
+ labels (torch.Tensor, optional): Class labels for classification loss.
367
+ output_hidden_states (bool): Whether to output hidden states.
368
+
369
+ Returns:
370
+ GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
371
+ """
372
+ outputs = self.genemamba(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ output_hidden_states=output_hidden_states,
376
+ )
377
+
378
+ pooled_embedding = outputs.pooled_embedding
379
+ logits = self.classifier(self.dropout(pooled_embedding))
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ loss_fct = nn.CrossEntropyLoss()
384
+ loss = loss_fct(logits, labels)
385
+
386
+ return GeneMambaSequenceClassifierOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
390
+ pooled_embedding=pooled_embedding,
391
+ )
392
+
393
+
394
+ # Register tokenizer class
395
+ register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
24l-768d/modeling_outputs.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom ModelOutput classes for GeneMamba.
3
+ Defines the output structure for different GeneMamba tasks.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class GeneMambaModelOutput(ModelOutput):
14
+ """
15
+ Base output class for GeneMamba models.
16
+
17
+ Attributes:
18
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
19
+ Sequence of hidden-states at the output of the last layer of the model.
20
+
21
+ hidden_states (tuple(torch.FloatTensor), optional):
22
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
23
+
24
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
25
+ Cell/sequence-level embedding (pooled representation) used for downstream tasks.
26
+ This is the recommended embedding to use for classification, clustering, etc.
27
+
28
+ embedding_pooling (str):
29
+ The pooling method used to generate pooled_embedding.
30
+ """
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ pooled_embedding: torch.FloatTensor = None
35
+ embedding_pooling: str = "mean"
36
+
37
+
38
+ @dataclass
39
+ class GeneMambaSequenceClassifierOutput(ModelOutput):
40
+ """
41
+ Output class for GeneMamba sequence classification models.
42
+
43
+ Attributes:
44
+ loss (torch.FloatTensor of shape (), optional):
45
+ Classification loss (if labels were provided).
46
+
47
+ logits (torch.FloatTensor of shape (batch_size, num_labels)):
48
+ Classification scores (before softmax).
49
+
50
+ hidden_states (tuple(torch.FloatTensor), optional):
51
+ Hidden-states of the model at the output of each layer.
52
+
53
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
54
+ Cell embedding before classification head.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ logits: torch.FloatTensor = None
59
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
60
+ pooled_embedding: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class GeneMambaMaskedLMOutput(ModelOutput):
65
+ """
66
+ Output class for GeneMamba masked language modeling.
67
+
68
+ Attributes:
69
+ loss (torch.FloatTensor of shape (), optional):
70
+ MLM loss (if labels were provided).
71
+
72
+ logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
73
+ Prediction scores of the language modeling head.
74
+
75
+ hidden_states (tuple(torch.FloatTensor), optional):
76
+ Hidden-states of the model at the output of each layer.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ logits: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
24l-768d/special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]"
4
+ }
24l-768d/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
24l-768d/tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "clean_up_tokenization_spaces": true,
4
+ "model_max_length": 1000000000000000019884624838656,
5
+ "pad_token": "[PAD]",
6
+ "tokenizer_class": "PreTrainedTokenizerFast",
7
+ "unk_token": "[UNK]"
8
+ }
48l-512d/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "genemamba",
3
+ "architectures": [
4
+ "GeneMambaModel"
5
+ ],
6
+ "vocab_size": 25426,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 512,
9
+ "num_hidden_layers": 48,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "initializer_range": 0.02,
13
+ "mamba_mode": "gate",
14
+ "embedding_pooling": "mean",
15
+ "num_labels": 2,
16
+ "pad_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "bos_token_id": 0,
19
+ "use_cache": true,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.40.2",
22
+ "auto_map": {
23
+ "AutoConfig": "configuration_genemamba.GeneMambaConfig",
24
+ "AutoModel": "modeling_genemamba.GeneMambaModel",
25
+ "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
26
+ "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
27
+ }
28
+ }
48l-512d/configuration_genemamba.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GeneMamba model.
3
+ Defines all hyperparameters and settings for the GeneMamba architecture.
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional
8
+
9
+
10
+ class GeneMambaConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for GeneMamba model.
13
+
14
+ This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
15
+ It can be used to instantiate models from pretrained checkpoints or customize model initialization.
16
+
17
+ Args:
18
+ vocab_size (int, optional, defaults to 25426):
19
+ Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
20
+
21
+ hidden_size (int, optional, defaults to 512):
22
+ Dimensionality of the hidden/embedding layers (d_model in Mamba).
23
+
24
+ num_hidden_layers (int, optional, defaults to 24):
25
+ Number of Mamba layers (mamba_layer).
26
+
27
+ intermediate_size (int, optional, defaults to 2048):
28
+ Dimensionality of intermediate representations in MLP.
29
+
30
+ max_position_embeddings (int, optional, defaults to 2048):
31
+ Maximum sequence length (seq_len).
32
+
33
+ hidden_dropout_prob (float, optional, defaults to 0.1):
34
+ Dropout probability for hidden states.
35
+
36
+ initializer_range (float, optional, defaults to 0.02):
37
+ Standard deviation of truncated normal initializer.
38
+
39
+ mamba_mode (str, optional, defaults to "gate"):
40
+ Aggregation mode for bidirectional Mamba layers.
41
+ Options: "mean", "sum", "concat", "gate".
42
+
43
+ embedding_pooling (str, optional, defaults to "mean"):
44
+ Method for pooling to get cell embedding.
45
+ Options: "CLS", "mean", "weighted".
46
+
47
+ num_labels (int, optional, defaults to 2):
48
+ Number of labels for sequence classification tasks.
49
+
50
+ pad_token_id (int, optional, defaults to 1):
51
+ Token ID for padding.
52
+
53
+ bos_token_id (int, optional, defaults to None):
54
+ Token ID for beginning of sequence.
55
+
56
+ eos_token_id (int, optional, defaults to None):
57
+ Token ID for end of sequence.
58
+ """
59
+
60
+ model_type = "genemamba"
61
+ attribute_map = {
62
+ "hidden_size": "hidden_size",
63
+ "num_hidden_layers": "num_hidden_layers",
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size: int = 25426,
69
+ hidden_size: int = 512,
70
+ num_hidden_layers: int = 24,
71
+ intermediate_size: int = 2048,
72
+ max_position_embeddings: int = 2048,
73
+ hidden_dropout_prob: float = 0.1,
74
+ initializer_range: float = 0.02,
75
+ mamba_mode: str = "gate",
76
+ embedding_pooling: str = "mean",
77
+ num_labels: int = 2,
78
+ pad_token_id: int = 1,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs
82
+ ):
83
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
84
+
85
+ self.vocab_size = vocab_size
86
+ self.hidden_size = hidden_size
87
+ self.num_hidden_layers = num_hidden_layers
88
+ self.intermediate_size = intermediate_size
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.hidden_dropout_prob = hidden_dropout_prob
91
+ self.initializer_range = initializer_range
92
+ self.mamba_mode = mamba_mode
93
+ self.embedding_pooling = embedding_pooling
94
+ self.num_labels = num_labels
95
+ self.pad_token_id = pad_token_id
96
+ self.bos_token_id = bos_token_id
97
+ self.eos_token_id = eos_token_id
48l-512d/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a715342c6cc00b20161a05941d9d181cca73c7ecc9cae17fd3a04bf92590a7d
3
+ size 421748360
48l-512d/modeling_genemamba.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of GeneMamba model for Hugging Face Transformers.
3
+ Includes backbone model and task-specific heads for various downstream tasks.
4
+ """
5
+
6
+ import math
7
+ import logging
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_, constant_
14
+
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
17
+ from transformers.models.auto import register_model_for_auto_class
18
+
19
+ from mamba_ssm import Mamba
20
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm
21
+
22
+ from .configuration_genemamba import GeneMambaConfig
23
+ from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ===========================
29
+ # Core Architecture Components
30
+ # ===========================
31
+
32
+ class EncoderLayer(nn.Module):
33
+ """
34
+ Single Mamba encoder layer with residual connection.
35
+ Applies a Mamba2 or Mamba layer followed by addition with input.
36
+
37
+ Args:
38
+ hidden_size (int): Dimension of hidden states.
39
+ """
40
+
41
+ def __init__(self, hidden_size: int):
42
+ super(EncoderLayer, self).__init__()
43
+ self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
44
+
45
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
49
+
50
+ Returns:
51
+ torch.Tensor: Output after Mamba layer and residual connection.
52
+ """
53
+ output = self.mamba(X) + X
54
+ return output
55
+
56
+
57
+ class MambaMixer(nn.Module):
58
+ """
59
+ Stack of Mamba encoder layers with bidirectional processing and aggregation.
60
+ Processes sequences in both forward and reverse directions, then aggregates.
61
+
62
+ Args:
63
+ mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
64
+ hidden_size (int): Dimension of hidden states.
65
+ num_hidden_layers (int): Number of Mamba layers.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ mode: str = "gate",
71
+ hidden_size: int = 512,
72
+ num_hidden_layers: int = 24
73
+ ):
74
+ super(MambaMixer, self).__init__()
75
+ self.mode = mode
76
+ self.hidden_size = hidden_size
77
+
78
+ # Create Mamba layers
79
+ self.layers = nn.ModuleList(
80
+ [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
81
+ )
82
+
83
+ # Aggregation modules for certain modes
84
+ if mode in ["concat", "gate"]:
85
+ self.aggr = nn.Linear(hidden_size * 2, hidden_size)
86
+
87
+ def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
88
+ """
89
+ Reverse a sequence based on actual length (ignoring padding).
90
+
91
+ Args:
92
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
93
+ mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
94
+
95
+ Returns:
96
+ torch.Tensor: Reversed tensor.
97
+ """
98
+ batch_size, seq_length, embedding_dim = X.size()
99
+
100
+ if mask is None:
101
+ # Simple flip
102
+ return X.flip([1])
103
+
104
+ # Flip based on actual sequence length (marked by mask)
105
+ lengths = (~mask).sum(dim=1)
106
+ pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
107
+ flip_mask = pos_tensor < lengths.unsqueeze(1)
108
+ reversed_positions = torch.where(
109
+ flip_mask,
110
+ lengths.unsqueeze(1) - 1 - pos_tensor,
111
+ pos_tensor
112
+ )
113
+
114
+ X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
115
+ return X_reverse
116
+
117
+ def forward(
118
+ self,
119
+ X: torch.Tensor,
120
+ padding_mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ """
123
+ Process sequence through bidirectional Mamba layers.
124
+
125
+ Args:
126
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
127
+ padding_mask (torch.Tensor, optional): Padding mask.
128
+
129
+ Returns:
130
+ torch.Tensor: Output after processing all layers and aggregation.
131
+ """
132
+
133
+ for layer in self.layers:
134
+ # Flip sequence for reverse processing
135
+ X_flip = self.flip_sequence(X, padding_mask)
136
+
137
+ # Forward and reverse passes
138
+ X_f = layer(X)
139
+ X_b = layer(X_flip)
140
+
141
+ # Flip back the reverse output
142
+ X_b = self.flip_sequence(X_b, padding_mask)
143
+
144
+ # Aggregate forward and reverse
145
+ if self.mode == "mean":
146
+ X = (X_f + X_b) / 2
147
+ elif self.mode == "sum":
148
+ X = X_f + X_b
149
+ elif self.mode == "concat":
150
+ X = torch.cat([X_f, X_b], dim=-1)
151
+ X = self.aggr(X)
152
+ elif self.mode == "gate":
153
+ z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
154
+ X = z * X_f + (1 - z) * X_b
155
+ else:
156
+ raise ValueError(f"Invalid aggregation mode: {self.mode}")
157
+
158
+ return X
159
+
160
+
161
+ # ===========================
162
+ # Base Model Classes
163
+ # ===========================
164
+
165
+ class GeneMambaPreTrainedModel(PreTrainedModel):
166
+ """
167
+ Base class for all GeneMamba models.
168
+ Handles weight initialization and provides standard model interfaces.
169
+ """
170
+
171
+ config_class = GeneMambaConfig
172
+ base_model_prefix = "genemamba"
173
+ supports_gradient_checkpointing = True
174
+
175
+ def _init_weights(self, module):
176
+ """Initialize module weights."""
177
+ if isinstance(module, nn.Linear):
178
+ normal_(module.weight, std=self.config.initializer_range)
179
+ if module.bias is not None:
180
+ constant_(module.bias, 0.0)
181
+ elif isinstance(module, nn.Embedding):
182
+ normal_(module.weight, std=self.config.initializer_range)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+ elif isinstance(module, nn.LayerNorm):
186
+ constant_(module.bias, 0.0)
187
+ constant_(module.weight, 1.0)
188
+
189
+
190
+ class GeneMambaModel(GeneMambaPreTrainedModel):
191
+ """
192
+ GeneMamba backbone model - outputs cell embeddings and hidden states.
193
+ This is the core model used by task-specific heads.
194
+
195
+ Args:
196
+ config (GeneMambaConfig): Model configuration class.
197
+ """
198
+
199
+ def __init__(self, config: GeneMambaConfig):
200
+ super().__init__(config)
201
+ self.config = config
202
+
203
+ # Embedding layer
204
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
205
+
206
+ # Mamba layers with bidirectional aggregation
207
+ self.mamba_mixer = MambaMixer(
208
+ mode=config.mamba_mode,
209
+ hidden_size=config.hidden_size,
210
+ num_hidden_layers=config.num_hidden_layers
211
+ )
212
+
213
+ # Final layer normalization
214
+ self.norm = RMSNorm(config.hidden_size)
215
+
216
+ self.apply(self._init_weights)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ """Return embedding layer."""
220
+ return self.embeddings
221
+
222
+ def set_input_embeddings(self, value: nn.Embedding):
223
+ """Set embedding layer."""
224
+ self.embeddings = value
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ output_hidden_states: bool = False,
231
+ ) -> GeneMambaModelOutput:
232
+ """
233
+ Args:
234
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
235
+ attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
236
+ output_hidden_states (bool): Whether to output hidden states from all layers.
237
+
238
+ Returns:
239
+ GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
240
+ """
241
+ # Get embeddings
242
+ hidden_states = self.embeddings(input_ids)
243
+
244
+ # Pass through Mamba layers
245
+ hidden_states = self.mamba_mixer(hidden_states, attention_mask)
246
+
247
+ # Apply final normalization
248
+ hidden_states = self.norm(hidden_states)
249
+
250
+ # Compute pooled embedding (cell representation)
251
+ if self.config.embedding_pooling == "CLS":
252
+ # Use first token (CLS)
253
+ pooled_embedding = hidden_states[:, 0, :]
254
+ elif self.config.embedding_pooling == "mean":
255
+ # Mean pooling over sequence
256
+ if attention_mask is not None:
257
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
258
+ pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
259
+ else:
260
+ pooled_embedding = hidden_states.mean(dim=1)
261
+ else:
262
+ raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
263
+
264
+ return GeneMambaModelOutput(
265
+ last_hidden_state=hidden_states,
266
+ pooled_embedding=pooled_embedding,
267
+ hidden_states=hidden_states if output_hidden_states else None,
268
+ embedding_pooling=self.config.embedding_pooling,
269
+ )
270
+
271
+
272
+ # ===========================
273
+ # Task-Specific Models
274
+ # ===========================
275
+
276
+ @register_model_for_auto_class("AutoModel")
277
+ class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
278
+ """
279
+ GeneMamba model for masked language modeling (MLM).
280
+ Suitable for pretraining and domain adaptation.
281
+
282
+ Args:
283
+ config (GeneMambaConfig): Model configuration class.
284
+ """
285
+
286
+ def __init__(self, config: GeneMambaConfig):
287
+ super().__init__(config)
288
+ self.genemamba = GeneMambaModel(config)
289
+
290
+ # Language modeling head
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
292
+
293
+ self.apply(self._init_weights)
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ labels: Optional[torch.Tensor] = None,
300
+ output_hidden_states: bool = False,
301
+ ) -> GeneMambaMaskedLMOutput:
302
+ """
303
+ Args:
304
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
305
+ attention_mask (torch.Tensor, optional): Attention mask.
306
+ labels (torch.Tensor, optional): Target token ids for MLM loss.
307
+ output_hidden_states (bool): Whether to output hidden states.
308
+
309
+ Returns:
310
+ GeneMambaMaskedLMOutput: Contains logits and optional loss.
311
+ """
312
+ outputs = self.genemamba(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ output_hidden_states=output_hidden_states,
316
+ )
317
+
318
+ logits = self.lm_head(outputs.last_hidden_state)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
324
+
325
+ return GeneMambaMaskedLMOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
329
+ )
330
+
331
+
332
+ @register_model_for_auto_class("AutoModelForSequenceClassification")
333
+ class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
334
+ """
335
+ GeneMamba model for sequence classification tasks.
336
+ Ideal for cell type annotation, tissue classification, etc.
337
+
338
+ Args:
339
+ config (GeneMambaConfig): Model configuration class.
340
+ """
341
+
342
+ def __init__(self, config: GeneMambaConfig):
343
+ super().__init__(config)
344
+ self.num_labels = config.num_labels
345
+ self.config = config
346
+
347
+ self.genemamba = GeneMambaModel(config)
348
+
349
+ # Classification head
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
352
+
353
+ self.apply(self._init_weights)
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> GeneMambaSequenceClassifierOutput:
362
+ """
363
+ Args:
364
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
365
+ attention_mask (torch.Tensor, optional): Attention mask.
366
+ labels (torch.Tensor, optional): Class labels for classification loss.
367
+ output_hidden_states (bool): Whether to output hidden states.
368
+
369
+ Returns:
370
+ GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
371
+ """
372
+ outputs = self.genemamba(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ output_hidden_states=output_hidden_states,
376
+ )
377
+
378
+ pooled_embedding = outputs.pooled_embedding
379
+ logits = self.classifier(self.dropout(pooled_embedding))
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ loss_fct = nn.CrossEntropyLoss()
384
+ loss = loss_fct(logits, labels)
385
+
386
+ return GeneMambaSequenceClassifierOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
390
+ pooled_embedding=pooled_embedding,
391
+ )
392
+
393
+
394
+ # Register tokenizer class
395
+ register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
48l-512d/modeling_outputs.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom ModelOutput classes for GeneMamba.
3
+ Defines the output structure for different GeneMamba tasks.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class GeneMambaModelOutput(ModelOutput):
14
+ """
15
+ Base output class for GeneMamba models.
16
+
17
+ Attributes:
18
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
19
+ Sequence of hidden-states at the output of the last layer of the model.
20
+
21
+ hidden_states (tuple(torch.FloatTensor), optional):
22
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
23
+
24
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
25
+ Cell/sequence-level embedding (pooled representation) used for downstream tasks.
26
+ This is the recommended embedding to use for classification, clustering, etc.
27
+
28
+ embedding_pooling (str):
29
+ The pooling method used to generate pooled_embedding.
30
+ """
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ pooled_embedding: torch.FloatTensor = None
35
+ embedding_pooling: str = "mean"
36
+
37
+
38
+ @dataclass
39
+ class GeneMambaSequenceClassifierOutput(ModelOutput):
40
+ """
41
+ Output class for GeneMamba sequence classification models.
42
+
43
+ Attributes:
44
+ loss (torch.FloatTensor of shape (), optional):
45
+ Classification loss (if labels were provided).
46
+
47
+ logits (torch.FloatTensor of shape (batch_size, num_labels)):
48
+ Classification scores (before softmax).
49
+
50
+ hidden_states (tuple(torch.FloatTensor), optional):
51
+ Hidden-states of the model at the output of each layer.
52
+
53
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
54
+ Cell embedding before classification head.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ logits: torch.FloatTensor = None
59
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
60
+ pooled_embedding: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class GeneMambaMaskedLMOutput(ModelOutput):
65
+ """
66
+ Output class for GeneMamba masked language modeling.
67
+
68
+ Attributes:
69
+ loss (torch.FloatTensor of shape (), optional):
70
+ MLM loss (if labels were provided).
71
+
72
+ logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
73
+ Prediction scores of the language modeling head.
74
+
75
+ hidden_states (tuple(torch.FloatTensor), optional):
76
+ Hidden-states of the model at the output of each layer.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ logits: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
48l-512d/special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]"
4
+ }
48l-512d/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
48l-512d/tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "clean_up_tokenization_spaces": true,
4
+ "model_max_length": 1000000000000000019884624838656,
5
+ "pad_token": "[PAD]",
6
+ "tokenizer_class": "PreTrainedTokenizerFast",
7
+ "unk_token": "[UNK]"
8
+ }
48l-768d/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "genemamba",
3
+ "architectures": [
4
+ "GeneMambaModel"
5
+ ],
6
+ "vocab_size": 25426,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 768,
9
+ "num_hidden_layers": 48,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "initializer_range": 0.02,
13
+ "mamba_mode": "gate",
14
+ "embedding_pooling": "mean",
15
+ "num_labels": 2,
16
+ "pad_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "bos_token_id": 0,
19
+ "use_cache": true,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.40.2",
22
+ "auto_map": {
23
+ "AutoConfig": "configuration_genemamba.GeneMambaConfig",
24
+ "AutoModel": "modeling_genemamba.GeneMambaModel",
25
+ "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
26
+ "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
27
+ }
28
+ }
48l-768d/configuration_genemamba.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GeneMamba model.
3
+ Defines all hyperparameters and settings for the GeneMamba architecture.
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional
8
+
9
+
10
+ class GeneMambaConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for GeneMamba model.
13
+
14
+ This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
15
+ It can be used to instantiate models from pretrained checkpoints or customize model initialization.
16
+
17
+ Args:
18
+ vocab_size (int, optional, defaults to 25426):
19
+ Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
20
+
21
+ hidden_size (int, optional, defaults to 512):
22
+ Dimensionality of the hidden/embedding layers (d_model in Mamba).
23
+
24
+ num_hidden_layers (int, optional, defaults to 24):
25
+ Number of Mamba layers (mamba_layer).
26
+
27
+ intermediate_size (int, optional, defaults to 2048):
28
+ Dimensionality of intermediate representations in MLP.
29
+
30
+ max_position_embeddings (int, optional, defaults to 2048):
31
+ Maximum sequence length (seq_len).
32
+
33
+ hidden_dropout_prob (float, optional, defaults to 0.1):
34
+ Dropout probability for hidden states.
35
+
36
+ initializer_range (float, optional, defaults to 0.02):
37
+ Standard deviation of truncated normal initializer.
38
+
39
+ mamba_mode (str, optional, defaults to "gate"):
40
+ Aggregation mode for bidirectional Mamba layers.
41
+ Options: "mean", "sum", "concat", "gate".
42
+
43
+ embedding_pooling (str, optional, defaults to "mean"):
44
+ Method for pooling to get cell embedding.
45
+ Options: "CLS", "mean", "weighted".
46
+
47
+ num_labels (int, optional, defaults to 2):
48
+ Number of labels for sequence classification tasks.
49
+
50
+ pad_token_id (int, optional, defaults to 1):
51
+ Token ID for padding.
52
+
53
+ bos_token_id (int, optional, defaults to None):
54
+ Token ID for beginning of sequence.
55
+
56
+ eos_token_id (int, optional, defaults to None):
57
+ Token ID for end of sequence.
58
+ """
59
+
60
+ model_type = "genemamba"
61
+ attribute_map = {
62
+ "hidden_size": "hidden_size",
63
+ "num_hidden_layers": "num_hidden_layers",
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size: int = 25426,
69
+ hidden_size: int = 512,
70
+ num_hidden_layers: int = 24,
71
+ intermediate_size: int = 2048,
72
+ max_position_embeddings: int = 2048,
73
+ hidden_dropout_prob: float = 0.1,
74
+ initializer_range: float = 0.02,
75
+ mamba_mode: str = "gate",
76
+ embedding_pooling: str = "mean",
77
+ num_labels: int = 2,
78
+ pad_token_id: int = 1,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs
82
+ ):
83
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
84
+
85
+ self.vocab_size = vocab_size
86
+ self.hidden_size = hidden_size
87
+ self.num_hidden_layers = num_hidden_layers
88
+ self.intermediate_size = intermediate_size
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.hidden_dropout_prob = hidden_dropout_prob
91
+ self.initializer_range = initializer_range
92
+ self.mamba_mode = mamba_mode
93
+ self.embedding_pooling = embedding_pooling
94
+ self.num_labels = num_labels
95
+ self.pad_token_id = pad_token_id
96
+ self.bos_token_id = bos_token_id
97
+ self.eos_token_id = eos_token_id
48l-768d/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:728514a211350e69937d73398dffa4c6bbb7f59366fb6c8b39f27437a6a5af77
3
+ size 860161160
48l-768d/modeling_genemamba.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of GeneMamba model for Hugging Face Transformers.
3
+ Includes backbone model and task-specific heads for various downstream tasks.
4
+ """
5
+
6
+ import math
7
+ import logging
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_, constant_
14
+
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
17
+ from transformers.models.auto import register_model_for_auto_class
18
+
19
+ from mamba_ssm import Mamba
20
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm
21
+
22
+ from .configuration_genemamba import GeneMambaConfig
23
+ from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ===========================
29
+ # Core Architecture Components
30
+ # ===========================
31
+
32
+ class EncoderLayer(nn.Module):
33
+ """
34
+ Single Mamba encoder layer with residual connection.
35
+ Applies a Mamba2 or Mamba layer followed by addition with input.
36
+
37
+ Args:
38
+ hidden_size (int): Dimension of hidden states.
39
+ """
40
+
41
+ def __init__(self, hidden_size: int):
42
+ super(EncoderLayer, self).__init__()
43
+ self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
44
+
45
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
49
+
50
+ Returns:
51
+ torch.Tensor: Output after Mamba layer and residual connection.
52
+ """
53
+ output = self.mamba(X) + X
54
+ return output
55
+
56
+
57
+ class MambaMixer(nn.Module):
58
+ """
59
+ Stack of Mamba encoder layers with bidirectional processing and aggregation.
60
+ Processes sequences in both forward and reverse directions, then aggregates.
61
+
62
+ Args:
63
+ mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
64
+ hidden_size (int): Dimension of hidden states.
65
+ num_hidden_layers (int): Number of Mamba layers.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ mode: str = "gate",
71
+ hidden_size: int = 512,
72
+ num_hidden_layers: int = 24
73
+ ):
74
+ super(MambaMixer, self).__init__()
75
+ self.mode = mode
76
+ self.hidden_size = hidden_size
77
+
78
+ # Create Mamba layers
79
+ self.layers = nn.ModuleList(
80
+ [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
81
+ )
82
+
83
+ # Aggregation modules for certain modes
84
+ if mode in ["concat", "gate"]:
85
+ self.aggr = nn.Linear(hidden_size * 2, hidden_size)
86
+
87
+ def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
88
+ """
89
+ Reverse a sequence based on actual length (ignoring padding).
90
+
91
+ Args:
92
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
93
+ mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
94
+
95
+ Returns:
96
+ torch.Tensor: Reversed tensor.
97
+ """
98
+ batch_size, seq_length, embedding_dim = X.size()
99
+
100
+ if mask is None:
101
+ # Simple flip
102
+ return X.flip([1])
103
+
104
+ # Flip based on actual sequence length (marked by mask)
105
+ lengths = (~mask).sum(dim=1)
106
+ pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
107
+ flip_mask = pos_tensor < lengths.unsqueeze(1)
108
+ reversed_positions = torch.where(
109
+ flip_mask,
110
+ lengths.unsqueeze(1) - 1 - pos_tensor,
111
+ pos_tensor
112
+ )
113
+
114
+ X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
115
+ return X_reverse
116
+
117
+ def forward(
118
+ self,
119
+ X: torch.Tensor,
120
+ padding_mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ """
123
+ Process sequence through bidirectional Mamba layers.
124
+
125
+ Args:
126
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
127
+ padding_mask (torch.Tensor, optional): Padding mask.
128
+
129
+ Returns:
130
+ torch.Tensor: Output after processing all layers and aggregation.
131
+ """
132
+
133
+ for layer in self.layers:
134
+ # Flip sequence for reverse processing
135
+ X_flip = self.flip_sequence(X, padding_mask)
136
+
137
+ # Forward and reverse passes
138
+ X_f = layer(X)
139
+ X_b = layer(X_flip)
140
+
141
+ # Flip back the reverse output
142
+ X_b = self.flip_sequence(X_b, padding_mask)
143
+
144
+ # Aggregate forward and reverse
145
+ if self.mode == "mean":
146
+ X = (X_f + X_b) / 2
147
+ elif self.mode == "sum":
148
+ X = X_f + X_b
149
+ elif self.mode == "concat":
150
+ X = torch.cat([X_f, X_b], dim=-1)
151
+ X = self.aggr(X)
152
+ elif self.mode == "gate":
153
+ z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
154
+ X = z * X_f + (1 - z) * X_b
155
+ else:
156
+ raise ValueError(f"Invalid aggregation mode: {self.mode}")
157
+
158
+ return X
159
+
160
+
161
+ # ===========================
162
+ # Base Model Classes
163
+ # ===========================
164
+
165
+ class GeneMambaPreTrainedModel(PreTrainedModel):
166
+ """
167
+ Base class for all GeneMamba models.
168
+ Handles weight initialization and provides standard model interfaces.
169
+ """
170
+
171
+ config_class = GeneMambaConfig
172
+ base_model_prefix = "genemamba"
173
+ supports_gradient_checkpointing = True
174
+
175
+ def _init_weights(self, module):
176
+ """Initialize module weights."""
177
+ if isinstance(module, nn.Linear):
178
+ normal_(module.weight, std=self.config.initializer_range)
179
+ if module.bias is not None:
180
+ constant_(module.bias, 0.0)
181
+ elif isinstance(module, nn.Embedding):
182
+ normal_(module.weight, std=self.config.initializer_range)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+ elif isinstance(module, nn.LayerNorm):
186
+ constant_(module.bias, 0.0)
187
+ constant_(module.weight, 1.0)
188
+
189
+
190
+ class GeneMambaModel(GeneMambaPreTrainedModel):
191
+ """
192
+ GeneMamba backbone model - outputs cell embeddings and hidden states.
193
+ This is the core model used by task-specific heads.
194
+
195
+ Args:
196
+ config (GeneMambaConfig): Model configuration class.
197
+ """
198
+
199
+ def __init__(self, config: GeneMambaConfig):
200
+ super().__init__(config)
201
+ self.config = config
202
+
203
+ # Embedding layer
204
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
205
+
206
+ # Mamba layers with bidirectional aggregation
207
+ self.mamba_mixer = MambaMixer(
208
+ mode=config.mamba_mode,
209
+ hidden_size=config.hidden_size,
210
+ num_hidden_layers=config.num_hidden_layers
211
+ )
212
+
213
+ # Final layer normalization
214
+ self.norm = RMSNorm(config.hidden_size)
215
+
216
+ self.apply(self._init_weights)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ """Return embedding layer."""
220
+ return self.embeddings
221
+
222
+ def set_input_embeddings(self, value: nn.Embedding):
223
+ """Set embedding layer."""
224
+ self.embeddings = value
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ output_hidden_states: bool = False,
231
+ ) -> GeneMambaModelOutput:
232
+ """
233
+ Args:
234
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
235
+ attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
236
+ output_hidden_states (bool): Whether to output hidden states from all layers.
237
+
238
+ Returns:
239
+ GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
240
+ """
241
+ # Get embeddings
242
+ hidden_states = self.embeddings(input_ids)
243
+
244
+ # Pass through Mamba layers
245
+ hidden_states = self.mamba_mixer(hidden_states, attention_mask)
246
+
247
+ # Apply final normalization
248
+ hidden_states = self.norm(hidden_states)
249
+
250
+ # Compute pooled embedding (cell representation)
251
+ if self.config.embedding_pooling == "CLS":
252
+ # Use first token (CLS)
253
+ pooled_embedding = hidden_states[:, 0, :]
254
+ elif self.config.embedding_pooling == "mean":
255
+ # Mean pooling over sequence
256
+ if attention_mask is not None:
257
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
258
+ pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
259
+ else:
260
+ pooled_embedding = hidden_states.mean(dim=1)
261
+ else:
262
+ raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
263
+
264
+ return GeneMambaModelOutput(
265
+ last_hidden_state=hidden_states,
266
+ pooled_embedding=pooled_embedding,
267
+ hidden_states=hidden_states if output_hidden_states else None,
268
+ embedding_pooling=self.config.embedding_pooling,
269
+ )
270
+
271
+
272
+ # ===========================
273
+ # Task-Specific Models
274
+ # ===========================
275
+
276
+ @register_model_for_auto_class("AutoModel")
277
+ class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
278
+ """
279
+ GeneMamba model for masked language modeling (MLM).
280
+ Suitable for pretraining and domain adaptation.
281
+
282
+ Args:
283
+ config (GeneMambaConfig): Model configuration class.
284
+ """
285
+
286
+ def __init__(self, config: GeneMambaConfig):
287
+ super().__init__(config)
288
+ self.genemamba = GeneMambaModel(config)
289
+
290
+ # Language modeling head
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
292
+
293
+ self.apply(self._init_weights)
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ labels: Optional[torch.Tensor] = None,
300
+ output_hidden_states: bool = False,
301
+ ) -> GeneMambaMaskedLMOutput:
302
+ """
303
+ Args:
304
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
305
+ attention_mask (torch.Tensor, optional): Attention mask.
306
+ labels (torch.Tensor, optional): Target token ids for MLM loss.
307
+ output_hidden_states (bool): Whether to output hidden states.
308
+
309
+ Returns:
310
+ GeneMambaMaskedLMOutput: Contains logits and optional loss.
311
+ """
312
+ outputs = self.genemamba(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ output_hidden_states=output_hidden_states,
316
+ )
317
+
318
+ logits = self.lm_head(outputs.last_hidden_state)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
324
+
325
+ return GeneMambaMaskedLMOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
329
+ )
330
+
331
+
332
+ @register_model_for_auto_class("AutoModelForSequenceClassification")
333
+ class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
334
+ """
335
+ GeneMamba model for sequence classification tasks.
336
+ Ideal for cell type annotation, tissue classification, etc.
337
+
338
+ Args:
339
+ config (GeneMambaConfig): Model configuration class.
340
+ """
341
+
342
+ def __init__(self, config: GeneMambaConfig):
343
+ super().__init__(config)
344
+ self.num_labels = config.num_labels
345
+ self.config = config
346
+
347
+ self.genemamba = GeneMambaModel(config)
348
+
349
+ # Classification head
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
352
+
353
+ self.apply(self._init_weights)
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> GeneMambaSequenceClassifierOutput:
362
+ """
363
+ Args:
364
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
365
+ attention_mask (torch.Tensor, optional): Attention mask.
366
+ labels (torch.Tensor, optional): Class labels for classification loss.
367
+ output_hidden_states (bool): Whether to output hidden states.
368
+
369
+ Returns:
370
+ GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
371
+ """
372
+ outputs = self.genemamba(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ output_hidden_states=output_hidden_states,
376
+ )
377
+
378
+ pooled_embedding = outputs.pooled_embedding
379
+ logits = self.classifier(self.dropout(pooled_embedding))
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ loss_fct = nn.CrossEntropyLoss()
384
+ loss = loss_fct(logits, labels)
385
+
386
+ return GeneMambaSequenceClassifierOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
390
+ pooled_embedding=pooled_embedding,
391
+ )
392
+
393
+
394
+ # Register tokenizer class
395
+ register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
48l-768d/modeling_outputs.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom ModelOutput classes for GeneMamba.
3
+ Defines the output structure for different GeneMamba tasks.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class GeneMambaModelOutput(ModelOutput):
14
+ """
15
+ Base output class for GeneMamba models.
16
+
17
+ Attributes:
18
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
19
+ Sequence of hidden-states at the output of the last layer of the model.
20
+
21
+ hidden_states (tuple(torch.FloatTensor), optional):
22
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
23
+
24
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
25
+ Cell/sequence-level embedding (pooled representation) used for downstream tasks.
26
+ This is the recommended embedding to use for classification, clustering, etc.
27
+
28
+ embedding_pooling (str):
29
+ The pooling method used to generate pooled_embedding.
30
+ """
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ pooled_embedding: torch.FloatTensor = None
35
+ embedding_pooling: str = "mean"
36
+
37
+
38
+ @dataclass
39
+ class GeneMambaSequenceClassifierOutput(ModelOutput):
40
+ """
41
+ Output class for GeneMamba sequence classification models.
42
+
43
+ Attributes:
44
+ loss (torch.FloatTensor of shape (), optional):
45
+ Classification loss (if labels were provided).
46
+
47
+ logits (torch.FloatTensor of shape (batch_size, num_labels)):
48
+ Classification scores (before softmax).
49
+
50
+ hidden_states (tuple(torch.FloatTensor), optional):
51
+ Hidden-states of the model at the output of each layer.
52
+
53
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
54
+ Cell embedding before classification head.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ logits: torch.FloatTensor = None
59
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
60
+ pooled_embedding: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class GeneMambaMaskedLMOutput(ModelOutput):
65
+ """
66
+ Output class for GeneMamba masked language modeling.
67
+
68
+ Attributes:
69
+ loss (torch.FloatTensor of shape (), optional):
70
+ MLM loss (if labels were provided).
71
+
72
+ logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
73
+ Prediction scores of the language modeling head.
74
+
75
+ hidden_states (tuple(torch.FloatTensor), optional):
76
+ Hidden-states of the model at the output of each layer.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ logits: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
48l-768d/special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]"
4
+ }
48l-768d/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
48l-768d/tokenizer_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "clean_up_tokenization_spaces": true,
4
+ "model_max_length": 1000000000000000019884624838656,
5
+ "pad_token": "[PAD]",
6
+ "tokenizer_class": "PreTrainedTokenizerFast",
7
+ "unk_token": "[UNK]"
8
+ }
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags:
4
+ - genomics
5
+ - single-cell
6
+ - mamba
7
+ - biology
8
+ pipeline_tag: feature-extraction
9
+ ---
10
+
11
+ # GeneMamba
12
+
13
+ This repository contains a **default GeneMamba model** plus full usage assets:
14
+ - default model weights at repository root (**24l-512d**)
15
+ - custom modeling/config files for `trust_remote_code=True`
16
+ - preprocessing example from `h5ad` to `input_ids`
17
+ - tokenizer assets and id mapping files
18
+
19
+ Additional model sizes are provided as subfolders:
20
+ - `24l-512d` (same architecture class as default)
21
+ - `24l-768d`
22
+ - `48l-512d`
23
+ - `48l-768d`
24
+
25
+ ## 1) Input format (very important)
26
+
27
+ GeneMamba input is **ranked gene token IDs** per cell:
28
+ 1. Start from one cell expression vector
29
+ 2. Keep genes with expression > 0
30
+ 3. Sort genes by expression descending
31
+ 4. Convert each gene ID (Ensembl, e.g. `ENSG00000000003`) to token ID
32
+ 5. Use resulting list as `input_ids`
33
+
34
+ Each sample is one list of integers:
35
+
36
+ ```python
37
+ {"input_ids": [145, 2088, 531, 91, ...]}
38
+ ```
39
+
40
+ For batch input, shape is typically `(batch_size, seq_len)` after padding/truncation.
41
+
42
+ ## 2) Where tokenizer and id mapping come from
43
+
44
+ - Main tokenizer used for model inference: `tokenizer.json`
45
+ - Original full tokenizer table: `tokenizer_assets/gene_tokenizer.json`
46
+ - Gene symbol -> token id mapping: `tokenizer_assets/symbol2id.pkl`
47
+ - Token id -> gene symbol mapping: `tokenizer_assets/id2symbol.pkl`
48
+
49
+ Special tokens:
50
+ - `[UNK]` = 0
51
+ - `[PAD]` = 1
52
+
53
+ ## 3) Preprocess your data
54
+
55
+ See script:
56
+ - `examples/00_preprocess_to_input_ids.py`
57
+
58
+ Example:
59
+
60
+ ```bash
61
+ python examples/00_preprocess_to_input_ids.py \
62
+ --h5ad /path/to/your_data.h5ad \
63
+ --tokenizer_json tokenizer.json \
64
+ --output_arrow ./my_data/sorted_gene_token_ids.arrow
65
+ ```
66
+
67
+ This output Arrow file has one column: `input_ids`.
68
+
69
+ ## 4) Load model and extract embedding
70
+
71
+ ### Default load (24l-512d)
72
+
73
+ ```python
74
+ from transformers import AutoModel, AutoTokenizer
75
+
76
+ model = AutoModel.from_pretrained(
77
+ "mineself2016/GeneMamba",
78
+ trust_remote_code=True
79
+ )
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(
82
+ "mineself2016/GeneMamba",
83
+ trust_remote_code=True
84
+ )
85
+ ```
86
+
87
+ ### Load other sizes (via `subfolder`)
88
+
89
+ ```python
90
+ from transformers import AutoModel
91
+
92
+ model_24l_768d = AutoModel.from_pretrained(
93
+ "mineself2016/GeneMamba",
94
+ subfolder="24l-768d",
95
+ trust_remote_code=True,
96
+ )
97
+
98
+ model_48l_512d = AutoModel.from_pretrained(
99
+ "mineself2016/GeneMamba",
100
+ subfolder="48l-512d",
101
+ trust_remote_code=True,
102
+ )
103
+
104
+ model_48l_768d = AutoModel.from_pretrained(
105
+ "mineself2016/GeneMamba",
106
+ subfolder="48l-768d",
107
+ trust_remote_code=True,
108
+ )
109
+ ```
110
+
111
+ More complete example:
112
+ - `examples/01_extract_embeddings.py`
113
+
114
+ ## 6) Downstream task examples (added)
115
+
116
+ See:
117
+ - `examples/downstream/README.md`
118
+
119
+ Included downstream tasks:
120
+ - cell type annotation fine-tuning
121
+ - zero-shot embedding + logistic regression
122
+ - batch integration proxy evaluation
123
+ - original legacy downstream scripts from `gene_mamba/analysis/cell_type_annotation`
124
+
125
+ ## 7) Source of preprocessing logic
126
+
127
+ The preprocessing/tokenization pipeline is aligned with assets from:
128
+ - `/project/zhiwei/cq5/PythonWorkSpace/gene_mamba`
129
+
130
+ Key references used:
131
+ - tokenizer: `gene_tokenizer.json`
132
+ - mappings: `symbol2id.pkl`, `id2symbol.pkl`
133
+ - dataset build logic (Arrow + `input_ids`): `utils.py` (`build_dataset`)
config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "genemamba",
3
+ "architectures": [
4
+ "GeneMambaModel"
5
+ ],
6
+ "vocab_size": 25426,
7
+ "max_position_embeddings": 2048,
8
+ "hidden_size": 512,
9
+ "num_hidden_layers": 24,
10
+ "intermediate_size": 2048,
11
+ "hidden_dropout_prob": 0.1,
12
+ "initializer_range": 0.02,
13
+ "mamba_mode": "gate",
14
+ "embedding_pooling": "mean",
15
+ "num_labels": 2,
16
+ "pad_token_id": 1,
17
+ "eos_token_id": 2,
18
+ "bos_token_id": 0,
19
+ "use_cache": true,
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.40.2",
22
+ "auto_map": {
23
+ "AutoConfig": "configuration_genemamba.GeneMambaConfig",
24
+ "AutoModel": "modeling_genemamba.GeneMambaModel",
25
+ "AutoModelForMaskedLM": "modeling_genemamba.GeneMambaForMaskedLM",
26
+ "AutoModelForSequenceClassification": "modeling_genemamba.GeneMambaForSequenceClassification"
27
+ }
28
+ }
configuration_genemamba.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for GeneMamba model.
3
+ Defines all hyperparameters and settings for the GeneMamba architecture.
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+ from typing import Optional
8
+
9
+
10
+ class GeneMambaConfig(PretrainedConfig):
11
+ """
12
+ Configuration class for GeneMamba model.
13
+
14
+ This class stores the configuration of a GeneMamba model, inheriting from PretrainedConfig.
15
+ It can be used to instantiate models from pretrained checkpoints or customize model initialization.
16
+
17
+ Args:
18
+ vocab_size (int, optional, defaults to 25426):
19
+ Vocabulary size of the model. Number of gene tokens (Ensembl Gene IDs).
20
+
21
+ hidden_size (int, optional, defaults to 512):
22
+ Dimensionality of the hidden/embedding layers (d_model in Mamba).
23
+
24
+ num_hidden_layers (int, optional, defaults to 24):
25
+ Number of Mamba layers (mamba_layer).
26
+
27
+ intermediate_size (int, optional, defaults to 2048):
28
+ Dimensionality of intermediate representations in MLP.
29
+
30
+ max_position_embeddings (int, optional, defaults to 2048):
31
+ Maximum sequence length (seq_len).
32
+
33
+ hidden_dropout_prob (float, optional, defaults to 0.1):
34
+ Dropout probability for hidden states.
35
+
36
+ initializer_range (float, optional, defaults to 0.02):
37
+ Standard deviation of truncated normal initializer.
38
+
39
+ mamba_mode (str, optional, defaults to "gate"):
40
+ Aggregation mode for bidirectional Mamba layers.
41
+ Options: "mean", "sum", "concat", "gate".
42
+
43
+ embedding_pooling (str, optional, defaults to "mean"):
44
+ Method for pooling to get cell embedding.
45
+ Options: "CLS", "mean", "weighted".
46
+
47
+ num_labels (int, optional, defaults to 2):
48
+ Number of labels for sequence classification tasks.
49
+
50
+ pad_token_id (int, optional, defaults to 1):
51
+ Token ID for padding.
52
+
53
+ bos_token_id (int, optional, defaults to None):
54
+ Token ID for beginning of sequence.
55
+
56
+ eos_token_id (int, optional, defaults to None):
57
+ Token ID for end of sequence.
58
+ """
59
+
60
+ model_type = "genemamba"
61
+ attribute_map = {
62
+ "hidden_size": "hidden_size",
63
+ "num_hidden_layers": "num_hidden_layers",
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_size: int = 25426,
69
+ hidden_size: int = 512,
70
+ num_hidden_layers: int = 24,
71
+ intermediate_size: int = 2048,
72
+ max_position_embeddings: int = 2048,
73
+ hidden_dropout_prob: float = 0.1,
74
+ initializer_range: float = 0.02,
75
+ mamba_mode: str = "gate",
76
+ embedding_pooling: str = "mean",
77
+ num_labels: int = 2,
78
+ pad_token_id: int = 1,
79
+ bos_token_id: Optional[int] = None,
80
+ eos_token_id: Optional[int] = None,
81
+ **kwargs
82
+ ):
83
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
84
+
85
+ self.vocab_size = vocab_size
86
+ self.hidden_size = hidden_size
87
+ self.num_hidden_layers = num_hidden_layers
88
+ self.intermediate_size = intermediate_size
89
+ self.max_position_embeddings = max_position_embeddings
90
+ self.hidden_dropout_prob = hidden_dropout_prob
91
+ self.initializer_range = initializer_range
92
+ self.mamba_mode = mamba_mode
93
+ self.embedding_pooling = embedding_pooling
94
+ self.num_labels = num_labels
95
+ self.pad_token_id = pad_token_id
96
+ self.bos_token_id = bos_token_id
97
+ self.eos_token_id = eos_token_id
examples/00_preprocess_to_input_ids.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ import pyarrow as pa
9
+
10
+
11
+ def load_vocab(tokenizer_json_path: str):
12
+ with open(tokenizer_json_path, "r") as f:
13
+ tokenizer = json.load(f)
14
+ vocab = tokenizer["model"]["vocab"]
15
+ pad_id = vocab.get("[PAD]", 1)
16
+ unk_id = vocab.get("[UNK]", 0)
17
+ return vocab, pad_id, unk_id
18
+
19
+
20
+ def ranked_gene_ids_for_cell(expr_values, gene_names, vocab):
21
+ nonzero_idx = np.where(expr_values > 0)[0]
22
+ if len(nonzero_idx) == 0:
23
+ return []
24
+
25
+ genes = np.array(gene_names)[nonzero_idx]
26
+ values = expr_values[nonzero_idx]
27
+
28
+ order = np.argsort(-values)
29
+ ranked_genes = genes[order]
30
+
31
+ token_ids = [vocab[g] for g in ranked_genes if g in vocab]
32
+ return token_ids
33
+
34
+
35
+ def main():
36
+ parser = argparse.ArgumentParser(description="Convert h5ad to GeneMamba input_ids (Arrow)")
37
+ parser.add_argument("--h5ad", required=True, help="Input h5ad file")
38
+ parser.add_argument("--tokenizer_json", required=True, help="Path to tokenizer.json or gene_tokenizer.json")
39
+ parser.add_argument("--output_arrow", required=True, help="Output arrow file path")
40
+ parser.add_argument("--max_cells", type=int, default=None, help="Optional: process first N cells only")
41
+ args = parser.parse_args()
42
+
43
+ adata = sc.read_h5ad(args.h5ad)
44
+ vocab, _, _ = load_vocab(args.tokenizer_json)
45
+
46
+ gene_names = list(adata.var_names)
47
+ n_cells = adata.n_obs if args.max_cells is None else min(args.max_cells, adata.n_obs)
48
+
49
+ rows = []
50
+ X = adata.X
51
+
52
+ for i in range(n_cells):
53
+ row = X[i]
54
+ if hasattr(row, "toarray"):
55
+ expr = row.toarray().ravel()
56
+ else:
57
+ expr = np.asarray(row).ravel()
58
+
59
+ token_ids = ranked_gene_ids_for_cell(expr, gene_names, vocab)
60
+ rows.append(token_ids)
61
+
62
+ df = pd.DataFrame({"input_ids": rows})
63
+ table = pa.Table.from_pandas(df)
64
+
65
+ output_path = Path(args.output_arrow)
66
+ output_path.parent.mkdir(parents=True, exist_ok=True)
67
+ with pa.OSFile(str(output_path), "wb") as sink:
68
+ with pa.ipc.new_stream(sink, table.schema) as writer:
69
+ writer.write_table(table)
70
+
71
+ print(f"Saved {len(rows)} cells to {output_path}")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
examples/01_extract_embeddings.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 1: Extract Cell Embeddings
3
+ Demonstrates how to load GeneMamba and extract cell embeddings for downstream analysis.
4
+
5
+ Usage:
6
+ python examples/1_extract_embeddings.py
7
+ """
8
+
9
+ import torch
10
+ import numpy as np
11
+ from transformers import AutoTokenizer, AutoModel
12
+
13
+
14
+ def main():
15
+ print("=" * 80)
16
+ print("GeneMamba Phase 1: Extract Cell Embeddings")
17
+ print("=" * 80)
18
+
19
+ # ============================================================
20
+ # Step 1: Load pretrained model and tokenizer
21
+ # ============================================================
22
+ print("\n[Step 1] Loading model and tokenizer...")
23
+
24
+ # For this example, we use a local model path
25
+ # In practice, you would use: "username/GeneMamba-24l-512d"
26
+ model_name = "GeneMamba-24l-512d" # Change to HF Hub path when available
27
+
28
+ try:
29
+ tokenizer = AutoTokenizer.from_pretrained(
30
+ model_name,
31
+ trust_remote_code=True,
32
+ local_files_only=True # Try local first
33
+ )
34
+ model = AutoModel.from_pretrained(
35
+ model_name,
36
+ trust_remote_code=True,
37
+ local_files_only=True
38
+ )
39
+ except Exception as e:
40
+ print(f"Note: Could not load from '{model_name}': {e}")
41
+ print("Using mock data for demonstration...")
42
+
43
+ # For demonstration without actual checkpoint
44
+ from configuration_genemamba import GeneMambaConfig
45
+ from modeling_genemamba import GeneMambaModel
46
+
47
+ config = GeneMambaConfig(
48
+ vocab_size=25426,
49
+ hidden_size=512,
50
+ num_hidden_layers=24,
51
+ embedding_pooling="mean",
52
+ )
53
+ model = GeneMambaModel(config)
54
+ tokenizer = None
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ model = model.to(device)
58
+ model.eval()
59
+
60
+ print(f"✓ Model loaded on device: {device}")
61
+ print(f"✓ Model config: hidden_size={model.config.hidden_size}, "
62
+ f"num_layers={model.config.num_hidden_layers}")
63
+
64
+ # ============================================================
65
+ # Step 2: Prepare simulated single-cell data
66
+ # ============================================================
67
+ print("\n[Step 2] Preparing sample data...")
68
+
69
+ batch_size = 8
70
+ seq_len = 2048
71
+ vocab_size = 25426
72
+
73
+ # Simulate ranked gene sequences
74
+ # In practice, this would come from your scRNA-seq data
75
+ # Genes should be ranked by expression (highest first)
76
+ input_ids = torch.randint(2, vocab_size, (batch_size, seq_len)).to(device)
77
+
78
+ print(f"✓ Created sample input:")
79
+ print(f" - Batch size: {batch_size}")
80
+ print(f" - Sequence length: {seq_len}")
81
+ print(f" - Input shape: {input_ids.shape}")
82
+
83
+ # ============================================================
84
+ # Step 3: Inference - Extract embeddings
85
+ # ============================================================
86
+ print("\n[Step 3] Extracting cell embeddings...")
87
+
88
+ with torch.no_grad():
89
+ outputs = model(input_ids, output_hidden_states=False)
90
+
91
+ # Get the pooled embedding (cell representation)
92
+ cell_embeddings = outputs.pooled_embedding
93
+
94
+ print(f"✓ Extraction complete!")
95
+ print(f" - Cell embeddings shape: {cell_embeddings.shape}")
96
+ print(f" - Pooling method used: {outputs.embedding_pooling}")
97
+ print(f" - Embedding type: {cell_embeddings.dtype}")
98
+
99
+ # ============================================================
100
+ # Step 4: Example downstream analyses
101
+ # ============================================================
102
+ print("\n[Step 4] Example downstream uses...")
103
+
104
+ # Example 1: Clustering (KMeans)
105
+ from sklearn.cluster import KMeans
106
+ n_clusters = 3
107
+ kmeans = KMeans(n_clusters=n_clusters, n_init=10)
108
+ clusters = kmeans.fit_predict(cell_embeddings.cpu().numpy())
109
+ print(f"✓ Clustering: Assigned {len(np.unique(clusters))} clusters")
110
+
111
+ # Example 2: Dimensionality reduction (PCA)
112
+ from sklearn.decomposition import PCA
113
+ pca = PCA(n_components=2)
114
+ embedding_2d = pca.fit_transform(cell_embeddings.cpu().numpy())
115
+ print(f"✓ PCA reduction: {cell_embeddings.shape} → {embedding_2d.shape}")
116
+
117
+ # Example 3: Similarity search
118
+ # Find the most similar cell to the first cell
119
+ similarities = torch.nn.functional.cosine_similarity(
120
+ cell_embeddings[0:1],
121
+ cell_embeddings
122
+ )
123
+ most_similar_idx = torch.argmax(similarities).item()
124
+ print(f"✓ Similarity search: Most similar cell to cell 0 is cell {most_similar_idx} "
125
+ f"(similarity: {similarities[most_similar_idx]:.4f})")
126
+
127
+ # Example 4: Statistics
128
+ print("\n[Step 5] Embedding statistics:")
129
+ print(f" - Mean: {cell_embeddings.mean(dim=0).norm():.4f}")
130
+ print(f" - Std: {cell_embeddings.std(dim=0).mean():.4f}")
131
+ print(f" - Min: {cell_embeddings.min():.4f}")
132
+ print(f" - Max: {cell_embeddings.max():.4f}")
133
+
134
+ # ============================================================
135
+ # Step 6: Save embeddings (optional)
136
+ # ============================================================
137
+ print("\n[Step 6] Saving embeddings...")
138
+
139
+ np.save("cell_embeddings.npy", cell_embeddings.cpu().numpy())
140
+ print("✓ Embeddings saved to 'cell_embeddings.npy'")
141
+
142
+ print("\n" + "=" * 80)
143
+ print("Phase 1 Complete!")
144
+ print("=" * 80)
145
+
146
+ return model, cell_embeddings
147
+
148
+
149
+ if __name__ == "__main__":
150
+ model, embeddings = main()
examples/downstream/10_finetune_classification.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 2: Downstream Task - Fine-tune for Classification
3
+ Demonstrates cell type annotation and other sequence classification tasks.
4
+
5
+ Usage:
6
+ python examples/2_finetune_classification.py
7
+ """
8
+
9
+ import torch
10
+ import numpy as np
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
13
+
14
+
15
+ class GeneExpressionDataset(Dataset):
16
+ """
17
+ Simple dataset for gene expression classification.
18
+ In practice, this would load from h5ad or other single-cell formats.
19
+ """
20
+
21
+ def __init__(self, input_ids, labels, max_length=2048):
22
+ self.input_ids = input_ids
23
+ self.labels = labels
24
+ self.max_length = max_length
25
+
26
+ def __len__(self):
27
+ return len(self.input_ids)
28
+
29
+ def __getitem__(self, idx):
30
+ input_id = self.input_ids[idx]
31
+ label = self.labels[idx]
32
+
33
+ return {
34
+ "input_ids": torch.tensor(input_id, dtype=torch.long),
35
+ "labels": torch.tensor(label, dtype=torch.long),
36
+ }
37
+
38
+
39
+ def create_mock_data(n_samples=1000, n_features=2048, n_classes=5):
40
+ """Create mock single-cell data for demonstration."""
41
+
42
+ print("Creating mock dataset...")
43
+
44
+ # Create random ranked gene sequences
45
+ input_ids = np.random.randint(2, 25426, (n_samples, n_features))
46
+
47
+ # Create random labels (e.g., cell types)
48
+ labels = np.random.randint(0, n_classes, n_samples)
49
+
50
+ # Split into train/val/test
51
+ train_size = int(0.7 * n_samples)
52
+ val_size = int(0.15 * n_samples)
53
+
54
+ train_ids = input_ids[:train_size]
55
+ train_labels = labels[:train_size]
56
+
57
+ val_ids = input_ids[train_size:train_size + val_size]
58
+ val_labels = labels[train_size:train_size + val_size]
59
+
60
+ test_ids = input_ids[train_size + val_size:]
61
+ test_labels = labels[train_size + val_size:]
62
+
63
+ print(f"✓ Dataset created:")
64
+ print(f" - Train: {len(train_ids)} samples")
65
+ print(f" - Val: {len(val_ids)} samples")
66
+ print(f" - Test: {len(test_ids)} samples")
67
+ print(f" - Classes: {n_classes}")
68
+
69
+ return (
70
+ GeneExpressionDataset(train_ids, train_labels),
71
+ GeneExpressionDataset(val_ids, val_labels),
72
+ GeneExpressionDataset(test_ids, test_labels),
73
+ )
74
+
75
+
76
+ def main():
77
+ print("=" * 80)
78
+ print("GeneMamba Phase 2: Downstream Classification")
79
+ print("=" * 80)
80
+
81
+ # ============================================================
82
+ # Step 1: Load pretrained model with classification head
83
+ # ============================================================
84
+ print("\n[Step 1] Loading pretrained model with classification head...")
85
+
86
+ num_classes = 5
87
+
88
+ try:
89
+ model = AutoModelForSequenceClassification.from_pretrained(
90
+ "GeneMamba-24l-512d",
91
+ num_labels=num_classes,
92
+ trust_remote_code=True,
93
+ local_files_only=True,
94
+ )
95
+ except Exception as e:
96
+ print(f"Note: Could not load from hub ({e})")
97
+ print("Using local initialization...")
98
+
99
+ # Initialize locally
100
+ from configuration_genemamba import GeneMambaConfig
101
+ from modeling_genemamba import GeneMambaForSequenceClassification
102
+
103
+ config = GeneMambaConfig(
104
+ vocab_size=25426,
105
+ hidden_size=512,
106
+ num_hidden_layers=24,
107
+ num_labels=num_classes,
108
+ )
109
+ model = GeneMambaForSequenceClassification(config)
110
+
111
+ print(f"✓ Model loaded")
112
+ print(f" - Classification head: input={model.config.hidden_size} → output={num_classes}")
113
+
114
+ # ============================================================
115
+ # Step 2: Prepare data
116
+ # ============================================================
117
+ print("\n[Step 2] Preparing dataset...")
118
+
119
+ train_dataset, val_dataset, test_dataset = create_mock_data(
120
+ n_samples=1000,
121
+ n_features=2048,
122
+ n_classes=num_classes,
123
+ )
124
+
125
+ # ============================================================
126
+ # Step 3: Set up training arguments
127
+ # ============================================================
128
+ print("\n[Step 3] Setting up training...")
129
+
130
+ output_dir = "./classification_results"
131
+
132
+ training_args = TrainingArguments(
133
+ output_dir=output_dir,
134
+ num_train_epochs=3,
135
+ per_device_train_batch_size=16,
136
+ per_device_eval_batch_size=16,
137
+ learning_rate=2e-5,
138
+ weight_decay=0.01,
139
+ warmup_steps=100,
140
+ logging_steps=50,
141
+ eval_strategy="epoch",
142
+ save_strategy="epoch",
143
+ load_best_model_at_end=True,
144
+ metric_for_best_model="accuracy",
145
+ report_to="none", # Disable W&B logging
146
+ seed=42,
147
+ )
148
+
149
+ print(f"✓ Training config:")
150
+ print(f" - Output dir: {output_dir}")
151
+ print(f" - Epochs: {training_args.num_train_epochs}")
152
+ print(f" - Batch size: {training_args.per_device_train_batch_size}")
153
+ print(f" - Learning rate: {training_args.learning_rate}")
154
+
155
+ # ============================================================
156
+ # Step 4: Train using Trainer
157
+ # ============================================================
158
+ print("\n[Step 4] Training model...")
159
+
160
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
161
+
162
+ def compute_metrics(eval_pred):
163
+ """Compute evaluation metrics."""
164
+ predictions, labels = eval_pred
165
+ predictions = np.argmax(predictions, axis=1)
166
+
167
+ return {
168
+ "accuracy": accuracy_score(labels, predictions),
169
+ "f1": f1_score(labels, predictions, average="weighted", zero_division=0),
170
+ "precision": precision_score(labels, predictions, average="weighted", zero_division=0),
171
+ "recall": recall_score(labels, predictions, average="weighted", zero_division=0),
172
+ }
173
+
174
+ trainer = Trainer(
175
+ model=model,
176
+ args=training_args,
177
+ train_dataset=train_dataset,
178
+ eval_dataset=val_dataset,
179
+ compute_metrics=compute_metrics,
180
+ )
181
+
182
+ train_result = trainer.train()
183
+
184
+ print(f"✓ Training complete!")
185
+ print(f" - Final training loss: {train_result.training_loss:.4f}")
186
+
187
+ # ============================================================
188
+ # Step 5: Evaluate on test set
189
+ # ============================================================
190
+ print("\n[Step 5] Evaluating on test set...")
191
+
192
+ test_results = trainer.evaluate(test_dataset)
193
+
194
+ print(f"✓ Test Results:")
195
+ for metric, value in test_results.items():
196
+ if isinstance(value, float):
197
+ print(f" - {metric}: {value:.4f}")
198
+
199
+ # ============================================================
200
+ # Step 6: Make predictions
201
+ # ============================================================
202
+ print("\n[Step 6] Making predictions...")
203
+
204
+ predictions = trainer.predict(test_dataset)
205
+ predicted_classes = np.argmax(predictions.predictions, axis=1)
206
+
207
+ print(f"✓ Predictions made:")
208
+ print(f" - Predicted classes: {len(predicted_classes)} samples")
209
+ print(f" - Class distribution: {np.bincount(predicted_classes)}")
210
+
211
+ # ============================================================
212
+ # Step 7: Save model
213
+ # ============================================================
214
+ print("\n[Step 7] Saving model...")
215
+
216
+ save_dir = "./my_genemamba_classifier"
217
+ model.save_pretrained(save_dir)
218
+ print(f"✓ Model saved to '{save_dir}'")
219
+
220
+ # ============================================================
221
+ # Step 8: Load and test saved model
222
+ # ============================================================
223
+ print("\n[Step 8] Testing model reloading...")
224
+
225
+ loaded_model = AutoModelForSequenceClassification.from_pretrained(
226
+ save_dir,
227
+ trust_remote_code=True,
228
+ )
229
+ loaded_model.eval()
230
+
231
+ # Test on a single batch
232
+ with torch.no_grad():
233
+ sample_input = torch.randint(2, 25426, (1, 2048))
234
+ output = loaded_model(sample_input)
235
+ logits = output.logits
236
+ prediction = torch.argmax(logits, dim=1)
237
+
238
+ print(f"✓ Loaded model test prediction: class {prediction.item()}")
239
+
240
+ print("\n" + "=" * 80)
241
+ print("Phase 2 Complete! Model ready for deployment.")
242
+ print("=" * 80)
243
+
244
+ return model, trainer
245
+
246
+
247
+ if __name__ == "__main__":
248
+ model, trainer = main()
examples/downstream/11_zero_shot_logreg.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Zero-shot downstream baseline:
3
+ 1) Extract frozen GeneMamba embeddings
4
+ 2) Train LogisticRegression on train split
5
+ 3) Evaluate on test split
6
+
7
+ Expected h5ad columns:
8
+ - obs['celltype']
9
+ - obs['partition'] with values in {'train', 'test'}
10
+ """
11
+
12
+ import argparse
13
+ import numpy as np
14
+ import scanpy as sc
15
+ import torch
16
+ from sklearn.linear_model import LogisticRegression
17
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
18
+ from sklearn.preprocessing import LabelEncoder
19
+ from transformers import AutoModel
20
+
21
+
22
+ def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1):
23
+ gene_names = np.array(adata.var_names)
24
+ X = adata.X
25
+ out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64)
26
+
27
+ for i in range(adata.n_obs):
28
+ row = X[i]
29
+ if hasattr(row, "toarray"):
30
+ expr = row.toarray().ravel()
31
+ else:
32
+ expr = np.asarray(row).ravel()
33
+
34
+ nz = np.where(expr > 0)[0]
35
+ if len(nz) == 0:
36
+ continue
37
+
38
+ genes = gene_names[nz]
39
+ vals = expr[nz]
40
+ order = np.argsort(-vals)
41
+ ranked_genes = genes[order]
42
+ ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len]
43
+ out[i, : len(ids)] = ids
44
+
45
+ return out
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("--model_path", required=True)
51
+ parser.add_argument("--h5ad", required=True)
52
+ parser.add_argument("--symbol2id_npy", default=None, help="Optional .npy dumped dict path")
53
+ parser.add_argument("--seq_len", type=int, default=2048)
54
+ parser.add_argument("--batch_size", type=int, default=64)
55
+ args = parser.parse_args()
56
+
57
+ adata = sc.read_h5ad(args.h5ad)
58
+ assert "celltype" in adata.obs, "h5ad must include obs['celltype']"
59
+ assert "partition" in adata.obs, "h5ad must include obs['partition']"
60
+
61
+ if args.symbol2id_npy is None:
62
+ raise ValueError("Please provide --symbol2id_npy (dict saved by np.save(..., allow_pickle=True))")
63
+
64
+ symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item()
65
+
66
+ input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len)
67
+ labels = LabelEncoder().fit_transform(adata.obs["celltype"].values)
68
+
69
+ model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
70
+ model.eval().cuda()
71
+
72
+ embeds = []
73
+ with torch.no_grad():
74
+ for s in range(0, input_ids.shape[0], args.batch_size):
75
+ batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda")
76
+ out = model(batch)
77
+ embeds.append(out.pooled_embedding.detach().cpu().numpy())
78
+ embeds = np.concatenate(embeds, axis=0)
79
+
80
+ train_mask = adata.obs["partition"].values == "train"
81
+ test_mask = adata.obs["partition"].values == "test"
82
+
83
+ X_train, y_train = embeds[train_mask], labels[train_mask]
84
+ X_test, y_test = embeds[test_mask], labels[test_mask]
85
+
86
+ clf = LogisticRegression(max_iter=2000)
87
+ clf.fit(X_train, y_train)
88
+ pred = clf.predict(X_test)
89
+
90
+ print("accuracy:", accuracy_score(y_test, pred))
91
+ print("micro_f1:", f1_score(y_test, pred, average="micro"))
92
+ print("macro_f1:", f1_score(y_test, pred, average="macro"))
93
+ print("precision_weighted:", precision_score(y_test, pred, average="weighted", zero_division=0))
94
+ print("recall_weighted:", recall_score(y_test, pred, average="weighted", zero_division=0))
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
examples/downstream/12_batch_integration_eval.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch integration downstream example:
3
+ - Extract embeddings with frozen GeneMamba
4
+ - Evaluate simple batch mixing score proxy (silhouette by batch)
5
+
6
+ Expected h5ad columns:
7
+ - obs['batch']
8
+ """
9
+
10
+ import argparse
11
+ import numpy as np
12
+ import scanpy as sc
13
+ import torch
14
+ from sklearn.metrics import silhouette_score
15
+ from sklearn.preprocessing import LabelEncoder
16
+ from transformers import AutoModel
17
+
18
+
19
+ def build_ranked_input_ids(adata, symbol2id, seq_len=2048, pad_id=1):
20
+ gene_names = np.array(adata.var_names)
21
+ X = adata.X
22
+ out = np.full((adata.n_obs, seq_len), pad_id, dtype=np.int64)
23
+
24
+ for i in range(adata.n_obs):
25
+ row = X[i]
26
+ if hasattr(row, "toarray"):
27
+ expr = row.toarray().ravel()
28
+ else:
29
+ expr = np.asarray(row).ravel()
30
+
31
+ nz = np.where(expr > 0)[0]
32
+ if len(nz) == 0:
33
+ continue
34
+
35
+ genes = gene_names[nz]
36
+ vals = expr[nz]
37
+ order = np.argsort(-vals)
38
+ ranked_genes = genes[order]
39
+ ids = [symbol2id[g] for g in ranked_genes if g in symbol2id][:seq_len]
40
+ out[i, : len(ids)] = ids
41
+
42
+ return out
43
+
44
+
45
+ def main():
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--model_path", required=True)
48
+ parser.add_argument("--h5ad", required=True)
49
+ parser.add_argument("--symbol2id_npy", required=True)
50
+ parser.add_argument("--seq_len", type=int, default=2048)
51
+ parser.add_argument("--batch_size", type=int, default=64)
52
+ args = parser.parse_args()
53
+
54
+ adata = sc.read_h5ad(args.h5ad)
55
+ assert "batch" in adata.obs, "h5ad must include obs['batch']"
56
+
57
+ symbol2id = np.load(args.symbol2id_npy, allow_pickle=True).item()
58
+ input_ids = build_ranked_input_ids(adata, symbol2id, seq_len=args.seq_len)
59
+
60
+ model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True)
61
+ model.eval().cuda()
62
+
63
+ embeds = []
64
+ with torch.no_grad():
65
+ for s in range(0, input_ids.shape[0], args.batch_size):
66
+ batch = torch.tensor(input_ids[s : s + args.batch_size], dtype=torch.long, device="cuda")
67
+ out = model(batch)
68
+ embeds.append(out.pooled_embedding.detach().cpu().numpy())
69
+ embeds = np.concatenate(embeds, axis=0)
70
+
71
+ batch_labels = LabelEncoder().fit_transform(adata.obs["batch"].values)
72
+ score = silhouette_score(embeds, batch_labels, metric="euclidean")
73
+
74
+ print("silhouette_by_batch:", score)
75
+ print("(Closer to 0 typically indicates better batch mixing than very high positive values.)")
76
+
77
+
78
+ if __name__ == "__main__":
79
+ main()
examples/downstream/20_continue_pretraining_reference.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3: Continue Pretraining
3
+ Demonstrates how to continue pretraining GeneMamba on your own data using masked LM objective.
4
+
5
+ Usage:
6
+ python examples/3_continue_pretraining.py
7
+ """
8
+
9
+ import torch
10
+ import numpy as np
11
+ from torch.utils.data import Dataset
12
+ from transformers import (
13
+ AutoModelForMaskedLM,
14
+ AutoTokenizer,
15
+ Trainer,
16
+ TrainingArguments,
17
+ DataCollatorForLanguageModeling,
18
+ )
19
+
20
+
21
+ class PretrainingDataset(Dataset):
22
+ """
23
+ Dataset for pretraining/continued pretraining.
24
+ Loads sequences and their lengths.
25
+ """
26
+
27
+ def __init__(self, input_ids_list, max_length=2048):
28
+ self.input_ids_list = input_ids_list
29
+ self.max_length = max_length
30
+
31
+ def __len__(self):
32
+ return len(self.input_ids_list)
33
+
34
+ def __getitem__(self, idx):
35
+ input_ids = self.input_ids_list[idx]
36
+
37
+ # Pad or truncate to max_length
38
+ if len(input_ids) >= self.max_length:
39
+ input_ids = input_ids[:self.max_length]
40
+ else:
41
+ input_ids = np.pad(
42
+ input_ids,
43
+ (0, self.max_length - len(input_ids)),
44
+ constant_values=1 # Pad token ID
45
+ )
46
+
47
+ return {
48
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
49
+ }
50
+
51
+
52
+ def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
53
+ """Create mock single-cell sequences for pretraining."""
54
+
55
+ print("Creating mock pretraining dataset...")
56
+
57
+ # Create ranked gene sequences
58
+ # In practice, these would come from your scRNA-seq data
59
+ sequences = []
60
+ for _ in range(n_sequences):
61
+ # Random ranked sequence
62
+ seq = np.random.randint(2, 25426, seq_len)
63
+ sequences.append(seq)
64
+
65
+ print(f"✓ Created {n_sequences} sequences of length {seq_len}")
66
+
67
+ return sequences
68
+
69
+
70
+ def main():
71
+ print("=" * 80)
72
+ print("GeneMamba Phase 3: Continue Pretraining")
73
+ print("=" * 80)
74
+
75
+ # ============================================================
76
+ # Step 1: Load pretrained model for masked LM
77
+ # ============================================================
78
+ print("\n[Step 1] Loading model for masked LM...")
79
+
80
+ try:
81
+ model = AutoModelForMaskedLM.from_pretrained(
82
+ "GeneMamba-24l-512d",
83
+ trust_remote_code=True,
84
+ local_files_only=True,
85
+ )
86
+ tokenizer = AutoTokenizer.from_pretrained(
87
+ "GeneMamba-24l-512d",
88
+ trust_remote_code=True,
89
+ local_files_only=True,
90
+ )
91
+ except Exception as e:
92
+ print(f"Note: Could not load from hub ({e})")
93
+ print("Using local initialization...")
94
+
95
+ # Initialize locally
96
+ from configuration_genemamba import GeneMambaConfig
97
+ from modeling_genemamba import GeneMambaForMaskedLM
98
+
99
+ config = GeneMambaConfig(
100
+ vocab_size=25426,
101
+ hidden_size=512,
102
+ num_hidden_layers=24,
103
+ )
104
+ model = GeneMambaForMaskedLM(config)
105
+ tokenizer = None
106
+
107
+ print(f"✓ Model loaded")
108
+ print(f" - Architecture: {model.config.num_hidden_layers} layers, "
109
+ f"hidden_size={model.config.hidden_size}")
110
+
111
+ # ============================================================
112
+ # Step 2: Prepare pretraining data
113
+ # ============================================================
114
+ print("\n[Step 2] Preparing pretraining dataset...")
115
+
116
+ sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
117
+
118
+ # Split train/eval
119
+ train_size = int(0.9 * len(sequences))
120
+ train_sequences = sequences[:train_size]
121
+ eval_sequences = sequences[train_size:]
122
+
123
+ train_dataset = PretrainingDataset(train_sequences)
124
+ eval_dataset = PretrainingDataset(eval_sequences)
125
+
126
+ print(f"✓ Datasets created:")
127
+ print(f" - Training: {len(train_dataset)} samples")
128
+ print(f" - Evaluation: {len(eval_dataset)} samples")
129
+
130
+ # ============================================================
131
+ # Step 3: Set up data collator for MLM
132
+ # ============================================================
133
+ print("\n[Step 3] Setting up data collator...")
134
+
135
+ if tokenizer is not None:
136
+ data_collator = DataCollatorForLanguageModeling(
137
+ tokenizer=tokenizer,
138
+ mlm=True,
139
+ mlm_probability=0.15, # Mask 15% of tokens
140
+ )
141
+ else:
142
+ # Custom collator if no tokenizer available
143
+ class CustomDataCollator:
144
+ def __call__(self, batch):
145
+ input_ids = torch.stack([item["input_ids"] for item in batch])
146
+
147
+ # Create masked labels (for MLM loss)
148
+ labels = input_ids.clone()
149
+ mask = torch.rand(input_ids.shape) < 0.15
150
+
151
+ # Set input to [MASK] token (id=0)
152
+ input_ids[mask] = 0
153
+
154
+ # Set labels to -100 where not masked (loss ignores these)
155
+ labels[~mask] = -100
156
+
157
+ return {"input_ids": input_ids, "labels": labels}
158
+
159
+ data_collator = CustomDataCollator()
160
+
161
+ print(f"✓ Data collator ready (MLM probability: 0.15)")
162
+
163
+ # ============================================================
164
+ # Step 4: Set up training arguments
165
+ # ============================================================
166
+ print("\n[Step 4] Setting up training...")
167
+
168
+ output_dir = "./pretrain_results"
169
+
170
+ training_args = TrainingArguments(
171
+ output_dir=output_dir,
172
+ num_train_epochs=2,
173
+ per_device_train_batch_size=16,
174
+ per_device_eval_batch_size=16,
175
+ learning_rate=2e-5,
176
+ weight_decay=0.01,
177
+ warmup_steps=500,
178
+ logging_steps=100,
179
+ eval_strategy="epoch",
180
+ save_strategy="epoch",
181
+ load_best_model_at_end=True,
182
+ metric_for_best_model="eval_loss",
183
+ report_to="none", # Disable W&B
184
+ seed=42,
185
+ )
186
+
187
+ print(f"✓ Training config:")
188
+ print(f" - Output dir: {output_dir}")
189
+ print(f" - Epochs: {training_args.num_train_epochs}")
190
+ print(f" - Batch size: {training_args.per_device_train_batch_size}")
191
+ print(f" - Learning rate: {training_args.learning_rate}")
192
+ print(f" - MLM masking: 15%")
193
+
194
+ # ============================================================
195
+ # Step 5: Train
196
+ # ============================================================
197
+ print("\n[Step 5] Starting continued pretraining...")
198
+
199
+ trainer = Trainer(
200
+ model=model,
201
+ args=training_args,
202
+ train_dataset=train_dataset,
203
+ eval_dataset=eval_dataset,
204
+ data_collator=data_collator,
205
+ )
206
+
207
+ train_result = trainer.train()
208
+
209
+ print(f"✓ Training complete!")
210
+ print(f" - Final training loss: {train_result.training_loss:.4f}")
211
+
212
+ # ============================================================
213
+ # Step 6: Evaluate
214
+ # ============================================================
215
+ print("\n[Step 6] Evaluating on held-out set...")
216
+
217
+ eval_results = trainer.evaluate()
218
+
219
+ print(f"✓ Evaluation Results:")
220
+ for metric, value in eval_results.items():
221
+ if isinstance(value, (int, float)):
222
+ print(f" - {metric}: {value:.4f}")
223
+
224
+ # ============================================================
225
+ # Step 7: Save model
226
+ # ============================================================
227
+ print("\n[Step 7] Saving continued pretrained model...")
228
+
229
+ save_dir = "./genemamba_continued_pretrain"
230
+ model.save_pretrained(save_dir)
231
+ if tokenizer is not None:
232
+ tokenizer.save_pretrained(save_dir)
233
+
234
+ print(f"✓ Model saved to '{save_dir}'")
235
+
236
+ # ============================================================
237
+ # Step 8: Test model inference
238
+ # ============================================================
239
+ print("\n[Step 8] Testing inference on masked input...")
240
+
241
+ model.eval()
242
+
243
+ # Create sample input with masked tokens
244
+ sample_input = torch.randint(2, 25426, (1, 2048))
245
+ sample_input[0, :10] = 0 # Mask first 10 tokens
246
+
247
+ with torch.no_grad():
248
+ outputs = model(sample_input)
249
+ logits = outputs.logits
250
+ predictions = torch.argmax(logits, dim=-1)
251
+
252
+ print(f"✓ Sample predictions generated")
253
+ print(f" - Input shape: {sample_input.shape}")
254
+ print(f" - Output logits shape: {logits.shape}")
255
+ print(f" - Top predicted genes (tokens): {predictions[0, :10].tolist()}")
256
+
257
+ print("\n" + "=" * 80)
258
+ print("Phase 3 Complete! Model ready for downstream tasks or further training.")
259
+ print("=" * 80)
260
+
261
+ return model, trainer
262
+
263
+
264
+ if __name__ == "__main__":
265
+ model, trainer = main()
examples/downstream/21_pretrain_from_scratch_reference.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 4: Train from Scratch
3
+ Demonstrates how to initialize and train a GeneMamba model from scratch.
4
+
5
+ Usage:
6
+ python examples/4_pretrain_from_scratch.py
7
+ """
8
+
9
+ import torch
10
+ import numpy as np
11
+ from torch.utils.data import Dataset
12
+ from transformers import (
13
+ AutoConfig,
14
+ Trainer,
15
+ TrainingArguments,
16
+ DataCollatorForLanguageModeling,
17
+ )
18
+
19
+
20
+ class PretrainingDataset(Dataset):
21
+ """Dataset for pretraining."""
22
+
23
+ def __init__(self, input_ids_list, max_length=2048):
24
+ self.input_ids_list = input_ids_list
25
+ self.max_length = max_length
26
+
27
+ def __len__(self):
28
+ return len(self.input_ids_list)
29
+
30
+ def __getitem__(self, idx):
31
+ input_ids = self.input_ids_list[idx]
32
+
33
+ # Pad or truncate
34
+ if len(input_ids) >= self.max_length:
35
+ input_ids = input_ids[:self.max_length]
36
+ else:
37
+ input_ids = np.pad(
38
+ input_ids,
39
+ (0, self.max_length - len(input_ids)),
40
+ constant_values=1
41
+ )
42
+
43
+ return {
44
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
45
+ }
46
+
47
+
48
+ def create_mock_pretraining_data(n_sequences=5000, seq_len=2048):
49
+ """Create mock pretraining data."""
50
+
51
+ print("Creating mock pretraining dataset for from-scratch training...")
52
+
53
+ sequences = []
54
+ for _ in range(n_sequences):
55
+ seq = np.random.randint(2, 25426, seq_len)
56
+ sequences.append(seq)
57
+
58
+ print(f"✓ Created {n_sequences} sequences")
59
+
60
+ return sequences
61
+
62
+
63
+ def main():
64
+ print("=" * 80)
65
+ print("GeneMamba Phase 4: Train from Scratch")
66
+ print("=" * 80)
67
+
68
+ # ============================================================
69
+ # Step 1: Create config from scratch
70
+ # ============================================================
71
+ print("\n[Step 1] Creating model configuration...")
72
+
73
+ from configuration_genemamba import GeneMambaConfig
74
+ from modeling_genemamba import GeneMambaForMaskedLM
75
+
76
+ config = GeneMambaConfig(
77
+ vocab_size=25426,
78
+ hidden_size=256, # Smaller for faster demo
79
+ num_hidden_layers=12, # Reduced for demo
80
+ intermediate_size=1024,
81
+ max_position_embeddings=2048,
82
+ mamba_mode="gate",
83
+ embedding_pooling="mean",
84
+ num_labels=2,
85
+ hidden_dropout_prob=0.1,
86
+ initializer_range=0.02,
87
+ )
88
+
89
+ print(f"✓ Config created:")
90
+ print(f" - Model type: {config.model_type}")
91
+ print(f" - Hidden size: {config.hidden_size}")
92
+ print(f" - Num layers: {config.num_hidden_layers}")
93
+ print(f" - Vocab size: {config.vocab_size}")
94
+ print(f" - Mode: {config.mamba_mode}")
95
+
96
+ # ============================================================
97
+ # Step 2: Initialize model from config
98
+ # ============================================================
99
+ print("\n[Step 2] Initializing model from config...")
100
+
101
+ model = GeneMambaForMaskedLM(config)
102
+
103
+ # Count parameters
104
+ total_params = sum(p.numel() for p in model.parameters())
105
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
106
+
107
+ print(f"✓ Model initialized:")
108
+ print(f" - Total parameters: {total_params / 1e6:.2f}M")
109
+ print(f" - Trainable parameters: {trainable_params / 1e6:.2f}M")
110
+
111
+ # ============================================================
112
+ # Step 3: Prepare data
113
+ # ============================================================
114
+ print("\n[Step 3] Preparing training data...")
115
+
116
+ sequences = create_mock_pretraining_data(n_sequences=5000, seq_len=2048)
117
+
118
+ # Split
119
+ train_size = int(0.8 * len(sequences))
120
+ train_sequences = sequences[:train_size]
121
+ eval_sequences = sequences[train_size:]
122
+
123
+ train_dataset = PretrainingDataset(train_sequences)
124
+ eval_dataset = PretrainingDataset(eval_sequences)
125
+
126
+ print(f"✓ Datasets created:")
127
+ print(f" - Train: {len(train_dataset)}")
128
+ print(f" - Eval: {len(eval_dataset)}")
129
+
130
+ # ============================================================
131
+ # Step 4: Data collator for MLM
132
+ # ============================================================
133
+ print("\n[Step 4] Setting up data collator...")
134
+
135
+ class CustomDataCollator:
136
+ """Custom collator for MLM without tokenizer."""
137
+
138
+ def __call__(self, batch):
139
+ input_ids = torch.stack([item["input_ids"] for item in batch])
140
+
141
+ # Create labels for MLM
142
+ labels = input_ids.clone()
143
+
144
+ # Mask 15% of tokens
145
+ mask = torch.rand(input_ids.shape) < 0.15
146
+ input_ids[mask] = 0 # [MASK] token
147
+
148
+ # Don't compute loss on non-masked tokens
149
+ labels[~mask] = -100
150
+
151
+ return {"input_ids": input_ids, "labels": labels}
152
+
153
+ data_collator = CustomDataCollator()
154
+ print(f"✓ Data collator ready")
155
+
156
+ # ============================================================
157
+ # Step 5: Training arguments
158
+ # ============================================================
159
+ print("\n[Step 5] Setting up training...")
160
+
161
+ output_dir = "./from_scratch_pretrain"
162
+
163
+ training_args = TrainingArguments(
164
+ output_dir=output_dir,
165
+ num_train_epochs=5,
166
+ per_device_train_batch_size=16,
167
+ per_device_eval_batch_size=16,
168
+ learning_rate=5e-4,
169
+ weight_decay=0.01,
170
+ warmup_steps=500,
171
+ logging_steps=50,
172
+ eval_strategy="epoch",
173
+ save_strategy="epoch",
174
+ load_best_model_at_end=True,
175
+ metric_for_best_model="eval_loss",
176
+ report_to="none",
177
+ seed=42,
178
+ optim="adamw_torch",
179
+ gradient_accumulation_steps=1,
180
+ max_grad_norm=1.0,
181
+ )
182
+
183
+ print(f"✓ Training config:")
184
+ print(f" - Output: {output_dir}")
185
+ print(f" - Epochs: {training_args.num_train_epochs}")
186
+ print(f" - Batch size: {training_args.per_device_train_batch_size}")
187
+ print(f" - Learning rate: {training_args.learning_rate}")
188
+
189
+ # ============================================================
190
+ # Step 6: Train
191
+ # ============================================================
192
+ print("\n[Step 6] Starting training from scratch...")
193
+ print("(This may take a while. In practice, use more GPUs/data for real pretraining)")
194
+
195
+ trainer = Trainer(
196
+ model=model,
197
+ args=training_args,
198
+ train_dataset=train_dataset,
199
+ eval_dataset=eval_dataset,
200
+ data_collator=data_collator,
201
+ )
202
+
203
+ train_result = trainer.train()
204
+
205
+ print(f"✓ Training complete!")
206
+ print(f" - Final training loss: {train_result.training_loss:.4f}")
207
+
208
+ # ============================================================
209
+ # Step 7: Evaluate
210
+ # ============================================================
211
+ print("\n[Step 7] Evaluating...")
212
+
213
+ eval_results = trainer.evaluate()
214
+
215
+ print(f"✓ Evaluation Results:")
216
+ for metric, value in eval_results.items():
217
+ if isinstance(value, (int, float)):
218
+ print(f" - {metric}: {value:.4f}")
219
+
220
+ # ============================================================
221
+ # Step 8: Save model and config
222
+ # ============================================================
223
+ print("\n[Step 8] Saving model...")
224
+
225
+ save_dir = "./my_genemamba_from_scratch"
226
+ model.save_pretrained(save_dir)
227
+ config.save_pretrained(save_dir)
228
+
229
+ print(f"✓ Model and config saved to '{save_dir}'")
230
+ print(f" Files created:")
231
+ print(f" - config.json")
232
+ print(f" - model.safetensors (or pytorch_model.bin)")
233
+
234
+ # ============================================================
235
+ # Step 9: Reload and verify
236
+ # ============================================================
237
+ print("\n[Step 9] Reloading model from checkpoint...")
238
+
239
+ from transformers import AutoModelForMaskedLM
240
+
241
+ loaded_model = AutoModelForMaskedLM.from_pretrained(
242
+ save_dir,
243
+ trust_remote_code=True,
244
+ )
245
+
246
+ loaded_model.eval()
247
+
248
+ # Test inference
249
+ with torch.no_grad():
250
+ sample_input = torch.randint(2, 25426, (2, 2048))
251
+ sample_input[:, :10] = 0 # Mask first 10 tokens
252
+
253
+ outputs = loaded_model(sample_input)
254
+ logits = outputs.logits
255
+
256
+ print(f"✓ Model reloaded and tested!")
257
+ print(f" - Input shape: {sample_input.shape}")
258
+ print(f" - Logits shape: {logits.shape}")
259
+
260
+ # ============================================================
261
+ # Step 10: Optional - Convert to different format
262
+ # ============================================================
263
+ print("\n[Step 10] Model ready for conversion/deployment!")
264
+ print(f"✓ You can now:")
265
+ print(f" 1. Push to Hugging Face Hub:")
266
+ print(f" model.push_to_hub('your-username/GeneMamba-custom')")
267
+ print(f" 2. Use with downstream tasks:")
268
+ print(f" AutoModelForSequenceClassification.from_pretrained('{save_dir}', num_labels=N)")
269
+ print(f" 3. Extract embeddings:")
270
+ print(f" AutoModel.from_pretrained('{save_dir}')")
271
+
272
+ print("\n" + "=" * 80)
273
+ print("Phase 4 Complete! Model trained from scratch and ready to use.")
274
+ print("=" * 80)
275
+
276
+ return model, trainer, config
277
+
278
+
279
+ if __name__ == "__main__":
280
+ model, trainer, config = main()
examples/downstream/README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Downstream Examples
2
+
3
+ This folder now contains both **ready-to-run** examples and **legacy scripts** from the original GeneMamba project.
4
+
5
+ ## Ready-to-run scripts
6
+
7
+ - `10_finetune_classification.py`
8
+ Fine-tune `AutoModelForSequenceClassification` for cell-type annotation.
9
+
10
+ - `11_zero_shot_logreg.py`
11
+ Freeze GeneMamba, extract `pooled_embedding`, train LogisticRegression on train split, evaluate on test split.
12
+
13
+ - `12_batch_integration_eval.py`
14
+ Batch integration proxy evaluation using silhouette score by `obs['batch']`.
15
+
16
+ ## Reference training scripts
17
+
18
+ - `20_continue_pretraining_reference.py`
19
+ - `21_pretrain_from_scratch_reference.py`
20
+
21
+ ## Legacy scripts from original repo
22
+
23
+ - `legacy_from_gene_mamba/mamba2_classification_finetune_with_label.py`
24
+ - `legacy_from_gene_mamba/mamba2_classification_finetune_without_label.py`
25
+ - `legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py`
26
+
27
+ ## Required h5ad conventions
28
+
29
+ For downstream compatibility, standardize columns in `adata.obs`:
30
+
31
+ - `celltype` for label
32
+ - `batch` for batch id
33
+ - `partition` in `{train, test}` for train/test split
34
+
35
+ This matches conventions described in the original `dataset/downstream/README.md`.
examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_with_label.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import torch
3
+ from transformers import Trainer
4
+ import os
5
+
6
+ import pyarrow as pa
7
+ import pandas as pd
8
+ import numpy as np
9
+
10
+ from matplotlib import pyplot as plt
11
+
12
+ from torch.utils.data import Dataset
13
+ from transformers import AutoTokenizer, TrainingArguments
14
+
15
+ import argparse
16
+
17
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
18
+ from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
19
+
20
+ from dotmap import DotMap
21
+
22
+ import sys
23
+ import os
24
+ import torch
25
+
26
+ # from trange import trange
27
+
28
+ sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
29
+ from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
30
+ from utils import permute_genes_by_expression
31
+ from utils2 import standardize_columns
32
+
33
+ import importlib
34
+ importlib.reload(sys.modules['models'])
35
+ importlib.reload(sys.modules['utils'])
36
+ importlib.reload(sys.modules['utils2'])
37
+
38
+
39
+ # %%
40
+ DATA_PATH = "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/"
41
+ # CHECKPOINT_PATH = "/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_48l_512d/1/3m/checkpoint-31250"
42
+ TOKENIZER_PATH = "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json"
43
+ SAVE_PATH = "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/embeddings/cell"
44
+
45
+ # %%
46
+ import argparse
47
+
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument("--dataset_name", type=str)
50
+ parser.add_argument("--ckpt_path", type = str)
51
+ parser.add_argument("--seq_len", type=int, default=2048)
52
+ parser.add_argument("--batch_size", type=int, default=24)
53
+ parser.add_argument("--num_epochs", type=int, default=5)
54
+ parser.add_argument("--test_size", type = float, default=0.1)
55
+ parser.add_argument("--split", type=lambda x: x.lower() in ["true", "1", "yes"], default=False,)
56
+
57
+ args = parser.parse_args()
58
+
59
+
60
+ # args = DotMap({
61
+ # "dataset_name": "ms",
62
+ # "seq_len": 512,
63
+ # "batch_size": 24,
64
+ # "num_epochs": 5,
65
+ # "test_size": 0.1
66
+ # })
67
+
68
+
69
+ #%%
70
+ CHECKPOINT_PATH = args.ckpt_path
71
+ model_name = CHECKPOINT_PATH.split("/")[-4]
72
+ mamba_layer = int(model_name.split("_")[1][:-1])
73
+ d_model = int(model_name.split("_")[2][:-1])
74
+
75
+
76
+ # make the sub directories to save the results
77
+ SAVE_PATH = os.path.join(SAVE_PATH, model_name)
78
+ sub_directories = ["predictions", "metrics", "figures", "repr"]
79
+ for sub_dir in sub_directories:
80
+ os.makedirs(os.path.join(SAVE_PATH, sub_dir), exist_ok=True)
81
+
82
+
83
+ # %%
84
+ import scanpy as sc
85
+
86
+ # Load the .h5ad file
87
+ dataset_name = args.dataset_name
88
+ # adata = sc.read_h5ad(os.path.join(DATA_PATH ,f'{dataset_name}.h5ad'))
89
+
90
+ adata = None
91
+
92
+ if args.split:
93
+ adata = sc.read_h5ad(os.path.join(DATA_PATH ,f'split/{dataset_name}_split.h5ad'))
94
+ print(f"Read data from {dataset_name}_split.h5ad")
95
+ dataset_name = dataset_name + "_split"
96
+ else:
97
+ adata = sc.read_h5ad(os.path.join(DATA_PATH ,f'processed/{dataset_name}_processed.h5ad'))
98
+ print(f"Read data from {dataset_name}_processed.h5ad")
99
+
100
+ # Display basic information about the data
101
+ print(adata)
102
+
103
+ # %%
104
+ # adata = standardize_columns(adata, dataset_name)
105
+ # assert "batch" in adata.obs.columns and "celltype" in adata.obs.columns
106
+
107
+ # %%
108
+ from sklearn.preprocessing import LabelEncoder
109
+
110
+ y_names = np.array(adata.obs['celltype'].values.tolist())
111
+
112
+ label_encoder = LabelEncoder()
113
+ y = label_encoder.fit_transform(y_names)
114
+
115
+ num_class = len(label_encoder.classes_)
116
+
117
+ # %%
118
+ from transformers import PretrainedConfig
119
+
120
+ config = PretrainedConfig.from_dict({
121
+ "d_model": d_model,
122
+ "mamba_layer": mamba_layer,
123
+ })
124
+
125
+
126
+ # %%
127
+ model_cell_cls = GeneMamba2ForCellClassification(config, model_path=CHECKPOINT_PATH, tokenizer_path = TOKENIZER_PATH, args=None, output_dim_cls = num_class, hidden_dim= 512, num_layers_cls = 4)
128
+
129
+ # %%
130
+ permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model_cell_cls.tokenizer, model_cell_cls.symbol2id)
131
+ permuted_gene_ids
132
+
133
+ # %%
134
+ seq_len = args.seq_len
135
+
136
+ input_data = permuted_gene_ids[:, :seq_len]
137
+
138
+ # %%
139
+ model_cell_cls.tokenizer.cls_token_id
140
+
141
+ # %%
142
+ torch.tensor([model_cell_cls.tokenizer.cls_token_id for _ in range(input_data.shape[0])])
143
+
144
+ # %%
145
+ model_cell_cls.tokenizer.cls_token_id
146
+
147
+ # %%
148
+ input_data.shape[0]
149
+
150
+ # %%
151
+ input_data
152
+
153
+ # %%
154
+ # add the cls token to the input data
155
+ input_data = np.hstack([np.array([model_cell_cls.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
156
+ input_data
157
+
158
+ # %%
159
+ input_data.shape
160
+
161
+ #%%
162
+ from sklearn.model_selection import train_test_split
163
+ import numpy as np
164
+
165
+ def manual_stratified_split(X, y, test_size=0.1, random_state=None):
166
+ # separate the samples for each class
167
+ unique_classes = np.unique(y)
168
+ X_train, X_test, y_train, y_test = [], [], [], []
169
+
170
+ for cls in unique_classes:
171
+ cls_indices = np.where(y == cls)[0]
172
+
173
+ if len(cls_indices) > 1:
174
+
175
+ cls_train, cls_test = train_test_split(cls_indices, test_size=test_size, random_state=random_state)
176
+ else:
177
+ # if a class has only one sample, put it in the training set
178
+ cls_train, cls_test = cls_indices, []
179
+
180
+ X_train.extend(X[cls_train])
181
+ y_train.extend(y[cls_train])
182
+ X_test.extend(X[cls_test])
183
+ y_test.extend(y[cls_test])
184
+
185
+ return np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)
186
+
187
+ # %%
188
+ # from sklearn.model_selection import train_test_split
189
+
190
+ # X_train, X_test, y_train, y_test = manual_stratified_split(input_data, y, test_size=args.test_size, random_state=42)
191
+
192
+
193
+ #%%
194
+ # train and test split is done and stored in the adata.obs["partition"] column, so we can extract the train and test data from there
195
+
196
+ X_train = input_data[adata.obs["partition"] == "train"]
197
+ X_test = input_data[adata.obs["partition"] == "test"]
198
+ y_train = y[adata.obs["partition"] == "train"]
199
+ y_test = y[adata.obs["partition"] == "test"]
200
+
201
+ X_train.shape, X_test.shape, y_train.shape, y_test.shape
202
+
203
+ # %%
204
+ from torch.utils.data import DataLoader, Dataset
205
+
206
+ class GeneDataset(Dataset):
207
+ def __init__(self, data, y):
208
+ self.data = data
209
+ self.labels = y
210
+
211
+ def __len__(self):
212
+ return len(self.data)
213
+
214
+ def __getitem__(self, idx):
215
+ return self.data[idx], self.labels[idx]
216
+
217
+ train_dataset = GeneDataset(X_train, y_train)
218
+ test_dataset = GeneDataset(X_test, y_test)
219
+ all_dataset = GeneDataset(input_data, y)
220
+
221
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
222
+ test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
223
+ all_loader = DataLoader(all_dataset, batch_size=args.batch_size, shuffle=False)
224
+
225
+ # %%
226
+ from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
227
+
228
+ # %%
229
+ def compute_metrics(y_pred, y_prob, y_true):
230
+
231
+ metrics = {
232
+ "accuracy": accuracy_score(y_true, y_pred),
233
+ "Micro-F1 score": f1_score(y_true, y_pred, average='micro'),
234
+ "Macro-F1 score": f1_score(y_true, y_pred, average='macro'),
235
+ "precision": precision_score(y_true, y_pred, average='weighted'),
236
+ "recall": recall_score(y_true, y_pred, average='weighted'),
237
+ # "auc_roc": roc_auc_score(y_true, y_prob, multi_class = 'ovr'),
238
+ }
239
+ return metrics
240
+
241
+ # %%
242
+ epochs = args.num_epochs
243
+ optimizer = torch.optim.Adam(model_cell_cls.parameters(), lr=1e-4)
244
+ loss = torch.nn.CrossEntropyLoss()
245
+
246
+ for epoch in range(epochs):
247
+ model_cell_cls.train()
248
+ for i, batch in enumerate(train_loader):
249
+ data = batch[0]
250
+ target = batch[1]
251
+ data = data.to(model_cell_cls.device)
252
+ target = target.to(model_cell_cls.device)
253
+ model_cell_cls = model_cell_cls.to(model_cell_cls.device)
254
+
255
+ optimizer.zero_grad()
256
+ output = model_cell_cls(data, None)
257
+ loss_val = loss(output, target)
258
+ loss_val.backward()
259
+ optimizer.step()
260
+ if i % 10 == 0:
261
+ print(f"Epoch {epoch}, Iteration {i}, Loss: {loss_val}")
262
+
263
+ model_cell_cls.eval()
264
+ with torch.no_grad():
265
+ # add code to compute the metrics
266
+ pred_prob = []
267
+ pred_label = []
268
+ targets = []
269
+ cell_repr = []
270
+
271
+ for i, batch in enumerate(test_loader):
272
+ data = batch[0]
273
+ target = batch[1]
274
+ data = data.to(model_cell_cls.device)
275
+ target = target.to(model_cell_cls.device)
276
+ model_cell_cls = model_cell_cls.to(model_cell_cls.device)
277
+
278
+ output, output_test_repr = model_cell_cls(data, None, return_cls = True)
279
+ cell_repr.append(output_test_repr.cpu().numpy())
280
+
281
+ # calculate the probability from the output
282
+ pred_prob.append(torch.nn.functional.softmax(output, dim=1).cpu().numpy())
283
+
284
+ _, predicted = torch.max(output, 1)
285
+ pred_label.append(predicted.cpu().numpy())
286
+ targets.append(target.cpu().numpy())
287
+
288
+ pred_prob = np.concatenate(pred_prob)
289
+ pred_label = np.concatenate(pred_label)
290
+ targets = np.concatenate(targets)
291
+ cell_repr = np.concatenate(cell_repr)
292
+
293
+ # break
294
+ # save the predictions
295
+ np.save(os.path.join(SAVE_PATH, f"predictions/pred_prob_{dataset_name}_{epoch}.npy"), pred_prob)
296
+ np.save(os.path.join(SAVE_PATH, f"predictions/pred_label_{dataset_name}_{epoch}.npy"), pred_label)
297
+ np.save(os.path.join(SAVE_PATH, f"predictions/targets_{dataset_name}_{epoch}.npy"), targets)
298
+
299
+ metrics = compute_metrics(pred_label, pred_prob, targets)
300
+
301
+ with open(os.path.join(SAVE_PATH, f"metrics/metrics_{dataset_name}_{epoch}.txt"), "w") as f:
302
+ print(metrics, file=f)
303
+ print(metrics)
304
+
305
+
306
+ # draw scatter plot for the first two components
307
+ from sklearn.decomposition import PCA
308
+
309
+ pca = PCA(n_components=2)
310
+ pca_result = pca.fit_transform(cell_repr)
311
+
312
+ plt.figure(figsize=(8, 8))
313
+
314
+ plt.scatter(pca_result[:, 0], pca_result[:, 1], c = targets)
315
+ plt.savefig(os.path.join(SAVE_PATH, f"figures/scatter_{dataset_name}_{epoch}.png"))
316
+ # plt.show()
317
+
318
+
319
+ # %%
320
+ model_cell_cls.eval()
321
+
322
+ def cell_embeddings(data_loader, model_cell_cls):
323
+ cell_repr = []
324
+
325
+ for i, batch in enumerate(data_loader):
326
+ data = batch[0]
327
+ target = batch[1]
328
+ data = data.to(model_cell_cls.device)
329
+ target = target.to(model_cell_cls.device)
330
+ model_cell_cls = model_cell_cls.to(model_cell_cls.device)
331
+
332
+ output, output_test_repr = model_cell_cls(data, None, return_cls = True)
333
+ cell_repr.append(output_test_repr.detach().cpu().numpy())
334
+ if i % 10 == 0:
335
+ print(f"Processed {i} batches")
336
+
337
+ cell_repr = np.concatenate(cell_repr)
338
+ return cell_repr
339
+
340
+
341
+ test_cell_repr = cell_embeddings(test_loader, model_cell_cls)
342
+ save_path_test = os.path.join(SAVE_PATH, f"repr/{dataset_name}_test_cell_repr.npy")
343
+ np.save(save_path_test, test_cell_repr)
344
+ del test_cell_repr
345
+
346
+
347
+ train_cell_repr = cell_embeddings(train_loader, model_cell_cls)
348
+ save_path_train = os.path.join(SAVE_PATH, f"repr/{dataset_name}_train_cell_repr.npy")
349
+ np.save(save_path_train, train_cell_repr)
350
+ del train_cell_repr
351
+
352
+ all_cell_repr = cell_embeddings(all_loader, model_cell_cls)
353
+ save_path_all = os.path.join(SAVE_PATH, f"repr/{dataset_name}_cell_repr.npy")
354
+ np.save(save_path_all, all_cell_repr)
355
+ del all_cell_repr
356
+
357
+
358
+ # %%
359
+ # original_data = adata.X.toarray()
360
+ # original_data.shape
361
+
362
+ # %%
363
+ # draw the scatter figure on the original data
364
+ # from sklearn.decomposition import PCA
365
+
366
+ # pca = PCA(n_components=2)
367
+ # pca_result = pca.fit_transform(original_data)
368
+
369
+ # plt.figure(figsize=(8, 8))
370
+
371
+ # plt.scatter(pca_result[:, 0], pca_result[:, 1], c = y)
372
+ # plt.show()
373
+
374
+
375
+ # %%
376
+
377
+
378
+
examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import torch
3
+ from transformers import Trainer
4
+ import os
5
+
6
+ import pyarrow as pa
7
+ import pandas as pd
8
+ import numpy as np
9
+
10
+ from matplotlib import pyplot as plt
11
+
12
+ from torch.utils.data import Dataset
13
+ from transformers import AutoTokenizer, TrainingArguments
14
+
15
+ import argparse
16
+
17
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
18
+ from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
19
+
20
+ from dotmap import DotMap
21
+
22
+ import sys
23
+ import os
24
+ import torch
25
+
26
+ sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
27
+
28
+ from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
29
+ from utils import permute_genes_by_expression, build_downstream_dataset
30
+
31
+
32
+ import importlib
33
+ importlib.reload(sys.modules['models'])
34
+ importlib.reload(sys.modules['utils'])
35
+
36
+ # %%
37
+ import scanpy as sc
38
+
39
+ import argparse
40
+
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--dataset_name", type=str)
43
+
44
+ args2 = parser.parse_args()
45
+
46
+ # Load the .h5ad file
47
+ dataset_name = args2.dataset_name
48
+
49
+
50
+ assert dataset_name in ["pbmc12k", "perirhinal_cortex", "covid19"]
51
+
52
+ adata = sc.read_h5ad(f'/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/processed/{dataset_name}_processed.h5ad')
53
+
54
+ assert "celltype" in adata.obs
55
+
56
+ print(adata)
57
+
58
+ # %%
59
+ from sklearn.preprocessing import LabelEncoder
60
+
61
+ y_names = np.array(adata.obs['celltype'].values.tolist())
62
+
63
+ label_encoder = LabelEncoder()
64
+ y = label_encoder.fit_transform(y_names)
65
+
66
+ num_class = len(label_encoder.classes_)
67
+
68
+ # %%
69
+ from transformers import PretrainedConfig
70
+
71
+ config = PretrainedConfig.from_dict({
72
+ "d_model": 512,
73
+ "mamba_layer": 24,
74
+ })
75
+
76
+
77
+ # %%
78
+ model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
79
+
80
+ # %%
81
+ permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model.tokenizer, model.symbol2id)
82
+ permuted_gene_ids
83
+
84
+ # %%
85
+ num_samples = permuted_gene_ids.shape[0]
86
+ num_avaliable_gpu = torch.cuda.device_count()
87
+
88
+ # %%
89
+ from dotmap import DotMap
90
+
91
+ args = DotMap(
92
+ {
93
+ # "model": "state-spaces/mamba-130m-hf",
94
+ # "tokenizer": "state-spaces/mamba-130m-hf",
95
+ "learning_rate": 5e-5,
96
+ "batch_size": 16,
97
+ "gradient_accumulation_steps": 1,
98
+ "optim": "adamw_torch",
99
+ # "data_path": "/home/cong/study/codeSpace/VSCodeSpace/PythonWorkSpace/TCRPrediction/mamba_transformer/smiles_data.txt",
100
+ # "num_epochs": args2.num_epochs,
101
+ "seq_len": 2048,
102
+ "num_samples": num_samples,
103
+ "num_gpus": num_avaliable_gpu,
104
+ "output_dir": "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned/debug",
105
+ }
106
+ )
107
+
108
+ # %%
109
+ input_data = permuted_gene_ids[:, :args.seq_len]
110
+
111
+ # %%
112
+ input_data.shape
113
+
114
+ #%%
115
+ # check if cls_token in the tokenizer:
116
+ if model.tokenizer.cls_token_id is None:
117
+ model.tokenizer.add_special_tokens({'cls_token': '[CLS]'})
118
+
119
+ #%%
120
+ input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
121
+
122
+ #%%
123
+ input_data.shape
124
+
125
+ # %%
126
+ sample_dataset = build_downstream_dataset(input_data, model.tokenizer)
127
+ sample_dataset
128
+
129
+ # input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
130
+ # input_data
131
+
132
+ # %%
133
+ args=TrainingArguments(
134
+ learning_rate=args.learning_rate,
135
+ num_train_epochs = 4,
136
+ per_device_train_batch_size=args.batch_size,
137
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
138
+ optim=args.optim,
139
+ output_dir=os.path.join(args.output_dir, dataset_name),
140
+ # output_dir=f"/scratch/zhiwei/cq5/logs/mamba/test/context_length",
141
+ # logging_dir=f"{args.output_dir}/{args.num_epochs}/{args.num_samples // 1000000 + args.bulk_id}m_logging",
142
+ logging_steps=args.num_samples // args.batch_size // 10,
143
+ save_steps=args.num_samples // args.batch_size // 10,
144
+ )
145
+
146
+
147
+ # %%
148
+ model.finetune(sample_dataset, args)
149
+
150
+ # %%
151
+ # ckpt_pth = get_last_checkpoint(os.path.join(args.output_dir, dataset_name))
152
+ # ckpt_pth
153
+
154
+ # #%%
155
+ # model = GeneMamba2(config, model_path=, tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
156
+
157
+ #%%
158
+
159
+
160
+
161
+
examples/downstream/legacy_from_gene_mamba/mamba2_classification_finetune_without_label_zero_shot.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import torch
3
+ from transformers import Trainer
4
+ import os
5
+
6
+ import pyarrow as pa
7
+ import pandas as pd
8
+ import numpy as np
9
+
10
+ from matplotlib import pyplot as plt
11
+
12
+ from torch.utils.data import Dataset
13
+ from transformers import AutoTokenizer, TrainingArguments
14
+
15
+ import argparse
16
+
17
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
18
+ from transformers import AutoTokenizer, TrainingArguments, MambaForCausalLM
19
+
20
+ from dotmap import DotMap
21
+
22
+ import sys
23
+ import os
24
+ import torch
25
+
26
+ sys.path.append("/project/zhiwei/cq5/PythonWorkSpace/gene_mamba")
27
+
28
+ from models import Classifier, GeneMamba, GeneMambaForCellAnnotation, GeneMambaForGeneClassification, GeneMamba2, GeneMamba2ForCellClassification
29
+ from utils import permute_genes_by_expression, build_downstream_dataset, get_last_checkpoint
30
+
31
+
32
+ import importlib
33
+ importlib.reload(sys.modules['models'])
34
+ importlib.reload(sys.modules['utils'])
35
+
36
+ # %%
37
+ import scanpy as sc
38
+
39
+ # import argparse
40
+
41
+ # parser = argparse.ArgumentParser()
42
+ # parser.add_argument("--dataset_name", type=str)
43
+
44
+ # args2 = parser.parse_args()
45
+
46
+ # dataset_name = args2.dataset_name
47
+
48
+ dataset_name = "pbmc12k"
49
+
50
+ assert dataset_name in ["pbmc12k", "perirhinal_cortex", "covid19"]
51
+
52
+ adata = sc.read_h5ad(f'/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/dataset/downstream/processed/{dataset_name}_processed.h5ad')
53
+
54
+ assert "celltype" in adata.obs
55
+
56
+ print(adata)
57
+
58
+ # %%
59
+ from transformers import PretrainedConfig
60
+
61
+ config = PretrainedConfig.from_dict({
62
+ "d_model": 512,
63
+ "mamba_layer": 24,
64
+ })
65
+
66
+
67
+ # %%
68
+ model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
69
+
70
+ # %%
71
+ permuted_gene_ids = permute_genes_by_expression(adata, dataset_name, model.tokenizer, model.symbol2id)
72
+ permuted_gene_ids
73
+
74
+ # %%
75
+ num_samples = permuted_gene_ids.shape[0]
76
+ num_avaliable_gpu = torch.cuda.device_count()
77
+
78
+ # %%
79
+ from dotmap import DotMap
80
+
81
+ args = DotMap(
82
+ {
83
+ # "model": "state-spaces/mamba-130m-hf",
84
+ # "tokenizer": "state-spaces/mamba-130m-hf",
85
+ "learning_rate": 5e-5,
86
+ "batch_size": 16,
87
+ "gradient_accumulation_steps": 1,
88
+ "optim": "adamw_torch",
89
+ # "data_path": "/home/cong/study/codeSpace/VSCodeSpace/PythonWorkSpace/TCRPrediction/mamba_transformer/smiles_data.txt",
90
+ # "num_epochs": args2.num_epochs,
91
+ "seq_len": 2048,
92
+ "num_samples": num_samples,
93
+ "num_gpus": num_avaliable_gpu,
94
+ "output_dir": "/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned",
95
+ }
96
+ )
97
+
98
+
99
+ #%%
100
+ model = GeneMamba2(config, model_path="/project/zhiwei/cq5/LLM_checkpoints/GeneMamba/GeneMamba2_24l_512d/1/16m/checkpoint-31250", tokenizer_path="/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/gene_tokenizer.json", args=None)
101
+
102
+ model.resize_token_embeddings()
103
+
104
+ #%%
105
+ def get_last_checkpoint(output_dir):
106
+ checkpoints = os.listdir(output_dir)
107
+ checkpoints = [ckpt for ckpt in checkpoints if "checkpoint" in ckpt]
108
+ checkpoints = [int(ckpt.split("-")[1]) for ckpt in checkpoints]
109
+ checkpoints = sorted(checkpoints)
110
+ last_checkpoint = checkpoints[-1]
111
+ last_checkpoint = os.path.join(output_dir, f"checkpoint-{last_checkpoint}")
112
+ return last_checkpoint
113
+
114
+ ckpt_pth = f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/fine-tuned/{dataset_name}"
115
+
116
+ last_checkpoint = get_last_checkpoint(ckpt_pth)
117
+ state_dict_pth = os.path.join(last_checkpoint, "model.safetensors")
118
+
119
+ print(state_dict_pth)
120
+
121
+ #%%
122
+ from safetensors.torch import load_file
123
+
124
+ state_dict = load_file(state_dict_pth)
125
+
126
+ model.model.load_state_dict(state_dict)
127
+
128
+
129
+ # %%
130
+ input_data = permuted_gene_ids[:, :args.seq_len]
131
+
132
+ # %%
133
+ input_data.shape
134
+
135
+ #%%
136
+ # check if cls_token in the tokenizer:
137
+ if model.tokenizer.cls_token_id is None:
138
+ model.tokenizer.add_special_tokens({'cls_token': '[CLS]'})
139
+
140
+ #%%
141
+ input_data = np.hstack([np.array([model.tokenizer.cls_token_id for _ in range(input_data.shape[0])]).reshape(-1, 1), input_data])
142
+
143
+ #%%
144
+ input_data.shape
145
+
146
+
147
+ #%%
148
+ from torch.utils.data import DataLoader, Dataset
149
+
150
+ class GeneDataset(Dataset):
151
+ def __init__(self, data):
152
+ self.data = data
153
+
154
+ def __len__(self):
155
+ return len(self.data)
156
+
157
+ def __getitem__(self, idx):
158
+ return self.data[idx]
159
+
160
+
161
+ #%%
162
+ all_dataset = GeneDataset(input_data)
163
+ all_loader = DataLoader(all_dataset, batch_size = args.batch_size, shuffle=False)
164
+
165
+ # %%
166
+ def cell_embeddings(data_loader, model):
167
+
168
+ cell_repr = []
169
+
170
+ for i, batch in enumerate(data_loader):
171
+ batch = batch.to(model.device)
172
+ outputs = model(batch)
173
+
174
+
175
+ cls_representation = outputs.hidden_states[:, 0, :]
176
+ cell_repr.append(cls_representation.detach().cpu().numpy())
177
+
178
+ if i % 10 == 0:
179
+ print(f"Processed {i} batches")
180
+
181
+ cell_repr = np.concatenate(cell_repr)
182
+ return cell_repr
183
+
184
+ # %%
185
+ model = model.to("cuda")
186
+ model.eval()
187
+
188
+ # %%
189
+ cell_repr = cell_embeddings(all_loader, model)
190
+ cell_repr.shape
191
+
192
+
193
+ # cell_repr = np.concatenate(cell_repr)
194
+ # %%
195
+ np.save(f"/project/zhiwei/cq5/PythonWorkSpace/gene_mamba/analysis/cell_type_annotation/embeddings/fine-tuned/{dataset_name}_cell_repr.npy", cell_repr)
196
+
197
+ # %%
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccb1fcb0ee4b3ea2013099b9b187455e160d3b66b76c606715231b70b13c2784
3
+ size 262998656
modeling_genemamba.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch implementation of GeneMamba model for Hugging Face Transformers.
3
+ Includes backbone model and task-specific heads for various downstream tasks.
4
+ """
5
+
6
+ import math
7
+ import logging
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.init import normal_, constant_
14
+
15
+ from transformers import PreTrainedModel, PretrainedConfig
16
+ from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
17
+ from transformers.models.auto import register_model_for_auto_class
18
+
19
+ from mamba_ssm import Mamba
20
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm
21
+
22
+ from .configuration_genemamba import GeneMambaConfig
23
+ from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ # ===========================
29
+ # Core Architecture Components
30
+ # ===========================
31
+
32
+ class EncoderLayer(nn.Module):
33
+ """
34
+ Single Mamba encoder layer with residual connection.
35
+ Applies a Mamba2 or Mamba layer followed by addition with input.
36
+
37
+ Args:
38
+ hidden_size (int): Dimension of hidden states.
39
+ """
40
+
41
+ def __init__(self, hidden_size: int):
42
+ super(EncoderLayer, self).__init__()
43
+ self.mamba = Mamba(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
44
+
45
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ Args:
48
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
49
+
50
+ Returns:
51
+ torch.Tensor: Output after Mamba layer and residual connection.
52
+ """
53
+ output = self.mamba(X) + X
54
+ return output
55
+
56
+
57
+ class MambaMixer(nn.Module):
58
+ """
59
+ Stack of Mamba encoder layers with bidirectional processing and aggregation.
60
+ Processes sequences in both forward and reverse directions, then aggregates.
61
+
62
+ Args:
63
+ mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
64
+ hidden_size (int): Dimension of hidden states.
65
+ num_hidden_layers (int): Number of Mamba layers.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ mode: str = "gate",
71
+ hidden_size: int = 512,
72
+ num_hidden_layers: int = 24
73
+ ):
74
+ super(MambaMixer, self).__init__()
75
+ self.mode = mode
76
+ self.hidden_size = hidden_size
77
+
78
+ # Create Mamba layers
79
+ self.layers = nn.ModuleList(
80
+ [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
81
+ )
82
+
83
+ # Aggregation modules for certain modes
84
+ if mode in ["concat", "gate"]:
85
+ self.aggr = nn.Linear(hidden_size * 2, hidden_size)
86
+
87
+ def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
88
+ """
89
+ Reverse a sequence based on actual length (ignoring padding).
90
+
91
+ Args:
92
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
93
+ mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
94
+
95
+ Returns:
96
+ torch.Tensor: Reversed tensor.
97
+ """
98
+ batch_size, seq_length, embedding_dim = X.size()
99
+
100
+ if mask is None:
101
+ # Simple flip
102
+ return X.flip([1])
103
+
104
+ # Flip based on actual sequence length (marked by mask)
105
+ lengths = (~mask).sum(dim=1)
106
+ pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
107
+ flip_mask = pos_tensor < lengths.unsqueeze(1)
108
+ reversed_positions = torch.where(
109
+ flip_mask,
110
+ lengths.unsqueeze(1) - 1 - pos_tensor,
111
+ pos_tensor
112
+ )
113
+
114
+ X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
115
+ return X_reverse
116
+
117
+ def forward(
118
+ self,
119
+ X: torch.Tensor,
120
+ padding_mask: Optional[torch.Tensor] = None
121
+ ) -> torch.Tensor:
122
+ """
123
+ Process sequence through bidirectional Mamba layers.
124
+
125
+ Args:
126
+ X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
127
+ padding_mask (torch.Tensor, optional): Padding mask.
128
+
129
+ Returns:
130
+ torch.Tensor: Output after processing all layers and aggregation.
131
+ """
132
+
133
+ for layer in self.layers:
134
+ # Flip sequence for reverse processing
135
+ X_flip = self.flip_sequence(X, padding_mask)
136
+
137
+ # Forward and reverse passes
138
+ X_f = layer(X)
139
+ X_b = layer(X_flip)
140
+
141
+ # Flip back the reverse output
142
+ X_b = self.flip_sequence(X_b, padding_mask)
143
+
144
+ # Aggregate forward and reverse
145
+ if self.mode == "mean":
146
+ X = (X_f + X_b) / 2
147
+ elif self.mode == "sum":
148
+ X = X_f + X_b
149
+ elif self.mode == "concat":
150
+ X = torch.cat([X_f, X_b], dim=-1)
151
+ X = self.aggr(X)
152
+ elif self.mode == "gate":
153
+ z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
154
+ X = z * X_f + (1 - z) * X_b
155
+ else:
156
+ raise ValueError(f"Invalid aggregation mode: {self.mode}")
157
+
158
+ return X
159
+
160
+
161
+ # ===========================
162
+ # Base Model Classes
163
+ # ===========================
164
+
165
+ class GeneMambaPreTrainedModel(PreTrainedModel):
166
+ """
167
+ Base class for all GeneMamba models.
168
+ Handles weight initialization and provides standard model interfaces.
169
+ """
170
+
171
+ config_class = GeneMambaConfig
172
+ base_model_prefix = "genemamba"
173
+ supports_gradient_checkpointing = True
174
+
175
+ def _init_weights(self, module):
176
+ """Initialize module weights."""
177
+ if isinstance(module, nn.Linear):
178
+ normal_(module.weight, std=self.config.initializer_range)
179
+ if module.bias is not None:
180
+ constant_(module.bias, 0.0)
181
+ elif isinstance(module, nn.Embedding):
182
+ normal_(module.weight, std=self.config.initializer_range)
183
+ if module.padding_idx is not None:
184
+ module.weight.data[module.padding_idx].zero_()
185
+ elif isinstance(module, nn.LayerNorm):
186
+ constant_(module.bias, 0.0)
187
+ constant_(module.weight, 1.0)
188
+
189
+
190
+ class GeneMambaModel(GeneMambaPreTrainedModel):
191
+ """
192
+ GeneMamba backbone model - outputs cell embeddings and hidden states.
193
+ This is the core model used by task-specific heads.
194
+
195
+ Args:
196
+ config (GeneMambaConfig): Model configuration class.
197
+ """
198
+
199
+ def __init__(self, config: GeneMambaConfig):
200
+ super().__init__(config)
201
+ self.config = config
202
+
203
+ # Embedding layer
204
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
205
+
206
+ # Mamba layers with bidirectional aggregation
207
+ self.mamba_mixer = MambaMixer(
208
+ mode=config.mamba_mode,
209
+ hidden_size=config.hidden_size,
210
+ num_hidden_layers=config.num_hidden_layers
211
+ )
212
+
213
+ # Final layer normalization
214
+ self.norm = RMSNorm(config.hidden_size)
215
+
216
+ self.apply(self._init_weights)
217
+
218
+ def get_input_embeddings(self) -> nn.Embedding:
219
+ """Return embedding layer."""
220
+ return self.embeddings
221
+
222
+ def set_input_embeddings(self, value: nn.Embedding):
223
+ """Set embedding layer."""
224
+ self.embeddings = value
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: torch.Tensor,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ output_hidden_states: bool = False,
231
+ ) -> GeneMambaModelOutput:
232
+ """
233
+ Args:
234
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
235
+ attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
236
+ output_hidden_states (bool): Whether to output hidden states from all layers.
237
+
238
+ Returns:
239
+ GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
240
+ """
241
+ # Get embeddings
242
+ hidden_states = self.embeddings(input_ids)
243
+
244
+ # Pass through Mamba layers
245
+ hidden_states = self.mamba_mixer(hidden_states, attention_mask)
246
+
247
+ # Apply final normalization
248
+ hidden_states = self.norm(hidden_states)
249
+
250
+ # Compute pooled embedding (cell representation)
251
+ if self.config.embedding_pooling == "CLS":
252
+ # Use first token (CLS)
253
+ pooled_embedding = hidden_states[:, 0, :]
254
+ elif self.config.embedding_pooling == "mean":
255
+ # Mean pooling over sequence
256
+ if attention_mask is not None:
257
+ mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
258
+ pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
259
+ else:
260
+ pooled_embedding = hidden_states.mean(dim=1)
261
+ else:
262
+ raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
263
+
264
+ return GeneMambaModelOutput(
265
+ last_hidden_state=hidden_states,
266
+ pooled_embedding=pooled_embedding,
267
+ hidden_states=hidden_states if output_hidden_states else None,
268
+ embedding_pooling=self.config.embedding_pooling,
269
+ )
270
+
271
+
272
+ # ===========================
273
+ # Task-Specific Models
274
+ # ===========================
275
+
276
+ @register_model_for_auto_class("AutoModel")
277
+ class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
278
+ """
279
+ GeneMamba model for masked language modeling (MLM).
280
+ Suitable for pretraining and domain adaptation.
281
+
282
+ Args:
283
+ config (GeneMambaConfig): Model configuration class.
284
+ """
285
+
286
+ def __init__(self, config: GeneMambaConfig):
287
+ super().__init__(config)
288
+ self.genemamba = GeneMambaModel(config)
289
+
290
+ # Language modeling head
291
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
292
+
293
+ self.apply(self._init_weights)
294
+
295
+ def forward(
296
+ self,
297
+ input_ids: torch.Tensor,
298
+ attention_mask: Optional[torch.Tensor] = None,
299
+ labels: Optional[torch.Tensor] = None,
300
+ output_hidden_states: bool = False,
301
+ ) -> GeneMambaMaskedLMOutput:
302
+ """
303
+ Args:
304
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
305
+ attention_mask (torch.Tensor, optional): Attention mask.
306
+ labels (torch.Tensor, optional): Target token ids for MLM loss.
307
+ output_hidden_states (bool): Whether to output hidden states.
308
+
309
+ Returns:
310
+ GeneMambaMaskedLMOutput: Contains logits and optional loss.
311
+ """
312
+ outputs = self.genemamba(
313
+ input_ids=input_ids,
314
+ attention_mask=attention_mask,
315
+ output_hidden_states=output_hidden_states,
316
+ )
317
+
318
+ logits = self.lm_head(outputs.last_hidden_state)
319
+
320
+ loss = None
321
+ if labels is not None:
322
+ loss_fct = nn.CrossEntropyLoss()
323
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
324
+
325
+ return GeneMambaMaskedLMOutput(
326
+ loss=loss,
327
+ logits=logits,
328
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
329
+ )
330
+
331
+
332
+ @register_model_for_auto_class("AutoModelForSequenceClassification")
333
+ class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
334
+ """
335
+ GeneMamba model for sequence classification tasks.
336
+ Ideal for cell type annotation, tissue classification, etc.
337
+
338
+ Args:
339
+ config (GeneMambaConfig): Model configuration class.
340
+ """
341
+
342
+ def __init__(self, config: GeneMambaConfig):
343
+ super().__init__(config)
344
+ self.num_labels = config.num_labels
345
+ self.config = config
346
+
347
+ self.genemamba = GeneMambaModel(config)
348
+
349
+ # Classification head
350
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
351
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
352
+
353
+ self.apply(self._init_weights)
354
+
355
+ def forward(
356
+ self,
357
+ input_ids: torch.Tensor,
358
+ attention_mask: Optional[torch.Tensor] = None,
359
+ labels: Optional[torch.Tensor] = None,
360
+ output_hidden_states: bool = False,
361
+ ) -> GeneMambaSequenceClassifierOutput:
362
+ """
363
+ Args:
364
+ input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
365
+ attention_mask (torch.Tensor, optional): Attention mask.
366
+ labels (torch.Tensor, optional): Class labels for classification loss.
367
+ output_hidden_states (bool): Whether to output hidden states.
368
+
369
+ Returns:
370
+ GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
371
+ """
372
+ outputs = self.genemamba(
373
+ input_ids=input_ids,
374
+ attention_mask=attention_mask,
375
+ output_hidden_states=output_hidden_states,
376
+ )
377
+
378
+ pooled_embedding = outputs.pooled_embedding
379
+ logits = self.classifier(self.dropout(pooled_embedding))
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ loss_fct = nn.CrossEntropyLoss()
384
+ loss = loss_fct(logits, labels)
385
+
386
+ return GeneMambaSequenceClassifierOutput(
387
+ loss=loss,
388
+ logits=logits,
389
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
390
+ pooled_embedding=pooled_embedding,
391
+ )
392
+
393
+
394
+ # Register tokenizer class
395
+ register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)
modeling_outputs.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom ModelOutput classes for GeneMamba.
3
+ Defines the output structure for different GeneMamba tasks.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+ import torch
9
+ from transformers.utils import ModelOutput
10
+
11
+
12
+ @dataclass
13
+ class GeneMambaModelOutput(ModelOutput):
14
+ """
15
+ Base output class for GeneMamba models.
16
+
17
+ Attributes:
18
+ last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
19
+ Sequence of hidden-states at the output of the last layer of the model.
20
+
21
+ hidden_states (tuple(torch.FloatTensor), optional):
22
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
23
+
24
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
25
+ Cell/sequence-level embedding (pooled representation) used for downstream tasks.
26
+ This is the recommended embedding to use for classification, clustering, etc.
27
+
28
+ embedding_pooling (str):
29
+ The pooling method used to generate pooled_embedding.
30
+ """
31
+
32
+ last_hidden_state: torch.FloatTensor = None
33
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
34
+ pooled_embedding: torch.FloatTensor = None
35
+ embedding_pooling: str = "mean"
36
+
37
+
38
+ @dataclass
39
+ class GeneMambaSequenceClassifierOutput(ModelOutput):
40
+ """
41
+ Output class for GeneMamba sequence classification models.
42
+
43
+ Attributes:
44
+ loss (torch.FloatTensor of shape (), optional):
45
+ Classification loss (if labels were provided).
46
+
47
+ logits (torch.FloatTensor of shape (batch_size, num_labels)):
48
+ Classification scores (before softmax).
49
+
50
+ hidden_states (tuple(torch.FloatTensor), optional):
51
+ Hidden-states of the model at the output of each layer.
52
+
53
+ pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
54
+ Cell embedding before classification head.
55
+ """
56
+
57
+ loss: Optional[torch.FloatTensor] = None
58
+ logits: torch.FloatTensor = None
59
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
60
+ pooled_embedding: Optional[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ class GeneMambaMaskedLMOutput(ModelOutput):
65
+ """
66
+ Output class for GeneMamba masked language modeling.
67
+
68
+ Attributes:
69
+ loss (torch.FloatTensor of shape (), optional):
70
+ MLM loss (if labels were provided).
71
+
72
+ logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
73
+ Prediction scores of the language modeling head.
74
+
75
+ hidden_states (tuple(torch.FloatTensor), optional):
76
+ Hidden-states of the model at the output of each layer.
77
+ """
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+ logits: torch.FloatTensor = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
special_tokens_map.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pad_token": "[PAD]",
3
+ "unk_token": "[UNK]"
4
+ }