| # AC-DiT 复现记录 |
|
|
| ## 最终结果(all-7-combo baseline,ckpt-25000) |
|
|
| | Combo | 我们 ckpt-25000 | 我们 ckpt-30000 | 论文 (100×3 runs) | |
| |---|---|---|---| |
| | pick_apple | 26.0% | **32.0%** | 33.3 ± 1.9 | |
| | pick_bowl | **42.0%** | 24.0% | 36.0 ± 6.5 | |
| | place_apple | **34.0%** | 18.0% | 33.3 ± 9.4 | |
| | place_bowl | **48.0%** | 24.0% | 17.3 ± 6.8 | |
| | open_fridge | 92.0% | **96.0%** | 90.7 ± 5.0 | |
| | open_kc | 74.0% | **78.0%** | 81.3 ± 6.8 | |
| | close_kc | **100.0%** | **100.0%** | 97.3 ± 1.9 | |
| | **Mean S.R.** | **59.4%** | 53.1% | **55.6%** | |
| |
| **ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。 |
| ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。 |
|
|
| **推荐发布权重:ckpt-25000。** |
|
|
| --- |
|
|
| ## 权重路径 |
|
|
| ### HPC3 |
| ``` |
| /data/user/jhe724/workspace/AC-DiT/checkpoints/ |
| ├── DiT-mshab-base-only/checkpoint-30000/ # Stage 1 mobility head (7 combos pooled, --base_only) |
| │ └── pytorch_model/mp_rank_00_model_states.pt |
| └── AC-DiT-mshab-all7_baseline/ # Stage 2 main model |
| ├── checkpoint-25000/ ← 最佳 |
| ├── checkpoint-27500/ |
| ├── checkpoint-30000/ (轻微过拟合) |
| └── checkpoint-{2500..22500}/ (中间 ckpt) |
| ``` |
|
|
| ### 本地 LFT-W02 |
| ``` |
| /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/ |
| ├── DiT-mshab-base-only/checkpoint-30000/ |
| └── AC-DiT-mshab-all7_baseline/ |
| ├── checkpoint-25000/ ← 评测用这个 |
| ├── checkpoint-15000/ |
| ├── checkpoint-20000/ |
| └── checkpoint-30000/ |
| ``` |
|
|
| --- |
|
|
| ## 训练脚本(HPC3, 8×H100, ~16h) |
|
|
| ### 1. sbatch wrapper |
| `/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch` |
| ```bash |
| #!/bin/bash |
| #SBATCH -p acd_u |
| #SBATCH --nodes=1 |
| #SBATCH --ntasks-per-node=1 |
| #SBATCH --cpus-per-task=64 |
| #SBATCH --mem=384G |
| #SBATCH --gres=gpu:8 |
| #SBATCH -t 7-00:00:00 |
| #SBATCH -J all7_base |
| PROJECT=/data/user/jhe724/workspace/AC-DiT |
| echo "=== sbatch host=$(hostname) job=$SLURM_JOB_ID time=$(date) ===" |
| cd $PROJECT |
| exec bash scripts/launch_s2_all7_baseline.sh |
| ``` |
|
|
| 提交:`sbatch scripts/train_all7.sbatch` |
|
|
| ### 2. launch 脚本 |
| `scripts/launch_s2_all7_baseline.sh` 做三件事: |
| 1. 把 `data/hdf5_mshab_dataset.py` 的 `task_subtask_obj` 恢复成全 7 combo |
| 2. patch `scripts/finetune_mshab_acdit.py`(用 .orig 还原后 sed 改): |
| - `--data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories`(lustre 直读,省 2T+ /tmp copy) |
| - `--num_episode_per_task=1000` (全 1000 traj/combo = 7000 total) |
| - `--train_batch_size=20`(× 8 GPU = 160 effective) |
| - `--sample_batch_size=8`(val eval 用) |
| - `--max_train_steps=30000` |
| - `--sample_period=200`(每 200 步 sample + val eval 写 wandb) |
| - `--mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt` |
| - EXP_NAME = `AC-DiT-mshab-all7_baseline` |
| 3. 跑 `python -m scripts.finetune_mshab_acdit` |
|
|
| ### 3. finetune driver |
| `scripts/finetune_mshab_acdit.py` 主要 accelerate_command: |
| ```python |
| 'accelerate', 'launch', |
| '--main_process_port=29905', |
| 'main.py', |
| '--deepspeed=./configs/zero2.json', |
| '--method_name=AC-DiT', |
| '--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b', |
| '--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384', |
| '--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384', |
| '--mobility_head_ckpt_path=...', |
| '--config_path=configs/config.yaml', |
| '--hdf5_dataset_name=mshab', |
| '--data_dir=...', |
| '--in_context_cond_dim=18', |
| '--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline', |
| '--resume_from_checkpoint=latest', |
| '--train_batch_size=20', |
| '--sample_batch_size=8', |
| '--gradient_accumulation_steps=1', |
| '--max_train_steps=30000', |
| '--checkpointing_period=2500', |
| '--sample_period=200', |
| '--checkpoints_total_limit=40', |
| '--lr_scheduler=constant', |
| '--learning_rate=1e-5', |
| '--mixed_precision=bf16', |
| '--dataloader_num_workers=8', |
| '--image_aug', |
| '--dataset_type=finetune', |
| '--state_noise_snr=40', |
| '--load_from_hdf5', |
| '--precomp_lang_embed', |
| '--num_episode_per_task=1000', |
| '--report_to=wandb', |
| ``` |
| |
| ### 4. config.yaml |
| `configs/config.yaml` 保持 baseline: |
| - `action_chunk_size: 2` |
| - `img_history_size: 2` |
| - `num_cameras: 3` |
| - `state_dim: 128` |
|
|
| --- |
|
|
| ## 代码改动(最小化, 仅加 val eval 支持) |
|
|
| vs commit `90ad00a`(init: release codes)的差异: |
|
|
| ### `data/hdf5_mshab_dataset.py` |
| - 加 `split: Literal["train","val"]="train"` 和 `num_episode_per_task: int=100` 构造参数 |
| - 用 `self.split` 替代硬编码 `'train'` 路径(3 处) |
| - 用 `self._cfg_num_episode_per_task` 替代硬编码 `100` |
|
|
| ### `train/dataset.py` |
| - `VLAConsumerDataset.__init__` 加 `split="train"` 和 `num_episode_per_task=100` 参数 |
| - 透传给 `HDF5MSHABDataset` |
|
|
| ### `train/train.py` |
| - 新建 `val_dataset` (split="val", num_episode_per_task=100) |
| - 新建 `val_dataloader`(用 `accelerator.prepare_data_loader` 单独 prepare) |
| - 在 `sample_period` 触发时除了原 sample_loss_for_log,还跑 val: |
| ```python |
| val_loss_for_log = log_sample_res(..., val_dataloader, ...) |
| val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()} |
| accelerator.log(val_loss_for_log, step=global_step) |
| ``` |
| |
| ### `main.py` |
| - 加 `--num_episode_per_task` argparse arg |
|
|
| ### 数据准备(一次性) |
| 1. 生成 val 数据 100 traj/combo(用 `gen_val_combo.sbatch`,SPLIT=val 跑 7 个 combo) |
| 2. 给每个 val 目录建 `instructions` 符号链接指向对应 train 目录(语言 embed 通用): |
| ```bash |
| for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do |
| ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions |
| done |
| ``` |
|
|
| --- |
|
|
| ## 评测脚本(本地, RTX A6000) |
|
|
| ### eval 命令模板 |
| ```bash |
| export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH |
| PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT |
| CKPT_STEP=25000 |
| |
| PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \ |
| MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \ |
| DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \ |
| TASK=set_table SUBTASK=pick OBJECT=013_apple \ |
| NUM_TRAJ=50 CUDA=0 \ |
| RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \ |
| bash scripts/eval_acdit.sh |
| ``` |
|
|
| ### 7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT |
| | short | TASK | SUBTASK | OBJECT | |
| |---|---|---|---| |
| | pick_apple | set_table | pick | 013_apple | |
| | pick_bowl | set_table | pick | 024_bowl | |
| | place_apple | set_table | place | 013_apple | |
| | place_bowl | set_table | place | 024_bowl | |
| | open_fridge | set_table | open | fridge | |
| | open_kc | set_table | open | kitchen_counter | |
| | close_kc | set_table | close | kitchen_counter | |
| |
| ### eval_acdit.sh 内部 |
| 直接调 `python -m scripts.eval_mshab` 跑 ManiSkill env 50 trial 实测。 |
| 每 trial 200 步 timeout。 |
| 结果写 `$RESULT_DIR/<combo>/`: |
| - `trial_<N>_{success|failure}.mp4` |
| - `success_rate.txt` |
|
|
| ### 并行评测建议 |
| 本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。 |
| 每 eval 约 25-90 min(依赖 GPU 共享情况)。 |
|
|
| --- |
|
|
| ## 复现步骤 |
|
|
| ### 训练 |
| ```bash |
| # 在 HPC3 上 |
| ssh HPC3_jhe724 |
| cd /data/user/jhe724/workspace/AC-DiT |
| |
| # 提交训练(8 GPU × 16h) |
| sbatch scripts/train_all7.sbatch |
| |
| # 监控 |
| squeue -u jhe724 |
| tail -f logs/train_s2_all7_baseline.log |
| ``` |
|
|
| ### 评测 |
| ```bash |
| # 在本地 LFT-W02 上 |
| cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT |
| |
| # rsync ckpt |
| rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \ |
| ./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/ |
| |
| # 跑评测(每 combo 一条命令,参考上面) |
| ``` |
|
|
| --- |
|
|
| ## 失败教训(避坑) |
|
|
| 1. **错觉**:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字 |
| 2. **数据量**:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ **训练数据规模是关键** |
| 3. **过拟合**:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ **早停而非训满** |
| 4. **Action 异常**:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。**Loss 上的尖峰来自这些样本**。不影响最终成功率(因为 env 也 clip)但训练监督不科学 |
| 5. **val 数据准备**:用 `gen_data.py` SPLIT=val 生成时不会自动建 `instructions/` 子目录 → 训练时 `np.random.choice` 报"empty"。修:从 train 目录软链 `instructions` |
| 6. **过早评测**:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 **完全对齐论文配置** |
|
|
| --- |
|
|
| ## Wandb metrics 命名 |
| - 训练 sample loss:`mshab_sample_mse`, `mshab_sample_l2err`, `overall_avg_sample_mse`, `overall_avg_sample_l2err` |
| - val loss(held-out 700 traj):`val/mshab_sample_mse`, `val/mshab_sample_l2err`, `val/overall_avg_sample_mse`, `val/overall_avg_sample_l2err` |
| - per-step:`loss`, `lr` |
|
|
| EXP_NAME = `AC-DiT-mshab-all7_baseline` (wandb run name) |
|
|