Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Last modified: 2025-01-14 | |
| # | |
| # Copyright 2025 Ziyang Song, USTC. All rights reserved. | |
| # | |
| # This file has been modified from the original version. | |
| # Original copyright (c) 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # -------------------------------------------------------------------------- | |
| # If you find this code useful, we kindly ask you to cite our paper in your work. | |
| # Please find bibtex at: https://github.com/indu1ge/DepthMaster#-citation | |
| # More information about the method can be found at https://indu1ge.github.io/DepthMaster_page | |
| # -------------------------------------------------------------------------- | |
| import io | |
| import os | |
| import random | |
| import tarfile | |
| from enum import Enum | |
| from typing import Union | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import InterpolationMode, Resize, RandomResizedCrop | |
| from src.util.depth_transform import DepthNormalizerBase | |
| from src.util.alignment import depth2disparity | |
| class DatasetMode(Enum): | |
| RGB_ONLY = "rgb_only" | |
| EVAL = "evaluate" | |
| TRAIN = "train" | |
| class DepthFileNameMode(Enum): | |
| """Prediction file naming modes""" | |
| id = 1 # id.png | |
| rgb_id = 2 # rgb_id.png | |
| i_d_rgb = 3 # i_d_1_rgb.png | |
| rgb_i_d = 4 | |
| def read_image_from_tar(tar_obj, img_rel_path): | |
| image = tar_obj.extractfile("./" + img_rel_path) | |
| image = image.read() | |
| image = Image.open(io.BytesIO(image)) | |
| class BaseDepthDataset(Dataset): | |
| def __init__( | |
| self, | |
| mode: DatasetMode, | |
| filename_ls_path: str, | |
| dataset_dir: str, | |
| disp_name: str, | |
| min_depth: float, | |
| max_depth: float, | |
| has_filled_depth: bool, | |
| has_egde_mask: bool, | |
| name_mode: DepthFileNameMode, | |
| depth_transform: Union[DepthNormalizerBase, None] = None, | |
| augmentation_args: dict = None, | |
| resize_to_hw=None, | |
| move_invalid_to_far_plane: bool = True, | |
| rgb_transform=lambda x: x / 255.0 * 2 - 1, # [0, 255] -> [-1, 1], | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| self.mode = mode | |
| # dataset info | |
| self.filename_ls_path = filename_ls_path | |
| self.dataset_dir = dataset_dir | |
| assert os.path.exists( | |
| self.dataset_dir | |
| ), f"Dataset does not exist at: {self.dataset_dir}" | |
| self.disp_name = disp_name | |
| self.has_filled_depth = has_filled_depth | |
| self.has_egde_mask = has_egde_mask | |
| self.name_mode: DepthFileNameMode = name_mode | |
| self.min_depth = min_depth | |
| self.max_depth = max_depth | |
| # training arguments | |
| self.depth_transform: DepthNormalizerBase = depth_transform | |
| self.augm_args = augmentation_args | |
| self.resize_to_hw = resize_to_hw | |
| self.rgb_transform = rgb_transform | |
| self.move_invalid_to_far_plane = move_invalid_to_far_plane | |
| # Load filenames | |
| with open(self.filename_ls_path, "r") as f: | |
| self.filenames = [ | |
| s.split() for s in f.readlines() | |
| ] # [['rgb.png', 'depth.tif'], [], ...] | |
| # Tar dataset | |
| self.tar_obj = None | |
| self.is_tar = ( | |
| True | |
| if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) | |
| else False | |
| ) | |
| def __len__(self): | |
| return len(self.filenames) | |
| def __getitem__(self, index): | |
| rasters, other = self._get_data_item(index) | |
| if DatasetMode.TRAIN == self.mode: | |
| rasters = self._training_preprocess(rasters) | |
| # merge | |
| outputs = rasters | |
| outputs.update(other) | |
| return outputs | |
| def _get_data_item(self, index): | |
| rgb_rel_path, depth_rel_path, filled_rel_path = self._get_data_path(index=index) | |
| rasters = {} | |
| # RGB data | |
| rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) | |
| # Depth data | |
| if DatasetMode.RGB_ONLY != self.mode: | |
| # load data | |
| depth_data = self._load_depth_data( | |
| depth_rel_path=depth_rel_path, filled_rel_path=filled_rel_path | |
| ) | |
| rasters.update(depth_data) | |
| # valid mask | |
| rasters["valid_mask_raw"] = self._get_valid_mask( | |
| rasters["depth_raw_linear"] | |
| ).clone() | |
| rasters["valid_mask_filled"] = self._get_valid_mask( | |
| rasters["depth_filled_linear"] | |
| ).clone() | |
| if DatasetMode.TRAIN == self.mode: | |
| # depth2disp | |
| rasters["depth_raw_linear"] = depth2disparity(rasters["depth_raw_linear"]).clone() | |
| if self.has_filled_depth: | |
| rasters["depth_filled_linear"] = depth2disparity(rasters["depth_filled_linear"]).clone() | |
| # sqrt(x) | |
| rasters["depth_raw_linear"] = torch.sqrt(rasters["depth_raw_linear"]).clone() | |
| if self.has_filled_depth: | |
| rasters["depth_filled_linear"] = torch.sqrt(rasters["depth_filled_linear"]).clone() | |
| other = {"index": index, "rgb_relative_path": rgb_rel_path} | |
| return rasters, other | |
| def _load_rgb_data(self, rgb_rel_path): | |
| # Read RGB data | |
| rgb = self._read_rgb_file(rgb_rel_path) | |
| rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] | |
| outputs = { | |
| "rgb_int": torch.from_numpy(rgb).int(), | |
| "rgb_norm": torch.from_numpy(rgb_norm).float(), | |
| } | |
| return outputs | |
| def _load_depth_data(self, depth_rel_path, filled_rel_path): | |
| # Read depth data | |
| outputs = {} | |
| depth_raw = self._read_depth_file(depth_rel_path).squeeze() | |
| depth_raw_linear = torch.from_numpy(depth_raw).float().unsqueeze(0) # [1, H, W] | |
| outputs["depth_raw_linear"] = depth_raw_linear.clone() | |
| if self.has_filled_depth: | |
| depth_filled = self._read_depth_file(filled_rel_path).squeeze() | |
| depth_filled_linear = torch.from_numpy(depth_filled).float().unsqueeze(0) | |
| outputs["depth_filled_linear"] = depth_filled_linear | |
| else: | |
| outputs["depth_filled_linear"] = depth_raw_linear.clone() | |
| return outputs | |
| def _get_data_path(self, index): | |
| filename_line = self.filenames[index] | |
| # Get data path | |
| rgb_rel_path = filename_line[0] | |
| depth_rel_path, filled_rel_path = None, None | |
| if DatasetMode.RGB_ONLY != self.mode: | |
| depth_rel_path = filename_line[1] | |
| if self.has_filled_depth: | |
| filled_rel_path = filename_line[2] | |
| return rgb_rel_path, depth_rel_path, filled_rel_path | |
| def _read_image(self, img_rel_path) -> np.ndarray: | |
| if self.is_tar: | |
| if self.tar_obj is None: | |
| self.tar_obj = tarfile.open(self.dataset_dir) | |
| image_to_read = self.tar_obj.extractfile("./" + img_rel_path) | |
| image_to_read = image_to_read.read() | |
| image_to_read = io.BytesIO(image_to_read) | |
| else: | |
| image_to_read = os.path.join(self.dataset_dir, img_rel_path) | |
| image = Image.open(image_to_read) # [H, W, rgb] | |
| image = np.asarray(image) | |
| return image | |
| def _read_rgb_file(self, rel_path) -> np.ndarray: | |
| rgb = self._read_image(rel_path) | |
| rgb = np.transpose(rgb, (2, 0, 1)).astype(int) # [rgb, H, W] | |
| return rgb | |
| def _read_depth_file(self, rel_path): | |
| depth_in = self._read_image(rel_path) | |
| # Replace code below to decode depth according to dataset definition | |
| depth_decoded = depth_in | |
| return depth_decoded | |
| def _get_valid_mask(self, depth: torch.Tensor): | |
| valid_mask = torch.logical_and( | |
| (depth > self.min_depth), (depth < self.max_depth) | |
| ).bool() | |
| return valid_mask | |
| def _training_preprocess(self, rasters): | |
| # Augmentation | |
| if self.augm_args is not None: | |
| rasters = self._augment_data(rasters) | |
| # Normalization | |
| rasters["depth_raw_norm"] = self.depth_transform( | |
| rasters["depth_raw_linear"], rasters["valid_mask_raw"] | |
| ).clone() | |
| rasters["depth_filled_norm"] = self.depth_transform( | |
| rasters["depth_filled_linear"], rasters["valid_mask_filled"] | |
| ).clone() | |
| # Set invalid pixel to far plane | |
| if self.move_invalid_to_far_plane: | |
| if self.depth_transform.far_plane_at_max: | |
| rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( | |
| self.depth_transform.norm_max | |
| ) | |
| else: | |
| rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( | |
| self.depth_transform.norm_min | |
| ) | |
| # Resize | |
| if self.resize_to_hw is not None: | |
| resize_transform = Resize( | |
| size=self.resize_to_hw, interpolation=InterpolationMode.NEAREST_EXACT | |
| ) | |
| rasters = {k: resize_transform(v) for k, v in rasters.items()} | |
| # # randomresizedcrop | |
| # resizedcrop = RandomResizedCrop(size=self.resize_to_hw, scale=(0.9, 1), ratio=()) | |
| return rasters | |
| def _augment_data(self, rasters_dict): | |
| # lr flipping | |
| lr_flip_p = self.augm_args.lr_flip_p | |
| if random.random() < lr_flip_p: | |
| rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} | |
| return rasters_dict | |
| def __del__(self): | |
| if hasattr(self, "tar_obj") and self.tar_obj is not None: | |
| self.tar_obj.close() | |
| self.tar_obj = None | |
| def get_pred_name(rgb_basename, name_mode, suffix=".png"): | |
| if DepthFileNameMode.rgb_id == name_mode: | |
| pred_basename = "pred_" + rgb_basename.split("_")[1] | |
| elif DepthFileNameMode.i_d_rgb == name_mode: | |
| pred_basename = rgb_basename.replace("_rgb.", "_pred.") | |
| elif DepthFileNameMode.id == name_mode: | |
| pred_basename = "pred_" + rgb_basename | |
| elif DepthFileNameMode.rgb_i_d == name_mode: | |
| pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) | |
| else: | |
| raise NotImplementedError | |
| # change suffix | |
| pred_basename = os.path.splitext(pred_basename)[0] + suffix | |
| return pred_basename | |