| import torch |
| import tempfile |
| import pathlib |
| import lightning as L |
| from huggingface_hub import PyTorchModelHubMixin, HfApi, hf_hub_download |
|
|
| UNK_IDX, PAD_IDX = 0, 1 |
| special_symbols = ['<unk>', '<pad>'] |
|
|
| def multihot_tensor(indices: torch.Tensor, num_classes: int, dtype=torch.int64, device=None): |
| *bs, _ = indices.shape |
| return torch.zeros((*bs, num_classes,), device=device, dtype=dtype).scatter(1, indices, 1) |
|
|
| class Vocab: |
| def __init__(self, vocab, default_index=0): |
| self.vocab = vocab |
| self.default_index = default_index |
| self.lookup = {token: i for i, token in enumerate(vocab)} |
|
|
| def __call__(self, sentence): |
| return [self.lookup.get(token, self.default_index) for token in sentence] |
|
|
| @staticmethod |
| def build_vocab_from_iterator(it, min_freq=1, specials=[], special_first=True): |
| vocab = [] |
| if special_first: |
| vocab += specials |
| from collections import Counter |
| tokens = Counter() |
| for sentence in it: |
| tokens.update(sentence) |
| for token, freq in tokens.most_common(): |
| if freq < min_freq: continue |
| vocab.append(token) |
| if not special_first: |
| vocab += specials |
| return Vocab(vocab) |
|
|
| def set_default_index(self, default_index): |
| self.default_index = default_index |
|
|
| def __len__(self): |
| return len(self.vocab) |
|
|
| def __reduce__(self): |
| return (Vocab, (self.vocab,)) |
|
|
| def save_txt(self, filename): |
| with open(filename, 'w') as fw: |
| for token in self.vocab: |
| print(token, file=fw) |
|
|
| @staticmethod |
| def from_txt(filename): |
| with open(filename, 'r') as fr: |
| return Vocab([line for line in map(str.rstrip, fr) if line]) |
|
|
| @staticmethod |
| def from_pretrained(repo_id: str, path_in_repo='vocab.txt'): |
| vocab_txt = hf_hub_download( |
| repo_id=repo_id, |
| filename=path_in_repo, |
| ) |
| return Vocab.from_txt(vocab_txt) |
|
|
| def push_to_hub(self, repo_id: str, path_in_repo='vocab.txt'): |
| api = HfApi() |
| api.create_repo(repo_id, exist_ok=True) |
| with tempfile.TemporaryDirectory() as tmpdir: |
| tmpdir = pathlib.Path(tmpdir) |
| self.save_txt(tmpdir/'vocab.txt') |
| return api.upload_file(path_or_fileobj=tmpdir/'vocab.txt', repo_id=repo_id, path_in_repo=path_in_repo) |
|
|
| class MLP(torch.nn.Module): |
| def __init__(self, *dims, activation=torch.nn.ReLU, dropout=0.2): |
| super().__init__() |
| activation = activation() |
| dropout = torch.nn.Dropout(dropout) |
| self.layers = torch.nn.ModuleList([ |
| layer |
| for a, b in zip(dims, dims[1:]) |
| for layer in ( |
| torch.nn.Linear(a, b), |
| activation, |
| dropout, |
| ) |
| ][:-2]) |
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
|
|
| class GSFM( |
| L.LightningModule, |
| PyTorchModelHubMixin, |
| tags=["gene", "gene set", "bioinformatics"], |
| ): |
| def __init__(self, vocab_size, d_model=256, depth=2, dropout=0.2, partition=0, weighted_loss=None): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.depth = depth |
| self.dropout = dropout |
| self.partition = partition |
| self.weighted_loss = weighted_loss |
| self.encoder = MLP(vocab_size, *[d_model*(2**(n-1)) for n in range(depth, 1, -1)], d_model, dropout=dropout) |
| self.decoder = MLP(d_model, *[d_model*(2**(n-1)) for n in range(1, depth)], vocab_size, dropout=dropout) |
| self.save_hyperparameters() |
|
|
| def encode(self, x): |
| x = multihot_tensor(x, num_classes=self.vocab_size, device=self.device, dtype=torch.float) |
| x[:, PAD_IDX] = 0 |
| return self.encoder(x) |
|
|
| def forward(self, x): |
| x = self.encode(x) |
| x = self.decoder(x) |
| return x |
|
|
| def training_step(self, batch, batch_idx): |
| x_idx = y_idx = batch |
| y_ = self(x_idx) |
| y = multihot_tensor(y_idx, num_classes=self.vocab_size, device=self.device, dtype=torch.float) |
| y[:, PAD_IDX] = 0 |
| criterion = torch.nn.BCEWithLogitsLoss() |
| loss = criterion(y_, y) |
| self.log('loss', loss, prog_bar=True) |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| return self.training_step(batch, batch_idx) |
|
|
| def configure_optimizers(self): |
| optimizer = torch.optim.Adam(self.parameters()) |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.25) |
| return [optimizer], [{ |
| "scheduler": scheduler, |
| "monitor": "loss", |
| "frequency": 1, |
| }] |
|
|