cond_gen / batch_infer.sh
Leon299's picture
Add files using upload-large-folder tool
fd1afc8 verified
#!/usr/bin/env bash
set -uo pipefail
########################################
# 配置区(你只需要改这里)
########################################
SCRIPT_PATH="qwen3_plain_ar.py"
DATASET_PATH="muse_mucodec_chord.ds"
# tokenizer(必须是带 chat_template 的)
TOKENIZER_PATH="/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/final"
# checkpoint 列表
CHECKPOINTS=(
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-907"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-1814"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-2721"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-3628"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-4535"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-5442"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-6349"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-7256"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-8163"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9070"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9977"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-10884"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-11791"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-12698"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-13605"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-14512"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-15419"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-16326"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-17233"
"/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-18140"
)
# 输出根目录
OUTPUT_ROOT="/root/batch_preditions_ablation"
# 每个 checkpoint 推理多少条
NUM_SAMPLES=20
########################################
# 推理参数(可以调)
########################################
DEVICE="cuda:0"
DTYPE="bfloat16"
ATTN_IMPLEMENTATION="sdpa"
TEMPERATURE=1.0
TOP_K=50
TOP_P=0.9
MAX_NEW_TOKENS=4096
# 是否跳过音频解码(调试建议先开)
SKIP_DECODE=false
########################################
# 日志文件
########################################
FAILED_LOG="${OUTPUT_ROOT}/failed_cases.log"
SUCCESS_LOG="${OUTPUT_ROOT}/success_cases.log"
########################################
# 开始执行
########################################
mkdir -p "${OUTPUT_ROOT}"
touch "${FAILED_LOG}"
touch "${SUCCESS_LOG}"
echo "======================================" | tee -a "${SUCCESS_LOG}"
echo "Batch inference started at $(date)" | tee -a "${SUCCESS_LOG}"
echo "Output root: ${OUTPUT_ROOT}" | tee -a "${SUCCESS_LOG}"
echo "======================================" | tee -a "${SUCCESS_LOG}"
for CKPT in "${CHECKPOINTS[@]}"; do
CKPT_NAME=$(basename "${CKPT}")
OUT_DIR="${OUTPUT_ROOT}/${CKPT_NAME}"
CKPT_LOG="${OUT_DIR}/run.log"
echo "======================================"
echo "Running checkpoint: ${CKPT_NAME}"
echo "Output dir: ${OUT_DIR}"
echo "======================================"
if [ ! -d "${CKPT}" ]; then
echo "[ERROR] checkpoint directory not found: ${CKPT}" | tee -a "${FAILED_LOG}"
continue
fi
mkdir -p "${OUT_DIR}"
touch "${CKPT_LOG}"
for ((i=0; i<NUM_SAMPLES; i++)); do
echo "[INFO] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${CKPT_LOG}"
CMD=(
python "${SCRIPT_PATH}" infer
--model_path "${CKPT}"
--tokenizer_path "${TOKENIZER_PATH}"
--dataset_path "${DATASET_PATH}"
--split validation
--sample_idx "${i}"
--device "${DEVICE}"
--dtype "${DTYPE}"
--attn_implementation "${ATTN_IMPLEMENTATION}"
--temperature "${TEMPERATURE}"
--top_k "${TOP_K}"
--top_p "${TOP_P}"
--max_new_tokens_per_section "${MAX_NEW_TOKENS}"
--output_dir "${OUT_DIR}"
--output_prefix "sample_${i}"
)
if [ "${SKIP_DECODE}" = true ]; then
CMD+=(--skip_decode)
fi
{
echo "[CMD] ${CMD[*]}"
"${CMD[@]}"
} >> "${CKPT_LOG}" 2>&1
EXIT_CODE=$?
if [ ${EXIT_CODE} -ne 0 ]; then
echo "[ERROR] checkpoint=${CKPT_NAME} sample_idx=${i} exit_code=${EXIT_CODE}" | tee -a "${FAILED_LOG}"
continue
else
echo "[OK] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${SUCCESS_LOG}"
fi
done
echo "[DONE] checkpoint=${CKPT_NAME}" | tee -a "${SUCCESS_LOG}"
done
echo "======================================" | tee -a "${SUCCESS_LOG}"
echo "Batch inference finished at $(date)" | tee -a "${SUCCESS_LOG}"
echo "Success log: ${SUCCESS_LOG}" | tee -a "${SUCCESS_LOG}"
echo "Failed log: ${FAILED_LOG}" | tee -a "${SUCCESS_LOG}"
echo "All done."