| #!/usr/bin/env bash |
| set -euo pipefail |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| LR_DIR="${LR_DIR:-/workspace/Llama-3.2-3B-Lr/instruction_residual_adapter}" |
| ROOT="${ROOT:-/workspace/v126rc_exp3/F_r10000}" |
| PYTHON_BIN="${PYTHON_BIN:-python}" |
|
|
| |
| SAVE_DTYPE="${SAVE_DTYPE:-bf16}" |
|
|
| |
| WATCH="${WATCH:-0}" |
| POLL_SECONDS="${POLL_SECONDS:-60}" |
|
|
| |
| IDLE_LIMIT_SECONDS="${IDLE_LIMIT_SECONDS:-1200}" |
|
|
| |
| |
| STABLE_SECONDS="${STABLE_SECONDS:-0}" |
|
|
| |
| DROP_CACHES="${DROP_CACHES:-0}" |
| GPU_RESET="${GPU_RESET:-0}" |
|
|
| |
| SKIP_DONE="${SKIP_DONE:-1}" |
|
|
| |
| RETRY_COOLDOWN_SECONDS="${RETRY_COOLDOWN_SECONDS:-120}" |
|
|
| |
| [[ -d "$LR_DIR" ]] || { echo "ERROR: LR_DIR not found: $LR_DIR" >&2; exit 1; } |
| [[ -d "$ROOT" ]] || { echo "ERROR: ROOT not found: $ROOT" >&2; exit 1; } |
|
|
| |
| (( POLL_SECONDS > 0 )) || { echo "ERROR: POLL_SECONDS must be > 0" >&2; exit 1; } |
| (( IDLE_LIMIT_SECONDS >= 0 )) || { echo "ERROR: IDLE_LIMIT_SECONDS must be >= 0" >&2; exit 1; } |
| (( STABLE_SECONDS >= 0 )) || { echo "ERROR: STABLE_SECONDS must be >= 0" >&2; exit 1; } |
| (( RETRY_COOLDOWN_SECONDS >= 0 )) || { echo "ERROR: RETRY_COOLDOWN_SECONDS must be >= 0" >&2; exit 1; } |
|
|
| echo "LR_DIR : $LR_DIR" |
| echo "ROOT : $ROOT" |
| echo "PYTHON_BIN : $PYTHON_BIN" |
| echo "SAVE_DTYPE : $SAVE_DTYPE" |
| echo "WATCH : $WATCH" |
| echo "POLL_SECONDS : $POLL_SECONDS" |
| echo "IDLE_LIMIT_SECONDS : $IDLE_LIMIT_SECONDS" |
| echo "STABLE_SECONDS : $STABLE_SECONDS" |
| echo "SKIP_DONE : $SKIP_DONE" |
| echo "RETRY_COOLDOWN_SECONDS : $RETRY_COOLDOWN_SECONDS" |
| echo |
|
|
| |
| is_done() { |
| local out_dir="$1" |
| |
| if [[ -f "$out_dir/model.safetensors" ]] || [[ -f "$out_dir/pytorch_model.bin" ]]; then |
| return 0 |
| fi |
| return 1 |
| } |
|
|
| cleanup_and_exit_watch() { |
| local reason="$1" |
| echo |
| echo "============================================================" |
| echo "WATCH STOP: $reason" |
| echo "============================================================" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| echo "Exiting." |
| exit 0 |
| } |
|
|
| cleanup_checkpoint_keep_residued() { |
| local ckpt_dir="$1" |
| local keep_dir="${ckpt_dir}/residued" |
|
|
| |
| [[ -d "$ckpt_dir" ]] || { echo "WARN: ckpt_dir missing: $ckpt_dir"; return 0; } |
| [[ -d "$keep_dir" ]] || { echo "WARN: residued missing (won't delete): $keep_dir"; return 0; } |
|
|
| echo "🧹 Cleaning checkpoint (keeping only residued/): $ckpt_dir" |
| |
| find "$ckpt_dir" -mindepth 1 -maxdepth 1 \ |
| ! -name "residued" \ |
| -exec rm -rf {} + |
| } |
|
|
| |
| |
| is_checkpoint_stable() { |
| local ckpt_dir="$1" |
| local stable_s="$STABLE_SECONDS" |
| (( stable_s == 0 )) && return 0 |
|
|
| |
| |
| local newest_epoch |
| newest_epoch="$(find "$ckpt_dir" -type f -printf '%T@\n' 2>/dev/null | sort -n | tail -1 | cut -d. -f1 || true)" |
| [[ -n "${newest_epoch:-}" ]] || return 1 |
|
|
| local now_epoch |
| now_epoch="$(date +%s)" |
| local age=$(( now_epoch - newest_epoch )) |
| (( age >= stable_s )) |
| } |
|
|
| run_merge_for_checkpoint() { |
| local ckpt_dir="$1" |
| local out_dir="${ckpt_dir}/residued" |
| local log_file="${out_dir}/merge.log" |
|
|
| mkdir -p "$out_dir" |
|
|
| if [[ "$SKIP_DONE" == "1" ]] && is_done "$out_dir"; then |
| echo "SKIP (already merged): $ckpt_dir" |
| echo " -> $out_dir" |
| echo " -> $log_file" |
| return 0 |
| fi |
|
|
| if ! is_checkpoint_stable "$ckpt_dir"; then |
| echo "HOLD (checkpoint not stable yet): $ckpt_dir (STABLE_SECONDS=$STABLE_SECONDS)" |
| return 0 |
| fi |
|
|
| echo "============================================================" |
| echo "Checkpoint : $ckpt_dir" |
| echo "Output : $out_dir" |
| echo "Log : $log_file" |
| echo "============================================================" |
|
|
| |
| |
| |
| if ! { |
| { |
| echo "[$(date -Is)] START merge" |
| echo "Base model : $ckpt_dir" |
| echo "LR adapter : $LR_DIR" |
| echo "Output dir : $out_dir" |
| echo "SAVE_DTYPE : $SAVE_DTYPE" |
| echo |
|
|
| |
| LR_DIR="$LR_DIR" BASE_DIR="$ckpt_dir" OUT_DIR="$out_dir" SAVE_DTYPE="$SAVE_DTYPE" \ |
| "$PYTHON_BIN" - <<'PY' |
| import os |
| import shutil |
| import gc |
| import torch |
| from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer |
|
|
| LR_DIR = os.environ["LR_DIR"] |
| BASE_DIR = os.environ["BASE_DIR"] |
| OUT_DIR = os.environ["OUT_DIR"] |
| SAVE_DTYPE = os.environ.get("SAVE_DTYPE", "bf16").lower().strip() |
|
|
| def _load_model_fp32(model_dir: str): |
| try: |
| return AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True, |
| ) |
| except TypeError: |
| return AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| trust_remote_code=True, |
| ) |
|
|
| def _to_save_dtype(model, save_dtype: str): |
| if save_dtype == "bf16": |
| return model.to(torch.bfloat16) |
| if save_dtype == "fp16": |
| return model.to(torch.float16) |
| if save_dtype == "fp32": |
| return model.to(torch.float32) |
| raise ValueError(f"Unknown SAVE_DTYPE={save_dtype}. Use bf16|fp16|fp32") |
|
|
| def merge_instruction_residual(lr_dir, base_model_dir, output_dir): |
| adapter_file = os.path.join(lr_dir, "adapter_model.bin") |
| if not os.path.exists(adapter_file): |
| raise FileNotFoundError(f"Adapter checkpoint not found at {adapter_file}") |
|
|
| print("Loading residual adapter...") |
| residual_state_dict = torch.load(adapter_file, map_location="cpu") |
|
|
| print(f"\nMerging residual into base model: {base_model_dir}") |
| base_model = _load_model_fp32(base_model_dir) |
| base_state_dict = base_model.state_dict() |
|
|
| merged_state_dict = {} |
| mismatched = [] |
|
|
| for key, base_tensor in base_state_dict.items(): |
| if key not in residual_state_dict: |
| merged_state_dict[key] = base_tensor |
| continue |
|
|
| res_tensor = residual_state_dict[key] |
|
|
| |
| if base_tensor.shape == res_tensor.shape: |
| merged_state_dict[key] = (base_tensor + res_tensor).to(torch.float32) |
| continue |
|
|
| |
| if ( |
| base_tensor.ndim == res_tensor.ndim |
| and base_tensor.ndim >= 1 |
| and base_tensor.shape[1:] == res_tensor.shape[1:] |
| and base_tensor.shape[0] != res_tensor.shape[0] |
| ): |
| n = min(base_tensor.shape[0], res_tensor.shape[0]) |
| out = base_tensor.clone().to(torch.float32) |
| out[:n] += res_tensor[:n].to(torch.float32) |
| merged_state_dict[key] = out |
| mismatched.append((key, tuple(base_tensor.shape), tuple(res_tensor.shape), n)) |
| continue |
|
|
| raise RuntimeError( |
| f"Shape mismatch for key '{key}': base={tuple(base_tensor.shape)} " |
| f"residual={tuple(res_tensor.shape)}. Not a simple vocab-resize mismatch." |
| ) |
|
|
| if mismatched: |
| print("\nHandled vocab-resize mismatches by partial add:") |
| for k, bs, rs, n in mismatched[:20]: |
| print(f" - {k}: base{bs} vs res{rs} → added first {n} rows, kept the rest unchanged") |
| if len(mismatched) > 20: |
| print(f" ... and {len(mismatched) - 20} more") |
|
|
| base_model.load_state_dict(merged_state_dict, strict=True) |
|
|
| base_model = _to_save_dtype(base_model, SAVE_DTYPE) |
| os.makedirs(output_dir, exist_ok=True) |
| base_model.save_pretrained(output_dir, safe_serialization=True) |
|
|
| base_config = AutoConfig.from_pretrained(base_model_dir) |
| base_config.save_pretrained(output_dir) |
|
|
| try: |
| tok = AutoTokenizer.from_pretrained(base_model_dir, trust_remote_code=True) |
| tok.save_pretrained(output_dir) |
| except Exception: |
| for file_name in ["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"]: |
| src_path = os.path.join(base_model_dir, file_name) |
| dst_path = os.path.join(output_dir, file_name) |
| if os.path.exists(src_path): |
| shutil.copyfile(src_path, dst_path) |
|
|
| print(f"\n✅ Merge complete.") |
| print(f"🧠 fp32 math → saved {SAVE_DTYPE} at: {output_dir}") |
|
|
| merge_instruction_residual(LR_DIR, BASE_DIR, OUT_DIR) |
| gc.collect() |
| PY |
|
|
| echo |
| echo "[$(date -Is)] DONE merge" |
| } >>"$log_file" 2>&1 |
| }; then |
| echo "⚠️ Merge failed (likely incomplete checkpoint): $ckpt_dir" |
| echo " -> See log: $log_file" |
| echo " -> Will retry later" |
| return 0 |
| fi |
|
|
| |
| cleanup_checkpoint_keep_residued "$ckpt_dir" |
|
|
| echo "✅ Finished $ckpt_dir (log: $log_file)" |
| echo |
|
|
| |
| if [[ "$DROP_CACHES" == "1" ]]; then |
| echo "Dropping Linux page cache (best-effort; requires sudo)..." |
| sync || true |
| sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches' || true |
| fi |
|
|
| if [[ "$GPU_RESET" == "1" ]]; then |
| echo "Attempting GPU reset (best-effort)..." |
| nvidia-smi --gpu-reset -i 0 >/dev/null 2>&1 || true |
| fi |
|
|
| sleep 1 |
| } |
|
|
| discover_checkpoints_sorted() { |
| |
| find "$ROOT" -maxdepth 1 -type d -name "checkpoint-*" | sort -V |
| } |
|
|
| |
| if [[ "$WATCH" == "0" ]]; then |
| mapfile -t CKPTS < <(discover_checkpoints_sorted) |
|
|
| [[ ${#CKPTS[@]} -gt 0 ]] || { echo "No checkpoint-* directories found under: $ROOT" >&2; exit 1; } |
|
|
| echo "Found ${#CKPTS[@]} checkpoints:" |
| printf ' - %s\n' "${CKPTS[@]}" |
| echo |
|
|
| for ckpt in "${CKPTS[@]}"; do |
| run_merge_for_checkpoint "$ckpt" |
| done |
|
|
| echo "All merges complete." |
| exit 0 |
| fi |
|
|
| |
| declare -A SEEN=() |
| declare -A LAST_FAIL_TS=() |
|
|
| echo "WATCH mode enabled. Polling every ${POLL_SECONDS}s for new checkpoint-* directories..." |
| echo "Auto-stop if idle for ${IDLE_LIMIT_SECONDS}s (no new checkpoints)." |
| echo |
|
|
| last_new_ts="$(date +%s)" |
|
|
| while true; do |
| found_new=0 |
|
|
| while IFS= read -r ckpt; do |
| [[ -z "$ckpt" ]] && continue |
|
|
| |
| if [[ -n "${SEEN[$ckpt]+x}" ]]; then |
| |
| if [[ "$SKIP_DONE" == "1" ]] && is_done "${ckpt}/residued"; then |
| continue |
| fi |
|
|
| |
| if [[ -n "${LAST_FAIL_TS[$ckpt]+x}" ]] && (( RETRY_COOLDOWN_SECONDS > 0 )); then |
| now_ts="$(date +%s)" |
| since_fail=$(( now_ts - LAST_FAIL_TS[$ckpt] )) |
| if (( since_fail < RETRY_COOLDOWN_SECONDS )); then |
| continue |
| fi |
| fi |
|
|
| |
| : |
| else |
| SEEN[$ckpt]=1 |
| found_new=1 |
| last_new_ts="$(date +%s)" |
| fi |
|
|
| |
| before_done=0 |
| if [[ "$SKIP_DONE" == "1" ]] && is_done "${ckpt}/residued"; then |
| before_done=1 |
| fi |
|
|
| run_merge_for_checkpoint "$ckpt" |
|
|
| |
| if [[ "$SKIP_DONE" == "1" ]] && ! is_done "${ckpt}/residued"; then |
| LAST_FAIL_TS[$ckpt]="$(date +%s)" |
| else |
| |
| unset 'LAST_FAIL_TS[$ckpt]' || true |
| fi |
| done < <(discover_checkpoints_sorted) |
|
|
| if [[ "$found_new" -eq 0 ]]; then |
| now_ts="$(date +%s)" |
| idle_for=$(( now_ts - last_new_ts )) |
|
|
| echo "[$(date -Is)] No new checkpoints found. Idle for ${idle_for}s. Sleeping ${POLL_SECONDS}s..." |
| if (( IDLE_LIMIT_SECONDS > 0 )) && (( idle_for >= IDLE_LIMIT_SECONDS )); then |
| cleanup_and_exit_watch "No new checkpoints for ${idle_for}s (>= ${IDLE_LIMIT_SECONDS}s)." |
| fi |
|
|
| sleep "$POLL_SECONDS" |
| fi |
| done |
|
|