Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn.functional as F | |
| ''' | |
| VERTEX_SLICES: | |
| 0 2 | |
| 1 3 | |
| ''' | |
| VERTEX_SLICES = [ | |
| (slice(None, -1), slice(None, -1)), | |
| (slice(1, None), slice(None, -1)), | |
| (slice(None, -1), slice(1, None)), | |
| (slice(1, None), slice(1, None)), | |
| ] | |
| TRIANGLE_SLICES = [ | |
| [VERTEX_SLICES[0], VERTEX_SLICES[1], VERTEX_SLICES[2]], | |
| [VERTEX_SLICES[2], VERTEX_SLICES[0], VERTEX_SLICES[3]], | |
| [VERTEX_SLICES[0], VERTEX_SLICES[1], VERTEX_SLICES[3]], | |
| [VERTEX_SLICES[2], VERTEX_SLICES[1], VERTEX_SLICES[3]], | |
| ] | |
| NUM_TRIANGLE = len(TRIANGLE_SLICES) | |
| def _fetch_pixel_val(x: torch.Tensor, vertex_slice): | |
| ''' | |
| :param x: shape (H, W, ...) | |
| :param vertex_slice: | |
| :return: shape (H - 1, W - 1, ...) | |
| ''' | |
| return x[vertex_slice[0], vertex_slice[1]] | |
| def get_triangle_valid(valid: torch.Tensor): | |
| ''' | |
| a triangle is valid if all vertices are valid | |
| :param valid: shape (H, W) | |
| :return: triangle_valid | |
| triangle_valid: shape (H - 1, W - 1, NUM_TRIANGLE) | |
| ''' | |
| H, W = valid.shape | |
| device = valid.device | |
| ret = torch.empty((H - 1, W - 1, NUM_TRIANGLE), dtype=torch.bool, device=device) | |
| for i, TRIANGLE_SLICE in enumerate(TRIANGLE_SLICES): | |
| ret[..., i] = _fetch_pixel_val(valid, TRIANGLE_SLICE[0]) & \ | |
| _fetch_pixel_val(valid, TRIANGLE_SLICE[1]) & \ | |
| _fetch_pixel_val(valid, TRIANGLE_SLICE[2]) | |
| return ret | |
| def get_triangle_normal(xyz: torch.Tensor): | |
| ''' | |
| :param xyz: shape (H, W, 3) | |
| :return: normal, normal_valid | |
| normal: shape (H - 1, W - 1, NUM_TRIANGLE, 3) | |
| normal_valid: shape (H - 1, W - 1, NUM_TRIANGLE) | |
| ''' | |
| H, W = xyz.shape[:2] | |
| device = xyz.device | |
| dtype = xyz.dtype | |
| normal = torch.empty((H - 1, W - 1, NUM_TRIANGLE, 3), dtype=dtype, device=device) | |
| normal_valid = torch.empty((H - 1, W - 1, NUM_TRIANGLE), dtype=torch.bool, device=device) | |
| for i, TRIANGLE_SLICE in enumerate(TRIANGLE_SLICES): | |
| normal[..., i, :] = torch.linalg.cross( | |
| F.normalize(_fetch_pixel_val(xyz, TRIANGLE_SLICE[1]) - _fetch_pixel_val(xyz, TRIANGLE_SLICE[0]), dim=-1), | |
| F.normalize(_fetch_pixel_val(xyz, TRIANGLE_SLICE[2]) - _fetch_pixel_val(xyz, TRIANGLE_SLICE[0]), dim=-1), | |
| dim=-1 | |
| ) | |
| vec_norm = torch.norm(normal[..., i, :], dim=-1) # (H - 1, W - 1) | |
| normal_valid[..., i] = vec_norm > 1e-5 | |
| normal[..., i, :] /= vec_norm.clamp(min=1e-5).unsqueeze(-1) | |
| return normal, normal_valid | |
| def get_triangle_normal_and_valid(xyz: torch.Tensor, valid: torch.Tensor, flatten: bool = True): | |
| ''' | |
| if gt_d and depth_layer are not None, filter out triangle across depth layers | |
| :param xyz: | |
| :param valid: | |
| :param flatten: | |
| :return: normal, valid | |
| ''' | |
| normal, normal_valid = get_triangle_normal(xyz) | |
| tri_valid = get_triangle_valid(valid) | |
| valid = normal_valid & tri_valid | |
| if flatten: | |
| normal = normal.reshape(-1, 3) | |
| valid = valid.reshape(-1) | |
| return normal, valid | |