UV-DOC / run_overfit_train_infer_consistency.sh
zhaxie's picture
Add files using upload-large-folder tool
fcd99cd verified
#!/usr/bin/env bash
# 单样例过拟合 + 训练/推理数据管线一致性校验
#
# 1) 快速断言:train.py 与 verify_ckpt_val_pipeline.py 使用的 UVDocDataset 张量一致
# 2) 可选:单样例过拟合训练(与 run_overfit_official_uvdoc.sh 相同超参)
# 3) 用同一套预处理跑 verify_ckpt_val_pipeline.py,mean_mse 应与训练日志里该 epoch 的 Val MSE 对齐
#
# 用法:
# PREPROCESS_ONLY=1 ./run_overfit_train_infer_consistency.sh
# ./run_overfit_train_infer_consistency.sh
#
set -euo pipefail
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PY="${PYTHON:-python3}"
UV="${UV_DOC_ROOT:-$ROOT/UVDoc_final}"
OFF="$ROOT/UVDoc_official"
LOGDIR="${LOGDIR:-$ROOT/log_overfit_consistency}"
CKPT_GLOB="${CKPT_GLOB:-}"
cd "$OFF"
echo "== (1) Preprocess alignment: train vs verify_ckpt constructors =="
"$PY" verify_uvdoc_train_infer_preprocess.py \
--data_path_UVDoc "$UV" \
--overfit_n 1 \
--mode overfit \
--check_dataloader \
--batch_size 8 \
--num_workers 0
if [[ "${PREPROCESS_ONLY:-0}" == "1" ]]; then
echo "PREPROCESS_ONLY=1, skip training and checkpoint verify."
exit 0
fi
echo "== (2) Single-sample overfit training =="
"$PY" train.py \
--data_to_use uvdoc \
--data_path_UVDoc "$UV" \
--overfit_n 1 \
--batch_size 8 \
--n_epochs 10 \
--n_epochs_decay 10 \
--lr 0.0002 \
--alpha_w 5.0 \
--beta_w 5.0 \
--gamma_w 1.0 \
--ep_gamma_start 10 \
--num_workers "${NUM_WORKERS:-4}" \
--device "${DEVICE:-cuda:0}" \
--logdir "$LOGDIR"
# 取最新 best ckpt(按修改时间)
mapfile -t CKPTS < <(ls -t "$LOGDIR"/ep_*_best_model.pkl 2>/dev/null || true)
if [[ ${#CKPTS[@]} -eq 0 ]]; then
echo "No ep_*_best_model.pkl under $LOGDIR" >&2
exit 1
fi
CKPT="${CKPT_GLOB:-${CKPTS[0]}}"
echo "Using checkpoint: $CKPT"
OUT="$LOGDIR/verify_infer_same_preprocess"
rm -rf "$OUT"
mkdir -p "$OUT"
echo "== (3) Inference with SAME dataset kwargs as train val/overfit =="
"$PY" verify_ckpt_val_pipeline.py \
--ckpt "$CKPT" \
--data_path_UVDoc "$UV" \
--overfit_n 1 \
--out_dir "$OUT" \
--max_save_images 1 \
--device "${DEVICE:-cuda:0}"
echo "Done. Compare mean_mse in $OUT/metrics.txt to the Val MSE line in train log under $LOGDIR"