Add minified training script for submission
Browse files- train_final.py +482 -0
train_final.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PG-v2: SP8192+ParallelRes+DepthRec+TTT+Int6GPTQ+EMA"""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
import copy,glob,io,math,os,random,subprocess,sys,time,uuid,zlib
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sentencepiece as spm
|
| 7 |
+
import torch,torch.distributed as dist,torch.nn.functional as F
|
| 8 |
+
from torch import Tensor,nn
|
| 9 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 10 |
+
|
| 11 |
+
class H:
|
| 12 |
+
dp=os.environ.get("DATA_PATH","./data/datasets/fineweb10B_sp8192")
|
| 13 |
+
tf=os.path.join(dp,"fineweb_train_*.bin")
|
| 14 |
+
vf=os.path.join(dp,"fineweb_val_*.bin")
|
| 15 |
+
tp=os.environ.get("TOKENIZER_PATH","./data/tokenizers/fineweb_8192_bpe.model")
|
| 16 |
+
rid=os.environ.get("RUN_ID",str(uuid.uuid4()))
|
| 17 |
+
seed=int(os.environ.get("SEED","1337"))
|
| 18 |
+
vbs=int(os.environ.get("VBS","524288"));vle=int(os.environ.get("VLE","1000"))
|
| 19 |
+
tle=int(os.environ.get("TLE","200"))
|
| 20 |
+
iters=int(os.environ.get("ITERS","20000"))
|
| 21 |
+
wdi=int(os.environ.get("WDI","3500"));wui=int(os.environ.get("WUI","20"))
|
| 22 |
+
tbt=int(os.environ.get("TBT","524288"));tsl=int(os.environ.get("TSL","1024"))
|
| 23 |
+
mws=float(os.environ.get("MWS","600.0"))
|
| 24 |
+
V=int(os.environ.get("V","8192"));D=int(os.environ.get("D","768"))
|
| 25 |
+
nh=int(os.environ.get("NH","12"));nkv=int(os.environ.get("NKV","4"))
|
| 26 |
+
mm=int(os.environ.get("MM","4"))
|
| 27 |
+
nul=int(os.environ.get("NUL","3"));nr=int(os.environ.get("NR","8"))
|
| 28 |
+
ner=int(os.environ.get("NER","0"))
|
| 29 |
+
rb=float(os.environ.get("RB","10000.0"))
|
| 30 |
+
lsc=float(os.environ.get("LSC","30.0"))
|
| 31 |
+
qkg=float(os.environ.get("QKG","5.25"))
|
| 32 |
+
sws=int(os.environ.get("SWS","64"));swl=int(os.environ.get("SWL","1024"))
|
| 33 |
+
tte=int(os.environ.get("TTE","1"))
|
| 34 |
+
ttlr=float(os.environ.get("TTLR","0.01"))
|
| 35 |
+
ttcs=int(os.environ.get("TTCS","64"))
|
| 36 |
+
ttly=os.environ.get("TTLY","all")
|
| 37 |
+
elr=float(os.environ.get("ELR","0.05"))
|
| 38 |
+
mlr=float(os.environ.get("MLR","0.04"))
|
| 39 |
+
slr=float(os.environ.get("SLR","0.04"))
|
| 40 |
+
mmo=float(os.environ.get("MMO","0.95"))
|
| 41 |
+
mbs=int(os.environ.get("MBS","5"))
|
| 42 |
+
mwd=float(os.environ.get("MWD","0.09"))
|
| 43 |
+
b1=float(os.environ.get("B1","0.9"))
|
| 44 |
+
b2=float(os.environ.get("B2","0.95"))
|
| 45 |
+
ae=float(os.environ.get("AE","1e-8"))
|
| 46 |
+
gb=int(os.environ.get("GB","6"))
|
| 47 |
+
sdn=float(os.environ.get("SDN","2.5"))
|
| 48 |
+
esf=float(os.environ.get("ESF","0.4"))
|
| 49 |
+
|
| 50 |
+
CP=tuple(p for p in "attn_scale,mlp_scale,resid_mix,q_gain".split(",") if p)
|
| 51 |
+
|
| 52 |
+
def zp5(G,s=10,e=1e-7):
|
| 53 |
+
a,b,c=3.4445,-4.7750,2.0315
|
| 54 |
+
X=G.bfloat16();X/=X.norm()+e
|
| 55 |
+
tr=G.size(0)>G.size(1)
|
| 56 |
+
if tr:X=X.T
|
| 57 |
+
for _ in range(s):
|
| 58 |
+
A=X@X.T;B=b*A+c*A@A;X=a*X+B@X
|
| 59 |
+
return X.T if tr else X
|
| 60 |
+
|
| 61 |
+
class Muon(torch.optim.Optimizer):
|
| 62 |
+
def __init__(s,p,lr,mom,bs,wd=0.,nest=True):
|
| 63 |
+
super().__init__(p,dict(lr=lr,mom=mom,bs=bs,wd=wd,nest=nest))
|
| 64 |
+
@torch.no_grad()
|
| 65 |
+
def step(s,cl=None):
|
| 66 |
+
lo=None
|
| 67 |
+
if cl:
|
| 68 |
+
with torch.enable_grad():lo=cl()
|
| 69 |
+
dd=dist.is_available() and dist.is_initialized()
|
| 70 |
+
ws=dist.get_world_size() if dd else 1
|
| 71 |
+
rk=dist.get_rank() if dd else 0
|
| 72 |
+
for g in s.param_groups:
|
| 73 |
+
ps=g["params"];lr=g["lr"];mo=g["mom"];bs=g["bs"];wd=g["wd"];ne=g["nest"]
|
| 74 |
+
tot=sum(int(p.numel()) for p in ps)
|
| 75 |
+
fl=torch.zeros(tot,device=ps[0].device,dtype=torch.bfloat16)
|
| 76 |
+
cur=0
|
| 77 |
+
for i,p in enumerate(ps):
|
| 78 |
+
if i%ws==rk and p.grad is not None:
|
| 79 |
+
gr=p.grad
|
| 80 |
+
if wd:gr=gr+wd*p.data.to(gr.dtype)
|
| 81 |
+
st=s.state[p]
|
| 82 |
+
if "mb" not in st:st["mb"]=torch.zeros_like(gr)
|
| 83 |
+
buf=st["mb"];buf.mul_(mo).add_(gr)
|
| 84 |
+
if ne:gr=gr.add(buf,alpha=mo)
|
| 85 |
+
gr=zp5(gr,steps=bs)
|
| 86 |
+
gr*=max(1,gr.size(0)/gr.size(1))**0.5
|
| 87 |
+
fl[cur:cur+p.numel()]=gr.reshape(-1)
|
| 88 |
+
cur+=p.numel()
|
| 89 |
+
if dd:dist.all_reduce(fl,op=dist.ReduceOp.SUM)
|
| 90 |
+
cur=0
|
| 91 |
+
for p in ps:
|
| 92 |
+
gr=fl[cur:cur+p.numel()].view_as(p).to(dtype=p.dtype)
|
| 93 |
+
p.add_(gr,alpha=-lr);cur+=p.numel()
|
| 94 |
+
return lo
|
| 95 |
+
|
| 96 |
+
def build_sp_luts(sp,vs,dev):
|
| 97 |
+
sv=int(sp.vocab_size());sz=max(sv,vs)
|
| 98 |
+
bb=np.zeros(sz,dtype=np.int16);hs=np.zeros(sz,dtype=bool);ib=np.ones(sz,dtype=bool)
|
| 99 |
+
for t in range(sv):
|
| 100 |
+
if sp.is_control(t) or sp.is_unknown(t) or sp.is_unused(t):continue
|
| 101 |
+
ib[t]=False
|
| 102 |
+
if sp.is_byte(t):bb[t]=1;continue
|
| 103 |
+
pc=sp.id_to_piece(t)
|
| 104 |
+
if pc.startswith("\u2581"):hs[t]=True;pc=pc[1:]
|
| 105 |
+
bb[t]=len(pc.encode("utf-8"))
|
| 106 |
+
return(torch.tensor(bb,dtype=torch.int16,device=dev),
|
| 107 |
+
torch.tensor(hs,dtype=torch.bool,device=dev),
|
| 108 |
+
torch.tensor(ib,dtype=torch.bool,device=dev))
|
| 109 |
+
|
| 110 |
+
def eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False):
|
| 111 |
+
sl=a.swl;st=a.sws;T=vt.numel()
|
| 112 |
+
starts=list(range(0,T-sl-1,st))
|
| 113 |
+
my=starts[rk::ws]
|
| 114 |
+
ls=torch.zeros((),device=dev,dtype=torch.float64)
|
| 115 |
+
tc=torch.zeros((),device=dev,dtype=torch.float64)
|
| 116 |
+
bc=torch.zeros((),device=dev,dtype=torch.float64)
|
| 117 |
+
rm=mdl
|
| 118 |
+
while hasattr(rm,'module'):rm=rm.module
|
| 119 |
+
if hasattr(rm,'_orig_mod'):rm=rm._orig_mod
|
| 120 |
+
rm.eval()
|
| 121 |
+
ctx=torch.no_grad if ttt else torch.inference_mode
|
| 122 |
+
with ctx():
|
| 123 |
+
for s in my:
|
| 124 |
+
e=s+sl
|
| 125 |
+
x=vt[s:e].unsqueeze(0).to(dev,dtype=torch.int64)
|
| 126 |
+
y=vt[s+1:e+1].unsqueeze(0).to(dev,dtype=torch.int64)
|
| 127 |
+
with torch.autocast("cuda",dtype=torch.bfloat16):
|
| 128 |
+
if ttt and a.tte:ptl=rm.ptl_ttt(x,y,a)
|
| 129 |
+
else:ptl=rm.ptl(x,y)
|
| 130 |
+
lo=sl-st;ps=ptl[0,lo:];ys=y[0,lo:];xs=x[0,lo:]
|
| 131 |
+
ls+=ps.to(torch.float64).sum();tc+=ps.numel()
|
| 132 |
+
tb=bbl[ys].to(torch.float64)
|
| 133 |
+
tb+=(hsl[ys]&~ibl[xs]).to(torch.float64)
|
| 134 |
+
bc+=tb.sum()
|
| 135 |
+
if dist.is_available() and dist.is_initialized():
|
| 136 |
+
for t in(ls,tc,bc):dist.all_reduce(t,op=dist.ReduceOp.SUM)
|
| 137 |
+
vl=float((ls/tc).item());bpb=float((ls/math.log(2)/bc).item())
|
| 138 |
+
rm.train();return vl,bpb
|
| 139 |
+
|
| 140 |
+
def sdclip(t,n=2.5):
|
| 141 |
+
m=t.float().mean();s=t.float().std()
|
| 142 |
+
return t.clamp((m-n*s).item(),(m+n*s).item())
|
| 143 |
+
|
| 144 |
+
def qi6(t,ns=2.5):
|
| 145 |
+
t32=t.float();mx=31
|
| 146 |
+
if t32.ndim==2:
|
| 147 |
+
m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9)
|
| 148 |
+
lo=m-ns*s;hi=m+ns*s
|
| 149 |
+
tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32))
|
| 150 |
+
cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/mx
|
| 151 |
+
q=torch.clamp(torch.round(tc/sc[:,None]),-mx,mx).to(torch.int8)
|
| 152 |
+
return q.contiguous(),sc.to(torch.float16).contiguous()
|
| 153 |
+
tc=sdclip(t32,ns);cv=float(tc.abs().max().item())
|
| 154 |
+
sc=torch.tensor(max(cv/mx,1./mx),dtype=torch.float32)
|
| 155 |
+
q=torch.clamp(torch.round(tc/sc),-mx,mx).to(torch.int8)
|
| 156 |
+
return q.contiguous(),sc
|
| 157 |
+
|
| 158 |
+
def qi8(t,ns=2.5):
|
| 159 |
+
t32=t.float()
|
| 160 |
+
if t32.ndim==2:
|
| 161 |
+
m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9)
|
| 162 |
+
lo=m-ns*s;hi=m+ns*s
|
| 163 |
+
tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32))
|
| 164 |
+
cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/127.
|
| 165 |
+
q=torch.clamp(torch.round(tc/sc[:,None]),-127,127).to(torch.int8)
|
| 166 |
+
return q.contiguous(),sc.to(torch.float16).contiguous()
|
| 167 |
+
cv=float(sdclip(t32,ns).abs().max().item())
|
| 168 |
+
sc=torch.tensor(max(cv/127.,1./127.),dtype=torch.float32)
|
| 169 |
+
q=torch.clamp(torch.round(t32.clamp(-cv,cv)/sc),-127,127).to(torch.int8)
|
| 170 |
+
return q.contiguous(),sc
|
| 171 |
+
|
| 172 |
+
def qsd(sd,gb=6,ns=2.5):
|
| 173 |
+
qf=qi6 if gb==6 else qi8
|
| 174 |
+
qu,sc,dt,pt,po,qm={},{},{},{},{},{}
|
| 175 |
+
st={k:0 for k in("pc","nt","bb","qb")}
|
| 176 |
+
for n,t in sd.items():
|
| 177 |
+
t=t.detach().cpu().contiguous()
|
| 178 |
+
st["pc"]+=t.numel();st["nt"]+=1;st["bb"]+=t.numel()*t.element_size()
|
| 179 |
+
if not t.is_floating_point():pt[n]=t;st["qb"]+=t.numel()*t.element_size();continue
|
| 180 |
+
ic=any(p in n for p in CP);ism=t.numel()<=65536
|
| 181 |
+
if "tok_emb" in n:
|
| 182 |
+
po[n]=str(t.dtype).removeprefix("torch.")
|
| 183 |
+
q,s=qi8(t,ns);qu[n]=q;sc[n]=s;dt[n]=po[n]
|
| 184 |
+
if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":8}
|
| 185 |
+
st["qb"]+=q.numel()+s.numel()*s.element_size();continue
|
| 186 |
+
if ic or ism:
|
| 187 |
+
if t.dtype in(torch.float32,torch.bfloat16):po[n]=str(t.dtype).removeprefix("torch.")
|
| 188 |
+
pt[n]=t.float() if ic else t.to(torch.float16)
|
| 189 |
+
pt[n]=pt[n].contiguous();st["qb"]+=pt[n].numel()*pt[n].element_size();continue
|
| 190 |
+
q,s=qf(t,ns)
|
| 191 |
+
if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":gb}
|
| 192 |
+
qu[n]=q;sc[n]=s;dt[n]=str(t.dtype).removeprefix("torch.")
|
| 193 |
+
st["qb"]+=q.numel()+s.numel()*s.element_size()
|
| 194 |
+
obj={"__qf__":f"i{gb}sd","q":qu,"s":sc,"d":dt,"p":pt}
|
| 195 |
+
if qm:obj["m"]=qm
|
| 196 |
+
if po:obj["o"]=po
|
| 197 |
+
return obj,st
|
| 198 |
+
|
| 199 |
+
def dqsd(obj):
|
| 200 |
+
out={};qm=obj.get("m",{});po=obj.get("o",{})
|
| 201 |
+
for n,q in obj["q"].items():
|
| 202 |
+
dt=getattr(torch,obj["d"][n]);s=obj["s"][n]
|
| 203 |
+
if qm.get(n,{}).get("scheme")=="per_row" or s.ndim>0:
|
| 204 |
+
s=s.to(torch.float32)
|
| 205 |
+
out[n]=(q.float()*s.view(q.shape[0],*([1]*(q.ndim-1)))).to(dt).contiguous()
|
| 206 |
+
else:out[n]=(q.float()*float(s.item())).to(dt).contiguous()
|
| 207 |
+
for n,t in obj["p"].items():
|
| 208 |
+
ot=t.detach().cpu().contiguous();od=po.get(n)
|
| 209 |
+
if isinstance(od,str):ot=ot.to(dtype=getattr(torch,od)).contiguous()
|
| 210 |
+
out[n]=ot
|
| 211 |
+
return out
|
| 212 |
+
|
| 213 |
+
def lds(f):
|
| 214 |
+
h=np.fromfile(f,dtype="<i4",count=256)
|
| 215 |
+
if h.size!=256 or int(h[0])!=20240520 or int(h[1])!=1:raise ValueError(f"Bad:{f}")
|
| 216 |
+
n=int(h[2]);t=np.fromfile(f,dtype="<u2",count=n,offset=256*4)
|
| 217 |
+
return torch.from_numpy(t.astype(np.uint16,copy=False))
|
| 218 |
+
|
| 219 |
+
def lvt(pat,sl):
|
| 220 |
+
fs=[Path(p) for p in sorted(glob.glob(pat))]
|
| 221 |
+
if not fs:raise FileNotFoundError(f"No val:{pat}")
|
| 222 |
+
t=torch.cat([lds(f) for f in fs]).contiguous()
|
| 223 |
+
u=((t.numel()-1)//sl)*sl;return t[:u+1]
|
| 224 |
+
|
| 225 |
+
class TS:
|
| 226 |
+
def __init__(s,pat):
|
| 227 |
+
fs=[Path(p) for p in sorted(glob.glob(pat))]
|
| 228 |
+
if not fs:raise FileNotFoundError(f"No:{pat}")
|
| 229 |
+
s.fs=fs;s.i=0;s.t=lds(fs[0]);s.p=0
|
| 230 |
+
def take(s,n):
|
| 231 |
+
ch=[];r=n
|
| 232 |
+
while r>0:
|
| 233 |
+
av=s.t.numel()-s.p
|
| 234 |
+
if av<=0:s.i=(s.i+1)%len(s.fs);s.t=lds(s.fs[s.i]);s.p=0;av=s.t.numel()
|
| 235 |
+
k=min(r,av);ch.append(s.t[s.p:s.p+k]);s.p+=k;r-=k
|
| 236 |
+
return ch[0] if len(ch)==1 else torch.cat(ch)
|
| 237 |
+
|
| 238 |
+
class DTL:
|
| 239 |
+
def __init__(s,pat,rk,ws,dev):s.rk=rk;s.ws=ws;s.dev=dev;s.st=TS(pat)
|
| 240 |
+
def nb(s,gt,sl,ga):
|
| 241 |
+
lt=gt//(s.ws*ga);ps=lt+1
|
| 242 |
+
ch=s.st.take(ps*s.ws);st=s.rk*ps
|
| 243 |
+
lo=ch[st:st+ps].to(torch.int64)
|
| 244 |
+
x=lo[:-1].reshape(-1,sl);y=lo[1:].reshape(-1,sl)
|
| 245 |
+
return x.to(s.dev,non_blocking=True),y.to(s.dev,non_blocking=True)
|
| 246 |
+
|
| 247 |
+
class RN(nn.Module):
|
| 248 |
+
def __init__(s,eps=None):super().__init__();s.eps=eps
|
| 249 |
+
def forward(s,x):return F.rms_norm(x,(x.size(-1),),eps=s.eps)
|
| 250 |
+
|
| 251 |
+
class Rot(nn.Module):
|
| 252 |
+
def __init__(s,d,b=10000.):
|
| 253 |
+
super().__init__()
|
| 254 |
+
s.register_buffer("if_",1./(b**(torch.arange(0,d,2,dtype=torch.float32)/d)),persistent=False)
|
| 255 |
+
s._cl=0;s._c=None;s._s=None
|
| 256 |
+
def forward(s,sl,dev,dt):
|
| 257 |
+
if s._c is None or s._cl!=sl or s._c.device!=dev:
|
| 258 |
+
t=torch.arange(sl,device=dev,dtype=s.if_.dtype)
|
| 259 |
+
fr=torch.outer(t,s.if_.to(dev))
|
| 260 |
+
s._c=fr.cos()[None,None,:,:];s._s=fr.sin()[None,None,:,:];s._cl=sl
|
| 261 |
+
return s._c.to(dtype=dt),s._s.to(dtype=dt)
|
| 262 |
+
|
| 263 |
+
def arot(x,c,si):
|
| 264 |
+
h=x.size(-1)//2;x1,x2=x[...,:h],x[...,h:]
|
| 265 |
+
return torch.cat((x1*c+x2*si,x1*(-si)+x2*c),dim=-1)
|
| 266 |
+
|
| 267 |
+
class CSA(nn.Module):
|
| 268 |
+
def __init__(s,d,nh,nk,rb,qkg):
|
| 269 |
+
super().__init__()
|
| 270 |
+
assert d%nh==0 and nh%nk==0
|
| 271 |
+
s.nh=nh;s.nk=nk;s.hd=d//nh;kd=nk*s.hd
|
| 272 |
+
s.cq=nn.Linear(d,d,bias=False);s.ck=nn.Linear(d,kd,bias=False)
|
| 273 |
+
s.cv=nn.Linear(d,kd,bias=False);s.pr=nn.Linear(d,d,bias=False)
|
| 274 |
+
s.qg=nn.Parameter(torch.full((nh,),qkg,dtype=torch.float32))
|
| 275 |
+
s.rot=Rot(s.hd,base=rb)
|
| 276 |
+
def forward(s,x):
|
| 277 |
+
B,T,_=x.shape
|
| 278 |
+
q=s.cq(x).reshape(B,T,s.nh,s.hd).transpose(1,2)
|
| 279 |
+
k=s.ck(x).reshape(B,T,s.nk,s.hd).transpose(1,2)
|
| 280 |
+
v=s.cv(x).reshape(B,T,s.nk,s.hd).transpose(1,2)
|
| 281 |
+
q=F.rms_norm(q,(q.size(-1),));k=F.rms_norm(k,(k.size(-1),))
|
| 282 |
+
c,si=s.rot(T,x.device,q.dtype)
|
| 283 |
+
q=arot(q,c,si);k=arot(k,c,si)
|
| 284 |
+
q=q*s.qg.to(dtype=q.dtype)[None,:,None,None]
|
| 285 |
+
y=F.scaled_dot_product_attention(q,k,v,attn_mask=None,is_causal=True,
|
| 286 |
+
enable_gqa=(s.nk!=s.nh))
|
| 287 |
+
return s.pr(y.transpose(1,2).contiguous().reshape(B,T,-1))
|
| 288 |
+
|
| 289 |
+
class MLP(nn.Module):
|
| 290 |
+
def __init__(s,d,m):
|
| 291 |
+
super().__init__()
|
| 292 |
+
h=d*m;s.fc=nn.Linear(d,h,bias=False);s.pr=nn.Linear(h,d,bias=False)
|
| 293 |
+
def forward(s,x):return s.pr(torch.relu(s.fc(x)).square())
|
| 294 |
+
|
| 295 |
+
class PB(nn.Module):
|
| 296 |
+
"""Parallel residual block."""
|
| 297 |
+
def __init__(s,d,nh,nk,mm,rb,qkg):
|
| 298 |
+
super().__init__()
|
| 299 |
+
s.n=RN();s.a=CSA(d,nh,nk,rb,qkg);s.m=MLP(d,mm)
|
| 300 |
+
s.as_=nn.Parameter(torch.ones(d,dtype=torch.float32))
|
| 301 |
+
s.ms=nn.Parameter(torch.ones(d,dtype=torch.float32))
|
| 302 |
+
s.rm=nn.Parameter(torch.stack([torch.ones(d),torch.zeros(d)]).float())
|
| 303 |
+
def forward(s,x,x0):
|
| 304 |
+
mx=s.rm.to(x.dtype)
|
| 305 |
+
x=mx[0][None,None,:]*x+mx[1][None,None,:]*x0
|
| 306 |
+
h=s.n(x)
|
| 307 |
+
x=x+s.as_.to(x.dtype)[None,None,:]*s.a(h)+s.ms.to(x.dtype)[None,None,:]*s.m(h)
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
class RGPT(nn.Module):
|
| 311 |
+
def __init__(s,a):
|
| 312 |
+
super().__init__()
|
| 313 |
+
s.lsc=a.lsc;s._tr=a.nr;s._er=a.ner or a.nr*2;s._V=a.V
|
| 314 |
+
s.te=nn.Embedding(a.V,a.D)
|
| 315 |
+
s.bl=nn.ModuleList([PB(a.D,a.nh,a.nkv,a.mm,a.rb,a.qkg) for _ in range(a.nul)])
|
| 316 |
+
s.fn=RN();nn.init.normal_(s.te.weight,std=0.005)
|
| 317 |
+
def _fh(s,ids):
|
| 318 |
+
x=F.rms_norm(s.te(ids),(s.te.embedding_dim,));x0=x
|
| 319 |
+
n=s._tr if s.training else s._er
|
| 320 |
+
for _ in range(n):
|
| 321 |
+
for b in s.bl:x=b(x,x0)
|
| 322 |
+
return s.fn(x)
|
| 323 |
+
def forward(s,ids,tgt):
|
| 324 |
+
h=s._fh(ids);lo=F.linear(h.reshape(-1,h.size(-1)),s.te.weight)
|
| 325 |
+
lo=s.lsc*torch.tanh(lo/s.lsc)
|
| 326 |
+
return F.cross_entropy(lo.float(),tgt.reshape(-1),reduction="mean")
|
| 327 |
+
def ptl(s,ids,tgt):
|
| 328 |
+
h=s._fh(ids);B,T,D=h.shape
|
| 329 |
+
lo=F.linear(h.reshape(B*T,D),s.te.weight)
|
| 330 |
+
lo=s.lsc*torch.tanh(lo/s.lsc)
|
| 331 |
+
return F.cross_entropy(lo.float(),tgt.reshape(B*T),reduction="none").reshape(B,T)
|
| 332 |
+
@torch.no_grad()
|
| 333 |
+
def ptl_ttt(s,ids,tgt,a):
|
| 334 |
+
"""Score-first TTT: score chunk, then update MLP W_down for next chunk."""
|
| 335 |
+
cs=a.ttcs;lr=a.ttlr;B,T=ids.shape
|
| 336 |
+
if a.ttly=="all":li=list(range(len(s.bl)))
|
| 337 |
+
else:li=[int(x) for x in a.ttly.split(",")]
|
| 338 |
+
ow={i:s.bl[i].m.pr.weight.data.clone() for i in li}
|
| 339 |
+
ap=[];nc=(T+cs-1)//cs
|
| 340 |
+
for ci in range(nc):
|
| 341 |
+
lo=ci*cs;hi=min((ci+1)*cs,T)
|
| 342 |
+
h=s._fh(ids);hc=h[:,lo:hi,:];yc=tgt[:,lo:hi]
|
| 343 |
+
lg=F.linear(hc.reshape(-1,hc.size(-1)),s.te.weight)
|
| 344 |
+
lg=s.lsc*torch.tanh(lg/s.lsc)
|
| 345 |
+
pt=F.cross_entropy(lg.float(),yc.reshape(-1),reduction="none").reshape(B,hi-lo)
|
| 346 |
+
ap.append(pt)
|
| 347 |
+
if ci<nc-1:
|
| 348 |
+
for i in li:
|
| 349 |
+
blk=s.bl[i];hn=F.rms_norm(hc.reshape(-1,hc.size(-1)).float(),(hc.size(-1),))
|
| 350 |
+
z=torch.relu(blk.m.fc(hn.to(hc.dtype))).square()
|
| 351 |
+
pred=z@blk.m.pr.weight.T
|
| 352 |
+
res=pred-hn.to(pred.dtype)
|
| 353 |
+
gw=res.T@z/z.size(0)
|
| 354 |
+
blk.m.pr.weight.data-=lr*gw.to(blk.m.pr.weight.dtype)
|
| 355 |
+
for i in li:s.bl[i].m.pr.weight.data=ow[i]
|
| 356 |
+
return torch.cat(ap,dim=1)
|
| 357 |
+
|
| 358 |
+
class EMA:
|
| 359 |
+
def __init__(s,m,d=0.999):s.m=m;s.d=d;s.sh={n:p.data.clone() for n,p in m.named_parameters()};s.bk={}
|
| 360 |
+
def up(s):
|
| 361 |
+
for n,p in s.m.named_parameters():s.sh[n].mul_(s.d).add_(p.data,alpha=1.-s.d)
|
| 362 |
+
def ap(s):
|
| 363 |
+
s.bk={};
|
| 364 |
+
for n,p in s.m.named_parameters():s.bk[n]=p.data.clone();p.data.copy_(s.sh[n])
|
| 365 |
+
def re(s):
|
| 366 |
+
for n,p in s.m.named_parameters():p.data.copy_(s.bk[n])
|
| 367 |
+
s.bk={}
|
| 368 |
+
|
| 369 |
+
def main():
|
| 370 |
+
global zp5
|
| 371 |
+
code=Path(__file__).read_text(encoding="utf-8")
|
| 372 |
+
a=H();zp5=torch.compile(zp5)
|
| 373 |
+
dd="RANK" in os.environ and "WORLD_SIZE" in os.environ
|
| 374 |
+
rk=int(os.environ.get("RANK","0"));ws=int(os.environ.get("WORLD_SIZE","1"))
|
| 375 |
+
lr_=int(os.environ.get("LOCAL_RANK","0"))
|
| 376 |
+
ga=max(1,8//ws);gs=1./ga
|
| 377 |
+
if not torch.cuda.is_available():raise RuntimeError("CUDA required")
|
| 378 |
+
dev=torch.device("cuda",lr_);torch.cuda.set_device(dev)
|
| 379 |
+
if dd:dist.init_process_group("nccl",device_id=dev);dist.barrier()
|
| 380 |
+
ma=rk==0
|
| 381 |
+
torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True
|
| 382 |
+
from torch.backends.cuda import enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp,enable_cudnn_sdp
|
| 383 |
+
enable_flash_sdp(True);enable_math_sdp(False);enable_mem_efficient_sdp(False);enable_cudnn_sdp(False)
|
| 384 |
+
lf=None
|
| 385 |
+
if ma:os.makedirs("logs",exist_ok=True);lf=f"logs/{a.rid}.txt";print(lf)
|
| 386 |
+
def l0(m,c=True):
|
| 387 |
+
if not ma:return
|
| 388 |
+
if c:print(m)
|
| 389 |
+
if lf:
|
| 390 |
+
with open(lf,"a") as f:print(m,file=f)
|
| 391 |
+
l0(code,console=False);l0(f"Python {sys.version}",console=False);l0(f"PyTorch {torch.__version__}",console=False)
|
| 392 |
+
try:l0(subprocess.run(["nvidia-smi"],capture_output=True,text=True,check=False).stdout,console=False)
|
| 393 |
+
except FileNotFoundError:pass
|
| 394 |
+
random.seed(a.seed);np.random.seed(a.seed);torch.manual_seed(a.seed);torch.cuda.manual_seed_all(a.seed)
|
| 395 |
+
sp=spm.SentencePieceProcessor(model_file=a.tp)
|
| 396 |
+
bbl,hsl,ibl=build_sp_luts(sp,a.V,dev)
|
| 397 |
+
vt=lvt(a.vf,a.swl);l0(f"val_tokens:{vt.numel()}")
|
| 398 |
+
bm=RGPT(a).to(dev).bfloat16()
|
| 399 |
+
cm=torch.compile(bm,dynamic=False,fullgraph=True)
|
| 400 |
+
mdl=DDP(cm,device_ids=[lr_],broadcast_buffers=False) if dd else cm
|
| 401 |
+
nu=sum(p.numel() for p in bm.parameters())
|
| 402 |
+
ed=a.nul*a.nr
|
| 403 |
+
l0(f"params:{nu} eff_depth:{ed} train_loops:{a.nr} eval_loops:{bm._er}")
|
| 404 |
+
l0(f"ws:{ws} ga:{ga}")
|
| 405 |
+
bp=list(bm.bl.named_parameters())
|
| 406 |
+
mp=[p for n,p in bp if p.ndim==2 and not any(c in n for c in CP)]
|
| 407 |
+
sp_=[p for n,p in bp if p.ndim<2 or any(c in n for c in CP)]
|
| 408 |
+
ot=torch.optim.Adam([{"params":[bm.te.weight],"lr":a.elr,"base_lr":a.elr}],betas=(a.b1,a.b2),eps=a.ae,fused=True)
|
| 409 |
+
om=Muon(mp,lr=a.mlr,mom=a.mmo,bs=a.mbs,wd=a.mwd)
|
| 410 |
+
for g in om.param_groups:g["base_lr"]=a.mlr
|
| 411 |
+
os_=torch.optim.Adam([{"params":sp_,"lr":a.slr,"base_lr":a.slr}],betas=(a.b1,a.b2),eps=a.ae,fused=True)
|
| 412 |
+
opts=[ot,om,os_]
|
| 413 |
+
ema=EMA(bm,d=0.999);ess=int(a.iters*a.esf)
|
| 414 |
+
mms=1000.*a.mws if a.mws>0 else None
|
| 415 |
+
def lrm(st,el):
|
| 416 |
+
if a.wdi<=0:return 1.
|
| 417 |
+
if mms is None:
|
| 418 |
+
w=max(a.iters-a.wdi,0)
|
| 419 |
+
return max((a.iters-st)/max(a.wdi,1),0.) if w<=st<a.iters else 1.
|
| 420 |
+
sm=el/max(st,1);rm=max(mms-el,0.);wm=a.wdi*sm
|
| 421 |
+
return rm/max(wm,1e-9) if rm<=wm else 1.
|
| 422 |
+
def za():[o.zero_grad(set_to_none=True) for o in opts]
|
| 423 |
+
if a.wui>0:
|
| 424 |
+
im={n:t.detach().cpu().clone() for n,t in bm.state_dict().items()}
|
| 425 |
+
io_=[copy.deepcopy(o.state_dict()) for o in opts]
|
| 426 |
+
mdl.train();tw=DTL(a.tf,rk,ws,dev)
|
| 427 |
+
for _ in range(a.wui):
|
| 428 |
+
za()
|
| 429 |
+
for mi in range(ga):
|
| 430 |
+
if dd:mdl.require_backward_grad_sync=(mi==ga-1)
|
| 431 |
+
x,y=tw.nb(a.tbt,a.tsl,ga)
|
| 432 |
+
with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward()
|
| 433 |
+
for o in opts:o.step()
|
| 434 |
+
za()
|
| 435 |
+
bm.load_state_dict(im,strict=True)
|
| 436 |
+
for o,s in zip(opts,io_):o.load_state_dict(s)
|
| 437 |
+
za()
|
| 438 |
+
if dd:mdl.require_backward_grad_sync=True
|
| 439 |
+
tl=DTL(a.tf,rk,ws,dev);tms=0.;ss=None
|
| 440 |
+
torch.cuda.synchronize();t0=time.perf_counter();step=0
|
| 441 |
+
while True:
|
| 442 |
+
ls=step==a.iters or(ss is not None and step>=ss)
|
| 443 |
+
dv=ls or(a.vle>0 and step%a.vle==0)
|
| 444 |
+
if dv:
|
| 445 |
+
torch.cuda.synchronize();tms+=1000.*(time.perf_counter()-t0)
|
| 446 |
+
vl,vb=eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False)
|
| 447 |
+
l0(f"step:{step}/{a.iters} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_ms:{tms:.0f} step_avg:{tms/max(step,1):.2f}ms")
|
| 448 |
+
torch.cuda.synchronize();t0=time.perf_counter()
|
| 449 |
+
if ls:
|
| 450 |
+
if ma:
|
| 451 |
+
ema.ap();l0("EMA+TTT eval...")
|
| 452 |
+
vle,vbe=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True)
|
| 453 |
+
l0(f"ema_ttt val_loss:{vle:.4f} val_bpb:{vbe:.4f}")
|
| 454 |
+
sd=bm.state_dict();obj,st=qsd(sd,a.gb,a.sdn)
|
| 455 |
+
buf=io.BytesIO();torch.save(obj,buf)
|
| 456 |
+
cmp=zlib.compress(buf.getvalue(),level=9)
|
| 457 |
+
cb=len(code.encode());mb=len(cmp);tb=cb+mb
|
| 458 |
+
l0(f"artifact code:{cb} model:{mb} total:{tb} ({tb/1e6:.3f}MB) params:{st['pc']}")
|
| 459 |
+
sd2=dqsd(obj);bm.load_state_dict(sd2,strict=True)
|
| 460 |
+
vl2,vb2=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True)
|
| 461 |
+
l0(f"quant+ttt val_loss:{vl2:.4f} val_bpb:{vb2:.4f}")
|
| 462 |
+
ema.re()
|
| 463 |
+
break
|
| 464 |
+
if ss is None and mms is not None:
|
| 465 |
+
torch.cuda.synchronize()
|
| 466 |
+
el=1000.*(time.perf_counter()-t0)+tms
|
| 467 |
+
if el>=mms:ss=step+1
|
| 468 |
+
za()
|
| 469 |
+
for mi in range(ga):
|
| 470 |
+
if dd:mdl.require_backward_grad_sync=(mi==ga-1)
|
| 471 |
+
x,y=tl.nb(a.tbt,a.tsl,ga)
|
| 472 |
+
with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward()
|
| 473 |
+
torch.cuda.synchronize();el=1000.*(time.perf_counter()-t0)+tms
|
| 474 |
+
m=lrm(step,el)
|
| 475 |
+
for o in opts:
|
| 476 |
+
for g in o.param_groups:g["lr"]=g["base_lr"]*m
|
| 477 |
+
for o in opts:o.step()
|
| 478 |
+
if step>=ess:ema.up()
|
| 479 |
+
if step%a.tle==0 and ma:l0(f"step:{step} lr_mul:{m:.4f}")
|
| 480 |
+
step+=1
|
| 481 |
+
|
| 482 |
+
if __name__=="__main__":main()
|