""" CASCADE worker process """ import os import sys import warnings warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) import json import math import pickle import datetime from io import StringIO import redis import numpy as np os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import keras from keras.models import load_model from NMR_Prediction.apply import ( preprocess_C, preprocess_H, evaluate_C, evaluate_H, RBFSequence, ) from nfp.layers import ( MessageLayer, GRUStep, Squeeze, EdgeNetwork, ReduceBondToPro, ReduceBondToAtom, GatherAtomToBond, ReduceAtomToPro, ) from nfp.models import GraphModel import pandas as pd from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import SDWriter from NMR_Prediction.genConf import genConf MODEL_PATH_C = os.path.join("NMR_Prediction", "schnet_edgeupdate", "best_model.hdf5") MODEL_PATH_H = os.path.join("NMR_Prediction", "schnet_edgeupdate_H", "best_model.hdf5") PREPROCESSOR_PATH = os.path.join("NMR_Prediction", "preprocessor.p") custom_objects = { "MessageLayer": MessageLayer, "GRUStep": GRUStep, "Squeeze": Squeeze, "EdgeNetwork": EdgeNetwork, "ReduceBondToPro": ReduceBondToPro, "ReduceBondToAtom": ReduceBondToAtom, "GatherAtomToBond": GatherAtomToBond, "ReduceAtomToPro": ReduceAtomToPro, "GraphModel": GraphModel, } print("Loading 13C model...", flush=True) model_C = load_model(MODEL_PATH_C, custom_objects=custom_objects) print("Loading 1H model...", flush=True) model_H = load_model(MODEL_PATH_H, custom_objects=custom_objects) print("Both models loaded.", flush=True) with open(PREPROCESSOR_PATH, "rb") as f: preprocessor = pickle.load(f)["preprocessor"] redis_client = redis.StrictRedis( host="localhost", port=6379, db=0, decode_responses=True ) # ── Analytics logging to HF Dataset ────────────────────────────────────────── _HF_TOKEN = os.environ.get("HF_TOKEN", "") _ANALYTICS_REPO = "patonlab/analytics" _ANALYTICS_FILE = "visits.csv" def _log_prediction(): """Append one row to the existing patonlab/analytics data.csv. Format matches the alfabet log: space,timestamp """ if not _HF_TOKEN: return try: from huggingface_hub import HfApi, hf_hub_download import tempfile api = HfApi(token=_HF_TOKEN) timestamp = datetime.datetime.utcnow().isoformat() with tempfile.TemporaryDirectory() as tmpdir: local_path = hf_hub_download( repo_id=_ANALYTICS_REPO, filename=_ANALYTICS_FILE, repo_type="dataset", token=_HF_TOKEN, local_dir=tmpdir, ) with open(local_path, "a") as f: f.write(f"patonlab/cascade,{timestamp}\n") api.upload_file( path_or_fileobj=local_path, path_in_repo=_ANALYTICS_FILE, repo_id=_ANALYTICS_REPO, repo_type="dataset", commit_message=f"log: cascade prediction {timestamp[:10]}", ) except Exception as e: print(f"Analytics logging failed (non-fatal): {e}", flush=True) def _mol_to_sdf(mol, conf_id=0): sio = StringIO() w = SDWriter(sio) w.write(mol, confId=conf_id) w.close() return sio.getvalue() def _build_sdfs_from_genconf(mol_with_confs, ids): sdfs = [] energy_order = [] for energy, conf_id in ids: try: sdf = _mol_to_sdf(mol_with_confs, conf_id=int(conf_id)) if sdf.strip(): sdfs.append(sdf) energy_order.append(int(conf_id)) except Exception as e: print(f"SDF error for conf_id={conf_id}: {e}", flush=True) return sdfs, energy_order def _boltzmann_average(spread_df): spread_df["b_weight"] = spread_df["relative_E"].apply( lambda x: math.exp(-x / (0.001987 * 298.15)) ) df_group = spread_df.set_index(["mol_id", "atom_index", "cf_id"]).groupby(level=[0, 1]) final = [] for (m_id, a_id), df in df_group: ws = (df["b_weight"] * df["predicted"]).sum() / df["b_weight"].sum() final.append([m_id, a_id, ws]) final = pd.DataFrame(final, columns=["mol_id", "atom_index", "Shift"]) final["atom_index"] = final["atom_index"].apply(lambda x: x + 1) return final.round(2).astype(dtype={"atom_index": "int"}) def _fmt_weighted(final_df): return "".join(f"{int(r['atom_index'])},{r['Shift']:.2f};" for _, r in final_df.iterrows()) def _fmt_conf_shifts(spread_df, energy_order): parts = [] for cf_id in energy_order: sub = spread_df[spread_df["cf_id"] == cf_id] if len(sub) == 0: continue parts.append("".join(f"{int(r['atom_index'])},{r['predicted']:.2f};" for _, r in sub.iterrows())) return "!".join(parts) def _fmt_relative_E(spread_df, energy_order): total_bw = spread_df.groupby("cf_id")["b_weight"].first().sum() parts = [] for cf_id in energy_order: sub = spread_df[spread_df["cf_id"] == cf_id] if len(sub) == 0: continue e = round(sub["relative_E"].iloc[0], 2) bw = round(sub["b_weight"].iloc[0] / total_bw, 4) parts.append(f"{e},{bw},") return "!".join(parts) def run_job(task_id, smiles, type_): result_key = f"task_result_{task_id}" try: mol = Chem.MolFromSmiles(smiles) AllChem.EmbedMolecule(mol, useRandomCoords=True) mol_with_h = Chem.AddHs(mol, addCoords=True) mol_with_confs, ids, nr = genConf(mol_with_h, rms=-1, nc=200, efilter=10.0, rmspost=0.5) print(f"genConf: {len(ids)} conformers", flush=True) conf_sdfs, energy_order = _build_sdfs_from_genconf(mol_with_confs, ids) mols = [Chem.MolFromSmiles(smiles)] for m in mols: AllChem.EmbedMolecule(m, useRandomCoords=True) mols = [Chem.AddHs(m, addCoords=True) for m in mols] # Suppress duplicate genConf stdout during preprocess _stdout, _stderr = sys.stdout, sys.stderr sys.stdout = sys.stderr = open(os.devnull, 'w') try: if type_ == "C": inputs, df, conf_mols = preprocess_C(mols, preprocessor, keep_all_cf=True) else: inputs, df, conf_mols = preprocess_H(mols, preprocessor, keep_all_cf=True) finally: sys.stdout.close() sys.stdout, sys.stderr = _stdout, _stderr if type_ == "C": predicted = evaluate_C(inputs, preprocessor, model_C) else: predicted = evaluate_H(inputs, preprocessor, model_H) if len(inputs) == 0: raise RuntimeError("No conformers generated") spread_df = pd.DataFrame(columns=["mol_id", "atom_index", "relative_E", "cf_id"]) for _, r in df.iterrows(): n = len(r["atom_index"]) tmp = pd.DataFrame({ "mol_id": [r["mol_id"]] * n, "atom_index": r["atom_index"], "relative_E": [r["relative_E"]] * n, "cf_id": [r["cf_id"]] * n, }) spread_df = pd.concat([spread_df, tmp], sort=True) spread_df["predicted"] = predicted spread_df["b_weight"] = spread_df["relative_E"].apply( lambda x: math.exp(-x / (0.001987 * 298.15)) ) spread_df["atom_index"] = spread_df["atom_index"].apply(lambda x: x + 1) spread_df = spread_df.round(2) final_df = _boltzmann_average( spread_df.copy().assign( atom_index=spread_df["atom_index"].apply(lambda x: x - 1) ) ) result = { "smiles": smiles, "type_": type_, "conf_sdfs": conf_sdfs, "weightedShiftTxt": _fmt_weighted(final_df), "confShiftTxt": _fmt_conf_shifts(spread_df, energy_order), "relative_E": _fmt_relative_E(spread_df, energy_order), } redis_client.set(result_key, json.dumps(result), ex=3600) print(f"Task {task_id} complete — {len(conf_sdfs)} conformers", flush=True) # Log to analytics dataset (non-blocking) _log_prediction() except Exception as e: import traceback; traceback.print_exc() redis_client.set(result_key, json.dumps({"errMessage": str(e)}), ex=3600) print("Worker ready, waiting for jobs...", flush=True) while True: item = redis_client.blpop("task_queue", timeout=5) if item is None: continue _, task_id = item detail = redis_client.get(f"task_detail_{task_id}") if not detail: continue detail = json.loads(detail) print(f"Processing task {task_id} smiles={detail['smiles']} type={detail['type_']}", flush=True) run_job(task_id, detail["smiles"], detail["type_"])