parameter-golf-v2 / train_final.py
rtferraz's picture
Add minified training script for submission
768495c verified
"""PG-v2: SP8192+ParallelRes+DepthRec+TTT+Int6GPTQ+EMA"""
from __future__ import annotations
import copy,glob,io,math,os,random,subprocess,sys,time,uuid,zlib
from pathlib import Path
import numpy as np
import sentencepiece as spm
import torch,torch.distributed as dist,torch.nn.functional as F
from torch import Tensor,nn
from torch.nn.parallel import DistributedDataParallel as DDP
class H:
dp=os.environ.get("DATA_PATH","./data/datasets/fineweb10B_sp8192")
tf=os.path.join(dp,"fineweb_train_*.bin")
vf=os.path.join(dp,"fineweb_val_*.bin")
tp=os.environ.get("TOKENIZER_PATH","./data/tokenizers/fineweb_8192_bpe.model")
rid=os.environ.get("RUN_ID",str(uuid.uuid4()))
seed=int(os.environ.get("SEED","1337"))
vbs=int(os.environ.get("VBS","524288"));vle=int(os.environ.get("VLE","1000"))
tle=int(os.environ.get("TLE","200"))
iters=int(os.environ.get("ITERS","20000"))
wdi=int(os.environ.get("WDI","3500"));wui=int(os.environ.get("WUI","20"))
tbt=int(os.environ.get("TBT","524288"));tsl=int(os.environ.get("TSL","1024"))
mws=float(os.environ.get("MWS","600.0"))
V=int(os.environ.get("V","8192"));D=int(os.environ.get("D","768"))
nh=int(os.environ.get("NH","12"));nkv=int(os.environ.get("NKV","4"))
mm=int(os.environ.get("MM","4"))
nul=int(os.environ.get("NUL","3"));nr=int(os.environ.get("NR","8"))
ner=int(os.environ.get("NER","0"))
rb=float(os.environ.get("RB","10000.0"))
lsc=float(os.environ.get("LSC","30.0"))
qkg=float(os.environ.get("QKG","5.25"))
sws=int(os.environ.get("SWS","64"));swl=int(os.environ.get("SWL","1024"))
tte=int(os.environ.get("TTE","1"))
ttlr=float(os.environ.get("TTLR","0.01"))
ttcs=int(os.environ.get("TTCS","64"))
ttly=os.environ.get("TTLY","all")
elr=float(os.environ.get("ELR","0.05"))
mlr=float(os.environ.get("MLR","0.04"))
slr=float(os.environ.get("SLR","0.04"))
mmo=float(os.environ.get("MMO","0.95"))
mbs=int(os.environ.get("MBS","5"))
mwd=float(os.environ.get("MWD","0.09"))
b1=float(os.environ.get("B1","0.9"))
b2=float(os.environ.get("B2","0.95"))
ae=float(os.environ.get("AE","1e-8"))
gb=int(os.environ.get("GB","6"))
sdn=float(os.environ.get("SDN","2.5"))
esf=float(os.environ.get("ESF","0.4"))
CP=tuple(p for p in "attn_scale,mlp_scale,resid_mix,q_gain".split(",") if p)
def zp5(G,s=10,e=1e-7):
a,b,c=3.4445,-4.7750,2.0315
X=G.bfloat16();X/=X.norm()+e
tr=G.size(0)>G.size(1)
if tr:X=X.T
for _ in range(s):
A=X@X.T;B=b*A+c*A@A;X=a*X+B@X
return X.T if tr else X
class Muon(torch.optim.Optimizer):
def __init__(s,p,lr,mom,bs,wd=0.,nest=True):
super().__init__(p,dict(lr=lr,mom=mom,bs=bs,wd=wd,nest=nest))
@torch.no_grad()
def step(s,cl=None):
lo=None
if cl:
with torch.enable_grad():lo=cl()
dd=dist.is_available() and dist.is_initialized()
ws=dist.get_world_size() if dd else 1
rk=dist.get_rank() if dd else 0
for g in s.param_groups:
ps=g["params"];lr=g["lr"];mo=g["mom"];bs=g["bs"];wd=g["wd"];ne=g["nest"]
tot=sum(int(p.numel()) for p in ps)
fl=torch.zeros(tot,device=ps[0].device,dtype=torch.bfloat16)
cur=0
for i,p in enumerate(ps):
if i%ws==rk and p.grad is not None:
gr=p.grad
if wd:gr=gr+wd*p.data.to(gr.dtype)
st=s.state[p]
if "mb" not in st:st["mb"]=torch.zeros_like(gr)
buf=st["mb"];buf.mul_(mo).add_(gr)
if ne:gr=gr.add(buf,alpha=mo)
gr=zp5(gr,steps=bs)
gr*=max(1,gr.size(0)/gr.size(1))**0.5
fl[cur:cur+p.numel()]=gr.reshape(-1)
cur+=p.numel()
if dd:dist.all_reduce(fl,op=dist.ReduceOp.SUM)
cur=0
for p in ps:
gr=fl[cur:cur+p.numel()].view_as(p).to(dtype=p.dtype)
p.add_(gr,alpha=-lr);cur+=p.numel()
return lo
def build_sp_luts(sp,vs,dev):
sv=int(sp.vocab_size());sz=max(sv,vs)
bb=np.zeros(sz,dtype=np.int16);hs=np.zeros(sz,dtype=bool);ib=np.ones(sz,dtype=bool)
for t in range(sv):
if sp.is_control(t) or sp.is_unknown(t) or sp.is_unused(t):continue
ib[t]=False
if sp.is_byte(t):bb[t]=1;continue
pc=sp.id_to_piece(t)
if pc.startswith("\u2581"):hs[t]=True;pc=pc[1:]
bb[t]=len(pc.encode("utf-8"))
return(torch.tensor(bb,dtype=torch.int16,device=dev),
torch.tensor(hs,dtype=torch.bool,device=dev),
torch.tensor(ib,dtype=torch.bool,device=dev))
def eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False):
sl=a.swl;st=a.sws;T=vt.numel()
starts=list(range(0,T-sl-1,st))
my=starts[rk::ws]
ls=torch.zeros((),device=dev,dtype=torch.float64)
tc=torch.zeros((),device=dev,dtype=torch.float64)
bc=torch.zeros((),device=dev,dtype=torch.float64)
rm=mdl
while hasattr(rm,'module'):rm=rm.module
if hasattr(rm,'_orig_mod'):rm=rm._orig_mod
rm.eval()
ctx=torch.no_grad if ttt else torch.inference_mode
with ctx():
for s in my:
e=s+sl
x=vt[s:e].unsqueeze(0).to(dev,dtype=torch.int64)
y=vt[s+1:e+1].unsqueeze(0).to(dev,dtype=torch.int64)
with torch.autocast("cuda",dtype=torch.bfloat16):
if ttt and a.tte:ptl=rm.ptl_ttt(x,y,a)
else:ptl=rm.ptl(x,y)
lo=sl-st;ps=ptl[0,lo:];ys=y[0,lo:];xs=x[0,lo:]
ls+=ps.to(torch.float64).sum();tc+=ps.numel()
tb=bbl[ys].to(torch.float64)
tb+=(hsl[ys]&~ibl[xs]).to(torch.float64)
bc+=tb.sum()
if dist.is_available() and dist.is_initialized():
for t in(ls,tc,bc):dist.all_reduce(t,op=dist.ReduceOp.SUM)
vl=float((ls/tc).item());bpb=float((ls/math.log(2)/bc).item())
rm.train();return vl,bpb
def sdclip(t,n=2.5):
m=t.float().mean();s=t.float().std()
return t.clamp((m-n*s).item(),(m+n*s).item())
def qi6(t,ns=2.5):
t32=t.float();mx=31
if t32.ndim==2:
m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9)
lo=m-ns*s;hi=m+ns*s
tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32))
cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/mx
q=torch.clamp(torch.round(tc/sc[:,None]),-mx,mx).to(torch.int8)
return q.contiguous(),sc.to(torch.float16).contiguous()
tc=sdclip(t32,ns);cv=float(tc.abs().max().item())
sc=torch.tensor(max(cv/mx,1./mx),dtype=torch.float32)
q=torch.clamp(torch.round(tc/sc),-mx,mx).to(torch.int8)
return q.contiguous(),sc
def qi8(t,ns=2.5):
t32=t.float()
if t32.ndim==2:
m=t32.mean(1,keepdim=True);s=t32.std(1,keepdim=True).clamp_min(1e-9)
lo=m-ns*s;hi=m+ns*s
tc=t32.clamp(lo.expand_as(t32),hi.expand_as(t32))
cv=tc.abs().amax(1).clamp_min(1e-9);sc=cv/127.
q=torch.clamp(torch.round(tc/sc[:,None]),-127,127).to(torch.int8)
return q.contiguous(),sc.to(torch.float16).contiguous()
cv=float(sdclip(t32,ns).abs().max().item())
sc=torch.tensor(max(cv/127.,1./127.),dtype=torch.float32)
q=torch.clamp(torch.round(t32.clamp(-cv,cv)/sc),-127,127).to(torch.int8)
return q.contiguous(),sc
def qsd(sd,gb=6,ns=2.5):
qf=qi6 if gb==6 else qi8
qu,sc,dt,pt,po,qm={},{},{},{},{},{}
st={k:0 for k in("pc","nt","bb","qb")}
for n,t in sd.items():
t=t.detach().cpu().contiguous()
st["pc"]+=t.numel();st["nt"]+=1;st["bb"]+=t.numel()*t.element_size()
if not t.is_floating_point():pt[n]=t;st["qb"]+=t.numel()*t.element_size();continue
ic=any(p in n for p in CP);ism=t.numel()<=65536
if "tok_emb" in n:
po[n]=str(t.dtype).removeprefix("torch.")
q,s=qi8(t,ns);qu[n]=q;sc[n]=s;dt[n]=po[n]
if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":8}
st["qb"]+=q.numel()+s.numel()*s.element_size();continue
if ic or ism:
if t.dtype in(torch.float32,torch.bfloat16):po[n]=str(t.dtype).removeprefix("torch.")
pt[n]=t.float() if ic else t.to(torch.float16)
pt[n]=pt[n].contiguous();st["qb"]+=pt[n].numel()*pt[n].element_size();continue
q,s=qf(t,ns)
if s.ndim>0:qm[n]={"scheme":"per_row","axis":0,"bits":gb}
qu[n]=q;sc[n]=s;dt[n]=str(t.dtype).removeprefix("torch.")
st["qb"]+=q.numel()+s.numel()*s.element_size()
obj={"__qf__":f"i{gb}sd","q":qu,"s":sc,"d":dt,"p":pt}
if qm:obj["m"]=qm
if po:obj["o"]=po
return obj,st
def dqsd(obj):
out={};qm=obj.get("m",{});po=obj.get("o",{})
for n,q in obj["q"].items():
dt=getattr(torch,obj["d"][n]);s=obj["s"][n]
if qm.get(n,{}).get("scheme")=="per_row" or s.ndim>0:
s=s.to(torch.float32)
out[n]=(q.float()*s.view(q.shape[0],*([1]*(q.ndim-1)))).to(dt).contiguous()
else:out[n]=(q.float()*float(s.item())).to(dt).contiguous()
for n,t in obj["p"].items():
ot=t.detach().cpu().contiguous();od=po.get(n)
if isinstance(od,str):ot=ot.to(dtype=getattr(torch,od)).contiguous()
out[n]=ot
return out
def lds(f):
h=np.fromfile(f,dtype="<i4",count=256)
if h.size!=256 or int(h[0])!=20240520 or int(h[1])!=1:raise ValueError(f"Bad:{f}")
n=int(h[2]);t=np.fromfile(f,dtype="<u2",count=n,offset=256*4)
return torch.from_numpy(t.astype(np.uint16,copy=False))
def lvt(pat,sl):
fs=[Path(p) for p in sorted(glob.glob(pat))]
if not fs:raise FileNotFoundError(f"No val:{pat}")
t=torch.cat([lds(f) for f in fs]).contiguous()
u=((t.numel()-1)//sl)*sl;return t[:u+1]
class TS:
def __init__(s,pat):
fs=[Path(p) for p in sorted(glob.glob(pat))]
if not fs:raise FileNotFoundError(f"No:{pat}")
s.fs=fs;s.i=0;s.t=lds(fs[0]);s.p=0
def take(s,n):
ch=[];r=n
while r>0:
av=s.t.numel()-s.p
if av<=0:s.i=(s.i+1)%len(s.fs);s.t=lds(s.fs[s.i]);s.p=0;av=s.t.numel()
k=min(r,av);ch.append(s.t[s.p:s.p+k]);s.p+=k;r-=k
return ch[0] if len(ch)==1 else torch.cat(ch)
class DTL:
def __init__(s,pat,rk,ws,dev):s.rk=rk;s.ws=ws;s.dev=dev;s.st=TS(pat)
def nb(s,gt,sl,ga):
lt=gt//(s.ws*ga);ps=lt+1
ch=s.st.take(ps*s.ws);st=s.rk*ps
lo=ch[st:st+ps].to(torch.int64)
x=lo[:-1].reshape(-1,sl);y=lo[1:].reshape(-1,sl)
return x.to(s.dev,non_blocking=True),y.to(s.dev,non_blocking=True)
class RN(nn.Module):
def __init__(s,eps=None):super().__init__();s.eps=eps
def forward(s,x):return F.rms_norm(x,(x.size(-1),),eps=s.eps)
class Rot(nn.Module):
def __init__(s,d,b=10000.):
super().__init__()
s.register_buffer("if_",1./(b**(torch.arange(0,d,2,dtype=torch.float32)/d)),persistent=False)
s._cl=0;s._c=None;s._s=None
def forward(s,sl,dev,dt):
if s._c is None or s._cl!=sl or s._c.device!=dev:
t=torch.arange(sl,device=dev,dtype=s.if_.dtype)
fr=torch.outer(t,s.if_.to(dev))
s._c=fr.cos()[None,None,:,:];s._s=fr.sin()[None,None,:,:];s._cl=sl
return s._c.to(dtype=dt),s._s.to(dtype=dt)
def arot(x,c,si):
h=x.size(-1)//2;x1,x2=x[...,:h],x[...,h:]
return torch.cat((x1*c+x2*si,x1*(-si)+x2*c),dim=-1)
class CSA(nn.Module):
def __init__(s,d,nh,nk,rb,qkg):
super().__init__()
assert d%nh==0 and nh%nk==0
s.nh=nh;s.nk=nk;s.hd=d//nh;kd=nk*s.hd
s.cq=nn.Linear(d,d,bias=False);s.ck=nn.Linear(d,kd,bias=False)
s.cv=nn.Linear(d,kd,bias=False);s.pr=nn.Linear(d,d,bias=False)
s.qg=nn.Parameter(torch.full((nh,),qkg,dtype=torch.float32))
s.rot=Rot(s.hd,base=rb)
def forward(s,x):
B,T,_=x.shape
q=s.cq(x).reshape(B,T,s.nh,s.hd).transpose(1,2)
k=s.ck(x).reshape(B,T,s.nk,s.hd).transpose(1,2)
v=s.cv(x).reshape(B,T,s.nk,s.hd).transpose(1,2)
q=F.rms_norm(q,(q.size(-1),));k=F.rms_norm(k,(k.size(-1),))
c,si=s.rot(T,x.device,q.dtype)
q=arot(q,c,si);k=arot(k,c,si)
q=q*s.qg.to(dtype=q.dtype)[None,:,None,None]
y=F.scaled_dot_product_attention(q,k,v,attn_mask=None,is_causal=True,
enable_gqa=(s.nk!=s.nh))
return s.pr(y.transpose(1,2).contiguous().reshape(B,T,-1))
class MLP(nn.Module):
def __init__(s,d,m):
super().__init__()
h=d*m;s.fc=nn.Linear(d,h,bias=False);s.pr=nn.Linear(h,d,bias=False)
def forward(s,x):return s.pr(torch.relu(s.fc(x)).square())
class PB(nn.Module):
"""Parallel residual block."""
def __init__(s,d,nh,nk,mm,rb,qkg):
super().__init__()
s.n=RN();s.a=CSA(d,nh,nk,rb,qkg);s.m=MLP(d,mm)
s.as_=nn.Parameter(torch.ones(d,dtype=torch.float32))
s.ms=nn.Parameter(torch.ones(d,dtype=torch.float32))
s.rm=nn.Parameter(torch.stack([torch.ones(d),torch.zeros(d)]).float())
def forward(s,x,x0):
mx=s.rm.to(x.dtype)
x=mx[0][None,None,:]*x+mx[1][None,None,:]*x0
h=s.n(x)
x=x+s.as_.to(x.dtype)[None,None,:]*s.a(h)+s.ms.to(x.dtype)[None,None,:]*s.m(h)
return x
class RGPT(nn.Module):
def __init__(s,a):
super().__init__()
s.lsc=a.lsc;s._tr=a.nr;s._er=a.ner or a.nr*2;s._V=a.V
s.te=nn.Embedding(a.V,a.D)
s.bl=nn.ModuleList([PB(a.D,a.nh,a.nkv,a.mm,a.rb,a.qkg) for _ in range(a.nul)])
s.fn=RN();nn.init.normal_(s.te.weight,std=0.005)
def _fh(s,ids):
x=F.rms_norm(s.te(ids),(s.te.embedding_dim,));x0=x
n=s._tr if s.training else s._er
for _ in range(n):
for b in s.bl:x=b(x,x0)
return s.fn(x)
def forward(s,ids,tgt):
h=s._fh(ids);lo=F.linear(h.reshape(-1,h.size(-1)),s.te.weight)
lo=s.lsc*torch.tanh(lo/s.lsc)
return F.cross_entropy(lo.float(),tgt.reshape(-1),reduction="mean")
def ptl(s,ids,tgt):
h=s._fh(ids);B,T,D=h.shape
lo=F.linear(h.reshape(B*T,D),s.te.weight)
lo=s.lsc*torch.tanh(lo/s.lsc)
return F.cross_entropy(lo.float(),tgt.reshape(B*T),reduction="none").reshape(B,T)
@torch.no_grad()
def ptl_ttt(s,ids,tgt,a):
"""Score-first TTT: score chunk, then update MLP W_down for next chunk."""
cs=a.ttcs;lr=a.ttlr;B,T=ids.shape
if a.ttly=="all":li=list(range(len(s.bl)))
else:li=[int(x) for x in a.ttly.split(",")]
ow={i:s.bl[i].m.pr.weight.data.clone() for i in li}
ap=[];nc=(T+cs-1)//cs
for ci in range(nc):
lo=ci*cs;hi=min((ci+1)*cs,T)
h=s._fh(ids);hc=h[:,lo:hi,:];yc=tgt[:,lo:hi]
lg=F.linear(hc.reshape(-1,hc.size(-1)),s.te.weight)
lg=s.lsc*torch.tanh(lg/s.lsc)
pt=F.cross_entropy(lg.float(),yc.reshape(-1),reduction="none").reshape(B,hi-lo)
ap.append(pt)
if ci<nc-1:
for i in li:
blk=s.bl[i];hn=F.rms_norm(hc.reshape(-1,hc.size(-1)).float(),(hc.size(-1),))
z=torch.relu(blk.m.fc(hn.to(hc.dtype))).square()
pred=z@blk.m.pr.weight.T
res=pred-hn.to(pred.dtype)
gw=res.T@z/z.size(0)
blk.m.pr.weight.data-=lr*gw.to(blk.m.pr.weight.dtype)
for i in li:s.bl[i].m.pr.weight.data=ow[i]
return torch.cat(ap,dim=1)
class EMA:
def __init__(s,m,d=0.999):s.m=m;s.d=d;s.sh={n:p.data.clone() for n,p in m.named_parameters()};s.bk={}
def up(s):
for n,p in s.m.named_parameters():s.sh[n].mul_(s.d).add_(p.data,alpha=1.-s.d)
def ap(s):
s.bk={};
for n,p in s.m.named_parameters():s.bk[n]=p.data.clone();p.data.copy_(s.sh[n])
def re(s):
for n,p in s.m.named_parameters():p.data.copy_(s.bk[n])
s.bk={}
def main():
global zp5
code=Path(__file__).read_text(encoding="utf-8")
a=H();zp5=torch.compile(zp5)
dd="RANK" in os.environ and "WORLD_SIZE" in os.environ
rk=int(os.environ.get("RANK","0"));ws=int(os.environ.get("WORLD_SIZE","1"))
lr_=int(os.environ.get("LOCAL_RANK","0"))
ga=max(1,8//ws);gs=1./ga
if not torch.cuda.is_available():raise RuntimeError("CUDA required")
dev=torch.device("cuda",lr_);torch.cuda.set_device(dev)
if dd:dist.init_process_group("nccl",device_id=dev);dist.barrier()
ma=rk==0
torch.backends.cuda.matmul.allow_tf32=True;torch.backends.cudnn.allow_tf32=True
from torch.backends.cuda import enable_flash_sdp,enable_math_sdp,enable_mem_efficient_sdp,enable_cudnn_sdp
enable_flash_sdp(True);enable_math_sdp(False);enable_mem_efficient_sdp(False);enable_cudnn_sdp(False)
lf=None
if ma:os.makedirs("logs",exist_ok=True);lf=f"logs/{a.rid}.txt";print(lf)
def l0(m,c=True):
if not ma:return
if c:print(m)
if lf:
with open(lf,"a") as f:print(m,file=f)
l0(code,console=False);l0(f"Python {sys.version}",console=False);l0(f"PyTorch {torch.__version__}",console=False)
try:l0(subprocess.run(["nvidia-smi"],capture_output=True,text=True,check=False).stdout,console=False)
except FileNotFoundError:pass
random.seed(a.seed);np.random.seed(a.seed);torch.manual_seed(a.seed);torch.cuda.manual_seed_all(a.seed)
sp=spm.SentencePieceProcessor(model_file=a.tp)
bbl,hsl,ibl=build_sp_luts(sp,a.V,dev)
vt=lvt(a.vf,a.swl);l0(f"val_tokens:{vt.numel()}")
bm=RGPT(a).to(dev).bfloat16()
cm=torch.compile(bm,dynamic=False,fullgraph=True)
mdl=DDP(cm,device_ids=[lr_],broadcast_buffers=False) if dd else cm
nu=sum(p.numel() for p in bm.parameters())
ed=a.nul*a.nr
l0(f"params:{nu} eff_depth:{ed} train_loops:{a.nr} eval_loops:{bm._er}")
l0(f"ws:{ws} ga:{ga}")
bp=list(bm.bl.named_parameters())
mp=[p for n,p in bp if p.ndim==2 and not any(c in n for c in CP)]
sp_=[p for n,p in bp if p.ndim<2 or any(c in n for c in CP)]
ot=torch.optim.Adam([{"params":[bm.te.weight],"lr":a.elr,"base_lr":a.elr}],betas=(a.b1,a.b2),eps=a.ae,fused=True)
om=Muon(mp,lr=a.mlr,mom=a.mmo,bs=a.mbs,wd=a.mwd)
for g in om.param_groups:g["base_lr"]=a.mlr
os_=torch.optim.Adam([{"params":sp_,"lr":a.slr,"base_lr":a.slr}],betas=(a.b1,a.b2),eps=a.ae,fused=True)
opts=[ot,om,os_]
ema=EMA(bm,d=0.999);ess=int(a.iters*a.esf)
mms=1000.*a.mws if a.mws>0 else None
def lrm(st,el):
if a.wdi<=0:return 1.
if mms is None:
w=max(a.iters-a.wdi,0)
return max((a.iters-st)/max(a.wdi,1),0.) if w<=st<a.iters else 1.
sm=el/max(st,1);rm=max(mms-el,0.);wm=a.wdi*sm
return rm/max(wm,1e-9) if rm<=wm else 1.
def za():[o.zero_grad(set_to_none=True) for o in opts]
if a.wui>0:
im={n:t.detach().cpu().clone() for n,t in bm.state_dict().items()}
io_=[copy.deepcopy(o.state_dict()) for o in opts]
mdl.train();tw=DTL(a.tf,rk,ws,dev)
for _ in range(a.wui):
za()
for mi in range(ga):
if dd:mdl.require_backward_grad_sync=(mi==ga-1)
x,y=tw.nb(a.tbt,a.tsl,ga)
with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward()
for o in opts:o.step()
za()
bm.load_state_dict(im,strict=True)
for o,s in zip(opts,io_):o.load_state_dict(s)
za()
if dd:mdl.require_backward_grad_sync=True
tl=DTL(a.tf,rk,ws,dev);tms=0.;ss=None
torch.cuda.synchronize();t0=time.perf_counter();step=0
while True:
ls=step==a.iters or(ss is not None and step>=ss)
dv=ls or(a.vle>0 and step%a.vle==0)
if dv:
torch.cuda.synchronize();tms+=1000.*(time.perf_counter()-t0)
vl,vb=eval_sw(a,mdl,rk,ws,dev,vt,bbl,hsl,ibl,ttt=False)
l0(f"step:{step}/{a.iters} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_ms:{tms:.0f} step_avg:{tms/max(step,1):.2f}ms")
torch.cuda.synchronize();t0=time.perf_counter()
if ls:
if ma:
ema.ap();l0("EMA+TTT eval...")
vle,vbe=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True)
l0(f"ema_ttt val_loss:{vle:.4f} val_bpb:{vbe:.4f}")
sd=bm.state_dict();obj,st=qsd(sd,a.gb,a.sdn)
buf=io.BytesIO();torch.save(obj,buf)
cmp=zlib.compress(buf.getvalue(),level=9)
cb=len(code.encode());mb=len(cmp);tb=cb+mb
l0(f"artifact code:{cb} model:{mb} total:{tb} ({tb/1e6:.3f}MB) params:{st['pc']}")
sd2=dqsd(obj);bm.load_state_dict(sd2,strict=True)
vl2,vb2=eval_sw(a,bm,rk,ws,dev,vt,bbl,hsl,ibl,ttt=True)
l0(f"quant+ttt val_loss:{vl2:.4f} val_bpb:{vb2:.4f}")
ema.re()
break
if ss is None and mms is not None:
torch.cuda.synchronize()
el=1000.*(time.perf_counter()-t0)+tms
if el>=mms:ss=step+1
za()
for mi in range(ga):
if dd:mdl.require_backward_grad_sync=(mi==ga-1)
x,y=tl.nb(a.tbt,a.tsl,ga)
with torch.autocast("cuda",torch.bfloat16):(mdl(x,y)*gs).backward()
torch.cuda.synchronize();el=1000.*(time.perf_counter()-t0)+tms
m=lrm(step,el)
for o in opts:
for g in o.param_groups:g["lr"]=g["base_lr"]*m
for o in opts:o.step()
if step>=ess:ema.up()
if step%a.tle==0 and ma:l0(f"step:{step} lr_mul:{m:.4f}")
step+=1
if __name__=="__main__":main()