| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Sequence, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from openfold.model.dropout import ( |
| DropoutRowwise, |
| DropoutColumnwise, |
| ) |
| from openfold.model.evoformer import ( |
| EvoformerBlock, |
| EvoformerStack, |
| ) |
| from openfold.model.outer_product_mean import OuterProductMean |
| from openfold.model.msa import ( |
| MSARowAttentionWithPairBias, |
| MSAColumnAttention, |
| MSAColumnGlobalAttention, |
| ) |
| from openfold.model.pair_transition import PairTransition |
| from openfold.model.primitives import Attention, GlobalAttention |
| from openfold.model.structure_module import ( |
| InvariantPointAttention, |
| BackboneUpdate, |
| ) |
| from openfold.model.template import TemplatePairStackBlock |
| from openfold.model.triangular_attention import ( |
| TriangleAttentionStartingNode, |
| TriangleAttentionEndingNode, |
| ) |
| from openfold.model.triangular_multiplicative_update import ( |
| TriangleMultiplicationOutgoing, |
| TriangleMultiplicationIncoming, |
| ) |
|
|
|
|
| def script_preset_(model: torch.nn.Module): |
| """ |
| TorchScript a handful of low-level but frequently used submodule types |
| that are known to be scriptable. |
| |
| Args: |
| model: |
| A torch.nn.Module. It should contain at least some modules from |
| this repository, or this function won't do anything. |
| """ |
| script_submodules_( |
| model, |
| [ |
| nn.Dropout, |
| Attention, |
| GlobalAttention, |
| EvoformerBlock, |
| |
| ], |
| attempt_trace=False, |
| batch_dims=None, |
| ) |
|
|
| |
| def _get_module_device(module: torch.nn.Module) -> torch.device: |
| """ |
| Fetches the device of a module, assuming that all of the module's |
| parameters reside on a single device |
| |
| Args: |
| module: A torch.nn.Module |
| Returns: |
| The module's device |
| """ |
| return next(module.parameters()).device |
|
|
|
|
| def _trace_module(module, batch_dims=None): |
| if(batch_dims is None): |
| batch_dims = () |
|
|
| |
| n_seq = 10 |
| n_res = 10 |
|
|
| device = _get_module_device(module) |
|
|
| def msa(channel_dim): |
| return torch.rand( |
| (*batch_dims, n_seq, n_res, channel_dim), |
| device=device, |
| ) |
|
|
| def pair(channel_dim): |
| return torch.rand( |
| (*batch_dims, n_res, n_res, channel_dim), |
| device=device, |
| ) |
|
|
| if(isinstance(module, MSARowAttentionWithPairBias)): |
| inputs = { |
| "forward": ( |
| msa(module.c_in), |
| pair(module.c_z), |
| torch.randint( |
| 0, 2, |
| (*batch_dims, n_seq, n_res) |
| ), |
| ), |
| } |
| elif(isinstance(module, MSAColumnAttention)): |
| inputs = { |
| "forward": ( |
| msa(module.c_in), |
| torch.randint( |
| 0, 2, |
| (*batch_dims, n_seq, n_res) |
| ), |
| ), |
| } |
| elif(isinstance(module, OuterProductMean)): |
| inputs = { |
| "forward": ( |
| msa(module.c_m), |
| torch.randint( |
| 0, 2, |
| (*batch_dims, n_seq, n_res) |
| ) |
| ) |
| } |
| else: |
| raise TypeError( |
| f"tracing is not supported for modules of type {type(module)}" |
| ) |
|
|
| return torch.jit.trace_module(module, inputs) |
|
|
|
|
| def _script_submodules_helper_( |
| model, |
| types, |
| attempt_trace, |
| to_trace, |
| ): |
| for name, child in model.named_children(): |
| if(types is None or any(isinstance(child, t) for t in types)): |
| try: |
| scripted = torch.jit.script(child) |
| setattr(model, name, scripted) |
| continue |
| except (RuntimeError, torch.jit.frontend.NotSupportedError) as e: |
| if(attempt_trace): |
| to_trace.add(type(child)) |
| else: |
| raise e |
| |
| _script_submodules_helper_(child, types, attempt_trace, to_trace) |
|
|
|
|
| def _trace_submodules_( |
| model, |
| types, |
| batch_dims=None, |
| ): |
| for name, child in model.named_children(): |
| if(any(isinstance(child, t) for t in types)): |
| traced = _trace_module(child, batch_dims=batch_dims) |
| setattr(model, name, traced) |
| else: |
| _trace_submodules_(child, types, batch_dims=batch_dims) |
|
|
|
|
| def script_submodules_( |
| model: nn.Module, |
| types: Optional[Sequence[type]] = None, |
| attempt_trace: Optional[bool] = True, |
| batch_dims: Optional[Tuple[int]] = None, |
| ): |
| """ |
| Convert all submodules whose types match one of those in the input |
| list to recursively scripted equivalents in place. To script the entire |
| model, just call torch.jit.script on it directly. |
| |
| When types is None, all submodules are scripted. |
| |
| Args: |
| model: |
| A torch.nn.Module |
| types: |
| A list of types of submodules to script |
| attempt_trace: |
| Whether to attempt to trace specified modules if scripting |
| fails. Recall that tracing eliminates all conditional |
| logic---with great tracing comes the mild responsibility of |
| having to remember to ensure that the modules in question |
| perform the same computations no matter what. |
| """ |
| to_trace = set() |
|
|
| |
| _script_submodules_helper_(model, types, attempt_trace, to_trace) |
| |
| |
| if(attempt_trace and len(to_trace) > 0): |
| _trace_submodules_(model, to_trace, batch_dims=batch_dims) |
|
|