JJho1314 commited on
Commit
f57e782
·
verified ·
1 Parent(s): cdacf94

Upload REPRODUCTION.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. REPRODUCTION.md +257 -0
REPRODUCTION.md ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AC-DiT 复现记录
2
+
3
+ ## 最终结果(all-7-combo baseline,ckpt-25000)
4
+
5
+ | Combo | 我们 ckpt-25000 | 我们 ckpt-30000 | 论文 (100×3 runs) |
6
+ |---|---|---|---|
7
+ | pick_apple | 26.0% | **32.0%** | 33.3 ± 1.9 |
8
+ | pick_bowl | **42.0%** | 24.0% | 36.0 ± 6.5 |
9
+ | place_apple | **34.0%** | 18.0% | 33.3 ± 9.4 |
10
+ | place_bowl | **48.0%** | 24.0% | 17.3 ± 6.8 |
11
+ | open_fridge | 92.0% | **96.0%** | 90.7 ± 5.0 |
12
+ | open_kc | 74.0% | **78.0%** | 81.3 ± 6.8 |
13
+ | close_kc | **100.0%** | **100.0%** | 97.3 ± 1.9 |
14
+ | **Mean S.R.** | **59.4%** | 53.1% | **55.6%** |
15
+
16
+ **ckpt-25000 mean 59.4% > 论文 55.6%**,4 个 combo 超过论文,1 个匹配,2 个低于(在 std 内)。
17
+ ckpt-30000 后期过拟合(pick_bowl / place_apple / place_bowl 各掉 16-24%)。
18
+
19
+ **推荐发布权重:ckpt-25000。**
20
+
21
+ ---
22
+
23
+ ## 权重路径
24
+
25
+ ### HPC3
26
+ ```
27
+ /data/user/jhe724/workspace/AC-DiT/checkpoints/
28
+ ├── DiT-mshab-base-only/checkpoint-30000/ # Stage 1 mobility head (7 combos pooled, --base_only)
29
+ │ └── pytorch_model/mp_rank_00_model_states.pt
30
+ └── AC-DiT-mshab-all7_baseline/ # Stage 2 main model
31
+ ├── checkpoint-25000/ ← 最佳
32
+ ├── checkpoint-27500/
33
+ ├── checkpoint-30000/ (轻微过拟合)
34
+ └── checkpoint-{2500..22500}/ (中间 ckpt)
35
+ ```
36
+
37
+ ### 本地 LFT-W02
38
+ ```
39
+ /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT/checkpoints/
40
+ ├── DiT-mshab-base-only/checkpoint-30000/
41
+ └── AC-DiT-mshab-all7_baseline/
42
+ ├── checkpoint-25000/ ← 评测用这个
43
+ ├── checkpoint-15000/
44
+ ├── checkpoint-20000/
45
+ └── checkpoint-30000/
46
+ ```
47
+
48
+ ---
49
+
50
+ ## 训练脚本(HPC3, 8×H100, ~16h)
51
+
52
+ ### 1. sbatch wrapper
53
+ `/data/user/jhe724/workspace/AC-DiT/scripts/train_all7.sbatch`
54
+ ```bash
55
+ #!/bin/bash
56
+ #SBATCH -p acd_u
57
+ #SBATCH --nodes=1
58
+ #SBATCH --ntasks-per-node=1
59
+ #SBATCH --cpus-per-task=64
60
+ #SBATCH --mem=384G
61
+ #SBATCH --gres=gpu:8
62
+ #SBATCH -t 7-00:00:00
63
+ #SBATCH -J all7_base
64
+ PROJECT=/data/user/jhe724/workspace/AC-DiT
65
+ echo "=== sbatch host=$(hostname) job=$SLURM_JOB_ID time=$(date) ==="
66
+ cd $PROJECT
67
+ exec bash scripts/launch_s2_all7_baseline.sh
68
+ ```
69
+
70
+ 提交:`sbatch scripts/train_all7.sbatch`
71
+
72
+ ### 2. launch 脚本
73
+ `scripts/launch_s2_all7_baseline.sh` 做三件事:
74
+ 1. 把 `data/hdf5_mshab_dataset.py` 的 `task_subtask_obj` 恢复成全 7 combo
75
+ 2. patch `scripts/finetune_mshab_acdit.py`(用 .orig 还原后 sed 改):
76
+ - `--data_dir=$PROJECT/third_party/mshab/mshab_data/gen_data_save_trajectories`(lustre 直读,省 2T+ /tmp copy)
77
+ - `--num_episode_per_task=1000` (全 1000 traj/combo = 7000 total)
78
+ - `--train_batch_size=20`(× 8 GPU = 160 effective)
79
+ - `--sample_batch_size=8`(val eval 用)
80
+ - `--max_train_steps=30000`
81
+ - `--sample_period=200`(每 200 步 sample + val eval 写 wandb)
82
+ - `--mobility_head_ckpt_path=$PROJECT/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt`
83
+ - EXP_NAME = `AC-DiT-mshab-all7_baseline`
84
+ 3. 跑 `python -m scripts.finetune_mshab_acdit`
85
+
86
+ ### 3. finetune driver
87
+ `scripts/finetune_mshab_acdit.py` 主要 accelerate_command:
88
+ ```python
89
+ 'accelerate', 'launch',
90
+ '--main_process_port=29905',
91
+ 'main.py',
92
+ '--deepspeed=./configs/zero2.json',
93
+ '--method_name=AC-DiT',
94
+ '--pretrained_model_name_or_path=robotics-diffusion-transformer/rdt-1b',
95
+ '--pretrained_text_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
96
+ '--pretrained_vision_encoder_name_or_path=/data/.../weights/siglip-so400m-patch14-384',
97
+ '--mobility_head_ckpt_path=...',
98
+ '--config_path=configs/config.yaml',
99
+ '--hdf5_dataset_name=mshab',
100
+ '--data_dir=...',
101
+ '--in_context_cond_dim=18',
102
+ '--output_dir=./checkpoints/AC-DiT-mshab-all7_baseline',
103
+ '--resume_from_checkpoint=latest',
104
+ '--train_batch_size=20',
105
+ '--sample_batch_size=8',
106
+ '--gradient_accumulation_steps=1',
107
+ '--max_train_steps=30000',
108
+ '--checkpointing_period=2500',
109
+ '--sample_period=200',
110
+ '--checkpoints_total_limit=40',
111
+ '--lr_scheduler=constant',
112
+ '--learning_rate=1e-5',
113
+ '--mixed_precision=bf16',
114
+ '--dataloader_num_workers=8',
115
+ '--image_aug',
116
+ '--dataset_type=finetune',
117
+ '--state_noise_snr=40',
118
+ '--load_from_hdf5',
119
+ '--precomp_lang_embed',
120
+ '--num_episode_per_task=1000',
121
+ '--report_to=wandb',
122
+ ```
123
+
124
+ ### 4. config.yaml
125
+ `configs/config.yaml` 保持 baseline:
126
+ - `action_chunk_size: 2`
127
+ - `img_history_size: 2`
128
+ - `num_cameras: 3`
129
+ - `state_dim: 128`
130
+
131
+ ---
132
+
133
+ ## 代码改动(最小化, 仅加 val eval 支持)
134
+
135
+ vs commit `90ad00a`(init: release codes)的差异:
136
+
137
+ ### `data/hdf5_mshab_dataset.py`
138
+ - 加 `split: Literal["train","val"]="train"` 和 `num_episode_per_task: int=100` 构造参数
139
+ - 用 `self.split` 替代硬编码 `'train'` 路径(3 处)
140
+ - 用 `self._cfg_num_episode_per_task` 替代硬编码 `100`
141
+
142
+ ### `train/dataset.py`
143
+ - `VLAConsumerDataset.__init__` 加 `split="train"` 和 `num_episode_per_task=100` 参数
144
+ - 透传给 `HDF5MSHABDataset`
145
+
146
+ ### `train/train.py`
147
+ - 新建 `val_dataset` (split="val", num_episode_per_task=100)
148
+ - 新建 `val_dataloader`(用 `accelerator.prepare_data_loader` 单独 prepare)
149
+ - 在 `sample_period` 触发时除了原 sample_loss_for_log,还跑 val:
150
+ ```python
151
+ val_loss_for_log = log_sample_res(..., val_dataloader, ...)
152
+ val_loss_for_log = {f"val/{k}": v for k, v in val_loss_for_log.items()}
153
+ accelerator.log(val_loss_for_log, step=global_step)
154
+ ```
155
+
156
+ ### `main.py`
157
+ - 加 `--num_episode_per_task` argparse arg
158
+
159
+ ### 数据准备(一次性)
160
+ 1. 生成 val 数据 100 traj/combo(用 `gen_val_combo.sbatch`,SPLIT=val 跑 7 个 combo)
161
+ 2. 给每个 val 目录建 `instructions` 符号链接指向对应 train 目录(语言 embed 通用):
162
+ ```bash
163
+ for sub_obj in "pick/013_apple" "pick/024_bowl" "place/013_apple" "place/024_bowl" "open/fridge" "open/kitchen_counter" "close/kitchen_counter"; do
164
+ ln -s $ROOT/set_table/$(dirname $sub_obj)/train/$(basename $sub_obj)/instructions $ROOT/set_table/$(dirname $sub_obj)/val/$(basename $sub_obj)/instructions
165
+ done
166
+ ```
167
+
168
+ ---
169
+
170
+ ## 评测脚本(本地, RTX A6000)
171
+
172
+ ### eval 命令模板
173
+ ```bash
174
+ export PATH=/data/LFT-W02_data/.conda/envs/acdit/bin:$PATH
175
+ PROJ=/data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
176
+ CKPT_STEP=25000
177
+
178
+ PRETRAINED=$PROJ/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-${CKPT_STEP}/pytorch_model/mp_rank_00_model_states.pt \
179
+ MH_CKPT=$PROJ/checkpoints/DiT-mshab-base-only/checkpoint-30000/pytorch_model/mp_rank_00_model_states.pt \
180
+ DATASET_DIR=$PROJ/third_party/mshab/mshab_data/gen_data_save_trajectories/set_table/pick/train/013_apple \
181
+ TASK=set_table SUBTASK=pick OBJECT=013_apple \
182
+ NUM_TRAJ=50 CUDA=0 \
183
+ RESULT_DIR=./eval_results_all7_${CKPT_STEP}_apple_50 \
184
+ bash scripts/eval_acdit.sh
185
+ ```
186
+
187
+ ### 7 个 combo 的 DATASET_DIR / TASK SUBTASK OBJECT
188
+ | short | TASK | SUBTASK | OBJECT |
189
+ |---|---|---|---|
190
+ | pick_apple | set_table | pick | 013_apple |
191
+ | pick_bowl | set_table | pick | 024_bowl |
192
+ | place_apple | set_table | place | 013_apple |
193
+ | place_bowl | set_table | place | 024_bowl |
194
+ | open_fridge | set_table | open | fridge |
195
+ | open_kc | set_table | open | kitchen_counter |
196
+ | close_kc | set_table | close | kitchen_counter |
197
+
198
+ ### eval_acdit.sh 内部
199
+ 直接调 `python -m scripts.eval_mshab` 跑 ManiSkill env 50 trial 实测。
200
+ 每 trial 200 步 timeout。
201
+ 结果写 `$RESULT_DIR/<combo>/`:
202
+ - `trial_<N>_{success|failure}.mp4`
203
+ - `success_rate.txt`
204
+
205
+ ### 并行评测建议
206
+ 本地 2 张 A6000,每张 ~10GB 可同时跑 2 个 combo(不同 ckpt 或同 ckpt 不同 combo)。
207
+ 每 eval 约 25-90 min(依赖 GPU 共享情况)。
208
+
209
+ ---
210
+
211
+ ## 复现步骤
212
+
213
+ ### 训练
214
+ ```bash
215
+ # 在 HPC3 上
216
+ ssh HPC3_jhe724
217
+ cd /data/user/jhe724/workspace/AC-DiT
218
+
219
+ # 提交训练(8 GPU × 16h)
220
+ sbatch scripts/train_all7.sbatch
221
+
222
+ # 监控
223
+ squeue -u jhe724
224
+ tail -f logs/train_s2_all7_baseline.log
225
+ ```
226
+
227
+ ### 评测
228
+ ```bash
229
+ # 在本地 LFT-W02 上
230
+ cd /data/LFT-W02_data/junjie/mobile_manipulation/AC-DiT
231
+
232
+ # rsync ckpt
233
+ rsync -a HPC3_jhe724:/data/user/jhe724/workspace/AC-DiT/checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/mp_rank_00_model_states.pt \
234
+ ./checkpoints/AC-DiT-mshab-all7_baseline/checkpoint-25000/pytorch_model/
235
+
236
+ # 跑评测(每 combo 一条命令,参考上面)
237
+ ```
238
+
239
+ ---
240
+
241
+ ## 失败教训(避坑)
242
+
243
+ 1. **错觉**:早期以为论文 pick_apple 是 90%,其实是 33%(90% 是 open_fridge)→ 别凭记忆 quote 数字
244
+ 2. **数据量**:之前只用 100 traj/combo × 2 combo 训练,结果 12% 反复折腾;改成 1000/combo × 7 combo 立刻到 32%(匹配论文)→ **训练数据规模是关键**
245
+ 3. **过拟合**:30k 步 ckpt 比 25k 后 4 个 combo 退步(loss 还在降但 eval 衰减)→ **早停而非训满**
246
+ 4. **Action 异常**:open/close 任务的 demo action 实际超 [-1, 1](最大 20),是 controller 未 clip 的 raw 输出,env 自动 clip 兼容了。**Loss 上的尖峰来自这些样本**。不影响最终成功率(因为 env 也 clip)但训练监督不科学
247
+ 5. **val 数据准备**:用 `gen_data.py` SPLIT=val 生成时不会自动建 `instructions/` 子目录 → 训练时 `np.random.choice` 报"empty"。修:从 train 目录软链 `instructions`
248
+ 6. **过早评测**:bullshit number 12% 之前来自不完整 setup(2 combo, 100 traj)。复现实验前先 **完全对齐论文配置**
249
+
250
+ ---
251
+
252
+ ## Wandb metrics 命名
253
+ - 训练 sample loss:`mshab_sample_mse`, `mshab_sample_l2err`, `overall_avg_sample_mse`, `overall_avg_sample_l2err`
254
+ - val loss(held-out 700 traj):`val/mshab_sample_mse`, `val/mshab_sample_l2err`, `val/overall_avg_sample_mse`, `val/overall_avg_sample_l2err`
255
+ - per-step:`loss`, `lr`
256
+
257
+ EXP_NAME = `AC-DiT-mshab-all7_baseline` (wandb run name)