AC-DiT-MSHab-Reproduction / REPRODUCTION.md
JJho1314's picture
Upload REPRODUCTION.md with huggingface_hub
f57e782 verified
# AC-DiT 复现记录
## 最终结果(all-7-combo baseline,ckpt-25000)
| Combo | 我们 ckpt-25000 | 我们 ckpt-30000 | 论文 (100×3 runs) |
|---|---|---|---|
| pick_apple | 26.0% | **32.0%** | 33.3 ± 1.9 |
| pick_bowl | **42.0%** | 24.0% | 36.0 ± 6.5 |
| place_apple | **34.0%** | 18.0% | 33.3 ± 9.4 |
| place_bowl | **48.0%** | 24.0% | 17.3 ± 6.8 |
| open_fridge | 92.0% | **96.0%** | 90.7 ± 5.0 |
| open_kc | 74.0% | **78.0%** | 81.3 ± 6.8 |
| close_kc | **100.0%** | **100.0%** | 97.3 ± 1.9 |
| **Mean S.R.** | **59.4%** | 53.1% | **55.6%** |
**ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。
ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。
**推荐发布权重:ckpt-25000。**
---
## 权重路径
### HPC3
```
/data/user/jhe724/workspace/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/ # Stage 1 mobility head (7 combos pooled, --base_only)
│ └── pytorch_model/mp_rank_00_model_states.pt
└── AC-DiT-mshab-all7_baseline/ # Stage 2 main model
├── checkpoint-25000/ ← 最佳
├── checkpoint-27500/
├── checkpoint-30000/ (轻微过拟合)
└── checkpoint-{2500..22500}/ (中间 ckpt)
```
### 本地 LFT-W02
```
/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/
└── AC-DiT-mshab-all7_baseline/
├── checkpoint-25000/ ← 评测用这个
├── checkpoint-15000/
├── checkpoint-20000/
└── checkpoint-30000/
```
---
## 训练脚本(HPC3, 8×H100, ~16h)
### 1. sbatch wrapper
`/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch`
```bash
#!/bin/bash
#SBATCH -p acd_u
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=384G
#SBATCH --gres=gpu:8
#SBATCH -t 7-00:00:00
#SBATCH -J all7_base
PROJECT=/data/user/jhe724/workspace/AC-DiT
echo "=== sbatch host=$(hostname) job=$SLURM_JOB_ID time=$(date) ==="
cd $PROJECT
exec bash scripts/launch_s2_all7_baseline.sh
```
提交:`sbatch scripts/train_all7.sbatch`
### 2. launch 脚本
`scripts/launch_s2_all7_baseline.sh` 做三件事:
1.`data/hdf5_mshab_dataset.py``task_subtask_obj` 恢复成全 7 combo
2. patch `scripts/finetune_mshab_acdit.py`(用 .orig 还原后 sed 改):
- `--data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories`(lustre 直读,省 2T+ /tmp copy)
- `--num_episode_per_task=1000` (全 1000 traj/combo = 7000 total)
- `--train_batch_size=20`(× 8 GPU = 160 effective)
- `--sample_batch_size=8`(val eval 用)
- `--max_train_steps=30000`
- `--sample_period=200`(每 200 步 sample + val eval 写 wandb)
- `--mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt`
- EXP_NAME = `AC-DiT-mshab-all7_baseline`
3.`python -m scripts.finetune_mshab_acdit`
### 3. finetune driver
`scripts/finetune_mshab_acdit.py` 主要 accelerate_command:
```python
'accelerate', 'launch',
'--main_process_port=29905',
'main.py',
'--deepspeed=./configs/zero2.json',
'--method_name=AC-DiT',
'--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b',
'--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--mobility_head_ckpt_path=...',
'--config_path=configs/config.yaml',
'--hdf5_dataset_name=mshab',
'--data_dir=...',
'--in_context_cond_dim=18',
'--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline',
'--resume_from_checkpoint=latest',
'--train_batch_size=20',
'--sample_batch_size=8',
'--gradient_accumulation_steps=1',
'--max_train_steps=30000',
'--checkpointing_period=2500',
'--sample_period=200',
'--checkpoints_total_limit=40',
'--lr_scheduler=constant',
'--learning_rate=1e-5',
'--mixed_precision=bf16',
'--dataloader_num_workers=8',
'--image_aug',
'--dataset_type=finetune',
'--state_noise_snr=40',
'--load_from_hdf5',
'--precomp_lang_embed',
'--num_episode_per_task=1000',
'--report_to=wandb',
```
### 4. config.yaml
`configs/config.yaml` 保持 baseline:
- `action_chunk_size: 2`
- `img_history_size: 2`
- `num_cameras: 3`
- `state_dim: 128`
---
## 代码改动(最小化, 仅加 val eval 支持)
vs commit `90ad00a`(init: release codes)的差异:
### `data/hdf5_mshab_dataset.py`
-`split: Literal["train","val"]="train"``num_episode_per_task: int=100` 构造参数
-`self.split` 替代硬编码 `'train'` 路径(3 处)
-`self._cfg_num_episode_per_task` 替代硬编码 `100`
### `train/dataset.py`
- `VLAConsumerDataset.__init__``split="train"``num_episode_per_task=100` 参数
- 透传给 `HDF5MSHABDataset`
### `train/train.py`
- 新建 `val_dataset` (split="val", num_episode_per_task=100)
- 新建 `val_dataloader`(用 `accelerator.prepare_data_loader` 单独 prepare)
-`sample_period` 触发时除了原 sample_loss_for_log,还跑 val:
```python
val_loss_for_log = log_sample_res(..., val_dataloader, ...)
val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()}
accelerator.log(val_loss_for_log, step=global_step)
```
### `main.py`
- 加 `--num_episode_per_task` argparse arg
### 数据准备(一次性)
1. 生成 val 数据 100 traj/combo(用 `gen_val_combo.sbatch`,SPLIT=val 跑 7 个 combo)
2. 给每个 val 目录建 `instructions` 符号链接指向对应 train 目录(语言 embed 通用):
```bash
for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do
ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions
done
```
---
## 评测脚本(本地, RTX A6000)
### eval 命令模板
```bash
export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH
PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
CKPT_STEP=25000
PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \
MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \
DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \
TASK=set_table SUBTASK=pick OBJECT=013_apple \
NUM_TRAJ=50 CUDA=0 \
RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \
bash scripts/eval_acdit.sh
```
### 7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT
| short | TASK | SUBTASK | OBJECT |
|---|---|---|---|
| pick_apple | set_table | pick | 013_apple |
| pick_bowl | set_table | pick | 024_bowl |
| place_apple | set_table | place | 013_apple |
| place_bowl | set_table | place | 024_bowl |
| open_fridge | set_table | open | fridge |
| open_kc | set_table | open | kitchen_counter |
| close_kc | set_table | close | kitchen_counter |
### eval_acdit.sh 内部
直接调 `python -m scripts.eval_mshab` 跑 ManiSkill env 50 trial 实测。
每 trial 200 步 timeout。
结果写 `$RESULT_DIR/<combo>/`
- `trial_<N>_{success|failure}.mp4`
- `success_rate.txt`
### 并行评测建议
本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。
每 eval 约 25-90 min(依赖 GPU 共享情况)。
---
## 复现步骤
### 训练
```bash
# 在 HPC3 上
ssh HPC3_jhe724
cd /data/user/jhe724/workspace/AC-DiT
# 提交训练(8 GPU × 16h)
sbatch scripts/train_all7.sbatch
# 监控
squeue -u jhe724
tail -f logs/train_s2_all7_baseline.log
```
### 评测
```bash
# 在本地 LFT-W02 上
cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
# rsync ckpt
rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \
./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/
# 跑评测(每 combo 一条命令,参考上面)
```
---
## 失败教训(避坑)
1. **错觉**:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字
2. **数据量**:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ **训练数据规模是关键**
3. **过拟合**:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ **早停而非训满**
4. **Action 异常**:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。**Loss 上的尖峰来自这些样本**。不影响最终成功率(因为 env 也 clip)但训练监督不科学
5. **val 数据准备**:用 `gen_data.py` SPLIT=val 生成时不会自动建 `instructions/` 子目录 → 训练时 `np.random.choice` 报"empty"。修:从 train 目录软链 `instructions`
6. **过早评测**:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 **完全对齐论文配置**
---
## Wandb metrics 命名
- 训练 sample loss:`mshab_sample_mse`, `mshab_sample_l2err`, `overall_avg_sample_mse`, `overall_avg_sample_l2err`
- val loss(held-out 700 traj):`val/mshab_sample_mse`, `val/mshab_sample_l2err`, `val/overall_avg_sample_mse`, `val/overall_avg_sample_l2err`
- per-step:`loss`, `lr`
EXP_NAME = `AC-DiT-mshab-all7_baseline` (wandb run name)