AC-DiT-MSHab-Reproduction / REPRODUCTION.md
JJho1314's picture
Upload REPRODUCTION.md with huggingface_hub
f57e782 verified

AC-DiT 复现记录

最终结果(all-7-combo baseline,ckpt-25000)

Combo 我们 ckpt-25000 我们 ckpt-30000 论文 (100×3 runs)
pick_apple 26.0% 32.0% 33.3 ± 1.9
pick_bowl 42.0% 24.0% 36.0 ± 6.5
place_apple 34.0% 18.0% 33.3 ± 9.4
place_bowl 48.0% 24.0% 17.3 ± 6.8
open_fridge 92.0% 96.0% 90.7 ± 5.0
open_kc 74.0% 78.0% 81.3 ± 6.8
close_kc 100.0% 100.0% 97.3 ± 1.9
Mean S.R. 59.4% 53.1% 55.6%

**ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。 ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。

推荐发布权重:ckpt-25000。


权重路径

HPC3

/data/user/jhe724/workspace/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/         # Stage 1 mobility head (7 combos pooled, --base_only)
│   └── pytorch_model/mp_rank_00_model_states.pt
└── AC-DiT-mshab-all7_baseline/                   # Stage 2 main model
    ├── checkpoint-25000/  ← 最佳
    ├── checkpoint-27500/
    ├── checkpoint-30000/  (轻微过拟合)
    └── checkpoint-{2500..22500}/  (中间 ckpt)

本地 LFT-W02

/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/
├── DiT-mshab-base-only/checkpoint-30000/
└── AC-DiT-mshab-all7_baseline/
    ├── checkpoint-25000/  ← 评测用这个
    ├── checkpoint-15000/
    ├── checkpoint-20000/
    └── checkpoint-30000/

训练脚本(HPC3, 8×H100, ~16h)

1. sbatch wrapper

/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch

#!/bin/bash
#SBATCH -p acd_u
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=64
#SBATCH --mem=384G
#SBATCH --gres=gpu:8
#SBATCH -t 7-00:00:00
#SBATCH -J all7_base
PROJECT=/data/user/jhe724/workspace/AC-DiT
echo "=== sbatch host=$(hostname)  job=$SLURM_JOB_ID  time=$(date) ==="
cd $PROJECT
exec bash scripts/launch_s2_all7_baseline.sh

提交:sbatch scripts/train_all7.sbatch

2. launch 脚本

scripts/launch_s2_all7_baseline.sh 做三件事:

  1. data/hdf5_mshab_dataset.pytask_subtask_obj 恢复成全 7 combo
  2. patch scripts/finetune_mshab_acdit.py(用 .orig 还原后 sed 改):
    • --data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories(lustre 直读,省 2T+ /tmp copy)
    • --num_episode_per_task=1000 (全 1000 traj/combo = 7000 total)
    • --train_batch_size=20(× 8 GPU = 160 effective)
    • --sample_batch_size=8(val eval 用)
    • --max_train_steps=30000
    • --sample_period=200(每 200 步 sample + val eval 写 wandb)
    • --mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt
    • EXP_NAME = AC-DiT-mshab-all7_baseline
  3. python -m scripts.finetune_mshab_acdit

3. finetune driver

scripts/finetune_mshab_acdit.py 主要 accelerate_command:

'accelerate', 'launch',
'--main_process_port=29905',
'main.py',
'--deepspeed=./configs/zero2.json',
'--method_name=AC-DiT',
'--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b',
'--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
'--mobility_head_ckpt_path=...',
'--config_path=configs/config.yaml',
'--hdf5_dataset_name=mshab',
'--data_dir=...',
'--in_context_cond_dim=18',
'--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline',
'--resume_from_checkpoint=latest',
'--train_batch_size=20',
'--sample_batch_size=8',
'--gradient_accumulation_steps=1',
'--max_train_steps=30000',
'--checkpointing_period=2500',
'--sample_period=200',
'--checkpoints_total_limit=40',
'--lr_scheduler=constant',
'--learning_rate=1e-5',
'--mixed_precision=bf16',
'--dataloader_num_workers=8',
'--image_aug',
'--dataset_type=finetune',
'--state_noise_snr=40',
'--load_from_hdf5',
'--precomp_lang_embed',
'--num_episode_per_task=1000',
'--report_to=wandb',

4. config.yaml

configs/config.yaml 保持 baseline:

  • action_chunk_size: 2
  • img_history_size: 2
  • num_cameras: 3
  • state_dim: 128

代码改动(最小化, 仅加 val eval 支持)

vs commit 90ad00a(init: release codes)的差异:

data/hdf5_mshab_dataset.py

  • split: Literal["train","val"]="train"num_episode_per_task: int=100 构造参数
  • self.split 替代硬编码 'train' 路径(3 处)
  • self._cfg_num_episode_per_task 替代硬编码 100

train/dataset.py

  • VLAConsumerDataset.__init__split="train"num_episode_per_task=100 参数
  • 透传给 HDF5MSHABDataset

train/train.py

  • 新建 val_dataset (split="val", num_episode_per_task=100)
  • 新建 val_dataloader(用 accelerator.prepare_data_loader 单独 prepare)
  • sample_period 触发时除了原 sample_loss_for_log,还跑 val:
val_loss_for_log = log_sample_res(..., val_dataloader, ...)
val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()}
accelerator.log(val_loss_for_log, step=global_step)

main.py

  • --num_episode_per_task argparse arg

数据准备(一次性)

  1. 生成 val 数据 100 traj/combo(用 gen_val_combo.sbatch,SPLIT=val 跑 7 个 combo)
  2. 给每个 val 目录建 instructions 符号链接指向对应 train 目录(语言 embed 通用):
for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do
  ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions
done

评测脚本(本地, RTX A6000)

eval 命令模板

export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH
PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
CKPT_STEP=25000

PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \
MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \
DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \
TASK=set_table SUBTASK=pick OBJECT=013_apple \
NUM_TRAJ=50 CUDA=0 \
RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \
bash scripts/eval_acdit.sh

7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT

short TASK SUBTASK OBJECT
pick_apple set_table pick 013_apple
pick_bowl set_table pick 024_bowl
place_apple set_table place 013_apple
place_bowl set_table place 024_bowl
open_fridge set_table open fridge
open_kc set_table open kitchen_counter
close_kc set_table close kitchen_counter

eval_acdit.sh 内部

直接调 python -m scripts.eval_mshab 跑 ManiSkill env 50 trial 实测。 每 trial 200 步 timeout。 结果写 $RESULT_DIR/<combo>/

  • trial_<N>_{success|failure}.mp4
  • success_rate.txt

并行评测建议

本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。 每 eval 约 25-90 min(依赖 GPU 共享情况)。


复现步骤

训练

# 在 HPC3 上
ssh HPC3_jhe724
cd /data/user/jhe724/workspace/AC-DiT

# 提交训练(8 GPU × 16h)
sbatch scripts/train_all7.sbatch

# 监控
squeue -u jhe724
tail -f logs/train_s2_all7_baseline.log

评测

# 在本地 LFT-W02 上
cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT

# rsync ckpt
rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \
  ./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/

# 跑评测(每 combo 一条命令,参考上面)

失败教训(避坑)

  1. 错觉:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字
  2. 数据量:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ 训练数据规模是关键
  3. 过拟合:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ 早停而非训满
  4. Action 异常:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。Loss 上的尖峰来自这些样本。不影响最终成功率(因为 env 也 clip)但训练监督不科学
  5. val 数据准备:用 gen_data.py SPLIT=val 生成时不会自动建 instructions/ 子目录 → 训练时 np.random.choice 报"empty"。修:从 train 目录软链 instructions
  6. 过早评测:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 完全对齐论文配置

Wandb metrics 命名

  • 训练 sample loss:mshab_sample_mse, mshab_sample_l2err, overall_avg_sample_mse, overall_avg_sample_l2err
  • val loss(held-out 700 traj):val/mshab_sample_mse, val/mshab_sample_l2err, val/overall_avg_sample_mse, val/overall_avg_sample_l2err
  • per-step:loss, lr

EXP_NAME = AC-DiT-mshab-all7_baseline (wandb run name)