Upload CrossDNA 28.6M pretrained files
Browse files- 28.6M/README.md +56 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/__pycache__/configuration_crossdna.cpython-311.pyc +0 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/__pycache__/modeling_crossdna.cpython-311.pyc +0 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/config.json +106 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/configuration_crossdna.py +178 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/last.ckpt +3 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/model.safetensors +3 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/modeling_crossdna.py +1702 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/special_tokens_map.json +12 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/tokenization_crossdna.py +181 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/tokenizer_config.json +26 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna/transfer.py +81 -0
- 28.6M/huggingface_crossdna_140K_len/crossdna_140K_infer.py +48 -0
28.6M/README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: gpl-3.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
## Using CrossDNA 28.6M (140K sequence inputs)
|
| 6 |
+
```python
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
os.environ["DISABLE_TORCH_COMPILE"] = "1"
|
| 11 |
+
os.environ["TORCHDYNAMO_DISABLE"] = "1"
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if hasattr(torch, "compile"):
|
| 17 |
+
def _no_compile(fn=None, *args, **kwargs):
|
| 18 |
+
if fn is None:
|
| 19 |
+
def deco(f):
|
| 20 |
+
return f
|
| 21 |
+
return deco
|
| 22 |
+
return fn
|
| 23 |
+
torch.compile = _no_compile
|
| 24 |
+
|
| 25 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 26 |
+
|
| 27 |
+
MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna"
|
| 28 |
+
|
| 29 |
+
tok = AutoTokenizer.from_pretrained(
|
| 30 |
+
MODEL_DIR,
|
| 31 |
+
trust_remote_code=True,
|
| 32 |
+
local_files_only=True,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
model = AutoModelForMaskedLM.from_pretrained(
|
| 36 |
+
MODEL_DIR,
|
| 37 |
+
trust_remote_code=True,
|
| 38 |
+
local_files_only=True,
|
| 39 |
+
).eval()
|
| 40 |
+
|
| 41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 42 |
+
model.to(device)
|
| 43 |
+
|
| 44 |
+
seq = "ACGT" * 128
|
| 45 |
+
enc = tok(seq, return_tensors="pt", add_special_tokens=False)
|
| 46 |
+
x = enc["input_ids"].to(device)
|
| 47 |
+
|
| 48 |
+
with torch.inference_mode():
|
| 49 |
+
out = model(input_ids=x)
|
| 50 |
+
emb = model.extract_embeddings(x)
|
| 51 |
+
|
| 52 |
+
print("input_ids.shape =", tuple(x.shape))
|
| 53 |
+
print("logits.shape =", tuple(out.logits.shape))
|
| 54 |
+
print("embeddings.shape =", tuple(emb.shape))
|
| 55 |
+
|
| 56 |
+
```
|
28.6M/huggingface_crossdna_140K_len/crossdna/__pycache__/configuration_crossdna.cpython-311.pyc
ADDED
|
Binary file (6.14 kB). View file
|
|
|
28.6M/huggingface_crossdna_140K_len/crossdna/__pycache__/modeling_crossdna.cpython-311.pyc
ADDED
|
Binary file (99.4 kB). View file
|
|
|
28.6M/huggingface_crossdna_140K_len/crossdna/config.json
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alphabet_size": 5,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"CrossDNAForMaskedLM"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_crossdna.CrossDNAConfig",
|
| 8 |
+
"AutoModelForMaskedLM": "modeling_crossdna.CrossDNAForMaskedLM",
|
| 9 |
+
"AutoTokenizer": "tokenization_crossdna.CrossDNATokenizer"
|
| 10 |
+
},
|
| 11 |
+
"auto_update_ema_in_forward": true,
|
| 12 |
+
"aux_ce_weight": 0.0,
|
| 13 |
+
"block_size": 4096,
|
| 14 |
+
"bos_token_id": 2,
|
| 15 |
+
"bridge_dropout": 0.05,
|
| 16 |
+
"checkpoint_chunk_size": 1,
|
| 17 |
+
"checkpoint_core_layers": true,
|
| 18 |
+
"cls_token_id": 0,
|
| 19 |
+
"comba_cfg": {
|
| 20 |
+
"conv_size": 4,
|
| 21 |
+
"correction_factor": 0.02,
|
| 22 |
+
"expand_v": 1,
|
| 23 |
+
"head_dim": 64,
|
| 24 |
+
"hidden_size": 256,
|
| 25 |
+
"mode": "chunk",
|
| 26 |
+
"norm_eps": 1e-05,
|
| 27 |
+
"num_heads": 8,
|
| 28 |
+
"use_gate": true,
|
| 29 |
+
"use_short_conv": true
|
| 30 |
+
},
|
| 31 |
+
"compact_n_token_id": 4,
|
| 32 |
+
"core_checkpoint_chunk_size": 1,
|
| 33 |
+
"d_model": 256,
|
| 34 |
+
"depth": 8,
|
| 35 |
+
"detach_gate": false,
|
| 36 |
+
"disable_cross_view": false,
|
| 37 |
+
"dna_token_ids": {
|
| 38 |
+
"A": 7,
|
| 39 |
+
"C": 8,
|
| 40 |
+
"G": 9,
|
| 41 |
+
"N": 11,
|
| 42 |
+
"T": 10
|
| 43 |
+
},
|
| 44 |
+
"dna_token_start_id": 7,
|
| 45 |
+
"dna_tokens": [
|
| 46 |
+
"A",
|
| 47 |
+
"C",
|
| 48 |
+
"G",
|
| 49 |
+
"T",
|
| 50 |
+
"N"
|
| 51 |
+
],
|
| 52 |
+
"drop_path_rates": [
|
| 53 |
+
0.0,
|
| 54 |
+
0.08
|
| 55 |
+
],
|
| 56 |
+
"dropout": 0.1,
|
| 57 |
+
"dtype": "float32",
|
| 58 |
+
"ema_decay": 0.9995,
|
| 59 |
+
"eos_token_id": 1,
|
| 60 |
+
"for_representation": false,
|
| 61 |
+
"gate_freeze_steps": 1000,
|
| 62 |
+
"gate_sup_warmup_steps": 1000,
|
| 63 |
+
"gate_sup_weight": 0.003,
|
| 64 |
+
"gate_temp": 1.2,
|
| 65 |
+
"mask_token_id": 3,
|
| 66 |
+
"model_type": "crossdna",
|
| 67 |
+
"pad_token_id": 4,
|
| 68 |
+
"pretrain": true,
|
| 69 |
+
"rc_bidirectional_stopgrad": true,
|
| 70 |
+
"rc_max_weight": 0.2,
|
| 71 |
+
"rc_tau": 1.5,
|
| 72 |
+
"rc_warmup_steps": 2000,
|
| 73 |
+
"return_ab_logits": true,
|
| 74 |
+
"sem_max_weight": 0.1,
|
| 75 |
+
"sem_warmup_steps": 8000,
|
| 76 |
+
"sep_token_id": 1,
|
| 77 |
+
"streaming_loss": true,
|
| 78 |
+
"streaming_report_ab": true,
|
| 79 |
+
"transformer_cfg": {
|
| 80 |
+
"attn": {
|
| 81 |
+
"num_heads": 8,
|
| 82 |
+
"num_kv_heads": 8,
|
| 83 |
+
"qkv_bias": false,
|
| 84 |
+
"rope_theta": 10000,
|
| 85 |
+
"window_size": 512
|
| 86 |
+
},
|
| 87 |
+
"fuse_swiglu": true,
|
| 88 |
+
"hidden_act": "swish",
|
| 89 |
+
"hidden_ratio": 4.0,
|
| 90 |
+
"hidden_size": 256,
|
| 91 |
+
"max_position_embeddings": 4096,
|
| 92 |
+
"norm_eps": 1e-05
|
| 93 |
+
},
|
| 94 |
+
"transformers_version": "4.57.1",
|
| 95 |
+
"unk_token_id": 6,
|
| 96 |
+
"use_barlow": false,
|
| 97 |
+
"use_bridge": true,
|
| 98 |
+
"use_checkpointing": true,
|
| 99 |
+
"use_ema_teacher": true,
|
| 100 |
+
"use_final_conv": false,
|
| 101 |
+
"use_mem": false,
|
| 102 |
+
"use_rc_kl": false,
|
| 103 |
+
"use_s_scan": true,
|
| 104 |
+
"use_tv": false,
|
| 105 |
+
"vocab_size": 12
|
| 106 |
+
}
|
28.6M/huggingface_crossdna_140K_len/crossdna/configuration_crossdna.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CrossDNAConfig(PretrainedConfig):
|
| 5 |
+
model_type = "crossdna"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
alphabet_size=5,
|
| 10 |
+
vocab_size=12,
|
| 11 |
+
dna_tokens=("A", "C", "G", "T", "N"),
|
| 12 |
+
dna_token_start_id=7,
|
| 13 |
+
compact_n_token_id=4,
|
| 14 |
+
dna_token_ids=None,
|
| 15 |
+
|
| 16 |
+
d_model=256,
|
| 17 |
+
block_size=4096,
|
| 18 |
+
depth=8,
|
| 19 |
+
drop_path_rates=(0.0, 0.08),
|
| 20 |
+
dropout=0.10,
|
| 21 |
+
|
| 22 |
+
pretrain=True,
|
| 23 |
+
use_s_scan=True,
|
| 24 |
+
for_representation=False,
|
| 25 |
+
|
| 26 |
+
use_bridge=True,
|
| 27 |
+
bridge_dropout=0.05,
|
| 28 |
+
|
| 29 |
+
use_ema_teacher=True,
|
| 30 |
+
ema_decay=0.9995,
|
| 31 |
+
auto_update_ema_in_forward=True,
|
| 32 |
+
|
| 33 |
+
use_mem=False,
|
| 34 |
+
use_rc_kl=False,
|
| 35 |
+
use_barlow=False,
|
| 36 |
+
use_tv=False,
|
| 37 |
+
use_final_conv=False,
|
| 38 |
+
|
| 39 |
+
sem_max_weight=0.10,
|
| 40 |
+
sem_warmup_steps=8000,
|
| 41 |
+
aux_ce_weight=0.0,
|
| 42 |
+
|
| 43 |
+
gate_freeze_steps=1000,
|
| 44 |
+
detach_gate=False,
|
| 45 |
+
gate_sup_weight=0.003,
|
| 46 |
+
gate_sup_warmup_steps=1000,
|
| 47 |
+
gate_temp=1.2,
|
| 48 |
+
|
| 49 |
+
use_checkpointing=True,
|
| 50 |
+
checkpoint_chunk_size=1,
|
| 51 |
+
checkpoint_core_layers=True,
|
| 52 |
+
core_checkpoint_chunk_size=1,
|
| 53 |
+
|
| 54 |
+
return_ab_logits=True,
|
| 55 |
+
streaming_loss=True,
|
| 56 |
+
streaming_report_ab=True,
|
| 57 |
+
|
| 58 |
+
disable_cross_view=False,
|
| 59 |
+
|
| 60 |
+
rc_bidirectional_stopgrad=True,
|
| 61 |
+
rc_max_weight=0.2,
|
| 62 |
+
rc_tau=1.5,
|
| 63 |
+
rc_warmup_steps=2000,
|
| 64 |
+
|
| 65 |
+
transformer_cfg=None,
|
| 66 |
+
comba_cfg=None,
|
| 67 |
+
|
| 68 |
+
pad_token_id=4,
|
| 69 |
+
bos_token_id=2,
|
| 70 |
+
eos_token_id=1,
|
| 71 |
+
sep_token_id=1,
|
| 72 |
+
cls_token_id=0,
|
| 73 |
+
mask_token_id=3,
|
| 74 |
+
unk_token_id=6,
|
| 75 |
+
**kwargs,
|
| 76 |
+
):
|
| 77 |
+
super().__init__(
|
| 78 |
+
pad_token_id=pad_token_id,
|
| 79 |
+
bos_token_id=bos_token_id,
|
| 80 |
+
eos_token_id=eos_token_id,
|
| 81 |
+
sep_token_id=sep_token_id,
|
| 82 |
+
cls_token_id=cls_token_id,
|
| 83 |
+
mask_token_id=mask_token_id,
|
| 84 |
+
unk_token_id=unk_token_id,
|
| 85 |
+
**kwargs,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.alphabet_size = int(alphabet_size)
|
| 89 |
+
self.vocab_size = int(vocab_size)
|
| 90 |
+
|
| 91 |
+
self.dna_tokens = list(dna_tokens)
|
| 92 |
+
self.dna_token_start_id = int(dna_token_start_id)
|
| 93 |
+
self.compact_n_token_id = int(compact_n_token_id)
|
| 94 |
+
|
| 95 |
+
if dna_token_ids is None:
|
| 96 |
+
dna_token_ids = {
|
| 97 |
+
ch: self.dna_token_start_id + i
|
| 98 |
+
for i, ch in enumerate(self.dna_tokens)
|
| 99 |
+
}
|
| 100 |
+
self.dna_token_ids = dict(dna_token_ids)
|
| 101 |
+
|
| 102 |
+
self.d_model = int(d_model)
|
| 103 |
+
self.block_size = int(block_size)
|
| 104 |
+
self.depth = int(depth)
|
| 105 |
+
self.drop_path_rates = list(drop_path_rates) if drop_path_rates is not None else None
|
| 106 |
+
self.dropout = float(dropout)
|
| 107 |
+
|
| 108 |
+
self.pretrain = bool(pretrain)
|
| 109 |
+
self.use_s_scan = bool(use_s_scan)
|
| 110 |
+
self.for_representation = bool(for_representation)
|
| 111 |
+
|
| 112 |
+
self.use_bridge = bool(use_bridge)
|
| 113 |
+
self.bridge_dropout = float(bridge_dropout)
|
| 114 |
+
|
| 115 |
+
self.use_ema_teacher = bool(use_ema_teacher)
|
| 116 |
+
self.ema_decay = float(ema_decay)
|
| 117 |
+
self.auto_update_ema_in_forward = bool(auto_update_ema_in_forward)
|
| 118 |
+
|
| 119 |
+
self.use_mem = bool(use_mem)
|
| 120 |
+
self.use_rc_kl = bool(use_rc_kl)
|
| 121 |
+
self.use_barlow = bool(use_barlow)
|
| 122 |
+
self.use_tv = bool(use_tv)
|
| 123 |
+
self.use_final_conv = bool(use_final_conv)
|
| 124 |
+
|
| 125 |
+
self.sem_max_weight = float(sem_max_weight)
|
| 126 |
+
self.sem_warmup_steps = int(sem_warmup_steps)
|
| 127 |
+
self.aux_ce_weight = float(aux_ce_weight)
|
| 128 |
+
|
| 129 |
+
self.gate_freeze_steps = int(gate_freeze_steps)
|
| 130 |
+
self.detach_gate = bool(detach_gate)
|
| 131 |
+
self.gate_sup_weight = float(gate_sup_weight)
|
| 132 |
+
self.gate_sup_warmup_steps = int(gate_sup_warmup_steps)
|
| 133 |
+
self.gate_temp = float(gate_temp)
|
| 134 |
+
|
| 135 |
+
self.use_checkpointing = bool(use_checkpointing)
|
| 136 |
+
self.checkpoint_chunk_size = int(checkpoint_chunk_size)
|
| 137 |
+
self.checkpoint_core_layers = bool(checkpoint_core_layers)
|
| 138 |
+
self.core_checkpoint_chunk_size = int(core_checkpoint_chunk_size)
|
| 139 |
+
|
| 140 |
+
self.return_ab_logits = bool(return_ab_logits)
|
| 141 |
+
self.streaming_loss = bool(streaming_loss)
|
| 142 |
+
self.streaming_report_ab = bool(streaming_report_ab)
|
| 143 |
+
|
| 144 |
+
self.disable_cross_view = bool(disable_cross_view)
|
| 145 |
+
|
| 146 |
+
self.rc_bidirectional_stopgrad = bool(rc_bidirectional_stopgrad)
|
| 147 |
+
self.rc_max_weight = float(rc_max_weight)
|
| 148 |
+
self.rc_tau = float(rc_tau)
|
| 149 |
+
self.rc_warmup_steps = int(rc_warmup_steps)
|
| 150 |
+
|
| 151 |
+
self.transformer_cfg = transformer_cfg or {
|
| 152 |
+
"hidden_size": self.d_model,
|
| 153 |
+
"norm_eps": 1e-5,
|
| 154 |
+
"max_position_embeddings": self.block_size,
|
| 155 |
+
"hidden_ratio": 4.0,
|
| 156 |
+
"hidden_act": "swish",
|
| 157 |
+
"fuse_swiglu": True,
|
| 158 |
+
"attn": {
|
| 159 |
+
"num_heads": 8,
|
| 160 |
+
"num_kv_heads": 8,
|
| 161 |
+
"qkv_bias": False,
|
| 162 |
+
"window_size": 512,
|
| 163 |
+
"rope_theta": 10000,
|
| 164 |
+
},
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
self.comba_cfg = comba_cfg or {
|
| 168 |
+
"hidden_size": self.d_model,
|
| 169 |
+
"expand_v": 1,
|
| 170 |
+
"head_dim": 64,
|
| 171 |
+
"num_heads": 8,
|
| 172 |
+
"use_gate": True,
|
| 173 |
+
"mode": "chunk",
|
| 174 |
+
"use_short_conv": True,
|
| 175 |
+
"correction_factor": 0.02,
|
| 176 |
+
"conv_size": 4,
|
| 177 |
+
"norm_eps": 1e-5,
|
| 178 |
+
}
|
28.6M/huggingface_crossdna_140K_len/crossdna/last.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6788405dcb78275faeaf1b94ae14a27465a37a3ff1872b34bffdb16b07fdf333
|
| 3 |
+
size 454953820
|
28.6M/huggingface_crossdna_140K_len/crossdna/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2bfbc40499d6b55d3ef4a3ba1e2f598dce843f083ad8e413621a7378b564a25a
|
| 3 |
+
size 225803888
|
28.6M/huggingface_crossdna_140K_len/crossdna/modeling_crossdna.py
ADDED
|
@@ -0,0 +1,1702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
import math
|
| 4 |
+
import copy
|
| 5 |
+
from functools import partial
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from collections import namedtuple
|
| 8 |
+
from typing import Dict, Optional, Tuple, Any
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint as cp
|
| 14 |
+
|
| 15 |
+
from transformers import PreTrainedModel
|
| 16 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 17 |
+
|
| 18 |
+
# speed
|
| 19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 20 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 21 |
+
|
| 22 |
+
from fla.layers import comba
|
| 23 |
+
from fla.layers.attn import Attention
|
| 24 |
+
from fla.modules import GatedMLP as SambaMLP
|
| 25 |
+
from fla.modules import RMSNorm
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from omegaconf import OmegaConf
|
| 29 |
+
except Exception:
|
| 30 |
+
OmegaConf = None
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from .configuration_crossdna import CrossDNAConfig
|
| 34 |
+
except ImportError:
|
| 35 |
+
from configuration_crossdna import CrossDNAConfig
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ========================
|
| 39 |
+
# OmegaConf helpers
|
| 40 |
+
# ========================
|
| 41 |
+
def _to_plain_container(x: Any) -> Any:
|
| 42 |
+
if OmegaConf is not None:
|
| 43 |
+
try:
|
| 44 |
+
if OmegaConf.is_config(x):
|
| 45 |
+
return OmegaConf.to_container(x, resolve=True)
|
| 46 |
+
except Exception:
|
| 47 |
+
pass
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _cfg_get(cfg: Any, key: str, default: Any = None) -> Any:
|
| 52 |
+
if cfg is None:
|
| 53 |
+
return default
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
if isinstance(cfg, dict):
|
| 57 |
+
return cfg.get(key, default)
|
| 58 |
+
except Exception:
|
| 59 |
+
pass
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
if hasattr(cfg, key):
|
| 63 |
+
return getattr(cfg, key)
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
try:
|
| 68 |
+
return cfg[key]
|
| 69 |
+
except Exception:
|
| 70 |
+
return default
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ========================
|
| 74 |
+
# Utils
|
| 75 |
+
# ========================
|
| 76 |
+
def complement(seq: torch.Tensor) -> torch.Tensor:
|
| 77 |
+
"""
|
| 78 |
+
compact DNA ids only:
|
| 79 |
+
A=0, C=1, G=2, T=3, N=4
|
| 80 |
+
"""
|
| 81 |
+
perm = torch.tensor([3, 2, 1, 0, 4], device=seq.device, dtype=torch.long)
|
| 82 |
+
return perm[seq.long()].to(seq.dtype)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def reverse_complement(seq: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
comp = complement(seq)
|
| 87 |
+
return torch.flip(comp, dims=[1])
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def make_complement_perm(C=5, device=None, dtype=torch.float32):
|
| 91 |
+
perm = torch.arange(C, device=device)
|
| 92 |
+
if C >= 4:
|
| 93 |
+
perm[0] = 3
|
| 94 |
+
perm[1] = 2
|
| 95 |
+
perm[2] = 1
|
| 96 |
+
perm[3] = 0
|
| 97 |
+
if C >= 5:
|
| 98 |
+
perm[4] = 4
|
| 99 |
+
|
| 100 |
+
P = torch.zeros(C, C, device=device, dtype=dtype)
|
| 101 |
+
P[torch.arange(C, device=device), perm] = 1.0
|
| 102 |
+
return P, perm
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def ensure_finite(x: torch.Tensor, name: str):
|
| 106 |
+
if not torch.isfinite(x).all():
|
| 107 |
+
raise FloatingPointError(f"Non-finite values detected in {name}")
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def linear_warmup_weight(step: int, warmup_steps: int, max_w: float):
|
| 112 |
+
if warmup_steps <= 0:
|
| 113 |
+
return max_w
|
| 114 |
+
if step <= 0:
|
| 115 |
+
return 0.0
|
| 116 |
+
if step >= warmup_steps:
|
| 117 |
+
return max_w
|
| 118 |
+
return max_w * (step / warmup_steps)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def preferred_amp_dtype():
|
| 122 |
+
try:
|
| 123 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
| 124 |
+
return torch.bfloat16
|
| 125 |
+
except Exception:
|
| 126 |
+
pass
|
| 127 |
+
return torch.float16
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def one_hot_float(x: torch.Tensor, num_classes: int, *, dtype: torch.dtype) -> torch.Tensor:
|
| 131 |
+
B, L = x.shape
|
| 132 |
+
out = torch.zeros((B, L, num_classes), device=x.device, dtype=dtype)
|
| 133 |
+
out.scatter_(2, x.unsqueeze(-1), 1.0)
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ========================
|
| 138 |
+
# RC / Barlow / TV
|
| 139 |
+
# ========================
|
| 140 |
+
def rc_consistency_kl(logits_A, logits_B_fwd, P, tau: float = 1.0, eps: float = 1e-6):
|
| 141 |
+
zA = logits_A.float() / tau
|
| 142 |
+
zB = logits_B_fwd.float() / tau
|
| 143 |
+
pA = F.softmax(zA, dim=-1)
|
| 144 |
+
logpA = F.log_softmax(zA, dim=-1)
|
| 145 |
+
pB = F.softmax(zB, dim=-1)
|
| 146 |
+
pB_comp = torch.matmul(pB, P.t()).clamp_min(eps)
|
| 147 |
+
logpB_comp = pB_comp.log()
|
| 148 |
+
kl = (pA * (logpA - logpB_comp)).sum(dim=-1).mean()
|
| 149 |
+
return kl * (tau * tau)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def rc_consistency_bidirectional_stopgrad(logits_A, logits_B_fwd, P, tau: float = 1.5, eps: float = 1e-6):
|
| 153 |
+
zA = logits_A.float() / tau
|
| 154 |
+
zB = logits_B_fwd.float() / tau
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
pB_t = torch.matmul(F.softmax(zB, dim=-1), P.t()).clamp_min(eps)
|
| 157 |
+
logpB_t = pB_t.log()
|
| 158 |
+
loss_A = F.kl_div(F.log_softmax(zA, dim=-1), logpB_t, reduction="batchmean", log_target=True)
|
| 159 |
+
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
pA_t = torch.matmul(F.softmax(zA, dim=-1), P.t()).clamp_min(eps)
|
| 162 |
+
logpA_t = pA_t.log()
|
| 163 |
+
loss_B = F.kl_div(F.log_softmax(zB, dim=-1), logpA_t, reduction="batchmean", log_target=True)
|
| 164 |
+
return 0.5 * (tau * tau) * (loss_A + loss_B)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def barlow_strand_loss_v2(z1, z2, λ_off=0.04, λ_diag=0.04, eps=1e-3):
|
| 168 |
+
B, L, H = z1.shape
|
| 169 |
+
n = B * L
|
| 170 |
+
z1 = z1.reshape(n, H)
|
| 171 |
+
z2 = z2.reshape(n, H)
|
| 172 |
+
|
| 173 |
+
def _std(z):
|
| 174 |
+
var = z.var(dim=0, unbiased=False)
|
| 175 |
+
return torch.sqrt(var + eps)
|
| 176 |
+
|
| 177 |
+
std1, std2 = _std(z1), _std(z2)
|
| 178 |
+
var_term = (F.relu(1 - std1).pow(2).mean() + F.relu(1 - std2).pow(2).mean())
|
| 179 |
+
|
| 180 |
+
z1 = (z1 - z1.mean(0)) / (std1 + eps)
|
| 181 |
+
z2 = (z2 - z2.mean(0)) / (std2 + eps)
|
| 182 |
+
c = (z1.t() @ z2) / max(1, n)
|
| 183 |
+
diag = torch.diagonal(c)
|
| 184 |
+
off = c - torch.diag_embed(diag)
|
| 185 |
+
cov = λ_diag * (1 - diag).pow(2).mean() + λ_off * off.pow(2).mean()
|
| 186 |
+
return var_term + cov
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def tv_mixed(h: torch.Tensor):
|
| 190 |
+
d1 = h[:, 1:, :] - h[:, :-1, :]
|
| 191 |
+
d2 = d1[:, 1:, :] - d1[:, :-1, :]
|
| 192 |
+
return d1.abs().mean() + d2.pow(2).mean()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Mlp(nn.Module):
|
| 196 |
+
def __init__(self, input_dimension, hidden_dimension=None, output_dimension=None,
|
| 197 |
+
activation=F.gelu, return_residual=False):
|
| 198 |
+
super().__init__()
|
| 199 |
+
self.return_residual = return_residual
|
| 200 |
+
hd = hidden_dimension or input_dimension
|
| 201 |
+
od = output_dimension or input_dimension
|
| 202 |
+
self.linear1 = nn.Linear(input_dimension, hd)
|
| 203 |
+
self.activation = activation
|
| 204 |
+
self.linear2 = nn.Linear(hd, od)
|
| 205 |
+
|
| 206 |
+
def forward(self, x: torch.Tensor):
|
| 207 |
+
h = self.activation(self.linear1(x))
|
| 208 |
+
y = self.linear2(h)
|
| 209 |
+
return (y, x) if self.return_residual else y
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def create_comba_cls(comba_kwargs=None, device=None, dtype=None):
|
| 213 |
+
factory_kwargs = {}
|
| 214 |
+
if device is not None:
|
| 215 |
+
factory_kwargs["device"] = device
|
| 216 |
+
if dtype is not None:
|
| 217 |
+
factory_kwargs["dtype"] = dtype
|
| 218 |
+
try:
|
| 219 |
+
base_kwargs = dict(comba_kwargs or {})
|
| 220 |
+
mixer_cls = partial(comba.Comba, **base_kwargs, **factory_kwargs)
|
| 221 |
+
except ImportError:
|
| 222 |
+
class FallbackComba(nn.Module):
|
| 223 |
+
def forward(self, x, *args, **kwargs):
|
| 224 |
+
return x
|
| 225 |
+
mixer_cls = lambda *args, **kwargs: FallbackComba()
|
| 226 |
+
return mixer_cls
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class SlidingWindowAttention(nn.Module):
|
| 230 |
+
def __init__(self, config: Any):
|
| 231 |
+
super().__init__()
|
| 232 |
+
config = _to_plain_container(config)
|
| 233 |
+
|
| 234 |
+
hidden_size = _cfg_get(config, "hidden_size")
|
| 235 |
+
norm_eps = _cfg_get(config, "norm_eps", 1e-5)
|
| 236 |
+
attn_cfg = _cfg_get(config, "attn", {}) or {}
|
| 237 |
+
attn_cfg = _to_plain_container(attn_cfg)
|
| 238 |
+
|
| 239 |
+
self.mixer_norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
|
| 240 |
+
self.mixer = Attention(
|
| 241 |
+
hidden_size=hidden_size,
|
| 242 |
+
num_heads=_cfg_get(attn_cfg, "num_heads"),
|
| 243 |
+
num_kv_heads=_cfg_get(attn_cfg, "num_kv_heads"),
|
| 244 |
+
qkv_bias=_cfg_get(attn_cfg, "qkv_bias"),
|
| 245 |
+
window_size=_cfg_get(attn_cfg, "window_size"),
|
| 246 |
+
rope_theta=_cfg_get(attn_cfg, "rope_theta"),
|
| 247 |
+
max_position_embeddings=_cfg_get(config, "max_position_embeddings"),
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
self.mlp_norm = RMSNorm(hidden_size, eps=norm_eps)
|
| 251 |
+
self.mlp = SambaMLP(
|
| 252 |
+
hidden_size=hidden_size,
|
| 253 |
+
hidden_ratio=_cfg_get(config, "hidden_ratio", 4.0),
|
| 254 |
+
hidden_act=_cfg_get(config, "hidden_act", "swish"),
|
| 255 |
+
fuse_swiglu=_cfg_get(config, "fuse_swiglu", True),
|
| 256 |
+
)
|
| 257 |
+
self.pre_scale = 1.0 / math.sqrt(2.0)
|
| 258 |
+
|
| 259 |
+
def forward(self, hidden_states: torch.Tensor, cache_params: Optional[Any] = None, **kwargs) -> Tuple[torch.Tensor, Any]:
|
| 260 |
+
residual = hidden_states
|
| 261 |
+
x = self.mixer_norm(hidden_states)
|
| 262 |
+
|
| 263 |
+
amp_dtype = preferred_amp_dtype()
|
| 264 |
+
device_type = x.device.type if x.device.type in ["cuda", "cpu", "xpu"] else "cuda"
|
| 265 |
+
|
| 266 |
+
with torch.autocast(device_type=device_type, enabled=(device_type == "cuda"), dtype=amp_dtype):
|
| 267 |
+
x_scaled = x * self.pre_scale
|
| 268 |
+
attn_out, _, cache_params = self.mixer(hidden_states=x_scaled, past_key_values=cache_params, **kwargs)
|
| 269 |
+
attn_out = attn_out / self.pre_scale
|
| 270 |
+
|
| 271 |
+
ensure_finite(attn_out, "attention_out")
|
| 272 |
+
h = residual + attn_out.to(x.dtype)
|
| 273 |
+
|
| 274 |
+
residual = h
|
| 275 |
+
x = self.mlp_norm(h)
|
| 276 |
+
with torch.autocast(device_type=device_type, enabled=(device_type == "cuda"), dtype=amp_dtype):
|
| 277 |
+
x = self.mlp(x, **kwargs)
|
| 278 |
+
|
| 279 |
+
h = residual + x
|
| 280 |
+
ensure_finite(h, "block_output")
|
| 281 |
+
return h, cache_params
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class EnhancedHybridCore(nn.Module):
|
| 285 |
+
def __init__(self, hidden_dim, comba_cfg, transformer_cfg, layer_idx=0, device=None, dtype=None):
|
| 286 |
+
super().__init__()
|
| 287 |
+
comba_cfg = _to_plain_container(comba_cfg)
|
| 288 |
+
transformer_cfg = _to_plain_container(transformer_cfg)
|
| 289 |
+
|
| 290 |
+
self.comba_cls = create_comba_cls(comba_kwargs=comba_cfg, device=device, dtype=dtype)
|
| 291 |
+
try:
|
| 292 |
+
self.comba = self.comba_cls(layer_idx=layer_idx)
|
| 293 |
+
except TypeError:
|
| 294 |
+
self.comba = self.comba_cls()
|
| 295 |
+
|
| 296 |
+
self.transformer = SlidingWindowAttention(config=transformer_cfg)
|
| 297 |
+
self.gate = nn.Linear(hidden_dim * 2, hidden_dim)
|
| 298 |
+
self.out_norm = nn.LayerNorm(hidden_dim)
|
| 299 |
+
|
| 300 |
+
@staticmethod
|
| 301 |
+
def _first(x):
|
| 302 |
+
return x[0] if isinstance(x, tuple) else x
|
| 303 |
+
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
orig_dtype = x.dtype
|
| 306 |
+
x_fp32 = x.float()
|
| 307 |
+
device_type = x.device.type if x.device.type in ["cuda", "cpu", "xpu"] else "cuda"
|
| 308 |
+
|
| 309 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 310 |
+
m_out = self._first(self.comba(x_fp32))
|
| 311 |
+
|
| 312 |
+
m_out = m_out.to(orig_dtype)
|
| 313 |
+
del x_fp32
|
| 314 |
+
|
| 315 |
+
t_out, _ = self.transformer(m_out)
|
| 316 |
+
|
| 317 |
+
concat = torch.cat([m_out, t_out], dim=-1)
|
| 318 |
+
g = torch.sigmoid(self.gate(concat))
|
| 319 |
+
fused = g * t_out + (1 - g) * m_out
|
| 320 |
+
y = self.out_norm(fused)
|
| 321 |
+
ensure_finite(y, "EnhancedHybridCore.out")
|
| 322 |
+
return y
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class DeepEnhancedBranch(nn.Module):
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
hidden_dim: int,
|
| 329 |
+
comba_cfg: Dict | None,
|
| 330 |
+
transformer_cfg: Any,
|
| 331 |
+
depth: int = 4,
|
| 332 |
+
drop_path_rates=None,
|
| 333 |
+
*,
|
| 334 |
+
device=None,
|
| 335 |
+
dtype=None,
|
| 336 |
+
checkpoint_core_layers: bool = False,
|
| 337 |
+
core_checkpoint_chunk_size: int = 1,
|
| 338 |
+
):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.layers = nn.ModuleList()
|
| 341 |
+
|
| 342 |
+
transformer_cfg = _to_plain_container(transformer_cfg)
|
| 343 |
+
comba_cfg = _to_plain_container(comba_cfg)
|
| 344 |
+
|
| 345 |
+
self.checkpoint_core_layers = bool(checkpoint_core_layers)
|
| 346 |
+
self.core_checkpoint_chunk_size = int(core_checkpoint_chunk_size)
|
| 347 |
+
|
| 348 |
+
if drop_path_rates is None:
|
| 349 |
+
rates = [0.05 * (i / max(1, depth - 1)) for i in range(depth)]
|
| 350 |
+
elif isinstance(drop_path_rates, (float, int)):
|
| 351 |
+
rates = [float(drop_path_rates)] * depth
|
| 352 |
+
else:
|
| 353 |
+
dpr = list(_to_plain_container(drop_path_rates))
|
| 354 |
+
rates = dpr + [dpr[-1]] * (depth - len(dpr))
|
| 355 |
+
|
| 356 |
+
for i in range(depth):
|
| 357 |
+
layer_cfg = dict(transformer_cfg) if isinstance(transformer_cfg, dict) else transformer_cfg.copy()
|
| 358 |
+
layer_cfg["drop_path_prob"] = rates[i]
|
| 359 |
+
self.layers.append(EnhancedHybridCore(hidden_dim, comba_cfg, layer_cfg, i, device, dtype))
|
| 360 |
+
|
| 361 |
+
self.output_norm = nn.LayerNorm(hidden_dim)
|
| 362 |
+
|
| 363 |
+
def _run_layers(self, x: torch.Tensor, start: int, end: int):
|
| 364 |
+
out = x
|
| 365 |
+
for i in range(start, end):
|
| 366 |
+
out = self.layers[i](out)
|
| 367 |
+
return out
|
| 368 |
+
|
| 369 |
+
def forward(self, x: torch.Tensor):
|
| 370 |
+
if self.training and self.checkpoint_core_layers:
|
| 371 |
+
chunk = max(1, self.core_checkpoint_chunk_size)
|
| 372 |
+
for s in range(0, len(self.layers), chunk):
|
| 373 |
+
e = min(s + chunk, len(self.layers))
|
| 374 |
+
|
| 375 |
+
def _seg(inp, s=s, e=e):
|
| 376 |
+
return self._run_layers(inp, s, e)
|
| 377 |
+
|
| 378 |
+
x = cp.checkpoint(_seg, x, use_reentrant=False)
|
| 379 |
+
else:
|
| 380 |
+
for layer in self.layers:
|
| 381 |
+
x = layer(x)
|
| 382 |
+
|
| 383 |
+
y = self.output_norm(x)
|
| 384 |
+
ensure_finite(y, "DeepEnhancedBranch.out")
|
| 385 |
+
return y
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class TokenBridge(nn.Module):
|
| 389 |
+
def __init__(self, hidden_dim: int, dropout: float = 0.0,
|
| 390 |
+
kernel_size: int = 9, dilations=(1, 2, 4, 8, 16),
|
| 391 |
+
use_global_token: bool = True):
|
| 392 |
+
super().__init__()
|
| 393 |
+
h = hidden_dim
|
| 394 |
+
pad = lambda d: d * (kernel_size // 2)
|
| 395 |
+
|
| 396 |
+
self.dw_B = nn.ModuleList([
|
| 397 |
+
nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d, groups=h, bias=False)
|
| 398 |
+
for d in dilations
|
| 399 |
+
])
|
| 400 |
+
self.mix_B = nn.Conv1d(h * len(dilations), h, 1)
|
| 401 |
+
|
| 402 |
+
self.dw_A = nn.ModuleList([
|
| 403 |
+
nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d, groups=h, bias=False)
|
| 404 |
+
for d in dilations
|
| 405 |
+
])
|
| 406 |
+
self.mix_A = nn.Conv1d(h * len(dilations), h, 1)
|
| 407 |
+
|
| 408 |
+
self.proj_B2A = nn.Linear(h, h)
|
| 409 |
+
self.proj_A2B = nn.Linear(h, h)
|
| 410 |
+
|
| 411 |
+
self.use_global_token = use_global_token
|
| 412 |
+
if use_global_token:
|
| 413 |
+
self.glb_B2A = nn.Linear(h, h)
|
| 414 |
+
self.glb_A2B = nn.Linear(h, h)
|
| 415 |
+
|
| 416 |
+
self.gate = nn.Linear(h * 4, h * 2)
|
| 417 |
+
self.dropout = nn.Dropout(dropout)
|
| 418 |
+
self.normA = nn.LayerNorm(h)
|
| 419 |
+
self.normB = nn.LayerNorm(h)
|
| 420 |
+
|
| 421 |
+
@staticmethod
|
| 422 |
+
def _agg(x: torch.Tensor, branches: nn.ModuleList, mix: nn.Module) -> torch.Tensor:
|
| 423 |
+
xch = x.transpose(1, 2)
|
| 424 |
+
ys = [conv(xch) for conv in branches]
|
| 425 |
+
y = torch.cat(ys, dim=1)
|
| 426 |
+
y = mix(y).transpose(1, 2).contiguous()
|
| 427 |
+
return y
|
| 428 |
+
|
| 429 |
+
def forward(self, xA: torch.Tensor, xB: torch.Tensor):
|
| 430 |
+
ctxB = self._agg(xB, self.dw_B, self.mix_B)
|
| 431 |
+
ctxA = self._agg(xA, self.dw_A, self.mix_A)
|
| 432 |
+
|
| 433 |
+
locA = self.proj_B2A(xB + ctxB)
|
| 434 |
+
locB = self.proj_A2B(xA + ctxA)
|
| 435 |
+
|
| 436 |
+
if self.use_global_token:
|
| 437 |
+
gB = self.glb_B2A(xB.mean(dim=1, keepdim=True))
|
| 438 |
+
gA = self.glb_A2B(xA.mean(dim=1, keepdim=True))
|
| 439 |
+
locA = locA + gB.expand(-1, xA.size(1), -1)
|
| 440 |
+
locB = locB + gA.expand(-1, xB.size(1), -1)
|
| 441 |
+
|
| 442 |
+
z = torch.cat([xA, xB, xA - xB, xA * xB], dim=-1)
|
| 443 |
+
gA, gB = self.gate(z).chunk(2, dim=-1)
|
| 444 |
+
gA = torch.sigmoid(gA)
|
| 445 |
+
gB = torch.sigmoid(gB)
|
| 446 |
+
|
| 447 |
+
yA = self.normA(xA + self.dropout(gA * locA))
|
| 448 |
+
yB = self.normB(xB + self.dropout(gB * locB))
|
| 449 |
+
return yA, yB
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def semantic_preservation_loss(R_plus: torch.Tensor, H_S_plus: torch.Tensor,
|
| 453 |
+
λ_recon: float = 1.0, λ_local: float = 0.5, λ_global: float = 0.2):
|
| 454 |
+
recon = F.mse_loss(H_S_plus, R_plus)
|
| 455 |
+
|
| 456 |
+
if R_plus.size(1) >= 2:
|
| 457 |
+
d_ref = R_plus[:, 1:] - R_plus[:, :-1]
|
| 458 |
+
d_S = H_S_plus[:, 1:] - H_S_plus[:, :-1]
|
| 459 |
+
local = F.mse_loss(d_S, d_ref)
|
| 460 |
+
else:
|
| 461 |
+
local = torch.tensor(0.0, device=R_plus.device)
|
| 462 |
+
|
| 463 |
+
def gram_norm(x):
|
| 464 |
+
G = torch.einsum("b i d, b j d -> b i j", x, x)
|
| 465 |
+
return G / (G.norm(dim=(1, 2), keepdim=True) + 1e-6)
|
| 466 |
+
|
| 467 |
+
glob = F.mse_loss(gram_norm(H_S_plus), gram_norm(R_plus))
|
| 468 |
+
return λ_recon * recon + λ_local * local + λ_global * glob
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
@contextmanager
|
| 472 |
+
def eval_mode(*modules):
|
| 473 |
+
states = [m.training for m in modules]
|
| 474 |
+
try:
|
| 475 |
+
for m in modules:
|
| 476 |
+
if m is not None:
|
| 477 |
+
m.eval()
|
| 478 |
+
yield
|
| 479 |
+
finally:
|
| 480 |
+
for m, s in zip(modules, states):
|
| 481 |
+
if m is not None:
|
| 482 |
+
m.train(s)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class SSScanDNAHybridModel(nn.Module):
|
| 486 |
+
"""
|
| 487 |
+
Latest training-structure CrossDNA backbone.
|
| 488 |
+
|
| 489 |
+
HF wrapper strategy:
|
| 490 |
+
- backbone remains compact 5-token A/C/G/T/N model
|
| 491 |
+
- HF wrapper maps tokenizer ids -> compact ids
|
| 492 |
+
- HF wrapper expands 5-way logits -> tokenizer vocab logits
|
| 493 |
+
"""
|
| 494 |
+
|
| 495 |
+
def __init__(
|
| 496 |
+
self,
|
| 497 |
+
config: Optional[Any] = None,
|
| 498 |
+
alphabet_size=5,
|
| 499 |
+
d_model=128,
|
| 500 |
+
block_size=2048,
|
| 501 |
+
comba_cfg=None,
|
| 502 |
+
transformer_cfg=None,
|
| 503 |
+
depth=4,
|
| 504 |
+
drop_path_rates=None,
|
| 505 |
+
pretrain=False,
|
| 506 |
+
for_representation=False,
|
| 507 |
+
use_final_conv=False,
|
| 508 |
+
use_s_scan: bool = True,
|
| 509 |
+
use_mem: bool = False,
|
| 510 |
+
use_rc_kl: bool = False,
|
| 511 |
+
use_barlow: bool = False,
|
| 512 |
+
use_tv: bool = False,
|
| 513 |
+
sem_max_weight: float = 0.2,
|
| 514 |
+
sem_warmup_steps: int = 3000,
|
| 515 |
+
rc_max_weight: float = 0.2,
|
| 516 |
+
rc_warmup_steps: int = 2000,
|
| 517 |
+
rc_tau: float = 1.5,
|
| 518 |
+
rc_bidirectional_stopgrad: bool = True,
|
| 519 |
+
aux_ce_weight: float = 0.1,
|
| 520 |
+
gate_freeze_steps: int = 1000,
|
| 521 |
+
detach_gate: bool = False,
|
| 522 |
+
gate_sup_weight: float = 0.005,
|
| 523 |
+
gate_sup_warmup_steps: int = 500,
|
| 524 |
+
gate_temp: float = 2.0,
|
| 525 |
+
dropout=0.1,
|
| 526 |
+
use_ema_teacher: bool = True,
|
| 527 |
+
ema_decay: float = 0.999,
|
| 528 |
+
auto_update_ema_in_forward: bool = True,
|
| 529 |
+
use_bridge: bool = True,
|
| 530 |
+
bridge_dropout: float = 0.0,
|
| 531 |
+
use_checkpointing: bool = True,
|
| 532 |
+
checkpoint_chunk_size: int = 2,
|
| 533 |
+
checkpoint_core_layers: bool = False,
|
| 534 |
+
core_checkpoint_chunk_size: int = 1,
|
| 535 |
+
return_ab_logits: bool = True,
|
| 536 |
+
streaming_loss: bool = True,
|
| 537 |
+
streaming_report_ab: bool = True,
|
| 538 |
+
disable_cross_view: bool = False,
|
| 539 |
+
**unused_kwargs,
|
| 540 |
+
):
|
| 541 |
+
super().__init__()
|
| 542 |
+
|
| 543 |
+
self.config = config
|
| 544 |
+
if config is not None:
|
| 545 |
+
cfg = _to_plain_container(config)
|
| 546 |
+
|
| 547 |
+
alphabet_size = _cfg_get(cfg, "alphabet_size", alphabet_size)
|
| 548 |
+
d_model = _cfg_get(cfg, "d_model", d_model)
|
| 549 |
+
block_size = _cfg_get(cfg, "block_size", block_size)
|
| 550 |
+
depth = _cfg_get(cfg, "depth", depth)
|
| 551 |
+
drop_path_rates = _cfg_get(cfg, "drop_path_rates", drop_path_rates)
|
| 552 |
+
|
| 553 |
+
pretrain = _cfg_get(cfg, "pretrain", pretrain)
|
| 554 |
+
for_representation = _cfg_get(cfg, "for_representation", for_representation)
|
| 555 |
+
use_final_conv = _cfg_get(cfg, "use_final_conv", use_final_conv)
|
| 556 |
+
|
| 557 |
+
use_s_scan = _cfg_get(cfg, "use_s_scan", use_s_scan)
|
| 558 |
+
use_mem = _cfg_get(cfg, "use_mem", use_mem)
|
| 559 |
+
use_rc_kl = _cfg_get(cfg, "use_rc_kl", use_rc_kl)
|
| 560 |
+
use_barlow = _cfg_get(cfg, "use_barlow", use_barlow)
|
| 561 |
+
use_tv = _cfg_get(cfg, "use_tv", use_tv)
|
| 562 |
+
|
| 563 |
+
sem_max_weight = _cfg_get(cfg, "sem_max_weight", sem_max_weight)
|
| 564 |
+
sem_warmup_steps = _cfg_get(cfg, "sem_warmup_steps", sem_warmup_steps)
|
| 565 |
+
rc_max_weight = _cfg_get(cfg, "rc_max_weight", rc_max_weight)
|
| 566 |
+
rc_warmup_steps = _cfg_get(cfg, "rc_warmup_steps", rc_warmup_steps)
|
| 567 |
+
rc_tau = _cfg_get(cfg, "rc_tau", rc_tau)
|
| 568 |
+
rc_bidirectional_stopgrad = _cfg_get(cfg, "rc_bidirectional_stopgrad", rc_bidirectional_stopgrad)
|
| 569 |
+
|
| 570 |
+
aux_ce_weight = _cfg_get(cfg, "aux_ce_weight", aux_ce_weight)
|
| 571 |
+
gate_freeze_steps = _cfg_get(cfg, "gate_freeze_steps", gate_freeze_steps)
|
| 572 |
+
detach_gate = _cfg_get(cfg, "detach_gate", detach_gate)
|
| 573 |
+
gate_sup_weight = _cfg_get(cfg, "gate_sup_weight", gate_sup_weight)
|
| 574 |
+
gate_sup_warmup_steps = _cfg_get(cfg, "gate_sup_warmup_steps", gate_sup_warmup_steps)
|
| 575 |
+
gate_temp = _cfg_get(cfg, "gate_temp", gate_temp)
|
| 576 |
+
dropout = _cfg_get(cfg, "dropout", dropout)
|
| 577 |
+
|
| 578 |
+
use_bridge = _cfg_get(cfg, "use_bridge", use_bridge)
|
| 579 |
+
bridge_dropout = _cfg_get(cfg, "bridge_dropout", bridge_dropout)
|
| 580 |
+
|
| 581 |
+
use_checkpointing = _cfg_get(cfg, "use_checkpointing", use_checkpointing)
|
| 582 |
+
checkpoint_chunk_size = _cfg_get(cfg, "checkpoint_chunk_size", checkpoint_chunk_size)
|
| 583 |
+
|
| 584 |
+
checkpoint_core_layers = _cfg_get(cfg, "checkpoint_core_layers", checkpoint_core_layers)
|
| 585 |
+
core_checkpoint_chunk_size = _cfg_get(cfg, "core_checkpoint_chunk_size", core_checkpoint_chunk_size)
|
| 586 |
+
|
| 587 |
+
return_ab_logits = _cfg_get(cfg, "return_ab_logits", return_ab_logits)
|
| 588 |
+
streaming_loss = _cfg_get(cfg, "streaming_loss", streaming_loss)
|
| 589 |
+
streaming_report_ab = _cfg_get(cfg, "streaming_report_ab", streaming_report_ab)
|
| 590 |
+
|
| 591 |
+
use_ema_teacher = _cfg_get(cfg, "use_ema_teacher", use_ema_teacher)
|
| 592 |
+
ema_decay = _cfg_get(cfg, "ema_decay", ema_decay)
|
| 593 |
+
auto_update_ema_in_forward = _cfg_get(cfg, "auto_update_ema_in_forward", auto_update_ema_in_forward)
|
| 594 |
+
disable_cross_view = _cfg_get(cfg, "disable_cross_view", disable_cross_view)
|
| 595 |
+
|
| 596 |
+
transformer_cfg = _cfg_get(cfg, "transformer_cfg", transformer_cfg)
|
| 597 |
+
comba_cfg = _cfg_get(cfg, "comba_cfg", comba_cfg)
|
| 598 |
+
|
| 599 |
+
transformer_cfg = _to_plain_container(transformer_cfg)
|
| 600 |
+
comba_cfg = _to_plain_container(comba_cfg)
|
| 601 |
+
drop_path_rates = _to_plain_container(drop_path_rates)
|
| 602 |
+
|
| 603 |
+
self.alphabet_size = int(alphabet_size)
|
| 604 |
+
self.pretrain = bool(pretrain)
|
| 605 |
+
self.for_representation = bool(for_representation)
|
| 606 |
+
self.block_size = int(block_size)
|
| 607 |
+
self.use_final_conv = bool(use_final_conv)
|
| 608 |
+
self.d_model = int(d_model)
|
| 609 |
+
|
| 610 |
+
self.use_checkpointing = bool(use_checkpointing)
|
| 611 |
+
self.checkpoint_chunk_size = int(checkpoint_chunk_size)
|
| 612 |
+
|
| 613 |
+
self.checkpoint_core_layers = bool(checkpoint_core_layers)
|
| 614 |
+
self.core_checkpoint_chunk_size = int(core_checkpoint_chunk_size)
|
| 615 |
+
|
| 616 |
+
self.return_ab_logits = bool(return_ab_logits)
|
| 617 |
+
self.streaming_loss = bool(streaming_loss)
|
| 618 |
+
self.streaming_report_ab = bool(streaming_report_ab)
|
| 619 |
+
|
| 620 |
+
self.use_ema_teacher = bool(use_ema_teacher)
|
| 621 |
+
self.ema_decay = float(ema_decay)
|
| 622 |
+
self.auto_update_ema_in_forward = bool(auto_update_ema_in_forward)
|
| 623 |
+
self.disable_cross_view = bool(disable_cross_view)
|
| 624 |
+
|
| 625 |
+
self.register_buffer("g_step", torch.zeros(1, dtype=torch.long))
|
| 626 |
+
|
| 627 |
+
self.linear = nn.Conv1d(self.alphabet_size, self.d_model, kernel_size=9, padding=4)
|
| 628 |
+
self.rc_linear = nn.Conv1d(self.alphabet_size, self.d_model, kernel_size=9, padding=4)
|
| 629 |
+
|
| 630 |
+
self.mlm_mask_embed = nn.Parameter(torch.zeros(self.d_model))
|
| 631 |
+
self.special_token_embed = nn.Parameter(torch.zeros(self.d_model))
|
| 632 |
+
nn.init.normal_(self.mlm_mask_embed, mean=0.0, std=0.02)
|
| 633 |
+
nn.init.normal_(self.special_token_embed, mean=0.0, std=0.02)
|
| 634 |
+
|
| 635 |
+
self.branchA_core = DeepEnhancedBranch(
|
| 636 |
+
hidden_dim=self.d_model,
|
| 637 |
+
comba_cfg=comba_cfg,
|
| 638 |
+
transformer_cfg=transformer_cfg,
|
| 639 |
+
depth=int(depth),
|
| 640 |
+
drop_path_rates=drop_path_rates,
|
| 641 |
+
checkpoint_core_layers=self.checkpoint_core_layers,
|
| 642 |
+
core_checkpoint_chunk_size=self.core_checkpoint_chunk_size,
|
| 643 |
+
)
|
| 644 |
+
self.branchB_core = DeepEnhancedBranch(
|
| 645 |
+
hidden_dim=self.d_model,
|
| 646 |
+
comba_cfg=comba_cfg,
|
| 647 |
+
transformer_cfg=transformer_cfg,
|
| 648 |
+
depth=int(depth),
|
| 649 |
+
drop_path_rates=drop_path_rates,
|
| 650 |
+
checkpoint_core_layers=self.checkpoint_core_layers,
|
| 651 |
+
core_checkpoint_chunk_size=self.core_checkpoint_chunk_size,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
self.use_bridge = bool(use_bridge)
|
| 655 |
+
if self.use_bridge:
|
| 656 |
+
self.bridge = TokenBridge(self.d_model, dropout=float(bridge_dropout))
|
| 657 |
+
else:
|
| 658 |
+
self.bridge = None
|
| 659 |
+
|
| 660 |
+
if self.use_ema_teacher:
|
| 661 |
+
self.branchA_core_ema = copy.deepcopy(self.branchA_core)
|
| 662 |
+
self.branchB_core_ema = copy.deepcopy(self.branchB_core)
|
| 663 |
+
for p in self.branchA_core_ema.parameters():
|
| 664 |
+
p.requires_grad_(False)
|
| 665 |
+
for p in self.branchB_core_ema.parameters():
|
| 666 |
+
p.requires_grad_(False)
|
| 667 |
+
|
| 668 |
+
if self.use_bridge:
|
| 669 |
+
self.bridge_ema = copy.deepcopy(self.bridge)
|
| 670 |
+
for p in self.bridge_ema.parameters():
|
| 671 |
+
p.requires_grad_(False)
|
| 672 |
+
else:
|
| 673 |
+
self.bridge_ema = None
|
| 674 |
+
else:
|
| 675 |
+
self.branchA_core_ema = None
|
| 676 |
+
self.branchB_core_ema = None
|
| 677 |
+
self.bridge_ema = None
|
| 678 |
+
|
| 679 |
+
self.proj_A = Mlp(self.d_model, self.d_model * 2, self.d_model, activation=F.gelu, return_residual=True)
|
| 680 |
+
self.proj_B = Mlp(self.d_model, self.d_model * 2, self.d_model, activation=F.gelu, return_residual=True)
|
| 681 |
+
self.gate_fuse = nn.Linear(2 * self.d_model, self.d_model)
|
| 682 |
+
self.out_linear = nn.Linear(self.d_model, self.alphabet_size)
|
| 683 |
+
self.dropout = nn.Dropout(float(dropout))
|
| 684 |
+
|
| 685 |
+
P_comp, _ = make_complement_perm(self.alphabet_size)
|
| 686 |
+
self.register_buffer("P_comp", P_comp)
|
| 687 |
+
|
| 688 |
+
self.use_s_scan = bool(use_s_scan)
|
| 689 |
+
self.use_mem = bool(use_mem)
|
| 690 |
+
self.use_rc_kl = bool(use_rc_kl)
|
| 691 |
+
self.use_barlow = bool(use_barlow)
|
| 692 |
+
self.use_tv = bool(use_tv)
|
| 693 |
+
|
| 694 |
+
self.sem_max_weight = float(sem_max_weight)
|
| 695 |
+
self.sem_warmup_steps = int(sem_warmup_steps)
|
| 696 |
+
self.rc_max_weight = float(rc_max_weight)
|
| 697 |
+
self.rc_warmup_steps = int(rc_warmup_steps)
|
| 698 |
+
self.rc_tau = float(rc_tau)
|
| 699 |
+
self.rc_bidirectional_stopgrad = bool(rc_bidirectional_stopgrad)
|
| 700 |
+
|
| 701 |
+
self.aux_ce_weight = float(aux_ce_weight)
|
| 702 |
+
self.gate_freeze_steps = int(gate_freeze_steps)
|
| 703 |
+
self.detach_gate = bool(detach_gate)
|
| 704 |
+
self.gate_sup_weight = float(gate_sup_weight)
|
| 705 |
+
self.gate_sup_warmup_steps = int(gate_sup_warmup_steps)
|
| 706 |
+
self.gate_temp = float(gate_temp)
|
| 707 |
+
|
| 708 |
+
if self.use_final_conv:
|
| 709 |
+
self.final_conv = nn.Conv1d(self.d_model, self.d_model, kernel_size=3, padding=1)
|
| 710 |
+
|
| 711 |
+
self._unused_init_kwargs = dict(unused_kwargs) if unused_kwargs else {}
|
| 712 |
+
|
| 713 |
+
def _branch_receives_rc(self, branch: str, block_idx: int) -> bool:
|
| 714 |
+
if self.disable_cross_view:
|
| 715 |
+
return branch == "B"
|
| 716 |
+
|
| 717 |
+
if branch == "A":
|
| 718 |
+
return (block_idx % 2) == 1
|
| 719 |
+
if branch == "B":
|
| 720 |
+
return (block_idx % 2) == 0
|
| 721 |
+
raise ValueError(f"Unknown branch: {branch}")
|
| 722 |
+
|
| 723 |
+
def _route_block_inputs(self, t_block: int, fwd_in: torch.Tensor, rc_in: torch.Tensor):
|
| 724 |
+
if self._branch_receives_rc("A", t_block):
|
| 725 |
+
return rc_in, fwd_in
|
| 726 |
+
return fwd_in, rc_in
|
| 727 |
+
|
| 728 |
+
def _realign_chunk_outputs(self, H_A: torch.Tensor, H_B: torch.Tensor, chunk_start: int):
|
| 729 |
+
num_blocks = H_A.size(0)
|
| 730 |
+
for c in range(num_blocks):
|
| 731 |
+
actual_t = chunk_start + c
|
| 732 |
+
if self._branch_receives_rc("A", actual_t):
|
| 733 |
+
H_A[c] = torch.flip(H_A[c], dims=[1])
|
| 734 |
+
if self._branch_receives_rc("B", actual_t):
|
| 735 |
+
H_B[c] = torch.flip(H_B[c], dims=[1])
|
| 736 |
+
return H_A, H_B
|
| 737 |
+
|
| 738 |
+
def _build_chunk_rc_masks(self, chunk_start: int, num_blocks: int, B: int, L_blk: int, device):
|
| 739 |
+
rcA = torch.tensor(
|
| 740 |
+
[self._branch_receives_rc("A", chunk_start + c) for c in range(num_blocks)],
|
| 741 |
+
device=device,
|
| 742 |
+
dtype=torch.bool,
|
| 743 |
+
)
|
| 744 |
+
maskA_row = rcA.repeat_interleave(B).unsqueeze(1)
|
| 745 |
+
maskA = maskA_row.expand(-1, L_blk)
|
| 746 |
+
maskB = ~maskA
|
| 747 |
+
return maskA, maskB
|
| 748 |
+
|
| 749 |
+
@torch.no_grad()
|
| 750 |
+
def update_ema(self):
|
| 751 |
+
if not self.use_ema_teacher:
|
| 752 |
+
return
|
| 753 |
+
if self.branchA_core_ema is None or self.branchB_core_ema is None:
|
| 754 |
+
return
|
| 755 |
+
|
| 756 |
+
d = float(self.ema_decay)
|
| 757 |
+
for m_ema, m in [
|
| 758 |
+
(self.branchA_core_ema, self.branchA_core),
|
| 759 |
+
(self.branchB_core_ema, self.branchB_core),
|
| 760 |
+
]:
|
| 761 |
+
for p_ema, p in zip(m_ema.parameters(), m.parameters()):
|
| 762 |
+
p_ema.data.lerp_(p.data, 1.0 - d)
|
| 763 |
+
|
| 764 |
+
if self.use_bridge and (self.bridge is not None) and (self.bridge_ema is not None):
|
| 765 |
+
for p_ema, p in zip(self.bridge_ema.parameters(), self.bridge.parameters()):
|
| 766 |
+
p_ema.data.lerp_(p.data, 1.0 - d)
|
| 767 |
+
|
| 768 |
+
def _forward_s_scan_chunk_streaming(
|
| 769 |
+
self,
|
| 770 |
+
X_A: torch.Tensor,
|
| 771 |
+
X_B: torch.Tensor,
|
| 772 |
+
A_emb_fwd: torch.Tensor,
|
| 773 |
+
B_emb_rc: torch.Tensor,
|
| 774 |
+
mlm_labels: torch.Tensor,
|
| 775 |
+
chunk_start_t: torch.Tensor,
|
| 776 |
+
num_blocks_t: torch.Tensor,
|
| 777 |
+
step_t: torch.Tensor,
|
| 778 |
+
report_ab_t: torch.Tensor,
|
| 779 |
+
):
|
| 780 |
+
chunk_start = int(chunk_start_t.item())
|
| 781 |
+
num_blocks = int(num_blocks_t.item())
|
| 782 |
+
step = int(step_t.item())
|
| 783 |
+
report_ab = bool(int(report_ab_t.item()))
|
| 784 |
+
|
| 785 |
+
BC, L_blk, H = X_A.shape
|
| 786 |
+
B = BC // max(1, num_blocks)
|
| 787 |
+
device = X_A.device
|
| 788 |
+
|
| 789 |
+
H_A = self.branchA_core(X_A)
|
| 790 |
+
H_B = self.branchB_core(X_B)
|
| 791 |
+
|
| 792 |
+
H_A = H_A.view(num_blocks, B, L_blk, H)
|
| 793 |
+
H_B = H_B.view(num_blocks, B, L_blk, H)
|
| 794 |
+
H_A, H_B = self._realign_chunk_outputs(H_A, H_B, chunk_start)
|
| 795 |
+
H_A = H_A.reshape(BC, L_blk, H)
|
| 796 |
+
H_B = H_B.reshape(BC, L_blk, H)
|
| 797 |
+
|
| 798 |
+
if self.use_bridge and self.bridge is not None:
|
| 799 |
+
H_A, H_B = self.bridge(H_A, H_B)
|
| 800 |
+
|
| 801 |
+
fA, rA = self.proj_A(H_A)
|
| 802 |
+
FA = fA + rA
|
| 803 |
+
fB, rB = self.proj_B(H_B)
|
| 804 |
+
FB = fB + rB
|
| 805 |
+
|
| 806 |
+
gate_in = torch.cat([FA, FB], dim=-1)
|
| 807 |
+
g_logits = self.gate_fuse(gate_in)
|
| 808 |
+
g_raw = torch.sigmoid(g_logits / max(1e-6, self.gate_temp))
|
| 809 |
+
|
| 810 |
+
if step < self.gate_freeze_steps:
|
| 811 |
+
g = 0.5 * torch.ones_like(g_raw)
|
| 812 |
+
else:
|
| 813 |
+
g = g_raw
|
| 814 |
+
|
| 815 |
+
if self.detach_gate:
|
| 816 |
+
mix = g.detach() * FA + (1 - g.detach()) * FB
|
| 817 |
+
else:
|
| 818 |
+
mix = g * FA + (1 - g) * FB
|
| 819 |
+
|
| 820 |
+
fused = F.layer_norm(mix, (mix.size(-1),))
|
| 821 |
+
fused = ensure_finite(fused, "fused_blk")
|
| 822 |
+
|
| 823 |
+
if self.use_final_conv:
|
| 824 |
+
fused = self.final_conv(fused.permute(0, 2, 1)).permute(0, 2, 1)
|
| 825 |
+
|
| 826 |
+
logits = self.out_linear(fused)
|
| 827 |
+
C = logits.size(-1)
|
| 828 |
+
|
| 829 |
+
logits2d = logits.reshape(-1, C)
|
| 830 |
+
labels1d = mlm_labels.reshape(-1)
|
| 831 |
+
|
| 832 |
+
ce_sum = F.cross_entropy(logits2d, labels1d, ignore_index=-100, reduction="sum")
|
| 833 |
+
|
| 834 |
+
with torch.no_grad():
|
| 835 |
+
valid = (labels1d != -100)
|
| 836 |
+
n_masked = valid.sum()
|
| 837 |
+
|
| 838 |
+
with torch.no_grad():
|
| 839 |
+
correct1 = torch.zeros([], device=device, dtype=torch.long)
|
| 840 |
+
correct3 = torch.zeros([], device=device, dtype=torch.long)
|
| 841 |
+
if n_masked.item() > 0:
|
| 842 |
+
sel_logits = logits2d[valid]
|
| 843 |
+
sel_labels = labels1d[valid]
|
| 844 |
+
pred1 = sel_logits.argmax(dim=-1)
|
| 845 |
+
correct1 = pred1.eq(sel_labels).sum()
|
| 846 |
+
top3 = sel_logits.topk(3, dim=-1).indices
|
| 847 |
+
correct3 = top3.eq(sel_labels.unsqueeze(-1)).any(dim=-1).sum()
|
| 848 |
+
|
| 849 |
+
total_aux = torch.zeros([], device=device, dtype=torch.float32)
|
| 850 |
+
|
| 851 |
+
if self.pretrain:
|
| 852 |
+
maskA, maskB = self._build_chunk_rc_masks(chunk_start, num_blocks, B, L_blk, device)
|
| 853 |
+
|
| 854 |
+
need_sem = self.sem_max_weight > 0.0
|
| 855 |
+
if need_sem:
|
| 856 |
+
with torch.no_grad():
|
| 857 |
+
teacherA = self.branchA_core_ema if self.use_ema_teacher else self.branchA_core
|
| 858 |
+
teacherB = self.branchB_core_ema if self.use_ema_teacher else self.branchB_core
|
| 859 |
+
tbridge = self.bridge_ema if (self.use_bridge and self.use_ema_teacher and self.bridge_ema is not None) else (
|
| 860 |
+
self.bridge if self.use_bridge else None
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
mods = [teacherA, teacherB] + ([tbridge] if tbridge is not None else [])
|
| 864 |
+
with eval_mode(*mods):
|
| 865 |
+
R_plus_A = teacherA(A_emb_fwd)
|
| 866 |
+
R_plus_B = teacherB(A_emb_fwd)
|
| 867 |
+
if tbridge is not None:
|
| 868 |
+
R_plus_A, R_plus_B = tbridge(R_plus_A, R_plus_B)
|
| 869 |
+
|
| 870 |
+
R_minus_A_rc = teacherA(B_emb_rc)
|
| 871 |
+
R_minus_B_rc = teacherB(B_emb_rc)
|
| 872 |
+
R_minus_A_fwd = torch.flip(R_minus_A_rc, dims=[1])
|
| 873 |
+
R_minus_B_fwd = torch.flip(R_minus_B_rc, dims=[1])
|
| 874 |
+
if tbridge is not None:
|
| 875 |
+
R_minus_A_fwd, R_minus_B_fwd = tbridge(R_minus_A_fwd, R_minus_B_fwd)
|
| 876 |
+
|
| 877 |
+
R_A_teacher = torch.where(maskA.unsqueeze(-1), R_minus_A_fwd, R_plus_A)
|
| 878 |
+
R_B_teacher = torch.where(maskB.unsqueeze(-1), R_minus_B_fwd, R_plus_B)
|
| 879 |
+
|
| 880 |
+
sem_A = semantic_preservation_loss(R_A_teacher.float(), FA.float())
|
| 881 |
+
sem_B = semantic_preservation_loss(R_B_teacher.float(), FB.float())
|
| 882 |
+
w_sem = linear_warmup_weight(step, self.sem_warmup_steps, self.sem_max_weight)
|
| 883 |
+
total_aux = total_aux + w_sem * (sem_A + sem_B)
|
| 884 |
+
|
| 885 |
+
if (self.gate_sup_weight > 0.0) and (step >= self.gate_freeze_steps):
|
| 886 |
+
g_target = (~maskA).float().unsqueeze(-1)
|
| 887 |
+
g_token_logits = g_logits.mean(dim=-1, keepdim=True) / max(1e-6, self.gate_temp)
|
| 888 |
+
w_gate = linear_warmup_weight(
|
| 889 |
+
step - self.gate_freeze_steps,
|
| 890 |
+
self.gate_sup_warmup_steps,
|
| 891 |
+
self.gate_sup_weight,
|
| 892 |
+
)
|
| 893 |
+
total_aux = total_aux + w_gate * F.binary_cross_entropy_with_logits(
|
| 894 |
+
g_token_logits.float(), g_target.float()
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
need_rc = bool(self.use_rc_kl and (self.rc_max_weight > 0.0))
|
| 898 |
+
need_ab = bool(report_ab or need_rc)
|
| 899 |
+
|
| 900 |
+
logitsA = logitsB = None
|
| 901 |
+
if need_ab:
|
| 902 |
+
logitsA = self.out_linear(FA)
|
| 903 |
+
logitsB = self.out_linear(FB)
|
| 904 |
+
|
| 905 |
+
if need_rc:
|
| 906 |
+
if self.rc_bidirectional_stopgrad:
|
| 907 |
+
rc = rc_consistency_bidirectional_stopgrad(logitsA, logitsB, self.P_comp, tau=self.rc_tau)
|
| 908 |
+
else:
|
| 909 |
+
rc = rc_consistency_kl(logitsA, logitsB, self.P_comp, tau=self.rc_tau)
|
| 910 |
+
w_rc = linear_warmup_weight(step, self.rc_warmup_steps, self.rc_max_weight)
|
| 911 |
+
total_aux = total_aux + w_rc * rc
|
| 912 |
+
|
| 913 |
+
if self.use_barlow:
|
| 914 |
+
total_aux = total_aux + barlow_strand_loss_v2(H_A.float(), H_B.float())
|
| 915 |
+
if self.use_tv:
|
| 916 |
+
total_aux = total_aux + tv_mixed(fused.float())
|
| 917 |
+
|
| 918 |
+
with torch.no_grad():
|
| 919 |
+
correctA1 = torch.zeros([], device=device, dtype=torch.long)
|
| 920 |
+
correctB1 = torch.zeros([], device=device, dtype=torch.long)
|
| 921 |
+
correctA3 = torch.zeros([], device=device, dtype=torch.long)
|
| 922 |
+
correctB3 = torch.zeros([], device=device, dtype=torch.long)
|
| 923 |
+
|
| 924 |
+
if report_ab and n_masked.item() > 0:
|
| 925 |
+
if ('logitsA' not in locals()) or (logitsA is None):
|
| 926 |
+
logitsA = self.out_linear(FA)
|
| 927 |
+
logitsB = self.out_linear(FB)
|
| 928 |
+
|
| 929 |
+
_, perm = make_complement_perm(C, device=device)
|
| 930 |
+
|
| 931 |
+
valid = (labels1d != -100)
|
| 932 |
+
labels_safe = labels1d.clamp_min(0)
|
| 933 |
+
labels_comp = perm[labels_safe]
|
| 934 |
+
|
| 935 |
+
maskA_tok, maskB_tok = self._build_chunk_rc_masks(chunk_start, num_blocks, B, L_blk, device)
|
| 936 |
+
maskA_tok = maskA_tok.reshape(-1)
|
| 937 |
+
maskB_tok = maskB_tok.reshape(-1)
|
| 938 |
+
|
| 939 |
+
yA = torch.where(maskA_tok, labels_comp, labels_safe)[valid]
|
| 940 |
+
yB = torch.where(maskB_tok, labels_comp, labels_safe)[valid]
|
| 941 |
+
|
| 942 |
+
A2d = logitsA.reshape(-1, C)[valid]
|
| 943 |
+
B2d = logitsB.reshape(-1, C)[valid]
|
| 944 |
+
|
| 945 |
+
predA1 = A2d.argmax(dim=-1)
|
| 946 |
+
predB1 = B2d.argmax(dim=-1)
|
| 947 |
+
correctA1 = predA1.eq(yA).sum()
|
| 948 |
+
correctB1 = predB1.eq(yB).sum()
|
| 949 |
+
|
| 950 |
+
topA3 = A2d.topk(3, dim=-1).indices
|
| 951 |
+
topB3 = B2d.topk(3, dim=-1).indices
|
| 952 |
+
correctA3 = topA3.eq(yA.unsqueeze(-1)).any(dim=-1).sum()
|
| 953 |
+
correctB3 = topB3.eq(yB.unsqueeze(-1)).any(dim=-1).sum()
|
| 954 |
+
|
| 955 |
+
return ce_sum, n_masked, total_aux, correct1, correct3, correctA1, correctB1, correctA3, correctB3
|
| 956 |
+
|
| 957 |
+
def _forward_s_scan_chunk(
|
| 958 |
+
self,
|
| 959 |
+
X_A: torch.Tensor,
|
| 960 |
+
X_B: torch.Tensor,
|
| 961 |
+
A_emb_fwd: torch.Tensor,
|
| 962 |
+
B_emb_rc: torch.Tensor,
|
| 963 |
+
chunk_start_t: torch.Tensor,
|
| 964 |
+
num_blocks_t: torch.Tensor,
|
| 965 |
+
step_t: torch.Tensor,
|
| 966 |
+
need_logits_t: torch.Tensor,
|
| 967 |
+
need_ab_t: torch.Tensor,
|
| 968 |
+
):
|
| 969 |
+
chunk_start = int(chunk_start_t.item())
|
| 970 |
+
num_blocks = int(num_blocks_t.item())
|
| 971 |
+
step = int(step_t.item())
|
| 972 |
+
need_logits = bool(int(need_logits_t.item()))
|
| 973 |
+
need_ab = bool(int(need_ab_t.item()))
|
| 974 |
+
|
| 975 |
+
BC, L_blk, H = X_A.shape
|
| 976 |
+
B = BC // max(1, num_blocks)
|
| 977 |
+
device = X_A.device
|
| 978 |
+
|
| 979 |
+
H_A = self.branchA_core(X_A)
|
| 980 |
+
H_B = self.branchB_core(X_B)
|
| 981 |
+
|
| 982 |
+
H_A = H_A.view(num_blocks, B, L_blk, H)
|
| 983 |
+
H_B = H_B.view(num_blocks, B, L_blk, H)
|
| 984 |
+
H_A, H_B = self._realign_chunk_outputs(H_A, H_B, chunk_start)
|
| 985 |
+
H_A = H_A.reshape(BC, L_blk, H)
|
| 986 |
+
H_B = H_B.reshape(BC, L_blk, H)
|
| 987 |
+
|
| 988 |
+
if self.use_bridge and self.bridge is not None:
|
| 989 |
+
H_A, H_B = self.bridge(H_A, H_B)
|
| 990 |
+
|
| 991 |
+
fA, rA = self.proj_A(H_A)
|
| 992 |
+
FA = fA + rA
|
| 993 |
+
fB, rB = self.proj_B(H_B)
|
| 994 |
+
FB = fB + rB
|
| 995 |
+
|
| 996 |
+
gate_in_blk = torch.cat([FA, FB], dim=-1)
|
| 997 |
+
g_logits_blk = self.gate_fuse(gate_in_blk)
|
| 998 |
+
g_raw_blk = torch.sigmoid(g_logits_blk / max(1e-6, self.gate_temp))
|
| 999 |
+
|
| 1000 |
+
if step < self.gate_freeze_steps:
|
| 1001 |
+
g_blk = 0.5 * torch.ones_like(g_raw_blk)
|
| 1002 |
+
else:
|
| 1003 |
+
g_blk = g_raw_blk
|
| 1004 |
+
|
| 1005 |
+
if self.detach_gate:
|
| 1006 |
+
mix_blk = g_blk.detach() * FA + (1 - g_blk.detach()) * FB
|
| 1007 |
+
else:
|
| 1008 |
+
mix_blk = g_blk * FA + (1 - g_blk) * FB
|
| 1009 |
+
|
| 1010 |
+
fused_blk = F.layer_norm(mix_blk, (mix_blk.size(-1),))
|
| 1011 |
+
fused_blk = ensure_finite(fused_blk, "fused_blk")
|
| 1012 |
+
|
| 1013 |
+
if self.use_final_conv:
|
| 1014 |
+
fused_blk = self.final_conv(fused_blk.permute(0, 2, 1)).permute(0, 2, 1)
|
| 1015 |
+
|
| 1016 |
+
logits_blk = self.out_linear(fused_blk) if need_logits else fused_blk.new_empty((0,))
|
| 1017 |
+
|
| 1018 |
+
need_rc_logits = bool(self.use_rc_kl and (self.rc_max_weight > 0.0))
|
| 1019 |
+
need_ab_internal = bool(need_ab or need_rc_logits)
|
| 1020 |
+
|
| 1021 |
+
logitsA_blk = self.out_linear(FA) if need_ab_internal else fused_blk.new_empty((0,))
|
| 1022 |
+
logitsB_blk = self.out_linear(FB) if need_ab_internal else fused_blk.new_empty((0,))
|
| 1023 |
+
|
| 1024 |
+
total_aux_blk = torch.zeros([], device=device, dtype=torch.float32)
|
| 1025 |
+
|
| 1026 |
+
if self.pretrain:
|
| 1027 |
+
maskA, maskB = self._build_chunk_rc_masks(chunk_start, num_blocks, B, L_blk, device)
|
| 1028 |
+
|
| 1029 |
+
need_sem = self.sem_max_weight > 0.0
|
| 1030 |
+
if need_sem:
|
| 1031 |
+
with torch.no_grad():
|
| 1032 |
+
teacherA = self.branchA_core_ema if self.use_ema_teacher else self.branchA_core
|
| 1033 |
+
teacherB = self.branchB_core_ema if self.use_ema_teacher else self.branchB_core
|
| 1034 |
+
tbridge = self.bridge_ema if (self.use_bridge and self.use_ema_teacher and self.bridge_ema is not None) else (
|
| 1035 |
+
self.bridge if self.use_bridge else None
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
mods = [teacherA, teacherB] + ([tbridge] if tbridge is not None else [])
|
| 1039 |
+
with eval_mode(*mods):
|
| 1040 |
+
R_plus_A = teacherA(A_emb_fwd)
|
| 1041 |
+
R_plus_B = teacherB(A_emb_fwd)
|
| 1042 |
+
if tbridge is not None:
|
| 1043 |
+
R_plus_A, R_plus_B = tbridge(R_plus_A, R_plus_B)
|
| 1044 |
+
|
| 1045 |
+
R_minus_A_rc = teacherA(B_emb_rc)
|
| 1046 |
+
R_minus_B_rc = teacherB(B_emb_rc)
|
| 1047 |
+
R_minus_A_fwd = torch.flip(R_minus_A_rc, dims=[1])
|
| 1048 |
+
R_minus_B_fwd = torch.flip(R_minus_B_rc, dims=[1])
|
| 1049 |
+
if tbridge is not None:
|
| 1050 |
+
R_minus_A_fwd, R_minus_B_fwd = tbridge(R_minus_A_fwd, R_minus_B_fwd)
|
| 1051 |
+
|
| 1052 |
+
R_A_teacher = torch.where(maskA.unsqueeze(-1), R_minus_A_fwd, R_plus_A)
|
| 1053 |
+
R_B_teacher = torch.where(maskB.unsqueeze(-1), R_minus_B_fwd, R_plus_B)
|
| 1054 |
+
|
| 1055 |
+
sem_A = semantic_preservation_loss(R_A_teacher.float(), FA.float())
|
| 1056 |
+
sem_B = semantic_preservation_loss(R_B_teacher.float(), FB.float())
|
| 1057 |
+
w_sem = linear_warmup_weight(step, self.sem_warmup_steps, self.sem_max_weight)
|
| 1058 |
+
total_aux_blk = total_aux_blk + w_sem * (sem_A + sem_B)
|
| 1059 |
+
|
| 1060 |
+
if (self.gate_sup_weight > 0.0) and (step >= self.gate_freeze_steps):
|
| 1061 |
+
g_target_blk = (~maskA).float().unsqueeze(-1)
|
| 1062 |
+
g_token_logits_blk = g_logits_blk.mean(dim=-1, keepdim=True) / max(1e-6, self.gate_temp)
|
| 1063 |
+
w_gate = linear_warmup_weight(
|
| 1064 |
+
step - self.gate_freeze_steps,
|
| 1065 |
+
self.gate_sup_warmup_steps,
|
| 1066 |
+
self.gate_sup_weight,
|
| 1067 |
+
)
|
| 1068 |
+
total_aux_blk = total_aux_blk + w_gate * F.binary_cross_entropy_with_logits(
|
| 1069 |
+
g_token_logits_blk.float(),
|
| 1070 |
+
g_target_blk.float(),
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
if self.use_rc_kl and (self.rc_max_weight > 0.0):
|
| 1074 |
+
if logitsA_blk.numel() == 0:
|
| 1075 |
+
logitsA_blk = self.out_linear(FA)
|
| 1076 |
+
if logitsB_blk.numel() == 0:
|
| 1077 |
+
logitsB_blk = self.out_linear(FB)
|
| 1078 |
+
|
| 1079 |
+
if self.rc_bidirectional_stopgrad:
|
| 1080 |
+
rc = rc_consistency_bidirectional_stopgrad(logitsA_blk, logitsB_blk, self.P_comp, tau=self.rc_tau)
|
| 1081 |
+
else:
|
| 1082 |
+
rc = rc_consistency_kl(logitsA_blk, logitsB_blk, self.P_comp, tau=self.rc_tau)
|
| 1083 |
+
|
| 1084 |
+
w_rc = linear_warmup_weight(step, self.rc_warmup_steps, self.rc_max_weight)
|
| 1085 |
+
total_aux_blk = total_aux_blk + w_rc * rc
|
| 1086 |
+
|
| 1087 |
+
if self.use_barlow:
|
| 1088 |
+
total_aux_blk = total_aux_blk + barlow_strand_loss_v2(H_A.float(), H_B.float())
|
| 1089 |
+
if self.use_tv:
|
| 1090 |
+
total_aux_blk = total_aux_blk + tv_mixed(fused_blk.float())
|
| 1091 |
+
|
| 1092 |
+
return fused_blk, logits_blk, logitsA_blk, logitsB_blk, total_aux_blk
|
| 1093 |
+
|
| 1094 |
+
def forward(self, seq, t=None, cls=None, return_embedding=False, state=None, mask=None, **kwargs):
|
| 1095 |
+
step = int(self.g_step.item())
|
| 1096 |
+
if self.training:
|
| 1097 |
+
self.g_step += 1
|
| 1098 |
+
|
| 1099 |
+
_ = mask
|
| 1100 |
+
|
| 1101 |
+
mlm_mask = None
|
| 1102 |
+
mlm_labels = None
|
| 1103 |
+
special_mask = None
|
| 1104 |
+
|
| 1105 |
+
if self.pretrain:
|
| 1106 |
+
if isinstance(seq, (tuple, list)):
|
| 1107 |
+
mlm_mask = seq[1] if len(seq) >= 2 else None
|
| 1108 |
+
mlm_labels = seq[2] if len(seq) >= 3 else None
|
| 1109 |
+
special_mask = seq[3] if len(seq) >= 4 else None
|
| 1110 |
+
seq = seq[0]
|
| 1111 |
+
|
| 1112 |
+
device_type = seq.device.type if seq.device.type in ["cuda", "cpu", "xpu"] else "cuda"
|
| 1113 |
+
amp_dtype = preferred_amp_dtype()
|
| 1114 |
+
|
| 1115 |
+
rc_seq = reverse_complement(seq)
|
| 1116 |
+
|
| 1117 |
+
with torch.autocast(device_type=device_type, dtype=amp_dtype, enabled=(device_type == "cuda")):
|
| 1118 |
+
seq_oh = one_hot_float(seq, self.alphabet_size, dtype=amp_dtype)
|
| 1119 |
+
rc_oh = one_hot_float(rc_seq, self.alphabet_size, dtype=amp_dtype)
|
| 1120 |
+
|
| 1121 |
+
if special_mask is not None:
|
| 1122 |
+
special_mask = special_mask.to(dtype=torch.bool, device=seq.device)
|
| 1123 |
+
rc_special_mask = torch.flip(special_mask, dims=[1])
|
| 1124 |
+
|
| 1125 |
+
seq_oh = seq_oh.masked_fill(special_mask.unsqueeze(-1), 0.0)
|
| 1126 |
+
rc_oh = rc_oh.masked_fill(rc_special_mask.unsqueeze(-1), 0.0)
|
| 1127 |
+
else:
|
| 1128 |
+
rc_special_mask = None
|
| 1129 |
+
|
| 1130 |
+
h = F.gelu(self.linear(seq_oh.permute(0, 2, 1)))
|
| 1131 |
+
rc_h = F.gelu(self.rc_linear(rc_oh.permute(0, 2, 1)))
|
| 1132 |
+
del seq_oh, rc_oh
|
| 1133 |
+
|
| 1134 |
+
if special_mask is not None:
|
| 1135 |
+
non_special = (~special_mask).to(dtype=h.dtype).unsqueeze(1)
|
| 1136 |
+
rc_non_special = (~rc_special_mask).to(dtype=rc_h.dtype).unsqueeze(1)
|
| 1137 |
+
h = h * non_special
|
| 1138 |
+
rc_h = rc_h * rc_non_special
|
| 1139 |
+
|
| 1140 |
+
if mlm_mask is not None:
|
| 1141 |
+
mlm_mask_f = mlm_mask.to(dtype=h.dtype, device=h.device).unsqueeze(1)
|
| 1142 |
+
rc_mlm_mask_f = torch.flip(mlm_mask, dims=[1]).to(dtype=rc_h.dtype, device=rc_h.device).unsqueeze(1)
|
| 1143 |
+
h = h + mlm_mask_f * self.mlm_mask_embed.view(1, -1, 1)
|
| 1144 |
+
rc_h = rc_h + rc_mlm_mask_f * self.mlm_mask_embed.view(1, -1, 1)
|
| 1145 |
+
|
| 1146 |
+
if special_mask is not None:
|
| 1147 |
+
special_mask_f = special_mask.to(dtype=h.dtype, device=h.device).unsqueeze(1)
|
| 1148 |
+
rc_special_mask_f = rc_special_mask.to(dtype=rc_h.dtype, device=rc_h.device).unsqueeze(1)
|
| 1149 |
+
h = h + special_mask_f * self.special_token_embed.view(1, -1, 1)
|
| 1150 |
+
rc_h = rc_h + rc_special_mask_f * self.special_token_embed.view(1, -1, 1)
|
| 1151 |
+
|
| 1152 |
+
use_streaming = bool(
|
| 1153 |
+
self.pretrain
|
| 1154 |
+
and self.use_s_scan
|
| 1155 |
+
and self.streaming_loss
|
| 1156 |
+
and (mlm_labels is not None)
|
| 1157 |
+
and (not self.for_representation)
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
if use_streaming:
|
| 1161 |
+
B, H, L = h.shape
|
| 1162 |
+
l = self.block_size
|
| 1163 |
+
K = (L + l - 1) // l
|
| 1164 |
+
chunk_size = max(1, self.checkpoint_chunk_size)
|
| 1165 |
+
|
| 1166 |
+
ce_sum_total = torch.zeros([], device=seq.device, dtype=torch.float32)
|
| 1167 |
+
n_total = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1168 |
+
total_aux = torch.zeros([], device=seq.device, dtype=torch.float32)
|
| 1169 |
+
|
| 1170 |
+
correct1 = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1171 |
+
correct3 = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1172 |
+
correctA1 = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1173 |
+
correctB1 = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1174 |
+
correctA3 = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1175 |
+
correctB3 = torch.zeros([], device=seq.device, dtype=torch.long)
|
| 1176 |
+
|
| 1177 |
+
keep_rate = mlm_mask.float().mean() if mlm_mask is not None else torch.tensor(1.0, device=seq.device)
|
| 1178 |
+
report_ab_t = torch.tensor(int(self.streaming_report_ab), device=seq.device)
|
| 1179 |
+
|
| 1180 |
+
for chunk_start in range(0, K, chunk_size):
|
| 1181 |
+
chunk_end = min(chunk_start + chunk_size, K)
|
| 1182 |
+
|
| 1183 |
+
X_A_batch, X_B_batch = [], []
|
| 1184 |
+
Aemb_batch, Bemb_batch = [], []
|
| 1185 |
+
labels_batch = []
|
| 1186 |
+
lengths = []
|
| 1187 |
+
|
| 1188 |
+
for t_block in range(chunk_start, chunk_end):
|
| 1189 |
+
start = t_block * l
|
| 1190 |
+
end = min(start + l, L)
|
| 1191 |
+
blk_len = end - start
|
| 1192 |
+
lengths.append(blk_len)
|
| 1193 |
+
|
| 1194 |
+
fwd_emb = h[:, :, start:end].transpose(1, 2).contiguous()
|
| 1195 |
+
rc_emb = rc_h[:, :, start:end].transpose(1, 2).contiguous()
|
| 1196 |
+
|
| 1197 |
+
fwd_in = self.dropout(h[:, :, start:end]).transpose(1, 2).contiguous()
|
| 1198 |
+
rc_in = self.dropout(rc_h[:, :, start:end]).transpose(1, 2).contiguous()
|
| 1199 |
+
|
| 1200 |
+
XA, XB = self._route_block_inputs(t_block, fwd_in, rc_in)
|
| 1201 |
+
X_A_batch.append(XA)
|
| 1202 |
+
X_B_batch.append(XB)
|
| 1203 |
+
|
| 1204 |
+
Aemb_batch.append(fwd_emb)
|
| 1205 |
+
Bemb_batch.append(rc_emb)
|
| 1206 |
+
labels_batch.append(mlm_labels[:, start:end])
|
| 1207 |
+
|
| 1208 |
+
if len(set(lengths)) == 1:
|
| 1209 |
+
nb = len(X_A_batch)
|
| 1210 |
+
|
| 1211 |
+
X_A_tensor = torch.cat(X_A_batch, dim=0)
|
| 1212 |
+
X_B_tensor = torch.cat(X_B_batch, dim=0)
|
| 1213 |
+
Aemb_tensor = torch.cat(Aemb_batch, dim=0)
|
| 1214 |
+
Bemb_tensor = torch.cat(Bemb_batch, dim=0)
|
| 1215 |
+
labels_tensor = torch.cat(labels_batch, dim=0)
|
| 1216 |
+
|
| 1217 |
+
if self.training and self.use_checkpointing:
|
| 1218 |
+
ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = cp.checkpoint(
|
| 1219 |
+
self._forward_s_scan_chunk_streaming,
|
| 1220 |
+
X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor, labels_tensor,
|
| 1221 |
+
torch.tensor(chunk_start, device=seq.device),
|
| 1222 |
+
torch.tensor(nb, device=seq.device),
|
| 1223 |
+
torch.tensor(step, device=seq.device),
|
| 1224 |
+
report_ab_t,
|
| 1225 |
+
use_reentrant=False,
|
| 1226 |
+
)
|
| 1227 |
+
else:
|
| 1228 |
+
ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = self._forward_s_scan_chunk_streaming(
|
| 1229 |
+
X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor, labels_tensor,
|
| 1230 |
+
torch.tensor(chunk_start, device=seq.device),
|
| 1231 |
+
torch.tensor(nb, device=seq.device),
|
| 1232 |
+
torch.tensor(step, device=seq.device),
|
| 1233 |
+
report_ab_t,
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
ce_sum_total = ce_sum_total + ce_sum
|
| 1237 |
+
n_total = n_total + n_masked
|
| 1238 |
+
total_aux = total_aux + aux_blk
|
| 1239 |
+
|
| 1240 |
+
correct1 += c1
|
| 1241 |
+
correct3 += c3
|
| 1242 |
+
correctA1 += a1
|
| 1243 |
+
correctB1 += b1
|
| 1244 |
+
correctA3 += a3
|
| 1245 |
+
correctB3 += b3
|
| 1246 |
+
|
| 1247 |
+
del X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor, labels_tensor
|
| 1248 |
+
|
| 1249 |
+
else:
|
| 1250 |
+
for idx, t_block in enumerate(range(chunk_start, chunk_end)):
|
| 1251 |
+
if self.training and self.use_checkpointing:
|
| 1252 |
+
ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = cp.checkpoint(
|
| 1253 |
+
self._forward_s_scan_chunk_streaming,
|
| 1254 |
+
X_A_batch[idx], X_B_batch[idx], Aemb_batch[idx], Bemb_batch[idx], labels_batch[idx],
|
| 1255 |
+
torch.tensor(t_block, device=seq.device),
|
| 1256 |
+
torch.tensor(1, device=seq.device),
|
| 1257 |
+
torch.tensor(step, device=seq.device),
|
| 1258 |
+
report_ab_t,
|
| 1259 |
+
use_reentrant=False,
|
| 1260 |
+
)
|
| 1261 |
+
else:
|
| 1262 |
+
ce_sum, n_masked, aux_blk, c1, c3, a1, b1, a3, b3 = self._forward_s_scan_chunk_streaming(
|
| 1263 |
+
X_A_batch[idx], X_B_batch[idx], Aemb_batch[idx], Bemb_batch[idx], labels_batch[idx],
|
| 1264 |
+
torch.tensor(t_block, device=seq.device),
|
| 1265 |
+
torch.tensor(1, device=seq.device),
|
| 1266 |
+
torch.tensor(step, device=seq.device),
|
| 1267 |
+
report_ab_t,
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
ce_sum_total = ce_sum_total + ce_sum
|
| 1271 |
+
n_total = n_total + n_masked
|
| 1272 |
+
total_aux = total_aux + aux_blk
|
| 1273 |
+
|
| 1274 |
+
correct1 += c1
|
| 1275 |
+
correct3 += c3
|
| 1276 |
+
correctA1 += a1
|
| 1277 |
+
correctB1 += b1
|
| 1278 |
+
correctA3 += a3
|
| 1279 |
+
correctB3 += b3
|
| 1280 |
+
|
| 1281 |
+
del h, rc_h
|
| 1282 |
+
|
| 1283 |
+
if self.training and self.use_ema_teacher and self.auto_update_ema_in_forward:
|
| 1284 |
+
self.update_ema()
|
| 1285 |
+
|
| 1286 |
+
HybridOutput = namedtuple("HybridOutput", ["logits"])
|
| 1287 |
+
step_t = torch.tensor(step, device=seq.device, dtype=torch.long)
|
| 1288 |
+
|
| 1289 |
+
stats = torch.stack([
|
| 1290 |
+
keep_rate.to(torch.float32),
|
| 1291 |
+
correct1.to(torch.float32),
|
| 1292 |
+
correct3.to(torch.float32),
|
| 1293 |
+
correctA1.to(torch.float32),
|
| 1294 |
+
correctB1.to(torch.float32),
|
| 1295 |
+
correctA3.to(torch.float32),
|
| 1296 |
+
correctB3.to(torch.float32),
|
| 1297 |
+
], dim=0)
|
| 1298 |
+
|
| 1299 |
+
return HybridOutput(logits=(ce_sum_total, n_total, total_aux, stats, step_t)), None
|
| 1300 |
+
|
| 1301 |
+
fused = None
|
| 1302 |
+
|
| 1303 |
+
if self.use_s_scan:
|
| 1304 |
+
B, H, L = h.shape
|
| 1305 |
+
l = self.block_size
|
| 1306 |
+
K = (L + l - 1) // l
|
| 1307 |
+
chunk_size = max(1, self.checkpoint_chunk_size)
|
| 1308 |
+
|
| 1309 |
+
collect_fused = bool(self.for_representation)
|
| 1310 |
+
collect_logits = (not self.for_representation) or self.pretrain
|
| 1311 |
+
need_ab_logits = bool((self.pretrain and self.return_ab_logits) or self.use_rc_kl)
|
| 1312 |
+
|
| 1313 |
+
fused_out = torch.empty((B, L, self.d_model), device=seq.device, dtype=amp_dtype) if collect_fused else None
|
| 1314 |
+
logits_out = torch.empty((B, L, self.alphabet_size), device=seq.device, dtype=amp_dtype) if collect_logits else None
|
| 1315 |
+
logitsA_out = torch.empty((B, L, self.alphabet_size), device=seq.device, dtype=amp_dtype) if need_ab_logits else None
|
| 1316 |
+
logitsB_out = torch.empty((B, L, self.alphabet_size), device=seq.device, dtype=amp_dtype) if need_ab_logits else None
|
| 1317 |
+
|
| 1318 |
+
mask_A_rc = torch.empty((B, L), device=seq.device, dtype=torch.bool)
|
| 1319 |
+
mask_B_rc = torch.empty((B, L), device=seq.device, dtype=torch.bool)
|
| 1320 |
+
|
| 1321 |
+
total_aux = torch.zeros([], device=seq.device, dtype=torch.float32)
|
| 1322 |
+
|
| 1323 |
+
for chunk_start in range(0, K, chunk_size):
|
| 1324 |
+
chunk_end = min(chunk_start + chunk_size, K)
|
| 1325 |
+
|
| 1326 |
+
X_A_batch, X_B_batch = [], []
|
| 1327 |
+
Aemb_batch, Bemb_batch = [], []
|
| 1328 |
+
lengths = []
|
| 1329 |
+
|
| 1330 |
+
for t_block in range(chunk_start, chunk_end):
|
| 1331 |
+
start = t_block * l
|
| 1332 |
+
end = min(start + l, L)
|
| 1333 |
+
blk_len = end - start
|
| 1334 |
+
lengths.append(blk_len)
|
| 1335 |
+
|
| 1336 |
+
fwd_emb = h[:, :, start:end].transpose(1, 2).contiguous()
|
| 1337 |
+
rc_emb = rc_h[:, :, start:end].transpose(1, 2).contiguous()
|
| 1338 |
+
|
| 1339 |
+
fwd_in = self.dropout(h[:, :, start:end]).transpose(1, 2).contiguous()
|
| 1340 |
+
rc_in = self.dropout(rc_h[:, :, start:end]).transpose(1, 2).contiguous()
|
| 1341 |
+
|
| 1342 |
+
XA, XB = self._route_block_inputs(t_block, fwd_in, rc_in)
|
| 1343 |
+
X_A_batch.append(XA)
|
| 1344 |
+
X_B_batch.append(XB)
|
| 1345 |
+
|
| 1346 |
+
Aemb_batch.append(fwd_emb)
|
| 1347 |
+
Bemb_batch.append(rc_emb)
|
| 1348 |
+
|
| 1349 |
+
mask_A_rc[:, start:end] = self._branch_receives_rc("A", t_block)
|
| 1350 |
+
mask_B_rc[:, start:end] = self._branch_receives_rc("B", t_block)
|
| 1351 |
+
|
| 1352 |
+
if len(set(lengths)) == 1:
|
| 1353 |
+
blk_len = lengths[0]
|
| 1354 |
+
X_A_tensor = torch.cat(X_A_batch, dim=0)
|
| 1355 |
+
X_B_tensor = torch.cat(X_B_batch, dim=0)
|
| 1356 |
+
Aemb_tensor = torch.cat(Aemb_batch, dim=0)
|
| 1357 |
+
Bemb_tensor = torch.cat(Bemb_batch, dim=0)
|
| 1358 |
+
|
| 1359 |
+
need_logits_t = torch.tensor(int(collect_logits), device=seq.device)
|
| 1360 |
+
need_ab_t = torch.tensor(int(need_ab_logits), device=seq.device)
|
| 1361 |
+
|
| 1362 |
+
if self.training and self.use_checkpointing:
|
| 1363 |
+
fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = cp.checkpoint(
|
| 1364 |
+
self._forward_s_scan_chunk,
|
| 1365 |
+
X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
|
| 1366 |
+
torch.tensor(chunk_start, device=seq.device),
|
| 1367 |
+
torch.tensor(len(X_A_batch), device=seq.device),
|
| 1368 |
+
torch.tensor(step, device=seq.device),
|
| 1369 |
+
need_logits_t, need_ab_t,
|
| 1370 |
+
use_reentrant=False,
|
| 1371 |
+
)
|
| 1372 |
+
else:
|
| 1373 |
+
fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = self._forward_s_scan_chunk(
|
| 1374 |
+
X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
|
| 1375 |
+
torch.tensor(chunk_start, device=seq.device),
|
| 1376 |
+
torch.tensor(len(X_A_batch), device=seq.device),
|
| 1377 |
+
torch.tensor(step, device=seq.device),
|
| 1378 |
+
need_logits_t, need_ab_t,
|
| 1379 |
+
)
|
| 1380 |
+
|
| 1381 |
+
total_aux = total_aux + aux_blk
|
| 1382 |
+
|
| 1383 |
+
nb = len(X_A_batch)
|
| 1384 |
+
fused_view = fused_blk.view(nb, B, blk_len, -1)
|
| 1385 |
+
logits_view = logits_blk.view(nb, B, blk_len, -1) if (collect_logits and logits_blk.numel() > 0) else None
|
| 1386 |
+
logitsA_view = logitsA_blk.view(nb, B, blk_len, -1) if (need_ab_logits and logitsA_blk.numel() > 0) else None
|
| 1387 |
+
logitsB_view = logitsB_blk.view(nb, B, blk_len, -1) if (need_ab_logits and logitsB_blk.numel() > 0) else None
|
| 1388 |
+
|
| 1389 |
+
for c, t_block in enumerate(range(chunk_start, chunk_end)):
|
| 1390 |
+
start = t_block * l
|
| 1391 |
+
end = min(start + l, L)
|
| 1392 |
+
|
| 1393 |
+
if collect_fused:
|
| 1394 |
+
fused_out[:, start:end, :] = fused_view[c]
|
| 1395 |
+
if collect_logits:
|
| 1396 |
+
logits_out[:, start:end, :] = logits_view[c]
|
| 1397 |
+
if need_ab_logits:
|
| 1398 |
+
logitsA_out[:, start:end, :] = logitsA_view[c]
|
| 1399 |
+
logitsB_out[:, start:end, :] = logitsB_view[c]
|
| 1400 |
+
|
| 1401 |
+
del X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor
|
| 1402 |
+
del fused_blk, logits_blk, logitsA_blk, logitsB_blk
|
| 1403 |
+
|
| 1404 |
+
else:
|
| 1405 |
+
for idx, t_block in enumerate(range(chunk_start, chunk_end)):
|
| 1406 |
+
start = t_block * l
|
| 1407 |
+
end = min(start + l, L)
|
| 1408 |
+
|
| 1409 |
+
X_A_tensor = X_A_batch[idx]
|
| 1410 |
+
X_B_tensor = X_B_batch[idx]
|
| 1411 |
+
Aemb_tensor = Aemb_batch[idx]
|
| 1412 |
+
Bemb_tensor = Bemb_batch[idx]
|
| 1413 |
+
|
| 1414 |
+
need_logits_t = torch.tensor(int(collect_logits), device=seq.device)
|
| 1415 |
+
need_ab_t = torch.tensor(int(need_ab_logits), device=seq.device)
|
| 1416 |
+
|
| 1417 |
+
if self.training and self.use_checkpointing:
|
| 1418 |
+
fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = cp.checkpoint(
|
| 1419 |
+
self._forward_s_scan_chunk,
|
| 1420 |
+
X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
|
| 1421 |
+
torch.tensor(t_block, device=seq.device),
|
| 1422 |
+
torch.tensor(1, device=seq.device),
|
| 1423 |
+
torch.tensor(step, device=seq.device),
|
| 1424 |
+
need_logits_t, need_ab_t,
|
| 1425 |
+
use_reentrant=False,
|
| 1426 |
+
)
|
| 1427 |
+
else:
|
| 1428 |
+
fused_blk, logits_blk, logitsA_blk, logitsB_blk, aux_blk = self._forward_s_scan_chunk(
|
| 1429 |
+
X_A_tensor, X_B_tensor, Aemb_tensor, Bemb_tensor,
|
| 1430 |
+
torch.tensor(t_block, device=seq.device),
|
| 1431 |
+
torch.tensor(1, device=seq.device),
|
| 1432 |
+
torch.tensor(step, device=seq.device),
|
| 1433 |
+
need_logits_t, need_ab_t,
|
| 1434 |
+
)
|
| 1435 |
+
|
| 1436 |
+
total_aux = total_aux + aux_blk
|
| 1437 |
+
|
| 1438 |
+
if collect_fused:
|
| 1439 |
+
fused_out[:, start:end, :] = fused_blk
|
| 1440 |
+
if collect_logits:
|
| 1441 |
+
logits_out[:, start:end, :] = logits_blk
|
| 1442 |
+
if need_ab_logits and logitsA_blk.numel() > 0:
|
| 1443 |
+
logitsA_out[:, start:end, :] = logitsA_blk
|
| 1444 |
+
logitsB_out[:, start:end, :] = logitsB_blk
|
| 1445 |
+
|
| 1446 |
+
del fused_blk, logits_blk, logitsA_blk, logitsB_blk
|
| 1447 |
+
|
| 1448 |
+
del h, rc_h
|
| 1449 |
+
|
| 1450 |
+
logits = logits_out if collect_logits else None
|
| 1451 |
+
logits_A_only = logitsA_out if need_ab_logits else None
|
| 1452 |
+
logits_B_only = logitsB_out if need_ab_logits else None
|
| 1453 |
+
fused = fused_out if collect_fused else None
|
| 1454 |
+
|
| 1455 |
+
else:
|
| 1456 |
+
feat = self.dropout(h).transpose(1, 2).contiguous()
|
| 1457 |
+
rc_feat = self.dropout(rc_h).transpose(1, 2).contiguous()
|
| 1458 |
+
|
| 1459 |
+
H_A = self.branchA_core(feat)
|
| 1460 |
+
H_Br = self.branchB_core(rc_feat)
|
| 1461 |
+
R_A = H_A
|
| 1462 |
+
R_B = torch.flip(H_Br, dims=[1])
|
| 1463 |
+
|
| 1464 |
+
if self.use_bridge and self.bridge is not None:
|
| 1465 |
+
R_A, R_B = self.bridge(R_A, R_B)
|
| 1466 |
+
|
| 1467 |
+
fA, rA = self.proj_A(R_A)
|
| 1468 |
+
FA = fA + rA
|
| 1469 |
+
fB, rB = self.proj_B(R_B)
|
| 1470 |
+
FB = fB + rB
|
| 1471 |
+
|
| 1472 |
+
gate_in = torch.cat([FA, FB], dim=-1)
|
| 1473 |
+
g_logits = self.gate_fuse(gate_in)
|
| 1474 |
+
g_raw = torch.sigmoid(g_logits / max(1e-6, self.gate_temp))
|
| 1475 |
+
|
| 1476 |
+
if step < self.gate_freeze_steps:
|
| 1477 |
+
g = 0.5 * torch.ones_like(g_raw)
|
| 1478 |
+
else:
|
| 1479 |
+
g = g_raw
|
| 1480 |
+
|
| 1481 |
+
if self.detach_gate:
|
| 1482 |
+
mix = g.detach() * FA + (1 - g.detach()) * FB
|
| 1483 |
+
else:
|
| 1484 |
+
mix = g * FA + (1 - g) * FB
|
| 1485 |
+
|
| 1486 |
+
fused = F.layer_norm(mix, (mix.size(-1),))
|
| 1487 |
+
fused = ensure_finite(fused, "fused")
|
| 1488 |
+
|
| 1489 |
+
if self.use_final_conv:
|
| 1490 |
+
fused = self.final_conv(fused.permute(0, 2, 1)).permute(0, 2, 1)
|
| 1491 |
+
|
| 1492 |
+
logits = self.out_linear(fused) if (not self.for_representation or self.pretrain) else None
|
| 1493 |
+
|
| 1494 |
+
need_ab_logits = bool((self.pretrain and self.return_ab_logits) or self.use_rc_kl)
|
| 1495 |
+
logits_A_only = self.out_linear(FA) if need_ab_logits else None
|
| 1496 |
+
logits_B_only = self.out_linear(FB) if need_ab_logits else None
|
| 1497 |
+
|
| 1498 |
+
mask_A_rc = torch.zeros(FA.size()[:2], dtype=torch.bool, device=FA.device)
|
| 1499 |
+
mask_B_rc = torch.ones_like(mask_A_rc)
|
| 1500 |
+
|
| 1501 |
+
total_aux = logits.new_zeros(()) if self.pretrain and logits is not None else None
|
| 1502 |
+
|
| 1503 |
+
del h, rc_h, feat, rc_feat
|
| 1504 |
+
|
| 1505 |
+
if self.for_representation:
|
| 1506 |
+
return fused, None
|
| 1507 |
+
|
| 1508 |
+
if self.training and self.use_ema_teacher and self.auto_update_ema_in_forward:
|
| 1509 |
+
self.update_ema()
|
| 1510 |
+
|
| 1511 |
+
if self.pretrain:
|
| 1512 |
+
HybridOutput = namedtuple("HybridOutput", ["logits"])
|
| 1513 |
+
return HybridOutput(
|
| 1514 |
+
logits=(
|
| 1515 |
+
logits,
|
| 1516 |
+
mlm_mask,
|
| 1517 |
+
total_aux,
|
| 1518 |
+
logits_A_only.detach() if logits_A_only is not None else None,
|
| 1519 |
+
logits_B_only.detach() if logits_B_only is not None else None,
|
| 1520 |
+
mask_A_rc.detach() if mask_A_rc is not None else None,
|
| 1521 |
+
mask_B_rc.detach() if mask_B_rc is not None else None,
|
| 1522 |
+
int(step),
|
| 1523 |
+
)
|
| 1524 |
+
), None
|
| 1525 |
+
|
| 1526 |
+
return logits, None
|
| 1527 |
+
|
| 1528 |
+
@property
|
| 1529 |
+
def d_output(self):
|
| 1530 |
+
if getattr(self, "d_model", None) is None:
|
| 1531 |
+
raise NotImplementedError("SequenceModule instantiation must set d_output")
|
| 1532 |
+
return self.d_model
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
class CrossDNAForMaskedLM(PreTrainedModel):
|
| 1536 |
+
config_class = CrossDNAConfig
|
| 1537 |
+
base_model_prefix = "backbone"
|
| 1538 |
+
main_input_name = "input_ids"
|
| 1539 |
+
|
| 1540 |
+
# allow older HF checkpoints to load with warnings rather than hard-failing
|
| 1541 |
+
_keys_to_ignore_on_load_missing = [
|
| 1542 |
+
r"backbone\.mlm_mask_embed",
|
| 1543 |
+
r"backbone\.special_token_embed",
|
| 1544 |
+
r"backbone\.branchA_core_ema\..*",
|
| 1545 |
+
r"backbone\.branchB_core_ema\..*",
|
| 1546 |
+
r"backbone\.bridge_ema\..*",
|
| 1547 |
+
]
|
| 1548 |
+
|
| 1549 |
+
def __init__(self, config: CrossDNAConfig):
|
| 1550 |
+
super().__init__(config)
|
| 1551 |
+
self.config = config
|
| 1552 |
+
self.backbone = SSScanDNAHybridModel(config=config)
|
| 1553 |
+
self.post_init()
|
| 1554 |
+
|
| 1555 |
+
@property
|
| 1556 |
+
def tokenizer_vocab_size(self) -> int:
|
| 1557 |
+
return int(getattr(self.config, "vocab_size", self.config.alphabet_size))
|
| 1558 |
+
|
| 1559 |
+
@property
|
| 1560 |
+
def dna_token_ids(self) -> Dict[str, int]:
|
| 1561 |
+
cfg_ids = getattr(self.config, "dna_token_ids", None)
|
| 1562 |
+
if cfg_ids is not None:
|
| 1563 |
+
return dict(cfg_ids)
|
| 1564 |
+
start = int(getattr(self.config, "dna_token_start_id", 7))
|
| 1565 |
+
return {"A": start + 0, "C": start + 1, "G": start + 2, "T": start + 3, "N": start + 4}
|
| 1566 |
+
|
| 1567 |
+
@property
|
| 1568 |
+
def compact_n_token_id(self) -> int:
|
| 1569 |
+
return int(getattr(self.config, "compact_n_token_id", self.config.alphabet_size - 1))
|
| 1570 |
+
|
| 1571 |
+
def _to_compact_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 1572 |
+
input_ids = input_ids.long()
|
| 1573 |
+
if input_ids.numel() == 0:
|
| 1574 |
+
return input_ids
|
| 1575 |
+
|
| 1576 |
+
mn, mx = int(input_ids.min()), int(input_ids.max())
|
| 1577 |
+
if 0 <= mn and mx < self.config.alphabet_size:
|
| 1578 |
+
return input_ids
|
| 1579 |
+
|
| 1580 |
+
compact = torch.full_like(input_ids, self.compact_n_token_id)
|
| 1581 |
+
for compact_id, base in enumerate(["A", "C", "G", "T", "N"]):
|
| 1582 |
+
tok_id = self.dna_token_ids[base]
|
| 1583 |
+
compact[input_ids == tok_id] = compact_id
|
| 1584 |
+
return compact
|
| 1585 |
+
|
| 1586 |
+
def _labels_to_compact(self, labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 1587 |
+
if labels is None:
|
| 1588 |
+
return None
|
| 1589 |
+
|
| 1590 |
+
labels = labels.long()
|
| 1591 |
+
compact = torch.full_like(labels, -100)
|
| 1592 |
+
|
| 1593 |
+
direct_mask = (labels >= 0) & (labels < self.config.alphabet_size)
|
| 1594 |
+
compact[direct_mask] = labels[direct_mask]
|
| 1595 |
+
|
| 1596 |
+
for compact_id, base in enumerate(["A", "C", "G", "T", "N"]):
|
| 1597 |
+
tok_id = self.dna_token_ids[base]
|
| 1598 |
+
compact[labels == tok_id] = compact_id
|
| 1599 |
+
|
| 1600 |
+
return compact
|
| 1601 |
+
|
| 1602 |
+
def _build_special_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 1603 |
+
input_ids = input_ids.long()
|
| 1604 |
+
if input_ids.numel() == 0:
|
| 1605 |
+
return torch.zeros_like(input_ids, dtype=torch.bool)
|
| 1606 |
+
|
| 1607 |
+
mn, mx = int(input_ids.min()), int(input_ids.max())
|
| 1608 |
+
if 0 <= mn and mx < self.config.alphabet_size:
|
| 1609 |
+
return torch.zeros_like(input_ids, dtype=torch.bool)
|
| 1610 |
+
|
| 1611 |
+
special_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
| 1612 |
+
for base in ["A", "C", "G", "T", "N"]:
|
| 1613 |
+
special_mask[input_ids == self.dna_token_ids[base]] = False
|
| 1614 |
+
return special_mask
|
| 1615 |
+
|
| 1616 |
+
def _expand_logits_to_tokenizer_vocab(self, compact_logits: torch.Tensor) -> torch.Tensor:
|
| 1617 |
+
B, L, C = compact_logits.shape
|
| 1618 |
+
V = self.tokenizer_vocab_size
|
| 1619 |
+
full_logits = compact_logits.new_full((B, L, V), -1e4)
|
| 1620 |
+
|
| 1621 |
+
for compact_id, base in enumerate(["A", "C", "G", "T", "N"]):
|
| 1622 |
+
full_logits[:, :, self.dna_token_ids[base]] = compact_logits[:, :, compact_id]
|
| 1623 |
+
|
| 1624 |
+
return full_logits
|
| 1625 |
+
|
| 1626 |
+
@torch.no_grad()
|
| 1627 |
+
def extract_embeddings(
|
| 1628 |
+
self,
|
| 1629 |
+
input_ids: torch.LongTensor,
|
| 1630 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1631 |
+
) -> torch.Tensor:
|
| 1632 |
+
compact_ids = self._to_compact_ids(input_ids)
|
| 1633 |
+
|
| 1634 |
+
was_pretrain = bool(getattr(self.backbone, "pretrain", False))
|
| 1635 |
+
was_for_repr = bool(getattr(self.backbone, "for_representation", False))
|
| 1636 |
+
try:
|
| 1637 |
+
self.backbone.pretrain = False
|
| 1638 |
+
self.backbone.for_representation = True
|
| 1639 |
+
embeddings, _ = self.backbone(compact_ids, mask=attention_mask)
|
| 1640 |
+
finally:
|
| 1641 |
+
self.backbone.pretrain = was_pretrain
|
| 1642 |
+
self.backbone.for_representation = was_for_repr
|
| 1643 |
+
return embeddings
|
| 1644 |
+
|
| 1645 |
+
def forward(
|
| 1646 |
+
self,
|
| 1647 |
+
input_ids: torch.LongTensor,
|
| 1648 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1649 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1650 |
+
return_dict: Optional[bool] = None,
|
| 1651 |
+
**kwargs
|
| 1652 |
+
):
|
| 1653 |
+
return_dict = self.config.use_return_dict if return_dict is None else return_dict
|
| 1654 |
+
|
| 1655 |
+
compact_input_ids = self._to_compact_ids(input_ids)
|
| 1656 |
+
compact_labels = self._labels_to_compact(labels)
|
| 1657 |
+
special_mask = self._build_special_mask(input_ids)
|
| 1658 |
+
|
| 1659 |
+
if compact_labels is not None:
|
| 1660 |
+
mlm_mask = compact_labels.ne(-100)
|
| 1661 |
+
else:
|
| 1662 |
+
mlm_mask = input_ids.eq(getattr(self.config, "mask_token_id", 3))
|
| 1663 |
+
|
| 1664 |
+
# HF wrapper always asks for dense token logits, so temporarily disable streaming-loss forward.
|
| 1665 |
+
was_streaming = bool(getattr(self.backbone, "streaming_loss", False))
|
| 1666 |
+
self.backbone.streaming_loss = False
|
| 1667 |
+
try:
|
| 1668 |
+
if self.config.pretrain:
|
| 1669 |
+
outputs, _ = self.backbone(
|
| 1670 |
+
(compact_input_ids, mlm_mask, compact_labels, special_mask),
|
| 1671 |
+
mask=attention_mask,
|
| 1672 |
+
)
|
| 1673 |
+
compact_logits = outputs.logits[0]
|
| 1674 |
+
aux_loss = outputs.logits[2]
|
| 1675 |
+
else:
|
| 1676 |
+
compact_logits, _ = self.backbone(compact_input_ids, mask=attention_mask)
|
| 1677 |
+
aux_loss = None
|
| 1678 |
+
finally:
|
| 1679 |
+
self.backbone.streaming_loss = was_streaming
|
| 1680 |
+
|
| 1681 |
+
logits = self._expand_logits_to_tokenizer_vocab(compact_logits)
|
| 1682 |
+
|
| 1683 |
+
loss = None
|
| 1684 |
+
if compact_labels is not None:
|
| 1685 |
+
loss = F.cross_entropy(
|
| 1686 |
+
compact_logits.reshape(-1, self.config.alphabet_size),
|
| 1687 |
+
compact_labels.reshape(-1),
|
| 1688 |
+
ignore_index=-100,
|
| 1689 |
+
)
|
| 1690 |
+
if aux_loss is not None:
|
| 1691 |
+
loss = loss + aux_loss.to(loss.dtype)
|
| 1692 |
+
|
| 1693 |
+
if not return_dict:
|
| 1694 |
+
output = (logits,)
|
| 1695 |
+
return ((loss,) + output) if loss is not None else output
|
| 1696 |
+
|
| 1697 |
+
return MaskedLMOutput(
|
| 1698 |
+
loss=loss,
|
| 1699 |
+
logits=logits,
|
| 1700 |
+
hidden_states=None,
|
| 1701 |
+
attentions=None,
|
| 1702 |
+
)
|
28.6M/huggingface_crossdna_140K_len/crossdna/special_tokens_map.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "[BOS]",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"eos_token": "[SEP]",
|
| 5 |
+
"mask_token": "[MASK]",
|
| 6 |
+
"pad_token": "[PAD]",
|
| 7 |
+
"sep_token": "[SEP]",
|
| 8 |
+
"unk_token": "[UNK]",
|
| 9 |
+
"additional_special_tokens": [
|
| 10 |
+
"[RESERVED]"
|
| 11 |
+
]
|
| 12 |
+
}
|
28.6M/huggingface_crossdna_140K_len/crossdna/tokenization_crossdna.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Dict, List, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CrossDNATokenizer(PreTrainedTokenizer):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
characters: Sequence[str] = ("A", "C", "G", "T", "N"),
|
| 14 |
+
model_max_length: int = 143360,
|
| 15 |
+
padding_side: str = "left",
|
| 16 |
+
dna_token_start_id: int = 7,
|
| 17 |
+
**kwargs,
|
| 18 |
+
):
|
| 19 |
+
self.characters = [str(ch).upper() for ch in characters]
|
| 20 |
+
self.model_max_length = int(model_max_length)
|
| 21 |
+
self.dna_token_start_id = int(dna_token_start_id)
|
| 22 |
+
|
| 23 |
+
self._vocab_str_to_int = {
|
| 24 |
+
"[CLS]": 0,
|
| 25 |
+
"[SEP]": 1,
|
| 26 |
+
"[BOS]": 2,
|
| 27 |
+
"[MASK]": 3,
|
| 28 |
+
"[PAD]": 4,
|
| 29 |
+
"[RESERVED]": 5,
|
| 30 |
+
"[UNK]": 6,
|
| 31 |
+
**{ch: self.dna_token_start_id + i for i, ch in enumerate(self.characters)},
|
| 32 |
+
}
|
| 33 |
+
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
|
| 34 |
+
|
| 35 |
+
bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
|
| 36 |
+
eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
|
| 37 |
+
sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
|
| 38 |
+
cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
|
| 39 |
+
pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
|
| 40 |
+
unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
|
| 41 |
+
mask_token = AddedToken("[MASK]", lstrip=False, rstrip=False)
|
| 42 |
+
|
| 43 |
+
if "add_special_tokens" in kwargs:
|
| 44 |
+
kwargs.pop("add_special_tokens")
|
| 45 |
+
|
| 46 |
+
super().__init__(
|
| 47 |
+
bos_token=bos_token,
|
| 48 |
+
eos_token=eos_token,
|
| 49 |
+
sep_token=sep_token,
|
| 50 |
+
cls_token=cls_token,
|
| 51 |
+
pad_token=pad_token,
|
| 52 |
+
mask_token=mask_token,
|
| 53 |
+
unk_token=unk_token,
|
| 54 |
+
add_prefix_space=False,
|
| 55 |
+
model_max_length=self.model_max_length,
|
| 56 |
+
padding_side=padding_side,
|
| 57 |
+
**kwargs,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def __len__(self):
|
| 61 |
+
return len(self._vocab_str_to_int)
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def vocab_size(self) -> int:
|
| 65 |
+
return len(self._vocab_str_to_int)
|
| 66 |
+
|
| 67 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 68 |
+
return dict(self._vocab_str_to_int)
|
| 69 |
+
|
| 70 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 71 |
+
return list(text.upper())
|
| 72 |
+
|
| 73 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 74 |
+
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
|
| 75 |
+
|
| 76 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 77 |
+
return self._vocab_int_to_str.get(index, "[UNK]")
|
| 78 |
+
|
| 79 |
+
def convert_tokens_to_string(self, tokens):
|
| 80 |
+
return "".join(tokens)
|
| 81 |
+
|
| 82 |
+
def build_inputs_with_special_tokens(
|
| 83 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 84 |
+
) -> List[int]:
|
| 85 |
+
sep = [self.sep_token_id]
|
| 86 |
+
result = token_ids_0 + sep
|
| 87 |
+
if token_ids_1 is not None:
|
| 88 |
+
result += token_ids_1 + sep
|
| 89 |
+
return result
|
| 90 |
+
|
| 91 |
+
def get_special_tokens_mask(
|
| 92 |
+
self,
|
| 93 |
+
token_ids_0: List[int],
|
| 94 |
+
token_ids_1: Optional[List[int]] = None,
|
| 95 |
+
already_has_special_tokens: bool = False,
|
| 96 |
+
) -> List[int]:
|
| 97 |
+
if already_has_special_tokens:
|
| 98 |
+
return super().get_special_tokens_mask(
|
| 99 |
+
token_ids_0=token_ids_0,
|
| 100 |
+
token_ids_1=token_ids_1,
|
| 101 |
+
already_has_special_tokens=True,
|
| 102 |
+
)
|
| 103 |
+
result = ([0] * len(token_ids_0)) + [1]
|
| 104 |
+
if token_ids_1 is not None:
|
| 105 |
+
result += ([0] * len(token_ids_1)) + [1]
|
| 106 |
+
return result
|
| 107 |
+
|
| 108 |
+
def create_token_type_ids_from_sequences(
|
| 109 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 110 |
+
) -> List[int]:
|
| 111 |
+
sep = [self.sep_token_id]
|
| 112 |
+
result = len(token_ids_0 + sep) * [0]
|
| 113 |
+
if token_ids_1 is not None:
|
| 114 |
+
result += len(token_ids_1 + sep) * [1]
|
| 115 |
+
return result
|
| 116 |
+
|
| 117 |
+
def get_config(self) -> Dict:
|
| 118 |
+
return {
|
| 119 |
+
"characters": self.characters,
|
| 120 |
+
"model_max_length": self.model_max_length,
|
| 121 |
+
"padding_side": self.padding_side,
|
| 122 |
+
"dna_token_start_id": self.dna_token_start_id,
|
| 123 |
+
"bos_token": "[BOS]",
|
| 124 |
+
"eos_token": "[SEP]",
|
| 125 |
+
"sep_token": "[SEP]",
|
| 126 |
+
"cls_token": "[CLS]",
|
| 127 |
+
"pad_token": "[PAD]",
|
| 128 |
+
"mask_token": "[MASK]",
|
| 129 |
+
"unk_token": "[UNK]",
|
| 130 |
+
"tokenizer_class": "CrossDNATokenizer",
|
| 131 |
+
"auto_map": {
|
| 132 |
+
"AutoTokenizer": [
|
| 133 |
+
"tokenization_crossdna.CrossDNATokenizer",
|
| 134 |
+
None
|
| 135 |
+
]
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def from_config(cls, config: Dict) -> "CrossDNATokenizer":
|
| 141 |
+
return cls(
|
| 142 |
+
characters=config.get("characters", ["A", "C", "G", "T", "N"]),
|
| 143 |
+
model_max_length=config.get("model_max_length", 143360),
|
| 144 |
+
padding_side=config.get("padding_side", "left"),
|
| 145 |
+
dna_token_start_id=config.get("dna_token_start_id", 7),
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 149 |
+
save_directory = Path(save_directory)
|
| 150 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
|
| 152 |
+
cfg_file = save_directory / "tokenizer_config.json"
|
| 153 |
+
stm_file = save_directory / "special_tokens_map.json"
|
| 154 |
+
|
| 155 |
+
with open(cfg_file, "w", encoding="utf-8") as f:
|
| 156 |
+
json.dump(self.get_config(), f, indent=2, ensure_ascii=False)
|
| 157 |
+
|
| 158 |
+
special_tokens_map = {
|
| 159 |
+
"bos_token": "[BOS]",
|
| 160 |
+
"cls_token": "[CLS]",
|
| 161 |
+
"eos_token": "[SEP]",
|
| 162 |
+
"mask_token": "[MASK]",
|
| 163 |
+
"pad_token": "[PAD]",
|
| 164 |
+
"sep_token": "[SEP]",
|
| 165 |
+
"unk_token": "[UNK]",
|
| 166 |
+
"additional_special_tokens": ["[RESERVED]"]
|
| 167 |
+
}
|
| 168 |
+
with open(stm_file, "w", encoding="utf-8") as f:
|
| 169 |
+
json.dump(special_tokens_map, f, indent=2, ensure_ascii=False)
|
| 170 |
+
|
| 171 |
+
return (str(cfg_file), str(stm_file))
|
| 172 |
+
|
| 173 |
+
@classmethod
|
| 174 |
+
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
|
| 175 |
+
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
| 176 |
+
with open(cfg_file, encoding="utf-8") as f:
|
| 177 |
+
cfg = json.load(f)
|
| 178 |
+
cfg.update(kwargs)
|
| 179 |
+
cfg.pop("tokenizer_class", None)
|
| 180 |
+
cfg.pop("auto_map", None)
|
| 181 |
+
return cls.from_config(cfg)
|
28.6M/huggingface_crossdna_140K_len/crossdna/tokenizer_config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"characters": [
|
| 3 |
+
"A",
|
| 4 |
+
"C",
|
| 5 |
+
"G",
|
| 6 |
+
"T",
|
| 7 |
+
"N"
|
| 8 |
+
],
|
| 9 |
+
"model_max_length": 143360,
|
| 10 |
+
"padding_side": "left",
|
| 11 |
+
"dna_token_start_id": 7,
|
| 12 |
+
"bos_token": "[BOS]",
|
| 13 |
+
"eos_token": "[SEP]",
|
| 14 |
+
"sep_token": "[SEP]",
|
| 15 |
+
"cls_token": "[CLS]",
|
| 16 |
+
"pad_token": "[PAD]",
|
| 17 |
+
"mask_token": "[MASK]",
|
| 18 |
+
"unk_token": "[UNK]",
|
| 19 |
+
"tokenizer_class": "CrossDNATokenizer",
|
| 20 |
+
"auto_map": {
|
| 21 |
+
"AutoTokenizer": [
|
| 22 |
+
"tokenization_crossdna.CrossDNATokenizer",
|
| 23 |
+
null
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
}
|
28.6M/huggingface_crossdna_140K_len/crossdna/transfer.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# transfer.py
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
# 1) 先关掉 dynamo / compile,必须放在导入 modeling_crossdna 之前
|
| 6 |
+
os.environ["DISABLE_TORCH_COMPILE"] = "1"
|
| 7 |
+
os.environ["TORCHDYNAMO_DISABLE"] = "1"
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
# 2) 再做一次硬 monkey-patch,避免 fla 在 import 时触发 @torch.compile
|
| 12 |
+
if hasattr(torch, "compile"):
|
| 13 |
+
def _no_compile(fn=None, *args, **kwargs):
|
| 14 |
+
if fn is None:
|
| 15 |
+
def deco(f):
|
| 16 |
+
return f
|
| 17 |
+
return deco
|
| 18 |
+
return fn
|
| 19 |
+
torch.compile = _no_compile
|
| 20 |
+
|
| 21 |
+
from configuration_crossdna import CrossDNAConfig
|
| 22 |
+
from modeling_crossdna import CrossDNAForMaskedLM
|
| 23 |
+
|
| 24 |
+
CKPT = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna/last.ckpt"
|
| 25 |
+
MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def adapt_state_dict(sd):
|
| 29 |
+
new_sd = {}
|
| 30 |
+
for k, v in sd.items():
|
| 31 |
+
k2 = k
|
| 32 |
+
|
| 33 |
+
# Lightning / DDP / compiled 常见前缀
|
| 34 |
+
for pref in (
|
| 35 |
+
"state_dict.",
|
| 36 |
+
"model.",
|
| 37 |
+
"module.",
|
| 38 |
+
"_orig_mod.",
|
| 39 |
+
):
|
| 40 |
+
if k2.startswith(pref):
|
| 41 |
+
k2 = k2[len(pref):]
|
| 42 |
+
|
| 43 |
+
# 如果 ckpt 里不是 HF 包装后的 key,就挂到 backbone 下
|
| 44 |
+
if not k2.startswith("backbone."):
|
| 45 |
+
k2 = "backbone." + k2
|
| 46 |
+
|
| 47 |
+
new_sd[k2] = v
|
| 48 |
+
return new_sd
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# 1) 读 config.json 构建 HF 模型骨架
|
| 52 |
+
with open(os.path.join(MODEL_DIR, "config.json"), "r", encoding="utf-8") as f:
|
| 53 |
+
cfg_dict = json.load(f)
|
| 54 |
+
|
| 55 |
+
cfg = CrossDNAConfig(**cfg_dict)
|
| 56 |
+
model = CrossDNAForMaskedLM(cfg)
|
| 57 |
+
|
| 58 |
+
# 2) 读 ckpt
|
| 59 |
+
raw = torch.load(CKPT, map_location="cpu")
|
| 60 |
+
sd = raw.get("state_dict", raw) if isinstance(raw, dict) else raw
|
| 61 |
+
sd = adapt_state_dict(sd)
|
| 62 |
+
|
| 63 |
+
# 3) 加载权重
|
| 64 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 65 |
+
|
| 66 |
+
print("[Missing]", len(missing), "keys")
|
| 67 |
+
if missing:
|
| 68 |
+
print(" first 30 missing:")
|
| 69 |
+
for k in missing[:30]:
|
| 70 |
+
print(" ", k)
|
| 71 |
+
|
| 72 |
+
print("[Unexpected]", len(unexpected), "keys")
|
| 73 |
+
if unexpected:
|
| 74 |
+
print(" first 30 unexpected:")
|
| 75 |
+
for k in unexpected[:30]:
|
| 76 |
+
print(" ", k)
|
| 77 |
+
|
| 78 |
+
# 4) 直接保存为 HF safetensors
|
| 79 |
+
model.save_pretrained(MODEL_DIR, safe_serialization=True)
|
| 80 |
+
|
| 81 |
+
print("Saved HF weights to:", os.path.join(MODEL_DIR, "model.safetensors"))
|
28.6M/huggingface_crossdna_140K_len/crossdna_140K_infer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# 必须放在最前面
|
| 4 |
+
os.environ["DISABLE_TORCH_COMPILE"] = "1"
|
| 5 |
+
os.environ["TORCHDYNAMO_DISABLE"] = "1"
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
# 必须在 from_pretrained 之前,把 torch.compile 变成 no-op
|
| 10 |
+
if hasattr(torch, "compile"):
|
| 11 |
+
def _no_compile(fn=None, *args, **kwargs):
|
| 12 |
+
if fn is None:
|
| 13 |
+
def deco(f):
|
| 14 |
+
return f
|
| 15 |
+
return deco
|
| 16 |
+
return fn
|
| 17 |
+
torch.compile = _no_compile
|
| 18 |
+
|
| 19 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 20 |
+
|
| 21 |
+
MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_140K_len/crossdna"
|
| 22 |
+
|
| 23 |
+
tok = AutoTokenizer.from_pretrained(
|
| 24 |
+
MODEL_DIR,
|
| 25 |
+
trust_remote_code=True,
|
| 26 |
+
local_files_only=True,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
model = AutoModelForMaskedLM.from_pretrained(
|
| 30 |
+
MODEL_DIR,
|
| 31 |
+
trust_remote_code=True,
|
| 32 |
+
local_files_only=True,
|
| 33 |
+
).eval()
|
| 34 |
+
|
| 35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
model.to(device)
|
| 37 |
+
|
| 38 |
+
seq = "ACGT" * 128
|
| 39 |
+
enc = tok(seq, return_tensors="pt", add_special_tokens=False)
|
| 40 |
+
x = enc["input_ids"].to(device)
|
| 41 |
+
|
| 42 |
+
with torch.inference_mode():
|
| 43 |
+
out = model(input_ids=x)
|
| 44 |
+
emb = model.extract_embeddings(x)
|
| 45 |
+
|
| 46 |
+
print("input_ids.shape =", tuple(x.shape))
|
| 47 |
+
print("logits.shape =", tuple(out.logits.shape))
|
| 48 |
+
print("embeddings.shape =", tuple(emb.shape))
|