| |
|
|
| |
| |
|
|
|
|
| from pathlib import Path |
| import sys,os |
| import argparse |
| import logging |
| import sys |
| import typing as T |
| from pathlib import Path |
| from timeit import default_timer as timer |
|
|
| import torch |
|
|
| import esm |
| from esm.data import read_fasta |
|
|
| logger = logging.getLogger() |
| logger.setLevel(logging.INFO) |
|
|
| formatter = logging.Formatter( |
| "%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| datefmt="%y/%m/%d %H:%M:%S", |
| ) |
|
|
| console_handler = logging.StreamHandler(sys.stdout) |
| console_handler.setLevel(logging.INFO) |
| console_handler.setFormatter(formatter) |
| logger.addHandler(console_handler) |
|
|
|
|
| PathLike = T.Union[str, Path] |
|
|
|
|
| def enable_cpu_offloading(model): |
| from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel |
| from torch.distributed.fsdp.wrap import enable_wrap, wrap |
|
|
| torch.distributed.init_process_group( |
| backend="nccl", init_method="tcp://localhost:9999", world_size=1, rank=0 |
| ) |
|
|
| wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True)) |
|
|
| with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs): |
| for layer_name, layer in model.layers.named_children(): |
| wrapped_layer = wrap(layer) |
| setattr(model.layers, layer_name, wrapped_layer) |
| model = wrap(model) |
|
|
| return model |
|
|
|
|
| def init_model_on_gpu_with_cpu_offloading(model): |
| model = model.eval() |
| model_esm = enable_cpu_offloading(model.esm) |
| del model.esm |
| model.cuda() |
| model.esm = model_esm |
| return model |
|
|
|
|
| def create_batched_sequence_datasest( |
| sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024 |
| ) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]: |
|
|
| batch_headers, batch_sequences, num_tokens = [], [], 0 |
| for header, seq in sequences: |
| if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0: |
| yield batch_headers, batch_sequences |
| batch_headers, batch_sequences, num_tokens = [], [], 0 |
| batch_headers.append(header) |
| batch_sequences.append(seq) |
| num_tokens += len(seq) |
|
|
| yield batch_headers, batch_sequences |
|
|
|
|
| def create_parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "-i", |
| "--fasta", |
| help="Path to input FASTA file", |
| type=Path, |
| required=True, |
| ) |
| parser.add_argument( |
| "-o", "--pdb", help="Path to output PDB directory", type=Path, required=True |
| ) |
| parser.add_argument( |
| "-m", "--model-dir", help="Parent path to Pretrained ESM data directory. ", type=Path, default=None |
| ) |
| parser.add_argument( |
| "--num-recycles", |
| type=int, |
| default=None, |
| help="Number of recycles to run. Defaults to number used in training (4).", |
| ) |
| parser.add_argument( |
| "--max-tokens-per-batch", |
| type=int, |
| default=1024, |
| help="Maximum number of tokens per gpu forward-pass. This will group shorter sequences together " |
| "for batched prediction. Lowering this can help with out of memory issues, if these occur on " |
| "short sequences.", |
| ) |
| parser.add_argument( |
| "--chunk-size", |
| type=int, |
| default=None, |
| help="Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). " |
| "Equivalent to running a for loop over chunks of of each dimension. Lower values will " |
| "result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. " |
| "Default: None.", |
| ) |
| parser.add_argument("--cpu-only", help="CPU only", action="store_true") |
| parser.add_argument("--cpu-offload", help="Enable CPU offloading", action="store_true") |
| return parser |
|
|
|
|
| def run(args): |
| if not args.fasta.exists(): |
| raise FileNotFoundError(args.fasta) |
|
|
| args.pdb.mkdir(exist_ok=True) |
|
|
| |
| logger.info(f"Reading sequences from {args.fasta}") |
| all_sequences = sorted(read_fasta(args.fasta), key=lambda header_seq: len(header_seq[1])) |
| logger.info(f"Loaded {len(all_sequences)} sequences from {args.fasta}") |
|
|
| logger.info("Loading model") |
|
|
| |
| if args.model_dir is not None: |
| |
| torch.hub.set_dir(args.model_dir) |
|
|
| model = esm.pretrained.esmfold_v1() |
|
|
|
|
| model = model.eval() |
| model.set_chunk_size(args.chunk_size) |
|
|
| if args.cpu_only: |
| model.esm.float() |
| model.cpu() |
| elif args.cpu_offload: |
| model = init_model_on_gpu_with_cpu_offloading(model) |
| else: |
| model.cuda() |
| logger.info("Starting Predictions") |
| batched_sequences = create_batched_sequence_datasest(all_sequences, args.max_tokens_per_batch) |
|
|
| num_completed = 0 |
| num_sequences = len(all_sequences) |
| for headers, sequences in batched_sequences: |
| start = timer() |
| try: |
| output = model.infer(sequences, num_recycles=args.num_recycles) |
| except RuntimeError as e: |
| if e.args[0].startswith("CUDA out of memory"): |
| if len(sequences) > 1: |
| logger.info( |
| f"Failed (CUDA out of memory) to predict batch of size {len(sequences)}. " |
| "Try lowering `--max-tokens-per-batch`." |
| ) |
| else: |
| logger.info( |
| f"Failed (CUDA out of memory) on sequence {headers[0]} of length {len(sequences[0])}." |
| ) |
|
|
| continue |
| raise |
|
|
| output = {key: value.cpu() for key, value in output.items()} |
| pdbs = model.output_to_pdb(output) |
| tottime = timer() - start |
| time_string = f"{tottime / len(headers):0.1f}s" |
| if len(sequences) > 1: |
| time_string = time_string + f" (amortized, batch size {len(sequences)})" |
| for header, seq, pdb_string, mean_plddt, ptm in zip( |
| headers, sequences, pdbs, output["mean_plddt"], output["ptm"] |
| ): |
| output_file = args.pdb / f"{header}.pdb" |
| output_file.write_text(pdb_string) |
| num_completed += 1 |
| logger.info( |
| f"Predicted structure for {header} with length {len(seq)}, pLDDT {mean_plddt:0.1f}, " |
| f"pTM {ptm:0.3f} in {time_string}. " |
| f"{num_completed} / {num_sequences} completed." |
| ) |
|
|
|
|
| def main(): |
| parser = create_parser() |
| args = parser.parse_args() |
| run(args) |
|
|
| if __name__ == "__main__": |
| main() |
|
|