Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| import yaml | |
| from rdkit import RDLogger | |
| from flare.definitions import REPO_DIR | |
| from flare.utils.config import resolve_repo_paths | |
| from flare.utils.data import get_mol_featurizer, get_spec_featurizer | |
| from flare.utils.models import get_model | |
| lg = RDLogger.logger() | |
| lg.setLevel(RDLogger.CRITICAL) | |
| def load_model_components(): | |
| param_pth = Path(os.environ.get("FLARE_PARAMS", REPO_DIR / "params.yaml")).expanduser() | |
| with open(param_pth, encoding="utf-8") as f: | |
| params = yaml.load(f, Loader=yaml.FullLoader) | |
| resolve_repo_paths(params) | |
| spec_featurizer = get_spec_featurizer(params["spectra_view"], params) | |
| mol_featurizer = get_mol_featurizer(params["molecule_view"], params) | |
| ckpt = os.environ.get("FLARE_CHECKPOINT", str(REPO_DIR / "pretrained_models" / "flare.ckpt")) | |
| params["checkpoint_pth"] = ckpt | |
| model = get_model(params["model"], params) | |
| return spec_featurizer, mol_featurizer, model | |