Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| STATE_DIM = 4 | |
| ACTION_DIM = 2 | |
| INPUT_DIM = STATE_DIM + ACTION_DIM | |
| class _DynamicsNet(nn.Module): | |
| def __init__(self, hidden_dim: int = 256): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(INPUT_DIM, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, STATE_DIM), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class EnsembleDynamicsModel: | |
| """ | |
| Ensemble of MLPs that learn P(s' | s, a) from an offline dataset. | |
| Uncertainty is measured as the std of next-state predictions across | |
| ensemble members β high std β out-of-distribution transition. | |
| """ | |
| def __init__(self, n_models: int = 5, hidden_dim: int = 256, lr: float = 1e-3): | |
| self.n_models = n_models | |
| self.models = [_DynamicsNet(hidden_dim) for _ in range(n_models)] | |
| self.optimizers = [torch.optim.Adam(m.parameters(), lr=lr) for m in self.models] | |
| # Normalisation statistics (set during train) | |
| self.state_mean: np.ndarray = np.zeros(STATE_DIM, dtype=np.float32) | |
| self.state_std: np.ndarray = np.ones(STATE_DIM, dtype=np.float32) | |
| # ββ Normalisation helpers ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _norm_state(self, s: np.ndarray) -> np.ndarray: | |
| return (s - self.state_mean) / (self.state_std + 1e-8) | |
| def _denorm_state(self, s: np.ndarray) -> np.ndarray: | |
| return s * (self.state_std + 1e-8) + self.state_mean | |
| # ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train( | |
| self, | |
| dataset: list, | |
| n_epochs: int = 100, | |
| batch_size: int = 512, | |
| ) -> None: | |
| from .collect_data import dataset_to_arrays | |
| states, actions, next_states = dataset_to_arrays(dataset) | |
| # Fit normalisation on full dataset | |
| self.state_mean = states.mean(axis=0) | |
| self.state_std = states.std(axis=0) + 1e-8 | |
| states_n = (states - self.state_mean) / self.state_std | |
| next_states_n = (next_states - self.state_mean) / self.state_std | |
| X = np.concatenate([states_n, actions], axis=1).astype(np.float32) # (N, 6) | |
| Y = next_states_n.astype(np.float32) # (N, 4) | |
| N = len(X) | |
| for model_idx, (model, opt) in enumerate(zip(self.models, self.optimizers)): | |
| # Bootstrap sample β each model sees a different data subset | |
| boot_idx = np.random.choice(N, N, replace=True) | |
| Xb, Yb = X[boot_idx], Y[boot_idx] | |
| for epoch in tqdm(range(n_epochs), desc=f"Dynamics model {model_idx+1}/{self.n_models}", leave=False): | |
| perm = np.random.permutation(N) | |
| epoch_loss = 0.0 | |
| for i in range(0, N, batch_size): | |
| idx = perm[i : i + batch_size] | |
| x_t = torch.from_numpy(Xb[idx]) | |
| y_t = torch.from_numpy(Yb[idx]) | |
| pred = model(x_t) | |
| loss = nn.MSELoss()(pred, y_t) | |
| opt.zero_grad() | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| epoch_loss += loss.item() | |
| print("Dynamics model training complete.") | |
| # ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict( | |
| self, | |
| state: np.ndarray, | |
| action, | |
| ) -> tuple: | |
| """ | |
| Returns (mean_next_state, transition_uncertainty). | |
| transition_uncertainty = mean std across state dims, averaged over ensemble. | |
| Higher value β ensemble disagrees β out-of-distribution. | |
| """ | |
| state_n = self._norm_state(state).astype(np.float32) | |
| action_oh = np.eye(ACTION_DIM, dtype=np.float32)[int(action)] | |
| x = torch.from_numpy(np.concatenate([state_n, action_oh])).unsqueeze(0) # (1, 6) | |
| preds = [] | |
| with torch.no_grad(): | |
| for model in self.models: | |
| preds.append(model(x).squeeze(0).numpy()) # (4,) | |
| preds = np.stack(preds) # (n_models, 4) | |
| mean_n = preds.mean(axis=0) # normalised space | |
| std_n = preds.std(axis=0) # normalised space | |
| mean = self._denorm_state(mean_n) | |
| uncertainty = float(std_n.mean()) # scalar summary of disagreement | |
| return mean, uncertainty | |
| # ββ Persistence βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def save(self, path: str) -> None: | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "state_dicts": [m.state_dict() for m in self.models], | |
| "state_mean": self.state_mean, | |
| "state_std": self.state_std, | |
| } | |
| torch.save(payload, path) | |
| print(f"Dynamics model saved β {path}") | |
| def load(self, path: str) -> None: | |
| payload = torch.load(path, map_location="cpu", weights_only=False) | |
| for model, sd in zip(self.models, payload["state_dicts"]): | |
| model.load_state_dict(sd) | |
| self.state_mean = payload["state_mean"] | |
| self.state_std = payload["state_std"] | |
| print(f"Dynamics model loaded β {path}") | |