| """ |
| CASCADE worker process |
| """ |
|
|
| import os |
| import sys |
| import warnings |
| warnings.filterwarnings("ignore", category=UserWarning) |
| warnings.filterwarnings("ignore", category=DeprecationWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| import json |
| import math |
| import pickle |
| import datetime |
| from io import StringIO |
|
|
| import redis |
| import numpy as np |
|
|
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
|
| import keras |
| from keras.models import load_model |
| from NMR_Prediction.apply import ( |
| preprocess_C, preprocess_H, |
| evaluate_C, evaluate_H, |
| RBFSequence, |
| ) |
| from nfp.layers import ( |
| MessageLayer, GRUStep, Squeeze, EdgeNetwork, |
| ReduceBondToPro, ReduceBondToAtom, |
| GatherAtomToBond, ReduceAtomToPro, |
| ) |
| from nfp.models import GraphModel |
|
|
| import pandas as pd |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from rdkit.Chem import SDWriter |
| from NMR_Prediction.genConf import genConf |
|
|
| MODEL_PATH_C = os.path.join("NMR_Prediction", "schnet_edgeupdate", "best_model.hdf5") |
| MODEL_PATH_H = os.path.join("NMR_Prediction", "schnet_edgeupdate_H", "best_model.hdf5") |
| PREPROCESSOR_PATH = os.path.join("NMR_Prediction", "preprocessor.p") |
|
|
| custom_objects = { |
| "MessageLayer": MessageLayer, |
| "GRUStep": GRUStep, |
| "Squeeze": Squeeze, |
| "EdgeNetwork": EdgeNetwork, |
| "ReduceBondToPro": ReduceBondToPro, |
| "ReduceBondToAtom": ReduceBondToAtom, |
| "GatherAtomToBond": GatherAtomToBond, |
| "ReduceAtomToPro": ReduceAtomToPro, |
| "GraphModel": GraphModel, |
| } |
|
|
| print("Loading 13C model...", flush=True) |
| model_C = load_model(MODEL_PATH_C, custom_objects=custom_objects) |
| print("Loading 1H model...", flush=True) |
| model_H = load_model(MODEL_PATH_H, custom_objects=custom_objects) |
| print("Both models loaded.", flush=True) |
|
|
| with open(PREPROCESSOR_PATH, "rb") as f: |
| preprocessor = pickle.load(f)["preprocessor"] |
|
|
| redis_client = redis.StrictRedis( |
| host="localhost", port=6379, db=0, decode_responses=True |
| ) |
|
|
| |
| _HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| _ANALYTICS_REPO = "patonlab/analytics" |
| _ANALYTICS_FILE = "visits.csv" |
|
|
| def _log_prediction(): |
| """Append one row to the existing patonlab/analytics data.csv. |
| Format matches the alfabet log: space,timestamp |
| """ |
| if not _HF_TOKEN: |
| return |
| try: |
| from huggingface_hub import HfApi, hf_hub_download |
| import tempfile |
|
|
| api = HfApi(token=_HF_TOKEN) |
| timestamp = datetime.datetime.utcnow().isoformat() |
|
|
| with tempfile.TemporaryDirectory() as tmpdir: |
| local_path = hf_hub_download( |
| repo_id=_ANALYTICS_REPO, |
| filename=_ANALYTICS_FILE, |
| repo_type="dataset", |
| token=_HF_TOKEN, |
| local_dir=tmpdir, |
| ) |
| with open(local_path, "a") as f: |
| f.write(f"patonlab/cascade,{timestamp}\n") |
|
|
| api.upload_file( |
| path_or_fileobj=local_path, |
| path_in_repo=_ANALYTICS_FILE, |
| repo_id=_ANALYTICS_REPO, |
| repo_type="dataset", |
| commit_message=f"log: cascade prediction {timestamp[:10]}", |
| ) |
| except Exception as e: |
| print(f"Analytics logging failed (non-fatal): {e}", flush=True) |
|
|
|
|
| def _mol_to_sdf(mol, conf_id=0): |
| sio = StringIO() |
| w = SDWriter(sio) |
| w.write(mol, confId=conf_id) |
| w.close() |
| return sio.getvalue() |
|
|
|
|
| def _build_sdfs_from_genconf(mol_with_confs, ids): |
| sdfs = [] |
| energy_order = [] |
| for energy, conf_id in ids: |
| try: |
| sdf = _mol_to_sdf(mol_with_confs, conf_id=int(conf_id)) |
| if sdf.strip(): |
| sdfs.append(sdf) |
| energy_order.append(int(conf_id)) |
| except Exception as e: |
| print(f"SDF error for conf_id={conf_id}: {e}", flush=True) |
| return sdfs, energy_order |
|
|
|
|
| def _boltzmann_average(spread_df): |
| spread_df["b_weight"] = spread_df["relative_E"].apply( |
| lambda x: math.exp(-x / (0.001987 * 298.15)) |
| ) |
| df_group = spread_df.set_index(["mol_id", "atom_index", "cf_id"]).groupby(level=[0, 1]) |
| final = [] |
| for (m_id, a_id), df in df_group: |
| ws = (df["b_weight"] * df["predicted"]).sum() / df["b_weight"].sum() |
| final.append([m_id, a_id, ws]) |
| final = pd.DataFrame(final, columns=["mol_id", "atom_index", "Shift"]) |
| final["atom_index"] = final["atom_index"].apply(lambda x: x + 1) |
| return final.round(2).astype(dtype={"atom_index": "int"}) |
|
|
|
|
| def _fmt_weighted(final_df): |
| return "".join(f"{int(r['atom_index'])},{r['Shift']:.2f};" for _, r in final_df.iterrows()) |
|
|
|
|
| def _fmt_conf_shifts(spread_df, energy_order): |
| parts = [] |
| for cf_id in energy_order: |
| sub = spread_df[spread_df["cf_id"] == cf_id] |
| if len(sub) == 0: |
| continue |
| parts.append("".join(f"{int(r['atom_index'])},{r['predicted']:.2f};" for _, r in sub.iterrows())) |
| return "!".join(parts) |
|
|
|
|
| def _fmt_relative_E(spread_df, energy_order): |
| total_bw = spread_df.groupby("cf_id")["b_weight"].first().sum() |
| parts = [] |
| for cf_id in energy_order: |
| sub = spread_df[spread_df["cf_id"] == cf_id] |
| if len(sub) == 0: |
| continue |
| e = round(sub["relative_E"].iloc[0], 2) |
| bw = round(sub["b_weight"].iloc[0] / total_bw, 4) |
| parts.append(f"{e},{bw},") |
| return "!".join(parts) |
|
|
|
|
| def run_job(task_id, smiles, type_): |
| result_key = f"task_result_{task_id}" |
| try: |
| mol = Chem.MolFromSmiles(smiles) |
| AllChem.EmbedMolecule(mol, useRandomCoords=True) |
| mol_with_h = Chem.AddHs(mol, addCoords=True) |
|
|
| mol_with_confs, ids, nr = genConf(mol_with_h, rms=-1, nc=200, efilter=10.0, rmspost=0.5) |
| print(f"genConf: {len(ids)} conformers", flush=True) |
|
|
| conf_sdfs, energy_order = _build_sdfs_from_genconf(mol_with_confs, ids) |
|
|
| mols = [Chem.MolFromSmiles(smiles)] |
| for m in mols: |
| AllChem.EmbedMolecule(m, useRandomCoords=True) |
| mols = [Chem.AddHs(m, addCoords=True) for m in mols] |
|
|
| |
| _stdout, _stderr = sys.stdout, sys.stderr |
| sys.stdout = sys.stderr = open(os.devnull, 'w') |
| try: |
| if type_ == "C": |
| inputs, df, conf_mols = preprocess_C(mols, preprocessor, keep_all_cf=True) |
| else: |
| inputs, df, conf_mols = preprocess_H(mols, preprocessor, keep_all_cf=True) |
| finally: |
| sys.stdout.close() |
| sys.stdout, sys.stderr = _stdout, _stderr |
|
|
| if type_ == "C": |
| predicted = evaluate_C(inputs, preprocessor, model_C) |
| else: |
| predicted = evaluate_H(inputs, preprocessor, model_H) |
|
|
| if len(inputs) == 0: |
| raise RuntimeError("No conformers generated") |
|
|
| spread_df = pd.DataFrame(columns=["mol_id", "atom_index", "relative_E", "cf_id"]) |
| for _, r in df.iterrows(): |
| n = len(r["atom_index"]) |
| tmp = pd.DataFrame({ |
| "mol_id": [r["mol_id"]] * n, |
| "atom_index": r["atom_index"], |
| "relative_E": [r["relative_E"]] * n, |
| "cf_id": [r["cf_id"]] * n, |
| }) |
| spread_df = pd.concat([spread_df, tmp], sort=True) |
|
|
| spread_df["predicted"] = predicted |
| spread_df["b_weight"] = spread_df["relative_E"].apply( |
| lambda x: math.exp(-x / (0.001987 * 298.15)) |
| ) |
| spread_df["atom_index"] = spread_df["atom_index"].apply(lambda x: x + 1) |
| spread_df = spread_df.round(2) |
|
|
| final_df = _boltzmann_average( |
| spread_df.copy().assign( |
| atom_index=spread_df["atom_index"].apply(lambda x: x - 1) |
| ) |
| ) |
|
|
| result = { |
| "smiles": smiles, |
| "type_": type_, |
| "conf_sdfs": conf_sdfs, |
| "weightedShiftTxt": _fmt_weighted(final_df), |
| "confShiftTxt": _fmt_conf_shifts(spread_df, energy_order), |
| "relative_E": _fmt_relative_E(spread_df, energy_order), |
| } |
| redis_client.set(result_key, json.dumps(result), ex=3600) |
| print(f"Task {task_id} complete β {len(conf_sdfs)} conformers", flush=True) |
|
|
| |
| _log_prediction() |
|
|
| except Exception as e: |
| import traceback; traceback.print_exc() |
| redis_client.set(result_key, json.dumps({"errMessage": str(e)}), ex=3600) |
|
|
|
|
| print("Worker ready, waiting for jobs...", flush=True) |
| while True: |
| item = redis_client.blpop("task_queue", timeout=5) |
| if item is None: |
| continue |
| _, task_id = item |
| detail = redis_client.get(f"task_detail_{task_id}") |
| if not detail: |
| continue |
| detail = json.loads(detail) |
| print(f"Processing task {task_id} smiles={detail['smiles']} type={detail['type_']}", flush=True) |
| run_job(task_id, detail["smiles"], detail["type_"]) |
|
|