#!/usr/bin/env bash set -uo pipefail ######################################## # 配置区(你只需要改这里) ######################################## SCRIPT_PATH="qwen3_plain_ar.py" DATASET_PATH="muse_mucodec_chord.ds" # tokenizer(必须是带 chat_template 的) TOKENIZER_PATH="/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/final" # checkpoint 列表 CHECKPOINTS=( "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-907" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-1814" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-2721" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-3628" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-4535" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-5442" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-6349" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-7256" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-8163" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9070" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9977" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-10884" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-11791" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-12698" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-13605" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-14512" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-15419" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-16326" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-17233" "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-18140" ) # 输出根目录 OUTPUT_ROOT="/root/batch_preditions_ablation" # 每个 checkpoint 推理多少条 NUM_SAMPLES=20 ######################################## # 推理参数(可以调) ######################################## DEVICE="cuda:0" DTYPE="bfloat16" ATTN_IMPLEMENTATION="sdpa" TEMPERATURE=1.0 TOP_K=50 TOP_P=0.9 MAX_NEW_TOKENS=4096 # 是否跳过音频解码(调试建议先开) SKIP_DECODE=false ######################################## # 日志文件 ######################################## FAILED_LOG="${OUTPUT_ROOT}/failed_cases.log" SUCCESS_LOG="${OUTPUT_ROOT}/success_cases.log" ######################################## # 开始执行 ######################################## mkdir -p "${OUTPUT_ROOT}" touch "${FAILED_LOG}" touch "${SUCCESS_LOG}" echo "======================================" | tee -a "${SUCCESS_LOG}" echo "Batch inference started at $(date)" | tee -a "${SUCCESS_LOG}" echo "Output root: ${OUTPUT_ROOT}" | tee -a "${SUCCESS_LOG}" echo "======================================" | tee -a "${SUCCESS_LOG}" for CKPT in "${CHECKPOINTS[@]}"; do CKPT_NAME=$(basename "${CKPT}") OUT_DIR="${OUTPUT_ROOT}/${CKPT_NAME}" CKPT_LOG="${OUT_DIR}/run.log" echo "======================================" echo "Running checkpoint: ${CKPT_NAME}" echo "Output dir: ${OUT_DIR}" echo "======================================" if [ ! -d "${CKPT}" ]; then echo "[ERROR] checkpoint directory not found: ${CKPT}" | tee -a "${FAILED_LOG}" continue fi mkdir -p "${OUT_DIR}" touch "${CKPT_LOG}" for ((i=0; i> "${CKPT_LOG}" 2>&1 EXIT_CODE=$? if [ ${EXIT_CODE} -ne 0 ]; then echo "[ERROR] checkpoint=${CKPT_NAME} sample_idx=${i} exit_code=${EXIT_CODE}" | tee -a "${FAILED_LOG}" continue else echo "[OK] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${SUCCESS_LOG}" fi done echo "[DONE] checkpoint=${CKPT_NAME}" | tee -a "${SUCCESS_LOG}" done echo "======================================" | tee -a "${SUCCESS_LOG}" echo "Batch inference finished at $(date)" | tee -a "${SUCCESS_LOG}" echo "Success log: ${SUCCESS_LOG}" | tee -a "${SUCCESS_LOG}" echo "Failed log: ${FAILED_LOG}" | tee -a "${SUCCESS_LOG}" echo "All done."