YixuanEvan's picture
add HF model card and mirror runnable codebase
7f7272e
#!/usr/bin/python3
import gc
import os
import os.path as osp
import random
import sys
from copy import deepcopy
from typing import Tuple, Union
import colorama
import torch
import yaml
import infinity.utils.dist as dist
from infinity.models import Infinity
from infinity.models.ema import get_ema_model
from infinity.utils import arg_util, misc
from infinity.utils.misc import os_system
import pdb
def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import Infinity
from timm.models import create_model
#model_str:infinity_2b
gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
def build_vae_bgpt_lora(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model_lora
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model_lora(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import BInfinity
from timm.models import create_model
model_str = 'b'+model_str
#model_str:binfinity_2b
gpt_wo_ddp: BInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
def build_vae_bgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import BInfinity
from timm.models import create_model
model_str = 'b'+model_str
#model_str:binfinity_2b
gpt_wo_ddp: BInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
def build_vae_bdgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import Infinity, BInfinity
from timm.models import create_model
###
gpt_teacher: Infinity = create_model(model_str, **gpt_kw)
gpt_teacher = gpt_teacher.to(device)
###
model_str = 'b'+model_str
#model_str:binfinity_2b
gpt_wo_ddp: BInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema, gpt_teacher
def build_vae_agpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import AInfinity
from timm.models import create_model
model_str = 'a'+model_str
print(model_str)
#model_str:finfinity_2b
gpt_wo_ddp: AInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
def build_vae_fagpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import FAInfinity
from timm.models import create_model
model_str = 'fa' + model_str
#model_str:fainfinity_2b
gpt_wo_ddp: FAInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
def build_vae_fgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import FInfinity
from timm.models import create_model
model_str = 'f'+model_str
print(model_str)
#model_str:finfinity_2b
gpt_wo_ddp: FInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
def build_vae_cgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'):
if args.vae_type in [8,16,18,20,24,32,64,128]:
from infinity.models.bsq_vae.vae import vae_model
schedule_mode = "dynamic"
codebook_dim = args.vae_type # 18
codebook_size = 2**codebook_dim
if args.apply_spatial_patchify:
patch_size = 8
encoder_ch_mult=[1, 2, 4, 4]
decoder_ch_mult=[1, 2, 4, 4]
else:
patch_size = 16
encoder_ch_mult=[1, 2, 4, 4, 4]
decoder_ch_mult=[1, 2, 4, 4, 4]
vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size,
encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device)
if args.fake_vae_input:
vae_local.encoder = None
vae_local.decoder = None
torch.cuda.empty_cache()
else:
raise ValueError(f"vae_type {args.vae_type} not supported")
if force_flash: args.flash = True
gpt_kw = dict(
pretrained=False, global_pool='',
text_channels=args.Ct5, text_maxlen=args.tlen,
norm_eps=args.norm_eps, rms_norm=args.rms,
shared_aln=args.saln, head_aln=args.haln,
cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop,
cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi,
raw_scale_schedule=args.scale_schedule,
head_depth=args.dec,
top_p=args.tp, top_k=args.tk,
customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm,
checkpointing=args.enable_checkpointing,
pad_to_multiplier=args.pad_to_multiplier,
use_flex_attn=args.use_flex_attn,
batch_size=args.batch_size,
add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block,
use_bit_label=args.use_bit_label,
rope2d_each_sa_layer=args.rope2d_each_sa_layer,
rope2d_normalized_by_hw=args.rope2d_normalized_by_hw,
pn=args.pn,
train_h_div_w_list=args.train_h_div_w_list,
always_training_scales=args.always_training_scales,
apply_spatial_patchify=args.apply_spatial_patchify,
)
if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp
if args.hd > 0: gpt_kw['num_heads'] = args.hd
print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n')
gpt_kw['vae_local'] = vae_local
model_str = args.model.replace('vgpt', 'infinity') # legacy
print(f"{model_str=}")
if model_str.rsplit('c', maxsplit=1)[-1].isdecimal():
model_str, block_chunks = model_str.rsplit('c', maxsplit=1)
block_chunks = int(block_chunks)
else:
block_chunks = 1
gpt_kw['block_chunks'] = block_chunks
from infinity.models import CInfinity
from timm.models import create_model
model_str = 'c'+model_str
print(model_str)
#model_str:cinfinity_2b
gpt_wo_ddp: CInfinity = create_model(model_str, **gpt_kw)
if args.use_fsdp_model_ema:
gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp)
else:
gpt_wo_ddp_ema = None
gpt_wo_ddp = gpt_wo_ddp.to(device)
assert all(not p.requires_grad for p in vae_local.parameters())
# assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters())
return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema
if __name__ == '__main__':
ld(sys.argv[1])