| import copy | |
| import logging | |
| import os | |
| import os.path as osp | |
| from os.path import join | |
| import io | |
| from copy import deepcopy | |
| import torch | |
| from torch.utils.data import ConcatDataset, DataLoader | |
| from models_viclip.backbones.clip.clip_vision import interpolate_pos_embed_vit | |
| from models_viclip.backbones.bert.tokenization_bert import BertTokenizer | |
| #from utils_viclip.optimizer import create_optimizer | |
| #from utils_viclip.scheduler import create_scheduler | |
| #from utils_viclip.distributed import get_world_size | |
| #import deepspeed | |
| # | |
| #logger = logging.getLogger(__name__) | |
| def get_media_types(datasources): | |
| """get the media types for for all the dataloaders. | |
| Args: | |
| datasources (List): List of dataloaders or datasets. | |
| Returns: List. The media_types. | |
| """ | |
| if isinstance(datasources[0], DataLoader): | |
| datasets = [dataloader.dataset for dataloader in datasources] | |
| else: | |
| datasets = datasources | |
| media_types = [ | |
| dataset.datasets[0].media_type | |
| if isinstance(dataset, ConcatDataset) | |
| else dataset.media_type | |
| for dataset in datasets | |
| ] | |
| return media_types | |
| def setup_model( | |
| config, model_cls, has_decoder=False, pretrain=False, find_unused_parameters=False, num_steps_per_epoch=-1, | |
| num_classes=101): | |
| #logger.info("Creating model") | |
| config = copy.deepcopy(config) | |
| # tokenizer = BertTokenizer.from_pretrained(config.model.text_encoder.pretrained) | |
| #tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", local_files_only=True) | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| #model = model_cls(config=config, tokenizer=tokenizer, is_pretrain=pretrain,num_classes=num_classes) | |
| model = model_cls(config=config, tokenizer=tokenizer, is_pretrain=pretrain,num_classes=num_classes) | |
| #model = model.to(torch.device(config.device)) | |
| #model_without_ddp = model | |
| # if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
| # optimizer_params = create_optimizer(config.optimizer, model, return_group=True) | |
| # scheduler = None | |
| # scaler = None | |
| # else: | |
| # if config.distributed: | |
| # model = torch.nn.parallel.DistributedDataParallel( | |
| # model, | |
| # device_ids=[config.gpu], | |
| # find_unused_parameters=find_unused_parameters, # `False` for image-only task | |
| # ) | |
| # optimizer = create_optimizer(config.optimizer, model) | |
| # scaler = torch.cuda.amp.GradScaler(enabled=config.fp16) # This is never used actually | |
| # scheduler = create_scheduler(config.scheduler, optimizer) | |
| # start_epoch = 0 | |
| # global_step = 0 | |
| # auto resume the latest checkpoint | |
| #if config.get("auto_resume", False): | |
| #logger.info("Auto resuming") | |
| #model_latest = join(config.output_dir, "ckpt_latest.pth") | |
| #model_best = join(config.output_dir, "ckpt_best.pth") | |
| #large_num = -1 | |
| #for p in os.listdir(config.output_dir): | |
| # if 'ckpt' in p: | |
| # num = p.split('_')[1].split('.')[0] | |
| # if str.isnumeric(num): | |
| # if int(num) > large_num: | |
| # large_num = int(num) | |
| #if large_num != -1: | |
| # model_latest = join(config.output_dir, f"ckpt_{large_num:02d}.pth") | |
| #if osp.exists(model_latest): | |
| # config.pretrained_path = model_latest | |
| # config.resume = True | |
| #elif osp.exists(model_best): | |
| # config.pretrained_path = model_best | |
| # config.resume = True | |
| #else: | |
| #logger.info(f"Not found checkpoint in {config.output_dir}") | |
| #if config.pretrained_path.strip() and (osp.isfile(config.pretrained_path) or "s3://" in config.pretrained_path): | |
| #logger.info(f"Loading checkpoint from {config.pretrained_path}") | |
| #checkpoint = torch.load(config.pretrained_path, map_location="cpu") | |
| #try: | |
| # state_dict = checkpoint["model"] | |
| #except: # This is a deepspeed stage 1 model | |
| # state_dict = checkpoint["module"] | |
| #if config.resume: | |
| # optimizer.load_state_dict(checkpoint["optimizer"]) | |
| # scheduler.load_state_dict(checkpoint["scheduler"]) | |
| # scaler.load_state_dict(checkpoint["scaler"]) | |
| # start_epoch = checkpoint["epoch"] + 1 | |
| # global_step = checkpoint["global_step"] | |
| #elif config.evaluate or (not pretrain): # downstream init from pretrained ckpt | |
| # is_blip_model = "VindLU_BLIP" in config.model.get("model_cls", "") | |
| # # interpolate positional embeddings. | |
| # if "vit" in config.model.vision_encoder.name: | |
| # state_dict = interpolate_pos_embed_vit(state_dict, model_without_ddp) | |
| # else: | |
| # raise ValueError( | |
| # f" vision encoder: {config.model.vision_encoder.name} not implelented" | |
| # ) | |
| # if not config.evaluate or config.get("zero_shot", False): # finetuning from a pretarined weights. | |
| # for key in list(state_dict.keys()): | |
| # if "bert" in key and not is_blip_model: | |
| # encoder_key = key.replace("bert.", "") | |
| # state_dict[encoder_key] = state_dict[key] | |
| # if not has_decoder: | |
| # del state_dict[key] | |
| # # init text decoder as multimodal encoder (last 6 layers of model.text_encoder) | |
| # # only for generation tasks like VQA | |
| # if has_decoder and "text_encoder" in key and not is_blip_model: | |
| # if "layer" in key: | |
| # encoder_keys = key.split(".") | |
| # layer_num = int(encoder_keys[4]) | |
| # if layer_num < config.model.text_encoder.fusion_layer: | |
| # del state_dict[key] | |
| # continue | |
| # else: | |
| # decoder_layer_num = layer_num - config.model.text_encoder.fusion_layer | |
| # encoder_keys[4] = str(decoder_layer_num) | |
| # encoder_key = ".".join(encoder_keys) | |
| # else: | |
| # encoder_key = key | |
| # decoder_key = encoder_key.replace("text_encoder", "text_decoder") | |
| # state_dict[decoder_key] = state_dict[key] | |
| # del state_dict[key] | |
| #if hasattr(config, "wiseft") and config.wiseft.enable: | |
| # #logger.info(f"Wiseft with coefficient {config.wiseft.coef}") | |
| # missing_keys_in_pretrained = [k for k in model_without_ddp.state_dict().keys() if k not in state_dict] | |
| # missing_keys_in_model = [k for k in state_dict.keys() if k not in model_without_ddp.state_dict()] | |
| # mismatch_keys = [k for k in state_dict.keys() if k in model_without_ddp.state_dict() \ | |
| # and state_dict[k].shape != model_without_ddp.state_dict()[k].shape] | |
| # common_keys = [k for k in state_dict.keys() if k in model_without_ddp.state_dict()] | |
| #logger.info(f"Missing keys in pretrained: {missing_keys_in_pretrained}") | |
| #logger.info(f"Missing keys in model: {missing_keys_in_model}") | |
| #logger.info(f"Mismatch keys: {mismatch_keys}") | |
| #logger.info(f"Keys to exclude: {config.wiseft.keys_to_exclude}") | |
| #for k in common_keys: | |
| # if k in config.wiseft.keys_to_exclude: | |
| # continue | |
| # state_dict[k] = config.wiseft.coef * state_dict[k] + (1 - config.wiseft.coef) * model_without_ddp.state_dict()[k].cpu() | |
| #msg = model_without_ddp.load_state_dict(state_dict, strict=False) | |
| #print(msg) | |
| #logger.info(msg) | |
| #logger.info(f"Loaded checkpoint from {config.pretrained_path}") | |
| #else: | |
| #logger.warning("No pretrained checkpoint provided, training from scratch") | |
| #if scaler is None: | |
| # model = model_without_ddp | |
| # model, optimizer, _, _ = deepspeed.initialize( | |
| # args=config, model=model, model_parameters=optimizer_params, dist_init_required=not config.distributed, | |
| # lr_scheduler=lambda opt: create_scheduler(config.scheduler, opt) | |
| # ) | |
| # | |
| #if config.resume and osp.isdir(config.pretrained_path): | |
| # output_dir, tag = os.path.split(config.pretrained_path) | |
| # model.load_checkpoint(output_dir, tag=tag) | |
| # global_step = model.global_steps | |
| # assert num_steps_per_epoch > 0, "Please provide num_steps_per_epoch" | |
| # start_epoch = global_step // num_steps_per_epoch | |
| #return ( | |
| # model, | |
| # model_without_ddp, | |
| # optimizer, | |
| # scheduler, | |
| # scaler, | |
| # tokenizer, | |
| # start_epoch, | |
| # global_step, | |
| #) | |
| return model | |