6Net 2.0 — 6-Axis Visual Robot Policy (~228M)
Custom transformer policy for visual 6-DoF robot arm control, v2. Optimised for LDX-218, LD-1501MG, and LFD-01M hardware.
| Component | Detail | Params |
|---|---|---|
| Visual Encoder | ResNet-50 (ImageNet V2, shared) | ~23.5M |
| Visual Projection (overhead) | Linear(2048→1024) | ~2.1M |
| Visual Projection (wrist) | Linear(2048→1024) | ~2.1M |
| State Encoder | MLP(6→256→1024) | ~0.3M |
| Transformer | 16L · d=1024 · 16h · ffn=4096 | ~201.6M |
| Action Head | MLP(1024→512→K×6) | ~0.5M |
| Total | ~228M |
Dataset: lerobot/pusht_image · Steps: 910 · Eff. batch: 32
Hardware profile: LDX-218
Key improvements over v1
- 2× parameter count via ResNet-50 backbone + wider/deeper transformer
- Dual-camera overhead + wrist tokens
- Action chunking (K=10): predicts 10 future steps; returns step 0 at inference
- Hardware profiles: joint limits, max velocity, and gravity-comp for LDX-218 / LD-1501MG / LFD-01M
- Streaming fallback: tries streaming download before falling back to synthetic data
Inference
import torch
from train_6net_v2 import SixNetV2, Config, HARDWARE_PROFILES
import torchvision.transforms as T
from PIL import Image
cfg = Config(hardware="LDX-218")
model = SixNetV2(cfg)
ckpt = torch.load("6net_v2_final.pt", map_location="cpu")
model.load_state_dict(ckpt["model_state"])
model.eval()
tf = T.Compose([T.Resize((224,224)), T.ToTensor(),
T.Normalize([.485,.456,.406],[.229,.224,.225])])
img = tf(Image.open("overhead.jpg")).unsqueeze(0)
wrist = tf(Image.open("wrist.jpg")).unsqueeze(0)
jts = torch.zeros(1, 6) # current joint angles (rad)
action = model.predict(img, jts, wrist=wrist, hw=HARDWARE_PROFILES["LDX-218"])
# → tensor of shape (1, 6), clamped to LDX-218 joint limits
- Downloads last month
- -