YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
ARX5 MuJoCo World Model β Delta Representation Sweep (500k steps)
Ablation of action/state representation for an offline world model trained on ARX5 MuJoCo simulation data. Four configs sweeping action_representation (absolute|delta) Γ state_prediction_mode (absolute|delta).
Architecture
- Backbone: nanoGPT causal transformer (type:
gpt) - Size: 6 layers, 12 heads, 384 embed dim
- History horizon: 32 (h32)
- Ensemble: 2 members
- State dim: 14 (7 joint pos + 7 joint vel)
- Action dim: 7
Training
- Dataset:
train.ptβ 1024 episodes, 3.5M frames, ARX5 dual-arm MuJoCo - Steps: 500,000
- Batch size: 64
- LR: 3e-4 β 3e-5 (cosine)
- Noise: linear context noise (0 β 0.05 over 50k steps)
- Hardware: NVIDIA GH200 (Isambard AI Phase 2)
Results
| config | action_input | state_target | best_val_loss | best_step | grad_norm | verdict |
|---|---|---|---|---|---|---|
| cfgA | absolute | absolute | 0.0982 | 500k | 0.034 | baseline |
| cfgB | absolute | delta | 0.4244 | 490k | 3.909 | discard |
| cfgC | delta | absolute | 0.0764 | 500k | 0.040 | winner |
| cfgD | delta | delta | 0.2282 | 425k | 4.295 | discard |
Config C (delta action input, absolute state targets) achieves 22% lower val loss than the baseline. Delta state targets (B, D) hurt significantly β the state is 14D (7 pos + 7 vel), so predicting s_{t+1} - s_t forces the velocity dims to regress acceleration (a second-order, high-variance quantity). This explains both the poor loss and the ~4x higher grad norms.
Checkpoint Hashes
Computed via cd checkpoints/cfg<X> && find params -type f | sort | xargs sha256sum | sha256sum (verified with double run):
| config | sha256 |
|---|---|
| cfgA | 5058f01a1171ecd2a79e8c90ca0f83e6cbf352523d7f58d9cf4f814699f02ca5 |
| cfgB | 4736e0c2cc027167c19a93e2939b97d7d13e6d5a692fa8e91ec1d0da396359df |
| cfgC | 621708c7716c97be1e9509e8d8840197680dd1c83c6f36c830ac7df02bced3b6 |
| cfgD | ca03b56a893e5e041806cf8b569313a3c893a6e396b0621af0931663f0d3f761 |
Contents
README.md
TRAINING_LOG.md
checkpoints/
cfgA/params/params.pt # absolute/absolute baseline
cfgA/params/normalization_stats.json
cfgB/params/params.pt # absolute/delta
cfgB/params/normalization_stats.json
cfgC/params/params.pt # delta/absolute (winner)
cfgC/params/normalization_stats.json
cfgD/params/params.pt # delta/delta
cfgD/params/normalization_stats.json
Each params.pt contains: model_state_dict, best_val_loss, best_step, global_step, action_representation, state_prediction_mode, config_label.
Usage
import torch
ckpt = torch.load("checkpoints/cfgC/params/params.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])