| """ |
| GPU worker for per-location intervention experiments. |
| Loads model once per checkpoint, runs all location/layer tasks, saves results. |
| """ |
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run')) |
| from model_analysis import GPT, GPTConfig, GPTIntervention |
|
|
| INTENSITY_VALUES = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0] |
| UB = 60 |
| LB = 60 |
| MIN_VALID = 200 |
|
|
|
|
| def remap_state_dict(sd): |
| new = {} |
| for k, v in sd.items(): |
| nk = k |
| for i in range(10): |
| nk = nk.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.') |
| nk = nk.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.') |
| new[nk] = v |
| return new |
|
|
|
|
| def load_model(ckpt_path, device): |
| ckpt = torch.load(ckpt_path, map_location='cpu') |
| mc = ckpt['model_config'] |
| config = GPTConfig(block_size=mc['block_size'], vocab_size=mc['vocab_size'] - 1, |
| with_layer_norm=mc.get('use_final_LN', True)) |
| model = GPT(config) |
| sd = remap_state_dict(ckpt['model_state_dict']) |
| wpe_max = config.block_size * 4 + 1 |
| if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > wpe_max: |
| sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:wpe_max] |
| for k in [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]: |
| del sd[k] |
| if 'lm_head.weight' in sd: |
| del sd['lm_head.weight'] |
| model.load_state_dict(sd, strict=False) |
| model.to(device) |
| model.eval() |
| return model, config |
|
|
|
|
| def get_batch(vs, bs, device): |
| x = torch.randperm(vs)[:bs] |
| vals, _ = torch.sort(x) |
| return torch.cat((x, torch.tensor([vs]), vals), dim=0).unsqueeze(0).to(device) |
|
|
|
|
| def compute_intensity_at_location(model, config, device, attn_layer, sorted_pos): |
| """Run intervention at a specific sorted-output position.""" |
| bs = config.block_size |
| vs = config.vocab_size |
| location = bs + sorted_pos |
|
|
| rates, counts = [], [] |
| for intens in INTENSITY_VALUES: |
| attempts, rounds = [], 0 |
| while len(attempts) < MIN_VALID and rounds < 2000: |
| rounds += 1 |
| idx = get_batch(vs, bs, device) |
| try: |
| im = GPTIntervention(model, idx) |
| im.intervent_attention( |
| attention_layer_num=attn_layer, location=location, |
| unsorted_lb=LB, unsorted_ub=UB, |
| unsorted_lb_num=0, unsorted_ub_num=1, |
| unsorted_intensity_inc=intens, |
| sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0) |
| g, n = im.check_if_still_works() |
| attempts.append(g == n) |
| im.revert_attention(attn_layer) |
| except: |
| continue |
| counts.append(len(attempts)) |
| rates.append(sum(attempts) / len(attempts) if attempts else 0.0) |
|
|
| return np.array(INTENSITY_VALUES), np.array(rates), np.array(counts) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--tasks-file', required=True) |
| parser.add_argument('--gpu', type=int, required=True) |
| args = parser.parse_args() |
|
|
| os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) |
| device = 'cuda' |
|
|
| with open(args.tasks_file) as f: |
| task_list = json.load(f) |
|
|
| n_ckpts = len(set(t['ckpt_path'] for t in task_list)) |
| print(f"GPU {args.gpu}: {len(task_list)} tasks across {n_ckpts} checkpoints", flush=True) |
|
|
| current_ckpt = None |
| model = None |
| done = 0 |
|
|
| for task in task_list: |
| if os.path.exists(task['out']): |
| done += 1 |
| continue |
|
|
| if task['ckpt_path'] != current_ckpt: |
| t0 = time.time() |
| model, config = load_model(task['ckpt_path'], device) |
| current_ckpt = task['ckpt_path'] |
| print(f" Loaded {os.path.basename(current_ckpt)} ({time.time()-t0:.1f}s)", flush=True) |
|
|
| os.makedirs(os.path.dirname(task['out']), exist_ok=True) |
| t0 = time.time() |
| try: |
| intensities, rates, counts = compute_intensity_at_location( |
| model, config, device, task['layer'], task['sorted_pos']) |
| np.savez(task['out'], intensities=intensities, success_rates=rates, |
| counts=counts, sorted_pos=task['sorted_pos'], layer=task['layer']) |
| dt = time.time() - t0 |
| done += 1 |
| print(json.dumps({ |
| 'status': 'done', 'task': task['name'], 'gpu': args.gpu, |
| 'elapsed': round(dt, 1), 'progress': f'{done}/{len(task_list)}', |
| 'counts': counts.tolist(), |
| }), flush=True) |
| except Exception as e: |
| done += 1 |
| print(json.dumps({ |
| 'status': 'fail', 'task': task['name'], 'error': str(e), |
| }), flush=True) |
|
|
| print(f"GPU {args.gpu}: all done ({done}/{len(task_list)})", flush=True) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|