yzhouchen001 commited on
Commit
f4a27d9
·
1 Parent(s): a872583
Files changed (2) hide show
  1. app_utils/model_utils.py +16 -18
  2. app_utils/upload_model.py +19 -6
app_utils/model_utils.py CHANGED
@@ -1,31 +1,29 @@
1
- import sys
2
- # sys.path.insert(0, "/data/yzhouc01/MassSpecGym")
3
- # sys.path.insert(0, "/data/yzhouc01/FILIP-MS")
4
 
 
5
  from rdkit import RDLogger
6
- from flare.utils.data import get_spec_featurizer, get_mol_featurizer, get_ms_dataset
7
- from flare.utils.models import get_model
8
 
9
- import yaml
 
 
 
10
 
11
- # Suppress RDKit warnings and errors
12
  lg = RDLogger.logger()
13
  lg.setLevel(RDLogger.CRITICAL)
14
 
15
- # Load model and data
16
 
17
  def load_model_components():
18
- param_pth = 'hparams.yaml'
19
- with open(param_pth) as f:
20
  params = yaml.load(f, Loader=yaml.FullLoader)
 
21
 
22
- spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
23
- mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
24
-
25
- # load model
26
 
27
- checkpoint_pth = "pretrained_models/flare.ckpt"
28
- params['checkpoint_pth'] = checkpoint_pth
29
- model = get_model(params['model'], params)
30
 
31
- return spec_featurizer, mol_featurizer, model
 
1
+ import os
2
+ from pathlib import Path
 
3
 
4
+ import yaml
5
  from rdkit import RDLogger
 
 
6
 
7
+ from flare.definitions import REPO_DIR
8
+ from flare.utils.config import resolve_repo_paths
9
+ from flare.utils.data import get_mol_featurizer, get_spec_featurizer
10
+ from flare.utils.models import get_model
11
 
 
12
  lg = RDLogger.logger()
13
  lg.setLevel(RDLogger.CRITICAL)
14
 
 
15
 
16
  def load_model_components():
17
+ param_pth = Path(os.environ.get("FLARE_PARAMS", REPO_DIR / "params.yaml")).expanduser()
18
+ with open(param_pth, encoding="utf-8") as f:
19
  params = yaml.load(f, Loader=yaml.FullLoader)
20
+ resolve_repo_paths(params)
21
 
22
+ spec_featurizer = get_spec_featurizer(params["spectra_view"], params)
23
+ mol_featurizer = get_mol_featurizer(params["molecule_view"], params)
 
 
24
 
25
+ ckpt = os.environ.get("FLARE_CHECKPOINT", str(REPO_DIR / "pretrained_models" / "flare.ckpt"))
26
+ params["checkpoint_pth"] = ckpt
27
+ model = get_model(params["model"], params)
28
 
29
+ return spec_featurizer, mol_featurizer, model
app_utils/upload_model.py CHANGED
@@ -1,10 +1,23 @@
 
 
 
1
  from huggingface_hub import HfApi
2
 
 
 
 
 
 
 
 
 
 
 
3
  api = HfApi()
4
  api.upload_file(
5
- path_or_fileobj="experiments/20250913_optimized_filip-model/epoch=1993-train_loss=0.10.ckpt",
6
- path_in_repo="epoch=1993-train_loss=0.10.ckpt",
7
- repo_id="yzhouchen001/flare",
8
- repo_type="space", # automatically uses your saved token
9
- use_auth_token=True
10
- )
 
1
+ """Upload a checkpoint to a Hugging Face Space (or model repo). Configure via environment variables."""
2
+ import os
3
+
4
  from huggingface_hub import HfApi
5
 
6
+ # Local checkpoint path (relative to repo root or absolute).
7
+ CKPT_PATH = os.environ.get(
8
+ "FLARE_UPLOAD_CKPT",
9
+ str(os.path.join(os.path.dirname(__file__), "..", "pretrained_models", "flare.ckpt")),
10
+ )
11
+ # Destination path inside the repo.
12
+ PATH_IN_REPO = os.environ.get("FLARE_UPLOAD_PATH_IN_REPO", "pretrained_models/flare.ckpt")
13
+ REPO_ID = os.environ.get("HF_REPO_ID", "HassounLab/FLARE")
14
+ REPO_TYPE = os.environ.get("HF_REPO_TYPE", "space")
15
+
16
  api = HfApi()
17
  api.upload_file(
18
+ path_or_fileobj=CKPT_PATH,
19
+ path_in_repo=PATH_IN_REPO,
20
+ repo_id=REPO_ID,
21
+ repo_type=REPO_TYPE,
22
+ token=os.environ.get("HF_TOKEN"),
23
+ )