FLARE / app_utils /model_utils.py
yzhouchen001's picture
clean up
f4a27d9
raw
history blame contribute delete
968 Bytes
import os
from pathlib import Path
import yaml
from rdkit import RDLogger
from flare.definitions import REPO_DIR
from flare.utils.config import resolve_repo_paths
from flare.utils.data import get_mol_featurizer, get_spec_featurizer
from flare.utils.models import get_model
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
def load_model_components():
param_pth = Path(os.environ.get("FLARE_PARAMS", REPO_DIR / "params.yaml")).expanduser()
with open(param_pth, encoding="utf-8") as f:
params = yaml.load(f, Loader=yaml.FullLoader)
resolve_repo_paths(params)
spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
ckpt = os.environ.get("FLARE_CHECKPOINT", str(REPO_DIR / "pretrained_models" / "flare.ckpt"))
params["checkpoint_pth"] = ckpt
model = get_model(params["model"], params)
return spec_featurizer, mol_featurizer, model