File size: 9,627 Bytes
f57e782 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | # AC-DiT 复现记录
## 最终结果(all-7-combo baseline,ckpt-25000)
| Combo | 我们 ckpt-25000 | 我们 ckpt-30000 | 论文 (100×3 runs) |
|---|---|---|---|
| pick_apple | 26.0% | **32.0%** | 33.3 ± 1.9 |
| pick_bowl | **42.0%** | 24.0% | 36.0 ± 6.5 |
| place_apple | **34.0%** | 18.0% | 33.3 ± 9.4 |
| place_bowl | **48.0%** | 24.0% | 17.3 ± 6.8 |
| open_fridge | 92.0% | **96.0%** | 90.7 ± 5.0 |
| open_kc | 74.0% | **78.0%** | 81.3 ± 6.8 |
| close_kc | **100.0%** | **100.0%** | 97.3 ± 1.9 |
| **Mean S.R.** | **59.4%** | 53.1% | **55.6%** |
**ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。
ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。
**推荐发布权重:ckpt-25000。**
---
## 权重路径
### HPC3
```
/data/user/jhe724/workspace/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/ # Stage 1 mobility head (7 combos pooled, --base_only)
│ └── pytorch_model/mp_rank_00_model_states.pt
└── AC-DiT-mshab-all7_baseline/ # Stage 2 main model
├── checkpoint-25000/ ← 最佳
├── checkpoint-27500/
├── checkpoint-30000/ (轻微过拟合)
└── checkpoint-{2500..22500}/ (中间 ckpt)
```
### 本地 LFT-W02
```
/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/
└── AC-DiT-mshab-all7_baseline/
├── checkpoint-25000/ ← 评测用这个
├── checkpoint-15000/
├── checkpoint-20000/
└── checkpoint-30000/
```
---
## 训练脚本(HPC3, 8×H100, ~16h)
### 1. sbatch wrapper
`/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch`
```bash
#!/bin/bash
#SBATCH -p acd_u
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=384G
#SBATCH --gres=gpu:8
#SBATCH -t 7-00:00:00
#SBATCH -J all7_base
PROJECT=/data/user/jhe724/workspace/AC-DiT
echo "=== sbatch host=$(hostname) job=$SLURM_JOB_ID time=$(date) ==="
cd $PROJECT
exec bash scripts/launch_s2_all7_baseline.sh
```
提交:`sbatch scripts/train_all7.sbatch`
### 2. launch 脚本
`scripts/launch_s2_all7_baseline.sh` 做三件事:
1. 把 `data/hdf5_mshab_dataset.py` 的 `task_subtask_obj` 恢复成全 7 combo
2. patch `scripts/finetune_mshab_acdit.py`(用 .orig 还原后 sed 改):
- `--data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories`(lustre 直读,省 2T+ /tmp copy)
- `--num_episode_per_task=1000` (全 1000 traj/combo = 7000 total)
- `--train_batch_size=20`(× 8 GPU = 160 effective)
- `--sample_batch_size=8`(val eval 用)
- `--max_train_steps=30000`
- `--sample_period=200`(每 200 步 sample + val eval 写 wandb)
- `--mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt`
- EXP_NAME = `AC-DiT-mshab-all7_baseline`
3. 跑 `python -m scripts.finetune_mshab_acdit`
### 3. finetune driver
`scripts/finetune_mshab_acdit.py` 主要 accelerate_command:
```python
'accelerate', 'launch',
'--main_process_port=29905',
'main.py',
'--deepspeed=./configs/zero2.json',
'--method_name=AC-DiT',
'--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b',
'--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--mobility_head_ckpt_path=...',
'--config_path=configs/config.yaml',
'--hdf5_dataset_name=mshab',
'--data_dir=...',
'--in_context_cond_dim=18',
'--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline',
'--resume_from_checkpoint=latest',
'--train_batch_size=20',
'--sample_batch_size=8',
'--gradient_accumulation_steps=1',
'--max_train_steps=30000',
'--checkpointing_period=2500',
'--sample_period=200',
'--checkpoints_total_limit=40',
'--lr_scheduler=constant',
'--learning_rate=1e-5',
'--mixed_precision=bf16',
'--dataloader_num_workers=8',
'--image_aug',
'--dataset_type=finetune',
'--state_noise_snr=40',
'--load_from_hdf5',
'--precomp_lang_embed',
'--num_episode_per_task=1000',
'--report_to=wandb',
```
### 4. config.yaml
`configs/config.yaml` 保持 baseline:
- `action_chunk_size: 2`
- `img_history_size: 2`
- `num_cameras: 3`
- `state_dim: 128`
---
## 代码改动(最小化, 仅加 val eval 支持)
vs commit `90ad00a`(init: release codes)的差异:
### `data/hdf5_mshab_dataset.py`
- 加 `split: Literal["train","val"]="train"` 和 `num_episode_per_task: int=100` 构造参数
- 用 `self.split` 替代硬编码 `'train'` 路径(3 处)
- 用 `self._cfg_num_episode_per_task` 替代硬编码 `100`
### `train/dataset.py`
- `VLAConsumerDataset.__init__` 加 `split="train"` 和 `num_episode_per_task=100` 参数
- 透传给 `HDF5MSHABDataset`
### `train/train.py`
- 新建 `val_dataset` (split="val", num_episode_per_task=100)
- 新建 `val_dataloader`(用 `accelerator.prepare_data_loader` 单独 prepare)
- 在 `sample_period` 触发时除了原 sample_loss_for_log,还跑 val:
```python
val_loss_for_log = log_sample_res(..., val_dataloader, ...)
val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()}
accelerator.log(val_loss_for_log, step=global_step)
```
### `main.py`
- 加 `--num_episode_per_task` argparse arg
### 数据准备(一次性)
1. 生成 val 数据 100 traj/combo(用 `gen_val_combo.sbatch`,SPLIT=val 跑 7 个 combo)
2. 给每个 val 目录建 `instructions` 符号链接指向对应 train 目录(语言 embed 通用):
```bash
for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do
ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions
done
```
---
## 评测脚本(本地, RTX A6000)
### eval 命令模板
```bash
export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH
PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
CKPT_STEP=25000
PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \
MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \
DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \
TASK=set_table SUBTASK=pick OBJECT=013_apple \
NUM_TRAJ=50 CUDA=0 \
RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \
bash scripts/eval_acdit.sh
```
### 7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT
| short | TASK | SUBTASK | OBJECT |
|---|---|---|---|
| pick_apple | set_table | pick | 013_apple |
| pick_bowl | set_table | pick | 024_bowl |
| place_apple | set_table | place | 013_apple |
| place_bowl | set_table | place | 024_bowl |
| open_fridge | set_table | open | fridge |
| open_kc | set_table | open | kitchen_counter |
| close_kc | set_table | close | kitchen_counter |
### eval_acdit.sh 内部
直接调 `python -m scripts.eval_mshab` 跑 ManiSkill env 50 trial 实测。
每 trial 200 步 timeout。
结果写 `$RESULT_DIR/<combo>/`:
- `trial_<N>_{success|failure}.mp4`
- `success_rate.txt`
### 并行评测建议
本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。
每 eval 约 25-90 min(依赖 GPU 共享情况)。
---
## 复现步骤
### 训练
```bash
# 在 HPC3 上
ssh HPC3_jhe724
cd /data/user/jhe724/workspace/AC-DiT
# 提交训练(8 GPU × 16h)
sbatch scripts/train_all7.sbatch
# 监控
squeue -u jhe724
tail -f logs/train_s2_all7_baseline.log
```
### 评测
```bash
# 在本地 LFT-W02 上
cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
# rsync ckpt
rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \
./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/
# 跑评测(每 combo 一条命令,参考上面)
```
---
## 失败教训(避坑)
1. **错觉**:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字
2. **数据量**:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ **训练数据规模是关键**
3. **过拟合**:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ **早停而非训满**
4. **Action 异常**:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。**Loss 上的尖峰来自这些样本**。不影响最终成功率(因为 env 也 clip)但训练监督不科学
5. **val 数据准备**:用 `gen_data.py` SPLIT=val 生成时不会自动建 `instructions/` 子目录 → 训练时 `np.random.choice` 报"empty"。修:从 train 目录软链 `instructions`
6. **过早评测**:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 **完全对齐论文配置**
---
## Wandb metrics 命名
- 训练 sample loss:`mshab_sample_mse`, `mshab_sample_l2err`, `overall_avg_sample_mse`, `overall_avg_sample_l2err`
- val loss(held-out 700 traj):`val/mshab_sample_mse`, `val/mshab_sample_l2err`, `val/overall_avg_sample_mse`, `val/overall_avg_sample_l2err`
- per-step:`loss`, `lr`
EXP_NAME = `AC-DiT-mshab-all7_baseline` (wandb run name)
|