| #!/usr/bin/env bash |
|
|
| set -uo pipefail |
|
|
| |
| |
| |
|
|
| SCRIPT_PATH="qwen3_plain_ar.py" |
|
|
| DATASET_PATH="muse_mucodec_chord.ds" |
|
|
| |
| TOKENIZER_PATH="/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/final" |
|
|
| |
| CHECKPOINTS=( |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-907" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-1814" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-2721" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-3628" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-4535" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-5442" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-6349" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-7256" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-8163" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9070" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-9977" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-10884" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-11791" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-12698" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-13605" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-14512" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-15419" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-16326" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-17233" |
| "/algo-intern/user/leonchen/cond_gen/output_qwen3_plain_ar/checkpoint-18140" |
| ) |
|
|
| |
| OUTPUT_ROOT="/root/batch_preditions_ablation" |
|
|
| |
| NUM_SAMPLES=20 |
|
|
| |
| |
| |
|
|
| DEVICE="cuda:0" |
| DTYPE="bfloat16" |
| ATTN_IMPLEMENTATION="sdpa" |
|
|
| TEMPERATURE=1.0 |
| TOP_K=50 |
| TOP_P=0.9 |
|
|
| MAX_NEW_TOKENS=4096 |
|
|
| |
| SKIP_DECODE=false |
|
|
| |
| |
| |
|
|
| FAILED_LOG="${OUTPUT_ROOT}/failed_cases.log" |
| SUCCESS_LOG="${OUTPUT_ROOT}/success_cases.log" |
|
|
| |
| |
| |
|
|
| mkdir -p "${OUTPUT_ROOT}" |
| touch "${FAILED_LOG}" |
| touch "${SUCCESS_LOG}" |
|
|
| echo "======================================" | tee -a "${SUCCESS_LOG}" |
| echo "Batch inference started at $(date)" | tee -a "${SUCCESS_LOG}" |
| echo "Output root: ${OUTPUT_ROOT}" | tee -a "${SUCCESS_LOG}" |
| echo "======================================" | tee -a "${SUCCESS_LOG}" |
|
|
| for CKPT in "${CHECKPOINTS[@]}"; do |
| CKPT_NAME=$(basename "${CKPT}") |
| OUT_DIR="${OUTPUT_ROOT}/${CKPT_NAME}" |
| CKPT_LOG="${OUT_DIR}/run.log" |
|
|
| echo "======================================" |
| echo "Running checkpoint: ${CKPT_NAME}" |
| echo "Output dir: ${OUT_DIR}" |
| echo "======================================" |
|
|
| if [ ! -d "${CKPT}" ]; then |
| echo "[ERROR] checkpoint directory not found: ${CKPT}" | tee -a "${FAILED_LOG}" |
| continue |
| fi |
|
|
| mkdir -p "${OUT_DIR}" |
| touch "${CKPT_LOG}" |
|
|
| for ((i=0; i<NUM_SAMPLES; i++)); do |
| echo "[INFO] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${CKPT_LOG}" |
|
|
| CMD=( |
| python "${SCRIPT_PATH}" infer |
| --model_path "${CKPT}" |
| --tokenizer_path "${TOKENIZER_PATH}" |
| --dataset_path "${DATASET_PATH}" |
| --split validation |
| --sample_idx "${i}" |
| --device "${DEVICE}" |
| --dtype "${DTYPE}" |
| --attn_implementation "${ATTN_IMPLEMENTATION}" |
| --temperature "${TEMPERATURE}" |
| --top_k "${TOP_K}" |
| --top_p "${TOP_P}" |
| --max_new_tokens_per_section "${MAX_NEW_TOKENS}" |
| --output_dir "${OUT_DIR}" |
| --output_prefix "sample_${i}" |
| ) |
|
|
| if [ "${SKIP_DECODE}" = true ]; then |
| CMD+=(--skip_decode) |
| fi |
|
|
| { |
| echo "[CMD] ${CMD[*]}" |
| "${CMD[@]}" |
| } >> "${CKPT_LOG}" 2>&1 |
|
|
| EXIT_CODE=$? |
|
|
| if [ ${EXIT_CODE} -ne 0 ]; then |
| echo "[ERROR] checkpoint=${CKPT_NAME} sample_idx=${i} exit_code=${EXIT_CODE}" | tee -a "${FAILED_LOG}" |
| continue |
| else |
| echo "[OK] checkpoint=${CKPT_NAME} sample_idx=${i}" | tee -a "${SUCCESS_LOG}" |
| fi |
| done |
|
|
| echo "[DONE] checkpoint=${CKPT_NAME}" | tee -a "${SUCCESS_LOG}" |
| done |
|
|
| echo "======================================" | tee -a "${SUCCESS_LOG}" |
| echo "Batch inference finished at $(date)" | tee -a "${SUCCESS_LOG}" |
| echo "Success log: ${SUCCESS_LOG}" | tee -a "${SUCCESS_LOG}" |
| echo "Failed log: ${FAILED_LOG}" | tee -a "${SUCCESS_LOG}" |
| echo "All done." |