Upload REPRODUCTION.md with huggingface_hub
Browse files- REPRODUCTION.md +257 -0
REPRODUCTION.md
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AC-DiT 复现记录
|
| 2 |
+
|
| 3 |
+
## 最终结果(all-7-combo baseline,ckpt-25000)
|
| 4 |
+
|
| 5 |
+
| Combo | 我们 ckpt-25000 | 我们 ckpt-30000 | 论文 (100×3 runs) |
|
| 6 |
+
|---|---|---|---|
|
| 7 |
+
| pick_apple | 26.0% | **32.0%** | 33.3 ± 1.9 |
|
| 8 |
+
| pick_bowl | **42.0%** | 24.0% | 36.0 ± 6.5 |
|
| 9 |
+
| place_apple | **34.0%** | 18.0% | 33.3 ± 9.4 |
|
| 10 |
+
| place_bowl | **48.0%** | 24.0% | 17.3 ± 6.8 |
|
| 11 |
+
| open_fridge | 92.0% | **96.0%** | 90.7 ± 5.0 |
|
| 12 |
+
| open_kc | 74.0% | **78.0%** | 81.3 ± 6.8 |
|
| 13 |
+
| close_kc | **100.0%** | **100.0%** | 97.3 ± 1.9 |
|
| 14 |
+
| **Mean S.R.** | **59.4%** | 53.1% | **55.6%** |
|
| 15 |
+
|
| 16 |
+
**ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。
|
| 17 |
+
ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。
|
| 18 |
+
|
| 19 |
+
**推荐发布权重:ckpt-25000。**
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 权重路径
|
| 24 |
+
|
| 25 |
+
### HPC3
|
| 26 |
+
```
|
| 27 |
+
/data/user/jhe724/workspace/AC-DiT/checkpoints/
|
| 28 |
+
├── DiT-mshab-base-only/checkpoint-30000/ # Stage 1 mobility head (7 combos pooled, --base_only)
|
| 29 |
+
│ └── pytorch_model/mp_rank_00_model_states.pt
|
| 30 |
+
└── AC-DiT-mshab-all7_baseline/ # Stage 2 main model
|
| 31 |
+
├── checkpoint-25000/ ← 最佳
|
| 32 |
+
├── checkpoint-27500/
|
| 33 |
+
├── checkpoint-30000/ (轻微过拟合)
|
| 34 |
+
└── checkpoint-{2500..22500}/ (中间 ckpt)
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### 本地 LFT-W02
|
| 38 |
+
```
|
| 39 |
+
/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/
|
| 40 |
+
├── DiT-mshab-base-only/checkpoint-30000/
|
| 41 |
+
└── AC-DiT-mshab-all7_baseline/
|
| 42 |
+
├── checkpoint-25000/ ← 评测用这个
|
| 43 |
+
├── checkpoint-15000/
|
| 44 |
+
├── checkpoint-20000/
|
| 45 |
+
└── checkpoint-30000/
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 训练脚本(HPC3, 8×H100, ~16h)
|
| 51 |
+
|
| 52 |
+
### 1. sbatch wrapper
|
| 53 |
+
`/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch`
|
| 54 |
+
```bash
|
| 55 |
+
#!/bin/bash
|
| 56 |
+
#SBATCH -p acd_u
|
| 57 |
+
#SBATCH --nodes=1
|
| 58 |
+
#SBATCH --ntasks-per-node=1
|
| 59 |
+
#SBATCH --cpus-per-task=64
|
| 60 |
+
#SBATCH --mem=384G
|
| 61 |
+
#SBATCH --gres=gpu:8
|
| 62 |
+
#SBATCH -t 7-00:00:00
|
| 63 |
+
#SBATCH -J all7_base
|
| 64 |
+
PROJECT=/data/user/jhe724/workspace/AC-DiT
|
| 65 |
+
echo "=== sbatch host=$(hostname) job=$SLURM_JOB_ID time=$(date) ==="
|
| 66 |
+
cd $PROJECT
|
| 67 |
+
exec bash scripts/launch_s2_all7_baseline.sh
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
提交:`sbatch scripts/train_all7.sbatch`
|
| 71 |
+
|
| 72 |
+
### 2. launch 脚本
|
| 73 |
+
`scripts/launch_s2_all7_baseline.sh` 做三件事:
|
| 74 |
+
1. 把 `data/hdf5_mshab_dataset.py` 的 `task_subtask_obj` 恢复成全 7 combo
|
| 75 |
+
2. patch `scripts/finetune_mshab_acdit.py`(用 .orig 还原后 sed 改):
|
| 76 |
+
- `--data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories`(lustre 直读,省 2T+ /tmp copy)
|
| 77 |
+
- `--num_episode_per_task=1000` (全 1000 traj/combo = 7000 total)
|
| 78 |
+
- `--train_batch_size=20`(× 8 GPU = 160 effective)
|
| 79 |
+
- `--sample_batch_size=8`(val eval 用)
|
| 80 |
+
- `--max_train_steps=30000`
|
| 81 |
+
- `--sample_period=200`(每 200 步 sample + val eval 写 wandb)
|
| 82 |
+
- `--mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt`
|
| 83 |
+
- EXP_NAME = `AC-DiT-mshab-all7_baseline`
|
| 84 |
+
3. 跑 `python -m scripts.finetune_mshab_acdit`
|
| 85 |
+
|
| 86 |
+
### 3. finetune driver
|
| 87 |
+
`scripts/finetune_mshab_acdit.py` 主要 accelerate_command:
|
| 88 |
+
```python
|
| 89 |
+
'accelerate', 'launch',
|
| 90 |
+
'--main_process_port=29905',
|
| 91 |
+
'main.py',
|
| 92 |
+
'--deepspeed=./configs/zero2.json',
|
| 93 |
+
'--method_name=AC-DiT',
|
| 94 |
+
'--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b',
|
| 95 |
+
'--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
|
| 96 |
+
'--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
|
| 97 |
+
'--mobility_head_ckpt_path=...',
|
| 98 |
+
'--config_path=configs/config.yaml',
|
| 99 |
+
'--hdf5_dataset_name=mshab',
|
| 100 |
+
'--data_dir=...',
|
| 101 |
+
'--in_context_cond_dim=18',
|
| 102 |
+
'--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline',
|
| 103 |
+
'--resume_from_checkpoint=latest',
|
| 104 |
+
'--train_batch_size=20',
|
| 105 |
+
'--sample_batch_size=8',
|
| 106 |
+
'--gradient_accumulation_steps=1',
|
| 107 |
+
'--max_train_steps=30000',
|
| 108 |
+
'--checkpointing_period=2500',
|
| 109 |
+
'--sample_period=200',
|
| 110 |
+
'--checkpoints_total_limit=40',
|
| 111 |
+
'--lr_scheduler=constant',
|
| 112 |
+
'--learning_rate=1e-5',
|
| 113 |
+
'--mixed_precision=bf16',
|
| 114 |
+
'--dataloader_num_workers=8',
|
| 115 |
+
'--image_aug',
|
| 116 |
+
'--dataset_type=finetune',
|
| 117 |
+
'--state_noise_snr=40',
|
| 118 |
+
'--load_from_hdf5',
|
| 119 |
+
'--precomp_lang_embed',
|
| 120 |
+
'--num_episode_per_task=1000',
|
| 121 |
+
'--report_to=wandb',
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 4. config.yaml
|
| 125 |
+
`configs/config.yaml` 保持 baseline:
|
| 126 |
+
- `action_chunk_size: 2`
|
| 127 |
+
- `img_history_size: 2`
|
| 128 |
+
- `num_cameras: 3`
|
| 129 |
+
- `state_dim: 128`
|
| 130 |
+
|
| 131 |
+
---
|
| 132 |
+
|
| 133 |
+
## 代码改动(最小化, 仅加 val eval 支持)
|
| 134 |
+
|
| 135 |
+
vs commit `90ad00a`(init: release codes)的差异:
|
| 136 |
+
|
| 137 |
+
### `data/hdf5_mshab_dataset.py`
|
| 138 |
+
- 加 `split: Literal["train","val"]="train"` 和 `num_episode_per_task: int=100` 构造参数
|
| 139 |
+
- 用 `self.split` 替代硬编码 `'train'` 路径(3 处)
|
| 140 |
+
- 用 `self._cfg_num_episode_per_task` 替代硬编码 `100`
|
| 141 |
+
|
| 142 |
+
### `train/dataset.py`
|
| 143 |
+
- `VLAConsumerDataset.__init__` 加 `split="train"` 和 `num_episode_per_task=100` 参数
|
| 144 |
+
- 透传给 `HDF5MSHABDataset`
|
| 145 |
+
|
| 146 |
+
### `train/train.py`
|
| 147 |
+
- 新建 `val_dataset` (split="val", num_episode_per_task=100)
|
| 148 |
+
- 新建 `val_dataloader`(用 `accelerator.prepare_data_loader` 单独 prepare)
|
| 149 |
+
- 在 `sample_period` 触发时除了原 sample_loss_for_log,还跑 val:
|
| 150 |
+
```python
|
| 151 |
+
val_loss_for_log = log_sample_res(..., val_dataloader, ...)
|
| 152 |
+
val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()}
|
| 153 |
+
accelerator.log(val_loss_for_log, step=global_step)
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### `main.py`
|
| 157 |
+
- 加 `--num_episode_per_task` argparse arg
|
| 158 |
+
|
| 159 |
+
### 数据准备(一次性)
|
| 160 |
+
1. 生成 val 数据 100 traj/combo(用 `gen_val_combo.sbatch`,SPLIT=val 跑 7 个 combo)
|
| 161 |
+
2. 给每个 val 目录建 `instructions` 符号链接指向对应 train 目录(语言 embed 通用):
|
| 162 |
+
```bash
|
| 163 |
+
for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do
|
| 164 |
+
ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions
|
| 165 |
+
done
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
## 评测脚本(本地, RTX A6000)
|
| 171 |
+
|
| 172 |
+
### eval 命令模板
|
| 173 |
+
```bash
|
| 174 |
+
export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH
|
| 175 |
+
PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
|
| 176 |
+
CKPT_STEP=25000
|
| 177 |
+
|
| 178 |
+
PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \
|
| 179 |
+
MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \
|
| 180 |
+
DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \
|
| 181 |
+
TASK=set_table SUBTASK=pick OBJECT=013_apple \
|
| 182 |
+
NUM_TRAJ=50 CUDA=0 \
|
| 183 |
+
RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \
|
| 184 |
+
bash scripts/eval_acdit.sh
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### 7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT
|
| 188 |
+
| short | TASK | SUBTASK | OBJECT |
|
| 189 |
+
|---|---|---|---|
|
| 190 |
+
| pick_apple | set_table | pick | 013_apple |
|
| 191 |
+
| pick_bowl | set_table | pick | 024_bowl |
|
| 192 |
+
| place_apple | set_table | place | 013_apple |
|
| 193 |
+
| place_bowl | set_table | place | 024_bowl |
|
| 194 |
+
| open_fridge | set_table | open | fridge |
|
| 195 |
+
| open_kc | set_table | open | kitchen_counter |
|
| 196 |
+
| close_kc | set_table | close | kitchen_counter |
|
| 197 |
+
|
| 198 |
+
### eval_acdit.sh 内部
|
| 199 |
+
直接调 `python -m scripts.eval_mshab` 跑 ManiSkill env 50 trial 实测。
|
| 200 |
+
每 trial 200 步 timeout。
|
| 201 |
+
结果写 `$RESULT_DIR/<combo>/`:
|
| 202 |
+
- `trial_<N>_{success|failure}.mp4`
|
| 203 |
+
- `success_rate.txt`
|
| 204 |
+
|
| 205 |
+
### 并行评测建议
|
| 206 |
+
本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。
|
| 207 |
+
每 eval 约 25-90 min(依赖 GPU 共享情况)。
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## 复现步骤
|
| 212 |
+
|
| 213 |
+
### 训练
|
| 214 |
+
```bash
|
| 215 |
+
# 在 HPC3 上
|
| 216 |
+
ssh HPC3_jhe724
|
| 217 |
+
cd /data/user/jhe724/workspace/AC-DiT
|
| 218 |
+
|
| 219 |
+
# 提交训练(8 GPU × 16h)
|
| 220 |
+
sbatch scripts/train_all7.sbatch
|
| 221 |
+
|
| 222 |
+
# 监控
|
| 223 |
+
squeue -u jhe724
|
| 224 |
+
tail -f logs/train_s2_all7_baseline.log
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
### 评测
|
| 228 |
+
```bash
|
| 229 |
+
# 在本地 LFT-W02 上
|
| 230 |
+
cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
|
| 231 |
+
|
| 232 |
+
# rsync ckpt
|
| 233 |
+
rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \
|
| 234 |
+
./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/
|
| 235 |
+
|
| 236 |
+
# 跑评测(每 combo 一条命令,参考上面)
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
---
|
| 240 |
+
|
| 241 |
+
## 失败教训(避坑)
|
| 242 |
+
|
| 243 |
+
1. **错觉**:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字
|
| 244 |
+
2. **数据量**:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ **训练数据规模是关键**
|
| 245 |
+
3. **过拟合**:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ **早停而非训满**
|
| 246 |
+
4. **Action 异常**:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。**Loss 上的尖峰来自这些样本**。不影响最终成功率(因为 env 也 clip)但训练监督不科学
|
| 247 |
+
5. **val 数据准备**:用 `gen_data.py` SPLIT=val 生成时不会自动建 `instructions/` 子目录 → 训练时 `np.random.choice` 报"empty"。修:从 train 目录软链 `instructions`
|
| 248 |
+
6. **过早评测**:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 **完全对齐论文配置**
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
## Wandb metrics 命名
|
| 253 |
+
- 训练 sample loss:`mshab_sample_mse`, `mshab_sample_l2err`, `overall_avg_sample_mse`, `overall_avg_sample_l2err`
|
| 254 |
+
- val loss(held-out 700 traj):`val/mshab_sample_mse`, `val/mshab_sample_l2err`, `val/overall_avg_sample_mse`, `val/overall_avg_sample_l2err`
|
| 255 |
+
- per-step:`loss`, `lr`
|
| 256 |
+
|
| 257 |
+
EXP_NAME = `AC-DiT-mshab-all7_baseline` (wandb run name)
|