| |
| """ |
| Orchestrator: distribute avg-attention-by-number computation across 8 GPUs |
| for all checkpoints in 1000k-checkpoints/. |
| """ |
| import glob, json, os, subprocess, sys, time |
|
|
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| WORKER = os.path.join(SCRIPT_DIR, 'attn_by_number_worker.py') |
| OUTPUT_BASE = os.path.join(SCRIPT_DIR, 'outputs') |
| TASK_DIR = os.path.join(OUTPUT_BASE, 'task_files') |
|
|
|
|
| def get_iter(basename): |
| tag = basename.replace('.pt', '').split('__')[1] |
| if tag.startswith('ckpt'): |
| return int(tag.replace('ckpt', '')) |
| if tag == 'final': |
| return 1000000 |
| return None |
|
|
|
|
| def folder_name(itr): |
| return f"plots_V256_B16_LR3e-2_MI{itr}_E64_H1_L2_ds1337_is1337_ckpt{itr}" |
|
|
|
|
| def main(): |
| import torch |
| n_gpus = min(torch.cuda.device_count(), 8) |
|
|
| pt_files = sorted(glob.glob(os.path.join(SCRIPT_DIR, '*.pt'))) |
| seen, tasks = set(), [] |
| for pt in pt_files: |
| bn = os.path.basename(pt) |
| itr = get_iter(bn) |
| if itr is None or itr in seen: |
| continue |
| seen.add(itr) |
| tasks.append({'ckpt_path': pt, |
| 'out_dir': os.path.join(OUTPUT_BASE, folder_name(itr)), |
| 'itr': itr}) |
| tasks.sort(key=lambda t: t['itr']) |
|
|
| print(f"{len(tasks)} checkpoints → {n_gpus} GPUs") |
|
|
| gpu_tasks = [[] for _ in range(n_gpus)] |
| for i, t in enumerate(tasks): |
| gpu_tasks[i % n_gpus].append(t) |
|
|
| os.makedirs(TASK_DIR, exist_ok=True) |
| procs = [] |
| for gid in range(n_gpus): |
| if not gpu_tasks[gid]: |
| continue |
| tf = os.path.join(TASK_DIR, f'attn_by_number_gpu{gid}.json') |
| with open(tf, 'w') as f: |
| json.dump(gpu_tasks[gid], f) |
| labels = [f"ckpt{t['itr']//1000}k" for t in gpu_tasks[gid]] |
| print(f" GPU {gid}: {', '.join(labels)}") |
| p = subprocess.Popen( |
| [sys.executable, WORKER, '--tasks-file', tf, '--gpu', str(gid)], |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| procs.append((p, gid)) |
|
|
| t0 = time.time() |
| for p, gid in procs: |
| stdout, stderr = p.communicate() |
| rc = p.returncode |
| tag = "OK" if rc == 0 else f"FAIL(rc={rc})" |
| print(f" GPU {gid}: {tag}") |
| if stdout: |
| for line in stdout.decode().strip().split('\n'): |
| print(f" {line}") |
| if rc != 0 and stderr: |
| print(f" STDERR: {stderr.decode()[-800:]}") |
|
|
| print(f"\nDone in {time.time()-t0:.1f}s") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|