YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

ARX5 MuJoCo World Model β€” Delta Representation Sweep (500k steps)

Ablation of action/state representation for an offline world model trained on ARX5 MuJoCo simulation data. Four configs sweeping action_representation (absolute|delta) Γ— state_prediction_mode (absolute|delta).

Architecture

  • Backbone: nanoGPT causal transformer (type: gpt)
  • Size: 6 layers, 12 heads, 384 embed dim
  • History horizon: 32 (h32)
  • Ensemble: 2 members
  • State dim: 14 (7 joint pos + 7 joint vel)
  • Action dim: 7

Training

  • Dataset: train.pt β€” 1024 episodes, 3.5M frames, ARX5 dual-arm MuJoCo
  • Steps: 500,000
  • Batch size: 64
  • LR: 3e-4 β†’ 3e-5 (cosine)
  • Noise: linear context noise (0 β†’ 0.05 over 50k steps)
  • Hardware: NVIDIA GH200 (Isambard AI Phase 2)

Results

config action_input state_target best_val_loss best_step grad_norm verdict
cfgA absolute absolute 0.0982 500k 0.034 baseline
cfgB absolute delta 0.4244 490k 3.909 discard
cfgC delta absolute 0.0764 500k 0.040 winner
cfgD delta delta 0.2282 425k 4.295 discard

Config C (delta action input, absolute state targets) achieves 22% lower val loss than the baseline. Delta state targets (B, D) hurt significantly β€” the state is 14D (7 pos + 7 vel), so predicting s_{t+1} - s_t forces the velocity dims to regress acceleration (a second-order, high-variance quantity). This explains both the poor loss and the ~4x higher grad norms.

Checkpoint Hashes

Computed via cd checkpoints/cfg<X> && find params -type f | sort | xargs sha256sum | sha256sum (verified with double run):

config sha256
cfgA 5058f01a1171ecd2a79e8c90ca0f83e6cbf352523d7f58d9cf4f814699f02ca5
cfgB 4736e0c2cc027167c19a93e2939b97d7d13e6d5a692fa8e91ec1d0da396359df
cfgC 621708c7716c97be1e9509e8d8840197680dd1c83c6f36c830ac7df02bced3b6
cfgD ca03b56a893e5e041806cf8b569313a3c893a6e396b0621af0931663f0d3f761

Contents

README.md
TRAINING_LOG.md
checkpoints/
  cfgA/params/params.pt          # absolute/absolute baseline
  cfgA/params/normalization_stats.json
  cfgB/params/params.pt          # absolute/delta
  cfgB/params/normalization_stats.json
  cfgC/params/params.pt          # delta/absolute (winner)
  cfgC/params/normalization_stats.json
  cfgD/params/params.pt          # delta/delta
  cfgD/params/normalization_stats.json

Each params.pt contains: model_state_dict, best_val_loss, best_step, global_step, action_representation, state_prediction_mode, config_label.

Usage

import torch

ckpt = torch.load("checkpoints/cfgC/params/params.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])

W&B

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support