singhanshuman commited on
Commit
248472d
Β·
verified Β·
1 Parent(s): 58b6f47

Upload simoprl/dynamics_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. simoprl/dynamics_model.py +143 -0
simoprl/dynamics_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from tqdm import tqdm
6
+
7
+ STATE_DIM = 4
8
+ ACTION_DIM = 2
9
+ INPUT_DIM = STATE_DIM + ACTION_DIM
10
+
11
+
12
+ class _DynamicsNet(nn.Module):
13
+ def __init__(self, hidden_dim: int = 256):
14
+ super().__init__()
15
+ self.net = nn.Sequential(
16
+ nn.Linear(INPUT_DIM, hidden_dim),
17
+ nn.SiLU(),
18
+ nn.Linear(hidden_dim, hidden_dim),
19
+ nn.SiLU(),
20
+ nn.Linear(hidden_dim, STATE_DIM),
21
+ )
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ return self.net(x)
25
+
26
+
27
+ class EnsembleDynamicsModel:
28
+ """
29
+ Ensemble of MLPs that learn P(s' | s, a) from an offline dataset.
30
+
31
+ Uncertainty is measured as the std of next-state predictions across
32
+ ensemble members β€” high std β†’ out-of-distribution transition.
33
+ """
34
+
35
+ def __init__(self, n_models: int = 5, hidden_dim: int = 256, lr: float = 1e-3):
36
+ self.n_models = n_models
37
+ self.models = [_DynamicsNet(hidden_dim) for _ in range(n_models)]
38
+ self.optimizers = [torch.optim.Adam(m.parameters(), lr=lr) for m in self.models]
39
+ # Normalisation statistics (set during train)
40
+ self.state_mean: np.ndarray = np.zeros(STATE_DIM, dtype=np.float32)
41
+ self.state_std: np.ndarray = np.ones(STATE_DIM, dtype=np.float32)
42
+
43
+ # ── Normalisation helpers ──────────────────────────────────────────────
44
+
45
+ def _norm_state(self, s: np.ndarray) -> np.ndarray:
46
+ return (s - self.state_mean) / (self.state_std + 1e-8)
47
+
48
+ def _denorm_state(self, s: np.ndarray) -> np.ndarray:
49
+ return s * (self.state_std + 1e-8) + self.state_mean
50
+
51
+ # ── Training ──────────────────────────────────────────────────────────
52
+
53
+ def train(
54
+ self,
55
+ dataset: list,
56
+ n_epochs: int = 100,
57
+ batch_size: int = 512,
58
+ ) -> None:
59
+ from .collect_data import dataset_to_arrays
60
+
61
+ states, actions, next_states = dataset_to_arrays(dataset)
62
+
63
+ # Fit normalisation on full dataset
64
+ self.state_mean = states.mean(axis=0)
65
+ self.state_std = states.std(axis=0) + 1e-8
66
+
67
+ states_n = (states - self.state_mean) / self.state_std
68
+ next_states_n = (next_states - self.state_mean) / self.state_std
69
+ X = np.concatenate([states_n, actions], axis=1).astype(np.float32) # (N, 6)
70
+ Y = next_states_n.astype(np.float32) # (N, 4)
71
+ N = len(X)
72
+
73
+ for model_idx, (model, opt) in enumerate(zip(self.models, self.optimizers)):
74
+ # Bootstrap sample β€” each model sees a different data subset
75
+ boot_idx = np.random.choice(N, N, replace=True)
76
+ Xb, Yb = X[boot_idx], Y[boot_idx]
77
+
78
+ for epoch in tqdm(range(n_epochs), desc=f"Dynamics model {model_idx+1}/{self.n_models}", leave=False):
79
+ perm = np.random.permutation(N)
80
+ epoch_loss = 0.0
81
+ for i in range(0, N, batch_size):
82
+ idx = perm[i : i + batch_size]
83
+ x_t = torch.from_numpy(Xb[idx])
84
+ y_t = torch.from_numpy(Yb[idx])
85
+ pred = model(x_t)
86
+ loss = nn.MSELoss()(pred, y_t)
87
+ opt.zero_grad()
88
+ loss.backward()
89
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
90
+ opt.step()
91
+ epoch_loss += loss.item()
92
+
93
+ print("Dynamics model training complete.")
94
+
95
+ # ── Inference ─────────────────────────────────────────────────────────
96
+
97
+ def predict(
98
+ self,
99
+ state: np.ndarray,
100
+ action,
101
+ ) -> tuple:
102
+ """
103
+ Returns (mean_next_state, transition_uncertainty).
104
+
105
+ transition_uncertainty = mean std across state dims, averaged over ensemble.
106
+ Higher value β†’ ensemble disagrees β†’ out-of-distribution.
107
+ """
108
+ state_n = self._norm_state(state).astype(np.float32)
109
+ action_oh = np.eye(ACTION_DIM, dtype=np.float32)[int(action)]
110
+ x = torch.from_numpy(np.concatenate([state_n, action_oh])).unsqueeze(0) # (1, 6)
111
+
112
+ preds = []
113
+ with torch.no_grad():
114
+ for model in self.models:
115
+ preds.append(model(x).squeeze(0).numpy()) # (4,)
116
+
117
+ preds = np.stack(preds) # (n_models, 4)
118
+ mean_n = preds.mean(axis=0) # normalised space
119
+ std_n = preds.std(axis=0) # normalised space
120
+
121
+ mean = self._denorm_state(mean_n)
122
+ uncertainty = float(std_n.mean()) # scalar summary of disagreement
123
+ return mean, uncertainty
124
+
125
+ # ── Persistence ───────��───────────────────────────────────────────────
126
+
127
+ def save(self, path: str) -> None:
128
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
129
+ payload = {
130
+ "state_dicts": [m.state_dict() for m in self.models],
131
+ "state_mean": self.state_mean,
132
+ "state_std": self.state_std,
133
+ }
134
+ torch.save(payload, path)
135
+ print(f"Dynamics model saved β†’ {path}")
136
+
137
+ def load(self, path: str) -> None:
138
+ payload = torch.load(path, map_location="cpu", weights_only=False)
139
+ for model, sd in zip(self.models, payload["state_dicts"]):
140
+ model.load_state_dict(sd)
141
+ self.state_mean = payload["state_mean"]
142
+ self.state_std = payload["state_std"]
143
+ print(f"Dynamics model loaded ← {path}")