| import os |
| from typing import * |
|
|
| import numpy as np |
| import torch |
| from scipy.spatial import distance |
| |
| from src.common.geo_utils import rmsd, _find_rigid_alignment, squared_deviation |
| from scipy.linalg import fractional_matrix_power |
| from sklearn.mixture import GaussianMixture |
| from Bio.PDB import PDBParser |
| |
| from Bio.PDB.Polypeptide import PPBuilder |
| import multiprocessing as mp |
|
|
| EPS = 1e-12 |
| PSEUDO_C = 1e-6 |
|
|
|
|
| def adjacent_ca_distance(coords): |
| """Calculate distance array for a single chain of CA atoms. Only k=1 neighbors. |
| Args: |
| coords: (..., L, 3) |
| return |
| dist: (..., L-1) |
| """ |
| assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
| dX = coords[..., :-1, :] - coords[..., 1:, :] |
| dist = np.sqrt(np.sum(dX**2, axis=-1)) |
| return dist |
|
|
|
|
| def distance_matrix_ca(coords): |
| """Calculate distance matrix for a single chain of CA atoms. W/o exclude neighbors. |
| Args: |
| coords: (..., L, 3) |
| Return: |
| dist: (..., L, L) |
| """ |
| assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
| dX = coords[..., None, :, :] - coords[..., None, :] |
| dist = np.sqrt(np.sum(dX**2, axis=-1)) |
| return dist |
|
|
|
|
| def pairwise_distance_ca(coords, k=1): |
| """Calculate pairwise distance vector for a single chain of CA atoms. W/o exclude neighbors. |
| Args: |
| coords: (..., L, 3) |
| Return: |
| dist: (..., D) (D=L * (L - 1) // 2) when k=1) |
| """ |
| assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
| dist = distance_matrix_ca(coords) |
| L = dist.shape[-1] |
| row, col = np.triu_indices(L, k=k) |
| triu = dist[..., row, col] |
| return triu |
|
|
|
|
| def radius_of_gyration(coords, masses=None): |
| """Compute the radius of gyration for every frame. |
| |
| Args: |
| coords: (..., num_atoms, 3) |
| masses: (num_atoms,) |
| |
| Returns: |
| Rg: (..., ) |
| |
| If masses are none, assumes equal masses. |
| """ |
| assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
| |
| if masses is None: |
| masses = np.ones(coords.shape[-2]) |
| else: |
| assert len(masses.shape) == 1, f"masses should be 1D, got {masses.shape}" |
| assert masses.shape[0] == coords.shape[-2], f"masses {masses.shape} != number of particles {coords.shape[-2]}" |
|
|
| weights = masses / masses.sum() |
| centered = coords - coords.mean(-2, keepdims=True) |
| squared_dists = (centered ** 2).sum(-1) |
| Rg = (squared_dists * weights).sum(-1) ** 0.5 |
| return Rg |
|
|
|
|
| def _steric_clash(coords, ca_vdw_radius=1.7, allowable_overlap=0.4, k_exclusion=0): |
| """ https://www.schrodinger.com/sites/default/files/s3/public/python_api/2022-3/_modules/schrodinger/structutils/interactions/steric_clash.html#clash_iterator |
| Calculate the number of clashes in a single chain of CA atoms. |
| |
| Usage: |
| n_clash = calc_clash(coords) |
| |
| Args: |
| coords: (n_atoms, 3), CA coordinates, coords should from one protein chain. |
| ca_vdw_radius: float, default 1.7. |
| allowable_overlap: float, default 0.4. |
| k_exclusion: int, default 0. Exclude neighbors within [i-k-1, i+k+1]. |
| |
| """ |
| assert np.isnan(coords).sum() == 0, "coords should not contain nan" |
| assert len(coords.shape) in (2, 3), f"CA coords should be 2D or 3D, got {coords.shape}" |
| assert k_exclusion >= 0, "k_exclusion should be non-negative" |
| bar = 2 * ca_vdw_radius - allowable_overlap |
| |
| |
| pwd = pairwise_distance_ca(coords, k=k_exclusion+1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| assert len(pwd.shape) == 2, f"pwd should be 2D, got {pwd.shape}" |
| n_clash = np.sum(pwd < bar, axis=-1) |
| return n_clash.astype(int) |
|
|
|
|
| def validity(ca_coords_dict, **clash_kwargs): |
| """Calculate clash validity of ensembles. |
| Args: |
| ca_coords_dict: {k: (B, L, 3)} |
| Return: |
| valid: {k: validity in [0,1]} |
| """ |
| num_residue = float(ca_coords_dict['target'].shape[1]) |
| n_clash = { |
| k: _steric_clash(v, **clash_kwargs) |
| for k, v in ca_coords_dict.items() |
| } |
| |
| |
| |
| results = { |
| k: 1.0 - (v/num_residue).mean() for k, v in n_clash.items() |
| } |
|
|
| results = {k: np.around(v, decimals=4) for k, v in results.items()} |
| return results |
|
|
|
|
| def bonding_validity(ca_coords_dict, ref_key='target', eps=1e-6): |
| """Calculate bonding dissociation validity of ensembles.""" |
| adj_dist = {k: adjacent_ca_distance(v) |
| for k, v in ca_coords_dict.items() |
| } |
| thres = adj_dist[ref_key].max()+ 1e-6 |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| results = { |
| k: (v < thres).mean() |
| for k, v in adj_dist.items() |
| } |
|
|
| results = {k: np.around(v, decimals=4) for k, v in results.items()} |
| return results |
|
|
|
|
| def js_pwd(ca_coords_dict, ref_key='target', n_bins=50, pwd_offset=3, weights=None): |
| |
| |
| |
| ca_pwd = { |
| k: pairwise_distance_ca(v, k=pwd_offset) for k, v in ca_coords_dict.items() |
| } |
| |
| if weights is None: |
| weights = {} |
| weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) |
| |
| d_min = ca_pwd[ref_key].min(axis=0) |
| d_max = ca_pwd[ref_key].max(axis=0) |
| ca_pwd_binned = { |
| k: np.apply_along_axis(lambda a: np.histogram(a[:-2], bins=n_bins, weights=weights[k], range=(a[-2], a[-1]))[0]+PSEUDO_C, 0, |
| np.concatenate([v, d_min[None], d_max[None]], axis=0)) |
| for k, v in ca_pwd.items() |
| } |
| |
| results = {k: distance.jensenshannon(v, ca_pwd_binned[ref_key], axis=0).mean() |
| for k, v in ca_pwd_binned.items() if k != ref_key} |
| results[ref_key] = 0.0 |
| results = {k: np.around(v, decimals=4) for k, v in results.items()} |
| return results |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def js_rg(ca_coords_dict, ref_key='target', n_bins=50, weights=None): |
| ca_rg = { |
| k: radius_of_gyration(v) for k, v in ca_coords_dict.items() |
| } |
| if weights is None: |
| weights = {} |
| weights.update({k: np.ones(len(v)) for k,v in ca_coords_dict.items() if k not in weights}) |
| |
| d_min = ca_rg[ref_key].min() |
| d_max = ca_rg[ref_key].max() |
| ca_rg_binned = { |
| k: np.histogram(v, bins=n_bins, weights=weights[k], range=(d_min, d_max))[0]+PSEUDO_C |
| for k, v in ca_rg.items() |
| } |
| |
| results = {k: distance.jensenshannon(v, ca_rg_binned[ref_key], axis=0).mean() |
| for k, v in ca_rg_binned.items() if k != ref_key} |
| |
| results[ref_key] = 0.0 |
| results = {k: np.around(v, decimals=4) for k, v in results.items()} |
| return results |
|
|
| def div_rmsd(ca_coords_dict): |
| results = {} |
| for k, v in ca_coords_dict.items(): |
|
|
| |
| |
| |
|
|
| v = torch.as_tensor(v) |
| |
| |
| |
|
|
| |
| |
| |
|
|
| count = 0 |
| rmsd_2_sum = 0 |
| for coord1 in v: |
| for coord2 in v: |
| count += 1 |
| rmsd_2_sum += squared_deviation(coord1,coord2,reduction='none') |
|
|
| |
| |
| |
| |
| |
| |
|
|
| results[k]=torch.sqrt(rmsd_2_sum/count) |
| results[k]=np.around(float(torch.mean(results[k])), decimals=4) |
| results['pred'] = (results['pred']-results['target'])/results['target'] |
| results = {k: np.around(v, decimals=4) for k, v in results.items()} |
| |
| return results |
| |
| def div_rmsf(ca_coords_dict): |
| ''' |
| 1D and 0D data |
| ''' |
| results = {} |
| for k, v in ca_coords_dict.items(): |
|
|
| v = torch.as_tensor(v) |
| |
| |
| |
|
|
| count = 0 |
| rmsd_2_sum = 0 |
| mean_str = torch.mean(v,dim = 0) |
| for coord1 in v: |
| count += 1 |
| rmsd_2_sum += squared_deviation(coord1,mean_str,reduction='none') |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| results[k]=torch.sqrt(rmsd_2_sum/count) |
| results[k]=np.around(float(torch.mean(results[k])), decimals=4) |
| |
| results['pred'] = (results['pred']-results['target'])/results['target'] |
| results = {k: np.around(v, decimals=4) for k, v in results.items()} |
| return results |
|
|
| def w2_rmwd(ca_coords_dict): |
| result = {} |
| means_total = {} |
| covariances_total = {} |
| count = 0 |
| v_ref = torch.as_tensor(ca_coords_dict['target'][0]) |
| for k, v in ca_coords_dict.items(): |
|
|
| v = torch.as_tensor(v) |
| |
| |
| |
|
|
| means_total[k] = [] |
| covariances_total[k] = [] |
|
|
| for idx_residue in range(v.shape[1]): |
| gmm = GaussianMixture(n_components=1) |
| gmm.fit(v[:, idx_residue, :]) |
| means = torch.as_tensor(gmm.means_[0]) |
| covariances = torch.as_tensor(gmm.covariances_[0]) |
|
|
| means_total[k].append(means) |
| covariances_total[k].append(covariances) |
| means_total[k] = torch.stack(means_total[k], dim=0) |
| covariances_total[k] = torch.stack(covariances_total[k], dim=0) |
| |
| |
|
|
| sigma_1_2_sqrt = [torch.as_tensor(fractional_matrix_power(i, 0.5)) for i in torch.matmul(covariances_total['target'], covariances_total['pred'])] |
| sigma_1_2_sqrt = torch.stack(sigma_1_2_sqrt, dim=0) |
| sigma_trace = covariances_total['target'] + covariances_total['pred'] - 2 * sigma_1_2_sqrt |
| sigma_trace = [torch.trace(i) for i in sigma_trace] |
| sigma_trace = torch.stack(sigma_trace, dim=0) |
|
|
| result_1D = torch.sum((means_total['target'] - means_total['pred'])**2, dim=-1) + sigma_trace |
| result['pred'] = np.around(float(torch.mean(result_1D)), decimals=4) |
| |
|
|
| return result |
|
|
| def pro_w_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1): |
| result = {} |
| w_contacts_total = {} |
|
|
| dist = distance_matrix_ca(cry_ca_coords) |
| L = dist.shape[-1] |
| row, col = np.triu_indices(L, k=1) |
| triu = dist[..., row, col] |
| w_contacts_crystall = (triu < dist_threshold) |
|
|
| for k, v in ca_coords_dict.items(): |
|
|
| dist = distance_matrix_ca(v) |
|
|
| L = dist.shape[-1] |
| row, col = np.triu_indices(L, k=1) |
| triu = dist[..., row, col] |
|
|
| w_contacts = (torch.tensor(triu) > dist_threshold).type(torch.float32) |
| w_contacts = torch.mean(w_contacts, dim=0) |
| w_contacts = w_contacts > percent_threshold |
|
|
| w_contacts_total[k] = w_contacts & w_contacts_crystall |
| |
| jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred']) |
| result['pred'] = np.around(float(jac_w_contacts), decimals=4) |
| |
|
|
| return result |
|
|
| def pro_t_contacts(ca_coords_dict, cry_ca_coords, dist_threshold = 8.0, percent_threshold = 0.1): |
| result = {} |
| w_contacts_total = {} |
|
|
| dist = distance_matrix_ca(cry_ca_coords) |
| L = dist.shape[-1] |
| row, col = np.triu_indices(L, k=1) |
| triu = dist[..., row, col] |
| w_contacts_crystall = (triu >= dist_threshold) |
|
|
| for k, v in ca_coords_dict.items(): |
|
|
| dist = distance_matrix_ca(v) |
|
|
| L = dist.shape[-1] |
| row, col = np.triu_indices(L, k=1) |
| triu = dist[..., row, col] |
|
|
| w_contacts = (torch.tensor(triu) <= dist_threshold).type(torch.float32) |
| w_contacts = torch.mean(w_contacts, dim=0) |
| w_contacts = w_contacts > percent_threshold |
|
|
| w_contacts_total[k] = w_contacts & w_contacts_crystall |
|
|
| jac_w_contacts = torch.sum(w_contacts_total['target'] & w_contacts_total['pred'])/torch.sum(w_contacts_total['target'] | w_contacts_total['pred']) |
| result['pred'] = np.around(float(jac_w_contacts), decimals=4) |
| |
|
|
| return result |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |