""" GPU worker for per-location intervention experiments. Loads model once per checkpoint, runs all location/layer tasks, saves results. """ import argparse import json import os import sys import time import numpy as np import torch sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) from model_analysis import GPT, GPTConfig, GPTIntervention INTENSITY_VALUES = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0] UB = 60 LB = 60 MIN_VALID = 200 def remap_state_dict(sd): new = {} for k, v in sd.items(): nk = k for i in range(10): nk = nk.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') nk = nk.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') new[nk] = v return new def load_model(ckpt_path, device): ckpt = torch.load(ckpt_path, map_location='cpu') mc = ckpt['model_config'] config = GPTConfig(block_size=mc['block_size'], vocab_size=mc['vocab_size'] - 1, with_layer_norm=mc.get('use_final_LN', True)) model = GPT(config) sd = remap_state_dict(ckpt['model_state_dict']) wpe_max = config.block_size * 4 + 1 if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > wpe_max: sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:wpe_max] for k in [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]: del sd[k] if 'lm_head.weight' in sd: del sd['lm_head.weight'] model.load_state_dict(sd, strict=False) model.to(device) model.eval() return model, config def get_batch(vs, bs, device): x = torch.randperm(vs)[:bs] vals, _ = torch.sort(x) return torch.cat((x, torch.tensor([vs]), vals), dim=0).unsqueeze(0).to(device) def compute_intensity_at_location(model, config, device, attn_layer, sorted_pos): """Run intervention at a specific sorted-output position.""" bs = config.block_size vs = config.vocab_size location = bs + sorted_pos rates, counts = [], [] for intens in INTENSITY_VALUES: attempts, rounds = [], 0 while len(attempts) < MIN_VALID and rounds < 2000: rounds += 1 idx = get_batch(vs, bs, device) try: im = GPTIntervention(model, idx) im.intervent_attention( attention_layer_num=attn_layer, location=location, unsorted_lb=LB, unsorted_ub=UB, unsorted_lb_num=0, unsorted_ub_num=1, unsorted_intensity_inc=intens, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) g, n = im.check_if_still_works() attempts.append(g == n) im.revert_attention(attn_layer) except: continue counts.append(len(attempts)) rates.append(sum(attempts) / len(attempts) if attempts else 0.0) return np.array(INTENSITY_VALUES), np.array(rates), np.array(counts) def main(): parser = argparse.ArgumentParser() parser.add_argument('--tasks-file', required=True) parser.add_argument('--gpu', type=int, required=True) args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) device = 'cuda' with open(args.tasks_file) as f: task_list = json.load(f) n_ckpts = len(set(t['ckpt_path'] for t in task_list)) print(f"GPU {args.gpu}: {len(task_list)} tasks across {n_ckpts} checkpoints", flush=True) current_ckpt = None model = None done = 0 for task in task_list: if os.path.exists(task['out']): done += 1 continue if task['ckpt_path'] != current_ckpt: t0 = time.time() model, config = load_model(task['ckpt_path'], device) current_ckpt = task['ckpt_path'] print(f" Loaded {os.path.basename(current_ckpt)} ({time.time()-t0:.1f}s)", flush=True) os.makedirs(os.path.dirname(task['out']), exist_ok=True) t0 = time.time() try: intensities, rates, counts = compute_intensity_at_location( model, config, device, task['layer'], task['sorted_pos']) np.savez(task['out'], intensities=intensities, success_rates=rates, counts=counts, sorted_pos=task['sorted_pos'], layer=task['layer']) dt = time.time() - t0 done += 1 print(json.dumps({ 'status': 'done', 'task': task['name'], 'gpu': args.gpu, 'elapsed': round(dt, 1), 'progress': f'{done}/{len(task_list)}', 'counts': counts.tolist(), }), flush=True) except Exception as e: done += 1 print(json.dumps({ 'status': 'fail', 'task': task['name'], 'error': str(e), }), flush=True) print(f"GPU {args.gpu}: all done ({done}/{len(task_list)})", flush=True) if __name__ == '__main__': main()