molforge / molforge_oracles.py
Adhitya122's picture
Prepare MolForge OpenEnv Docker Space submission
bf9e424 verified
"""RDKit/TDC-backed molecular oracle helpers for MolForge."""
from __future__ import annotations
import math
from functools import lru_cache
from typing import Any, Dict, Mapping, Optional
WARHEAD_SMILES = {
"acrylamide": "C(=O)NC=C",
"reversible_cyanoacrylamide": "C(=O)NC(=C)C#N",
"nitrile": "C#N",
"vinyl_sulfonamide": "S(=O)(=O)NC=C",
}
HINGE_SMILES = {
"azaindole": "c1[nH]c2ccccc2n1",
"pyridine": "c1ccncc1",
"fluorophenyl": "c1ccc(F)cc1",
"quinazoline": "c1ncnc2ccccc12",
}
TAIL_SMILES = {
"morpholine": "N1CCOCC1",
"piperazine": "N1CCNCC1",
"cyclopropyl": "C1CC1",
"dimethylamino": "N(C)C",
}
BACK_POCKET_SMILES = {
"methoxy": "OC",
"chloro": "Cl",
"trifluoromethyl": "C(F)(F)F",
"cyano": "C#N",
}
def assemble_surrogate_smiles(molecule: Mapping[str, str]) -> str:
"""Build a valid substituted-aryl SMILES for RDKit/TDC scoring."""
return (
f"c%10({WARHEAD_SMILES[molecule['warhead']]})"
f"c({HINGE_SMILES[molecule['hinge']]})"
f"c({TAIL_SMILES[molecule['solvent_tail']]})"
f"c({BACK_POCKET_SMILES[molecule['back_pocket']]})cc%10"
)
def oracle_backend_status() -> Dict[str, bool]:
"""Report which external chemistry engines are importable."""
return {"rdkit": _rdkit_modules() is not None, "tdc": _tdc_oracle_class() is not None}
def evaluate_with_rdkit_tdc(
molecule: Mapping[str, str],
fallback_properties: Mapping[str, float],
) -> Dict[str, float]:
"""Blend RDKit/TDC medicinal-chemistry signals into MolForge properties."""
modules = _rdkit_modules()
if modules is None:
return dict(fallback_properties)
Chem = modules["Chem"]
Descriptors = modules["Descriptors"]
Crippen = modules["Crippen"]
Lipinski = modules["Lipinski"]
QED = modules["QED"]
rdFingerprintGenerator = modules["rdFingerprintGenerator"]
rdMolDescriptors = modules["rdMolDescriptors"]
DataStructs = modules["DataStructs"]
smiles = assemble_surrogate_smiles(molecule)
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return dict(fallback_properties)
canonical = Chem.MolToSmiles(mol)
qed_value = _tdc_oracle_score("QED", canonical)
if qed_value is None:
qed_value = float(QED.qed(mol))
qed_score = _clamp01(qed_value)
sa_value = _tdc_oracle_score("SA", canonical)
synth_score = _normalize_sa(sa_value)
if synth_score is None:
synth_score = _rdkit_synth_proxy(mol, Descriptors, Lipinski, rdMolDescriptors)
logp = float(Crippen.MolLogP(mol))
tpsa = float(Descriptors.TPSA(mol))
mol_wt = float(Descriptors.MolWt(mol))
rotatable = float(Lipinski.NumRotatableBonds(mol))
aromatic_rings = float(rdMolDescriptors.CalcNumAromaticRings(mol))
property_risk = _property_risk(logp=logp, tpsa=tpsa, mol_wt=mol_wt, rotatable=rotatable)
structural_risk = _structural_alert_risk(molecule)
rdkit_toxicity = _clamp01(0.55 * property_risk + 0.45 * structural_risk)
target_fit = _target_fit_proxy(
molecule,
qed_score=qed_score,
logp=logp,
tpsa=tpsa,
aromatic_rings=aromatic_rings,
)
novelty = _novelty_proxy(mol, Chem, rdFingerprintGenerator, DataStructs)
return {
"potency": round(_blend(fallback_properties["potency"], target_fit, 0.35), 4),
"safety": round(_clamp01(1.0 - _blend(fallback_properties["toxicity"], rdkit_toxicity, 0.25)), 4),
"toxicity": round(_blend(fallback_properties["toxicity"], rdkit_toxicity, 0.25), 4),
"synth": round(_blend(fallback_properties["synth"], synth_score, 0.55), 4),
"novelty": round(_blend(fallback_properties["novelty"], novelty, 0.50), 4),
}
@lru_cache(maxsize=1)
def _rdkit_modules() -> Optional[Dict[str, Any]]:
try:
from rdkit import Chem, DataStructs
from rdkit.Chem import Crippen, Descriptors, Lipinski, QED, rdFingerprintGenerator, rdMolDescriptors
except Exception:
return None
return {
"Chem": Chem,
"Crippen": Crippen,
"DataStructs": DataStructs,
"Descriptors": Descriptors,
"Lipinski": Lipinski,
"QED": QED,
"rdFingerprintGenerator": rdFingerprintGenerator,
"rdMolDescriptors": rdMolDescriptors,
}
@lru_cache(maxsize=1)
def _tdc_oracle_class() -> Optional[Any]:
try:
from tdc import Oracle
except Exception:
return None
return Oracle
@lru_cache(maxsize=8)
def _tdc_oracle(name: str) -> Optional[Any]:
oracle_class = _tdc_oracle_class()
if oracle_class is None:
return None
try:
return oracle_class(name=name)
except Exception:
return None
def _tdc_oracle_score(name: str, smiles: str) -> Optional[float]:
oracle = _tdc_oracle(name)
if oracle is None:
return None
try:
value = oracle(smiles)
except Exception:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def _normalize_sa(value: Optional[float]) -> Optional[float]:
if value is None:
return None
if 0.0 <= value <= 1.0:
return _clamp01(value)
return _clamp01((10.0 - value) / 9.0)
def _rdkit_synth_proxy(mol: Any, Descriptors: Any, Lipinski: Any, rdMolDescriptors: Any) -> float:
mol_wt = float(Descriptors.MolWt(mol))
rotatable = float(Lipinski.NumRotatableBonds(mol))
stereocenters = float(rdMolDescriptors.CalcNumAtomStereoCenters(mol))
ring_count = float(rdMolDescriptors.CalcNumRings(mol))
aromatic_rings = float(rdMolDescriptors.CalcNumAromaticRings(mol))
complexity = (
max(0.0, mol_wt - 350.0) / 260.0
+ rotatable / 12.0
+ stereocenters / 4.0
+ max(0.0, ring_count - 3.0) / 4.0
+ aromatic_rings / 8.0
)
return _clamp01(1.0 - 0.35 * complexity)
def _property_risk(*, logp: float, tpsa: float, mol_wt: float, rotatable: float) -> float:
logp_risk = _sigmoid((logp - 3.5) / 1.15)
size_risk = _sigmoid((mol_wt - 500.0) / 90.0)
flexibility_risk = _sigmoid((rotatable - 8.0) / 2.5)
polarity_risk = _sigmoid((tpsa - 130.0) / 32.0)
return _clamp01(0.42 * logp_risk + 0.24 * size_risk + 0.20 * flexibility_risk + 0.14 * polarity_risk)
def _structural_alert_risk(molecule: Mapping[str, str]) -> float:
risk = 0.18
if molecule["warhead"] == "acrylamide":
risk += 0.12
if molecule["warhead"] == "vinyl_sulfonamide":
risk += 0.22
if molecule["solvent_tail"] == "dimethylamino":
risk += 0.24
if molecule["back_pocket"] == "trifluoromethyl":
risk += 0.20
if molecule["hinge"] == "fluorophenyl" and molecule["back_pocket"] in {"chloro", "trifluoromethyl"}:
risk += 0.12
if molecule["solvent_tail"] in {"morpholine", "piperazine"}:
risk -= 0.08
if molecule["warhead"] == "nitrile":
risk -= 0.08
return _clamp01(risk)
def _target_fit_proxy(
molecule: Mapping[str, str],
*,
qed_score: float,
logp: float,
tpsa: float,
aromatic_rings: float,
) -> float:
lipophilic_match = 1.0 - min(abs(logp - 3.0) / 4.0, 1.0)
polarity_match = 1.0 - min(abs(tpsa - 85.0) / 110.0, 1.0)
pocket_match = 0.0
if molecule["hinge"] in {"azaindole", "quinazoline"}:
pocket_match += 0.18
if molecule["back_pocket"] in {"cyano", "chloro", "trifluoromethyl"}:
pocket_match += 0.14
if molecule["warhead"] in {"acrylamide", "reversible_cyanoacrylamide", "nitrile"}:
pocket_match += 0.12
if aromatic_rings >= 2:
pocket_match += 0.08
return _clamp01(0.20 + 0.30 * lipophilic_match + 0.22 * polarity_match + 0.18 * qed_score + pocket_match)
def _novelty_proxy(mol: Any, Chem: Any, rdFingerprintGenerator: Any, DataStructs: Any) -> float:
refs = [
"c%10(C(=O)NC=C)c(c1ccncc1)c(C1CC1)c(OC)cc%10",
"c%10(C#N)c(c1ccncc1)c(N1CCOCC1)c(C#N)cc%10",
"c%10(C(=O)NC=C)c(c1ccc(F)cc1)c(N(C)C)c(Cl)cc%10",
]
generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)
fp = generator.GetFingerprint(mol)
similarities = []
for ref in refs:
ref_mol = Chem.MolFromSmiles(ref)
if ref_mol is None:
continue
ref_fp = generator.GetFingerprint(ref_mol)
similarities.append(float(DataStructs.TanimotoSimilarity(fp, ref_fp)))
if not similarities:
return 0.5
return _clamp01(1.0 - max(similarities))
def _blend(fallback_value: float, oracle_value: float, oracle_weight: float) -> float:
return _clamp01((1.0 - oracle_weight) * fallback_value + oracle_weight * oracle_value)
def _sigmoid(value: float) -> float:
return 1.0 / (1.0 + math.exp(-value))
def _clamp01(value: float) -> float:
return min(max(float(value), 0.0), 1.0)