llm-sort / pernumber_worker.py
gatmiry's picture
Upload folder using huggingface_hub
beda614 verified
"""
GPU worker for per-number intervention experiments.
For each target number, generates sequences containing that number,
finds its position in the sorted output, and intervenes there.
"""
import argparse
import json
import os
import sys
import time
import numpy as np
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'grid-run'))
from model_analysis import GPT, GPTConfig, GPTIntervention
INTENSITY_VALUES = [1.0, 2.0, 4.0, 6.0, 8.0, 10.0]
UB = 60
LB = 60
MIN_VALID = 200
def remap_state_dict(sd):
new = {}
for k, v in sd.items():
nk = k
for i in range(10):
nk = nk.replace(f'transformer.h.{i}.attn.', f'transformer.h.{i}.c_attn.')
nk = nk.replace(f'transformer.h.{i}.mlp.', f'transformer.h.{i}.c_fc.')
new[nk] = v
return new
def load_model(ckpt_path, device):
ckpt = torch.load(ckpt_path, map_location='cpu')
mc = ckpt['model_config']
config = GPTConfig(block_size=mc['block_size'], vocab_size=mc['vocab_size'] - 1,
with_layer_norm=mc.get('use_final_LN', True))
model = GPT(config)
sd = remap_state_dict(ckpt['model_state_dict'])
wpe_max = config.block_size * 4 + 1
if 'transformer.wpe.weight' in sd and sd['transformer.wpe.weight'].shape[0] > wpe_max:
sd['transformer.wpe.weight'] = sd['transformer.wpe.weight'][:wpe_max]
for k in [k for k in sd if k.endswith('.c_attn.bias') and 'c_attn.c_attn' not in k]:
del sd[k]
if 'lm_head.weight' in sd:
del sd['lm_head.weight']
model.load_state_dict(sd, strict=False)
model.to(device)
model.eval()
return model, config
def get_batch_with_number(vs, bs, target_num, device):
"""Generate a batch that always contains target_num."""
pool = list(range(vs))
pool.remove(target_num)
perm = torch.randperm(len(pool))[:bs - 1]
others = torch.tensor([pool[i] for i in perm])
x = torch.cat([others, torch.tensor([target_num])])
x = x[torch.randperm(bs)]
vals, _ = torch.sort(x)
return torch.cat((x, torch.tensor([vs]), vals), dim=0).unsqueeze(0).to(device)
def find_location_of_number(idx, bs, target_num):
"""Find the location index for target_num in the sorted output.
Returns None if target_num is the last sorted element (no next token to predict)."""
sorted_part = idx[0, bs + 1: 2 * bs + 1]
matches = (sorted_part == target_num).nonzero(as_tuple=True)[0]
if len(matches) == 0:
return None
k = matches[0].item()
if k >= bs - 1:
return None
return bs + 1 + k
def compute_intensity_for_number(model, config, device, attn_layer, target_num):
bs = config.block_size
vs = config.vocab_size
rates, counts = [], []
for intens in INTENSITY_VALUES:
attempts, rounds = [], 0
while len(attempts) < MIN_VALID and rounds < 3000:
rounds += 1
idx = get_batch_with_number(vs, bs, target_num, device)
location = find_location_of_number(idx, bs, target_num)
if location is None:
continue
try:
im = GPTIntervention(model, idx)
im.intervent_attention(
attention_layer_num=attn_layer, location=location,
unsorted_lb=LB, unsorted_ub=UB,
unsorted_lb_num=0, unsorted_ub_num=1,
unsorted_intensity_inc=intens,
sorted_lb=0, sorted_num=0, sorted_intensity_inc=0.0)
g, n = im.check_if_still_works()
attempts.append(g == n)
im.revert_attention(attn_layer)
except:
continue
counts.append(len(attempts))
rates.append(sum(attempts) / len(attempts) if attempts else 0.0)
return np.array(INTENSITY_VALUES), np.array(rates), np.array(counts)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--tasks-file', required=True)
parser.add_argument('--gpu', type=int, required=True)
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
device = 'cuda'
with open(args.tasks_file) as f:
task_list = json.load(f)
n_ckpts = len(set(t['ckpt_path'] for t in task_list))
print(f"GPU {args.gpu}: {len(task_list)} tasks across {n_ckpts} checkpoints", flush=True)
current_ckpt = None
model = None
done = 0
for task in task_list:
if os.path.exists(task['out']):
done += 1
continue
if task['ckpt_path'] != current_ckpt:
t0 = time.time()
model, config = load_model(task['ckpt_path'], device)
current_ckpt = task['ckpt_path']
print(f" Loaded {os.path.basename(current_ckpt)} ({time.time()-t0:.1f}s)", flush=True)
os.makedirs(os.path.dirname(task['out']), exist_ok=True)
t0 = time.time()
try:
intensities, rates, counts = compute_intensity_for_number(
model, config, device, task['layer'], task['target_num'])
np.savez(task['out'], intensities=intensities, success_rates=rates,
counts=counts, target_num=task['target_num'], layer=task['layer'])
dt = time.time() - t0
done += 1
print(json.dumps({
'status': 'done', 'task': task['name'], 'gpu': args.gpu,
'elapsed': round(dt, 1), 'progress': f'{done}/{len(task_list)}',
'counts': counts.tolist(),
}), flush=True)
except Exception as e:
done += 1
print(json.dumps({
'status': 'fail', 'task': task['name'], 'error': str(e),
}), flush=True)
print(f"GPU {args.gpu}: all done ({done}/{len(task_list)})", flush=True)
if __name__ == '__main__':
main()