llm-sort / run_hijack_layer1_all.py
gatmiry's picture
Upload folder using huggingface_hub
c7f1373 verified
#!/usr/bin/env python3
"""
Run hijack intervention on layer 1 for ALL checkpoints in 1000k-checkpoints/,
distributed across available GPUs. Saves heatmaps in the matching output subfolder.
"""
import os
import sys
import glob
import json
import subprocess
import time
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
OUTPUT_BASE = os.path.join(SCRIPT_DIR, 'outputs')
WORKER_SCRIPT = os.path.join(SCRIPT_DIR, 'hijack_layer1_worker.py')
def get_iter_from_filename(basename):
parts = basename.replace('.pt', '').split('__')
ckpt_type = parts[1] if len(parts) > 1 else 'final'
if ckpt_type.startswith('ckpt'):
return int(ckpt_type.replace('ckpt', ''))
elif ckpt_type == 'final':
return 1000000
return None
def make_folder_name(itr):
return f"plots_V256_B16_LR3e-2_MI{itr}_E64_H1_L2_ds1337_is1337_ckpt{itr}"
def main():
import torch
num_gpus = torch.cuda.device_count() or 1
pt_files = sorted(glob.glob(os.path.join(SCRIPT_DIR, '*.pt')))
tasks = []
for pt in pt_files:
bn = os.path.basename(pt)
itr = get_iter_from_filename(bn)
if itr is None:
continue
folder = make_folder_name(itr)
out_dir = os.path.join(OUTPUT_BASE, folder)
check_file = os.path.join(out_dir, 'hijack_breaking_rate_heatmap_layer1.png')
if os.path.exists(check_file):
print(f"Already exists, skipping: {bn}")
continue
tasks.append((pt, out_dir, itr))
if not tasks:
print("All layer-1 hijack heatmaps already generated.")
return
print(f"Running layer-1 hijack for {len(tasks)} checkpoints on {num_gpus} GPUs ...\n")
batch_size = num_gpus
for batch_start in range(0, len(tasks), batch_size):
batch = tasks[batch_start:batch_start + batch_size]
procs = []
for idx, (pt_path, out_dir, itr) in enumerate(batch):
gpu_id = idx % num_gpus
env = os.environ.copy()
env['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
cmd = [sys.executable, WORKER_SCRIPT, pt_path, '--output-dir', out_dir]
bn = os.path.basename(pt_path)
print(f" [GPU {gpu_id}] {bn} -> {os.path.basename(out_dir)}")
p = subprocess.Popen(cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
procs.append((p, bn))
for p, bn in procs:
stdout, stderr = p.communicate()
out_text = stdout.decode().strip()
if p.returncode != 0:
print(f" FAILED: {bn}\n{stderr.decode()[-500:]}")
else:
last_lines = [l for l in out_text.split('\n') if l.strip()][-3:]
for l in last_lines:
print(f" {l}")
print(f" Done: {bn}")
print()
print("All done.")
if __name__ == '__main__':
main()