| import sys |
| import open_clip |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from torchvision.transforms import transforms |
|
|
| from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL |
|
|
|
|
|
|
| def print_statistics(arr): |
| |
| assert len(arr.shape) == 1 |
| print(f'[mean] {arr.mean():.4f} [median] {np.median(arr):.4f} [min] {arr.min():.4f} [max] ' |
| f'{arr.max():.4f} [std] {arr.std():.4f} [n] {len(arr)}\n') |
|
|
| def interpolate_state_dict(m1, beta=0.): |
| m = {} |
| try: |
| m2 = torch.load("/mnt/nsingh/project_multimodal/models/clip-vit-l-visual.pt", map_location='cpu') |
| except: |
| m2 = torch.load("/data/naman_deep_singh/project_multimodal/clip-vit-l-visual.pt", map_location='cpu') |
| for k in m1.keys(): |
| m[k] = (1 - beta) * m1[k] + beta * m2[k] |
|
|
| return m |
|
|
|
|
| def load_clip_model(clip_model_name, pretrained, beta=0.): |
| try: |
| model, _, image_processor = open_clip.create_model_and_transforms( |
| clip_model_name, pretrained='openai', device='cpu' |
| ) |
| if pretrained != 'openai': |
| if isinstance(pretrained, str): |
| checkpoint = torch.load(pretrained, map_location=torch.device('cpu')) |
| else: |
| checkpoint = pretrained |
| |
| if beta != 0.: |
| print("beta", beta) |
| checkpoint = interpolate_state_dict(pretrained, beta) |
|
|
| if 'vision_encoder_state_dict' in checkpoint.keys(): |
| model.visual.load_state_dict(checkpoint['vision_encoder_state_dict']) |
| else: |
| model.visual.load_state_dict(checkpoint) |
| except RuntimeError as e: |
| print(f'error: {e}', file=sys.stderr) |
| print('retrying by loading whole model..', file=sys.stderr) |
| torch.cuda.empty_cache() |
| model, _, image_processor = open_clip.create_model_and_transforms( |
| clip_model_name, pretrained=pretrained, force_quick_gelu=True, device='cpu' |
| ) |
| model.eval() |
|
|
| |
| preprocessor_no_norm = transforms.Compose(image_processor.transforms[:-1]) |
| normalizer = image_processor.transforms[-1] |
| return model, preprocessor_no_norm, normalizer |
|
|
| @torch.no_grad() |
| def get_text_embeddings(model, dataset, texts): |
| assert not (dataset and texts) |
| if dataset: |
| assert dataset == 'imagenet' |
| if dataset == 'imagenet': |
| template = 'This is a photo of a {}' |
| texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()] |
| text_tokens = open_clip.tokenize(texts) |
| elif texts: |
| text_tokens = open_clip.tokenize(texts) |
| embedding_text_labels_norm = [] |
| chunk_size = 500 |
| for i in range(0, len(text_tokens), chunk_size): |
| el = text_tokens[i:i+chunk_size] |
| embedding_text_labels_norm.append( |
| model.model.encode_text(el.cuda(), normalize=True).detach().cpu() |
| ) |
| embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T |
| if dataset == 'imagenet': |
| assert (embedding_text_labels_norm.shape == (512, 1000) |
| or embedding_text_labels_norm.shape == (768, 1000)), embedding_text_labels_norm.shape |
| return embedding_text_labels_norm |
|
|
|
|
| @torch.inference_mode() |
| def compute_accuracy_no_dataloader(model, data, targets, device, batch_size=1000): |
| |
| |
| train_flag = model.training |
| model.eval() |
| n_batches = int(np.ceil(data.shape[0] / batch_size)) |
| n_total = 0 |
| n_correct = 0 |
| for batch_idx in range(n_batches): |
| start_idx = batch_idx * batch_size |
| end_idx = min((batch_idx + 1) * batch_size, data.shape[0]) |
| data_batch = data[start_idx:end_idx, :].clone().to(device) |
| targets_batch = targets[start_idx:end_idx].clone().to(device) |
| logits = model(data_batch) |
| confs, preds = F.softmax(logits, dim=1).max(dim=1) |
| n_total += targets_batch.size(0) |
| n_correct += (preds.eq(targets_batch).sum()).item() |
| acc = n_correct / n_total |
|
|
| |
| |
| if train_flag: |
| model.train() |
| return acc |
|
|
|
|
|
|