| import torch
|
| from torch import nn
|
| import torch.nn.functional as F
|
| from config import CFG
|
| import utils
|
| import math
|
| import numpy as np
|
| from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
|
| from GNN import layers as gly
|
|
|
| class MolGNNEncoder(nn.Module):
|
| def __init__(self,
|
| outdim,
|
| n_feats=74,
|
| n_filters_list=[256, 256, 256],
|
| n_head=4,
|
| mols=1,
|
| adj_chans=6,
|
| readout_layers=2,
|
| bias=True):
|
|
|
| super().__init__()
|
|
|
| n_filters_list = [i for i in n_filters_list if i is not None]
|
| lys = []
|
|
|
| for i, nf in enumerate(n_filters_list):
|
| if i == 0:
|
| nf1 = n_feats
|
| else:
|
| nf1 = prevnf
|
|
|
| prevnf = nf
|
|
|
| ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias)
|
| lys.append(ly)
|
|
|
| self.block_layers = nn.ModuleList(lys)
|
| self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias)
|
| self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)])
|
| self.gelu = QuickGELU()
|
|
|
| def forward(self, batch):
|
| V = batch['V']
|
| A = batch['A']
|
| mol_size = batch['mol_size']
|
|
|
| for ly in self.block_layers:
|
| V = ly(V, A)
|
|
|
| X = self.attention_layer(V, mol_size)
|
|
|
| for ly in self.readout_layers:
|
| X = self.gelu(ly(X))
|
|
|
| return X
|
|
|
| class ProjectionHead(nn.Module):
|
| def __init__(self,
|
| embedding_dim,
|
| projection_dim,
|
| cfg,
|
| transformer=True,
|
| lstm=False):
|
|
|
| super().__init__()
|
|
|
| self.projection = nn.Linear(embedding_dim, projection_dim)
|
| self.gelu = nn.GELU()
|
| self.transformer = None
|
| if transformer:
|
| self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads)
|
| self.lstm = None
|
| if lstm:
|
| self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True)
|
| self.dropout = nn.Dropout(cfg.dropout)
|
|
|
| def forward(self, x):
|
| projected = self.projection(x)
|
| if self.transformer is None:
|
| x = self.gelu(projected)
|
| else:
|
| x = self.transformer(projected)
|
| if not self.lstm is None:
|
| x, (_, _) = self.lstm(x)
|
| x = self.dropout(x)
|
|
|
| return x
|
|
|
|
|
| class FragSimiModel(nn.Module):
|
| def __init__(
|
| self,
|
| cfg
|
| ):
|
| super().__init__()
|
|
|
| self.cfg = cfg
|
| self.mol_gnn_encoder = None
|
| mol_embedding_dim = cfg.mol_embedding_dim
|
|
|
| if 'gnn' in self.cfg.mol_encoder:
|
| self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim,
|
| n_filters_list=cfg.molgnn_n_filters_list,
|
| n_head=cfg.molgnn_nhead,
|
| readout_layers=cfg.molgnn_readout_layers)
|
| if 'fp' in self.cfg.mol_encoder:
|
| mol_embedding_dim = 2*cfg.mol_embedding_dim
|
|
|
| if 'fm' in self.cfg.mol_encoder:
|
| mol_embedding_dim += 10
|
|
|
| self.ms_projection = ProjectionHead(cfg.ms_embedding_dim,
|
| cfg.projection_dim,
|
| cfg,
|
| cfg.tsfm_in_ms,
|
| cfg.lstm_in_ms)
|
|
|
| self.mol_projection = ProjectionHead(mol_embedding_dim,
|
| cfg.projection_dim,
|
| cfg,
|
| cfg.tsfm_in_mol,
|
| cfg.lstm_in_mol)
|
|
|
| def forward(self, batch):
|
| ms_features = batch["ms_bins"]
|
| mol_feat_list = []
|
| if 'gnn' in self.cfg.mol_encoder:
|
| mol_feat_list.append(self.mol_gnn_encoder(batch))
|
| if 'fp' in self.cfg.mol_encoder:
|
| mol_feat_list.append(batch["mol_fps"])
|
| if 'fm' in self.cfg.mol_encoder:
|
| mol_feat_list.append(batch["mol_fmvec"])
|
|
|
| if len(mol_feat_list) > 1:
|
| mol_features = torch.cat(mol_feat_list, dim=1)
|
| else:
|
| mol_features = mol_feat_list[0]
|
|
|
|
|
| ms_embeddings = self.ms_projection(ms_features)
|
| mol_embeddings = self.mol_projection(mol_features)
|
|
|
|
|
| mol_embeddings = F.normalize(mol_embeddings, p=2, dim=1)
|
| ms_embeddings = F.normalize(ms_embeddings, p=2, dim=1)
|
|
|
| return mol_embeddings, ms_embeddings
|
|
|
|
|
|
|
|
|
| '''logits = mol_embeddings @ ms_embeddings.t()
|
|
|
| ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
|
|
|
| ms_loss = loss_func(logits, ground_truth)
|
| mol_loss = loss_func(logits.t(), ground_truth)
|
| loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
|
|
|
| return loss.mean()'''
|
|
|