P2DFlow / openfold /model /heads.py
Holmes
test
ca7299e
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.loss import (
compute_plddt,
compute_tm,
compute_predicted_aligned_error,
)
class AuxiliaryHeads(nn.Module):
def __init__(self, config):
super(AuxiliaryHeads, self).__init__()
self.plddt = PerResidueLDDTCaPredictor(
**config["lddt"],
)
self.distogram = DistogramHead(
**config["distogram"],
)
self.masked_msa = MaskedMSAHead(
**config["masked_msa"],
)
self.experimentally_resolved = ExperimentallyResolvedHead(
**config["experimentally_resolved"],
)
if config.tm.enabled:
self.tm = TMScoreHead(
**config.tm,
)
self.config = config
def forward(self, outputs):
aux_out = {}
lddt_logits = self.plddt(outputs["sm"]["single"])
aux_out["lddt_logits"] = lddt_logits
# Required for relaxation later on
aux_out["plddt"] = compute_plddt(lddt_logits)
distogram_logits = self.distogram(outputs["pair"])
aux_out["distogram_logits"] = distogram_logits
masked_msa_logits = self.masked_msa(outputs["msa"])
aux_out["masked_msa_logits"] = masked_msa_logits
experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"]
)
aux_out[
"experimentally_resolved_logits"
] = experimentally_resolved_logits
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
**self.config.tm,
)
)
return aux_out
class PerResidueLDDTCaPredictor(nn.Module):
def __init__(self, no_bins, c_in, c_hidden):
super(PerResidueLDDTCaPredictor, self).__init__()
self.no_bins = no_bins
self.c_in = c_in
self.c_hidden = c_hidden
self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s = self.layer_norm(s)
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
return s
class DistogramHead(nn.Module):
"""
Computes a distogram probability distribution.
For use in computation of distogram loss, subsection 1.9.8
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of distogram bins
"""
super(DistogramHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N, N, no_bins] distogram probability distribution
"""
# [*, N, N, no_bins]
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of bins
"""
super(TMScoreHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z):
"""
Args:
z:
[*, N_res, N_res, C_z] pairwise embedding
Returns:
[*, N_res, N_res, no_bins] prediction
"""
# [*, N, N, no_bins]
logits = self.linear(z)
return logits
class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def __init__(self, c_m, c_out, **kwargs):
"""
Args:
c_m:
MSA channel dimension
c_out:
Output channel dimension
"""
super(MaskedMSAHead, self).__init__()
self.c_m = c_m
self.c_out = c_out
self.linear = Linear(self.c_m, self.c_out, init="final")
def forward(self, m):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
Returns:
[*, N_seq, N_res, C_out] reconstruction
"""
# [*, N_seq, N_res, C_out]
logits = self.linear(m)
return logits
class ExperimentallyResolvedHead(nn.Module):
"""
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def __init__(self, c_s, c_out, **kwargs):
"""
Args:
c_s:
Input channel dimension
c_out:
Number of distogram bins
"""
super(ExperimentallyResolvedHead, self).__init__()
self.c_s = c_s
self.c_out = c_out
self.linear = Linear(self.c_s, self.c_out, init="final")
def forward(self, s):
"""
Args:
s:
[*, N_res, C_s] single embedding
Returns:
[*, N, C_out] logits
"""
# [*, N, C_out]
logits = self.linear(s)
return logits