| """ |
| Self-contained loader for the molecular-shadows observable regressor. |
| |
| Usage: |
| from inference import MolecularShadowsRegressor |
| m = MolecularShadowsRegressor.from_local(".") # after cloning the HF repo |
| # or |
| m = MolecularShadowsRegressor.from_hub("aniketdesh/molecular-shadows-h2-v10", |
| revision="v10", # tag, branch, or commit |
| token="hf_...") # only for private repos |
| |
| # Predict 120 (or 28 for H2) observable expectations at (R, t): |
| y = m.predict(R=1.4, t=12.5) # scalar -> (n_observables,) |
| y = m.predict(R=[1.0, 1.4], t=[5, 10]) # batched -> (B, n_observables) |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| from observable_regressor import ( |
| ObservableRegressor, |
| ObservableRegressorConfig, |
| init_observable_regressor, |
| ) |
|
|
|
|
| class MolecularShadowsRegressor: |
| def __init__(self, model: ObservableRegressor, payload: dict, orb_grid: np.ndarray, |
| orb_table: np.ndarray, omega_op_table: np.ndarray | None = None, |
| device: str = "cpu"): |
| self.model = model.to(device).eval() |
| self.payload = payload |
| self.orb_grid = orb_grid |
| self.orb_table = orb_table |
| self.omega_op_table = omega_op_table |
| self.device = device |
|
|
| @property |
| def n_observables(self) -> int: |
| return self.payload["model_config"]["n_observables"] |
|
|
| @property |
| def observable_keys(self): |
| return self.payload["observable_keys"] |
|
|
| @property |
| def R_range(self): |
| return float(self.orb_grid.min()), float(self.orb_grid.max()) |
|
|
| @classmethod |
| def from_local(cls, repo_dir: str | os.PathLike, device: str = "cpu") -> "MolecularShadowsRegressor": |
| repo_dir = Path(repo_dir) |
| payload = torch.load(repo_dir / "regressor.pt", map_location=device, weights_only=False) |
| config = ObservableRegressorConfig(**payload["model_config"]) |
| model = init_observable_regressor(**config.to_dict()) |
| model.load_state_dict(payload["state_dict"]) |
|
|
| orb_npz = np.load(repo_dir / "orbital_energies.npz") |
| orb_grid = orb_npz["R_grid"] |
| orb_table = orb_npz["orbital_energies"] |
| omega_op_table = orb_npz["omega_op"] if "omega_op" in orb_npz.files else None |
|
|
| return cls(model, payload, orb_grid, orb_table, omega_op_table, device=device) |
|
|
| @classmethod |
| def from_hub(cls, repo_id: str, revision: str | None = None, |
| token: str | None = None, device: str = "cpu", |
| cache_dir: str | None = None) -> "MolecularShadowsRegressor": |
| from huggingface_hub import snapshot_download |
| local = snapshot_download(repo_id=repo_id, revision=revision, token=token, cache_dir=cache_dir) |
| return cls.from_local(local, device=device) |
|
|
| def _interp_orb(self, R: np.ndarray) -> np.ndarray: |
| |
| out = np.empty((R.shape[0], self.orb_table.shape[1]), dtype=np.float32) |
| for k in range(self.orb_table.shape[1]): |
| out[:, k] = np.interp(R, self.orb_grid, self.orb_table[:, k]) |
| return out |
|
|
| def _interp_omega_op(self, R: np.ndarray) -> np.ndarray: |
| if self.omega_op_table is None: |
| return None |
| return np.interp(R, self.orb_grid, self.omega_op_table).astype(np.float32) |
|
|
| @torch.no_grad() |
| def predict(self, R, t): |
| R_arr = np.atleast_1d(np.asarray(R, dtype=np.float32)) |
| t_arr = np.atleast_1d(np.asarray(t, dtype=np.float32)) |
| if R_arr.shape != t_arr.shape: |
| R_arr, t_arr = np.broadcast_arrays(R_arr, t_arr) |
| R_arr = np.ascontiguousarray(R_arr) |
| t_arr = np.ascontiguousarray(t_arr) |
|
|
| rt = torch.from_numpy(np.stack([R_arr, t_arr], axis=-1)).to(self.device) |
|
|
| kwargs = {} |
| if self.payload["model_config"].get("n_orb_features", 0) > 0: |
| orb = self._interp_orb(R_arr) |
| kwargs["orb_energies"] = torch.from_numpy(orb).to(self.device) |
| if self.payload["model_config"].get("adaptive_bandwidth", False): |
| omega_op = self._interp_omega_op(R_arr) |
| if omega_op is None: |
| raise ValueError("This checkpoint requires omega_op but none is bundled.") |
| kwargs["omega_op"] = torch.from_numpy(omega_op).to(self.device) |
|
|
| y = self.model(rt, **kwargs).cpu().numpy() |
| return y[0] if np.isscalar(R) and np.isscalar(t) else y |
|
|
| @torch.no_grad() |
| def predict_trajectory(self, R: float, t_grid: np.ndarray): |
| """Convenience: full time series at fixed R. Returns (len(t_grid), n_observables).""" |
| t_grid = np.asarray(t_grid, dtype=np.float32) |
| return self.predict(np.full_like(t_grid, R, dtype=np.float32), t_grid) |
|
|
|
|
| def _print_metadata_summary(m: MolecularShadowsRegressor): |
| print(f" n_observables: {m.n_observables}") |
| print(f" R range: {m.R_range[0]:.3f} -> {m.R_range[1]:.3f} A") |
| print(f" config: {json.dumps(m.payload['model_config'])}") |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--repo_dir", default=".", help="Local checkout of the HF repo") |
| ap.add_argument("--R", type=float, default=1.4) |
| ap.add_argument("--t", type=float, default=10.0) |
| args = ap.parse_args() |
|
|
| m = MolecularShadowsRegressor.from_local(args.repo_dir) |
| _print_metadata_summary(m) |
| y = m.predict(args.R, args.t) |
| print(f" predict(R={args.R}, t={args.t}) -> shape {y.shape}, " |
| f"min={y.min():.3e}, max={y.max():.3e}, mean={y.mean():.3e}") |
|
|