Spaces:
Running
Running
| """RDKit/TDC-backed molecular oracle helpers for MolForge.""" | |
| from __future__ import annotations | |
| import math | |
| from functools import lru_cache | |
| from typing import Any, Dict, Mapping, Optional | |
| WARHEAD_SMILES = { | |
| "acrylamide": "C(=O)NC=C", | |
| "reversible_cyanoacrylamide": "C(=O)NC(=C)C#N", | |
| "nitrile": "C#N", | |
| "vinyl_sulfonamide": "S(=O)(=O)NC=C", | |
| } | |
| HINGE_SMILES = { | |
| "azaindole": "c1[nH]c2ccccc2n1", | |
| "pyridine": "c1ccncc1", | |
| "fluorophenyl": "c1ccc(F)cc1", | |
| "quinazoline": "c1ncnc2ccccc12", | |
| } | |
| TAIL_SMILES = { | |
| "morpholine": "N1CCOCC1", | |
| "piperazine": "N1CCNCC1", | |
| "cyclopropyl": "C1CC1", | |
| "dimethylamino": "N(C)C", | |
| } | |
| BACK_POCKET_SMILES = { | |
| "methoxy": "OC", | |
| "chloro": "Cl", | |
| "trifluoromethyl": "C(F)(F)F", | |
| "cyano": "C#N", | |
| } | |
| def assemble_surrogate_smiles(molecule: Mapping[str, str]) -> str: | |
| """Build a valid substituted-aryl SMILES for RDKit/TDC scoring.""" | |
| return ( | |
| f"c%10({WARHEAD_SMILES[molecule['warhead']]})" | |
| f"c({HINGE_SMILES[molecule['hinge']]})" | |
| f"c({TAIL_SMILES[molecule['solvent_tail']]})" | |
| f"c({BACK_POCKET_SMILES[molecule['back_pocket']]})cc%10" | |
| ) | |
| def oracle_backend_status() -> Dict[str, bool]: | |
| """Report which external chemistry engines are importable.""" | |
| return {"rdkit": _rdkit_modules() is not None, "tdc": _tdc_oracle_class() is not None} | |
| def evaluate_with_rdkit_tdc( | |
| molecule: Mapping[str, str], | |
| fallback_properties: Mapping[str, float], | |
| ) -> Dict[str, float]: | |
| """Blend RDKit/TDC medicinal-chemistry signals into MolForge properties.""" | |
| modules = _rdkit_modules() | |
| if modules is None: | |
| return dict(fallback_properties) | |
| Chem = modules["Chem"] | |
| Descriptors = modules["Descriptors"] | |
| Crippen = modules["Crippen"] | |
| Lipinski = modules["Lipinski"] | |
| QED = modules["QED"] | |
| rdFingerprintGenerator = modules["rdFingerprintGenerator"] | |
| rdMolDescriptors = modules["rdMolDescriptors"] | |
| DataStructs = modules["DataStructs"] | |
| smiles = assemble_surrogate_smiles(molecule) | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return dict(fallback_properties) | |
| canonical = Chem.MolToSmiles(mol) | |
| qed_value = _tdc_oracle_score("QED", canonical) | |
| if qed_value is None: | |
| qed_value = float(QED.qed(mol)) | |
| qed_score = _clamp01(qed_value) | |
| sa_value = _tdc_oracle_score("SA", canonical) | |
| synth_score = _normalize_sa(sa_value) | |
| if synth_score is None: | |
| synth_score = _rdkit_synth_proxy(mol, Descriptors, Lipinski, rdMolDescriptors) | |
| logp = float(Crippen.MolLogP(mol)) | |
| tpsa = float(Descriptors.TPSA(mol)) | |
| mol_wt = float(Descriptors.MolWt(mol)) | |
| rotatable = float(Lipinski.NumRotatableBonds(mol)) | |
| aromatic_rings = float(rdMolDescriptors.CalcNumAromaticRings(mol)) | |
| property_risk = _property_risk(logp=logp, tpsa=tpsa, mol_wt=mol_wt, rotatable=rotatable) | |
| structural_risk = _structural_alert_risk(molecule) | |
| rdkit_toxicity = _clamp01(0.55 * property_risk + 0.45 * structural_risk) | |
| target_fit = _target_fit_proxy( | |
| molecule, | |
| qed_score=qed_score, | |
| logp=logp, | |
| tpsa=tpsa, | |
| aromatic_rings=aromatic_rings, | |
| ) | |
| novelty = _novelty_proxy(mol, Chem, rdFingerprintGenerator, DataStructs) | |
| return { | |
| "potency": round(_blend(fallback_properties["potency"], target_fit, 0.35), 4), | |
| "safety": round(_clamp01(1.0 - _blend(fallback_properties["toxicity"], rdkit_toxicity, 0.25)), 4), | |
| "toxicity": round(_blend(fallback_properties["toxicity"], rdkit_toxicity, 0.25), 4), | |
| "synth": round(_blend(fallback_properties["synth"], synth_score, 0.55), 4), | |
| "novelty": round(_blend(fallback_properties["novelty"], novelty, 0.50), 4), | |
| } | |
| def _rdkit_modules() -> Optional[Dict[str, Any]]: | |
| try: | |
| from rdkit import Chem, DataStructs | |
| from rdkit.Chem import Crippen, Descriptors, Lipinski, QED, rdFingerprintGenerator, rdMolDescriptors | |
| except Exception: | |
| return None | |
| return { | |
| "Chem": Chem, | |
| "Crippen": Crippen, | |
| "DataStructs": DataStructs, | |
| "Descriptors": Descriptors, | |
| "Lipinski": Lipinski, | |
| "QED": QED, | |
| "rdFingerprintGenerator": rdFingerprintGenerator, | |
| "rdMolDescriptors": rdMolDescriptors, | |
| } | |
| def _tdc_oracle_class() -> Optional[Any]: | |
| try: | |
| from tdc import Oracle | |
| except Exception: | |
| return None | |
| return Oracle | |
| def _tdc_oracle(name: str) -> Optional[Any]: | |
| oracle_class = _tdc_oracle_class() | |
| if oracle_class is None: | |
| return None | |
| try: | |
| return oracle_class(name=name) | |
| except Exception: | |
| return None | |
| def _tdc_oracle_score(name: str, smiles: str) -> Optional[float]: | |
| oracle = _tdc_oracle(name) | |
| if oracle is None: | |
| return None | |
| try: | |
| value = oracle(smiles) | |
| except Exception: | |
| return None | |
| try: | |
| return float(value) | |
| except (TypeError, ValueError): | |
| return None | |
| def _normalize_sa(value: Optional[float]) -> Optional[float]: | |
| if value is None: | |
| return None | |
| if 0.0 <= value <= 1.0: | |
| return _clamp01(value) | |
| return _clamp01((10.0 - value) / 9.0) | |
| def _rdkit_synth_proxy(mol: Any, Descriptors: Any, Lipinski: Any, rdMolDescriptors: Any) -> float: | |
| mol_wt = float(Descriptors.MolWt(mol)) | |
| rotatable = float(Lipinski.NumRotatableBonds(mol)) | |
| stereocenters = float(rdMolDescriptors.CalcNumAtomStereoCenters(mol)) | |
| ring_count = float(rdMolDescriptors.CalcNumRings(mol)) | |
| aromatic_rings = float(rdMolDescriptors.CalcNumAromaticRings(mol)) | |
| complexity = ( | |
| max(0.0, mol_wt - 350.0) / 260.0 | |
| + rotatable / 12.0 | |
| + stereocenters / 4.0 | |
| + max(0.0, ring_count - 3.0) / 4.0 | |
| + aromatic_rings / 8.0 | |
| ) | |
| return _clamp01(1.0 - 0.35 * complexity) | |
| def _property_risk(*, logp: float, tpsa: float, mol_wt: float, rotatable: float) -> float: | |
| logp_risk = _sigmoid((logp - 3.5) / 1.15) | |
| size_risk = _sigmoid((mol_wt - 500.0) / 90.0) | |
| flexibility_risk = _sigmoid((rotatable - 8.0) / 2.5) | |
| polarity_risk = _sigmoid((tpsa - 130.0) / 32.0) | |
| return _clamp01(0.42 * logp_risk + 0.24 * size_risk + 0.20 * flexibility_risk + 0.14 * polarity_risk) | |
| def _structural_alert_risk(molecule: Mapping[str, str]) -> float: | |
| risk = 0.18 | |
| if molecule["warhead"] == "acrylamide": | |
| risk += 0.12 | |
| if molecule["warhead"] == "vinyl_sulfonamide": | |
| risk += 0.22 | |
| if molecule["solvent_tail"] == "dimethylamino": | |
| risk += 0.24 | |
| if molecule["back_pocket"] == "trifluoromethyl": | |
| risk += 0.20 | |
| if molecule["hinge"] == "fluorophenyl" and molecule["back_pocket"] in {"chloro", "trifluoromethyl"}: | |
| risk += 0.12 | |
| if molecule["solvent_tail"] in {"morpholine", "piperazine"}: | |
| risk -= 0.08 | |
| if molecule["warhead"] == "nitrile": | |
| risk -= 0.08 | |
| return _clamp01(risk) | |
| def _target_fit_proxy( | |
| molecule: Mapping[str, str], | |
| *, | |
| qed_score: float, | |
| logp: float, | |
| tpsa: float, | |
| aromatic_rings: float, | |
| ) -> float: | |
| lipophilic_match = 1.0 - min(abs(logp - 3.0) / 4.0, 1.0) | |
| polarity_match = 1.0 - min(abs(tpsa - 85.0) / 110.0, 1.0) | |
| pocket_match = 0.0 | |
| if molecule["hinge"] in {"azaindole", "quinazoline"}: | |
| pocket_match += 0.18 | |
| if molecule["back_pocket"] in {"cyano", "chloro", "trifluoromethyl"}: | |
| pocket_match += 0.14 | |
| if molecule["warhead"] in {"acrylamide", "reversible_cyanoacrylamide", "nitrile"}: | |
| pocket_match += 0.12 | |
| if aromatic_rings >= 2: | |
| pocket_match += 0.08 | |
| return _clamp01(0.20 + 0.30 * lipophilic_match + 0.22 * polarity_match + 0.18 * qed_score + pocket_match) | |
| def _novelty_proxy(mol: Any, Chem: Any, rdFingerprintGenerator: Any, DataStructs: Any) -> float: | |
| refs = [ | |
| "c%10(C(=O)NC=C)c(c1ccncc1)c(C1CC1)c(OC)cc%10", | |
| "c%10(C#N)c(c1ccncc1)c(N1CCOCC1)c(C#N)cc%10", | |
| "c%10(C(=O)NC=C)c(c1ccc(F)cc1)c(N(C)C)c(Cl)cc%10", | |
| ] | |
| generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024) | |
| fp = generator.GetFingerprint(mol) | |
| similarities = [] | |
| for ref in refs: | |
| ref_mol = Chem.MolFromSmiles(ref) | |
| if ref_mol is None: | |
| continue | |
| ref_fp = generator.GetFingerprint(ref_mol) | |
| similarities.append(float(DataStructs.TanimotoSimilarity(fp, ref_fp))) | |
| if not similarities: | |
| return 0.5 | |
| return _clamp01(1.0 - max(similarities)) | |
| def _blend(fallback_value: float, oracle_value: float, oracle_weight: float) -> float: | |
| return _clamp01((1.0 - oracle_weight) * fallback_value + oracle_weight * oracle_value) | |
| def _sigmoid(value: float) -> float: | |
| return 1.0 / (1.0 + math.exp(-value)) | |
| def _clamp01(value: float) -> float: | |
| return min(max(float(value), 0.0), 1.0) | |