| from typing import List, Dict, Any |
| from openfold.utils import rigid_utils as ru |
| from data import residue_constants |
| import numpy as np |
| import collections |
| import string |
| import pickle |
| import os |
| import torch |
| from torch_scatter import scatter_add, scatter |
| from Bio.PDB.Chain import Chain |
| from data import protein |
| import dataclasses |
| from Bio import PDB |
|
|
| Rigid = ru.Rigid |
| Protein = protein.Protein |
|
|
| |
| ALPHANUMERIC = string.ascii_letters + string.digits + ' ' |
| CHAIN_TO_INT = { |
| chain_char: i for i, chain_char in enumerate(ALPHANUMERIC) |
| } |
| INT_TO_CHAIN = { |
| i: chain_char for i, chain_char in enumerate(ALPHANUMERIC) |
| } |
|
|
| NM_TO_ANG_SCALE = 10.0 |
| ANG_TO_NM_SCALE = 1 / NM_TO_ANG_SCALE |
|
|
| CHAIN_FEATS = [ |
| 'atom_positions', 'aatype', 'atom_mask', 'residue_index', 'b_factors' |
| ] |
|
|
| to_numpy = lambda x: x.detach().cpu().numpy() |
| aatype_to_seq = lambda aatype: ''.join([ |
| residue_constants.restypes_with_x[x] for x in aatype]) |
|
|
|
|
| class CPU_Unpickler(pickle.Unpickler): |
| """Pytorch pickle loading workaround. |
| |
| https://github.com/pytorch/pytorch/issues/16797 |
| """ |
| def find_class(self, module, name): |
| if module == 'torch.storage' and name == '_load_from_bytes': |
| return lambda b: torch.load(io.BytesIO(b), map_location='cpu') |
| else: return super().find_class(module, name) |
|
|
|
|
| def create_rigid(rots, trans): |
| rots = ru.Rotation(rot_mats=rots) |
| return Rigid(rots=rots, trans=trans) |
|
|
|
|
| def batch_align_structures(pos_1, pos_2, mask=None): |
| if pos_1.shape != pos_2.shape: |
| raise ValueError('pos_1 and pos_2 must have the same shape.') |
| if pos_1.ndim != 3: |
| raise ValueError(f'Expected inputs to have shape [B, N, 3]') |
| num_batch = pos_1.shape[0] |
| device = pos_1.device |
| batch_indices = ( |
| torch.ones(*pos_1.shape[:2], device=device, dtype=torch.int64) |
| * torch.arange(num_batch, device=device)[:, None] |
| ) |
| flat_pos_1 = pos_1.reshape(-1, 3) |
| flat_pos_2 = pos_2.reshape(-1, 3) |
| flat_batch_indices = batch_indices.reshape(-1) |
| if mask is None: |
| |
| |
| |
| |
| |
| mask = torch.ones(*pos_1.shape[:2], device=device).reshape(-1).bool() |
|
|
| flat_mask = mask.reshape(-1).bool() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| aligned_pos_1, aligned_pos_2, align_rots = align_structures( |
| flat_pos_1[flat_mask], flat_batch_indices[flat_mask], flat_pos_2[flat_mask]) |
| aligned_pos_1 = aligned_pos_1.reshape(num_batch, -1, 3) |
| aligned_pos_2 = aligned_pos_2.reshape(num_batch, -1, 3) |
| return aligned_pos_1, aligned_pos_2, align_rots |
|
|
|
|
|
|
| def adjust_oxygen_pos( |
| atom_37: torch.Tensor, pos_is_known = None |
| ) -> torch.Tensor: |
| """ |
| Imputes the position of the oxygen atom on the backbone by using adjacent frame information. |
| Specifically, we say that the oxygen atom is in the plane created by the Calpha and C from the |
| current frame and the nitrogen of the next frame. The oxygen is then placed c_o_bond_length Angstrom |
| away from the C in the current frame in the direction away from the Ca-C-N triangle. |
| |
| For cases where the next frame is not available, for example we are at the C-terminus or the |
| next frame is not available in the data then we place the oxygen in the same plane as the |
| N-Ca-C of the current frame and pointing in the same direction as the average of the |
| Ca->C and Ca->N vectors. |
| |
| Args: |
| atom_37 (torch.Tensor): (N, 37, 3) tensor of positions of the backbone atoms in atom_37 ordering |
| which is ['N', 'CA', 'C', 'CB', 'O', ...] |
| pos_is_known (torch.Tensor): (N,) mask for known residues. |
| """ |
|
|
| N = atom_37.shape[0] |
| assert atom_37.shape == (N, 37, 3) |
|
|
| |
| |
|
|
| |
| calpha_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[:-1, 1, :]) / ( |
| torch.norm(atom_37[:-1, 2, :] - atom_37[:-1, 1, :], keepdim=True, dim=1) + 1e-7 |
| ) |
| |
| |
|
|
| |
| nitrogen_to_carbonyl: torch.Tensor = (atom_37[:-1, 2, :] - atom_37[1:, 0, :]) / ( |
| torch.norm(atom_37[:-1, 2, :] - atom_37[1:, 0, :], keepdim=True, dim=1) + 1e-7 |
| ) |
|
|
| carbonyl_to_oxygen: torch.Tensor = calpha_to_carbonyl + nitrogen_to_carbonyl |
| carbonyl_to_oxygen = carbonyl_to_oxygen / ( |
| torch.norm(carbonyl_to_oxygen, dim=1, keepdim=True) + 1e-7 |
| ) |
|
|
| atom_37[:-1, 4, :] = atom_37[:-1, 2, :] + carbonyl_to_oxygen * 1.23 |
|
|
| |
|
|
| |
| calpha_to_carbonyl_term: torch.Tensor = (atom_37[:, 2, :] - atom_37[:, 1, :]) / ( |
| torch.norm(atom_37[:, 2, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
| ) |
| |
| calpha_to_nitrogen_term: torch.Tensor = (atom_37[:, 0, :] - atom_37[:, 1, :]) / ( |
| torch.norm(atom_37[:, 0, :] - atom_37[:, 1, :], keepdim=True, dim=1) + 1e-7 |
| ) |
| carbonyl_to_oxygen_term: torch.Tensor = ( |
| calpha_to_carbonyl_term + calpha_to_nitrogen_term |
| ) |
| carbonyl_to_oxygen_term = carbonyl_to_oxygen_term / ( |
| torch.norm(carbonyl_to_oxygen_term, dim=1, keepdim=True) + 1e-7 |
| ) |
|
|
| |
| |
| |
|
|
| if pos_is_known is None: |
| pos_is_known = torch.ones((atom_37.shape[0],), dtype=torch.int64, device=atom_37.device) |
|
|
| next_res_gone: torch.Tensor = ~pos_is_known.bool() |
| next_res_gone = torch.cat( |
| [next_res_gone, torch.ones((1,), device=pos_is_known.device).bool()], dim=0 |
| ) |
| next_res_gone = next_res_gone[1:] |
|
|
| atom_37[next_res_gone, 4, :] = ( |
| atom_37[next_res_gone, 2, :] |
| + carbonyl_to_oxygen_term[next_res_gone, :] * 1.23 |
| ) |
|
|
| return atom_37 |
|
|
|
|
| def write_pkl( |
| save_path: str, pkl_data: Any, create_dir: bool = False, use_torch=False): |
| """Serialize data into a pickle file.""" |
| if create_dir: |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) |
| if use_torch: |
| torch.save(pkl_data, save_path, pickle_protocol=pickle.HIGHEST_PROTOCOL) |
| else: |
| with open(save_path, 'wb') as handle: |
| pickle.dump(pkl_data, handle, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
| def read_pkl(read_path: str, verbose=True, use_torch=False, map_location=None): |
| """Read data from a pickle file.""" |
| try: |
| if use_torch: |
| return torch.load(read_path, map_location=map_location) |
| else: |
| with open(read_path, 'rb') as handle: |
| return pickle.load(handle) |
| except Exception as e: |
| try: |
| with open(read_path, 'rb') as handle: |
| return CPU_Unpickler(handle).load() |
| except Exception as e2: |
| if verbose: |
| print(f'Failed to read {read_path}. First error: {e}\n Second error: {e2}') |
| raise(e) |
|
|
|
|
| def chain_str_to_int(chain_str: str): |
| chain_int = 0 |
| if len(chain_str) == 1: |
| return CHAIN_TO_INT[chain_str] |
| for i, chain_char in enumerate(chain_str): |
| chain_int += CHAIN_TO_INT[chain_char] + (i * len(ALPHANUMERIC)) |
| return chain_int |
|
|
|
|
| def parse_chain_feats(chain_feats, scale_factor=1.): |
| ca_idx = residue_constants.atom_order['CA'] |
| chain_feats['bb_mask'] = chain_feats['atom_mask'][:, ca_idx] |
| bb_pos = chain_feats['atom_positions'][:, ca_idx] |
| bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['bb_mask']) + 1e-5) |
| centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] |
| scaled_pos = centered_pos / scale_factor |
| chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] |
| chain_feats['bb_positions'] = chain_feats['atom_positions'][:, ca_idx] |
| return chain_feats |
|
|
|
|
| def concat_np_features( |
| np_dicts: List[Dict[str, np.ndarray]], add_batch_dim: bool): |
| """Performs a nested concatenation of feature dicts. |
| |
| Args: |
| np_dicts: list of dicts with the same structure. |
| Each dict must have the same keys and numpy arrays as the values. |
| add_batch_dim: whether to add a batch dimension to each feature. |
| |
| Returns: |
| A single dict with all the features concatenated. |
| """ |
| combined_dict = collections.defaultdict(list) |
| for chain_dict in np_dicts: |
| for feat_name, feat_val in chain_dict.items(): |
| if add_batch_dim: |
| feat_val = feat_val[None] |
| combined_dict[feat_name].append(feat_val) |
| |
| for feat_name, feat_vals in combined_dict.items(): |
| combined_dict[feat_name] = np.concatenate(feat_vals, axis=0) |
| return combined_dict |
|
|
|
|
| def center_zero(pos: torch.Tensor, batch_indexes: torch.LongTensor) -> torch.Tensor: |
| """ |
| Move the molecule center to zero for sparse position tensors. |
| |
| Args: |
| pos: [N, 3] batch positions of atoms in the molecule in sparse batch format. |
| batch_indexes: [N] batch index for each atom in sparse batch format. |
| |
| Returns: |
| pos: [N, 3] zero-centered batch positions of atoms in the molecule in sparse batch format. |
| """ |
| assert len(pos.shape) == 2 and pos.shape[-1] == 3, "pos must have shape [N, 3]" |
|
|
| means = scatter(pos, batch_indexes, dim=0, reduce="mean") |
| return pos - means[batch_indexes] |
|
|
|
|
| @torch.no_grad() |
| def align_structures( |
| batch_positions: torch.Tensor, |
| batch_indices: torch.Tensor, |
| reference_positions: torch.Tensor, |
| broadcast_reference: bool = False, |
| ): |
| """ |
| Align structures in a ChemGraph batch to a reference, e.g. for RMSD computation. This uses the |
| sparse formulation of pytorch geometric. If the ChemGraph is composed of a single system, then |
| the reference can be given as a single structure and broadcasted. Returns the structure |
| coordinates shifted to the geometric center and the batch structures rotated to match the |
| reference structures. Uses the Kabsch algorithm (see e.g. [kabsch_align1]_). No permutation of |
| atoms is carried out. |
| |
| Args: |
| batch_positions (Tensor): Batch of structures (e.g. from ChemGraph) which should be aligned |
| to a reference. |
| batch_indices (Tensor): Index tensor mapping each node / atom in batch to the respective |
| system (e.g. batch attribute of ChemGraph batch). |
| reference_positions (Tensor): Reference structure. Can either be a batch of structures or a |
| single structure. In the second case, broadcasting is possible if the input batch is |
| composed exclusively of this structure. |
| broadcast_reference (bool, optional): If reference batch contains only a single structure, |
| broadcast this structure to match the ChemGraph batch. Defaults to False. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Tensors containing the centered positions of batch |
| structures rotated into the reference and the centered reference batch. |
| |
| References |
| ---------- |
| .. [kabsch_align1] Lawrence, Bernal, Witzgall: |
| A purely algebraic justification of the Kabsch-Umeyama algorithm. |
| Journal of research of the National Institute of Standards and Technology, 124, 1. 2019. |
| """ |
| |
| |
| |
|
|
| if batch_positions.shape[0] != reference_positions.shape[0]: |
| if broadcast_reference: |
| |
| |
| |
| num_molecules = int(torch.max(batch_indices) + 1) |
| reference_positions = reference_positions.repeat(num_molecules, 1) |
| else: |
| raise ValueError("Mismatch in batch dimensions.") |
|
|
| |
| batch_positions = center_zero(batch_positions, batch_indices) |
| reference_positions = center_zero(reference_positions, batch_indices) |
|
|
| |
| cov = scatter_add( |
| batch_positions[:, None, :] * reference_positions[:, :, None], batch_indices, dim=0 |
| ) |
|
|
| |
| u, _, v_t = torch.linalg.svd(cov) |
| |
| u_t = u.transpose(1, 2) |
| v = v_t.transpose(1, 2) |
|
|
| |
| |
| sign_correction = torch.sign(torch.linalg.det(torch.bmm(v, u_t))) |
| |
| u_t[:, 2, :] = u_t[:, 2, :] * sign_correction[:, None] |
|
|
| |
| rotation_matrices = torch.bmm(v, u_t) |
|
|
| |
| |
| rotation_matrices = rotation_matrices.type(batch_positions.dtype) |
|
|
| |
| batch_positions_rotated = torch.bmm( |
| batch_positions[:, None, :], |
| rotation_matrices[batch_indices], |
| ).squeeze(1) |
|
|
| return batch_positions_rotated, reference_positions, rotation_matrices |
|
|
|
|
| def parse_pdb_feats( |
| pdb_name: str, |
| pdb_path: str, |
| scale_factor=1., |
| |
| chain_id='A', |
| ): |
| """ |
| Args: |
| pdb_name: name of PDB to parse. |
| pdb_path: path to PDB file to read. |
| scale_factor: factor to scale atom positions. |
| mean_center: whether to mean center atom positions. |
| Returns: |
| Dict with CHAIN_FEATS features extracted from PDB with specified |
| preprocessing. |
| """ |
| parser = PDB.PDBParser(QUIET=True) |
| structure = parser.get_structure(pdb_name, pdb_path) |
| struct_chains = { |
| chain.id: chain |
| for chain in structure.get_chains()} |
|
|
| def _process_chain_id(x): |
| chain_prot = process_chain(struct_chains[x], x) |
| chain_dict = dataclasses.asdict(chain_prot) |
|
|
| |
| feat_dict = {x: chain_dict[x] for x in CHAIN_FEATS} |
| return parse_chain_feats( |
| feat_dict, scale_factor=scale_factor) |
|
|
| if isinstance(chain_id, str): |
| return _process_chain_id(chain_id) |
| elif isinstance(chain_id, list): |
| return { |
| x: _process_chain_id(x) for x in chain_id |
| } |
| elif chain_id is None: |
| return { |
| x: _process_chain_id(x) for x in struct_chains |
| } |
| else: |
| raise ValueError(f'Unrecognized chain list {chain_id}') |
|
|
| def rigid_transform_3D(A, B, verbose=False): |
| |
| |
| assert A.shape == B.shape |
| A = A.T |
| B = B.T |
|
|
| num_rows, num_cols = A.shape |
| if num_rows != 3: |
| raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") |
|
|
| num_rows, num_cols = B.shape |
| if num_rows != 3: |
| raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") |
|
|
| |
| centroid_A = np.mean(A, axis=1) |
| centroid_B = np.mean(B, axis=1) |
|
|
| |
| centroid_A = centroid_A.reshape(-1, 1) |
| centroid_B = centroid_B.reshape(-1, 1) |
|
|
| |
| Am = A - centroid_A |
| Bm = B - centroid_B |
|
|
| H = Am @ np.transpose(Bm) |
|
|
| |
| |
| |
|
|
| |
| U, S, Vt = np.linalg.svd(H) |
| R = Vt.T @ U.T |
|
|
| |
| reflection_detected = False |
| if np.linalg.det(R) < 0: |
| if verbose: |
| print("det(R) < R, reflection detected!, correcting for it ...") |
| Vt[2,:] *= -1 |
| R = Vt.T @ U.T |
| reflection_detected = True |
|
|
| t = -R @ centroid_A + centroid_B |
| optimal_A = R @ A + t |
|
|
| return optimal_A.T, R, t, reflection_detected |
|
|
| def process_chain(chain: Chain, chain_id: str) -> Protein: |
| """Convert a PDB chain object into a AlphaFold Protein instance. |
| |
| Forked from alphafold.common.protein.from_pdb_string |
| |
| WARNING: All non-standard residue types will be converted into UNK. All |
| non-standard atoms will be ignored. |
| |
| Took out lines 94-97 which don't allow insertions in the PDB. |
| Sabdab uses insertions for the chothia numbering so we need to allow them. |
| |
| Took out lines 110-112 since that would mess up CDR numbering. |
| |
| Args: |
| chain: Instance of Biopython's chain class. |
| |
| Returns: |
| Protein object with protein features. |
| """ |
| atom_positions = [] |
| aatype = [] |
| atom_mask = [] |
| residue_index = [] |
| b_factors = [] |
| chain_ids = [] |
| for res in chain: |
| res_shortname = residue_constants.restype_3to1.get(res.resname, 'X') |
| restype_idx = residue_constants.restype_order.get( |
| res_shortname, residue_constants.restype_num) |
| pos = np.zeros((residue_constants.atom_type_num, 3)) |
| mask = np.zeros((residue_constants.atom_type_num,)) |
| res_b_factors = np.zeros((residue_constants.atom_type_num,)) |
| for atom in res: |
| if atom.name not in residue_constants.atom_types: |
| continue |
| pos[residue_constants.atom_order[atom.name]] = atom.coord |
| mask[residue_constants.atom_order[atom.name]] = 1. |
| res_b_factors[residue_constants.atom_order[atom.name] |
| ] = atom.bfactor |
| aatype.append(restype_idx) |
| atom_positions.append(pos) |
| atom_mask.append(mask) |
| residue_index.append(res.id[1]) |
| b_factors.append(res_b_factors) |
| chain_ids.append(chain_id) |
|
|
| return Protein( |
| atom_positions=np.array(atom_positions), |
| atom_mask=np.array(atom_mask), |
| aatype=np.array(aatype), |
| residue_index=np.array(residue_index), |
| chain_index=np.array(chain_ids), |
| b_factors=np.array(b_factors)) |
|
|