import os from pathlib import Path import yaml from rdkit import RDLogger from flare.definitions import REPO_DIR from flare.utils.config import resolve_repo_paths from flare.utils.data import get_mol_featurizer, get_spec_featurizer from flare.utils.models import get_model lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) def load_model_components(): param_pth = Path(os.environ.get("FLARE_PARAMS", REPO_DIR / "params.yaml")).expanduser() with open(param_pth, encoding="utf-8") as f: params = yaml.load(f, Loader=yaml.FullLoader) resolve_repo_paths(params) spec_featurizer = get_spec_featurizer(params["spectra_view"], params) mol_featurizer = get_mol_featurizer(params["molecule_view"], params) ckpt = os.environ.get("FLARE_CHECKPOINT", str(REPO_DIR / "pretrained_models" / "flare.ckpt")) params["checkpoint_pth"] = ckpt model = get_model(params["model"], params) return spec_featurizer, mol_featurizer, model