Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from abc import ABC, abstractmethod | |
| from omegaconf import DictConfig | |
| import cv2 | |
| import numpy as np | |
| import imageio | |
| from ppd.utils.logger import Log | |
| import time | |
| import h5py | |
| import torch | |
| from torchvision.transforms import Compose | |
| from PIL import Image | |
| class Dataset(ABC): | |
| def __init__(self, **kwargs): | |
| super(Dataset, self).__init__() | |
| self.cfg = DictConfig(kwargs) | |
| self.dataset_name = self.cfg.get('dataset_name', 'unknown') | |
| self.use_low = self.cfg.get('use_low', True) | |
| self.build_metas() | |
| self.build_transforms() | |
| Log.info( | |
| f'{self.cfg.split} split of {self.dataset_name} dataset: {len(self.rgb_files)} frames in total.') | |
| def build_metas(self): | |
| ''' | |
| prepare rgb_files, depth_files, low_files | |
| ''' | |
| pass | |
| # depth_files | |
| # rgb_files | |
| def build_transforms(self): | |
| transforms = self.cfg.get('transforms', []) | |
| if len(transforms) == 0: | |
| self.transform = lambda x: x | |
| return | |
| log_str = f'{self.dataset_name} transform layers: \n' | |
| for idx, transform in enumerate(transforms): | |
| log_str += (str(transform) + | |
| '\n') if idx != len(transforms) - 1 else str(transform) | |
| Log.info(log_str) | |
| self.transform = Compose(transforms) | |
| def read_rgb(self, index): | |
| img_path = self.rgb_files[index] | |
| start_time = time.time() | |
| rgb = cv2.imread(img_path) | |
| end_time = time.time() | |
| if end_time - start_time > 1: | |
| Log.warn(f'Long time to read {img_path}: {end_time - start_time}') | |
| rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) | |
| return np.asarray(rgb / 255.).astype(np.float32) | |
| def read_rgb_name(self, index): | |
| return '__'.join(self.rgb_files[index].split('/')[-2:]) | |
| def read_depth(self, index, depth=None): | |
| if not hasattr(self, 'depth_files'): | |
| return None, None | |
| Log.debug(index, self.depth_files[index]) | |
| start_time = time.time() | |
| if depth is not None: | |
| pass | |
| elif self.depth_files[index].endswith('.png'): | |
| depth_path = self.depth_files[index] | |
| depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | | |
| cv2.IMREAD_ANYDEPTH) / 1000. | |
| elif self.depth_files[index].endswith('.npz'): | |
| depth = np.load(self.depth_files[index])['data'] | |
| elif self.depth_files[index].endswith('.hdf5'): | |
| depth = h5py.File(self.depth_files[index])['dataset'] | |
| depth = np.asarray(depth) | |
| elif self.depth_files[index].endswith('.npy'): | |
| depth = np.load(self.depth_files[index]) | |
| else: | |
| raise ValueError(f"Invalid depth file: {self.depth_files[index]}") | |
| if len(depth.shape) == 2: | |
| pass | |
| elif len(depth.shape) == 3 and depth.shape[2] == 1: | |
| depth = depth[:, :, 0] | |
| else: | |
| raise ValueError(f"Invalid depth file: {self.depth_files[index]}") | |
| end_time = time.time() | |
| if end_time - start_time > 1: | |
| Log.warn( | |
| f'Long time to read {self.depth_files[index]}: {end_time - start_time}') | |
| valid_mask = np.logical_and( | |
| depth > 0.01, ~np.isnan(depth)) & (~np.isinf(depth)) | |
| if valid_mask.sum() == 0: | |
| Log.warn('No valid mask in the depth map of {}'.format( | |
| self.depth_files[index])) | |
| if valid_mask.sum() != 0 and np.isnan(depth).sum() != 0: | |
| depth[np.isnan(depth)] = depth[valid_mask].max() | |
| if valid_mask.sum() != 0 and np.isinf(depth).sum() != 0: | |
| depth[np.isinf(depth)] = depth[valid_mask].max() | |
| return depth, valid_mask.astype(np.uint8) | |
| def check_shape(self, rgb, dpt): | |
| assert (rgb.shape[:2] == dpt.shape[:2]), "rgb.shape: {}, dpt.shape: {}".format( | |
| rgb.shape, dpt.shape) | |
| assert (len(rgb.shape) == 3), "rgb.shape: {}".format(rgb.shape) | |
| assert (len(dpt.shape) == 2), "dpt.shape: {}".format(dpt.shape) | |
| def __getitem__(self, index): | |
| index = index % len(self.rgb_files) | |
| repeat_num = 0 | |
| while True: | |
| rgb, (dpt, msk) = self.read_rgb(index), self.read_depth(index) | |
| if dpt is not None: | |
| self.check_shape(rgb, dpt) | |
| sample = { | |
| 'image': rgb, | |
| } | |
| if dpt is not None: | |
| sample['depth'] = dpt | |
| sample['mask'] = msk | |
| sample = self.transform(sample) | |
| if 'mask' not in sample or sample['mask'].sum() >= 10: | |
| break | |
| else: | |
| repeat_num += 1 | |
| index = int(np.random.randint(0, len(self.rgb_files))) | |
| image_name = self.rgb_files[index] | |
| if repeat_num >= 1: | |
| Log.warn( | |
| f'No valid mask in the depth map of {image_name}.') | |
| elif repeat_num > 5: | |
| Log.warn( | |
| f'No valid mask in the depth map of {image_name}.') | |
| elif repeat_num > 10: | |
| raise ValueError( | |
| f'No valid mask in the depth map of {image_name}.') | |
| sample['dataset_name'] = self.dataset_name | |
| sample['image_name'] = self.read_rgb_name(index) | |
| sample['image_path'] = self.rgb_files[index] | |
| return sample | |
| def __len__(self): | |
| return len(self.rgb_files) | |