molecular-shadows-h2-v10 / inference.py
aniketdesh's picture
upload molecular-shadows-h2-v10 (v10)
d8e765d verified
"""
Self-contained loader for the molecular-shadows observable regressor.
Usage:
from inference import MolecularShadowsRegressor
m = MolecularShadowsRegressor.from_local(".") # after cloning the HF repo
# or
m = MolecularShadowsRegressor.from_hub("aniketdesh/molecular-shadows-h2-v10",
revision="v10", # tag, branch, or commit
token="hf_...") # only for private repos
# Predict 120 (or 28 for H2) observable expectations at (R, t):
y = m.predict(R=1.4, t=12.5) # scalar -> (n_observables,)
y = m.predict(R=[1.0, 1.4], t=[5, 10]) # batched -> (B, n_observables)
"""
from __future__ import annotations
import json
import os
from pathlib import Path
import numpy as np
import torch
from observable_regressor import (
ObservableRegressor,
ObservableRegressorConfig,
init_observable_regressor,
)
class MolecularShadowsRegressor:
def __init__(self, model: ObservableRegressor, payload: dict, orb_grid: np.ndarray,
orb_table: np.ndarray, omega_op_table: np.ndarray | None = None,
device: str = "cpu"):
self.model = model.to(device).eval()
self.payload = payload
self.orb_grid = orb_grid # (n_R,)
self.orb_table = orb_table # (n_R, n_orb)
self.omega_op_table = omega_op_table # (n_R,) or None
self.device = device
@property
def n_observables(self) -> int:
return self.payload["model_config"]["n_observables"]
@property
def observable_keys(self):
return self.payload["observable_keys"]
@property
def R_range(self):
return float(self.orb_grid.min()), float(self.orb_grid.max())
@classmethod
def from_local(cls, repo_dir: str | os.PathLike, device: str = "cpu") -> "MolecularShadowsRegressor":
repo_dir = Path(repo_dir)
payload = torch.load(repo_dir / "regressor.pt", map_location=device, weights_only=False)
config = ObservableRegressorConfig(**payload["model_config"])
model = init_observable_regressor(**config.to_dict())
model.load_state_dict(payload["state_dict"])
orb_npz = np.load(repo_dir / "orbital_energies.npz")
orb_grid = orb_npz["R_grid"]
orb_table = orb_npz["orbital_energies"]
omega_op_table = orb_npz["omega_op"] if "omega_op" in orb_npz.files else None
return cls(model, payload, orb_grid, orb_table, omega_op_table, device=device)
@classmethod
def from_hub(cls, repo_id: str, revision: str | None = None,
token: str | None = None, device: str = "cpu",
cache_dir: str | None = None) -> "MolecularShadowsRegressor":
from huggingface_hub import snapshot_download
local = snapshot_download(repo_id=repo_id, revision=revision, token=token, cache_dir=cache_dir)
return cls.from_local(local, device=device)
def _interp_orb(self, R: np.ndarray) -> np.ndarray:
# Linear per-orbital interpolation on the bundled R-grid
out = np.empty((R.shape[0], self.orb_table.shape[1]), dtype=np.float32)
for k in range(self.orb_table.shape[1]):
out[:, k] = np.interp(R, self.orb_grid, self.orb_table[:, k])
return out
def _interp_omega_op(self, R: np.ndarray) -> np.ndarray:
if self.omega_op_table is None:
return None
return np.interp(R, self.orb_grid, self.omega_op_table).astype(np.float32)
@torch.no_grad()
def predict(self, R, t):
R_arr = np.atleast_1d(np.asarray(R, dtype=np.float32))
t_arr = np.atleast_1d(np.asarray(t, dtype=np.float32))
if R_arr.shape != t_arr.shape:
R_arr, t_arr = np.broadcast_arrays(R_arr, t_arr)
R_arr = np.ascontiguousarray(R_arr)
t_arr = np.ascontiguousarray(t_arr)
rt = torch.from_numpy(np.stack([R_arr, t_arr], axis=-1)).to(self.device)
kwargs = {}
if self.payload["model_config"].get("n_orb_features", 0) > 0:
orb = self._interp_orb(R_arr)
kwargs["orb_energies"] = torch.from_numpy(orb).to(self.device)
if self.payload["model_config"].get("adaptive_bandwidth", False):
omega_op = self._interp_omega_op(R_arr)
if omega_op is None:
raise ValueError("This checkpoint requires omega_op but none is bundled.")
kwargs["omega_op"] = torch.from_numpy(omega_op).to(self.device)
y = self.model(rt, **kwargs).cpu().numpy()
return y[0] if np.isscalar(R) and np.isscalar(t) else y
@torch.no_grad()
def predict_trajectory(self, R: float, t_grid: np.ndarray):
"""Convenience: full time series at fixed R. Returns (len(t_grid), n_observables)."""
t_grid = np.asarray(t_grid, dtype=np.float32)
return self.predict(np.full_like(t_grid, R, dtype=np.float32), t_grid)
def _print_metadata_summary(m: MolecularShadowsRegressor):
print(f" n_observables: {m.n_observables}")
print(f" R range: {m.R_range[0]:.3f} -> {m.R_range[1]:.3f} A")
print(f" config: {json.dumps(m.payload['model_config'])}")
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--repo_dir", default=".", help="Local checkout of the HF repo")
ap.add_argument("--R", type=float, default=1.4)
ap.add_argument("--t", type=float, default=10.0)
args = ap.parse_args()
m = MolecularShadowsRegressor.from_local(args.repo_dir)
_print_metadata_summary(m)
y = m.predict(args.R, args.t)
print(f" predict(R={args.R}, t={args.t}) -> shape {y.shape}, "
f"min={y.min():.3e}, max={y.max():.3e}, mean={y.mean():.3e}")