#!/usr/bin/env python3 """ Comprehensive intervention worker for 100k checkpoints (Layer 0). For each checkpoint: generates N_SEQ random sequences, intervenes at every sorted-output position with multiple intensities, records per-trial details. Methodology matches existing perlocation/pernumber experiments: - unsorted_lb_num=0, unsorted_ub_num=1 (boost one wrong unsorted number) - ub=60 (wide neighbourhood) - Same GPTIntervention mechanism from grid-run/model_analysis.py """ import argparse import json import os import sys import time import numpy as np import torch sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) from model_analysis import GPT, GPTConfig, GPTIntervention N_SEQ = 3000 INTENSITIES = [2.0, 4.0, 6.0, 10.0] UB = 60 def remap_state_dict(sd): new = {} for k, v in sd.items(): nk = k for i in range(10): nk = nk.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') nk = nk.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') new[nk] = v return new def load_model(ckpt_path, device): ckpt = torch.load(ckpt_path, map_location='cpu') mc = ckpt['model_config'] config = GPTConfig(block_size=mc['block_size'], vocab_size=mc['vocab_size'] - 1, with_layer_norm=mc.get('use_final_LN', True)) model = GPT(config) sd = remap_state_dict(ckpt['model_state_dict']) wpe_max = config.block_size * 4 + 1 if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > wpe_max: sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:wpe_max] for k in [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]: del sd[k] if 'lm_head.weight' in sd: del sd['lm_head.weight'] model.load_state_dict(sd, strict=False) model.to(device) model.eval() return model, config def get_batch(vs, bs, device): x = torch.randperm(vs)[:bs] vals, _ = torch.sort(x) return torch.cat((x, torch.tensor([vs]), vals), dim=0).unsqueeze(0).to(device) @torch.no_grad() def run_checkpoint(model, config, device): bs = config.block_size vs = config.vocab_size pos_l, num_l, nxt_l, gap_l = [], [], [], [] int_l, cor_l, pred_l = [], [], [] n_ok = n_fail = 0 for si in range(N_SEQ): idx = get_batch(vs, bs, device) # --- baseline predictions (intensity=0) --- logits, _ = model(idx) bpreds = torch.argmax(logits, dim=-1)[0] for p in range(bs - 1): loc = bs + 1 + p num_val = idx[0, loc].item() nxt_val = idx[0, loc + 1].item() pos_l.append(p) num_l.append(num_val) nxt_l.append(nxt_val) gap_l.append(nxt_val - num_val) int_l.append(0.0) pr = bpreds[loc].item() cor_l.append(int(pr == nxt_val)) pred_l.append(pr) # --- interventions --- try: im = GPTIntervention(model, idx) except Exception: continue for p in range(bs - 1): loc = bs + 1 + p num_val = idx[0, loc].item() nxt_val = idx[0, loc + 1].item() gap = nxt_val - num_val for intensity in INTENSITIES: try: im.intervent_attention( attention_layer_num=0, location=loc, unsorted_lb=UB, unsorted_ub=UB, unsorted_lb_num=0, unsorted_ub_num=1, unsorted_intensity_inc=intensity, sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) pr, ac = im.check_if_still_works() pos_l.append(p) num_l.append(num_val) nxt_l.append(nxt_val) gap_l.append(gap) int_l.append(intensity) cor_l.append(int(pr == ac)) pred_l.append(pr) im.revert_attention(0) n_ok += 1 except Exception: try: im.revert_attention(0) except Exception: pass n_fail += 1 if (si + 1) % 500 == 0: print(f" {si+1}/{N_SEQ} ok={n_ok} fail={n_fail}", flush=True) return dict( position=np.array(pos_l, dtype=np.int16), number=np.array(num_l, dtype=np.int16), next_number=np.array(nxt_l, dtype=np.int16), gap=np.array(gap_l, dtype=np.int16), intensity=np.array(int_l, dtype=np.float32), correct=np.array(cor_l, dtype=np.int8), predicted=np.array(pred_l, dtype=np.int16), ) def main(): ap = argparse.ArgumentParser() ap.add_argument('--tasks-file', required=True) ap.add_argument('--gpu', type=int, required=True) args = ap.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) device = 'cuda' with open(args.tasks_file) as f: tasks = json.load(f) print(f"GPU {args.gpu}: {len(tasks)} checkpoints to process", flush=True) for t in tasks: if os.path.exists(t['out']): print(f" Skip {t['name']} (cached)", flush=True) continue t0 = time.time() model, config = load_model(t['ckpt_path'], device) print(f" Loaded {t['name']} ({time.time()-t0:.1f}s)", flush=True) t0 = time.time() res = run_checkpoint(model, config, device) os.makedirs(os.path.dirname(t['out']), exist_ok=True) np.savez_compressed(t['out'], **res) dt = time.time() - t0 n = len(res['position']) print(json.dumps({ 'done': t['name'], 'gpu': args.gpu, 'elapsed': round(dt, 1), 'n_trials': n }), flush=True) print(f"GPU {args.gpu}: all done", flush=True) if __name__ == '__main__': main()