#!/usr/bin/env bash set -euo pipefail # ============================================================ # RUNME.sh (self-contained, auto-discovers checkpoints) # # Modes: # - Default (WATCH=0): snapshot checkpoints once and process them. # - WATCH=1 : keep polling for new checkpoint-* dirs and process new ones. # # Output: # /workspace/v126rc_exp3/F_r10000/checkpoint-*/residued # Logs: # /workspace/v126rc_exp3/F_r10000/checkpoint-*/residued/merge.log # # RAM safety: # Each checkpoint merge runs in a fresh Python process. # # Auto-stop (WATCH=1): # If no new checkpoints appear for IDLE_LIMIT_SECONDS (default 600), # the script exits and (optionally) deletes the RunPod pod. # # Robustness: # If a checkpoint is incomplete and merge fails, we log it and retry later # (WATCH=1 keeps running; WATCH=0 continues to next checkpoint). # # Cleanup: # After a SUCCESSFUL merge, delete everything inside checkpoint-* except residued/ # ============================================================ # ---------------- CONFIG ---------------- LR_DIR="${LR_DIR:-/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter}" ROOT="${ROOT:-/workspace/v126rc_exp3/F_r10000}" PYTHON_BIN="${PYTHON_BIN:-python}" # Save dtype: bf16 (default), fp16, fp32 SAVE_DTYPE="${SAVE_DTYPE:-bf16}" # Watch mode: 0 = run once, 1 = keep discovering new checkpoints WATCH="${WATCH:-0}" POLL_SECONDS="${POLL_SECONDS:-60}" # Auto-stop when no new checkpoints for this long (WATCH=1 only) IDLE_LIMIT_SECONDS="${IDLE_LIMIT_SECONDS:-1200}" # 20 minutes default # Optional: require a checkpoint directory to be "stable" (no mtime changes) before merging # Set to 0 to disable. A small value (e.g. 30-120) helps avoid half-written checkpoints. STABLE_SECONDS="${STABLE_SECONDS:-0}" # Optional toggles (best-effort) DROP_CACHES="${DROP_CACHES:-0}" # requires sudo; 0/1 GPU_RESET="${GPU_RESET:-0}" # 0/1 # Skip if output already exists and looks complete (has model files) SKIP_DONE="${SKIP_DONE:-1}" # 0/1 # How long to wait before retrying a failed checkpoint (WATCH=1 only) RETRY_COOLDOWN_SECONDS="${RETRY_COOLDOWN_SECONDS:-120}" # -------------- CHECKS -------------- [[ -d "$LR_DIR" ]] || { echo "ERROR: LR_DIR not found: $LR_DIR" >&2; exit 1; } [[ -d "$ROOT" ]] || { echo "ERROR: ROOT not found: $ROOT" >&2; exit 1; } # Guardrails (( POLL_SECONDS > 0 )) || { echo "ERROR: POLL_SECONDS must be > 0" >&2; exit 1; } (( IDLE_LIMIT_SECONDS >= 0 )) || { echo "ERROR: IDLE_LIMIT_SECONDS must be >= 0" >&2; exit 1; } (( STABLE_SECONDS >= 0 )) || { echo "ERROR: STABLE_SECONDS must be >= 0" >&2; exit 1; } (( RETRY_COOLDOWN_SECONDS >= 0 )) || { echo "ERROR: RETRY_COOLDOWN_SECONDS must be >= 0" >&2; exit 1; } echo "LR_DIR : $LR_DIR" echo "ROOT : $ROOT" echo "PYTHON_BIN : $PYTHON_BIN" echo "SAVE_DTYPE : $SAVE_DTYPE" echo "WATCH : $WATCH" echo "POLL_SECONDS : $POLL_SECONDS" echo "IDLE_LIMIT_SECONDS : $IDLE_LIMIT_SECONDS" echo "STABLE_SECONDS : $STABLE_SECONDS" echo "SKIP_DONE : $SKIP_DONE" echo "RETRY_COOLDOWN_SECONDS : $RETRY_COOLDOWN_SECONDS" echo # -------------- Helpers -------------- is_done() { local out_dir="$1" # Heuristic: if safetensors exists (or pytorch_model.bin), consider done. if [[ -f "$out_dir/model.safetensors" ]] || [[ -f "$out_dir/pytorch_model.bin" ]]; then return 0 fi return 1 } cleanup_and_exit_watch() { local reason="$1" echo echo "============================================================" echo "WATCH STOP: $reason" echo "============================================================" # # Best-effort pod removal. This will TERMINATE the pod. # if [[ -n "${RUNPOD_POD_ID:-}" ]]; then # echo "Attempting: runpodctl remove pod \"$RUNPOD_POD_ID\"" # if command -v runpodctl >/dev/null 2>&1; then # runpodctl remove pod "$RUNPOD_POD_ID" || true # else # echo "WARNING: runpodctl not found in PATH; cannot remove pod automatically." # fi # else # echo "RUNPOD_POD_ID not set; skipping pod removal." # fi echo "Exiting." exit 0 } cleanup_checkpoint_keep_residued() { local ckpt_dir="$1" local keep_dir="${ckpt_dir}/residued" # Safety checks [[ -d "$ckpt_dir" ]] || { echo "WARN: ckpt_dir missing: $ckpt_dir"; return 0; } [[ -d "$keep_dir" ]] || { echo "WARN: residued missing (won't delete): $keep_dir"; return 0; } echo "๐Ÿงน Cleaning checkpoint (keeping only residued/): $ckpt_dir" # Delete everything at top-level of checkpoint dir EXCEPT 'residued' find "$ckpt_dir" -mindepth 1 -maxdepth 1 \ ! -name "residued" \ -exec rm -rf {} + } # Returns 0 if checkpoint looks "stable enough" to attempt merge. # If STABLE_SECONDS=0, always returns 0. is_checkpoint_stable() { local ckpt_dir="$1" local stable_s="$STABLE_SECONDS" (( stable_s == 0 )) && return 0 # Find newest mtime under the checkpoint dir # (portable-ish; uses find + stat; on busybox stat flags differ, but RunPod usually has coreutils) local newest_epoch newest_epoch="$(find "$ckpt_dir" -type f -printf '%T@\n' 2>/dev/null | sort -n | tail -1 | cut -d. -f1 || true)" [[ -n "${newest_epoch:-}" ]] || return 1 local now_epoch now_epoch="$(date +%s)" local age=$(( now_epoch - newest_epoch )) (( age >= stable_s )) } run_merge_for_checkpoint() { local ckpt_dir="$1" local out_dir="${ckpt_dir}/residued" local log_file="${out_dir}/merge.log" mkdir -p "$out_dir" if [[ "$SKIP_DONE" == "1" ]] && is_done "$out_dir"; then echo "SKIP (already merged): $ckpt_dir" echo " -> $out_dir" echo " -> $log_file" return 0 fi if ! is_checkpoint_stable "$ckpt_dir"; then echo "HOLD (checkpoint not stable yet): $ckpt_dir (STABLE_SECONDS=$STABLE_SECONDS)" return 0 fi echo "============================================================" echo "Checkpoint : $ckpt_dir" echo "Output : $out_dir" echo "Log : $log_file" echo "============================================================" # IMPORTANT: # We must not let a failed merge kill the whole WATCH loop. # So we run the merge, capture failure, log it, and return 0 (retry later). if ! { { echo "[$(date -Is)] START merge" echo "Base model : $ckpt_dir" echo "LR adapter : $LR_DIR" echo "Output dir : $out_dir" echo "SAVE_DTYPE : $SAVE_DTYPE" echo # Fresh Python process per checkpoint => frees RAM on exit. LR_DIR="$LR_DIR" BASE_DIR="$ckpt_dir" OUT_DIR="$out_dir" SAVE_DTYPE="$SAVE_DTYPE" \ "$PYTHON_BIN" - <<'PY' import os import shutil import gc import torch from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer LR_DIR = os.environ["LR_DIR"] BASE_DIR = os.environ["BASE_DIR"] OUT_DIR = os.environ["OUT_DIR"] SAVE_DTYPE = os.environ.get("SAVE_DTYPE", "bf16").lower().strip() def _load_model_fp32(model_dir: str): try: return AutoModelForCausalLM.from_pretrained( model_dir, dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) except TypeError: return AutoModelForCausalLM.from_pretrained( model_dir, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) def _to_save_dtype(model, save_dtype: str): if save_dtype == "bf16": return model.to(torch.bfloat16) if save_dtype == "fp16": return model.to(torch.float16) if save_dtype == "fp32": return model.to(torch.float32) raise ValueError(f"Unknown SAVE_DTYPE={save_dtype}. Use bf16|fp16|fp32") def merge_instruction_residual(lr_dir, base_model_dir, output_dir): adapter_file = os.path.join(lr_dir, "adapter_model.bin") if not os.path.exists(adapter_file): raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}") print("Loading residual adapter...") residual_state_dict = torch.load(adapter_file, map_location="cpu") print(f"\nMerging residual into base model: {base_model_dir}") base_model = _load_model_fp32(base_model_dir) base_state_dict = base_model.state_dict() merged_state_dict = {} mismatched = [] for key, base_tensor in base_state_dict.items(): if key not in residual_state_dict: merged_state_dict[key] = base_tensor continue res_tensor = residual_state_dict[key] # Exact match if base_tensor.shape == res_tensor.shape: merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32) continue # Vocab resized: dim0 differs, rest matches if ( base_tensor.ndim == res_tensor.ndim and base_tensor.ndim >= 1 and base_tensor.shape[1:] == res_tensor.shape[1:] and base_tensor.shape[0] != res_tensor.shape[0] ): n = min(base_tensor.shape[0], res_tensor.shape[0]) out = base_tensor.clone().to(torch.float32) out[:n] += res_tensor[:n].to(torch.float32) merged_state_dict[key] = out mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n)) continue raise RuntimeError( f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} " f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch." ) if mismatched: print("\nHandled vocab-resize mismatches by partial add:") for k, bs, rs, n in mismatched[:20]: print(f" - {k}: base{bs} vs res{rs} โ†’ added first {n} rows, kept the rest unchanged") if len(mismatched) > 20: print(f" ... and {len(mismatched) - 20} more") base_model.load_state_dict(merged_state_dict, strict=True) base_model = _to_save_dtype(base_model, SAVE_DTYPE) os.makedirs(output_dir, exist_ok=True) base_model.save_pretrained(output_dir, safe_serialization=True) base_config = AutoConfig.from_pretrained(base_model_dir) base_config.save_pretrained(output_dir) try: tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) tok.save_pretrained(output_dir) except Exception: for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: src_path = os.path.join(base_model_dir, file_name) dst_path = os.path.join(output_dir, file_name) if os.path.exists(src_path): shutil.copyfile(src_path, dst_path) print(f"\nโœ… Merge complete.") print(f"๐Ÿง  fp32 math โ†’ saved {SAVE_DTYPE} at: {output_dir}") merge_instruction_residual(LR_DIR, BASE_DIR, OUT_DIR) gc.collect() PY echo echo "[$(date -Is)] DONE merge" } >>"$log_file" 2>&1 }; then echo "โš ๏ธ Merge failed (likely incomplete checkpoint): $ckpt_dir" echo " -> See log: $log_file" echo " -> Will retry later" return 0 fi # Delete everything except residued/ after a SUCCESSFUL merge cleanup_checkpoint_keep_residued "$ckpt_dir" echo "โœ… Finished $ckpt_dir (log: $log_file)" echo # Optional cleanup if [[ "$DROP_CACHES" == "1" ]]; then echo "Dropping Linux page cache (best-effort; requires sudo)..." sync || true sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches' || true fi if [[ "$GPU_RESET" == "1" ]]; then echo "Attempting GPU reset (best-effort)..." nvidia-smi --gpu-reset -i 0 >/dev/null 2>&1 || true fi sleep 1 } discover_checkpoints_sorted() { # Print checkpoints, one per line, sorted (natural version sort) find "$ROOT" -maxdepth 1 -type d -name "checkpoint-*" | sort -V } # -------------- MAIN -------------- if [[ "$WATCH" == "0" ]]; then mapfile -t CKPTS < <(discover_checkpoints_sorted) [[ ${#CKPTS[@]} -gt 0 ]] || { echo "No checkpoint-* directories found under: $ROOT" >&2; exit 1; } echo "Found ${#CKPTS[@]} checkpoints:" printf ' - %s\n' "${CKPTS[@]}" echo for ckpt in "${CKPTS[@]}"; do run_merge_for_checkpoint "$ckpt" done echo "All merges complete." exit 0 fi # WATCH=1 mode: keep discovering new checkpoints declare -A SEEN=() declare -A LAST_FAIL_TS=() echo "WATCH mode enabled. Polling every ${POLL_SECONDS}s for new checkpoint-* directories..." echo "Auto-stop if idle for ${IDLE_LIMIT_SECONDS}s (no new checkpoints)." echo last_new_ts="$(date +%s)" while true; do found_new=0 while IFS= read -r ckpt; do [[ -z "$ckpt" ]] && continue # If we've seen it before and it failed, allow retries with cooldown if [[ -n "${SEEN[$ckpt]+x}" ]]; then # If it is already merged (residued has model), we can ignore forever if [[ "$SKIP_DONE" == "1" ]] && is_done "${ckpt}/residued"; then continue fi # Cooldown logic for retries if [[ -n "${LAST_FAIL_TS[$ckpt]+x}" ]] && (( RETRY_COOLDOWN_SECONDS > 0 )); then now_ts="$(date +%s)" since_fail=$(( now_ts - LAST_FAIL_TS[$ckpt] )) if (( since_fail < RETRY_COOLDOWN_SECONDS )); then continue fi fi # Retry eligible : else SEEN[$ckpt]=1 found_new=1 last_new_ts="$(date +%s)" fi # Try merging; if it fails, record fail time (for cooldown) before_done=0 if [[ "$SKIP_DONE" == "1" ]] && is_done "${ckpt}/residued"; then before_done=1 fi run_merge_for_checkpoint "$ckpt" # If still not done after attempting, mark as failed attempt time if [[ "$SKIP_DONE" == "1" ]] && ! is_done "${ckpt}/residued"; then LAST_FAIL_TS[$ckpt]="$(date +%s)" else # Success clears failure timestamp unset 'LAST_FAIL_TS[$ckpt]' || true fi done < <(discover_checkpoints_sorted) if [[ "$found_new" -eq 0 ]]; then now_ts="$(date +%s)" idle_for=$(( now_ts - last_new_ts )) echo "[$(date -Is)] No new checkpoints found. Idle for ${idle_for}s. Sleeping ${POLL_SECONDS}s..." if (( IDLE_LIMIT_SECONDS > 0 )) && (( idle_for >= IDLE_LIMIT_SECONDS )); then cleanup_and_exit_watch "No new checkpoints for ${idle_for}s (>= ${IDLE_LIMIT_SECONDS}s)." fi sleep "$POLL_SECONDS" fi done