| |
| |
| |
| |
|
|
| import itertools |
| import os |
| from typing import Sequence, Tuple, List, Union |
| import pickle |
| import re |
| import shutil |
| import torch |
| from pathlib import Path |
| from esm.constants import proteinseq_toks |
|
|
| RawMSA = Sequence[Tuple[str, str]] |
|
|
|
|
| class FastaBatchedDataset(object): |
| def __init__(self, sequence_labels, sequence_strs): |
| self.sequence_labels = list(sequence_labels) |
| self.sequence_strs = list(sequence_strs) |
|
|
| @classmethod |
| def from_file(cls, fasta_file): |
| sequence_labels, sequence_strs = [], [] |
| cur_seq_label = None |
| buf = [] |
|
|
| def _flush_current_seq(): |
| nonlocal cur_seq_label, buf |
| if cur_seq_label is None: |
| return |
| sequence_labels.append(cur_seq_label) |
| sequence_strs.append("".join(buf)) |
| cur_seq_label = None |
| buf = [] |
|
|
| with open(fasta_file, "r") as infile: |
| for line_idx, line in enumerate(infile): |
| if line.startswith(">"): |
| _flush_current_seq() |
| line = line[1:].strip() |
| if len(line) > 0: |
| cur_seq_label = line |
| else: |
| cur_seq_label = f"seqnum{line_idx:09d}" |
| else: |
| buf.append(line.strip()) |
|
|
| _flush_current_seq() |
|
|
| assert len(set(sequence_labels)) == len( |
| sequence_labels |
| ), "Found duplicate sequence labels" |
|
|
| return cls(sequence_labels, sequence_strs) |
|
|
| def __len__(self): |
| return len(self.sequence_labels) |
|
|
| def __getitem__(self, idx): |
| return self.sequence_labels[idx], self.sequence_strs[idx] |
|
|
| def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): |
| sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] |
| sizes.sort() |
| batches = [] |
| buf = [] |
| max_len = 0 |
|
|
| def _flush_current_buf(): |
| nonlocal max_len, buf |
| if len(buf) == 0: |
| return |
| batches.append(buf) |
| buf = [] |
| max_len = 0 |
|
|
| for sz, i in sizes: |
| sz += extra_toks_per_seq |
| if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: |
| _flush_current_buf() |
| max_len = max(max_len, sz) |
| buf.append(i) |
|
|
| _flush_current_buf() |
| return batches |
|
|
|
|
| class Alphabet(object): |
| def __init__( |
| self, |
| standard_toks: Sequence[str], |
| prepend_toks: Sequence[str] = ("<null_0>", "<pad>", "<eos>", "<unk>"), |
| append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"), |
| prepend_bos: bool = True, |
| append_eos: bool = False, |
| use_msa: bool = False, |
| ): |
| self.standard_toks = list(standard_toks) |
| self.prepend_toks = list(prepend_toks) |
| self.append_toks = list(append_toks) |
| self.prepend_bos = prepend_bos |
| self.append_eos = append_eos |
| self.use_msa = use_msa |
|
|
| self.all_toks = list(self.prepend_toks) |
| self.all_toks.extend(self.standard_toks) |
| for i in range((8 - (len(self.all_toks) % 8)) % 8): |
| self.all_toks.append(f"<null_{i + 1}>") |
| self.all_toks.extend(self.append_toks) |
|
|
| self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} |
|
|
| self.unk_idx = self.tok_to_idx["<unk>"] |
| self.padding_idx = self.get_idx("<pad>") |
| self.cls_idx = self.get_idx("<cls>") |
| self.mask_idx = self.get_idx("<mask>") |
| self.eos_idx = self.get_idx("<eos>") |
| self.all_special_tokens = ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>'] |
| self.unique_no_split_tokens = self.all_toks |
|
|
| def __len__(self): |
| return len(self.all_toks) |
|
|
| def get_idx(self, tok): |
| return self.tok_to_idx.get(tok, self.unk_idx) |
|
|
| def get_tok(self, ind): |
| return self.all_toks[ind] |
|
|
| def to_dict(self): |
| return self.tok_to_idx.copy() |
|
|
| def get_batch_converter(self, truncation_seq_length: int = None): |
| if self.use_msa: |
| return MSABatchConverter(self, truncation_seq_length) |
| else: |
| return BatchConverter(self, truncation_seq_length) |
|
|
| @classmethod |
| def from_architecture(cls, name: str) -> "Alphabet": |
| if name in ("ESM-1", "protein_bert_base"): |
| standard_toks = proteinseq_toks["toks"] |
| prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>") |
| append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>") |
| prepend_bos = True |
| append_eos = False |
| use_msa = False |
| elif name in ("ESM-1b", "roberta_large"): |
| standard_toks = proteinseq_toks["toks"] |
| prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>") |
| append_toks = ("<mask>",) |
| prepend_bos = True |
| append_eos = True |
| use_msa = False |
| elif name in ("MSA Transformer", "msa_transformer"): |
| standard_toks = proteinseq_toks["toks"] |
| prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>") |
| append_toks = ("<mask>",) |
| prepend_bos = True |
| append_eos = False |
| use_msa = True |
| elif "invariant_gvp" in name.lower(): |
| standard_toks = proteinseq_toks["toks"] |
| prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>") |
| append_toks = ("<mask>", "<cath>", "<af2>") |
| prepend_bos = True |
| append_eos = False |
| use_msa = False |
| else: |
| raise ValueError("Unknown architecture selected") |
| return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa) |
|
|
| def _tokenize(self, text) -> str: |
| return text.split() |
|
|
| def tokenize(self, text, **kwargs) -> List[str]: |
| """ |
| Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py |
| Converts a string in a sequence of tokens, using the tokenizer. |
| |
| Args: |
| text (:obj:`str`): |
| The sequence to be encoded. |
| |
| Returns: |
| :obj:`List[str]`: The list of tokens. |
| """ |
|
|
| def split_on_token(tok, text): |
| result = [] |
| split_text = text.split(tok) |
| for i, sub_text in enumerate(split_text): |
| |
| |
| |
| |
| |
| if i < len(split_text) - 1: |
| sub_text = sub_text.rstrip() |
| if i > 0: |
| sub_text = sub_text.lstrip() |
|
|
| if i == 0 and not sub_text: |
| result.append(tok) |
| elif i == len(split_text) - 1: |
| if sub_text: |
| result.append(sub_text) |
| else: |
| pass |
| else: |
| if sub_text: |
| result.append(sub_text) |
| result.append(tok) |
| return result |
|
|
| def split_on_tokens(tok_list, text): |
| if not text.strip(): |
| return [] |
|
|
| tokenized_text = [] |
| text_list = [text] |
| for tok in tok_list: |
| tokenized_text = [] |
| for sub_text in text_list: |
| if sub_text not in self.unique_no_split_tokens: |
| tokenized_text.extend(split_on_token(tok, sub_text)) |
| else: |
| tokenized_text.append(sub_text) |
| text_list = tokenized_text |
|
|
| return list( |
| itertools.chain.from_iterable( |
| ( |
| self._tokenize(token) |
| if token not in self.unique_no_split_tokens |
| else [token] |
| for token in tokenized_text |
| ) |
| ) |
| ) |
|
|
| no_split_token = self.unique_no_split_tokens |
| tokenized_text = split_on_tokens(no_split_token, text) |
| return tokenized_text |
|
|
| def encode(self, text): |
| return [self.tok_to_idx[tok] for tok in self.tokenize(text)] |
|
|
|
|
| class BatchConverter(object): |
| """Callable to convert an unprocessed (labels + strings) batch to a |
| processed (labels + tensor) batch. |
| """ |
|
|
| def __init__(self, alphabet, truncation_seq_length: int = None): |
| self.alphabet = alphabet |
| self.truncation_seq_length = truncation_seq_length |
|
|
| def __call__(self, raw_batch: Sequence[Tuple[str, str]]): |
| |
| batch_size = len(raw_batch) |
| batch_labels, seq_str_list = zip(*raw_batch) |
| seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list] |
| if self.truncation_seq_length: |
| seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list] |
| max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list) |
| tokens = torch.empty( |
| ( |
| batch_size, |
| max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), |
| ), |
| dtype=torch.int64, |
| ) |
| tokens.fill_(self.alphabet.padding_idx) |
| labels = [] |
| strs = [] |
|
|
| for i, (label, seq_str, seq_encoded) in enumerate( |
| zip(batch_labels, seq_str_list, seq_encoded_list) |
| ): |
| labels.append(label) |
| strs.append(seq_str) |
| if self.alphabet.prepend_bos: |
| tokens[i, 0] = self.alphabet.cls_idx |
| seq = torch.tensor(seq_encoded, dtype=torch.int64) |
| tokens[ |
| i, |
| int(self.alphabet.prepend_bos) : len(seq_encoded) |
| + int(self.alphabet.prepend_bos), |
| ] = seq |
| if self.alphabet.append_eos: |
| tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx |
|
|
| return labels, strs, tokens |
|
|
|
|
| class MSABatchConverter(BatchConverter): |
| def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]): |
| if isinstance(inputs[0][0], str): |
| |
| raw_batch: Sequence[RawMSA] = [inputs] |
| else: |
| raw_batch = inputs |
|
|
| batch_size = len(raw_batch) |
| max_alignments = max(len(msa) for msa in raw_batch) |
| max_seqlen = max(len(msa[0][1]) for msa in raw_batch) |
|
|
| tokens = torch.empty( |
| ( |
| batch_size, |
| max_alignments, |
| max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos), |
| ), |
| dtype=torch.int64, |
| ) |
| tokens.fill_(self.alphabet.padding_idx) |
| labels = [] |
| strs = [] |
|
|
| for i, msa in enumerate(raw_batch): |
| msa_seqlens = set(len(seq) for _, seq in msa) |
| if not len(msa_seqlens) == 1: |
| raise RuntimeError( |
| "Received unaligned sequences for input to MSA, all sequence " |
| "lengths must be equal." |
| ) |
| msa_labels, msa_strs, msa_tokens = super().__call__(msa) |
| labels.append(msa_labels) |
| strs.append(msa_strs) |
| tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens |
|
|
| return labels, strs, tokens |
|
|
|
|
| def read_fasta( |
| path, |
| keep_gaps=True, |
| keep_insertions=True, |
| to_upper=False, |
| ): |
| with open(path, "r") as f: |
| for result in read_alignment_lines( |
| f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper |
| ): |
| yield result |
|
|
|
|
| def read_alignment_lines( |
| lines, |
| keep_gaps=True, |
| keep_insertions=True, |
| to_upper=False, |
| ): |
| seq = desc = None |
|
|
| def parse(s): |
| if not keep_gaps: |
| s = re.sub("-", "", s) |
| if not keep_insertions: |
| s = re.sub("[a-z]", "", s) |
| return s.upper() if to_upper else s |
|
|
| for line in lines: |
| |
| if len(line) > 0 and line[0] == ">": |
| if seq is not None: |
| yield desc, parse(seq) |
| desc = line.strip().lstrip(">") |
| seq = "" |
| else: |
| assert isinstance(seq, str) |
| seq += line.strip() |
| assert isinstance(seq, str) and isinstance(desc, str) |
| yield desc, parse(seq) |
|
|
|
|
| class ESMStructuralSplitDataset(torch.utils.data.Dataset): |
| """ |
| Structural Split Dataset as described in section A.10 of the supplement of our paper. |
| https://doi.org/10.1101/622803 |
| |
| We use the full version of SCOPe 2.07, clustered at 90% sequence identity, |
| generated on January 23, 2020. |
| |
| For each SCOPe domain: |
| - We extract the sequence from the corresponding PDB file |
| - We extract the 3D coordinates of the Carbon beta atoms, aligning them |
| to the sequence. We put NaN where Cb atoms are missing. |
| - From the 3D coordinates, we calculate a pairwise distance map, based |
| on L2 distance |
| - We use DSSP to generate secondary structure labels for the corresponding |
| PDB file. This is also aligned to the sequence. We put - where SSP |
| labels are missing. |
| |
| For each SCOPe classification level of family/superfamily/fold (in order of difficulty), |
| we have split the data into 5 partitions for cross validation. These are provided |
| in a downloaded splits folder, in the format: |
| splits/{split_level}/{cv_partition}/{train|valid}.txt |
| where train is the partition and valid is the concatentation of the remaining 4. |
| |
| For each SCOPe domain, we provide a pkl dump that contains: |
| - seq : The domain sequence, stored as an L-length string |
| - ssp : The secondary structure labels, stored as an L-length string |
| - dist : The distance map, stored as an LxL numpy array |
| - coords : The 3D coordinates, stored as an Lx3 numpy array |
| |
| """ |
|
|
| base_folder = "structural-data" |
| file_list = [ |
| |
| ( |
| "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz", |
| "splits.tar.gz", |
| "splits", |
| "456fe1c7f22c9d3d8dfe9735da52411d", |
| ), |
| ( |
| "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz", |
| "pkl.tar.gz", |
| "pkl", |
| "644ea91e56066c750cd50101d390f5db", |
| ), |
| ] |
|
|
| def __init__( |
| self, |
| split_level, |
| cv_partition, |
| split, |
| root_path=os.path.expanduser("~/.cache/torch/data/esm"), |
| download=False, |
| ): |
| super().__init__() |
| assert split in [ |
| "train", |
| "valid", |
| ], "train_valid must be 'train' or 'valid'" |
| self.root_path = root_path |
| self.base_path = os.path.join(self.root_path, self.base_folder) |
|
|
| |
| if download: |
| self.download() |
|
|
| self.split_file = os.path.join( |
| self.base_path, "splits", split_level, cv_partition, f"{split}.txt" |
| ) |
| self.pkl_dir = os.path.join(self.base_path, "pkl") |
| self.names = [] |
| with open(self.split_file) as f: |
| self.names = f.read().splitlines() |
|
|
| def __len__(self): |
| return len(self.names) |
|
|
| def _check_exists(self) -> bool: |
| for (_, _, filename, _) in self.file_list: |
| fpath = os.path.join(self.base_path, filename) |
| if not os.path.exists(fpath) or not os.path.isdir(fpath): |
| return False |
| return True |
|
|
| def download(self): |
|
|
| if self._check_exists(): |
| print("Files already downloaded and verified") |
| return |
|
|
| from torchvision.datasets.utils import download_url |
|
|
| for url, tar_filename, filename, md5_hash in self.file_list: |
| download_path = os.path.join(self.base_path, tar_filename) |
| download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash) |
| shutil.unpack_archive(download_path, self.base_path) |
|
|
| def __getitem__(self, idx): |
| """ |
| Returns a dict with the following entires |
| - seq : Str (domain sequence) |
| - ssp : Str (SSP labels) |
| - dist : np.array (distance map) |
| - coords : np.array (3D coordinates) |
| """ |
| name = self.names[idx] |
| pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl") |
| with open(pkl_fname, "rb") as f: |
| obj = pickle.load(f) |
| return obj |
|
|