chengCCC commited on
Commit
cd118dc
·
verified ·
1 Parent(s): fc89b2b

Upload CrossDNA 28.6M pretrained files

Browse files
28.6M/README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gpl-3.0
3
+ ---
4
+
5
+ ## Using CrossDNA 28.6M (140K sequence inputs)
6
+ ```python
7
+ import os
8
+
9
+
10
+ os.environ["DISABLE_TORCH_COMPILE"] = "1"
11
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
12
+
13
+ import torch
14
+
15
+
16
+ if hasattr(torch, "compile"):
17
+ def _no_compile(fn=None, *args, **kwargs):
18
+ if fn is None:
19
+ def deco(f):
20
+ return f
21
+ return deco
22
+ return fn
23
+ torch.compile = _no_compile
24
+
25
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
26
+
27
+ MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna"
28
+
29
+ tok = AutoTokenizer.from_pretrained(
30
+ MODEL_DIR,
31
+ trust_remote_code=True,
32
+ local_files_only=True,
33
+ )
34
+
35
+ model = AutoModelForMaskedLM.from_pretrained(
36
+ MODEL_DIR,
37
+ trust_remote_code=True,
38
+ local_files_only=True,
39
+ ).eval()
40
+
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model.to(device)
43
+
44
+ seq = "ACGT" * 128
45
+ enc = tok(seq, return_tensors="pt", add_special_tokens=False)
46
+ x = enc["input_ids"].to(device)
47
+
48
+ with torch.inference_mode():
49
+ out = model(input_ids=x)
50
+ emb = model.extract_embeddings(x)
51
+
52
+ print("input_ids.shape =", tuple(x.shape))
53
+ print("logits.shape =", tuple(out.logits.shape))
54
+ print("embeddings.shape =", tuple(emb.shape))
55
+
56
+ ```
28.6M/huggingface_crossdna_140K_len/crossdna/__pycache__/configuration_crossdna.cpython-311.pyc ADDED
Binary file (6.14 kB). View file
 
28.6M/huggingface_crossdna_140K_len/crossdna/__pycache__/modeling_crossdna.cpython-311.pyc ADDED
Binary file (99.4 kB). View file
 
28.6M/huggingface_crossdna_140K_len/crossdna/config.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alphabet_size": 5,
3
+ "architectures": [
4
+ "CrossDNAForMaskedLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_crossdna.CrossDNAConfig",
8
+ "AutoModelForMaskedLM": "modeling_crossdna.CrossDNAForMaskedLM",
9
+ "AutoTokenizer": "tokenization_crossdna.CrossDNATokenizer"
10
+ },
11
+ "auto_update_ema_in_forward": true,
12
+ "aux_ce_weight": 0.0,
13
+ "block_size": 4096,
14
+ "bos_token_id": 2,
15
+ "bridge_dropout": 0.05,
16
+ "checkpoint_chunk_size": 1,
17
+ "checkpoint_core_layers": true,
18
+ "cls_token_id": 0,
19
+ "comba_cfg": {
20
+ "conv_size": 4,
21
+ "correction_factor": 0.02,
22
+ "expand_v": 1,
23
+ "head_dim": 64,
24
+ "hidden_size": 256,
25
+ "mode": "chunk",
26
+ "norm_eps": 1e-05,
27
+ "num_heads": 8,
28
+ "use_gate": true,
29
+ "use_short_conv": true
30
+ },
31
+ "compact_n_token_id": 4,
32
+ "core_checkpoint_chunk_size": 1,
33
+ "d_model": 256,
34
+ "depth": 8,
35
+ "detach_gate": false,
36
+ "disable_cross_view": false,
37
+ "dna_token_ids": {
38
+ "A": 7,
39
+ "C": 8,
40
+ "G": 9,
41
+ "N": 11,
42
+ "T": 10
43
+ },
44
+ "dna_token_start_id": 7,
45
+ "dna_tokens": [
46
+ "A",
47
+ "C",
48
+ "G",
49
+ "T",
50
+ "N"
51
+ ],
52
+ "drop_path_rates": [
53
+ 0.0,
54
+ 0.08
55
+ ],
56
+ "dropout": 0.1,
57
+ "dtype": "float32",
58
+ "ema_decay": 0.9995,
59
+ "eos_token_id": 1,
60
+ "for_representation": false,
61
+ "gate_freeze_steps": 1000,
62
+ "gate_sup_warmup_steps": 1000,
63
+ "gate_sup_weight": 0.003,
64
+ "gate_temp": 1.2,
65
+ "mask_token_id": 3,
66
+ "model_type": "crossdna",
67
+ "pad_token_id": 4,
68
+ "pretrain": true,
69
+ "rc_bidirectional_stopgrad": true,
70
+ "rc_max_weight": 0.2,
71
+ "rc_tau": 1.5,
72
+ "rc_warmup_steps": 2000,
73
+ "return_ab_logits": true,
74
+ "sem_max_weight": 0.1,
75
+ "sem_warmup_steps": 8000,
76
+ "sep_token_id": 1,
77
+ "streaming_loss": true,
78
+ "streaming_report_ab": true,
79
+ "transformer_cfg": {
80
+ "attn": {
81
+ "num_heads": 8,
82
+ "num_kv_heads": 8,
83
+ "qkv_bias": false,
84
+ "rope_theta": 10000,
85
+ "window_size": 512
86
+ },
87
+ "fuse_swiglu": true,
88
+ "hidden_act": "swish",
89
+ "hidden_ratio": 4.0,
90
+ "hidden_size": 256,
91
+ "max_position_embeddings": 4096,
92
+ "norm_eps": 1e-05
93
+ },
94
+ "transformers_version": "4.57.1",
95
+ "unk_token_id": 6,
96
+ "use_barlow": false,
97
+ "use_bridge": true,
98
+ "use_checkpointing": true,
99
+ "use_ema_teacher": true,
100
+ "use_final_conv": false,
101
+ "use_mem": false,
102
+ "use_rc_kl": false,
103
+ "use_s_scan": true,
104
+ "use_tv": false,
105
+ "vocab_size": 12
106
+ }
28.6M/huggingface_crossdna_140K_len/crossdna/configuration_crossdna.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class CrossDNAConfig(PretrainedConfig):
5
+ model_type = "crossdna"
6
+
7
+ def __init__(
8
+ self,
9
+ alphabet_size=5,
10
+ vocab_size=12,
11
+ dna_tokens=("A", "C", "G", "T", "N"),
12
+ dna_token_start_id=7,
13
+ compact_n_token_id=4,
14
+ dna_token_ids=None,
15
+
16
+ d_model=256,
17
+ block_size=4096,
18
+ depth=8,
19
+ drop_path_rates=(0.0, 0.08),
20
+ dropout=0.10,
21
+
22
+ pretrain=True,
23
+ use_s_scan=True,
24
+ for_representation=False,
25
+
26
+ use_bridge=True,
27
+ bridge_dropout=0.05,
28
+
29
+ use_ema_teacher=True,
30
+ ema_decay=0.9995,
31
+ auto_update_ema_in_forward=True,
32
+
33
+ use_mem=False,
34
+ use_rc_kl=False,
35
+ use_barlow=False,
36
+ use_tv=False,
37
+ use_final_conv=False,
38
+
39
+ sem_max_weight=0.10,
40
+ sem_warmup_steps=8000,
41
+ aux_ce_weight=0.0,
42
+
43
+ gate_freeze_steps=1000,
44
+ detach_gate=False,
45
+ gate_sup_weight=0.003,
46
+ gate_sup_warmup_steps=1000,
47
+ gate_temp=1.2,
48
+
49
+ use_checkpointing=True,
50
+ checkpoint_chunk_size=1,
51
+ checkpoint_core_layers=True,
52
+ core_checkpoint_chunk_size=1,
53
+
54
+ return_ab_logits=True,
55
+ streaming_loss=True,
56
+ streaming_report_ab=True,
57
+
58
+ disable_cross_view=False,
59
+
60
+ rc_bidirectional_stopgrad=True,
61
+ rc_max_weight=0.2,
62
+ rc_tau=1.5,
63
+ rc_warmup_steps=2000,
64
+
65
+ transformer_cfg=None,
66
+ comba_cfg=None,
67
+
68
+ pad_token_id=4,
69
+ bos_token_id=2,
70
+ eos_token_id=1,
71
+ sep_token_id=1,
72
+ cls_token_id=0,
73
+ mask_token_id=3,
74
+ unk_token_id=6,
75
+ **kwargs,
76
+ ):
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ sep_token_id=sep_token_id,
82
+ cls_token_id=cls_token_id,
83
+ mask_token_id=mask_token_id,
84
+ unk_token_id=unk_token_id,
85
+ **kwargs,
86
+ )
87
+
88
+ self.alphabet_size = int(alphabet_size)
89
+ self.vocab_size = int(vocab_size)
90
+
91
+ self.dna_tokens = list(dna_tokens)
92
+ self.dna_token_start_id = int(dna_token_start_id)
93
+ self.compact_n_token_id = int(compact_n_token_id)
94
+
95
+ if dna_token_ids is None:
96
+ dna_token_ids = {
97
+ ch: self.dna_token_start_id + i
98
+ for i, ch in enumerate(self.dna_tokens)
99
+ }
100
+ self.dna_token_ids = dict(dna_token_ids)
101
+
102
+ self.d_model = int(d_model)
103
+ self.block_size = int(block_size)
104
+ self.depth = int(depth)
105
+ self.drop_path_rates = list(drop_path_rates) if drop_path_rates is not None else None
106
+ self.dropout = float(dropout)
107
+
108
+ self.pretrain = bool(pretrain)
109
+ self.use_s_scan = bool(use_s_scan)
110
+ self.for_representation = bool(for_representation)
111
+
112
+ self.use_bridge = bool(use_bridge)
113
+ self.bridge_dropout = float(bridge_dropout)
114
+
115
+ self.use_ema_teacher = bool(use_ema_teacher)
116
+ self.ema_decay = float(ema_decay)
117
+ self.auto_update_ema_in_forward = bool(auto_update_ema_in_forward)
118
+
119
+ self.use_mem = bool(use_mem)
120
+ self.use_rc_kl = bool(use_rc_kl)
121
+ self.use_barlow = bool(use_barlow)
122
+ self.use_tv = bool(use_tv)
123
+ self.use_final_conv = bool(use_final_conv)
124
+
125
+ self.sem_max_weight = float(sem_max_weight)
126
+ self.sem_warmup_steps = int(sem_warmup_steps)
127
+ self.aux_ce_weight = float(aux_ce_weight)
128
+
129
+ self.gate_freeze_steps = int(gate_freeze_steps)
130
+ self.detach_gate = bool(detach_gate)
131
+ self.gate_sup_weight = float(gate_sup_weight)
132
+ self.gate_sup_warmup_steps = int(gate_sup_warmup_steps)
133
+ self.gate_temp = float(gate_temp)
134
+
135
+ self.use_checkpointing = bool(use_checkpointing)
136
+ self.checkpoint_chunk_size = int(checkpoint_chunk_size)
137
+ self.checkpoint_core_layers = bool(checkpoint_core_layers)
138
+ self.core_checkpoint_chunk_size = int(core_checkpoint_chunk_size)
139
+
140
+ self.return_ab_logits = bool(return_ab_logits)
141
+ self.streaming_loss = bool(streaming_loss)
142
+ self.streaming_report_ab = bool(streaming_report_ab)
143
+
144
+ self.disable_cross_view = bool(disable_cross_view)
145
+
146
+ self.rc_bidirectional_stopgrad = bool(rc_bidirectional_stopgrad)
147
+ self.rc_max_weight = float(rc_max_weight)
148
+ self.rc_tau = float(rc_tau)
149
+ self.rc_warmup_steps = int(rc_warmup_steps)
150
+
151
+ self.transformer_cfg = transformer_cfg or {
152
+ "hidden_size": self.d_model,
153
+ "norm_eps": 1e-5,
154
+ "max_position_embeddings": self.block_size,
155
+ "hidden_ratio": 4.0,
156
+ "hidden_act": "swish",
157
+ "fuse_swiglu": True,
158
+ "attn": {
159
+ "num_heads": 8,
160
+ "num_kv_heads": 8,
161
+ "qkv_bias": False,
162
+ "window_size": 512,
163
+ "rope_theta": 10000,
164
+ },
165
+ }
166
+
167
+ self.comba_cfg = comba_cfg or {
168
+ "hidden_size": self.d_model,
169
+ "expand_v": 1,
170
+ "head_dim": 64,
171
+ "num_heads": 8,
172
+ "use_gate": True,
173
+ "mode": "chunk",
174
+ "use_short_conv": True,
175
+ "correction_factor": 0.02,
176
+ "conv_size": 4,
177
+ "norm_eps": 1e-5,
178
+ }
28.6M/huggingface_crossdna_140K_len/crossdna/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6788405dcb78275faeaf1b94ae14a27465a37a3ff1872b34bffdb16b07fdf333
3
+ size 454953820
28.6M/huggingface_crossdna_140K_len/crossdna/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bfbc40499d6b55d3ef4a3ba1e2f598dce843f083ad8e413621a7378b564a25a
3
+ size 225803888
28.6M/huggingface_crossdna_140K_len/crossdna/modeling_crossdna.py ADDED
@@ -0,0 +1,1702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+ import copy
5
+ from functools import partial
6
+ from contextlib import contextmanager
7
+ from collections import namedtuple
8
+ from typing import Dict, Optional, Tuple, Any
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as cp
14
+
15
+ from transformers import PreTrainedModel
16
+ from transformers.modeling_outputs import MaskedLMOutput
17
+
18
+ # speed
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+ from fla.layers import comba
23
+ from fla.layers.attn import Attention
24
+ from fla.modules import GatedMLP as SambaMLP
25
+ from fla.modules import RMSNorm
26
+
27
+ try:
28
+ from omegaconf import OmegaConf
29
+ except Exception:
30
+ OmegaConf = None
31
+
32
+ try:
33
+ from .configuration_crossdna import CrossDNAConfig
34
+ except ImportError:
35
+ from configuration_crossdna import CrossDNAConfig
36
+
37
+
38
+ # ========================
39
+ # OmegaConf helpers
40
+ # ========================
41
+ def _to_plain_container(x: Any) -> Any:
42
+ if OmegaConf is not None:
43
+ try:
44
+ if OmegaConf.is_config(x):
45
+ return OmegaConf.to_container(x, resolve=True)
46
+ except Exception:
47
+ pass
48
+ return x
49
+
50
+
51
+ def _cfg_get(cfg: Any, key: str, default: Any = None) -> Any:
52
+ if cfg is None:
53
+ return default
54
+
55
+ try:
56
+ if isinstance(cfg, dict):
57
+ return cfg.get(key, default)
58
+ except Exception:
59
+ pass
60
+
61
+ try:
62
+ if hasattr(cfg, key):
63
+ return getattr(cfg, key)
64
+ except Exception:
65
+ pass
66
+
67
+ try:
68
+ return cfg[key]
69
+ except Exception:
70
+ return default
71
+
72
+
73
+ # ========================
74
+ # Utils
75
+ # ========================
76
+ def complement(seq: torch.Tensor) -> torch.Tensor:
77
+ """
78
+ compact DNA ids only:
79
+ A=0, C=1, G=2, T=3, N=4
80
+ """
81
+ perm = torch.tensor([3, 2, 1, 0, 4], device=seq.device, dtype=torch.long)
82
+ return perm[seq.long()].to(seq.dtype)
83
+
84
+
85
+ def reverse_complement(seq: torch.Tensor) -> torch.Tensor:
86
+ comp = complement(seq)
87
+ return torch.flip(comp, dims=[1])
88
+
89
+
90
+ def make_complement_perm(C=5, device=None, dtype=torch.float32):
91
+ perm = torch.arange(C, device=device)
92
+ if C >= 4:
93
+ perm[0] = 3
94
+ perm[1] = 2
95
+ perm[2] = 1
96
+ perm[3] = 0
97
+ if C >= 5:
98
+ perm[4] = 4
99
+
100
+ P = torch.zeros(C, C, device=device, dtype=dtype)
101
+ P[torch.arange(C, device=device), perm] = 1.0
102
+ return P, perm
103
+
104
+
105
+ def ensure_finite(x: torch.Tensor, name: str):
106
+ if not torch.isfinite(x).all():
107
+ raise FloatingPointError(f"Non-finite values detected in {name}")
108
+ return x
109
+
110
+
111
+ def linear_warmup_weight(step: int, warmup_steps: int, max_w: float):
112
+ if warmup_steps <= 0:
113
+ return max_w
114
+ if step <= 0:
115
+ return 0.0
116
+ if step >= warmup_steps:
117
+ return max_w
118
+ return max_w * (step / warmup_steps)
119
+
120
+
121
+ def preferred_amp_dtype():
122
+ try:
123
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
124
+ return torch.bfloat16
125
+ except Exception:
126
+ pass
127
+ return torch.float16
128
+
129
+
130
+ def one_hot_float(x: torch.Tensor, num_classes: int, *, dtype: torch.dtype) -> torch.Tensor:
131
+ B, L = x.shape
132
+ out = torch.zeros((B, L, num_classes), device=x.device, dtype=dtype)
133
+ out.scatter_(2, x.unsqueeze(-1), 1.0)
134
+ return out
135
+
136
+
137
+ # ========================
138
+ # RC / Barlow / TV
139
+ # ========================
140
+ def rc_consistency_kl(logits_A, logits_B_fwd, P, tau: float = 1.0, eps: float = 1e-6):
141
+ zA = logits_A.float() / tau
142
+ zB = logits_B_fwd.float() / tau
143
+ pA = F.softmax(zA, dim=-1)
144
+ logpA = F.log_softmax(zA, dim=-1)
145
+ pB = F.softmax(zB, dim=-1)
146
+ pB_comp = torch.matmul(pB, P.t()).clamp_min(eps)
147
+ logpB_comp = pB_comp.log()
148
+ kl = (pA * (logpA - logpB_comp)).sum(dim=-1).mean()
149
+ return kl * (tau * tau)
150
+
151
+
152
+ def rc_consistency_bidirectional_stopgrad(logits_A, logits_B_fwd, P, tau: float = 1.5, eps: float = 1e-6):
153
+ zA = logits_A.float() / tau
154
+ zB = logits_B_fwd.float() / tau
155
+ with torch.no_grad():
156
+ pB_t = torch.matmul(F.softmax(zB, dim=-1), P.t()).clamp_min(eps)
157
+ logpB_t = pB_t.log()
158
+ loss_A = F.kl_div(F.log_softmax(zA, dim=-1), logpB_t, reduction="batchmean", log_target=True)
159
+
160
+ with torch.no_grad():
161
+ pA_t = torch.matmul(F.softmax(zA, dim=-1), P.t()).clamp_min(eps)
162
+ logpA_t = pA_t.log()
163
+ loss_B = F.kl_div(F.log_softmax(zB, dim=-1), logpA_t, reduction="batchmean", log_target=True)
164
+ return 0.5 * (tau * tau) * (loss_A + loss_B)
165
+
166
+
167
+ def barlow_strand_loss_v2(z1, z2, λ_off=0.04, λ_diag=0.04, eps=1e-3):
168
+ B, L, H = z1.shape
169
+ n = B * L
170
+ z1 = z1.reshape(n, H)
171
+ z2 = z2.reshape(n, H)
172
+
173
+ def _std(z):
174
+ var = z.var(dim=0, unbiased=False)
175
+ return torch.sqrt(var + eps)
176
+
177
+ std1, std2 = _std(z1), _std(z2)
178
+ var_term = (F.relu(1 - std1).pow(2).mean() + F.relu(1 - std2).pow(2).mean())
179
+
180
+ z1 = (z1 - z1.mean(0)) / (std1 + eps)
181
+ z2 = (z2 - z2.mean(0)) / (std2 + eps)
182
+ c = (z1.t() @ z2) / max(1, n)
183
+ diag = torch.diagonal(c)
184
+ off = c - torch.diag_embed(diag)
185
+ cov = λ_diag * (1 - diag).pow(2).mean() + λ_off * off.pow(2).mean()
186
+ return var_term + cov
187
+
188
+
189
+ def tv_mixed(h: torch.Tensor):
190
+ d1 = h[:, 1:, :] - h[:, :-1, :]
191
+ d2 = d1[:, 1:, :] - d1[:, :-1, :]
192
+ return d1.abs().mean() + d2.pow(2).mean()
193
+
194
+
195
+ class Mlp(nn.Module):
196
+ def __init__(self, input_dimension, hidden_dimension=None, output_dimension=None,
197
+ activation=F.gelu, return_residual=False):
198
+ super().__init__()
199
+ self.return_residual = return_residual
200
+ hd = hidden_dimension or input_dimension
201
+ od = output_dimension or input_dimension
202
+ self.linear1 = nn.Linear(input_dimension, hd)
203
+ self.activation = activation
204
+ self.linear2 = nn.Linear(hd, od)
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ h = self.activation(self.linear1(x))
208
+ y = self.linear2(h)
209
+ return (y, x) if self.return_residual else y
210
+
211
+
212
+ def create_comba_cls(comba_kwargs=None, device=None, dtype=None):
213
+ factory_kwargs = {}
214
+ if device is not None:
215
+ factory_kwargs["device"] = device
216
+ if dtype is not None:
217
+ factory_kwargs["dtype"] = dtype
218
+ try:
219
+ base_kwargs = dict(comba_kwargs or {})
220
+ mixer_cls = partial(comba.Comba, **base_kwargs, **factory_kwargs)
221
+ except ImportError:
222
+ class FallbackComba(nn.Module):
223
+ def forward(self, x, *args, **kwargs):
224
+ return x
225
+ mixer_cls = lambda *args, **kwargs: FallbackComba()
226
+ return mixer_cls
227
+
228
+
229
+ class SlidingWindowAttention(nn.Module):
230
+ def __init__(self, config: Any):
231
+ super().__init__()
232
+ config = _to_plain_container(config)
233
+
234
+ hidden_size = _cfg_get(config, "hidden_size")
235
+ norm_eps = _cfg_get(config, "norm_eps", 1e-5)
236
+ attn_cfg = _cfg_get(config, "attn", {}) or {}
237
+ attn_cfg = _to_plain_container(attn_cfg)
238
+
239
+ self.mixer_norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
240
+ self.mixer = Attention(
241
+ hidden_size=hidden_size,
242
+ num_heads=_cfg_get(attn_cfg, "num_heads"),
243
+ num_kv_heads=_cfg_get(attn_cfg, "num_kv_heads"),
244
+ qkv_bias=_cfg_get(attn_cfg, "qkv_bias"),
245
+ window_size=_cfg_get(attn_cfg, "window_size"),
246
+ rope_theta=_cfg_get(attn_cfg, "rope_theta"),
247
+ max_position_embeddings=_cfg_get(config, "max_position_embeddings"),
248
+ )
249
+
250
+ self.mlp_norm = RMSNorm(hidden_size, eps=norm_eps)
251
+ self.mlp = SambaMLP(
252
+ hidden_size=hidden_size,
253
+ hidden_ratio=_cfg_get(config, "hidden_ratio", 4.0),
254
+ hidden_act=_cfg_get(config, "hidden_act", "swish"),
255
+ fuse_swiglu=_cfg_get(config, "fuse_swiglu", True),
256
+ )
257
+ self.pre_scale = 1.0 / math.sqrt(2.0)
258
+
259
+ def forward(self, hidden_states: torch.Tensor, cache_params: Optional[Any] = None, **kwargs) -> Tuple[torch.Tensor, Any]:
260
+ residual = hidden_states
261
+ x = self.mixer_norm(hidden_states)
262
+
263
+ amp_dtype = preferred_amp_dtype()
264
+ device_type = x.device.type if x.device.type in ["cuda", "cpu", "xpu"] else "cuda"
265
+
266
+ with torch.autocast(device_type=device_type, enabled=(device_type == "cuda"), dtype=amp_dtype):
267
+ x_scaled = x * self.pre_scale
268
+ attn_out, _, cache_params = self.mixer(hidden_states=x_scaled, past_key_values=cache_params, **kwargs)
269
+ attn_out = attn_out / self.pre_scale
270
+
271
+ ensure_finite(attn_out, "attention_out")
272
+ h = residual + attn_out.to(x.dtype)
273
+
274
+ residual = h
275
+ x = self.mlp_norm(h)
276
+ with torch.autocast(device_type=device_type, enabled=(device_type == "cuda"), dtype=amp_dtype):
277
+ x = self.mlp(x, **kwargs)
278
+
279
+ h = residual + x
280
+ ensure_finite(h, "block_output")
281
+ return h, cache_params
282
+
283
+
284
+ class EnhancedHybridCore(nn.Module):
285
+ def __init__(self, hidden_dim, comba_cfg, transformer_cfg, layer_idx=0, device=None, dtype=None):
286
+ super().__init__()
287
+ comba_cfg = _to_plain_container(comba_cfg)
288
+ transformer_cfg = _to_plain_container(transformer_cfg)
289
+
290
+ self.comba_cls = create_comba_cls(comba_kwargs=comba_cfg, device=device, dtype=dtype)
291
+ try:
292
+ self.comba = self.comba_cls(layer_idx=layer_idx)
293
+ except TypeError:
294
+ self.comba = self.comba_cls()
295
+
296
+ self.transformer = SlidingWindowAttention(config=transformer_cfg)
297
+ self.gate = nn.Linear(hidden_dim * 2, hidden_dim)
298
+ self.out_norm = nn.LayerNorm(hidden_dim)
299
+
300
+ @staticmethod
301
+ def _first(x):
302
+ return x[0] if isinstance(x, tuple) else x
303
+
304
+ def forward(self, x):
305
+ orig_dtype = x.dtype
306
+ x_fp32 = x.float()
307
+ device_type = x.device.type if x.device.type in ["cuda", "cpu", "xpu"] else "cuda"
308
+
309
+ with torch.autocast(device_type=device_type, enabled=False):
310
+ m_out = self._first(self.comba(x_fp32))
311
+
312
+ m_out = m_out.to(orig_dtype)
313
+ del x_fp32
314
+
315
+ t_out, _ = self.transformer(m_out)
316
+
317
+ concat = torch.cat([m_out, t_out], dim=-1)
318
+ g = torch.sigmoid(self.gate(concat))
319
+ fused = g * t_out + (1 - g) * m_out
320
+ y = self.out_norm(fused)
321
+ ensure_finite(y, "EnhancedHybridCore.out")
322
+ return y
323
+
324
+
325
+ class DeepEnhancedBranch(nn.Module):
326
+ def __init__(
327
+ self,
328
+ hidden_dim: int,
329
+ comba_cfg: Dict | None,
330
+ transformer_cfg: Any,
331
+ depth: int = 4,
332
+ drop_path_rates=None,
333
+ *,
334
+ device=None,
335
+ dtype=None,
336
+ checkpoint_core_layers: bool = False,
337
+ core_checkpoint_chunk_size: int = 1,
338
+ ):
339
+ super().__init__()
340
+ self.layers = nn.ModuleList()
341
+
342
+ transformer_cfg = _to_plain_container(transformer_cfg)
343
+ comba_cfg = _to_plain_container(comba_cfg)
344
+
345
+ self.checkpoint_core_layers = bool(checkpoint_core_layers)
346
+ self.core_checkpoint_chunk_size = int(core_checkpoint_chunk_size)
347
+
348
+ if drop_path_rates is None:
349
+ rates = [0.05 * (i / max(1, depth - 1)) for i in range(depth)]
350
+ elif isinstance(drop_path_rates, (float, int)):
351
+ rates = [float(drop_path_rates)] * depth
352
+ else:
353
+ dpr = list(_to_plain_container(drop_path_rates))
354
+ rates = dpr + [dpr[-1]] * (depth - len(dpr))
355
+
356
+ for i in range(depth):
357
+ layer_cfg = dict(transformer_cfg) if isinstance(transformer_cfg, dict) else transformer_cfg.copy()
358
+ layer_cfg["drop_path_prob"] = rates[i]
359
+ self.layers.append(EnhancedHybridCore(hidden_dim, comba_cfg, layer_cfg, i, device, dtype))
360
+
361
+ self.output_norm = nn.LayerNorm(hidden_dim)
362
+
363
+ def _run_layers(self, x: torch.Tensor, start: int, end: int):
364
+ out = x
365
+ for i in range(start, end):
366
+ out = self.layers[i](out)
367
+ return out
368
+
369
+ def forward(self, x: torch.Tensor):
370
+ if self.training and self.checkpoint_core_layers:
371
+ chunk = max(1, self.core_checkpoint_chunk_size)
372
+ for s in range(0, len(self.layers), chunk):
373
+ e = min(s + chunk, len(self.layers))
374
+
375
+ def _seg(inp, s=s, e=e):
376
+ return self._run_layers(inp, s, e)
377
+
378
+ x = cp.checkpoint(_seg, x, use_reentrant=False)
379
+ else:
380
+ for layer in self.layers:
381
+ x = layer(x)
382
+
383
+ y = self.output_norm(x)
384
+ ensure_finite(y, "DeepEnhancedBranch.out")
385
+ return y
386
+
387
+
388
+ class TokenBridge(nn.Module):
389
+ def __init__(self, hidden_dim: int, dropout: float = 0.0,
390
+ kernel_size: int = 9, dilations=(1, 2, 4, 8, 16),
391
+ use_global_token: bool = True):
392
+ super().__init__()
393
+ h = hidden_dim
394
+ pad = lambda d: d * (kernel_size // 2)
395
+
396
+ self.dw_B = nn.ModuleList([
397
+ nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d, groups=h, bias=False)
398
+ for d in dilations
399
+ ])
400
+ self.mix_B = nn.Conv1d(h * len(dilations), h, 1)
401
+
402
+ self.dw_A = nn.ModuleList([
403
+ nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d, groups=h, bias=False)
404
+ for d in dilations
405
+ ])
406
+ self.mix_A = nn.Conv1d(h * len(dilations), h, 1)
407
+
408
+ self.proj_B2A = nn.Linear(h, h)
409
+ self.proj_A2B = nn.Linear(h, h)
410
+
411
+ self.use_global_token = use_global_token
412
+ if use_global_token:
413
+ self.glb_B2A = nn.Linear(h, h)
414
+ self.glb_A2B = nn.Linear(h, h)
415
+
416
+ self.gate = nn.Linear(h * 4, h * 2)
417
+ self.dropout = nn.Dropout(dropout)
418
+ self.normA = nn.LayerNorm(h)
419
+ self.normB = nn.LayerNorm(h)
420
+
421
+ @staticmethod
422
+ def _agg(x: torch.Tensor, branches: nn.ModuleList, mix: nn.Module) -> torch.Tensor:
423
+ xch = x.transpose(1, 2)
424
+ ys = [conv(xch) for conv in branches]
425
+ y = torch.cat(ys, dim=1)
426
+ y = mix(y).transpose(1, 2).contiguous()
427
+ return y
428
+
429
+ def forward(self, xA: torch.Tensor, xB: torch.Tensor):
430
+ ctxB = self._agg(xB, self.dw_B, self.mix_B)
431
+ ctxA = self._agg(xA, self.dw_A, self.mix_A)
432
+
433
+ locA = self.proj_B2A(xB + ctxB)
434
+ locB = self.proj_A2B(xA + ctxA)
435
+
436
+ if self.use_global_token:
437
+ gB = self.glb_B2A(xB.mean(dim=1, keepdim=True))
438
+ gA = self.glb_A2B(xA.mean(dim=1, keepdim=True))
439
+ locA = locA + gB.expand(-1, xA.size(1), -1)
440
+ locB = locB + gA.expand(-1, xB.size(1), -1)
441
+
442
+ z = torch.cat([xA, xB, xA - xB, xA * xB], dim=-1)
443
+ gA, gB = self.gate(z).chunk(2, dim=-1)
444
+ gA = torch.sigmoid(gA)
445
+ gB = torch.sigmoid(gB)
446
+
447
+ yA = self.normA(xA + self.dropout(gA * locA))
448
+ yB = self.normB(xB + self.dropout(gB * locB))
449
+ return yA, yB
450
+
451
+
452
+ def semantic_preservation_loss(R_plus: torch.Tensor, H_S_plus: torch.Tensor,
453
+ λ_recon: float = 1.0, λ_local: float = 0.5, λ_global: float = 0.2):
454
+ recon = F.mse_loss(H_S_plus, R_plus)
455
+
456
+ if R_plus.size(1) >= 2:
457
+ d_ref = R_plus[:, 1:] - R_plus[:, :-1]
458
+ d_S = H_S_plus[:, 1:] - H_S_plus[:, :-1]
459
+ local = F.mse_loss(d_S, d_ref)
460
+ else:
461
+ local = torch.tensor(0.0, device=R_plus.device)
462
+
463
+ def gram_norm(x):
464
+ G = torch.einsum("b i d, b j d -> b i j", x, x)
465
+ return G / (G.norm(dim=(1, 2), keepdim=True) + 1e-6)
466
+
467
+ glob = F.mse_loss(gram_norm(H_S_plus), gram_norm(R_plus))
468
+ return λ_recon * recon + λ_local * local + λ_global * glob
469
+
470
+
471
+ @contextmanager
472
+ def eval_mode(*modules):
473
+ states = [m.training for m in modules]
474
+ try:
475
+ for m in modules:
476
+ if m is not None:
477
+ m.eval()
478
+ yield
479
+ finally:
480
+ for m, s in zip(modules, states):
481
+ if m is not None:
482
+ m.train(s)
483
+
484
+
485
+ class SSScanDNAHybridModel(nn.Module):
486
+ """
487
+ Latest training-structure CrossDNA backbone.
488
+
489
+ HF wrapper strategy:
490
+ - backbone remains compact 5-token A/C/G/T/N model
491
+ - HF wrapper maps tokenizer ids -> compact ids
492
+ - HF wrapper expands 5-way logits -> tokenizer vocab logits
493
+ """
494
+
495
+ def __init__(
496
+ self,
497
+ config: Optional[Any] = None,
498
+ alphabet_size=5,
499
+ d_model=128,
500
+ block_size=2048,
501
+ comba_cfg=None,
502
+ transformer_cfg=None,
503
+ depth=4,
504
+ drop_path_rates=None,
505
+ pretrain=False,
506
+ for_representation=False,
507
+ use_final_conv=False,
508
+ use_s_scan: bool = True,
509
+ use_mem: bool = False,
510
+ use_rc_kl: bool = False,
511
+ use_barlow: bool = False,
512
+ use_tv: bool = False,
513
+ sem_max_weight: float = 0.2,
514
+ sem_warmup_steps: int = 3000,
515
+ rc_max_weight: float = 0.2,
516
+ rc_warmup_steps: int = 2000,
517
+ rc_tau: float = 1.5,
518
+ rc_bidirectional_stopgrad: bool = True,
519
+ aux_ce_weight: float = 0.1,
520
+ gate_freeze_steps: int = 1000,
521
+ detach_gate: bool = False,
522
+ gate_sup_weight: float = 0.005,
523
+ gate_sup_warmup_steps: int = 500,
524
+ gate_temp: float = 2.0,
525
+ dropout=0.1,
526
+ use_ema_teacher: bool = True,
527
+ ema_decay: float = 0.999,
528
+ auto_update_ema_in_forward: bool = True,
529
+ use_bridge: bool = True,
530
+ bridge_dropout: float = 0.0,
531
+ use_checkpointing: bool = True,
532
+ checkpoint_chunk_size: int = 2,
533
+ checkpoint_core_layers: bool = False,
534
+ core_checkpoint_chunk_size: int = 1,
535
+ return_ab_logits: bool = True,
536
+ streaming_loss: bool = True,
537
+ streaming_report_ab: bool = True,
538
+ disable_cross_view: bool = False,
539
+ **unused_kwargs,
540
+ ):
541
+ super().__init__()
542
+
543
+ self.config = config
544
+ if config is not None:
545
+ cfg = _to_plain_container(config)
546
+
547
+ alphabet_size = _cfg_get(cfg, "alphabet_size", alphabet_size)
548
+ d_model = _cfg_get(cfg, "d_model", d_model)
549
+ block_size = _cfg_get(cfg, "block_size", block_size)
550
+ depth = _cfg_get(cfg, "depth", depth)
551
+ drop_path_rates = _cfg_get(cfg, "drop_path_rates", drop_path_rates)
552
+
553
+ pretrain = _cfg_get(cfg, "pretrain", pretrain)
554
+ for_representation = _cfg_get(cfg, "for_representation", for_representation)
555
+ use_final_conv = _cfg_get(cfg, "use_final_conv", use_final_conv)
556
+
557
+ use_s_scan = _cfg_get(cfg, "use_s_scan", use_s_scan)
558
+ use_mem = _cfg_get(cfg, "use_mem", use_mem)
559
+ use_rc_kl = _cfg_get(cfg, "use_rc_kl", use_rc_kl)
560
+ use_barlow = _cfg_get(cfg, "use_barlow", use_barlow)
561
+ use_tv = _cfg_get(cfg, "use_tv", use_tv)
562
+
563
+ sem_max_weight = _cfg_get(cfg, "sem_max_weight", sem_max_weight)
564
+ sem_warmup_steps = _cfg_get(cfg, "sem_warmup_steps", sem_warmup_steps)
565
+ rc_max_weight = _cfg_get(cfg, "rc_max_weight", rc_max_weight)
566
+ rc_warmup_steps = _cfg_get(cfg, "rc_warmup_steps", rc_warmup_steps)
567
+ rc_tau = _cfg_get(cfg, "rc_tau", rc_tau)
568
+ rc_bidirectional_stopgrad = _cfg_get(cfg, "rc_bidirectional_stopgrad", rc_bidirectional_stopgrad)
569
+
570
+ aux_ce_weight = _cfg_get(cfg, "aux_ce_weight", aux_ce_weight)
571
+ gate_freeze_steps = _cfg_get(cfg, "gate_freeze_steps", gate_freeze_steps)
572
+ detach_gate = _cfg_get(cfg, "detach_gate", detach_gate)
573
+ gate_sup_weight = _cfg_get(cfg, "gate_sup_weight", gate_sup_weight)
574
+ gate_sup_warmup_steps = _cfg_get(cfg, "gate_sup_warmup_steps", gate_sup_warmup_steps)
575
+ gate_temp = _cfg_get(cfg, "gate_temp", gate_temp)
576
+ dropout = _cfg_get(cfg, "dropout", dropout)
577
+
578
+ use_bridge = _cfg_get(cfg, "use_bridge", use_bridge)
579
+ bridge_dropout = _cfg_get(cfg, "bridge_dropout", bridge_dropout)
580
+
581
+ use_checkpointing = _cfg_get(cfg, "use_checkpointing", use_checkpointing)
582
+ checkpoint_chunk_size = _cfg_get(cfg, "checkpoint_chunk_size", checkpoint_chunk_size)
583
+
584
+ checkpoint_core_layers = _cfg_get(cfg, "checkpoint_core_layers", checkpoint_core_layers)
585
+ core_checkpoint_chunk_size = _cfg_get(cfg, "core_checkpoint_chunk_size", core_checkpoint_chunk_size)
586
+
587
+ return_ab_logits = _cfg_get(cfg, "return_ab_logits", return_ab_logits)
588
+ streaming_loss = _cfg_get(cfg, "streaming_loss", streaming_loss)
589
+ streaming_report_ab = _cfg_get(cfg, "streaming_report_ab", streaming_report_ab)
590
+
591
+ use_ema_teacher = _cfg_get(cfg, "use_ema_teacher", use_ema_teacher)
592
+ ema_decay = _cfg_get(cfg, "ema_decay", ema_decay)
593
+ auto_update_ema_in_forward = _cfg_get(cfg, "auto_update_ema_in_forward", auto_update_ema_in_forward)
594
+ disable_cross_view = _cfg_get(cfg, "disable_cross_view", disable_cross_view)
595
+
596
+ transformer_cfg = _cfg_get(cfg, "transformer_cfg", transformer_cfg)
597
+ comba_cfg = _cfg_get(cfg, "comba_cfg", comba_cfg)
598
+
599
+ transformer_cfg = _to_plain_container(transformer_cfg)
600
+ comba_cfg = _to_plain_container(comba_cfg)
601
+ drop_path_rates = _to_plain_container(drop_path_rates)
602
+
603
+ self.alphabet_size = int(alphabet_size)
604
+ self.pretrain = bool(pretrain)
605
+ self.for_representation = bool(for_representation)
606
+ self.block_size = int(block_size)
607
+ self.use_final_conv = bool(use_final_conv)
608
+ self.d_model = int(d_model)
609
+
610
+ self.use_checkpointing = bool(use_checkpointing)
611
+ self.checkpoint_chunk_size = int(checkpoint_chunk_size)
612
+
613
+ self.checkpoint_core_layers = bool(checkpoint_core_layers)
614
+ self.core_checkpoint_chunk_size = int(core_checkpoint_chunk_size)
615
+
616
+ self.return_ab_logits = bool(return_ab_logits)
617
+ self.streaming_loss = bool(streaming_loss)
618
+ self.streaming_report_ab = bool(streaming_report_ab)
619
+
620
+ self.use_ema_teacher = bool(use_ema_teacher)
621
+ self.ema_decay = float(ema_decay)
622
+ self.auto_update_ema_in_forward = bool(auto_update_ema_in_forward)
623
+ self.disable_cross_view = bool(disable_cross_view)
624
+
625
+ self.register_buffer("g_step", torch.zeros(1, dtype=torch.long))
626
+
627
+ self.linear = nn.Conv1d(self.alphabet_size, self.d_model, kernel_size=9, padding=4)
628
+ self.rc_linear = nn.Conv1d(self.alphabet_size, self.d_model, kernel_size=9, padding=4)
629
+
630
+ self.mlm_mask_embed = nn.Parameter(torch.zeros(self.d_model))
631
+ self.special_token_embed = nn.Parameter(torch.zeros(self.d_model))
632
+ nn.init.normal_(self.mlm_mask_embed, mean=0.0, std=0.02)
633
+ nn.init.normal_(self.special_token_embed, mean=0.0, std=0.02)
634
+
635
+ self.branchA_core = DeepEnhancedBranch(
636
+ hidden_dim=self.d_model,
637
+ comba_cfg=comba_cfg,
638
+ transformer_cfg=transformer_cfg,
639
+ depth=int(depth),
640
+ drop_path_rates=drop_path_rates,
641
+ checkpoint_core_layers=self.checkpoint_core_layers,
642
+ core_checkpoint_chunk_size=self.core_checkpoint_chunk_size,
643
+ )
644
+ self.branchB_core = DeepEnhancedBranch(
645
+ hidden_dim=self.d_model,
646
+ comba_cfg=comba_cfg,
647
+ transformer_cfg=transformer_cfg,
648
+ depth=int(depth),
649
+ drop_path_rates=drop_path_rates,
650
+ checkpoint_core_layers=self.checkpoint_core_layers,
651
+ core_checkpoint_chunk_size=self.core_checkpoint_chunk_size,
652
+ )
653
+
654
+ self.use_bridge = bool(use_bridge)
655
+ if self.use_bridge:
656
+ self.bridge = TokenBridge(self.d_model, dropout=float(bridge_dropout))
657
+ else:
658
+ self.bridge = None
659
+
660
+ if self.use_ema_teacher:
661
+ self.branchA_core_ema = copy.deepcopy(self.branchA_core)
662
+ self.branchB_core_ema = copy.deepcopy(self.branchB_core)
663
+ for p in self.branchA_core_ema.parameters():
664
+ p.requires_grad_(False)
665
+ for p in self.branchB_core_ema.parameters():
666
+ p.requires_grad_(False)
667
+
668
+ if self.use_bridge:
669
+ self.bridge_ema = copy.deepcopy(self.bridge)
670
+ for p in self.bridge_ema.parameters():
671
+ p.requires_grad_(False)
672
+ else:
673
+ self.bridge_ema = None
674
+ else:
675
+ self.branchA_core_ema = None
676
+ self.branchB_core_ema = None
677
+ self.bridge_ema = None
678
+
679
+ self.proj_A = Mlp(self.d_model, self.d_model * 2, self.d_model, activation=F.gelu, return_residual=True)
680
+ self.proj_B = Mlp(self.d_model, self.d_model * 2, self.d_model, activation=F.gelu, return_residual=True)
681
+ self.gate_fuse = nn.Linear(2 * self.d_model, self.d_model)
682
+ self.out_linear = nn.Linear(self.d_model, self.alphabet_size)
683
+ self.dropout = nn.Dropout(float(dropout))
684
+
685
+ P_comp, _ = make_complement_perm(self.alphabet_size)
686
+ self.register_buffer("P_comp", P_comp)
687
+
688
+ self.use_s_scan = bool(use_s_scan)
689
+ self.use_mem = bool(use_mem)
690
+ self.use_rc_kl = bool(use_rc_kl)
691
+ self.use_barlow = bool(use_barlow)
692
+ self.use_tv = bool(use_tv)
693
+
694
+ self.sem_max_weight = float(sem_max_weight)
695
+ self.sem_warmup_steps = int(sem_warmup_steps)
696
+ self.rc_max_weight = float(rc_max_weight)
697
+ self.rc_warmup_steps = int(rc_warmup_steps)
698
+ self.rc_tau = float(rc_tau)
699
+ self.rc_bidirectional_stopgrad = bool(rc_bidirectional_stopgrad)
700
+
701
+ self.aux_ce_weight = float(aux_ce_weight)
702
+ self.gate_freeze_steps = int(gate_freeze_steps)
703
+ self.detach_gate = bool(detach_gate)
704
+ self.gate_sup_weight = float(gate_sup_weight)
705
+ self.gate_sup_warmup_steps = int(gate_sup_warmup_steps)
706
+ self.gate_temp = float(gate_temp)
707
+
708
+ if self.use_final_conv:
709
+ self.final_conv = nn.Conv1d(self.d_model, self.d_model, kernel_size=3, padding=1)
710
+
711
+ self._unused_init_kwargs = dict(unused_kwargs) if unused_kwargs else {}
712
+
713
+ def _branch_receives_rc(self, branch: str, block_idx: int) -> bool:
714
+ if self.disable_cross_view:
715
+ return branch == "B"
716
+
717
+ if branch == "A":
718
+ return (block_idx % 2) == 1
719
+ if branch == "B":
720
+ return (block_idx % 2) == 0
721
+ raise ValueError(f"Unknown branch: {branch}")
722
+
723
+ def _route_block_inputs(self, t_block: int, fwd_in: torch.Tensor, rc_in: torch.Tensor):
724
+ if self._branch_receives_rc("A", t_block):
725
+ return rc_in, fwd_in
726
+ return fwd_in, rc_in
727
+
728
+ def _realign_chunk_outputs(self, H_A: torch.Tensor, H_B: torch.Tensor, chunk_start: int):
729
+ num_blocks = H_A.size(0)
730
+ for c in range(num_blocks):
731
+ actual_t = chunk_start + c
732
+ if self._branch_receives_rc("A", actual_t):
733
+ H_A[c] = torch.flip(H_A[c], dims=[1])
734
+ if self._branch_receives_rc("B", actual_t):
735
+ H_B[c] = torch.flip(H_B[c], dims=[1])
736
+ return H_A, H_B
737
+
738
+ def _build_chunk_rc_masks(self, chunk_start: int, num_blocks: int, B: int, L_blk: int, device):
739
+ rcA = torch.tensor(
740
+ [self._branch_receives_rc("A", chunk_start + c) for c in range(num_blocks)],
741
+ device=device,
742
+ dtype=torch.bool,
743
+ )
744
+ maskA_row = rcA.repeat_interleave(B).unsqueeze(1)
745
+ maskA = maskA_row.expand(-1, L_blk)
746
+ maskB = ~maskA
747
+ return maskA, maskB
748
+
749
+ @torch.no_grad()
750
+ def update_ema(self):
751
+ if not self.use_ema_teacher:
752
+ return
753
+ if self.branchA_core_ema is None or self.branchB_core_ema is None:
754
+ return
755
+
756
+ d = float(self.ema_decay)
757
+ for m_ema, m in [
758
+ (self.branchA_core_ema, self.branchA_core),
759
+ (self.branchB_core_ema, self.branchB_core),
760
+ ]:
761
+ for p_ema, p in zip(m_ema.parameters(), m.parameters()):
762
+ p_ema.data.lerp_(p.data, 1.0 - d)
763
+
764
+ if self.use_bridge and (self.bridge is not None) and (self.bridge_ema is not None):
765
+ for p_ema, p in zip(self.bridge_ema.parameters(), self.bridge.parameters()):
766
+ p_ema.data.lerp_(p.data, 1.0 - d)
767
+
768
+ def _forward_s_scan_chunk_streaming(
769
+ self,
770
+ X_A: torch.Tensor,
771
+ X_B: torch.Tensor,
772
+ A_emb_fwd: torch.Tensor,
773
+ B_emb_rc: torch.Tensor,
774
+ mlm_labels: torch.Tensor,
775
+ chunk_start_t: torch.Tensor,
776
+ num_blocks_t: torch.Tensor,
777
+ step_t: torch.Tensor,
778
+ report_ab_t: torch.Tensor,
779
+ ):
780
+ chunk_start = int(chunk_start_t.item())
781
+ num_blocks = int(num_blocks_t.item())
782
+ step = int(step_t.item())
783
+ report_ab = bool(int(report_ab_t.item()))
784
+
785
+ BC, L_blk, H = X_A.shape
786
+ B = BC // max(1, num_blocks)
787
+ device = X_A.device
788
+
789
+ H_A = self.branchA_core(X_A)
790
+ H_B = self.branchB_core(X_B)
791
+
792
+ H_A = H_A.view(num_blocks, B, L_blk, H)
793
+ H_B = H_B.view(num_blocks, B, L_blk, H)
794
+ H_A, H_B = self._realign_chunk_outputs(H_A, H_B, chunk_start)
795
+ H_A = H_A.reshape(BC, L_blk, H)
796
+ H_B = H_B.reshape(BC, L_blk, H)
797
+
798
+ if self.use_bridge and self.bridge is not None:
799
+ H_A, H_B = self.bridge(H_A, H_B)
800
+
801
+ fA, rA = self.proj_A(H_A)
802
+ FA = fA + rA
803
+ fB, rB = self.proj_B(H_B)
804
+ FB = fB + rB
805
+
806
+ gate_in = torch.cat([FA, FB], dim=-1)
807
+ g_logits = self.gate_fuse(gate_in)
808
+ g_raw = torch.sigmoid(g_logits / max(1e-6, self.gate_temp))
809
+
810
+ if step < self.gate_freeze_steps:
811
+ g = 0.5 * torch.ones_like(g_raw)
812
+ else:
813
+ g = g_raw
814
+
815
+ if self.detach_gate:
816
+ mix = g.detach() * FA + (1 - g.detach()) * FB
817
+ else:
818
+ mix = g * FA + (1 - g) * FB
819
+
820
+ fused = F.layer_norm(mix, (mix.size(-1),))
821
+ fused = ensure_finite(fused, "fused_blk")
822
+
823
+ if self.use_final_conv:
824
+ fused = self.final_conv(fused.permute(0, 2, 1)).permute(0, 2, 1)
825
+
826
+ logits = self.out_linear(fused)
827
+ C = logits.size(-1)
828
+
829
+ logits2d = logits.reshape(-1, C)
830
+ labels1d = mlm_labels.reshape(-1)
831
+
832
+ ce_sum = F.cross_entropy(logits2d, labels1d, ignore_index=-100, reduction="sum")
833
+
834
+ with torch.no_grad():
835
+ valid = (labels1d != -100)
836
+ n_masked = valid.sum()
837
+
838
+ with torch.no_grad():
839
+ correct1 = torch.zeros([], device=device, dtype=torch.long)
840
+ correct3 = torch.zeros([], device=device, dtype=torch.long)
841
+ if n_masked.item() > 0:
842
+ sel_logits = logits2d[valid]
843
+ sel_labels = labels1d[valid]
844
+ pred1 = sel_logits.argmax(dim=-1)
845
+ correct1 = pred1.eq(sel_labels).sum()
846
+ top3 = sel_logits.topk(3, dim=-1).indices
847
+ correct3 = top3.eq(sel_labels.unsqueeze(-1)).any(dim=-1).sum()
848
+
849
+ total_aux = torch.zeros([], device=device, dtype=torch.float32)
850
+
851
+ if self.pretrain:
852
+ maskA, maskB = self._build_chunk_rc_masks(chunk_start, num_blocks, B, L_blk, device)
853
+
854
+ need_sem = self.sem_max_weight > 0.0
855
+ if need_sem:
856
+ with torch.no_grad():
857
+ teacherA = self.branchA_core_ema if self.use_ema_teacher else self.branchA_core
858
+ teacherB = self.branchB_core_ema if self.use_ema_teacher else self.branchB_core
859
+ tbridge = self.bridge_ema if (self.use_bridge and self.use_ema_teacher and self.bridge_ema is not None) else (
860
+ self.bridge if self.use_bridge else None
861
+ )
862
+
863
+ mods = [teacherA, teacherB] + ([tbridge] if tbridge is not None else [])
864
+ with eval_mode(*mods):
865
+ R_plus_A = teacherA(A_emb_fwd)
866
+ R_plus_B = teacherB(A_emb_fwd)
867
+ if tbridge is not None:
868
+ R_plus_A, R_plus_B = tbridge(R_plus_A, R_plus_B)
869
+
870
+ R_minus_A_rc = teacherA(B_emb_rc)
871
+ R_minus_B_rc = teacherB(B_emb_rc)
872
+ R_minus_A_fwd = torch.flip(R_minus_A_rc, dims=[1])
873
+ R_minus_B_fwd = torch.flip(R_minus_B_rc, dims=[1])
874
+ if tbridge is not None:
875
+ R_minus_A_fwd, R_minus_B_fwd = tbridge(R_minus_A_fwd, R_minus_B_fwd)
876
+
877
+ R_A_teacher = torch.where(maskA.unsqueeze(-1), R_minus_A_fwd, R_plus_A)
878
+ R_B_teacher = torch.where(maskB.unsqueeze(-1), R_minus_B_fwd, R_plus_B)
879
+
880
+ sem_A = semantic_preservation_loss(R_A_teacher.float(), FA.float())
881
+ sem_B = semantic_preservation_loss(R_B_teacher.float(), FB.float())
882
+ w_sem = linear_warmup_weight(step, self.sem_warmup_steps, self.sem_max_weight)
883
+ total_aux = total_aux + w_sem * (sem_A + sem_B)
884
+
885
+ if (self.gate_sup_weight > 0.0) and (step >= self.gate_freeze_steps):
886
+ g_target = (~maskA).float().unsqueeze(-1)
887
+ g_token_logits = g_logits.mean(dim=-1, keepdim=True) / max(1e-6, self.gate_temp)
888
+ w_gate = linear_warmup_weight(
889
+ step - self.gate_freeze_steps,
890
+ self.gate_sup_warmup_steps,
891
+ self.gate_sup_weight,
892
+ )
893
+ total_aux = total_aux + w_gate * F.binary_cross_entropy_with_logits(
894
+ g_token_logits.float(), g_target.float()
895
+ )
896
+
897
+ need_rc = bool(self.use_rc_kl and (self.rc_max_weight > 0.0))
898
+ need_ab = bool(report_ab or need_rc)
899
+
900
+ logitsA = logitsB = None
901
+ if need_ab:
902
+ logitsA = self.out_linear(FA)
903
+ logitsB = self.out_linear(FB)
904
+
905
+ if need_rc:
906
+ if self.rc_bidirectional_stopgrad:
907
+ rc = rc_consistency_bidirectional_stopgrad(logitsA, logitsB, self.P_comp, tau=self.rc_tau)
908
+ else:
909
+ rc = rc_consistency_kl(logitsA, logitsB, self.P_comp, tau=self.rc_tau)
910
+ w_rc = linear_warmup_weight(step, self.rc_warmup_steps, self.rc_max_weight)
911
+ total_aux = total_aux + w_rc * rc
912
+
913
+ if self.use_barlow:
914
+ total_aux = total_aux + barlow_strand_loss_v2(H_A.float(), H_B.float())
915
+ if self.use_tv:
916
+ total_aux = total_aux + tv_mixed(fused.float())
917
+
918
+ with torch.no_grad():
919
+ correctA1 = torch.zeros([], device=device, dtype=torch.long)
920
+ correctB1 = torch.zeros([], device=device, dtype=torch.long)
921
+ correctA3 = torch.zeros([], device=device, dtype=torch.long)
922
+ correctB3 = torch.zeros([], device=device, dtype=torch.long)
923
+
924
+ if report_ab and n_masked.item() > 0:
925
+ if ('logitsA' not in locals()) or (logitsA is None):
926
+ logitsA = self.out_linear(FA)
927
+ logitsB = self.out_linear(FB)
928
+
929
+ _, perm = make_complement_perm(C, device=device)
930
+
931
+ valid = (labels1d != -100)
932
+ labels_safe = labels1d.clamp_min(0)
933
+ labels_comp = perm[labels_safe]
934
+
935
+ maskA_tok, maskB_tok = self._build_chunk_rc_masks(chunk_start, num_blocks, B, L_blk, device)
936
+ maskA_tok = maskA_tok.reshape(-1)
937
+ maskB_tok = maskB_tok.reshape(-1)
938
+
939
+ yA = torch.where(maskA_tok, labels_comp, labels_safe)[valid]
940
+ yB = torch.where(maskB_tok, labels_comp, labels_safe)[valid]
941
+
942
+ A2d = logitsA.reshape(-1, C)[valid]
943
+ B2d = logitsB.reshape(-1, C)[valid]
944
+
945
+ predA1 = A2d.argmax(dim=-1)
946
+ predB1 = B2d.argmax(dim=-1)
947
+ correctA1 = predA1.eq(yA).sum()
948
+ correctB1 = predB1.eq(yB).sum()
949
+
950
+ topA3 = A2d.topk(3, dim=-1).indices
951
+ topB3 = B2d.topk(3, dim=-1).indices
952
+ correctA3 = topA3.eq(yA.unsqueeze(-1)).any(dim=-1).sum()
953
+ correctB3 = topB3.eq(yB.unsqueeze(-1)).any(dim=-1).sum()
954
+
955
+ return ce_sum, n_masked, total_aux, correct1, correct3, correctA1, correctB1, correctA3, correctB3
956
+
957
+ def _forward_s_scan_chunk(
958
+ self,
959
+ X_A: torch.Tensor,
960
+ X_B: torch.Tensor,
961
+ A_emb_fwd: torch.Tensor,
962
+ B_emb_rc: torch.Tensor,
963
+ chunk_start_t: torch.Tensor,
964
+ num_blocks_t: torch.Tensor,
965
+ step_t: torch.Tensor,
966
+ need_logits_t: torch.Tensor,
967
+ need_ab_t: torch.Tensor,
968
+ ):
969
+ chunk_start = int(chunk_start_t.item())
970
+ num_blocks = int(num_blocks_t.item())
971
+ step = int(step_t.item())
972
+ need_logits = bool(int(need_logits_t.item()))
973
+ need_ab = bool(int(need_ab_t.item()))
974
+
975
+ BC, L_blk, H = X_A.shape
976
+ B = BC // max(1, num_blocks)
977
+ device = X_A.device
978
+
979
+ H_A = self.branchA_core(X_A)
980
+ H_B = self.branchB_core(X_B)
981
+
982
+ H_A = H_A.view(num_blocks, B, L_blk, H)
983
+ H_B = H_B.view(num_blocks, B, L_blk, H)
984
+ H_A, H_B = self._realign_chunk_outputs(H_A, H_B, chunk_start)
985
+ H_A = H_A.reshape(BC, L_blk, H)
986
+ H_B = H_B.reshape(BC, L_blk, H)
987
+
988
+ if self.use_bridge and self.bridge is not None:
989
+ H_A, H_B = self.bridge(H_A, H_B)
990
+
991
+ fA, rA = self.proj_A(H_A)
992
+ FA = fA + rA
993
+ fB, rB = self.proj_B(H_B)
994
+ FB = fB + rB
995
+
996
+ gate_in_blk = torch.cat([FA, FB], dim=-1)
997
+ g_logits_blk = self.gate_fuse(gate_in_blk)
998
+ g_raw_blk = torch.sigmoid(g_logits_blk / max(1e-6, self.gate_temp))
999
+
1000
+ if step < self.gate_freeze_steps:
1001
+ g_blk = 0.5 * torch.ones_like(g_raw_blk)
1002
+ else:
1003
+ g_blk = g_raw_blk
1004
+
1005
+ if self.detach_gate:
1006
+ mix_blk = g_blk.detach() * FA + (1 - g_blk.detach()) * FB
1007
+ else:
1008
+ mix_blk = g_blk * FA + (1 - g_blk) * FB
1009
+
1010
+ fused_blk = F.layer_norm(mix_blk, (mix_blk.size(-1),))
1011
+ fused_blk = ensure_finite(fused_blk, "fused_blk")
1012
+
1013
+ if self.use_final_conv:
1014
+ fused_blk = self.final_conv(fused_blk.permute(0, 2, 1)).permute(0, 2, 1)
1015
+
1016
+ logits_blk = self.out_linear(fused_blk) if need_logits else fused_blk.new_empty((0,))
1017
+
1018
+ need_rc_logits = bool(self.use_rc_kl and (self.rc_max_weight > 0.0))
1019
+ need_ab_internal = bool(need_ab or need_rc_logits)
1020
+
1021
+ logitsA_blk = self.out_linear(FA) if need_ab_internal else fused_blk.new_empty((0,))
1022
+ logitsB_blk = self.out_linear(FB) if need_ab_internal else fused_blk.new_empty((0,))
1023
+
1024
+ total_aux_blk = torch.zeros([], device=device, dtype=torch.float32)
1025
+
1026
+ if self.pretrain:
1027
+ maskA, maskB = self._build_chunk_rc_masks(chunk_start, num_blocks, B, L_blk, device)
1028
+
1029
+ need_sem = self.sem_max_weight > 0.0
1030
+ if need_sem:
1031
+ with torch.no_grad():
1032
+ teacherA = self.branchA_core_ema if self.use_ema_teacher else self.branchA_core
1033
+ teacherB = self.branchB_core_ema if self.use_ema_teacher else self.branchB_core
1034
+ tbridge = self.bridge_ema if (self.use_bridge and self.use_ema_teacher and self.bridge_ema is not None) else (
1035
+ self.bridge if self.use_bridge else None
1036
+ )
1037
+
1038
+ mods = [teacherA, teacherB] + ([tbridge] if tbridge is not None else [])
1039
+ with eval_mode(*mods):
1040
+ R_plus_A = teacherA(A_emb_fwd)
1041
+ R_plus_B = teacherB(A_emb_fwd)
1042
+ if tbridge is not None:
1043
+ R_plus_A, R_plus_B = tbridge(R_plus_A, R_plus_B)
1044
+
1045
+ R_minus_A_rc = teacherA(B_emb_rc)
1046
+ R_minus_B_rc = teacherB(B_emb_rc)
1047
+ R_minus_A_fwd = torch.flip(R_minus_A_rc, dims=[1])
1048
+ R_minus_B_fwd = torch.flip(R_minus_B_rc, dims=[1])
1049
+ if tbridge is not None:
1050
+ R_minus_A_fwd, R_minus_B_fwd = tbridge(R_minus_A_fwd, R_minus_B_fwd)
1051
+
1052
+ R_A_teacher = torch.where(maskA.unsqueeze(-1), R_minus_A_fwd, R_plus_A)
1053
+ R_B_teacher = torch.where(maskB.unsqueeze(-1), R_minus_B_fwd, R_plus_B)
1054
+
1055
+ sem_A = semantic_preservation_loss(R_A_teacher.float(), FA.float())
1056
+ sem_B = semantic_preservation_loss(R_B_teacher.float(), FB.float())
1057
+ w_sem = linear_warmup_weight(step, self.sem_warmup_steps, self.sem_max_weight)
1058
+ total_aux_blk = total_aux_blk + w_sem * (sem_A + sem_B)
1059
+
1060
+ if (self.gate_sup_weight > 0.0) and (step >= self.gate_freeze_steps):
1061
+ g_target_blk = (~maskA).float().unsqueeze(-1)
1062
+ g_token_logits_blk = g_logits_blk.mean(dim=-1, keepdim=True) / max(1e-6, self.gate_temp)
1063
+ w_gate = linear_warmup_weight(
1064
+ step - self.gate_freeze_steps,
1065
+ self.gate_sup_warmup_steps,
1066
+ self.gate_sup_weight,
1067
+ )
1068
+ total_aux_blk = total_aux_blk + w_gate * F.binary_cross_entropy_with_logits(
1069
+ g_token_logits_blk.float(),
1070
+ g_target_blk.float(),
1071
+ )
1072
+
1073
+ if self.use_rc_kl and (self.rc_max_weight > 0.0):
1074
+ if logitsA_blk.numel() == 0:
1075
+ logitsA_blk = self.out_linear(FA)
1076
+ if logitsB_blk.numel() == 0:
1077
+ logitsB_blk = self.out_linear(FB)
1078
+
1079
+ if self.rc_bidirectional_stopgrad:
1080
+ rc = rc_consistency_bidirectional_stopgrad(logitsA_blk, logitsB_blk, self.P_comp, tau=self.rc_tau)
1081
+ else:
1082
+ rc = rc_consistency_kl(logitsA_blk, logitsB_blk, self.P_comp, tau=self.rc_tau)
1083
+
1084
+ w_rc = linear_warmup_weight(step, self.rc_warmup_steps, self.rc_max_weight)
1085
+ total_aux_blk = total_aux_blk + w_rc * rc
1086
+
1087
+ if self.use_barlow:
1088
+ total_aux_blk = total_aux_blk + barlow_strand_loss_v2(H_A.float(), H_B.float())
1089
+ if self.use_tv:
1090
+ total_aux_blk = total_aux_blk + tv_mixed(fused_blk.float())
1091
+
1092
+ return fused_blk, logits_blk, logitsA_blk, logitsB_blk, total_aux_blk
1093
+
1094
+ def forward(self, seq, t=None, cls=None, return_embedding=False, state=None, mask=None, **kwargs):
1095
+ step = int(self.g_step.item())
1096
+ if self.training:
1097
+ self.g_step += 1
1098
+
1099
+ _ = mask
1100
+
1101
+ mlm_mask = None
1102
+ mlm_labels = None
1103
+ special_mask = None
1104
+
1105
+ if self.pretrain:
1106
+ if isinstance(seq, (tuple, list)):
1107
+ mlm_mask = seq[1] if len(seq) >= 2 else None
1108
+ mlm_labels = seq[2] if len(seq) >= 3 else None
1109
+ special_mask = seq[3] if len(seq) >= 4 else None
1110
+ seq = seq[0]
1111
+
1112
+ device_type = seq.device.type if seq.device.type in ["cuda", "cpu", "xpu"] else "cuda"
1113
+ amp_dtype = preferred_amp_dtype()
1114
+
1115
+ rc_seq = reverse_complement(seq)
1116
+
1117
+ with torch.autocast(device_type=device_type, dtype=amp_dtype, enabled=(device_type == "cuda")):
1118
+ seq_oh = one_hot_float(seq, self.alphabet_size, dtype=amp_dtype)
1119
+ rc_oh = one_hot_float(rc_seq, self.alphabet_size, dtype=amp_dtype)
1120
+
1121
+ if special_mask is not None:
1122
+ special_mask = special_mask.to(dtype=torch.bool, device=seq.device)
1123
+ rc_special_mask = torch.flip(special_mask, dims=[1])
1124
+
1125
+ seq_oh = seq_oh.masked_fill(special_mask.unsqueeze(-1), 0.0)
1126
+ rc_oh = rc_oh.masked_fill(rc_special_mask.unsqueeze(-1), 0.0)
1127
+ else:
1128
+ rc_special_mask = None
1129
+
1130
+ h = F.gelu(self.linear(seq_oh.permute(0, 2, 1)))
1131
+ rc_h = F.gelu(self.rc_linear(rc_oh.permute(0, 2, 1)))
1132
+ del seq_oh, rc_oh
1133
+
1134
+ if special_mask is not None:
1135
+ non_special = (~special_mask).to(dtype=h.dtype).unsqueeze(1)
1136
+ rc_non_special = (~rc_special_mask).to(dtype=rc_h.dtype).unsqueeze(1)
1137
+ h = h * non_special
1138
+ rc_h = rc_h * rc_non_special
1139
+
1140
+ if mlm_mask is not None:
1141
+ mlm_mask_f = mlm_mask.to(dtype=h.dtype, device=h.device).unsqueeze(1)
1142
+ rc_mlm_mask_f = torch.flip(mlm_mask, dims=[1]).to(dtype=rc_h.dtype, device=rc_h.device).unsqueeze(1)
1143
+ h = h + mlm_mask_f * self.mlm_mask_embed.view(1, -1, 1)
1144
+ rc_h = rc_h + rc_mlm_mask_f * self.mlm_mask_embed.view(1, -1, 1)
1145
+
1146
+ if special_mask is not None:
1147
+ special_mask_f = special_mask.to(dtype=h.dtype, device=h.device).unsqueeze(1)
1148
+ rc_special_mask_f = rc_special_mask.to(dtype=rc_h.dtype, device=rc_h.device).unsqueeze(1)
1149
+ h = h + special_mask_f * self.special_token_embed.view(1, -1, 1)
1150
+ rc_h = rc_h + rc_special_mask_f * self.special_token_embed.view(1, -1, 1)
1151
+
1152
+ use_streaming = bool(
1153
+ self.pretrain
1154
+ and self.use_s_scan
1155
+ and self.streaming_loss
1156
+ and (mlm_labels is not None)
1157
+ and (not self.for_representation)
1158
+ )
1159
+
1160
+ if use_streaming:
1161
+ B, H, L = h.shape
1162
+ l = self.block_size
1163
+ K = (L + l - 1) // l
1164
+ chunk_size = max(1, self.checkpoint_chunk_size)
1165
+
1166
+ ce_sum_total = torch.zeros([], device=seq.device, dtype=torch.float32)
1167
+ n_total = torch.zeros([], device=seq.device, dtype=torch.long)
1168
+ total_aux = torch.zeros([], device=seq.device, dtype=torch.float32)
1169
+
1170
+ correct1 = torch.zeros([], device=seq.device, dtype=torch.long)
1171
+ correct3 = torch.zeros([], device=seq.device, dtype=torch.long)
1172
+ correctA1 = torch.zeros([], device=seq.device, dtype=torch.long)
1173
+ correctB1 = torch.zeros([], device=seq.device, dtype=torch.long)
1174
+ correctA3 = torch.zeros([], device=seq.device, dtype=torch.long)
1175
+ correctB3 = torch.zeros([], device=seq.device, dtype=torch.long)
1176
+
1177
+ keep_rate = mlm_mask.float().mean() if mlm_mask is not None else torch.tensor(1.0, device=seq.device)
1178
+ report_ab_t = torch.tensor(int(self.streaming_report_ab), device=seq.device)
1179
+
1180
+ for chunk_start in range(0, K, chunk_size):
1181
+ chunk_end = min(chunk_start + chunk_size, K)
1182
+
1183
+ X_A_batch, X_B_batch = [], []
1184
+ Aemb_batch, Bemb_batch = [], []
1185
+ labels_batch = []
1186
+ lengths = []
1187
+
1188
+ for t_block in range(chunk_start, chunk_end):
1189
+ start = t_block * l
1190
+ end = min(start + l, L)
1191
+ blk_len = end - start
1192
+ lengths.append(blk_len)
1193
+
1194
+ fwd_emb = h[:, :, start:end].transpose(1, 2).contiguous()
1195
+ rc_emb = rc_h[:, :, start:end].transpose(1, 2).contiguous()
1196
+
1197
+ fwd_in = self.dropout(h[:, :, start:end]).transpose(1, 2).contiguous()
1198
+ rc_in = self.dropout(rc_h[:, :, start:end]).transpose(1, 2).contiguous()
1199
+
1200
+ XA, XB = self._route_block_inputs(t_block, fwd_in, rc_in)
1201
+ X_A_batch.append(XA)
1202
+ X_B_batch.append(XB)
1203
+
1204
+ Aemb_batch.append(fwd_emb)
1205
+ Bemb_batch.append(rc_emb)
1206
+ labels_batch.append(mlm_labels[:, start:end])
1207
+
1208
+ if len(set(lengths)) == 1:
1209
+ nb = len(X_A_batch)
1210
+
1211
+ X_A_tensor = torch.cat(X_A_batch, dim=0)
1212
+ X_B_tensor = torch.cat(X_B_batch, dim=0)
1213
+ Aemb_tensor = torch.cat(Aemb_batch, dim=0)
1214
+ Bemb_tensor = torch.cat(Bemb_batch, dim=0)
1215
+ labels_tensor = torch.cat(labels_batch, dim=0)
1216
+
1217
+ if self.training and self.use_checkpointing:
1218
+ ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = cp.checkpoint(
1219
+ self._forward_s_scan_chunk_streaming,
1220
+ X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor, labels_tensor,
1221
+ torch.tensor(chunk_start, device=seq.device),
1222
+ torch.tensor(nb, device=seq.device),
1223
+ torch.tensor(step, device=seq.device),
1224
+ report_ab_t,
1225
+ use_reentrant=False,
1226
+ )
1227
+ else:
1228
+ ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = self._forward_s_scan_chunk_streaming(
1229
+ X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor, labels_tensor,
1230
+ torch.tensor(chunk_start, device=seq.device),
1231
+ torch.tensor(nb, device=seq.device),
1232
+ torch.tensor(step, device=seq.device),
1233
+ report_ab_t,
1234
+ )
1235
+
1236
+ ce_sum_total = ce_sum_total + ce_sum
1237
+ n_total = n_total + n_masked
1238
+ total_aux = total_aux + aux_blk
1239
+
1240
+ correct1 += c1
1241
+ correct3 += c3
1242
+ correctA1 += a1
1243
+ correctB1 += b1
1244
+ correctA3 += a3
1245
+ correctB3 += b3
1246
+
1247
+ del X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor, labels_tensor
1248
+
1249
+ else:
1250
+ for idx, t_block in enumerate(range(chunk_start, chunk_end)):
1251
+ if self.training and self.use_checkpointing:
1252
+ ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = cp.checkpoint(
1253
+ self._forward_s_scan_chunk_streaming,
1254
+ X_A_batch[idx], X_B_batch[idx], Aemb_batch[idx], Bemb_batch[idx], labels_batch[idx],
1255
+ torch.tensor(t_block, device=seq.device),
1256
+ torch.tensor(1, device=seq.device),
1257
+ torch.tensor(step, device=seq.device),
1258
+ report_ab_t,
1259
+ use_reentrant=False,
1260
+ )
1261
+ else:
1262
+ ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = self._forward_s_scan_chunk_streaming(
1263
+ X_A_batch[idx], X_B_batch[idx], Aemb_batch[idx], Bemb_batch[idx], labels_batch[idx],
1264
+ torch.tensor(t_block, device=seq.device),
1265
+ torch.tensor(1, device=seq.device),
1266
+ torch.tensor(step, device=seq.device),
1267
+ report_ab_t,
1268
+ )
1269
+
1270
+ ce_sum_total = ce_sum_total + ce_sum
1271
+ n_total = n_total + n_masked
1272
+ total_aux = total_aux + aux_blk
1273
+
1274
+ correct1 += c1
1275
+ correct3 += c3
1276
+ correctA1 += a1
1277
+ correctB1 += b1
1278
+ correctA3 += a3
1279
+ correctB3 += b3
1280
+
1281
+ del h, rc_h
1282
+
1283
+ if self.training and self.use_ema_teacher and self.auto_update_ema_in_forward:
1284
+ self.update_ema()
1285
+
1286
+ HybridOutput = namedtuple("HybridOutput", ["logits"])
1287
+ step_t = torch.tensor(step, device=seq.device, dtype=torch.long)
1288
+
1289
+ stats = torch.stack([
1290
+ keep_rate.to(torch.float32),
1291
+ correct1.to(torch.float32),
1292
+ correct3.to(torch.float32),
1293
+ correctA1.to(torch.float32),
1294
+ correctB1.to(torch.float32),
1295
+ correctA3.to(torch.float32),
1296
+ correctB3.to(torch.float32),
1297
+ ], dim=0)
1298
+
1299
+ return HybridOutput(logits=(ce_sum_total, n_total, total_aux, stats, step_t)), None
1300
+
1301
+ fused = None
1302
+
1303
+ if self.use_s_scan:
1304
+ B, H, L = h.shape
1305
+ l = self.block_size
1306
+ K = (L + l - 1) // l
1307
+ chunk_size = max(1, self.checkpoint_chunk_size)
1308
+
1309
+ collect_fused = bool(self.for_representation)
1310
+ collect_logits = (not self.for_representation) or self.pretrain
1311
+ need_ab_logits = bool((self.pretrain and self.return_ab_logits) or self.use_rc_kl)
1312
+
1313
+ fused_out = torch.empty((B, L, self.d_model), device=seq.device, dtype=amp_dtype) if collect_fused else None
1314
+ logits_out = torch.empty((B, L, self.alphabet_size), device=seq.device, dtype=amp_dtype) if collect_logits else None
1315
+ logitsA_out = torch.empty((B, L, self.alphabet_size), device=seq.device, dtype=amp_dtype) if need_ab_logits else None
1316
+ logitsB_out = torch.empty((B, L, self.alphabet_size), device=seq.device, dtype=amp_dtype) if need_ab_logits else None
1317
+
1318
+ mask_A_rc = torch.empty((B, L), device=seq.device, dtype=torch.bool)
1319
+ mask_B_rc = torch.empty((B, L), device=seq.device, dtype=torch.bool)
1320
+
1321
+ total_aux = torch.zeros([], device=seq.device, dtype=torch.float32)
1322
+
1323
+ for chunk_start in range(0, K, chunk_size):
1324
+ chunk_end = min(chunk_start + chunk_size, K)
1325
+
1326
+ X_A_batch, X_B_batch = [], []
1327
+ Aemb_batch, Bemb_batch = [], []
1328
+ lengths = []
1329
+
1330
+ for t_block in range(chunk_start, chunk_end):
1331
+ start = t_block * l
1332
+ end = min(start + l, L)
1333
+ blk_len = end - start
1334
+ lengths.append(blk_len)
1335
+
1336
+ fwd_emb = h[:, :, start:end].transpose(1, 2).contiguous()
1337
+ rc_emb = rc_h[:, :, start:end].transpose(1, 2).contiguous()
1338
+
1339
+ fwd_in = self.dropout(h[:, :, start:end]).transpose(1, 2).contiguous()
1340
+ rc_in = self.dropout(rc_h[:, :, start:end]).transpose(1, 2).contiguous()
1341
+
1342
+ XA, XB = self._route_block_inputs(t_block, fwd_in, rc_in)
1343
+ X_A_batch.append(XA)
1344
+ X_B_batch.append(XB)
1345
+
1346
+ Aemb_batch.append(fwd_emb)
1347
+ Bemb_batch.append(rc_emb)
1348
+
1349
+ mask_A_rc[:, start:end] = self._branch_receives_rc("A", t_block)
1350
+ mask_B_rc[:, start:end] = self._branch_receives_rc("B", t_block)
1351
+
1352
+ if len(set(lengths)) == 1:
1353
+ blk_len = lengths[0]
1354
+ X_A_tensor = torch.cat(X_A_batch, dim=0)
1355
+ X_B_tensor = torch.cat(X_B_batch, dim=0)
1356
+ Aemb_tensor = torch.cat(Aemb_batch, dim=0)
1357
+ Bemb_tensor = torch.cat(Bemb_batch, dim=0)
1358
+
1359
+ need_logits_t = torch.tensor(int(collect_logits), device=seq.device)
1360
+ need_ab_t = torch.tensor(int(need_ab_logits), device=seq.device)
1361
+
1362
+ if self.training and self.use_checkpointing:
1363
+ fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = cp.checkpoint(
1364
+ self._forward_s_scan_chunk,
1365
+ X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
1366
+ torch.tensor(chunk_start, device=seq.device),
1367
+ torch.tensor(len(X_A_batch), device=seq.device),
1368
+ torch.tensor(step, device=seq.device),
1369
+ need_logits_t, need_ab_t,
1370
+ use_reentrant=False,
1371
+ )
1372
+ else:
1373
+ fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = self._forward_s_scan_chunk(
1374
+ X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
1375
+ torch.tensor(chunk_start, device=seq.device),
1376
+ torch.tensor(len(X_A_batch), device=seq.device),
1377
+ torch.tensor(step, device=seq.device),
1378
+ need_logits_t, need_ab_t,
1379
+ )
1380
+
1381
+ total_aux = total_aux + aux_blk
1382
+
1383
+ nb = len(X_A_batch)
1384
+ fused_view = fused_blk.view(nb, B, blk_len, -1)
1385
+ logits_view = logits_blk.view(nb, B, blk_len, -1) if (collect_logits and logits_blk.numel() > 0) else None
1386
+ logitsA_view = logitsA_blk.view(nb, B, blk_len, -1) if (need_ab_logits and logitsA_blk.numel() > 0) else None
1387
+ logitsB_view = logitsB_blk.view(nb, B, blk_len, -1) if (need_ab_logits and logitsB_blk.numel() > 0) else None
1388
+
1389
+ for c, t_block in enumerate(range(chunk_start, chunk_end)):
1390
+ start = t_block * l
1391
+ end = min(start + l, L)
1392
+
1393
+ if collect_fused:
1394
+ fused_out[:, start:end, :] = fused_view[c]
1395
+ if collect_logits:
1396
+ logits_out[:, start:end, :] = logits_view[c]
1397
+ if need_ab_logits:
1398
+ logitsA_out[:, start:end, :] = logitsA_view[c]
1399
+ logitsB_out[:, start:end, :] = logitsB_view[c]
1400
+
1401
+ del X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor
1402
+ del fused_blk, logits_blk, logitsA_blk, logitsB_blk
1403
+
1404
+ else:
1405
+ for idx, t_block in enumerate(range(chunk_start, chunk_end)):
1406
+ start = t_block * l
1407
+ end = min(start + l, L)
1408
+
1409
+ X_A_tensor = X_A_batch[idx]
1410
+ X_B_tensor = X_B_batch[idx]
1411
+ Aemb_tensor = Aemb_batch[idx]
1412
+ Bemb_tensor = Bemb_batch[idx]
1413
+
1414
+ need_logits_t = torch.tensor(int(collect_logits), device=seq.device)
1415
+ need_ab_t = torch.tensor(int(need_ab_logits), device=seq.device)
1416
+
1417
+ if self.training and self.use_checkpointing:
1418
+ fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = cp.checkpoint(
1419
+ self._forward_s_scan_chunk,
1420
+ X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
1421
+ torch.tensor(t_block, device=seq.device),
1422
+ torch.tensor(1, device=seq.device),
1423
+ torch.tensor(step, device=seq.device),
1424
+ need_logits_t, need_ab_t,
1425
+ use_reentrant=False,
1426
+ )
1427
+ else:
1428
+ fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = self._forward_s_scan_chunk(
1429
+ X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
1430
+ torch.tensor(t_block, device=seq.device),
1431
+ torch.tensor(1, device=seq.device),
1432
+ torch.tensor(step, device=seq.device),
1433
+ need_logits_t, need_ab_t,
1434
+ )
1435
+
1436
+ total_aux = total_aux + aux_blk
1437
+
1438
+ if collect_fused:
1439
+ fused_out[:, start:end, :] = fused_blk
1440
+ if collect_logits:
1441
+ logits_out[:, start:end, :] = logits_blk
1442
+ if need_ab_logits and logitsA_blk.numel() > 0:
1443
+ logitsA_out[:, start:end, :] = logitsA_blk
1444
+ logitsB_out[:, start:end, :] = logitsB_blk
1445
+
1446
+ del fused_blk, logits_blk, logitsA_blk, logitsB_blk
1447
+
1448
+ del h, rc_h
1449
+
1450
+ logits = logits_out if collect_logits else None
1451
+ logits_A_only = logitsA_out if need_ab_logits else None
1452
+ logits_B_only = logitsB_out if need_ab_logits else None
1453
+ fused = fused_out if collect_fused else None
1454
+
1455
+ else:
1456
+ feat = self.dropout(h).transpose(1, 2).contiguous()
1457
+ rc_feat = self.dropout(rc_h).transpose(1, 2).contiguous()
1458
+
1459
+ H_A = self.branchA_core(feat)
1460
+ H_Br = self.branchB_core(rc_feat)
1461
+ R_A = H_A
1462
+ R_B = torch.flip(H_Br, dims=[1])
1463
+
1464
+ if self.use_bridge and self.bridge is not None:
1465
+ R_A, R_B = self.bridge(R_A, R_B)
1466
+
1467
+ fA, rA = self.proj_A(R_A)
1468
+ FA = fA + rA
1469
+ fB, rB = self.proj_B(R_B)
1470
+ FB = fB + rB
1471
+
1472
+ gate_in = torch.cat([FA, FB], dim=-1)
1473
+ g_logits = self.gate_fuse(gate_in)
1474
+ g_raw = torch.sigmoid(g_logits / max(1e-6, self.gate_temp))
1475
+
1476
+ if step < self.gate_freeze_steps:
1477
+ g = 0.5 * torch.ones_like(g_raw)
1478
+ else:
1479
+ g = g_raw
1480
+
1481
+ if self.detach_gate:
1482
+ mix = g.detach() * FA + (1 - g.detach()) * FB
1483
+ else:
1484
+ mix = g * FA + (1 - g) * FB
1485
+
1486
+ fused = F.layer_norm(mix, (mix.size(-1),))
1487
+ fused = ensure_finite(fused, "fused")
1488
+
1489
+ if self.use_final_conv:
1490
+ fused = self.final_conv(fused.permute(0, 2, 1)).permute(0, 2, 1)
1491
+
1492
+ logits = self.out_linear(fused) if (not self.for_representation or self.pretrain) else None
1493
+
1494
+ need_ab_logits = bool((self.pretrain and self.return_ab_logits) or self.use_rc_kl)
1495
+ logits_A_only = self.out_linear(FA) if need_ab_logits else None
1496
+ logits_B_only = self.out_linear(FB) if need_ab_logits else None
1497
+
1498
+ mask_A_rc = torch.zeros(FA.size()[:2], dtype=torch.bool, device=FA.device)
1499
+ mask_B_rc = torch.ones_like(mask_A_rc)
1500
+
1501
+ total_aux = logits.new_zeros(()) if self.pretrain and logits is not None else None
1502
+
1503
+ del h, rc_h, feat, rc_feat
1504
+
1505
+ if self.for_representation:
1506
+ return fused, None
1507
+
1508
+ if self.training and self.use_ema_teacher and self.auto_update_ema_in_forward:
1509
+ self.update_ema()
1510
+
1511
+ if self.pretrain:
1512
+ HybridOutput = namedtuple("HybridOutput", ["logits"])
1513
+ return HybridOutput(
1514
+ logits=(
1515
+ logits,
1516
+ mlm_mask,
1517
+ total_aux,
1518
+ logits_A_only.detach() if logits_A_only is not None else None,
1519
+ logits_B_only.detach() if logits_B_only is not None else None,
1520
+ mask_A_rc.detach() if mask_A_rc is not None else None,
1521
+ mask_B_rc.detach() if mask_B_rc is not None else None,
1522
+ int(step),
1523
+ )
1524
+ ), None
1525
+
1526
+ return logits, None
1527
+
1528
+ @property
1529
+ def d_output(self):
1530
+ if getattr(self, "d_model", None) is None:
1531
+ raise NotImplementedError("SequenceModule instantiation must set d_output")
1532
+ return self.d_model
1533
+
1534
+
1535
+ class CrossDNAForMaskedLM(PreTrainedModel):
1536
+ config_class = CrossDNAConfig
1537
+ base_model_prefix = "backbone"
1538
+ main_input_name = "input_ids"
1539
+
1540
+ # allow older HF checkpoints to load with warnings rather than hard-failing
1541
+ _keys_to_ignore_on_load_missing = [
1542
+ r"backbone\.mlm_mask_embed",
1543
+ r"backbone\.special_token_embed",
1544
+ r"backbone\.branchA_core_ema\..*",
1545
+ r"backbone\.branchB_core_ema\..*",
1546
+ r"backbone\.bridge_ema\..*",
1547
+ ]
1548
+
1549
+ def __init__(self, config: CrossDNAConfig):
1550
+ super().__init__(config)
1551
+ self.config = config
1552
+ self.backbone = SSScanDNAHybridModel(config=config)
1553
+ self.post_init()
1554
+
1555
+ @property
1556
+ def tokenizer_vocab_size(self) -> int:
1557
+ return int(getattr(self.config, "vocab_size", self.config.alphabet_size))
1558
+
1559
+ @property
1560
+ def dna_token_ids(self) -> Dict[str, int]:
1561
+ cfg_ids = getattr(self.config, "dna_token_ids", None)
1562
+ if cfg_ids is not None:
1563
+ return dict(cfg_ids)
1564
+ start = int(getattr(self.config, "dna_token_start_id", 7))
1565
+ return {"A": start + 0, "C": start + 1, "G": start + 2, "T": start + 3, "N": start + 4}
1566
+
1567
+ @property
1568
+ def compact_n_token_id(self) -> int:
1569
+ return int(getattr(self.config, "compact_n_token_id", self.config.alphabet_size - 1))
1570
+
1571
+ def _to_compact_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1572
+ input_ids = input_ids.long()
1573
+ if input_ids.numel() == 0:
1574
+ return input_ids
1575
+
1576
+ mn, mx = int(input_ids.min()), int(input_ids.max())
1577
+ if 0 <= mn and mx < self.config.alphabet_size:
1578
+ return input_ids
1579
+
1580
+ compact = torch.full_like(input_ids, self.compact_n_token_id)
1581
+ for compact_id, base in enumerate(["A", "C", "G", "T", "N"]):
1582
+ tok_id = self.dna_token_ids[base]
1583
+ compact[input_ids == tok_id] = compact_id
1584
+ return compact
1585
+
1586
+ def _labels_to_compact(self, labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
1587
+ if labels is None:
1588
+ return None
1589
+
1590
+ labels = labels.long()
1591
+ compact = torch.full_like(labels, -100)
1592
+
1593
+ direct_mask = (labels >= 0) & (labels < self.config.alphabet_size)
1594
+ compact[direct_mask] = labels[direct_mask]
1595
+
1596
+ for compact_id, base in enumerate(["A", "C", "G", "T", "N"]):
1597
+ tok_id = self.dna_token_ids[base]
1598
+ compact[labels == tok_id] = compact_id
1599
+
1600
+ return compact
1601
+
1602
+ def _build_special_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
1603
+ input_ids = input_ids.long()
1604
+ if input_ids.numel() == 0:
1605
+ return torch.zeros_like(input_ids, dtype=torch.bool)
1606
+
1607
+ mn, mx = int(input_ids.min()), int(input_ids.max())
1608
+ if 0 <= mn and mx < self.config.alphabet_size:
1609
+ return torch.zeros_like(input_ids, dtype=torch.bool)
1610
+
1611
+ special_mask = torch.ones_like(input_ids, dtype=torch.bool)
1612
+ for base in ["A", "C", "G", "T", "N"]:
1613
+ special_mask[input_ids == self.dna_token_ids[base]] = False
1614
+ return special_mask
1615
+
1616
+ def _expand_logits_to_tokenizer_vocab(self, compact_logits: torch.Tensor) -> torch.Tensor:
1617
+ B, L, C = compact_logits.shape
1618
+ V = self.tokenizer_vocab_size
1619
+ full_logits = compact_logits.new_full((B, L, V), -1e4)
1620
+
1621
+ for compact_id, base in enumerate(["A", "C", "G", "T", "N"]):
1622
+ full_logits[:, :, self.dna_token_ids[base]] = compact_logits[:, :, compact_id]
1623
+
1624
+ return full_logits
1625
+
1626
+ @torch.no_grad()
1627
+ def extract_embeddings(
1628
+ self,
1629
+ input_ids: torch.LongTensor,
1630
+ attention_mask: Optional[torch.Tensor] = None,
1631
+ ) -> torch.Tensor:
1632
+ compact_ids = self._to_compact_ids(input_ids)
1633
+
1634
+ was_pretrain = bool(getattr(self.backbone, "pretrain", False))
1635
+ was_for_repr = bool(getattr(self.backbone, "for_representation", False))
1636
+ try:
1637
+ self.backbone.pretrain = False
1638
+ self.backbone.for_representation = True
1639
+ embeddings, _ = self.backbone(compact_ids, mask=attention_mask)
1640
+ finally:
1641
+ self.backbone.pretrain = was_pretrain
1642
+ self.backbone.for_representation = was_for_repr
1643
+ return embeddings
1644
+
1645
+ def forward(
1646
+ self,
1647
+ input_ids: torch.LongTensor,
1648
+ attention_mask: Optional[torch.Tensor] = None,
1649
+ labels: Optional[torch.LongTensor] = None,
1650
+ return_dict: Optional[bool] = None,
1651
+ **kwargs
1652
+ ):
1653
+ return_dict = self.config.use_return_dict if return_dict is None else return_dict
1654
+
1655
+ compact_input_ids = self._to_compact_ids(input_ids)
1656
+ compact_labels = self._labels_to_compact(labels)
1657
+ special_mask = self._build_special_mask(input_ids)
1658
+
1659
+ if compact_labels is not None:
1660
+ mlm_mask = compact_labels.ne(-100)
1661
+ else:
1662
+ mlm_mask = input_ids.eq(getattr(self.config, "mask_token_id", 3))
1663
+
1664
+ # HF wrapper always asks for dense token logits, so temporarily disable streaming-loss forward.
1665
+ was_streaming = bool(getattr(self.backbone, "streaming_loss", False))
1666
+ self.backbone.streaming_loss = False
1667
+ try:
1668
+ if self.config.pretrain:
1669
+ outputs, _ = self.backbone(
1670
+ (compact_input_ids, mlm_mask, compact_labels, special_mask),
1671
+ mask=attention_mask,
1672
+ )
1673
+ compact_logits = outputs.logits[0]
1674
+ aux_loss = outputs.logits[2]
1675
+ else:
1676
+ compact_logits, _ = self.backbone(compact_input_ids, mask=attention_mask)
1677
+ aux_loss = None
1678
+ finally:
1679
+ self.backbone.streaming_loss = was_streaming
1680
+
1681
+ logits = self._expand_logits_to_tokenizer_vocab(compact_logits)
1682
+
1683
+ loss = None
1684
+ if compact_labels is not None:
1685
+ loss = F.cross_entropy(
1686
+ compact_logits.reshape(-1, self.config.alphabet_size),
1687
+ compact_labels.reshape(-1),
1688
+ ignore_index=-100,
1689
+ )
1690
+ if aux_loss is not None:
1691
+ loss = loss + aux_loss.to(loss.dtype)
1692
+
1693
+ if not return_dict:
1694
+ output = (logits,)
1695
+ return ((loss,) + output) if loss is not None else output
1696
+
1697
+ return MaskedLMOutput(
1698
+ loss=loss,
1699
+ logits=logits,
1700
+ hidden_states=None,
1701
+ attentions=None,
1702
+ )
28.6M/huggingface_crossdna_140K_len/crossdna/special_tokens_map.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "[PAD]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "[UNK]",
9
+ "additional_special_tokens": [
10
+ "[RESERVED]"
11
+ ]
12
+ }
28.6M/huggingface_crossdna_140K_len/crossdna/tokenization_crossdna.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Dict, List, Optional, Sequence, Union
6
+
7
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
8
+
9
+
10
+ class CrossDNATokenizer(PreTrainedTokenizer):
11
+ def __init__(
12
+ self,
13
+ characters: Sequence[str] = ("A", "C", "G", "T", "N"),
14
+ model_max_length: int = 143360,
15
+ padding_side: str = "left",
16
+ dna_token_start_id: int = 7,
17
+ **kwargs,
18
+ ):
19
+ self.characters = [str(ch).upper() for ch in characters]
20
+ self.model_max_length = int(model_max_length)
21
+ self.dna_token_start_id = int(dna_token_start_id)
22
+
23
+ self._vocab_str_to_int = {
24
+ "[CLS]": 0,
25
+ "[SEP]": 1,
26
+ "[BOS]": 2,
27
+ "[MASK]": 3,
28
+ "[PAD]": 4,
29
+ "[RESERVED]": 5,
30
+ "[UNK]": 6,
31
+ **{ch: self.dna_token_start_id + i for i, ch in enumerate(self.characters)},
32
+ }
33
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
34
+
35
+ bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
36
+ eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
37
+ sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
38
+ cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
39
+ pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
40
+ unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
41
+ mask_token = AddedToken("[MASK]", lstrip=False, rstrip=False)
42
+
43
+ if "add_special_tokens" in kwargs:
44
+ kwargs.pop("add_special_tokens")
45
+
46
+ super().__init__(
47
+ bos_token=bos_token,
48
+ eos_token=eos_token,
49
+ sep_token=sep_token,
50
+ cls_token=cls_token,
51
+ pad_token=pad_token,
52
+ mask_token=mask_token,
53
+ unk_token=unk_token,
54
+ add_prefix_space=False,
55
+ model_max_length=self.model_max_length,
56
+ padding_side=padding_side,
57
+ **kwargs,
58
+ )
59
+
60
+ def __len__(self):
61
+ return len(self._vocab_str_to_int)
62
+
63
+ @property
64
+ def vocab_size(self) -> int:
65
+ return len(self._vocab_str_to_int)
66
+
67
+ def get_vocab(self) -> Dict[str, int]:
68
+ return dict(self._vocab_str_to_int)
69
+
70
+ def _tokenize(self, text: str) -> List[str]:
71
+ return list(text.upper())
72
+
73
+ def _convert_token_to_id(self, token: str) -> int:
74
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
75
+
76
+ def _convert_id_to_token(self, index: int) -> str:
77
+ return self._vocab_int_to_str.get(index, "[UNK]")
78
+
79
+ def convert_tokens_to_string(self, tokens):
80
+ return "".join(tokens)
81
+
82
+ def build_inputs_with_special_tokens(
83
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
84
+ ) -> List[int]:
85
+ sep = [self.sep_token_id]
86
+ result = token_ids_0 + sep
87
+ if token_ids_1 is not None:
88
+ result += token_ids_1 + sep
89
+ return result
90
+
91
+ def get_special_tokens_mask(
92
+ self,
93
+ token_ids_0: List[int],
94
+ token_ids_1: Optional[List[int]] = None,
95
+ already_has_special_tokens: bool = False,
96
+ ) -> List[int]:
97
+ if already_has_special_tokens:
98
+ return super().get_special_tokens_mask(
99
+ token_ids_0=token_ids_0,
100
+ token_ids_1=token_ids_1,
101
+ already_has_special_tokens=True,
102
+ )
103
+ result = ([0] * len(token_ids_0)) + [1]
104
+ if token_ids_1 is not None:
105
+ result += ([0] * len(token_ids_1)) + [1]
106
+ return result
107
+
108
+ def create_token_type_ids_from_sequences(
109
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
110
+ ) -> List[int]:
111
+ sep = [self.sep_token_id]
112
+ result = len(token_ids_0 + sep) * [0]
113
+ if token_ids_1 is not None:
114
+ result += len(token_ids_1 + sep) * [1]
115
+ return result
116
+
117
+ def get_config(self) -> Dict:
118
+ return {
119
+ "characters": self.characters,
120
+ "model_max_length": self.model_max_length,
121
+ "padding_side": self.padding_side,
122
+ "dna_token_start_id": self.dna_token_start_id,
123
+ "bos_token": "[BOS]",
124
+ "eos_token": "[SEP]",
125
+ "sep_token": "[SEP]",
126
+ "cls_token": "[CLS]",
127
+ "pad_token": "[PAD]",
128
+ "mask_token": "[MASK]",
129
+ "unk_token": "[UNK]",
130
+ "tokenizer_class": "CrossDNATokenizer",
131
+ "auto_map": {
132
+ "AutoTokenizer": [
133
+ "tokenization_crossdna.CrossDNATokenizer",
134
+ None
135
+ ]
136
+ }
137
+ }
138
+
139
+ @classmethod
140
+ def from_config(cls, config: Dict) -> "CrossDNATokenizer":
141
+ return cls(
142
+ characters=config.get("characters", ["A", "C", "G", "T", "N"]),
143
+ model_max_length=config.get("model_max_length", 143360),
144
+ padding_side=config.get("padding_side", "left"),
145
+ dna_token_start_id=config.get("dna_token_start_id", 7),
146
+ )
147
+
148
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
149
+ save_directory = Path(save_directory)
150
+ save_directory.mkdir(parents=True, exist_ok=True)
151
+
152
+ cfg_file = save_directory / "tokenizer_config.json"
153
+ stm_file = save_directory / "special_tokens_map.json"
154
+
155
+ with open(cfg_file, "w", encoding="utf-8") as f:
156
+ json.dump(self.get_config(), f, indent=2, ensure_ascii=False)
157
+
158
+ special_tokens_map = {
159
+ "bos_token": "[BOS]",
160
+ "cls_token": "[CLS]",
161
+ "eos_token": "[SEP]",
162
+ "mask_token": "[MASK]",
163
+ "pad_token": "[PAD]",
164
+ "sep_token": "[SEP]",
165
+ "unk_token": "[UNK]",
166
+ "additional_special_tokens": ["[RESERVED]"]
167
+ }
168
+ with open(stm_file, "w", encoding="utf-8") as f:
169
+ json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
170
+
171
+ return (str(cfg_file), str(stm_file))
172
+
173
+ @classmethod
174
+ def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
175
+ cfg_file = Path(save_directory) / "tokenizer_config.json"
176
+ with open(cfg_file, encoding="utf-8") as f:
177
+ cfg = json.load(f)
178
+ cfg.update(kwargs)
179
+ cfg.pop("tokenizer_class", None)
180
+ cfg.pop("auto_map", None)
181
+ return cls.from_config(cfg)
28.6M/huggingface_crossdna_140K_len/crossdna/tokenizer_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "characters": [
3
+ "A",
4
+ "C",
5
+ "G",
6
+ "T",
7
+ "N"
8
+ ],
9
+ "model_max_length": 143360,
10
+ "padding_side": "left",
11
+ "dna_token_start_id": 7,
12
+ "bos_token": "[BOS]",
13
+ "eos_token": "[SEP]",
14
+ "sep_token": "[SEP]",
15
+ "cls_token": "[CLS]",
16
+ "pad_token": "[PAD]",
17
+ "mask_token": "[MASK]",
18
+ "unk_token": "[UNK]",
19
+ "tokenizer_class": "CrossDNATokenizer",
20
+ "auto_map": {
21
+ "AutoTokenizer": [
22
+ "tokenization_crossdna.CrossDNATokenizer",
23
+ null
24
+ ]
25
+ }
26
+ }
28.6M/huggingface_crossdna_140K_len/crossdna/transfer.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # transfer.py
2
+ import os
3
+ import json
4
+
5
+ # 1) 先关掉 dynamo / compile,必须放在导入 modeling_crossdna 之前
6
+ os.environ["DISABLE_TORCH_COMPILE"] = "1"
7
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
8
+
9
+ import torch
10
+
11
+ # 2) 再做一次硬 monkey-patch,避免 fla 在 import 时触发 @torch.compile
12
+ if hasattr(torch, "compile"):
13
+ def _no_compile(fn=None, *args, **kwargs):
14
+ if fn is None:
15
+ def deco(f):
16
+ return f
17
+ return deco
18
+ return fn
19
+ torch.compile = _no_compile
20
+
21
+ from configuration_crossdna import CrossDNAConfig
22
+ from modeling_crossdna import CrossDNAForMaskedLM
23
+
24
+ CKPT = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna/last.ckpt"
25
+ MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna"
26
+
27
+
28
+ def adapt_state_dict(sd):
29
+ new_sd = {}
30
+ for k, v in sd.items():
31
+ k2 = k
32
+
33
+ # Lightning / DDP / compiled 常见前缀
34
+ for pref in (
35
+ "state_dict.",
36
+ "model.",
37
+ "module.",
38
+ "_orig_mod.",
39
+ ):
40
+ if k2.startswith(pref):
41
+ k2 = k2[len(pref):]
42
+
43
+ # 如果 ckpt 里不是 HF 包装后的 key,就挂到 backbone 下
44
+ if not k2.startswith("backbone."):
45
+ k2 = "backbone." + k2
46
+
47
+ new_sd[k2] = v
48
+ return new_sd
49
+
50
+
51
+ # 1) 读 config.json 构建 HF 模型骨架
52
+ with open(os.path.join(MODEL_DIR, "config.json"), "r", encoding="utf-8") as f:
53
+ cfg_dict = json.load(f)
54
+
55
+ cfg = CrossDNAConfig(**cfg_dict)
56
+ model = CrossDNAForMaskedLM(cfg)
57
+
58
+ # 2) 读 ckpt
59
+ raw = torch.load(CKPT, map_location="cpu")
60
+ sd = raw.get("state_dict", raw) if isinstance(raw, dict) else raw
61
+ sd = adapt_state_dict(sd)
62
+
63
+ # 3) 加载权重
64
+ missing, unexpected = model.load_state_dict(sd, strict=False)
65
+
66
+ print("[Missing]", len(missing), "keys")
67
+ if missing:
68
+ print(" first 30 missing:")
69
+ for k in missing[:30]:
70
+ print(" ", k)
71
+
72
+ print("[Unexpected]", len(unexpected), "keys")
73
+ if unexpected:
74
+ print(" first 30 unexpected:")
75
+ for k in unexpected[:30]:
76
+ print(" ", k)
77
+
78
+ # 4) 直接保存为 HF safetensors
79
+ model.save_pretrained(MODEL_DIR, safe_serialization=True)
80
+
81
+ print("Saved HF weights to:", os.path.join(MODEL_DIR, "model.safetensors"))
28.6M/huggingface_crossdna_140K_len/crossdna_140K_infer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # 必须放在最前面
4
+ os.environ["DISABLE_TORCH_COMPILE"] = "1"
5
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
6
+
7
+ import torch
8
+
9
+ # 必须在 from_pretrained 之前,把 torch.compile 变成 no-op
10
+ if hasattr(torch, "compile"):
11
+ def _no_compile(fn=None, *args, **kwargs):
12
+ if fn is None:
13
+ def deco(f):
14
+ return f
15
+ return deco
16
+ return fn
17
+ torch.compile = _no_compile
18
+
19
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
20
+
21
+ MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna"
22
+
23
+ tok = AutoTokenizer.from_pretrained(
24
+ MODEL_DIR,
25
+ trust_remote_code=True,
26
+ local_files_only=True,
27
+ )
28
+
29
+ model = AutoModelForMaskedLM.from_pretrained(
30
+ MODEL_DIR,
31
+ trust_remote_code=True,
32
+ local_files_only=True,
33
+ ).eval()
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ model.to(device)
37
+
38
+ seq = "ACGT" * 128
39
+ enc = tok(seq, return_tensors="pt", add_special_tokens=False)
40
+ x = enc["input_ids"].to(device)
41
+
42
+ with torch.inference_mode():
43
+ out = model(input_ids=x)
44
+ emb = model.extract_embeddings(x)
45
+
46
+ print("input_ids.shape =", tuple(x.shape))
47
+ print("logits.shape =", tuple(out.logits.shape))
48
+ print("embeddings.shape =", tuple(emb.shape))