workspace_Feb13 / MERGE_SEQUENCE.sh
Linksome's picture
Add files using upload-large-folder tool
7c31071 verified
#!/usr/bin/env bash
set -euo pipefail
# ============================================================
# RUNME.sh (self-contained, auto-discovers checkpoints)
#
# Modes:
# - Default (WATCH=0): snapshot checkpoints once and process them.
# - WATCH=1 : keep polling for new checkpoint-* dirs and process new ones.
#
# Output:
# /workspace/v126rc_exp3/F_r10000/checkpoint-*/residued
# Logs:
# /workspace/v126rc_exp3/F_r10000/checkpoint-*/residued/merge.log
#
# RAM safety:
# Each checkpoint merge runs in a fresh Python process.
#
# Auto-stop (WATCH=1):
# If no new checkpoints appear for IDLE_LIMIT_SECONDS (default 600),
# the script exits and (optionally) deletes the RunPod pod.
#
# Robustness:
# If a checkpoint is incomplete and merge fails, we log it and retry later
# (WATCH=1 keeps running; WATCH=0 continues to next checkpoint).
#
# Cleanup:
# After a SUCCESSFUL merge, delete everything inside checkpoint-* except residued/
# ============================================================
# ---------------- CONFIG ----------------
LR_DIR="${LR_DIR:-/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter}"
ROOT="${ROOT:-/workspace/v126rc_exp3/F_r10000}"
PYTHON_BIN="${PYTHON_BIN:-python}"
# Save dtype: bf16 (default), fp16, fp32
SAVE_DTYPE="${SAVE_DTYPE:-bf16}"
# Watch mode: 0 = run once, 1 = keep discovering new checkpoints
WATCH="${WATCH:-0}"
POLL_SECONDS="${POLL_SECONDS:-60}"
# Auto-stop when no new checkpoints for this long (WATCH=1 only)
IDLE_LIMIT_SECONDS="${IDLE_LIMIT_SECONDS:-1200}" # 20 minutes default
# Optional: require a checkpoint directory to be "stable" (no mtime changes) before merging
# Set to 0 to disable. A small value (e.g. 30-120) helps avoid half-written checkpoints.
STABLE_SECONDS="${STABLE_SECONDS:-0}"
# Optional toggles (best-effort)
DROP_CACHES="${DROP_CACHES:-0}" # requires sudo; 0/1
GPU_RESET="${GPU_RESET:-0}" # 0/1
# Skip if output already exists and looks complete (has model files)
SKIP_DONE="${SKIP_DONE:-1}" # 0/1
# How long to wait before retrying a failed checkpoint (WATCH=1 only)
RETRY_COOLDOWN_SECONDS="${RETRY_COOLDOWN_SECONDS:-120}"
# -------------- CHECKS --------------
[[ -d "$LR_DIR" ]] || { echo "ERROR: LR_DIR not found: $LR_DIR" >&2; exit 1; }
[[ -d "$ROOT" ]] || { echo "ERROR: ROOT not found: $ROOT" >&2; exit 1; }
# Guardrails
(( POLL_SECONDS > 0 )) || { echo "ERROR: POLL_SECONDS must be > 0" >&2; exit 1; }
(( IDLE_LIMIT_SECONDS >= 0 )) || { echo "ERROR: IDLE_LIMIT_SECONDS must be >= 0" >&2; exit 1; }
(( STABLE_SECONDS >= 0 )) || { echo "ERROR: STABLE_SECONDS must be >= 0" >&2; exit 1; }
(( RETRY_COOLDOWN_SECONDS >= 0 )) || { echo "ERROR: RETRY_COOLDOWN_SECONDS must be >= 0" >&2; exit 1; }
echo "LR_DIR : $LR_DIR"
echo "ROOT : $ROOT"
echo "PYTHON_BIN : $PYTHON_BIN"
echo "SAVE_DTYPE : $SAVE_DTYPE"
echo "WATCH : $WATCH"
echo "POLL_SECONDS : $POLL_SECONDS"
echo "IDLE_LIMIT_SECONDS : $IDLE_LIMIT_SECONDS"
echo "STABLE_SECONDS : $STABLE_SECONDS"
echo "SKIP_DONE : $SKIP_DONE"
echo "RETRY_COOLDOWN_SECONDS : $RETRY_COOLDOWN_SECONDS"
echo
# -------------- Helpers --------------
is_done() {
local out_dir="$1"
# Heuristic: if safetensors exists (or pytorch_model.bin), consider done.
if [[ -f "$out_dir/model.safetensors" ]] || [[ -f "$out_dir/pytorch_model.bin" ]]; then
return 0
fi
return 1
}
cleanup_and_exit_watch() {
local reason="$1"
echo
echo "============================================================"
echo "WATCH STOP: $reason"
echo "============================================================"
# # Best-effort pod removal. This will TERMINATE the pod.
# if [[ -n "${RUNPOD_POD_ID:-}" ]]; then
# echo "Attempting: runpodctl remove pod \"$RUNPOD_POD_ID\""
# if command -v runpodctl >/dev/null 2>&1; then
# runpodctl remove pod "$RUNPOD_POD_ID" || true
# else
# echo "WARNING: runpodctl not found in PATH; cannot remove pod automatically."
# fi
# else
# echo "RUNPOD_POD_ID not set; skipping pod removal."
# fi
echo "Exiting."
exit 0
}
cleanup_checkpoint_keep_residued() {
local ckpt_dir="$1"
local keep_dir="${ckpt_dir}/residued"
# Safety checks
[[ -d "$ckpt_dir" ]] || { echo "WARN: ckpt_dir missing: $ckpt_dir"; return 0; }
[[ -d "$keep_dir" ]] || { echo "WARN: residued missing (won't delete): $keep_dir"; return 0; }
echo "🧹 Cleaning checkpoint (keeping only residued/): $ckpt_dir"
# Delete everything at top-level of checkpoint dir EXCEPT 'residued'
find "$ckpt_dir" -mindepth 1 -maxdepth 1 \
! -name "residued" \
-exec rm -rf {} +
}
# Returns 0 if checkpoint looks "stable enough" to attempt merge.
# If STABLE_SECONDS=0, always returns 0.
is_checkpoint_stable() {
local ckpt_dir="$1"
local stable_s="$STABLE_SECONDS"
(( stable_s == 0 )) && return 0
# Find newest mtime under the checkpoint dir
# (portable-ish; uses find + stat; on busybox stat flags differ, but RunPod usually has coreutils)
local newest_epoch
newest_epoch="$(find "$ckpt_dir" -type f -printf '%T@\n' 2>/dev/null | sort -n | tail -1 | cut -d. -f1 || true)"
[[ -n "${newest_epoch:-}" ]] || return 1
local now_epoch
now_epoch="$(date +%s)"
local age=$(( now_epoch - newest_epoch ))
(( age >= stable_s ))
}
run_merge_for_checkpoint() {
local ckpt_dir="$1"
local out_dir="${ckpt_dir}/residued"
local log_file="${out_dir}/merge.log"
mkdir -p "$out_dir"
if [[ "$SKIP_DONE" == "1" ]] && is_done "$out_dir"; then
echo "SKIP (already merged): $ckpt_dir"
echo " -> $out_dir"
echo " -> $log_file"
return 0
fi
if ! is_checkpoint_stable "$ckpt_dir"; then
echo "HOLD (checkpoint not stable yet): $ckpt_dir (STABLE_SECONDS=$STABLE_SECONDS)"
return 0
fi
echo "============================================================"
echo "Checkpoint : $ckpt_dir"
echo "Output : $out_dir"
echo "Log : $log_file"
echo "============================================================"
# IMPORTANT:
# We must not let a failed merge kill the whole WATCH loop.
# So we run the merge, capture failure, log it, and return 0 (retry later).
if ! {
{
echo "[$(date -Is)] START merge"
echo "Base model : $ckpt_dir"
echo "LR adapter : $LR_DIR"
echo "Output dir : $out_dir"
echo "SAVE_DTYPE : $SAVE_DTYPE"
echo
# Fresh Python process per checkpoint => frees RAM on exit.
LR_DIR="$LR_DIR" BASE_DIR="$ckpt_dir" OUT_DIR="$out_dir" SAVE_DTYPE="$SAVE_DTYPE" \
"$PYTHON_BIN" - <<'PY'
import os
import shutil
import gc
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
LR_DIR = os.environ["LR_DIR"]
BASE_DIR = os.environ["BASE_DIR"]
OUT_DIR = os.environ["OUT_DIR"]
SAVE_DTYPE = os.environ.get("SAVE_DTYPE", "bf16").lower().strip()
def _load_model_fp32(model_dir: str):
try:
return AutoModelForCausalLM.from_pretrained(
model_dir,
dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
except TypeError:
return AutoModelForCausalLM.from_pretrained(
model_dir,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
def _to_save_dtype(model, save_dtype: str):
if save_dtype == "bf16":
return model.to(torch.bfloat16)
if save_dtype == "fp16":
return model.to(torch.float16)
if save_dtype == "fp32":
return model.to(torch.float32)
raise ValueError(f"Unknown SAVE_DTYPE={save_dtype}. Use bf16|fp16|fp32")
def merge_instruction_residual(lr_dir, base_model_dir, output_dir):
adapter_file = os.path.join(lr_dir, "adapter_model.bin")
if not os.path.exists(adapter_file):
raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}")
print("Loading residual adapter...")
residual_state_dict = torch.load(adapter_file, map_location="cpu")
print(f"\nMerging residual into base model: {base_model_dir}")
base_model = _load_model_fp32(base_model_dir)
base_state_dict = base_model.state_dict()
merged_state_dict = {}
mismatched = []
for key, base_tensor in base_state_dict.items():
if key not in residual_state_dict:
merged_state_dict[key] = base_tensor
continue
res_tensor = residual_state_dict[key]
# Exact match
if base_tensor.shape == res_tensor.shape:
merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32)
continue
# Vocab resized: dim0 differs, rest matches
if (
base_tensor.ndim == res_tensor.ndim
and base_tensor.ndim >= 1
and base_tensor.shape[1:] == res_tensor.shape[1:]
and base_tensor.shape[0] != res_tensor.shape[0]
):
n = min(base_tensor.shape[0], res_tensor.shape[0])
out = base_tensor.clone().to(torch.float32)
out[:n] += res_tensor[:n].to(torch.float32)
merged_state_dict[key] = out
mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n))
continue
raise RuntimeError(
f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} "
f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch."
)
if mismatched:
print("\nHandled vocab-resize mismatches by partial add:")
for k, bs, rs, n in mismatched[:20]:
print(f" - {k}: base{bs} vs res{rs} → added first {n} rows, kept the rest unchanged")
if len(mismatched) > 20:
print(f" ... and {len(mismatched) - 20} more")
base_model.load_state_dict(merged_state_dict, strict=True)
base_model = _to_save_dtype(base_model, SAVE_DTYPE)
os.makedirs(output_dir, exist_ok=True)
base_model.save_pretrained(output_dir, safe_serialization=True)
base_config = AutoConfig.from_pretrained(base_model_dir)
base_config.save_pretrained(output_dir)
try:
tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True)
tok.save_pretrained(output_dir)
except Exception:
for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]:
src_path = os.path.join(base_model_dir, file_name)
dst_path = os.path.join(output_dir, file_name)
if os.path.exists(src_path):
shutil.copyfile(src_path, dst_path)
print(f"\n✅ Merge complete.")
print(f"🧠 fp32 math → saved {SAVE_DTYPE} at: {output_dir}")
merge_instruction_residual(LR_DIR, BASE_DIR, OUT_DIR)
gc.collect()
PY
echo
echo "[$(date -Is)] DONE merge"
} >>"$log_file" 2>&1
}; then
echo "⚠️ Merge failed (likely incomplete checkpoint): $ckpt_dir"
echo " -> See log: $log_file"
echo " -> Will retry later"
return 0
fi
# Delete everything except residued/ after a SUCCESSFUL merge
cleanup_checkpoint_keep_residued "$ckpt_dir"
echo "✅ Finished $ckpt_dir (log: $log_file)"
echo
# Optional cleanup
if [[ "$DROP_CACHES" == "1" ]]; then
echo "Dropping Linux page cache (best-effort; requires sudo)..."
sync || true
sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches' || true
fi
if [[ "$GPU_RESET" == "1" ]]; then
echo "Attempting GPU reset (best-effort)..."
nvidia-smi --gpu-reset -i 0 >/dev/null 2>&1 || true
fi
sleep 1
}
discover_checkpoints_sorted() {
# Print checkpoints, one per line, sorted (natural version sort)
find "$ROOT" -maxdepth 1 -type d -name "checkpoint-*" | sort -V
}
# -------------- MAIN --------------
if [[ "$WATCH" == "0" ]]; then
mapfile -t CKPTS < <(discover_checkpoints_sorted)
[[ ${#CKPTS[@]} -gt 0 ]] || { echo "No checkpoint-* directories found under: $ROOT" >&2; exit 1; }
echo "Found ${#CKPTS[@]} checkpoints:"
printf ' - %s\n' "${CKPTS[@]}"
echo
for ckpt in "${CKPTS[@]}"; do
run_merge_for_checkpoint "$ckpt"
done
echo "All merges complete."
exit 0
fi
# WATCH=1 mode: keep discovering new checkpoints
declare -A SEEN=()
declare -A LAST_FAIL_TS=()
echo "WATCH mode enabled. Polling every ${POLL_SECONDS}s for new checkpoint-* directories..."
echo "Auto-stop if idle for ${IDLE_LIMIT_SECONDS}s (no new checkpoints)."
echo
last_new_ts="$(date +%s)"
while true; do
found_new=0
while IFS= read -r ckpt; do
[[ -z "$ckpt" ]] && continue
# If we've seen it before and it failed, allow retries with cooldown
if [[ -n "${SEEN[$ckpt]+x}" ]]; then
# If it is already merged (residued has model), we can ignore forever
if [[ "$SKIP_DONE" == "1" ]] && is_done "${ckpt}/residued"; then
continue
fi
# Cooldown logic for retries
if [[ -n "${LAST_FAIL_TS[$ckpt]+x}" ]] && (( RETRY_COOLDOWN_SECONDS > 0 )); then
now_ts="$(date +%s)"
since_fail=$(( now_ts - LAST_FAIL_TS[$ckpt] ))
if (( since_fail < RETRY_COOLDOWN_SECONDS )); then
continue
fi
fi
# Retry eligible
:
else
SEEN[$ckpt]=1
found_new=1
last_new_ts="$(date +%s)"
fi
# Try merging; if it fails, record fail time (for cooldown)
before_done=0
if [[ "$SKIP_DONE" == "1" ]] && is_done "${ckpt}/residued"; then
before_done=1
fi
run_merge_for_checkpoint "$ckpt"
# If still not done after attempting, mark as failed attempt time
if [[ "$SKIP_DONE" == "1" ]] && ! is_done "${ckpt}/residued"; then
LAST_FAIL_TS[$ckpt]="$(date +%s)"
else
# Success clears failure timestamp
unset 'LAST_FAIL_TS[$ckpt]' || true
fi
done < <(discover_checkpoints_sorted)
if [[ "$found_new" -eq 0 ]]; then
now_ts="$(date +%s)"
idle_for=$(( now_ts - last_new_ts ))
echo "[$(date -Is)] No new checkpoints found. Idle for ${idle_for}s. Sleeping ${POLL_SECONDS}s..."
if (( IDLE_LIMIT_SECONDS > 0 )) && (( idle_for >= IDLE_LIMIT_SECONDS )); then
cleanup_and_exit_watch "No new checkpoints for ${idle_for}s (>= ${IDLE_LIMIT_SECONDS}s)."
fi
sleep "$POLL_SECONDS"
fi
done