Tactile-1 v1: Cross-Modal Visual-Tactile World Model
A cross-modal predictor that predicts tactile embeddings from visual observations using frozen pretrained encoders.
Architecture
| Component | Model | Params | Status |
|---|---|---|---|
| Visual encoder | DINOv2-Small (ViT-S/14) | 22M | Frozen |
| Tactile encoder | Sparsh-DINO (ViT-B/16, 6ch) | 86M | Frozen |
| Predictor | CrossAttnPredictor (6-layer, 384-dim) | 16M | Trained |
| Action encoder | MLP (7β384) | 0.15M | Initialized (not yet trained) |
| Total | 125M | 16M trainable |
Training
- Dataset: YCB-Slide (50 trajectories, 60K frames, paired RGB + GelSight tactile)
- Split: Leave-2-Objects-Out (L2OO) β holdout: tomato soup can, hammer
- Hardware: NVIDIA A100 SXM 80GB, RunPod
- Time: 18 minutes (11 epochs, early stopped at epoch 10)
- Batch size: 128, num_workers=2, bfloat16 AMP
- Optimizer: AdamW, lr=3e-4, weight_decay=0.01, cosine warmup (2000 steps)
- Loss: Smooth L1 on LayerNorm-normalized tactile targets
Results
| Metric | Value |
|---|---|
| Val prediction loss (L2OO) | 0.0108 |
| Train loss (converged) | 0.0112 |
The model converged rapidly β most learning happened in epoch 0 (loss: 0.52 β 0.011).
Usage
import torch
from tactile_jepa.encoders import FrozenDINOv2, FrozenSparsh
from tactile_jepa.architectures import CrossAttnPredictor, ActionEncoder
from tactile_jepa.cross_modal_jepa import CrossModalJEPA
from tactile_jepa.losses import CrossModalLoss
# Build model
model = CrossModalJEPA(
visual_encoder=FrozenDINOv2(),
tactile_encoder=FrozenSparsh(),
predictor=CrossAttnPredictor(dim=384, depth=6, num_heads=6, num_queries=16,
visual_dim=384, output_dim=768, dropout=0.1),
action_encoder=ActionEncoder(action_dim=7, hidden_dim=384),
predcost=CrossModalLoss(target_dim=768),
).to("cuda")
# Load checkpoint
ckpt = torch.load("best.pth.tar", map_location="cuda", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model.eval()
# Predict tactile from visual input
# visual_obs: [B, T, 3, 224, 224] ImageNet-normalized
predicted_tactile = model.predict_tactile(visual_obs) # [B, N_queries, 768]
Files
best.pth.tarβ Best checkpoint (epoch 0, val_pred=0.0108), 382MBconfig.yamlβ Training configuration
Limitations
- Trained only on YCB-Slide (50 trajectories, single sensor type)
- Actions not yet enabled (Phase A only β Phase B with IDM is planned)
- No cross-modal retrieval or effective rank evaluation yet
- Sparsh encoder uses CC-BY-NC license (Meta FAIR) β not for commercial use
Citation
@software{tactile_jepa_2026,
title={Tactile-JEPA: Cross-Modal Visual-Tactile World Model},
author={Julian Saks},
year={2026},
url={https://github.com/kingulio8238/tactile_jepa}
}
- Downloads last month
- 4