| |
| |
| import argparse |
| import json |
| import os |
| import random |
| from copy import deepcopy |
| from collections import defaultdict |
| from tqdm import tqdm |
| from tqdm.contrib.concurrent import process_map |
| import statistics |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| from scipy.stats import spearmanr |
|
|
| from data.converter.pdb_to_list_blocks import pdb_to_list_blocks |
| from evaluation import diversity |
| from evaluation.dockq import dockq |
| from evaluation.rmsd import compute_rmsd |
| from utils.random_seed import setup_seed |
| from evaluation.seq_metric import aar, slide_aar |
|
|
|
|
| def _get_ref_pdb(_id, root_dir): |
| return os.path.join(root_dir, 'references', f'{_id}_ref.pdb') |
|
|
|
|
| def _get_gen_pdb(_id, number, root_dir, use_rosetta): |
| suffix = '_rosetta' if use_rosetta else '' |
| return os.path.join(root_dir, 'candidates', _id, f'{_id}_gen_{number}{suffix}.pdb') |
|
|
|
|
| def cal_metrics(items): |
| |
| root_dir = items[0]['root_dir'] |
| ref_pdb, rec_chain, lig_chain = items[0]['ref_pdb'], items[0]['rec_chain'], items[0]['lig_chain'] |
| ref_pdb = _get_ref_pdb(items[0]['id'], root_dir) |
| seq_only, struct_only, backbone_only = items[0]['seq_only'], items[0]['struct_only'], items[0]['backbone_only'] |
|
|
| |
| results = defaultdict(list) |
| cand_seqs, cand_ca_xs = [], [] |
| rec_blocks, ref_pep_blocks = pdb_to_list_blocks(ref_pdb, [rec_chain, lig_chain]) |
| ref_ca_x, ca_mask = [], [] |
| for ref_block in ref_pep_blocks: |
| if ref_block.has_unit('CA'): |
| ca_mask.append(1) |
| ref_ca_x.append(ref_block.get_unit_by_name('CA').get_coord()) |
| else: |
| ca_mask.append(0) |
| ref_ca_x.append([0, 0, 0]) |
| ref_ca_x, ca_mask = np.array(ref_ca_x), np.array(ca_mask).astype(bool) |
|
|
| for item in items: |
| if not struct_only: |
| cand_seqs.append(item['gen_seq']) |
| results['Slide AAR'].append(slide_aar(item['gen_seq'], item['ref_seq'], aar)) |
| |
| |
| gen_pdb = _get_gen_pdb(item['id'], item['number'], root_dir, item['rosetta']) |
| _, gen_pep_blocks = pdb_to_list_blocks(gen_pdb, [rec_chain, lig_chain]) |
| assert len(gen_pep_blocks) == len(ref_pep_blocks), f'{item}\t{len(ref_pep_blocks)}\t{len(gen_pep_blocks)}' |
|
|
| |
| gen_ca_x = np.array([block.get_unit_by_name('CA').get_coord() for block in gen_pep_blocks]) |
| cand_ca_xs.append(gen_ca_x) |
| rmsd = compute_rmsd(ref_ca_x[ca_mask], gen_ca_x[ca_mask], aligned=True) |
| results['RMSD(CA)'].append(rmsd) |
| if struct_only: |
| results['RMSD<=2.0'].append(1 if rmsd <= 2.0 else 0) |
| results['RMSD<=5.0'].append(1 if rmsd <= 5.0 else 0) |
| results['RMSD<=10.0'].append(1 if rmsd <= 10.0 else 0) |
|
|
|
|
| if backbone_only: |
| continue |
|
|
| |
| dockq_score = dockq(gen_pdb, ref_pdb, lig_chain) |
| results['DockQ'].append(dockq_score) |
| if struct_only: |
| results['DockQ>=0.23'].append(1 if dockq_score >= 0.23 else 0) |
| results['DockQ>=0.49'].append(1 if dockq_score >= 0.49 else 0) |
| results['DockQ>=0.80'].append(1 if dockq_score >= 0.80 else 0) |
|
|
| |
| if struct_only: |
| gen_all_x, ref_all_x = [], [] |
| for gen_block, ref_block in zip(gen_pep_blocks, ref_pep_blocks): |
| for ref_atom in ref_block: |
| if gen_block.has_unit(ref_atom.name): |
| ref_all_x.append(ref_atom.get_coord()) |
| gen_all_x.append(gen_block.get_unit_by_name(ref_atom.name).get_coord()) |
| results['RMSD(full-atom)'].append(compute_rmsd( |
| np.array(gen_all_x), np.array(ref_all_x), aligned=True |
| )) |
| |
| pmets = [item['pmetric'] for item in items] |
| indexes = list(range(len(items))) |
| |
| for name in results: |
| vals = results[name] |
| corr = spearmanr(vals, pmets, nan_policy='omit').statistic |
| if np.isnan(corr): |
| corr = 0 |
| aggr_res = { |
| 'max': max(vals), |
| 'min': min(vals), |
| 'mean': sum(vals) / len(vals), |
| 'random': vals[0], |
| 'max*': vals[(max if corr > 0 else min)(indexes, key=lambda i: pmets[i])], |
| 'min*': vals[(min if corr > 0 else max)(indexes, key=lambda i: pmets[i])], |
| 'pmet_corr': corr, |
| 'individual': vals, |
| 'individual_pmet': pmets |
| } |
| results[name] = aggr_res |
|
|
| if len(cand_seqs) > 1 and not seq_only: |
| seq_div, struct_div, co_div, consistency = diversity.diversity(cand_seqs, np.array(cand_ca_xs)) |
| results['Sequence Diversity'] = seq_div |
| results['Struct Diversity'] = struct_div |
| results['Codesign Diversity'] = co_div |
| results['Consistency'] = consistency |
|
|
| return results |
|
|
|
|
| def cnt_aa_dist(seqs): |
| cnts = {} |
| for seq in seqs: |
| for aa in seq: |
| if aa not in cnts: |
| cnts[aa] = 0 |
| cnts[aa] += 1 |
| aas = sorted(list(cnts.keys()), key=lambda aa: cnts[aa]) |
| total = sum(cnts.values()) |
| for aa in aas: |
| print(f'\t{aa}: {cnts[aa] / total}') |
|
|
|
|
| def main(args): |
| root_dir = os.path.dirname(args.results) |
| |
| if args.filter_dG is None: |
| filter_func = lambda _id, n: True |
| else: |
| dG_results = json.load(open(args.filter_dG, 'r')) |
| filter_func = lambda _id, n: dG_results[_id]['all'][str(n)] < 0 |
| |
| with open(args.results, 'r') as fin: |
| lines = fin.read().strip().split('\n') |
| id2items = {} |
| for line in lines: |
| item = json.loads(line) |
| _id = item['id'] |
| if not filter_func(_id, item['number']): |
| continue |
| if _id not in id2items: |
| id2items[_id] = [] |
| item['root_dir'] = root_dir |
| item['rosetta'] = args.rosetta |
| id2items[_id].append(item) |
| ids = list(id2items.keys()) |
|
|
| if args.filter_dG is not None: |
| |
| del_ids = [_id for _id in ids if len(id2items[_id]) < 2] |
| for _id in del_ids: |
| print(f'Deleting {_id} since it only has one sample passed the filter') |
| del id2items[_id] |
|
|
| if args.num_workers > 1: |
| metrics = process_map(cal_metrics, id2items.values(), max_workers=args.num_workers, chunksize=1) |
| else: |
| metrics = [cal_metrics(inputs) for inputs in tqdm(id2items.values())] |
| |
| eval_results_path = os.path.join(os.path.dirname(args.results), 'eval_report.json') |
| with open(eval_results_path, 'w') as fout: |
| for i, _id in enumerate(id2items): |
| metric = deepcopy(metrics[i]) |
| metric['id'] = _id |
| fout.write(json.dumps(metric) + '\n') |
|
|
| |
| print('Point-wise evaluation results:') |
| for name in metrics[0]: |
| vals = [item[name] for item in metrics] |
| if isinstance(vals[0], dict): |
| if 'RMSD' in name and '<=' not in name: |
| aggr = 'min' |
| else: |
| aggr = 'max' |
| aggr_vals = [val[aggr] for val in vals] |
| if '>=' in name or '<=' in name: |
| print(f'{name}: {sum(aggr_vals) / len(aggr_vals)}') |
| else: |
| if 'RMSD' in name: |
| print(f'{name}(median): {statistics.median(aggr_vals)}') |
| else: |
| print(f'{name}(mean): {sum(aggr_vals) / len(aggr_vals)}') |
| lowest_i = min([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i]) |
| highest_i = max([i for i in range(len(aggr_vals))], key=lambda i: aggr_vals[i]) |
| print(f'\tlowest: {aggr_vals[lowest_i]}, id: {ids[lowest_i]}', end='') |
| print(f'\thighest: {aggr_vals[highest_i]}, id: {ids[highest_i]}') |
| else: |
| print(f'{name} (mean): {sum(vals) / len(vals)}') |
| lowest_i = min([i for i in range(len(vals))], key=lambda i: vals[i]) |
| highest_i = max([i for i in range(len(vals))], key=lambda i: vals[i]) |
| print(f'\tlowest: {vals[lowest_i]}, id: {ids[lowest_i]}') |
| print(f'\thighest: {vals[highest_i]}, id: {ids[highest_i]}') |
|
|
|
|
| def parse(): |
| parser = argparse.ArgumentParser(description='calculate metrics') |
| parser.add_argument('--results', type=str, required=True, help='Path to test set') |
| parser.add_argument('--num_workers', type=int, default=8, help='Number of workers to use') |
| parser.add_argument('--rosetta', action='store_true', help='Use the rosetta-refined structure') |
| parser.add_argument('--filter_dG', type=str, default=None, help='Only calculate results on samples with dG<0') |
|
|
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| setup_seed(0) |
| main(parse()) |
|
|