sim-oprl / simoprl /dynamics_model.py
singhanshuman's picture
Upload simoprl/dynamics_model.py with huggingface_hub
248472d verified
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from tqdm import tqdm
STATE_DIM = 4
ACTION_DIM = 2
INPUT_DIM = STATE_DIM + ACTION_DIM
class _DynamicsNet(nn.Module):
def __init__(self, hidden_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(INPUT_DIM, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, STATE_DIM),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class EnsembleDynamicsModel:
"""
Ensemble of MLPs that learn P(s' | s, a) from an offline dataset.
Uncertainty is measured as the std of next-state predictions across
ensemble members β€” high std β†’ out-of-distribution transition.
"""
def __init__(self, n_models: int = 5, hidden_dim: int = 256, lr: float = 1e-3):
self.n_models = n_models
self.models = [_DynamicsNet(hidden_dim) for _ in range(n_models)]
self.optimizers = [torch.optim.Adam(m.parameters(), lr=lr) for m in self.models]
# Normalisation statistics (set during train)
self.state_mean: np.ndarray = np.zeros(STATE_DIM, dtype=np.float32)
self.state_std: np.ndarray = np.ones(STATE_DIM, dtype=np.float32)
# ── Normalisation helpers ──────────────────────────────────────────────
def _norm_state(self, s: np.ndarray) -> np.ndarray:
return (s - self.state_mean) / (self.state_std + 1e-8)
def _denorm_state(self, s: np.ndarray) -> np.ndarray:
return s * (self.state_std + 1e-8) + self.state_mean
# ── Training ──────────────────────────────────────────────────────────
def train(
self,
dataset: list,
n_epochs: int = 100,
batch_size: int = 512,
) -> None:
from .collect_data import dataset_to_arrays
states, actions, next_states = dataset_to_arrays(dataset)
# Fit normalisation on full dataset
self.state_mean = states.mean(axis=0)
self.state_std = states.std(axis=0) + 1e-8
states_n = (states - self.state_mean) / self.state_std
next_states_n = (next_states - self.state_mean) / self.state_std
X = np.concatenate([states_n, actions], axis=1).astype(np.float32) # (N, 6)
Y = next_states_n.astype(np.float32) # (N, 4)
N = len(X)
for model_idx, (model, opt) in enumerate(zip(self.models, self.optimizers)):
# Bootstrap sample β€” each model sees a different data subset
boot_idx = np.random.choice(N, N, replace=True)
Xb, Yb = X[boot_idx], Y[boot_idx]
for epoch in tqdm(range(n_epochs), desc=f"Dynamics model {model_idx+1}/{self.n_models}", leave=False):
perm = np.random.permutation(N)
epoch_loss = 0.0
for i in range(0, N, batch_size):
idx = perm[i : i + batch_size]
x_t = torch.from_numpy(Xb[idx])
y_t = torch.from_numpy(Yb[idx])
pred = model(x_t)
loss = nn.MSELoss()(pred, y_t)
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
epoch_loss += loss.item()
print("Dynamics model training complete.")
# ── Inference ─────────────────────────────────────────────────────────
def predict(
self,
state: np.ndarray,
action,
) -> tuple:
"""
Returns (mean_next_state, transition_uncertainty).
transition_uncertainty = mean std across state dims, averaged over ensemble.
Higher value β†’ ensemble disagrees β†’ out-of-distribution.
"""
state_n = self._norm_state(state).astype(np.float32)
action_oh = np.eye(ACTION_DIM, dtype=np.float32)[int(action)]
x = torch.from_numpy(np.concatenate([state_n, action_oh])).unsqueeze(0) # (1, 6)
preds = []
with torch.no_grad():
for model in self.models:
preds.append(model(x).squeeze(0).numpy()) # (4,)
preds = np.stack(preds) # (n_models, 4)
mean_n = preds.mean(axis=0) # normalised space
std_n = preds.std(axis=0) # normalised space
mean = self._denorm_state(mean_n)
uncertainty = float(std_n.mean()) # scalar summary of disagreement
return mean, uncertainty
# ── Persistence ───────────────────────────────────────────────────────
def save(self, path: str) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
payload = {
"state_dicts": [m.state_dict() for m in self.models],
"state_mean": self.state_mean,
"state_std": self.state_std,
}
torch.save(payload, path)
print(f"Dynamics model saved β†’ {path}")
def load(self, path: str) -> None:
payload = torch.load(path, map_location="cpu", weights_only=False)
for model, sd in zip(self.models, payload["state_dicts"]):
model.load_state_dict(sd)
self.state_mean = payload["state_mean"]
self.state_std = payload["state_std"]
print(f"Dynamics model loaded ← {path}")