Spaces:
Running on Zero
Running on Zero
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from tqdm import tqdm | |
| # from open_clip import get_cast_dtype | |
| # from training.precision import get_autocast | |
| from .voc_dataset import voc_extended_classes | |
| import os | |
| import re | |
| import scipy | |
| import numpy as np | |
| import time | |
| from eval.models import dust3r, mast3r, Sab3r | |
| from datasets import load_dataset | |
| import torch | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| import PIL.Image | |
| import PIL | |
| from featup.util import norm, unnorm | |
| from featup.plotting import plot_feats, plot_lang_heatmaps | |
| import torchvision.transforms as tvf | |
| from pytorch_lightning import seed_everything | |
| from featup.util import pca, remove_axes | |
| import matplotlib.pyplot as plt | |
| from featup.featurizers.maskclip.clip import tokenize | |
| from tqdm import tqdm | |
| class PascalMIoU: | |
| def __init__(self): | |
| self._num_classes = 20 + 1 # background | |
| self.confusion_matrix = None | |
| self._num_examples = 0 | |
| self.reset() | |
| def reset(self): | |
| self.confusion_matrix = np.zeros((self._num_classes,) * 2, dtype=np.int64) | |
| self._num_examples = 0 | |
| def update(self, prediction, label): | |
| self._num_examples += label.shape[0] | |
| prediction[prediction >= self._num_classes] = 0 | |
| mask = label < 255 # skip-pixel for VOC dataset | |
| indices = self._num_classes * label[mask] + prediction[mask] | |
| m = np.bincount(indices, minlength=self._num_classes**2).reshape( | |
| self._num_classes, self._num_classes | |
| ) | |
| self.confusion_matrix += m | |
| def compute(self): | |
| if self._num_examples == 0: | |
| return 0 | |
| cm = self.confusion_matrix | |
| iou_list = cm.diagonal() / ( | |
| cm.sum(axis=0) + cm.sum(axis=1) - cm.diagonal() + np.finfo(np.float32).eps | |
| ) | |
| return {"mIoU": np.nanmean(iou_list), "per_class": iou_list.tolist()} | |
| def get_similarity( | |
| image_encodings, | |
| label_encodings, | |
| target_shape, | |
| interpolation="bilinear", | |
| do_argmax=False, | |
| ): | |
| """ | |
| Args: | |
| image_encodings: | |
| label_encodings: | |
| target_shape: | |
| interpolation: nearest, bilinear | |
| do_argmax: | |
| Returns: | |
| """ | |
| image_encodings = image_encodings.cpu() | |
| label_encodings = label_encodings.cpu() | |
| image_encodings = rearrange(image_encodings, "b d h w -> b h w d") | |
| similarity = image_encodings @ label_encodings.T | |
| similarity = rearrange(similarity, "b h w d-> b d h w") | |
| if do_argmax: | |
| similarity = torch.argmax(similarity, dim=1, keepdim=True).to(torch.float64) | |
| return similarity | |
| class PascalEvaluator: | |
| def __init__(self, model, upsampler, dataset, metric, opts): | |
| self.classes = voc_extended_classes | |
| self._class_prompts = self.get_class_prompts() | |
| self.class_embeddings = None | |
| self.model = model | |
| self.upsampler = upsampler | |
| self.dataset = dataset | |
| self.metric: PascalMIoU = metric | |
| self.opts = opts | |
| self.dataloader = torch.utils.data.DataLoader( | |
| dataset, batch_size=opts.batch_size, num_workers=opts.workers | |
| ) | |
| def evaluate(self): | |
| self.metric.reset() | |
| if self.class_embeddings is None: | |
| self.class_embeddings = self.get_class_embeddings() | |
| # autocast = get_autocast(self.opts.precision) | |
| # cast_dtype = get_cast_dtype(self.opts.precision) | |
| with torch.no_grad(): | |
| for images, target in tqdm( | |
| self.dataloader, unit_scale=self.opts.batch_size | |
| ): | |
| breakpoint() | |
| image_features = self.model.predict(images) | |
| clip_features = ( | |
| image_features.get_clip() | |
| ) # Assuming this returns features of shape (512, H, W) | |
| pred_0 = ( | |
| torch.tensor(clip_features[0]).cuda().contiguous().permute(2, 0, 1) | |
| ).unsqueeze( | |
| 0 | |
| ) # Shape should be (1, 512, H, W) | |
| similarity = get_similarity( | |
| pred_0, self.class_embeddings, target.shape, do_argmax=True | |
| ) | |
| similarity = similarity[:, 0, :, :] | |
| pred = similarity.detach().to(torch.int64) # .to(self.opts.device) | |
| target = target.to(torch.int64) | |
| # target = target.to(self.opts.device) | |
| self.metric.update(pred, target) | |
| def get_class_embeddings(self): | |
| # self.model.eval() | |
| cls_names = [name.lower() for name in self._class_prompts.values()] | |
| with torch.no_grad(): | |
| tokenized_text = tokenize(cls_names).to( | |
| "cuda" | |
| ) # Tokenize and move to the correct device | |
| class_embeddings = self.upsampler.model.model.encode_text(tokenized_text) | |
| class_embeddings = F.normalize(class_embeddings, dim=-1) | |
| class_embeddings /= class_embeddings.norm() | |
| return class_embeddings | |
| def get_class_prompts(self): | |
| class_prompts = {} | |
| for idx, c in enumerate(self.classes): | |
| if c.startswith(tuple("aeiou")): | |
| class_prompts[idx] = f"a photo of an {c}" | |
| else: | |
| class_prompts[idx] = f"a photo of a {c}" | |
| return class_prompts | |