#!/usr/bin/python3 import gc import os import os.path as osp import random import sys from copy import deepcopy from typing import Tuple, Union import colorama import torch import yaml import infinity.utils.dist as dist from infinity.models import Infinity from infinity.models.ema import get_ema_model from infinity.utils import arg_util, misc from infinity.utils.misc import os_system import pdb def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import Infinity from timm.models import create_model #model_str:infinity_2b gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema def build_vae_bgpt_lora(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model_lora schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model_lora(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import BInfinity from timm.models import create_model model_str = 'b'+model_str #model_str:binfinity_2b gpt_wo_ddp: BInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema def build_vae_bgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import BInfinity from timm.models import create_model model_str = 'b'+model_str #model_str:binfinity_2b gpt_wo_ddp: BInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema def build_vae_bdgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import Infinity, BInfinity from timm.models import create_model ### gpt_teacher: Infinity = create_model(model_str, **gpt_kw) gpt_teacher = gpt_teacher.to(device) ### model_str = 'b'+model_str #model_str:binfinity_2b gpt_wo_ddp: BInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema, gpt_teacher def build_vae_agpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import AInfinity from timm.models import create_model model_str = 'a'+model_str print(model_str) #model_str:finfinity_2b gpt_wo_ddp: AInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema def build_vae_fagpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import FAInfinity from timm.models import create_model model_str = 'fa' + model_str #model_str:fainfinity_2b gpt_wo_ddp: FAInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema def build_vae_fgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import FInfinity from timm.models import create_model model_str = 'f'+model_str print(model_str) #model_str:finfinity_2b gpt_wo_ddp: FInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema def build_vae_cgpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): if args.vae_type in [8,16,18,20,24,32,64,128]: from infinity.models.bsq_vae.vae import vae_model schedule_mode = "dynamic" codebook_dim = args.vae_type # 18 codebook_size = 2**codebook_dim if args.apply_spatial_patchify: patch_size = 8 encoder_ch_mult=[1, 2, 4, 4] decoder_ch_mult=[1, 2, 4, 4] else: patch_size = 16 encoder_ch_mult=[1, 2, 4, 4, 4] decoder_ch_mult=[1, 2, 4, 4, 4] vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) if args.fake_vae_input: vae_local.encoder = None vae_local.decoder = None torch.cuda.empty_cache() else: raise ValueError(f"vae_type {args.vae_type} not supported") if force_flash: args.flash = True gpt_kw = dict( pretrained=False, global_pool='', text_channels=args.Ct5, text_maxlen=args.tlen, norm_eps=args.norm_eps, rms_norm=args.rms, shared_aln=args.saln, head_aln=args.haln, cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, raw_scale_schedule=args.scale_schedule, head_depth=args.dec, top_p=args.tp, top_k=args.tk, customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, checkpointing=args.enable_checkpointing, pad_to_multiplier=args.pad_to_multiplier, use_flex_attn=args.use_flex_attn, batch_size=args.batch_size, add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, use_bit_label=args.use_bit_label, rope2d_each_sa_layer=args.rope2d_each_sa_layer, rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, pn=args.pn, train_h_div_w_list=args.train_h_div_w_list, always_training_scales=args.always_training_scales, apply_spatial_patchify=args.apply_spatial_patchify, ) if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp if args.hd > 0: gpt_kw['num_heads'] = args.hd print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') gpt_kw['vae_local'] = vae_local model_str = args.model.replace('vgpt', 'infinity') # legacy print(f"{model_str=}") if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): model_str, block_chunks = model_str.rsplit('c', maxsplit=1) block_chunks = int(block_chunks) else: block_chunks = 1 gpt_kw['block_chunks'] = block_chunks from infinity.models import CInfinity from timm.models import create_model model_str = 'c'+model_str print(model_str) #model_str:cinfinity_2b gpt_wo_ddp: CInfinity = create_model(model_str, **gpt_kw) if args.use_fsdp_model_ema: gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) else: gpt_wo_ddp_ema = None gpt_wo_ddp = gpt_wo_ddp.to(device) assert all(not p.requires_grad for p in vae_local.parameters()) # assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema if __name__ == '__main__': ld(sys.argv[1])