| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from torch_geometric.nn.conv import GCNConv |
| from performer_pytorch import Performer |
|
|
| |
| model_params = { |
| "dim": 320, |
| "bins": 10, |
| "gb_repeat": 1, |
| "p_repeat": 2, |
| "bin_head": 8, |
| "full_head": 4, |
| "gene_length": 19357 |
| } |
|
|
| |
| |
| |
|
|
| class PositionalExprEmbedding(nn.Module): |
| """ |
| Rotary Expression Embedding (REE): |
| Converts continuous gene expression values into a sinusoidal |
| embedding usable by Performer/Transformer blocks. Deterministic, |
| not learned. Masked positions (-10) → zero vector. |
| """ |
| def __init__(self, dim, mask_token=-10): |
| super().__init__() |
| self.mask_token = mask_token |
| self.inv_freq = nn.Parameter( |
| 1.0 / (100 ** (torch.arange(0, dim, 2).float() / dim)), |
| requires_grad=False |
| ) |
|
|
| def forward(self, x): |
| mask = (x == self.mask_token).nonzero(as_tuple=False) |
| x = torch.einsum("bi,j->bij", x, self.inv_freq) |
| x = torch.cat([x.sin(), x.cos()], dim=-1) |
| x[mask[:, 0], mask[:, 1]] = 0 |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class GBFormer(nn.Module): |
| """ |
| A single GBFormer block: |
| - LayerNorm |
| - GCNConv (gene-gene propagation) |
| - Binning by learned importance score |
| - Local Performer per-bin |
| - Global Performer |
| """ |
| def __init__(self, dim, gene_length, bin_head, full_head, bins, p_repeat): |
| super().__init__() |
|
|
| self.dim = dim |
| self.bins = bins |
| self.bin_head = bin_head |
| self.full_head = full_head |
| self.p_repeat = p_repeat |
|
|
| self.layernorm = nn.LayerNorm(dim) |
| self.gcn = GCNConv(dim, dim, cached=True, add_self_loops=False) |
|
|
| |
| self.which_bin = nn.Linear(dim, 1) |
|
|
| |
| self.bin_layers = nn.ModuleList([ |
| Performer( |
| dim=dim, |
| heads=bin_head, |
| depth=1, |
| dim_head=dim // bin_head, |
| attn_dropout=0.2, |
| ff_dropout=0.2 |
| ) |
| for _ in range(bins) |
| ]) |
|
|
| |
| self.global_layers = nn.Sequential(*[ |
| Performer( |
| dim=dim, |
| heads=full_head, |
| depth=1, |
| dim_head=dim // full_head |
| ) |
| for _ in range(p_repeat) |
| ]) |
|
|
| def forward(self, x, graph): |
| B, G, D = x.shape |
|
|
| x = self.layernorm(x) |
| x = x + self.gcn(x, graph) |
|
|
| if self.bins > 0: |
| scores = self.which_bin(x).squeeze(-1) |
| order = torch.argsort(scores, dim=1, descending=True) |
| order_full = order.unsqueeze(-1).expand(-1, -1, D) |
|
|
| x_sorted = x.gather(1, order_full) |
| bin_size = (G - 1) // self.bins + 1 |
| chunks = torch.split(x_sorted, bin_size, dim=1) |
|
|
| processed = [ |
| layer(chunk) |
| for chunk, layer in zip(chunks, self.bin_layers) |
| ] |
|
|
| x_cat = torch.cat(processed, dim=1) |
| x = torch.empty_like(x_cat).scatter_(1, order_full, x_cat) |
|
|
| x = self.global_layers(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class BulkFormer(nn.Module): |
| """ |
| CancerTranscriptome-Mini-48M: |
| A compact BulkFormer-style masked-expression model. |
| Combines: |
| - ESM2 gene identity embeddings |
| - Rotary Expression Embeddings (REE) |
| - Graph Convolution (GCNConv) |
| - Local/global Performer attention |
| - Optional intermediate repr_layers for feature extraction |
| """ |
| def __init__( |
| self, |
| dim, |
| graph, |
| gene_emb, |
| gene_length, |
| bin_head=4, |
| full_head=4, |
| bins=10, |
| gb_repeat=1, |
| p_repeat=1 |
| ): |
| super().__init__() |
|
|
| self.dim = dim |
| self.graph = graph |
| self.gene_length = gene_length |
|
|
| |
| self.gene_emb = nn.Parameter(gene_emb) |
| self.gene_proj = nn.Sequential( |
| nn.Linear(gene_emb.shape[1], 4 * dim), |
| nn.ReLU(), |
| nn.Linear(4 * dim, dim) |
| ) |
|
|
| |
| self.expr_emb = PositionalExprEmbedding(dim) |
|
|
| |
| self.mix = nn.Sequential( |
| nn.Linear(dim, 4 * dim), |
| nn.ReLU(), |
| nn.Linear(4 * dim, dim) |
| ) |
|
|
| |
| self.gb_blocks = nn.ModuleList([ |
| GBFormer(dim, gene_length, bin_head, full_head, bins, p_repeat) |
| for _ in range(gb_repeat) |
| ]) |
|
|
| self.final_norm = nn.LayerNorm(dim) |
|
|
| |
| self.head = nn.Sequential( |
| nn.Linear(dim, 4 * dim), |
| nn.ReLU(), |
| nn.Linear(4 * dim, 1), |
| nn.ReLU() |
| ) |
|
|
| def forward(self, x, repr_layers=None): |
| B, G = x.shape |
| hidden = {} |
|
|
| x = ( |
| self.expr_emb(x) + |
| self.gene_proj(self.gene_emb) + |
| torch.zeros(B, 1, self.dim, device=x.device) |
| ) |
|
|
| x = self.mix(x) |
|
|
| for i, block in enumerate(self.gb_blocks): |
| x = block(x, self.graph) |
| if repr_layers and i in repr_layers: |
| hidden[i] = x |
|
|
| x = self.final_norm(x) |
| out = self.head(x).squeeze(-1) |
|
|
| if repr_layers: |
| return out, hidden |
| return out |
|
|