Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Reference View Selection Strategies | |
| This module provides different strategies for selecting a reference view | |
| from multiple input views in multi-view depth estimation. | |
| """ | |
| import torch | |
| from typing import Literal | |
| RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"] | |
| def select_reference_view( | |
| x: torch.Tensor, | |
| strategy: RefViewStrategy = "saddle_balanced", | |
| ) -> torch.Tensor: | |
| """ | |
| Select a reference view from multiple views using the specified strategy. | |
| Args: | |
| x: Input tensor of shape (B, S, N, C) where | |
| B = batch size | |
| S = number of views | |
| N = number of tokens | |
| C = channel dimension | |
| strategy: Selection strategy, one of: | |
| - "first": Always select the first view | |
| - "middle": Select the middle view | |
| - "saddle_balanced": Select view with balanced features across multiple metrics | |
| - "saddle_sim_range": Select view with largest similarity range | |
| Returns: | |
| b_idx: Tensor of shape (B,) containing the selected view index for each batch | |
| """ | |
| B, S, N, C = x.shape | |
| # For single view, no reordering needed | |
| if S <= 1: | |
| return torch.zeros(B, dtype=torch.long, device=x.device) | |
| # Simple position-based strategies | |
| if strategy == "first": | |
| return torch.zeros(B, dtype=torch.long, device=x.device) | |
| elif strategy == "middle": | |
| return torch.full((B,), S // 2, dtype=torch.long, device=x.device) | |
| # Feature-based strategies require normalized class tokens | |
| # Extract and normalize class tokens (first token of each view) | |
| img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # B S C | |
| if strategy == "saddle_balanced": | |
| # Select view with balanced features across multiple metrics | |
| # Compute similarity matrix | |
| sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S | |
| sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) | |
| sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # B S | |
| feat_norm = x[:, :, 0].norm(dim=-1) # B S | |
| feat_var = img_class_feat.var(dim=-1) # B S | |
| # Normalize all metrics to [0, 1] | |
| def normalize_metric(metric): | |
| min_val = metric.min(dim=1, keepdim=True).values | |
| max_val = metric.max(dim=1, keepdim=True).values | |
| return (metric - min_val) / (max_val - min_val + 1e-8) | |
| sim_score_norm = normalize_metric(sim_score) | |
| norm_norm = normalize_metric(feat_norm) | |
| var_norm = normalize_metric(feat_var) | |
| # Select view closest to the median (0.5) across all metrics | |
| balance_score = ( | |
| (sim_score_norm - 0.5).abs() + | |
| (norm_norm - 0.5).abs() + | |
| (var_norm - 0.5).abs() | |
| ) | |
| b_idx = balance_score.argmin(dim=1) | |
| elif strategy == "saddle_sim_range": | |
| # Select view with largest similarity range (max - min) | |
| sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # B S S | |
| sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) | |
| sim_max = sim_no_diag.max(dim=-1).values # B S | |
| sim_min = sim_no_diag.min(dim=-1).values # B S | |
| sim_range = sim_max - sim_min | |
| b_idx = sim_range.argmax(dim=1) | |
| else: | |
| raise ValueError( | |
| f"Unknown reference view selection strategy: {strategy}. " | |
| f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'" | |
| ) | |
| return b_idx | |
| def reorder_by_reference( | |
| x: torch.Tensor, | |
| b_idx: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Reorder views to place the selected reference view first. | |
| Args: | |
| x: Input tensor of shape (B, S, N, C) | |
| b_idx: Reference view indices of shape (B,) | |
| Returns: | |
| Reordered tensor with reference view at position 0 | |
| Example: | |
| If b_idx = [2] and S = 5 (views [0,1,2,3,4]), | |
| result order is [2,0,1,3,4] (ref_idx first, then others in order) | |
| """ | |
| B, S = x.shape[0], x.shape[1] | |
| # For single view, no reordering needed | |
| if S <= 1: | |
| return x | |
| # Create position indices: (B, S) where each row is [0, 1, 2, ..., S-1] | |
| positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S | |
| # For each position, determine which original index it should take | |
| # Position 0 gets ref_idx | |
| # Position 1 to ref_idx gets indices 0 to ref_idx-1 | |
| # Position ref_idx+1 to S-1 gets indices ref_idx+1 to S-1 | |
| b_idx_expanded = b_idx.unsqueeze(1) # B 1 | |
| # Create the reordering indices | |
| # For positions 1 to ref_idx: map to indices 0 to ref_idx-1 (shift by -1) | |
| # For positions > ref_idx: keep the same | |
| reorder_indices = positions.clone() | |
| reorder_indices = torch.where( | |
| (positions > 0) & (positions <= b_idx_expanded), | |
| positions - 1, | |
| positions | |
| ) | |
| # Set position 0 to ref_idx | |
| reorder_indices[:, 0] = b_idx | |
| # Gather using advanced indexing | |
| batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1 | |
| x_reordered = x[batch_indices, reorder_indices] | |
| return x_reordered | |
| def restore_original_order( | |
| x: torch.Tensor, | |
| b_idx: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Restore original view order after processing. | |
| Args: | |
| x: Reordered tensor of shape (B, S, ...) | |
| b_idx: Original reference view indices of shape (B,) | |
| Returns: | |
| Tensor with original view order restored | |
| Example: | |
| If original order was [0, 1, 2, 3, 4] and b_idx=2, | |
| reordered becomes [2, 0, 1, 3, 4] (reference at position 0), | |
| restore should return [0, 1, 2, 3, 4] (original order). | |
| """ | |
| B, S = x.shape[0], x.shape[1] | |
| # For single view, no restoration needed | |
| if S <= 1: | |
| return x | |
| # Create target position indices: (B, S) where each row is [0, 1, 2, ..., S-1] | |
| target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) # B S | |
| # For each target position, determine which current position it comes from | |
| # Target position 0 to ref_idx-1 <- Current position 1 to ref_idx (shift by +1) | |
| # Target position ref_idx <- Current position 0 | |
| # Target position ref_idx+1 to S-1 <- Current position ref_idx+1 to S-1 (no change) | |
| b_idx_expanded = b_idx.unsqueeze(1) # B 1 | |
| # Create the restore indices | |
| restore_indices = torch.where( | |
| target_positions < b_idx_expanded, | |
| target_positions + 1, # Positions before ref_idx come from current position + 1 | |
| target_positions # Positions after ref_idx stay the same | |
| ) | |
| # Target position = ref_idx comes from current position 0 | |
| # Use scatter to set specific positions | |
| restore_indices = torch.scatter( | |
| restore_indices, | |
| dim=1, | |
| index=b_idx_expanded, | |
| src=torch.zeros_like(b_idx_expanded) | |
| ) | |
| # Gather using advanced indexing | |
| batch_indices = torch.arange(B, device=x.device).unsqueeze(1) # B 1 | |
| x_restored = x[batch_indices, restore_indices] | |
| return x_restored | |