| """PG-v2: SP8192+ParallelRes+DepthRec+TTT+Int6GPTQ+EMA""" |
| from __future__ import annotations |
| import copy,glob,io,math,os,random,subprocess,sys,time,uuid,zlib |
| from pathlib import Path |
| import numpy as np |
| import sentencepiece as spm |
| import torch,torch.distributed as dist,torch.nn.functional as F |
| from torch import Tensor,nn |
| from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
| class H: |
| dp=os.environ.get("DATA_PATH","./data/datasets/fineweb10B_sp8192") |
| tf=os.path.join(dp,"fineweb_train_*.bin") |
| vf=os.path.join(dp,"fineweb_val_*.bin") |
| tp=os.environ.get("TOKENIZER_PATH","./data/tokenizers/fineweb_8192_bpe.model") |
| rid=os.environ.get("RUN_ID",str(uuid.uuid4())) |
| seed=int(os.environ.get("SEED","1337")) |
| vbs=int(os.environ.get("VBS","524288"));vle=int(os.environ.get("VLE","1000")) |
| tle=int(os.environ.get("TLE","200")) |
| iters=int(os.environ.get("ITERS","20000")) |
| wdi=int(os.environ.get("WDI","3500"));wui=int(os.environ.get("WUI","20")) |
| tbt=int(os.environ.get("TBT","524288"));tsl=int(os.environ.get("TSL","1024")) |
| mws=float(os.environ.get("MWS","600.0")) |
| V=int(os.environ.get("V","8192"));D=int(os.environ.get("D","768")) |
| nh=int(os.environ.get("NH","12"));nkv=int(os.environ.get("NKV","4")) |
| mm=int(os.environ.get("MM","4")) |
| nul=int(os.environ.get("NUL","3"));nr=int(os.environ.get("NR","8")) |
| ner=int(os.environ.get("NER","0")) |
| rb=float(os.environ.get("RB","10000.0")) |
| lsc=float(os.environ.get("LSC","30.0")) |
| qkg=float(os.environ.get("QKG","5.25")) |
| sws=int(os.environ.get("SWS","64"));swl=int(os.environ.get("SWL","1024")) |
| tte=int(os.environ.get("TTE","1")) |
| ttlr=float(os.environ.get("TTLR","0.01")) |
| ttcs=int(os.environ.get("TTCS","64")) |
| ttly=os.environ.get("TTLY","all") |
| elr=float(os.environ.get("ELR","0.05")) |
| mlr=float(os.environ.get("MLR","0.04")) |
| slr=float(os.environ.get("SLR","0.04")) |
| mmo=float(os.environ.get("MMO","0.95")) |
| mbs=int(os.environ.get("MBS","5")) |
| mwd=float(os.environ.get("MWD","0.09")) |
| b1=float(os.environ.get("B1","0.9")) |
| b2=float(os.environ.get("B2","0.95")) |
| ae=float(os.environ.get("AE","1e-8")) |
| gb=int(os.environ.get("GB","6")) |
| sdn=float(os.environ.get("SDN","2.5")) |
| esf=float(os.environ.get("ESF","0.4")) |
|
|
| CP=tuple(p for p in "attn_scale,mlp_scale,resid_mix,q_gain".split(",") if p) |
|
|
| def zp5(G,s=10,e=1e-7): |
| a,b,c=3.4445,-4.7750,2.0315 |
| X=G.bfloat16();X/=X.norm()+e |
| tr=G.size(0)>G.size(1) |
| if tr:X=X.T |
| for _ in range(s): |
| A=X@X.T;B=b*A+c*A@A;X=a*X+B@X |
| return X.T if tr else X |
|
|
| class Muon(torch.optim.Optimizer): |
| def __init__(s,p,lr,mom,bs,wd=0.,nest=True): |
| super().__init__(p,dict(lr=lr,mom=mom,bs=bs,wd=wd,nest=nest)) |
| @torch.no_grad() |
| def step(s,cl=None): |
| lo=None |
| if cl: |
| with torch.enable_grad():lo=cl() |
| dd=dist.is_available() and dist.is_initialized() |
| ws=dist.get_world_size() if dd else 1 |
| rk=dist.get_rank() if dd else 0 |
| for g in s.param_groups: |
| ps=g["params"];lr=g["lr"];mo=g["mom"];bs=g["bs"];wd=g["wd"];ne=g["nest"] |
| tot=sum(int(p.numel()) for p in ps) |
| fl=torch.zeros(tot,device=ps[0].device,dtype=torch.bfloat16) |
| cur=0 |
| for i,p in enumerate(ps): |
| if i%ws==rk and p.grad is not None: |
| gr=p.grad |
| if wd:gr=gr+wd*p.data.to(gr.dtype) |
| st=s.state[p] |
| if "mb" not in st:st["mb"]=torch.zeros_like(gr) |
| buf=st["mb"];buf.mul_(mo).add_(gr) |
| if ne:gr=gr.add(buf,alpha=mo) |
| gr=zp5(gr,steps=bs) |
| gr*=max(1,gr.size(0)/gr.size(1))**0.5 |
| fl[cur:cur+p.numel()]=gr.reshape(-1) |
| cur+=p.numel() |
| if dd:dist.all_reduce(fl,op=dist.ReduceOp.SUM) |
| cur=0 |
| for p in ps: |
| gr=fl[cur:cur+p.numel()].view_as(p).to(dtype=p.dtype) |
| p.add_(gr,alpha=-lr);cur+=p.numel() |
| return lo |
|
|
| def build_sp_luts(sp,vs,dev): |
| sv=int(sp.vocab_size());sz=max(sv,vs) |
| bb=np.zeros(sz,dtype=np.int16);hs=np.zeros(sz,dtype=bool);ib=np.ones(sz,dtype=bool) |
| for t in range(sv): |
| if sp.is_control(t) or sp.is_unknown(t) or sp.is_unused(t):continue |
| ib[t]=False |
| if sp.is_byte(t):bb[t]=1;continue |
| pc=sp.id_to_piece(t) |
| if pc.startswith("\u2581"):hs[t]=True;pc=pc[1:] |
| bb[t]=len(pc.encode("utf-8")) |
| return(torch.tensor(bb,dtype=torch.int16,device=dev), |
| torch.tensor(hs,dtype=torch.bool,device=dev), |
| torch.tensor(ib,dtype=torch.bool,device=dev)) |
|
|
| def eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False): |
| sl=a.swl;st=a.sws;T=vt.numel() |
| starts=list(range(0,T-sl-1,st)) |
| my=starts[rk::ws] |
| ls=torch.zeros((),device=dev,dtype=torch.float64) |
| tc=torch.zeros((),device=dev,dtype=torch.float64) |
| bc=torch.zeros((),device=dev,dtype=torch.float64) |
| rm=mdl |
| while hasattr(rm,'module'):rm=rm.module |
| if hasattr(rm,'_orig_mod'):rm=rm._orig_mod |
| rm.eval() |
| ctx=torch.no_grad if ttt else torch.inference_mode |
| with ctx(): |
| for s in my: |
| e=s+sl |
| x=vt[s:e].unsqueeze(0).to(dev,dtype=torch.int64) |
| y=vt[s+1:e+1].unsqueeze(0).to(dev,dtype=torch.int64) |
| with torch.autocast("cuda",dtype=torch.bfloat16): |
| if ttt and a.tte:ptl=rm.ptl_ttt(x,y,a) |
| else:ptl=rm.ptl(x,y) |
| lo=sl-st;ps=ptl[0,lo:];ys=y[0,lo:];xs=x[0,lo:] |
| ls+=ps.to(torch.float64).sum();tc+=ps.numel() |
| tb=bbl[ys].to(torch.float64) |
| tb+=(hsl[ys]&~ibl[xs]).to(torch.float64) |
| bc+=tb.sum() |
| if dist.is_available() and dist.is_initialized(): |
| for t in(ls,tc,bc):dist.all_reduce(t,op=dist.ReduceOp.SUM) |
| vl=float((ls/tc).item());bpb=float((ls/math.log(2)/bc).item()) |
| rm.train();return vl,bpb |
|
|
| def sdclip(t,n=2.5): |
| m=t.float().mean();s=t.float().std() |
| return t.clamp((m-n*s).item(),(m+n*s).item()) |
|
|
| def qi6(t,ns=2.5): |
| t32=t.float();mx=31 |
| if t32.ndim==2: |
| m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9) |
| lo=m-ns*s;hi=m+ns*s |
| tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32)) |
| cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/mx |
| q=torch.clamp(torch.round(tc/sc[:,None]),-mx,mx).to(torch.int8) |
| return q.contiguous(),sc.to(torch.float16).contiguous() |
| tc=sdclip(t32,ns);cv=float(tc.abs().max().item()) |
| sc=torch.tensor(max(cv/mx,1./mx),dtype=torch.float32) |
| q=torch.clamp(torch.round(tc/sc),-mx,mx).to(torch.int8) |
| return q.contiguous(),sc |
|
|
| def qi8(t,ns=2.5): |
| t32=t.float() |
| if t32.ndim==2: |
| m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9) |
| lo=m-ns*s;hi=m+ns*s |
| tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32)) |
| cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/127. |
| q=torch.clamp(torch.round(tc/sc[:,None]),-127,127).to(torch.int8) |
| return q.contiguous(),sc.to(torch.float16).contiguous() |
| cv=float(sdclip(t32,ns).abs().max().item()) |
| sc=torch.tensor(max(cv/127.,1./127.),dtype=torch.float32) |
| q=torch.clamp(torch.round(t32.clamp(-cv,cv)/sc),-127,127).to(torch.int8) |
| return q.contiguous(),sc |
|
|
| def qsd(sd,gb=6,ns=2.5): |
| qf=qi6 if gb==6 else qi8 |
| qu,sc,dt,pt,po,qm={},{},{},{},{},{} |
| st={k:0 for k in("pc","nt","bb","qb")} |
| for n,t in sd.items(): |
| t=t.detach().cpu().contiguous() |
| st["pc"]+=t.numel();st["nt"]+=1;st["bb"]+=t.numel()*t.element_size() |
| if not t.is_floating_point():pt[n]=t;st["qb"]+=t.numel()*t.element_size();continue |
| ic=any(p in n for p in CP);ism=t.numel()<=65536 |
| if "tok_emb" in n: |
| po[n]=str(t.dtype).removeprefix("torch.") |
| q,s=qi8(t,ns);qu[n]=q;sc[n]=s;dt[n]=po[n] |
| if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":8} |
| st["qb"]+=q.numel()+s.numel()*s.element_size();continue |
| if ic or ism: |
| if t.dtype in(torch.float32,torch.bfloat16):po[n]=str(t.dtype).removeprefix("torch.") |
| pt[n]=t.float() if ic else t.to(torch.float16) |
| pt[n]=pt[n].contiguous();st["qb"]+=pt[n].numel()*pt[n].element_size();continue |
| q,s=qf(t,ns) |
| if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":gb} |
| qu[n]=q;sc[n]=s;dt[n]=str(t.dtype).removeprefix("torch.") |
| st["qb"]+=q.numel()+s.numel()*s.element_size() |
| obj={"__qf__":f"i{gb}sd","q":qu,"s":sc,"d":dt,"p":pt} |
| if qm:obj["m"]=qm |
| if po:obj["o"]=po |
| return obj,st |
|
|
| def dqsd(obj): |
| out={};qm=obj.get("m",{});po=obj.get("o",{}) |
| for n,q in obj["q"].items(): |
| dt=getattr(torch,obj["d"][n]);s=obj["s"][n] |
| if qm.get(n,{}).get("scheme")=="per_row" or s.ndim>0: |
| s=s.to(torch.float32) |
| out[n]=(q.float()*s.view(q.shape[0],*([1]*(q.ndim-1)))).to(dt).contiguous() |
| else:out[n]=(q.float()*float(s.item())).to(dt).contiguous() |
| for n,t in obj["p"].items(): |
| ot=t.detach().cpu().contiguous();od=po.get(n) |
| if isinstance(od,str):ot=ot.to(dtype=getattr(torch,od)).contiguous() |
| out[n]=ot |
| return out |
|
|
| def lds(f): |
| h=np.fromfile(f,dtype="<i4",count=256) |
| if h.size!=256 or int(h[0])!=20240520 or int(h[1])!=1:raise ValueError(f"Bad:{f}") |
| n=int(h[2]);t=np.fromfile(f,dtype="<u2",count=n,offset=256*4) |
| return torch.from_numpy(t.astype(np.uint16,copy=False)) |
|
|
| def lvt(pat,sl): |
| fs=[Path(p) for p in sorted(glob.glob(pat))] |
| if not fs:raise FileNotFoundError(f"No val:{pat}") |
| t=torch.cat([lds(f) for f in fs]).contiguous() |
| u=((t.numel()-1)//sl)*sl;return t[:u+1] |
|
|
| class TS: |
| def __init__(s,pat): |
| fs=[Path(p) for p in sorted(glob.glob(pat))] |
| if not fs:raise FileNotFoundError(f"No:{pat}") |
| s.fs=fs;s.i=0;s.t=lds(fs[0]);s.p=0 |
| def take(s,n): |
| ch=[];r=n |
| while r>0: |
| av=s.t.numel()-s.p |
| if av<=0:s.i=(s.i+1)%len(s.fs);s.t=lds(s.fs[s.i]);s.p=0;av=s.t.numel() |
| k=min(r,av);ch.append(s.t[s.p:s.p+k]);s.p+=k;r-=k |
| return ch[0] if len(ch)==1 else torch.cat(ch) |
|
|
| class DTL: |
| def __init__(s,pat,rk,ws,dev):s.rk=rk;s.ws=ws;s.dev=dev;s.st=TS(pat) |
| def nb(s,gt,sl,ga): |
| lt=gt//(s.ws*ga);ps=lt+1 |
| ch=s.st.take(ps*s.ws);st=s.rk*ps |
| lo=ch[st:st+ps].to(torch.int64) |
| x=lo[:-1].reshape(-1,sl);y=lo[1:].reshape(-1,sl) |
| return x.to(s.dev,non_blocking=True),y.to(s.dev,non_blocking=True) |
|
|
| class RN(nn.Module): |
| def __init__(s,eps=None):super().__init__();s.eps=eps |
| def forward(s,x):return F.rms_norm(x,(x.size(-1),),eps=s.eps) |
|
|
| class Rot(nn.Module): |
| def __init__(s,d,b=10000.): |
| super().__init__() |
| s.register_buffer("if_",1./(b**(torch.arange(0,d,2,dtype=torch.float32)/d)),persistent=False) |
| s._cl=0;s._c=None;s._s=None |
| def forward(s,sl,dev,dt): |
| if s._c is None or s._cl!=sl or s._c.device!=dev: |
| t=torch.arange(sl,device=dev,dtype=s.if_.dtype) |
| fr=torch.outer(t,s.if_.to(dev)) |
| s._c=fr.cos()[None,None,:,:];s._s=fr.sin()[None,None,:,:];s._cl=sl |
| return s._c.to(dtype=dt),s._s.to(dtype=dt) |
|
|
| def arot(x,c,si): |
| h=x.size(-1)//2;x1,x2=x[...,:h],x[...,h:] |
| return torch.cat((x1*c+x2*si,x1*(-si)+x2*c),dim=-1) |
|
|
| class CSA(nn.Module): |
| def __init__(s,d,nh,nk,rb,qkg): |
| super().__init__() |
| assert d%nh==0 and nh%nk==0 |
| s.nh=nh;s.nk=nk;s.hd=d//nh;kd=nk*s.hd |
| s.cq=nn.Linear(d,d,bias=False);s.ck=nn.Linear(d,kd,bias=False) |
| s.cv=nn.Linear(d,kd,bias=False);s.pr=nn.Linear(d,d,bias=False) |
| s.qg=nn.Parameter(torch.full((nh,),qkg,dtype=torch.float32)) |
| s.rot=Rot(s.hd,base=rb) |
| def forward(s,x): |
| B,T,_=x.shape |
| q=s.cq(x).reshape(B,T,s.nh,s.hd).transpose(1,2) |
| k=s.ck(x).reshape(B,T,s.nk,s.hd).transpose(1,2) |
| v=s.cv(x).reshape(B,T,s.nk,s.hd).transpose(1,2) |
| q=F.rms_norm(q,(q.size(-1),));k=F.rms_norm(k,(k.size(-1),)) |
| c,si=s.rot(T,x.device,q.dtype) |
| q=arot(q,c,si);k=arot(k,c,si) |
| q=q*s.qg.to(dtype=q.dtype)[None,:,None,None] |
| y=F.scaled_dot_product_attention(q,k,v,attn_mask=None,is_causal=True, |
| enable_gqa=(s.nk!=s.nh)) |
| return s.pr(y.transpose(1,2).contiguous().reshape(B,T,-1)) |
|
|
| class MLP(nn.Module): |
| def __init__(s,d,m): |
| super().__init__() |
| h=d*m;s.fc=nn.Linear(d,h,bias=False);s.pr=nn.Linear(h,d,bias=False) |
| def forward(s,x):return s.pr(torch.relu(s.fc(x)).square()) |
|
|
| class PB(nn.Module): |
| """Parallel residual block.""" |
| def __init__(s,d,nh,nk,mm,rb,qkg): |
| super().__init__() |
| s.n=RN();s.a=CSA(d,nh,nk,rb,qkg);s.m=MLP(d,mm) |
| s.as_=nn.Parameter(torch.ones(d,dtype=torch.float32)) |
| s.ms=nn.Parameter(torch.ones(d,dtype=torch.float32)) |
| s.rm=nn.Parameter(torch.stack([torch.ones(d),torch.zeros(d)]).float()) |
| def forward(s,x,x0): |
| mx=s.rm.to(x.dtype) |
| x=mx[0][None,None,:]*x+mx[1][None,None,:]*x0 |
| h=s.n(x) |
| x=x+s.as_.to(x.dtype)[None,None,:]*s.a(h)+s.ms.to(x.dtype)[None,None,:]*s.m(h) |
| return x |
|
|
| class RGPT(nn.Module): |
| def __init__(s,a): |
| super().__init__() |
| s.lsc=a.lsc;s._tr=a.nr;s._er=a.ner or a.nr*2;s._V=a.V |
| s.te=nn.Embedding(a.V,a.D) |
| s.bl=nn.ModuleList([PB(a.D,a.nh,a.nkv,a.mm,a.rb,a.qkg) for _ in range(a.nul)]) |
| s.fn=RN();nn.init.normal_(s.te.weight,std=0.005) |
| def _fh(s,ids): |
| x=F.rms_norm(s.te(ids),(s.te.embedding_dim,));x0=x |
| n=s._tr if s.training else s._er |
| for _ in range(n): |
| for b in s.bl:x=b(x,x0) |
| return s.fn(x) |
| def forward(s,ids,tgt): |
| h=s._fh(ids);lo=F.linear(h.reshape(-1,h.size(-1)),s.te.weight) |
| lo=s.lsc*torch.tanh(lo/s.lsc) |
| return F.cross_entropy(lo.float(),tgt.reshape(-1),reduction="mean") |
| def ptl(s,ids,tgt): |
| h=s._fh(ids);B,T,D=h.shape |
| lo=F.linear(h.reshape(B*T,D),s.te.weight) |
| lo=s.lsc*torch.tanh(lo/s.lsc) |
| return F.cross_entropy(lo.float(),tgt.reshape(B*T),reduction="none").reshape(B,T) |
| @torch.no_grad() |
| def ptl_ttt(s,ids,tgt,a): |
| """Score-first TTT: score chunk, then update MLP W_down for next chunk.""" |
| cs=a.ttcs;lr=a.ttlr;B,T=ids.shape |
| if a.ttly=="all":li=list(range(len(s.bl))) |
| else:li=[int(x) for x in a.ttly.split(",")] |
| ow={i:s.bl[i].m.pr.weight.data.clone() for i in li} |
| ap=[];nc=(T+cs-1)//cs |
| for ci in range(nc): |
| lo=ci*cs;hi=min((ci+1)*cs,T) |
| h=s._fh(ids);hc=h[:,lo:hi,:];yc=tgt[:,lo:hi] |
| lg=F.linear(hc.reshape(-1,hc.size(-1)),s.te.weight) |
| lg=s.lsc*torch.tanh(lg/s.lsc) |
| pt=F.cross_entropy(lg.float(),yc.reshape(-1),reduction="none").reshape(B,hi-lo) |
| ap.append(pt) |
| if ci<nc-1: |
| for i in li: |
| blk=s.bl[i];hn=F.rms_norm(hc.reshape(-1,hc.size(-1)).float(),(hc.size(-1),)) |
| z=torch.relu(blk.m.fc(hn.to(hc.dtype))).square() |
| pred=z@blk.m.pr.weight.T |
| res=pred-hn.to(pred.dtype) |
| gw=res.T@z/z.size(0) |
| blk.m.pr.weight.data-=lr*gw.to(blk.m.pr.weight.dtype) |
| for i in li:s.bl[i].m.pr.weight.data=ow[i] |
| return torch.cat(ap,dim=1) |
|
|
| class EMA: |
| def __init__(s,m,d=0.999):s.m=m;s.d=d;s.sh={n:p.data.clone() for n,p in m.named_parameters()};s.bk={} |
| def up(s): |
| for n,p in s.m.named_parameters():s.sh[n].mul_(s.d).add_(p.data,alpha=1.-s.d) |
| def ap(s): |
| s.bk={}; |
| for n,p in s.m.named_parameters():s.bk[n]=p.data.clone();p.data.copy_(s.sh[n]) |
| def re(s): |
| for n,p in s.m.named_parameters():p.data.copy_(s.bk[n]) |
| s.bk={} |
|
|
| def main(): |
| global zp5 |
| code=Path(__file__).read_text(encoding="utf-8") |
| a=H();zp5=torch.compile(zp5) |
| dd="RANK" in os.environ and "WORLD_SIZE" in os.environ |
| rk=int(os.environ.get("RANK","0"));ws=int(os.environ.get("WORLD_SIZE","1")) |
| lr_=int(os.environ.get("LOCAL_RANK","0")) |
| ga=max(1,8//ws);gs=1./ga |
| if not torch.cuda.is_available():raise RuntimeError("CUDA required") |
| dev=torch.device("cuda",lr_);torch.cuda.set_device(dev) |
| if dd:dist.init_process_group("nccl",device_id=dev);dist.barrier() |
| ma=rk==0 |
| torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True |
| from torch.backends.cuda import enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp,enable_cudnn_sdp |
| enable_flash_sdp(True);enable_math_sdp(False);enable_mem_efficient_sdp(False);enable_cudnn_sdp(False) |
| lf=None |
| if ma:os.makedirs("logs",exist_ok=True);lf=f"logs/{a.rid}.txt";print(lf) |
| def l0(m,c=True): |
| if not ma:return |
| if c:print(m) |
| if lf: |
| with open(lf,"a") as f:print(m,file=f) |
| l0(code,console=False);l0(f"Python {sys.version}",console=False);l0(f"PyTorch {torch.__version__}",console=False) |
| try:l0(subprocess.run(["nvidia-smi"],capture_output=True,text=True,check=False).stdout,console=False) |
| except FileNotFoundError:pass |
| random.seed(a.seed);np.random.seed(a.seed);torch.manual_seed(a.seed);torch.cuda.manual_seed_all(a.seed) |
| sp=spm.SentencePieceProcessor(model_file=a.tp) |
| bbl,hsl,ibl=build_sp_luts(sp,a.V,dev) |
| vt=lvt(a.vf,a.swl);l0(f"val_tokens:{vt.numel()}") |
| bm=RGPT(a).to(dev).bfloat16() |
| cm=torch.compile(bm,dynamic=False,fullgraph=True) |
| mdl=DDP(cm,device_ids=[lr_],broadcast_buffers=False) if dd else cm |
| nu=sum(p.numel() for p in bm.parameters()) |
| ed=a.nul*a.nr |
| l0(f"params:{nu} eff_depth:{ed} train_loops:{a.nr} eval_loops:{bm._er}") |
| l0(f"ws:{ws} ga:{ga}") |
| bp=list(bm.bl.named_parameters()) |
| mp=[p for n,p in bp if p.ndim==2 and not any(c in n for c in CP)] |
| sp_=[p for n,p in bp if p.ndim<2 or any(c in n for c in CP)] |
| ot=torch.optim.Adam([{"params":[bm.te.weight],"lr":a.elr,"base_lr":a.elr}],betas=(a.b1,a.b2),eps=a.ae,fused=True) |
| om=Muon(mp,lr=a.mlr,mom=a.mmo,bs=a.mbs,wd=a.mwd) |
| for g in om.param_groups:g["base_lr"]=a.mlr |
| os_=torch.optim.Adam([{"params":sp_,"lr":a.slr,"base_lr":a.slr}],betas=(a.b1,a.b2),eps=a.ae,fused=True) |
| opts=[ot,om,os_] |
| ema=EMA(bm,d=0.999);ess=int(a.iters*a.esf) |
| mms=1000.*a.mws if a.mws>0 else None |
| def lrm(st,el): |
| if a.wdi<=0:return 1. |
| if mms is None: |
| w=max(a.iters-a.wdi,0) |
| return max((a.iters-st)/max(a.wdi,1),0.) if w<=st<a.iters else 1. |
| sm=el/max(st,1);rm=max(mms-el,0.);wm=a.wdi*sm |
| return rm/max(wm,1e-9) if rm<=wm else 1. |
| def za():[o.zero_grad(set_to_none=True) for o in opts] |
| if a.wui>0: |
| im={n:t.detach().cpu().clone() for n,t in bm.state_dict().items()} |
| io_=[copy.deepcopy(o.state_dict()) for o in opts] |
| mdl.train();tw=DTL(a.tf,rk,ws,dev) |
| for _ in range(a.wui): |
| za() |
| for mi in range(ga): |
| if dd:mdl.require_backward_grad_sync=(mi==ga-1) |
| x,y=tw.nb(a.tbt,a.tsl,ga) |
| with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward() |
| for o in opts:o.step() |
| za() |
| bm.load_state_dict(im,strict=True) |
| for o,s in zip(opts,io_):o.load_state_dict(s) |
| za() |
| if dd:mdl.require_backward_grad_sync=True |
| tl=DTL(a.tf,rk,ws,dev);tms=0.;ss=None |
| torch.cuda.synchronize();t0=time.perf_counter();step=0 |
| while True: |
| ls=step==a.iters or(ss is not None and step>=ss) |
| dv=ls or(a.vle>0 and step%a.vle==0) |
| if dv: |
| torch.cuda.synchronize();tms+=1000.*(time.perf_counter()-t0) |
| vl,vb=eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False) |
| l0(f"step:{step}/{a.iters} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_ms:{tms:.0f} step_avg:{tms/max(step,1):.2f}ms") |
| torch.cuda.synchronize();t0=time.perf_counter() |
| if ls: |
| if ma: |
| ema.ap();l0("EMA+TTT eval...") |
| vle,vbe=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True) |
| l0(f"ema_ttt val_loss:{vle:.4f} val_bpb:{vbe:.4f}") |
| sd=bm.state_dict();obj,st=qsd(sd,a.gb,a.sdn) |
| buf=io.BytesIO();torch.save(obj,buf) |
| cmp=zlib.compress(buf.getvalue(),level=9) |
| cb=len(code.encode());mb=len(cmp);tb=cb+mb |
| l0(f"artifact code:{cb} model:{mb} total:{tb} ({tb/1e6:.3f}MB) params:{st['pc']}") |
| sd2=dqsd(obj);bm.load_state_dict(sd2,strict=True) |
| vl2,vb2=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True) |
| l0(f"quant+ttt val_loss:{vl2:.4f} val_bpb:{vb2:.4f}") |
| ema.re() |
| break |
| if ss is None and mms is not None: |
| torch.cuda.synchronize() |
| el=1000.*(time.perf_counter()-t0)+tms |
| if el>=mms:ss=step+1 |
| za() |
| for mi in range(ga): |
| if dd:mdl.require_backward_grad_sync=(mi==ga-1) |
| x,y=tl.nb(a.tbt,a.tsl,ga) |
| with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward() |
| torch.cuda.synchronize();el=1000.*(time.perf_counter()-t0)+tms |
| m=lrm(step,el) |
| for o in opts: |
| for g in o.param_groups:g["lr"]=g["base_lr"]*m |
| for o in opts:o.step() |
| if step>=ess:ema.up() |
| if step%a.tle==0 and ma:l0(f"step:{step} lr_mul:{m:.4f}") |
| step+=1 |
|
|
| if __name__=="__main__":main() |
|
|