rtferraz commited on
Commit
768495c
·
verified ·
1 Parent(s): 680a32f

Add minified training script for submission

Browse files
Files changed (1) hide show
  1. train_final.py +482 -0
train_final.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PG-v2: SP8192+ParallelRes+DepthRec+TTT+Int6GPTQ+EMA"""
2
+ from __future__ import annotations
3
+ import copy,glob,io,math,os,random,subprocess,sys,time,uuid,zlib
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import sentencepiece as spm
7
+ import torch,torch.distributed as dist,torch.nn.functional as F
8
+ from torch import Tensor,nn
9
+ from torch.nn.parallel import DistributedDataParallel as DDP
10
+
11
+ class H:
12
+ dp=os.environ.get("DATA_PATH","./data/datasets/fineweb10B_sp8192")
13
+ tf=os.path.join(dp,"fineweb_train_*.bin")
14
+ vf=os.path.join(dp,"fineweb_val_*.bin")
15
+ tp=os.environ.get("TOKENIZER_PATH","./data/tokenizers/fineweb_8192_bpe.model")
16
+ rid=os.environ.get("RUN_ID",str(uuid.uuid4()))
17
+ seed=int(os.environ.get("SEED","1337"))
18
+ vbs=int(os.environ.get("VBS","524288"));vle=int(os.environ.get("VLE","1000"))
19
+ tle=int(os.environ.get("TLE","200"))
20
+ iters=int(os.environ.get("ITERS","20000"))
21
+ wdi=int(os.environ.get("WDI","3500"));wui=int(os.environ.get("WUI","20"))
22
+ tbt=int(os.environ.get("TBT","524288"));tsl=int(os.environ.get("TSL","1024"))
23
+ mws=float(os.environ.get("MWS","600.0"))
24
+ V=int(os.environ.get("V","8192"));D=int(os.environ.get("D","768"))
25
+ nh=int(os.environ.get("NH","12"));nkv=int(os.environ.get("NKV","4"))
26
+ mm=int(os.environ.get("MM","4"))
27
+ nul=int(os.environ.get("NUL","3"));nr=int(os.environ.get("NR","8"))
28
+ ner=int(os.environ.get("NER","0"))
29
+ rb=float(os.environ.get("RB","10000.0"))
30
+ lsc=float(os.environ.get("LSC","30.0"))
31
+ qkg=float(os.environ.get("QKG","5.25"))
32
+ sws=int(os.environ.get("SWS","64"));swl=int(os.environ.get("SWL","1024"))
33
+ tte=int(os.environ.get("TTE","1"))
34
+ ttlr=float(os.environ.get("TTLR","0.01"))
35
+ ttcs=int(os.environ.get("TTCS","64"))
36
+ ttly=os.environ.get("TTLY","all")
37
+ elr=float(os.environ.get("ELR","0.05"))
38
+ mlr=float(os.environ.get("MLR","0.04"))
39
+ slr=float(os.environ.get("SLR","0.04"))
40
+ mmo=float(os.environ.get("MMO","0.95"))
41
+ mbs=int(os.environ.get("MBS","5"))
42
+ mwd=float(os.environ.get("MWD","0.09"))
43
+ b1=float(os.environ.get("B1","0.9"))
44
+ b2=float(os.environ.get("B2","0.95"))
45
+ ae=float(os.environ.get("AE","1e-8"))
46
+ gb=int(os.environ.get("GB","6"))
47
+ sdn=float(os.environ.get("SDN","2.5"))
48
+ esf=float(os.environ.get("ESF","0.4"))
49
+
50
+ CP=tuple(p for p in "attn_scale,mlp_scale,resid_mix,q_gain".split(",") if p)
51
+
52
+ def zp5(G,s=10,e=1e-7):
53
+ a,b,c=3.4445,-4.7750,2.0315
54
+ X=G.bfloat16();X/=X.norm()+e
55
+ tr=G.size(0)>G.size(1)
56
+ if tr:X=X.T
57
+ for _ in range(s):
58
+ A=X@X.T;B=b*A+c*A@A;X=a*X+B@X
59
+ return X.T if tr else X
60
+
61
+ class Muon(torch.optim.Optimizer):
62
+ def __init__(s,p,lr,mom,bs,wd=0.,nest=True):
63
+ super().__init__(p,dict(lr=lr,mom=mom,bs=bs,wd=wd,nest=nest))
64
+ @torch.no_grad()
65
+ def step(s,cl=None):
66
+ lo=None
67
+ if cl:
68
+ with torch.enable_grad():lo=cl()
69
+ dd=dist.is_available() and dist.is_initialized()
70
+ ws=dist.get_world_size() if dd else 1
71
+ rk=dist.get_rank() if dd else 0
72
+ for g in s.param_groups:
73
+ ps=g["params"];lr=g["lr"];mo=g["mom"];bs=g["bs"];wd=g["wd"];ne=g["nest"]
74
+ tot=sum(int(p.numel()) for p in ps)
75
+ fl=torch.zeros(tot,device=ps[0].device,dtype=torch.bfloat16)
76
+ cur=0
77
+ for i,p in enumerate(ps):
78
+ if i%ws==rk and p.grad is not None:
79
+ gr=p.grad
80
+ if wd:gr=gr+wd*p.data.to(gr.dtype)
81
+ st=s.state[p]
82
+ if "mb" not in st:st["mb"]=torch.zeros_like(gr)
83
+ buf=st["mb"];buf.mul_(mo).add_(gr)
84
+ if ne:gr=gr.add(buf,alpha=mo)
85
+ gr=zp5(gr,steps=bs)
86
+ gr*=max(1,gr.size(0)/gr.size(1))**0.5
87
+ fl[cur:cur+p.numel()]=gr.reshape(-1)
88
+ cur+=p.numel()
89
+ if dd:dist.all_reduce(fl,op=dist.ReduceOp.SUM)
90
+ cur=0
91
+ for p in ps:
92
+ gr=fl[cur:cur+p.numel()].view_as(p).to(dtype=p.dtype)
93
+ p.add_(gr,alpha=-lr);cur+=p.numel()
94
+ return lo
95
+
96
+ def build_sp_luts(sp,vs,dev):
97
+ sv=int(sp.vocab_size());sz=max(sv,vs)
98
+ bb=np.zeros(sz,dtype=np.int16);hs=np.zeros(sz,dtype=bool);ib=np.ones(sz,dtype=bool)
99
+ for t in range(sv):
100
+ if sp.is_control(t) or sp.is_unknown(t) or sp.is_unused(t):continue
101
+ ib[t]=False
102
+ if sp.is_byte(t):bb[t]=1;continue
103
+ pc=sp.id_to_piece(t)
104
+ if pc.startswith("\u2581"):hs[t]=True;pc=pc[1:]
105
+ bb[t]=len(pc.encode("utf-8"))
106
+ return(torch.tensor(bb,dtype=torch.int16,device=dev),
107
+ torch.tensor(hs,dtype=torch.bool,device=dev),
108
+ torch.tensor(ib,dtype=torch.bool,device=dev))
109
+
110
+ def eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False):
111
+ sl=a.swl;st=a.sws;T=vt.numel()
112
+ starts=list(range(0,T-sl-1,st))
113
+ my=starts[rk::ws]
114
+ ls=torch.zeros((),device=dev,dtype=torch.float64)
115
+ tc=torch.zeros((),device=dev,dtype=torch.float64)
116
+ bc=torch.zeros((),device=dev,dtype=torch.float64)
117
+ rm=mdl
118
+ while hasattr(rm,'module'):rm=rm.module
119
+ if hasattr(rm,'_orig_mod'):rm=rm._orig_mod
120
+ rm.eval()
121
+ ctx=torch.no_grad if ttt else torch.inference_mode
122
+ with ctx():
123
+ for s in my:
124
+ e=s+sl
125
+ x=vt[s:e].unsqueeze(0).to(dev,dtype=torch.int64)
126
+ y=vt[s+1:e+1].unsqueeze(0).to(dev,dtype=torch.int64)
127
+ with torch.autocast("cuda",dtype=torch.bfloat16):
128
+ if ttt and a.tte:ptl=rm.ptl_ttt(x,y,a)
129
+ else:ptl=rm.ptl(x,y)
130
+ lo=sl-st;ps=ptl[0,lo:];ys=y[0,lo:];xs=x[0,lo:]
131
+ ls+=ps.to(torch.float64).sum();tc+=ps.numel()
132
+ tb=bbl[ys].to(torch.float64)
133
+ tb+=(hsl[ys]&~ibl[xs]).to(torch.float64)
134
+ bc+=tb.sum()
135
+ if dist.is_available() and dist.is_initialized():
136
+ for t in(ls,tc,bc):dist.all_reduce(t,op=dist.ReduceOp.SUM)
137
+ vl=float((ls/tc).item());bpb=float((ls/math.log(2)/bc).item())
138
+ rm.train();return vl,bpb
139
+
140
+ def sdclip(t,n=2.5):
141
+ m=t.float().mean();s=t.float().std()
142
+ return t.clamp((m-n*s).item(),(m+n*s).item())
143
+
144
+ def qi6(t,ns=2.5):
145
+ t32=t.float();mx=31
146
+ if t32.ndim==2:
147
+ m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9)
148
+ lo=m-ns*s;hi=m+ns*s
149
+ tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32))
150
+ cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/mx
151
+ q=torch.clamp(torch.round(tc/sc[:,None]),-mx,mx).to(torch.int8)
152
+ return q.contiguous(),sc.to(torch.float16).contiguous()
153
+ tc=sdclip(t32,ns);cv=float(tc.abs().max().item())
154
+ sc=torch.tensor(max(cv/mx,1./mx),dtype=torch.float32)
155
+ q=torch.clamp(torch.round(tc/sc),-mx,mx).to(torch.int8)
156
+ return q.contiguous(),sc
157
+
158
+ def qi8(t,ns=2.5):
159
+ t32=t.float()
160
+ if t32.ndim==2:
161
+ m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9)
162
+ lo=m-ns*s;hi=m+ns*s
163
+ tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32))
164
+ cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/127.
165
+ q=torch.clamp(torch.round(tc/sc[:,None]),-127,127).to(torch.int8)
166
+ return q.contiguous(),sc.to(torch.float16).contiguous()
167
+ cv=float(sdclip(t32,ns).abs().max().item())
168
+ sc=torch.tensor(max(cv/127.,1./127.),dtype=torch.float32)
169
+ q=torch.clamp(torch.round(t32.clamp(-cv,cv)/sc),-127,127).to(torch.int8)
170
+ return q.contiguous(),sc
171
+
172
+ def qsd(sd,gb=6,ns=2.5):
173
+ qf=qi6 if gb==6 else qi8
174
+ qu,sc,dt,pt,po,qm={},{},{},{},{},{}
175
+ st={k:0 for k in("pc","nt","bb","qb")}
176
+ for n,t in sd.items():
177
+ t=t.detach().cpu().contiguous()
178
+ st["pc"]+=t.numel();st["nt"]+=1;st["bb"]+=t.numel()*t.element_size()
179
+ if not t.is_floating_point():pt[n]=t;st["qb"]+=t.numel()*t.element_size();continue
180
+ ic=any(p in n for p in CP);ism=t.numel()<=65536
181
+ if "tok_emb" in n:
182
+ po[n]=str(t.dtype).removeprefix("torch.")
183
+ q,s=qi8(t,ns);qu[n]=q;sc[n]=s;dt[n]=po[n]
184
+ if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":8}
185
+ st["qb"]+=q.numel()+s.numel()*s.element_size();continue
186
+ if ic or ism:
187
+ if t.dtype in(torch.float32,torch.bfloat16):po[n]=str(t.dtype).removeprefix("torch.")
188
+ pt[n]=t.float() if ic else t.to(torch.float16)
189
+ pt[n]=pt[n].contiguous();st["qb"]+=pt[n].numel()*pt[n].element_size();continue
190
+ q,s=qf(t,ns)
191
+ if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":gb}
192
+ qu[n]=q;sc[n]=s;dt[n]=str(t.dtype).removeprefix("torch.")
193
+ st["qb"]+=q.numel()+s.numel()*s.element_size()
194
+ obj={"__qf__":f"i{gb}sd","q":qu,"s":sc,"d":dt,"p":pt}
195
+ if qm:obj["m"]=qm
196
+ if po:obj["o"]=po
197
+ return obj,st
198
+
199
+ def dqsd(obj):
200
+ out={};qm=obj.get("m",{});po=obj.get("o",{})
201
+ for n,q in obj["q"].items():
202
+ dt=getattr(torch,obj["d"][n]);s=obj["s"][n]
203
+ if qm.get(n,{}).get("scheme")=="per_row" or s.ndim>0:
204
+ s=s.to(torch.float32)
205
+ out[n]=(q.float()*s.view(q.shape[0],*([1]*(q.ndim-1)))).to(dt).contiguous()
206
+ else:out[n]=(q.float()*float(s.item())).to(dt).contiguous()
207
+ for n,t in obj["p"].items():
208
+ ot=t.detach().cpu().contiguous();od=po.get(n)
209
+ if isinstance(od,str):ot=ot.to(dtype=getattr(torch,od)).contiguous()
210
+ out[n]=ot
211
+ return out
212
+
213
+ def lds(f):
214
+ h=np.fromfile(f,dtype="<i4",count=256)
215
+ if h.size!=256 or int(h[0])!=20240520 or int(h[1])!=1:raise ValueError(f"Bad:{f}")
216
+ n=int(h[2]);t=np.fromfile(f,dtype="<u2",count=n,offset=256*4)
217
+ return torch.from_numpy(t.astype(np.uint16,copy=False))
218
+
219
+ def lvt(pat,sl):
220
+ fs=[Path(p) for p in sorted(glob.glob(pat))]
221
+ if not fs:raise FileNotFoundError(f"No val:{pat}")
222
+ t=torch.cat([lds(f) for f in fs]).contiguous()
223
+ u=((t.numel()-1)//sl)*sl;return t[:u+1]
224
+
225
+ class TS:
226
+ def __init__(s,pat):
227
+ fs=[Path(p) for p in sorted(glob.glob(pat))]
228
+ if not fs:raise FileNotFoundError(f"No:{pat}")
229
+ s.fs=fs;s.i=0;s.t=lds(fs[0]);s.p=0
230
+ def take(s,n):
231
+ ch=[];r=n
232
+ while r>0:
233
+ av=s.t.numel()-s.p
234
+ if av<=0:s.i=(s.i+1)%len(s.fs);s.t=lds(s.fs[s.i]);s.p=0;av=s.t.numel()
235
+ k=min(r,av);ch.append(s.t[s.p:s.p+k]);s.p+=k;r-=k
236
+ return ch[0] if len(ch)==1 else torch.cat(ch)
237
+
238
+ class DTL:
239
+ def __init__(s,pat,rk,ws,dev):s.rk=rk;s.ws=ws;s.dev=dev;s.st=TS(pat)
240
+ def nb(s,gt,sl,ga):
241
+ lt=gt//(s.ws*ga);ps=lt+1
242
+ ch=s.st.take(ps*s.ws);st=s.rk*ps
243
+ lo=ch[st:st+ps].to(torch.int64)
244
+ x=lo[:-1].reshape(-1,sl);y=lo[1:].reshape(-1,sl)
245
+ return x.to(s.dev,non_blocking=True),y.to(s.dev,non_blocking=True)
246
+
247
+ class RN(nn.Module):
248
+ def __init__(s,eps=None):super().__init__();s.eps=eps
249
+ def forward(s,x):return F.rms_norm(x,(x.size(-1),),eps=s.eps)
250
+
251
+ class Rot(nn.Module):
252
+ def __init__(s,d,b=10000.):
253
+ super().__init__()
254
+ s.register_buffer("if_",1./(b**(torch.arange(0,d,2,dtype=torch.float32)/d)),persistent=False)
255
+ s._cl=0;s._c=None;s._s=None
256
+ def forward(s,sl,dev,dt):
257
+ if s._c is None or s._cl!=sl or s._c.device!=dev:
258
+ t=torch.arange(sl,device=dev,dtype=s.if_.dtype)
259
+ fr=torch.outer(t,s.if_.to(dev))
260
+ s._c=fr.cos()[None,None,:,:];s._s=fr.sin()[None,None,:,:];s._cl=sl
261
+ return s._c.to(dtype=dt),s._s.to(dtype=dt)
262
+
263
+ def arot(x,c,si):
264
+ h=x.size(-1)//2;x1,x2=x[...,:h],x[...,h:]
265
+ return torch.cat((x1*c+x2*si,x1*(-si)+x2*c),dim=-1)
266
+
267
+ class CSA(nn.Module):
268
+ def __init__(s,d,nh,nk,rb,qkg):
269
+ super().__init__()
270
+ assert d%nh==0 and nh%nk==0
271
+ s.nh=nh;s.nk=nk;s.hd=d//nh;kd=nk*s.hd
272
+ s.cq=nn.Linear(d,d,bias=False);s.ck=nn.Linear(d,kd,bias=False)
273
+ s.cv=nn.Linear(d,kd,bias=False);s.pr=nn.Linear(d,d,bias=False)
274
+ s.qg=nn.Parameter(torch.full((nh,),qkg,dtype=torch.float32))
275
+ s.rot=Rot(s.hd,base=rb)
276
+ def forward(s,x):
277
+ B,T,_=x.shape
278
+ q=s.cq(x).reshape(B,T,s.nh,s.hd).transpose(1,2)
279
+ k=s.ck(x).reshape(B,T,s.nk,s.hd).transpose(1,2)
280
+ v=s.cv(x).reshape(B,T,s.nk,s.hd).transpose(1,2)
281
+ q=F.rms_norm(q,(q.size(-1),));k=F.rms_norm(k,(k.size(-1),))
282
+ c,si=s.rot(T,x.device,q.dtype)
283
+ q=arot(q,c,si);k=arot(k,c,si)
284
+ q=q*s.qg.to(dtype=q.dtype)[None,:,None,None]
285
+ y=F.scaled_dot_product_attention(q,k,v,attn_mask=None,is_causal=True,
286
+ enable_gqa=(s.nk!=s.nh))
287
+ return s.pr(y.transpose(1,2).contiguous().reshape(B,T,-1))
288
+
289
+ class MLP(nn.Module):
290
+ def __init__(s,d,m):
291
+ super().__init__()
292
+ h=d*m;s.fc=nn.Linear(d,h,bias=False);s.pr=nn.Linear(h,d,bias=False)
293
+ def forward(s,x):return s.pr(torch.relu(s.fc(x)).square())
294
+
295
+ class PB(nn.Module):
296
+ """Parallel residual block."""
297
+ def __init__(s,d,nh,nk,mm,rb,qkg):
298
+ super().__init__()
299
+ s.n=RN();s.a=CSA(d,nh,nk,rb,qkg);s.m=MLP(d,mm)
300
+ s.as_=nn.Parameter(torch.ones(d,dtype=torch.float32))
301
+ s.ms=nn.Parameter(torch.ones(d,dtype=torch.float32))
302
+ s.rm=nn.Parameter(torch.stack([torch.ones(d),torch.zeros(d)]).float())
303
+ def forward(s,x,x0):
304
+ mx=s.rm.to(x.dtype)
305
+ x=mx[0][None,None,:]*x+mx[1][None,None,:]*x0
306
+ h=s.n(x)
307
+ x=x+s.as_.to(x.dtype)[None,None,:]*s.a(h)+s.ms.to(x.dtype)[None,None,:]*s.m(h)
308
+ return x
309
+
310
+ class RGPT(nn.Module):
311
+ def __init__(s,a):
312
+ super().__init__()
313
+ s.lsc=a.lsc;s._tr=a.nr;s._er=a.ner or a.nr*2;s._V=a.V
314
+ s.te=nn.Embedding(a.V,a.D)
315
+ s.bl=nn.ModuleList([PB(a.D,a.nh,a.nkv,a.mm,a.rb,a.qkg) for _ in range(a.nul)])
316
+ s.fn=RN();nn.init.normal_(s.te.weight,std=0.005)
317
+ def _fh(s,ids):
318
+ x=F.rms_norm(s.te(ids),(s.te.embedding_dim,));x0=x
319
+ n=s._tr if s.training else s._er
320
+ for _ in range(n):
321
+ for b in s.bl:x=b(x,x0)
322
+ return s.fn(x)
323
+ def forward(s,ids,tgt):
324
+ h=s._fh(ids);lo=F.linear(h.reshape(-1,h.size(-1)),s.te.weight)
325
+ lo=s.lsc*torch.tanh(lo/s.lsc)
326
+ return F.cross_entropy(lo.float(),tgt.reshape(-1),reduction="mean")
327
+ def ptl(s,ids,tgt):
328
+ h=s._fh(ids);B,T,D=h.shape
329
+ lo=F.linear(h.reshape(B*T,D),s.te.weight)
330
+ lo=s.lsc*torch.tanh(lo/s.lsc)
331
+ return F.cross_entropy(lo.float(),tgt.reshape(B*T),reduction="none").reshape(B,T)
332
+ @torch.no_grad()
333
+ def ptl_ttt(s,ids,tgt,a):
334
+ """Score-first TTT: score chunk, then update MLP W_down for next chunk."""
335
+ cs=a.ttcs;lr=a.ttlr;B,T=ids.shape
336
+ if a.ttly=="all":li=list(range(len(s.bl)))
337
+ else:li=[int(x) for x in a.ttly.split(",")]
338
+ ow={i:s.bl[i].m.pr.weight.data.clone() for i in li}
339
+ ap=[];nc=(T+cs-1)//cs
340
+ for ci in range(nc):
341
+ lo=ci*cs;hi=min((ci+1)*cs,T)
342
+ h=s._fh(ids);hc=h[:,lo:hi,:];yc=tgt[:,lo:hi]
343
+ lg=F.linear(hc.reshape(-1,hc.size(-1)),s.te.weight)
344
+ lg=s.lsc*torch.tanh(lg/s.lsc)
345
+ pt=F.cross_entropy(lg.float(),yc.reshape(-1),reduction="none").reshape(B,hi-lo)
346
+ ap.append(pt)
347
+ if ci<nc-1:
348
+ for i in li:
349
+ blk=s.bl[i];hn=F.rms_norm(hc.reshape(-1,hc.size(-1)).float(),(hc.size(-1),))
350
+ z=torch.relu(blk.m.fc(hn.to(hc.dtype))).square()
351
+ pred=z@blk.m.pr.weight.T
352
+ res=pred-hn.to(pred.dtype)
353
+ gw=res.T@z/z.size(0)
354
+ blk.m.pr.weight.data-=lr*gw.to(blk.m.pr.weight.dtype)
355
+ for i in li:s.bl[i].m.pr.weight.data=ow[i]
356
+ return torch.cat(ap,dim=1)
357
+
358
+ class EMA:
359
+ def __init__(s,m,d=0.999):s.m=m;s.d=d;s.sh={n:p.data.clone() for n,p in m.named_parameters()};s.bk={}
360
+ def up(s):
361
+ for n,p in s.m.named_parameters():s.sh[n].mul_(s.d).add_(p.data,alpha=1.-s.d)
362
+ def ap(s):
363
+ s.bk={};
364
+ for n,p in s.m.named_parameters():s.bk[n]=p.data.clone();p.data.copy_(s.sh[n])
365
+ def re(s):
366
+ for n,p in s.m.named_parameters():p.data.copy_(s.bk[n])
367
+ s.bk={}
368
+
369
+ def main():
370
+ global zp5
371
+ code=Path(__file__).read_text(encoding="utf-8")
372
+ a=H();zp5=torch.compile(zp5)
373
+ dd="RANK" in os.environ and "WORLD_SIZE" in os.environ
374
+ rk=int(os.environ.get("RANK","0"));ws=int(os.environ.get("WORLD_SIZE","1"))
375
+ lr_=int(os.environ.get("LOCAL_RANK","0"))
376
+ ga=max(1,8//ws);gs=1./ga
377
+ if not torch.cuda.is_available():raise RuntimeError("CUDA required")
378
+ dev=torch.device("cuda",lr_);torch.cuda.set_device(dev)
379
+ if dd:dist.init_process_group("nccl",device_id=dev);dist.barrier()
380
+ ma=rk==0
381
+ torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True
382
+ from torch.backends.cuda import enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp,enable_cudnn_sdp
383
+ enable_flash_sdp(True);enable_math_sdp(False);enable_mem_efficient_sdp(False);enable_cudnn_sdp(False)
384
+ lf=None
385
+ if ma:os.makedirs("logs",exist_ok=True);lf=f"logs/{a.rid}.txt";print(lf)
386
+ def l0(m,c=True):
387
+ if not ma:return
388
+ if c:print(m)
389
+ if lf:
390
+ with open(lf,"a") as f:print(m,file=f)
391
+ l0(code,console=False);l0(f"Python {sys.version}",console=False);l0(f"PyTorch {torch.__version__}",console=False)
392
+ try:l0(subprocess.run(["nvidia-smi"],capture_output=True,text=True,check=False).stdout,console=False)
393
+ except FileNotFoundError:pass
394
+ random.seed(a.seed);np.random.seed(a.seed);torch.manual_seed(a.seed);torch.cuda.manual_seed_all(a.seed)
395
+ sp=spm.SentencePieceProcessor(model_file=a.tp)
396
+ bbl,hsl,ibl=build_sp_luts(sp,a.V,dev)
397
+ vt=lvt(a.vf,a.swl);l0(f"val_tokens:{vt.numel()}")
398
+ bm=RGPT(a).to(dev).bfloat16()
399
+ cm=torch.compile(bm,dynamic=False,fullgraph=True)
400
+ mdl=DDP(cm,device_ids=[lr_],broadcast_buffers=False) if dd else cm
401
+ nu=sum(p.numel() for p in bm.parameters())
402
+ ed=a.nul*a.nr
403
+ l0(f"params:{nu} eff_depth:{ed} train_loops:{a.nr} eval_loops:{bm._er}")
404
+ l0(f"ws:{ws} ga:{ga}")
405
+ bp=list(bm.bl.named_parameters())
406
+ mp=[p for n,p in bp if p.ndim==2 and not any(c in n for c in CP)]
407
+ sp_=[p for n,p in bp if p.ndim<2 or any(c in n for c in CP)]
408
+ ot=torch.optim.Adam([{"params":[bm.te.weight],"lr":a.elr,"base_lr":a.elr}],betas=(a.b1,a.b2),eps=a.ae,fused=True)
409
+ om=Muon(mp,lr=a.mlr,mom=a.mmo,bs=a.mbs,wd=a.mwd)
410
+ for g in om.param_groups:g["base_lr"]=a.mlr
411
+ os_=torch.optim.Adam([{"params":sp_,"lr":a.slr,"base_lr":a.slr}],betas=(a.b1,a.b2),eps=a.ae,fused=True)
412
+ opts=[ot,om,os_]
413
+ ema=EMA(bm,d=0.999);ess=int(a.iters*a.esf)
414
+ mms=1000.*a.mws if a.mws>0 else None
415
+ def lrm(st,el):
416
+ if a.wdi<=0:return 1.
417
+ if mms is None:
418
+ w=max(a.iters-a.wdi,0)
419
+ return max((a.iters-st)/max(a.wdi,1),0.) if w<=st<a.iters else 1.
420
+ sm=el/max(st,1);rm=max(mms-el,0.);wm=a.wdi*sm
421
+ return rm/max(wm,1e-9) if rm<=wm else 1.
422
+ def za():[o.zero_grad(set_to_none=True) for o in opts]
423
+ if a.wui>0:
424
+ im={n:t.detach().cpu().clone() for n,t in bm.state_dict().items()}
425
+ io_=[copy.deepcopy(o.state_dict()) for o in opts]
426
+ mdl.train();tw=DTL(a.tf,rk,ws,dev)
427
+ for _ in range(a.wui):
428
+ za()
429
+ for mi in range(ga):
430
+ if dd:mdl.require_backward_grad_sync=(mi==ga-1)
431
+ x,y=tw.nb(a.tbt,a.tsl,ga)
432
+ with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward()
433
+ for o in opts:o.step()
434
+ za()
435
+ bm.load_state_dict(im,strict=True)
436
+ for o,s in zip(opts,io_):o.load_state_dict(s)
437
+ za()
438
+ if dd:mdl.require_backward_grad_sync=True
439
+ tl=DTL(a.tf,rk,ws,dev);tms=0.;ss=None
440
+ torch.cuda.synchronize();t0=time.perf_counter();step=0
441
+ while True:
442
+ ls=step==a.iters or(ss is not None and step>=ss)
443
+ dv=ls or(a.vle>0 and step%a.vle==0)
444
+ if dv:
445
+ torch.cuda.synchronize();tms+=1000.*(time.perf_counter()-t0)
446
+ vl,vb=eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False)
447
+ l0(f"step:{step}/{a.iters} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_ms:{tms:.0f} step_avg:{tms/max(step,1):.2f}ms")
448
+ torch.cuda.synchronize();t0=time.perf_counter()
449
+ if ls:
450
+ if ma:
451
+ ema.ap();l0("EMA+TTT eval...")
452
+ vle,vbe=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True)
453
+ l0(f"ema_ttt val_loss:{vle:.4f} val_bpb:{vbe:.4f}")
454
+ sd=bm.state_dict();obj,st=qsd(sd,a.gb,a.sdn)
455
+ buf=io.BytesIO();torch.save(obj,buf)
456
+ cmp=zlib.compress(buf.getvalue(),level=9)
457
+ cb=len(code.encode());mb=len(cmp);tb=cb+mb
458
+ l0(f"artifact code:{cb} model:{mb} total:{tb} ({tb/1e6:.3f}MB) params:{st['pc']}")
459
+ sd2=dqsd(obj);bm.load_state_dict(sd2,strict=True)
460
+ vl2,vb2=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True)
461
+ l0(f"quant+ttt val_loss:{vl2:.4f} val_bpb:{vb2:.4f}")
462
+ ema.re()
463
+ break
464
+ if ss is None and mms is not None:
465
+ torch.cuda.synchronize()
466
+ el=1000.*(time.perf_counter()-t0)+tms
467
+ if el>=mms:ss=step+1
468
+ za()
469
+ for mi in range(ga):
470
+ if dd:mdl.require_backward_grad_sync=(mi==ga-1)
471
+ x,y=tl.nb(a.tbt,a.tsl,ga)
472
+ with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward()
473
+ torch.cuda.synchronize();el=1000.*(time.perf_counter()-t0)+tms
474
+ m=lrm(step,el)
475
+ for o in opts:
476
+ for g in o.param_groups:g["lr"]=g["base_lr"]*m
477
+ for o in opts:o.step()
478
+ if step>=ess:ema.up()
479
+ if step%a.tle==0 and ma:l0(f"step:{step} lr_mul:{m:.4f}")
480
+ step+=1
481
+
482
+ if __name__=="__main__":main()