| |
| |
| |
| |
| |
|
|
| import argparse |
| import pathlib |
| import sys |
| print("using", sys.executable) |
|
|
| sys.path.insert( 0,"/home/user/.local/lib/python3.8/site-packages") |
| sys.path.insert( 0,"/home/user/app/esm/") |
| import os |
|
|
| import torch |
|
|
| from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer |
|
|
|
|
| def create_parser(): |
| parser = argparse.ArgumentParser( |
| description="Extract per-token representations and model outputs for sequences in a FASTA file" |
| ) |
|
|
| parser.add_argument( |
| "model_location", |
| type=str, |
| help="PyTorch model file OR name of pretrained model to download (see README for models)", |
| ) |
| parser.add_argument( |
| "fasta_file", |
| type=pathlib.Path, |
| help="FASTA file on which to extract representations", |
| ) |
| parser.add_argument( |
| "output_dir", |
| type=pathlib.Path, |
| help="output directory for extracted representations", |
| ) |
|
|
| parser.add_argument("--toks_per_batch", type=int, default=4096, help="maximum batch size") |
| parser.add_argument( |
| "--repr_layers", |
| type=int, |
| default=[-1], |
| nargs="+", |
| help="layers indices from which to extract representations (0 to num_layers, inclusive)", |
| ) |
| parser.add_argument( |
| "--include", |
| type=str, |
| nargs="+", |
| choices=["mean", "per_tok", "bos", "contacts"], |
| help="specify which representations to return", |
| required=True, |
| ) |
| parser.add_argument( |
| "--truncation_seq_length", |
| type=int, |
| default=1022, |
| help="truncate sequences longer than the given value", |
| ) |
|
|
| parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available") |
| return parser |
|
|
|
|
| def main(args): |
| model, alphabet = pretrained.load_model_and_alphabet(args.model_location) |
| model.eval() |
| if isinstance(model, MSATransformer): |
| raise ValueError( |
| "This script currently does not handle models with MSA input (MSA Transformer)." |
| ) |
| if torch.cuda.is_available() and not args.nogpu: |
| model = model.cuda() |
| print("Transferred model to GPU") |
|
|
| dataset = FastaBatchedDataset.from_file(args.fasta_file) |
| batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) |
| data_loader = torch.utils.data.DataLoader( |
| dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches |
| ) |
| print(f"Read {args.fasta_file} with {len(dataset)} sequences") |
|
|
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| return_contacts = "contacts" in args.include |
|
|
| assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers) |
| repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers] |
|
|
| with torch.no_grad(): |
| for batch_idx, (labels, strs, toks) in enumerate(data_loader): |
| print( |
| f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)" |
| ) |
| if torch.cuda.is_available() and not args.nogpu: |
| toks = toks.to(device="cuda", non_blocking=True) |
|
|
| out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts) |
|
|
| logits = out["logits"].to(device="cpu") |
| representations = { |
| layer: t.to(device="cpu") for layer, t in out["representations"].items() |
| } |
| if return_contacts: |
| contacts = out["contacts"].to(device="cpu") |
|
|
| for i, label in enumerate(labels): |
| args.output_file = args.output_dir / f"{label}.pt" |
| args.output_file.parent.mkdir(parents=True, exist_ok=True) |
| result = {"label": label} |
| |
| |
| if "per_tok" in args.include: |
| result["representations"] = { |
| layer: t[i, 1 : len(strs[i]) + 1].clone() |
| for layer, t in representations.items() |
| } |
| if "mean" in args.include: |
| result["mean_representations"] = { |
| layer: t[i, 1 : len(strs[i]) + 1].mean(0).clone() |
| for layer, t in representations.items() |
| } |
| if "bos" in args.include: |
| result["bos_representations"] = { |
| layer: t[i, 0].clone() for layer, t in representations.items() |
| } |
| if return_contacts: |
| result["contacts"] = contacts[i, : len(strs[i]), : len(strs[i])].clone() |
|
|
| torch.save( |
| result, |
| args.output_file, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = create_parser() |
| args = parser.parse_args() |
| main(args) |
|
|