File size: 9,627 Bytes
f57e782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# AC-DiT 复现记录

## 最终结果(all-7-combo baseline,ckpt-25000)

| Combo | 我们 ckpt-25000 | 我们 ckpt-30000 | 论文 (100×3 runs) |
|---|---|---|---|
| pick_apple | 26.0% | **32.0%** | 33.3 ± 1.9 |
| pick_bowl | **42.0%** | 24.0% | 36.0 ± 6.5 |
| place_apple | **34.0%** | 18.0% | 33.3 ± 9.4 |
| place_bowl | **48.0%** | 24.0% | 17.3 ± 6.8 |
| open_fridge | 92.0% | **96.0%** | 90.7 ± 5.0 |
| open_kc | 74.0% | **78.0%** | 81.3 ± 6.8 |
| close_kc | **100.0%** | **100.0%** | 97.3 ± 1.9 |
| **Mean S.R.** | **59.4%** | 53.1% | **55.6%** |

**ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。
ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。

**推荐发布权重:ckpt-25000。**

---

## 权重路径

### HPC3
```
/data/user/jhe724/workspace/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/         # Stage 1 mobility head (7 combos pooled, --base_only)
│   └── pytorch_model/mp_rank_00_model_states.pt
└── AC-DiT-mshab-all7_baseline/                   # Stage 2 main model
    ├── checkpoint-25000/  ← 最佳
    ├── checkpoint-27500/
    ├── checkpoint-30000/  (轻微过拟合)
    └── checkpoint-{2500..22500}/  (中间 ckpt)
```

### 本地 LFT-W02
```
/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/
└── AC-DiT-mshab-all7_baseline/
    ├── checkpoint-25000/  ← 评测用这个
    ├── checkpoint-15000/
    ├── checkpoint-20000/
    └── checkpoint-30000/
```

---

## 训练脚本(HPC3, 8×H100, ~16h)

### 1. sbatch wrapper
`/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch`
```bash
#!/bin/bash
#SBATCH -p acd_u
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=384G
#SBATCH --gres=gpu:8
#SBATCH -t 7-00:00:00
#SBATCH -J all7_base
PROJECT=/data/user/jhe724/workspace/AC-DiT
echo "=== sbatch host=$(hostname)  job=$SLURM_JOB_ID  time=$(date) ==="
cd $PROJECT
exec bash scripts/launch_s2_all7_baseline.sh
```

提交:`sbatch scripts/train_all7.sbatch`

### 2. launch 脚本
`scripts/launch_s2_all7_baseline.sh` 做三件事:
1.`data/hdf5_mshab_dataset.py``task_subtask_obj` 恢复成全 7 combo
2. patch `scripts/finetune_mshab_acdit.py`(用 .orig 还原后 sed 改):
   - `--data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories`(lustre 直读,省 2T+ /tmp copy)
   - `--num_episode_per_task=1000` (全 1000 traj/combo = 7000 total)
   - `--train_batch_size=20`(× 8 GPU = 160 effective)
   - `--sample_batch_size=8`(val eval 用)
   - `--max_train_steps=30000`
   - `--sample_period=200`(每 200 步 sample + val eval 写 wandb)
   - `--mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt`
   - EXP_NAME = `AC-DiT-mshab-all7_baseline`
3.`python -m scripts.finetune_mshab_acdit`

### 3. finetune driver
`scripts/finetune_mshab_acdit.py` 主要 accelerate_command:
```python
'accelerate', 'launch',
'--main_process_port=29905',
'main.py',
'--deepspeed=./configs/zero2.json',
'--method_name=AC-DiT',
'--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b',
'--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--mobility_head_ckpt_path=...',
'--config_path=configs/config.yaml',
'--hdf5_dataset_name=mshab',
'--data_dir=...',
'--in_context_cond_dim=18',
'--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline',
'--resume_from_checkpoint=latest',
'--train_batch_size=20',
'--sample_batch_size=8',
'--gradient_accumulation_steps=1',
'--max_train_steps=30000',
'--checkpointing_period=2500',
'--sample_period=200',
'--checkpoints_total_limit=40',
'--lr_scheduler=constant',
'--learning_rate=1e-5',
'--mixed_precision=bf16',
'--dataloader_num_workers=8',
'--image_aug',
'--dataset_type=finetune',
'--state_noise_snr=40',
'--load_from_hdf5',
'--precomp_lang_embed',
'--num_episode_per_task=1000',
'--report_to=wandb',
```

### 4. config.yaml
`configs/config.yaml` 保持 baseline:
- `action_chunk_size: 2`
- `img_history_size: 2`
- `num_cameras: 3`
- `state_dim: 128`

---

## 代码改动(最小化, 仅加 val eval 支持)

vs commit `90ad00a`(init: release codes)的差异:

### `data/hdf5_mshab_dataset.py`
-`split: Literal["train","val"]="train"``num_episode_per_task: int=100` 构造参数
-`self.split` 替代硬编码 `'train'` 路径(3 处)
-`self._cfg_num_episode_per_task` 替代硬编码 `100`

### `train/dataset.py`
- `VLAConsumerDataset.__init__``split="train"``num_episode_per_task=100` 参数
- 透传给 `HDF5MSHABDataset`

### `train/train.py`
- 新建 `val_dataset` (split="val", num_episode_per_task=100)
- 新建 `val_dataloader`(用 `accelerator.prepare_data_loader` 单独 prepare)
-`sample_period` 触发时除了原 sample_loss_for_log,还跑 val:
```python
val_loss_for_log = log_sample_res(..., val_dataloader, ...)
val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()}
accelerator.log(val_loss_for_log, step=global_step)
```

### `main.py`
- 加 `--num_episode_per_task` argparse arg

### 数据准备(一次性)
1. 生成 val 数据 100 traj/combo(用 `gen_val_combo.sbatch`,SPLIT=val 跑 7 个 combo)
2. 给每个 val 目录建 `instructions` 符号链接指向对应 train 目录(语言 embed 通用):
```bash
for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do
  ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions
done
```

---

## 评测脚本(本地, RTX A6000)

### eval 命令模板
```bash
export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH
PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
CKPT_STEP=25000

PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \
MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \
DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \
TASK=set_table SUBTASK=pick OBJECT=013_apple \
NUM_TRAJ=50 CUDA=0 \
RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \
bash scripts/eval_acdit.sh
```

### 7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT
| short | TASK | SUBTASK | OBJECT |
|---|---|---|---|
| pick_apple | set_table | pick | 013_apple |
| pick_bowl | set_table | pick | 024_bowl |
| place_apple | set_table | place | 013_apple |
| place_bowl | set_table | place | 024_bowl |
| open_fridge | set_table | open | fridge |
| open_kc | set_table | open | kitchen_counter |
| close_kc | set_table | close | kitchen_counter |

### eval_acdit.sh 内部
直接调 `python -m scripts.eval_mshab` 跑 ManiSkill env 50 trial 实测。
每 trial 200 步 timeout。
结果写 `$RESULT_DIR/<combo>/`- `trial_<N>_{success|failure}.mp4`
- `success_rate.txt`

### 并行评测建议
本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。
每 eval 约 25-90 min(依赖 GPU 共享情况)。

---

## 复现步骤

### 训练
```bash
# 在 HPC3 上
ssh HPC3_jhe724
cd /data/user/jhe724/workspace/AC-DiT

# 提交训练(8 GPU × 16h)
sbatch scripts/train_all7.sbatch

# 监控
squeue -u jhe724
tail -f logs/train_s2_all7_baseline.log
```

### 评测
```bash
# 在本地 LFT-W02 上
cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT

# rsync ckpt
rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \
  ./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/

# 跑评测(每 combo 一条命令,参考上面)
```

---

## 失败教训(避坑)

1. **错觉**:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字
2. **数据量**:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ **训练数据规模是关键**
3. **过拟合**:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ **早停而非训满**
4. **Action 异常**:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。**Loss 上的尖峰来自这些样本**。不影响最终成功率(因为 env 也 clip)但训练监督不科学
5. **val 数据准备**:用 `gen_data.py` SPLIT=val 生成时不会自动建 `instructions/` 子目录 → 训练时 `np.random.choice` 报"empty"。修:从 train 目录软链 `instructions`
6. **过早评测**:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 **完全对齐论文配置**

---

## Wandb metrics 命名
- 训练 sample loss:`mshab_sample_mse`, `mshab_sample_l2err`, `overall_avg_sample_mse`, `overall_avg_sample_l2err`
- val loss(held-out 700 traj):`val/mshab_sample_mse`, `val/mshab_sample_l2err`, `val/overall_avg_sample_mse`, `val/overall_avg_sample_l2err`
- per-step:`loss`, `lr`

EXP_NAME = `AC-DiT-mshab-all7_baseline` (wandb run name)