Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # Adapted from [VGGT-Long](https://github.com/DengKaiCQ/VGGT-Long) | |
| import numpy as np | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| def apply_transformation_residual_kernel( | |
| src_ptr, # [n, 3] | |
| tgt_ptr, # [n, 3] | |
| transformed_ptr, # [n, 3] | |
| residuals_ptr, # [n] | |
| s, | |
| R00, | |
| R01, | |
| R02, | |
| R10, | |
| R11, | |
| R12, | |
| R20, | |
| R21, | |
| R22, | |
| t0, | |
| t1, | |
| t2, | |
| n_points, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_points | |
| src_x = tl.load(src_ptr + offsets * 3 + 0, mask=mask) | |
| src_y = tl.load(src_ptr + offsets * 3 + 1, mask=mask) | |
| src_z = tl.load(src_ptr + offsets * 3 + 2, mask=mask) | |
| tgt_x = tl.load(tgt_ptr + offsets * 3 + 0, mask=mask) | |
| tgt_y = tl.load(tgt_ptr + offsets * 3 + 1, mask=mask) | |
| tgt_z = tl.load(tgt_ptr + offsets * 3 + 2, mask=mask) | |
| # transformed = s * (R @ p) + t | |
| transformed_x = s * (R00 * src_x + R01 * src_y + R02 * src_z) + t0 | |
| transformed_y = s * (R10 * src_x + R11 * src_y + R12 * src_z) + t1 | |
| transformed_z = s * (R20 * src_x + R21 * src_y + R22 * src_z) + t2 | |
| tl.store(transformed_ptr + offsets * 3 + 0, transformed_x, mask=mask) | |
| tl.store(transformed_ptr + offsets * 3 + 1, transformed_y, mask=mask) | |
| tl.store(transformed_ptr + offsets * 3 + 2, transformed_z, mask=mask) | |
| dx = tgt_x - transformed_x | |
| dy = tgt_y - transformed_y | |
| dz = tgt_z - transformed_z | |
| residual = tl.sqrt(dx * dx + dy * dy + dz * dz) | |
| tl.store(residuals_ptr + offsets, residual, mask=mask) | |
| def weighted_covariance_kernel( | |
| src_ptr, # [n, 3] | |
| tgt_ptr, # [n, 3] | |
| weights_ptr, # [n] | |
| mu_src0, | |
| mu_src1, | |
| mu_src2, | |
| mu_tgt0, | |
| mu_tgt1, | |
| mu_tgt2, | |
| H_ptr, # [3, 3] | |
| n_points, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_points | |
| w = tl.load(weights_ptr + offsets, mask=mask) | |
| src_x = tl.load(src_ptr + offsets * 3 + 0, mask=mask) | |
| src_y = tl.load(src_ptr + offsets * 3 + 1, mask=mask) | |
| src_z = tl.load(src_ptr + offsets * 3 + 2, mask=mask) | |
| tgt_x = tl.load(tgt_ptr + offsets * 3 + 0, mask=mask) | |
| tgt_y = tl.load(tgt_ptr + offsets * 3 + 1, mask=mask) | |
| tgt_z = tl.load(tgt_ptr + offsets * 3 + 2, mask=mask) | |
| src_centered_x = src_x - mu_src0 | |
| src_centered_y = src_y - mu_src1 | |
| src_centered_z = src_z - mu_src2 | |
| tgt_centered_x = tgt_x - mu_tgt0 | |
| tgt_centered_y = tgt_y - mu_tgt1 | |
| tgt_centered_z = tgt_z - mu_tgt2 | |
| sqrt_w = tl.sqrt(w) | |
| weighted_src_x = src_centered_x * sqrt_w | |
| weighted_src_y = src_centered_y * sqrt_w | |
| weighted_src_z = src_centered_z * sqrt_w | |
| weighted_tgt_x = tgt_centered_x * sqrt_w | |
| weighted_tgt_y = tgt_centered_y * sqrt_w | |
| weighted_tgt_z = tgt_centered_z * sqrt_w | |
| h00 = weighted_src_x * weighted_tgt_x | |
| h01 = weighted_src_x * weighted_tgt_y | |
| h02 = weighted_src_x * weighted_tgt_z | |
| h10 = weighted_src_y * weighted_tgt_x | |
| h11 = weighted_src_y * weighted_tgt_y | |
| h12 = weighted_src_y * weighted_tgt_z | |
| h20 = weighted_src_z * weighted_tgt_x | |
| h21 = weighted_src_z * weighted_tgt_y | |
| h22 = weighted_src_z * weighted_tgt_z | |
| tl.atomic_add(H_ptr + 0, tl.sum(h00, axis=0)) | |
| tl.atomic_add(H_ptr + 1, tl.sum(h01, axis=0)) | |
| tl.atomic_add(H_ptr + 2, tl.sum(h02, axis=0)) | |
| tl.atomic_add(H_ptr + 3, tl.sum(h10, axis=0)) | |
| tl.atomic_add(H_ptr + 4, tl.sum(h11, axis=0)) | |
| tl.atomic_add(H_ptr + 5, tl.sum(h12, axis=0)) | |
| tl.atomic_add(H_ptr + 6, tl.sum(h20, axis=0)) | |
| tl.atomic_add(H_ptr + 7, tl.sum(h21, axis=0)) | |
| tl.atomic_add(H_ptr + 8, tl.sum(h22, axis=0)) | |
| def compute_huber_weights_kernel( | |
| residuals_ptr, | |
| weights_ptr, | |
| delta, | |
| n_points, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_points | |
| r = tl.load(residuals_ptr + offsets, mask=mask) | |
| weight = tl.where(r > delta, delta / r, 1.0) | |
| tl.store(weights_ptr + offsets, weight, mask=mask) | |
| def weighted_mean_kernel( | |
| points_ptr, # [n, 3] | |
| weights_ptr, # [n] | |
| mean_ptr, # [sum(w*x), sum(w*y), sum(w*z), sum(w)] | |
| n_points, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| pid = tl.program_id(0) | |
| offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
| mask = offsets < n_points | |
| w = tl.load(weights_ptr + offsets, mask=mask) | |
| x = tl.load(points_ptr + offsets * 3 + 0, mask=mask) | |
| y = tl.load(points_ptr + offsets * 3 + 1, mask=mask) | |
| z = tl.load(points_ptr + offsets * 3 + 2, mask=mask) | |
| wx = w * x | |
| wy = w * y | |
| wz = w * z | |
| tl.atomic_add(mean_ptr + 0, tl.sum(wx, axis=0)) | |
| tl.atomic_add(mean_ptr + 1, tl.sum(wy, axis=0)) | |
| tl.atomic_add(mean_ptr + 2, tl.sum(wz, axis=0)) | |
| tl.atomic_add(mean_ptr + 3, tl.sum(w, axis=0)) | |
| def apply_transformation_residual_triton(src, tgt, s, R, t): | |
| n_points = src.shape[0] | |
| transformed = torch.empty_like(src) | |
| residuals = torch.empty(n_points, device=src.device, dtype=src.dtype) | |
| BLOCK_SIZE = 256 | |
| grid = (triton.cdiv(n_points, BLOCK_SIZE),) | |
| R_flat = R.contiguous().view(-1) | |
| t_flat = t.contiguous().view(-1) | |
| apply_transformation_residual_kernel[grid]( | |
| src, | |
| tgt, | |
| transformed, | |
| residuals, | |
| float(s), | |
| float(R_flat[0]), | |
| float(R_flat[1]), | |
| float(R_flat[2]), | |
| float(R_flat[3]), | |
| float(R_flat[4]), | |
| float(R_flat[5]), | |
| float(R_flat[6]), | |
| float(R_flat[7]), | |
| float(R_flat[8]), | |
| float(t_flat[0]), | |
| float(t_flat[1]), | |
| float(t_flat[2]), | |
| n_points, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| ) | |
| return transformed, residuals | |
| def compute_weighted_mean_triton(points, weights): | |
| n_points = points.shape[0] | |
| # [sum(w*x), sum(w*y), sum(w*z), sum(w)] | |
| mean_buffer = torch.zeros(4, device=points.device, dtype=points.dtype) | |
| BLOCK_SIZE = 256 | |
| grid = (triton.cdiv(n_points, BLOCK_SIZE),) | |
| weighted_mean_kernel[grid](points, weights, mean_buffer, n_points, BLOCK_SIZE=BLOCK_SIZE) | |
| total_weight = mean_buffer[3] | |
| if total_weight > 1e-12: | |
| mean = mean_buffer[:3] / total_weight | |
| else: | |
| mean = torch.zeros(3, device=points.device, dtype=points.dtype) | |
| return mean, total_weight | |
| def compute_weighted_covariance_triton(src, tgt, weights, mu_src, mu_tgt): | |
| n_points = src.shape[0] | |
| H = torch.zeros(9, device=src.device, dtype=src.dtype) | |
| BLOCK_SIZE = 256 | |
| grid = (triton.cdiv(n_points, BLOCK_SIZE),) | |
| mu_src_flat = mu_src.contiguous().view(-1) | |
| mu_tgt_flat = mu_tgt.contiguous().view(-1) | |
| weighted_covariance_kernel[grid]( | |
| src, | |
| tgt, | |
| weights, | |
| float(mu_src_flat[0]), | |
| float(mu_src_flat[1]), | |
| float(mu_src_flat[2]), | |
| float(mu_tgt_flat[0]), | |
| float(mu_tgt_flat[1]), | |
| float(mu_tgt_flat[2]), | |
| H, | |
| n_points, | |
| BLOCK_SIZE=BLOCK_SIZE, | |
| ) | |
| return H.reshape(3, 3) | |
| def compute_huber_weights_triton(residuals, delta): | |
| n_points = residuals.shape[0] | |
| weights = torch.empty_like(residuals) | |
| BLOCK_SIZE = 256 | |
| grid = (triton.cdiv(n_points, BLOCK_SIZE),) | |
| compute_huber_weights_kernel[grid]( | |
| residuals, weights, float(delta), n_points, BLOCK_SIZE=BLOCK_SIZE | |
| ) | |
| return weights | |
| def weighted_estimate_se3_triton(source_points, target_points, weights): | |
| source_points = torch.from_numpy(source_points).cuda().float() | |
| target_points = torch.from_numpy(target_points).cuda().float() | |
| weights = torch.from_numpy(weights).cuda().float() | |
| total_weight = torch.sum(weights) | |
| if total_weight < 1e-6: | |
| return ( | |
| 1.0, | |
| np.zeros(3, dtype=np.float32), | |
| np.zeros(3, dtype=np.float32), | |
| np.zeros((3, 3), dtype=np.float32), | |
| ) | |
| normalized_weights = weights / total_weight | |
| mu_src, _ = compute_weighted_mean_triton(source_points, normalized_weights) | |
| mu_tgt, _ = compute_weighted_mean_triton(target_points, normalized_weights) | |
| H = compute_weighted_covariance_triton( | |
| source_points, target_points, normalized_weights, mu_src, mu_tgt | |
| ) | |
| return 1.0, mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy() | |
| def weighted_estimate_sim3_triton(source_points, target_points, weights): | |
| source_points = torch.from_numpy(source_points).cuda().float() | |
| target_points = torch.from_numpy(target_points).cuda().float() | |
| weights = torch.from_numpy(weights).cuda().float() | |
| total_weight = torch.sum(weights) | |
| if total_weight < 1e-6: | |
| return ( | |
| -1.0, | |
| np.zeros(3, dtype=np.float32), | |
| np.zeros(3, dtype=np.float32), | |
| np.zeros((3, 3), dtype=np.float32), | |
| ) | |
| normalized_weights = weights / total_weight | |
| mu_src, _ = compute_weighted_mean_triton(source_points, normalized_weights) | |
| mu_tgt, _ = compute_weighted_mean_triton(target_points, normalized_weights) | |
| src_centered = source_points - mu_src | |
| tgt_centered = target_points - mu_tgt | |
| scale_src = torch.sqrt(torch.sum(normalized_weights * torch.sum(src_centered**2, dim=1))) | |
| scale_tgt = torch.sqrt(torch.sum(normalized_weights * torch.sum(tgt_centered**2, dim=1))) | |
| s = scale_tgt / scale_src | |
| weighted_src = s * src_centered | |
| H = compute_weighted_covariance_triton( | |
| weighted_src, | |
| tgt_centered, | |
| normalized_weights, | |
| torch.zeros_like(mu_src), | |
| torch.zeros_like(mu_tgt), | |
| ) | |
| return s.cpu().numpy(), mu_src.cpu().numpy(), mu_tgt.cpu().numpy(), H.cpu().numpy() | |
| def weighted_estimate_sim3_numba_triton( | |
| source_points, target_points, weights, align_method="sim3" | |
| ): | |
| if align_method == "sim3": | |
| s, mu_src, mu_tgt, H = weighted_estimate_sim3_triton(source_points, target_points, weights) | |
| elif align_method == "se3" or align_method == "scale+se3": | |
| s, mu_src, mu_tgt, H = weighted_estimate_se3_triton(source_points, target_points, weights) | |
| if s < 0: | |
| raise ValueError("Total weight too small for meaningful estimation") | |
| H_torch = torch.from_numpy(H).cuda().float() | |
| U, _, Vt = torch.linalg.svd(H_torch) | |
| U = U.cpu().numpy() | |
| Vt = Vt.cpu().numpy() | |
| R = Vt.T @ U.T | |
| if np.linalg.det(R) < 0: | |
| Vt[2, :] *= -1 | |
| R = Vt.T @ U.T | |
| mu_src = mu_src.astype(np.float32) | |
| mu_tgt = mu_tgt.astype(np.float32) | |
| R = R.astype(np.float32) | |
| if align_method == "se3" or align_method == "scale+se3": | |
| t = mu_tgt - R @ mu_src | |
| else: | |
| t = mu_tgt - s * R @ mu_src | |
| return s, R, t.astype(np.float32) | |
| def robust_weighted_estimate_sim3_triton( | |
| src, tgt, init_weights, delta=0.1, max_iters=20, tol=1e-9, align_method="sim3" | |
| ): | |
| src = src.astype(np.float32) | |
| tgt = tgt.astype(np.float32) | |
| init_weights = init_weights.astype(np.float32) | |
| src_torch = torch.from_numpy(src).cuda().float() | |
| tgt_torch = torch.from_numpy(tgt).cuda().float() | |
| init_weights_torch = torch.from_numpy(init_weights).cuda().float() | |
| s, R, t = weighted_estimate_sim3_numba_triton( | |
| src, tgt, init_weights, align_method=align_method | |
| ) | |
| R_torch = torch.from_numpy(R).cuda().float() | |
| t_torch = torch.from_numpy(t).cuda().float() | |
| s_torch = torch.tensor(s, device="cuda", dtype=torch.float32) | |
| prev_error = float("inf") | |
| for iter in range(max_iters): | |
| transformed, residuals = apply_transformation_residual_triton( | |
| src_torch, tgt_torch, s_torch, R_torch, t_torch | |
| ) | |
| mean_residual = torch.mean(residuals).cpu().numpy() | |
| print(f"Iter {iter}: Mean residual = {mean_residual:.6f}") | |
| huber_weights = compute_huber_weights_triton(residuals, delta) | |
| combined_weights = init_weights_torch * huber_weights | |
| combined_weights_sum = torch.sum(combined_weights) | |
| if combined_weights_sum > 1e-12: | |
| combined_weights /= combined_weights_sum | |
| else: | |
| combined_weights = init_weights_torch / torch.sum(init_weights_torch) | |
| combined_weights_np = combined_weights.cpu().numpy() | |
| s_new, R_new, t_new = weighted_estimate_sim3_numba_triton( | |
| src, tgt, combined_weights_np, align_method=align_method | |
| ) | |
| param_change = np.abs(s_new - s) + np.linalg.norm(t_new - t) | |
| rot_angle = np.arccos(min(1.0, max(-1.0, (np.trace(R_new @ R.T) - 1) / 2))) | |
| residuals_np = residuals.cpu().numpy() | |
| huber_loss_values = np.where( | |
| residuals_np <= delta, 0.5 * residuals_np**2, delta * (residuals_np - 0.5 * delta) | |
| ) | |
| current_error = np.sum(huber_loss_values * init_weights) | |
| if (param_change < tol and rot_angle < np.radians(0.1)) or ( | |
| abs(prev_error - current_error) < tol * prev_error | |
| ): | |
| print(f"Converged at iteration {iter}") | |
| break | |
| s, R, t = s_new, R_new, t_new | |
| s_torch = torch.tensor(s, device="cuda", dtype=torch.float32) | |
| R_torch = torch.from_numpy(R).cuda().float() | |
| t_torch = torch.from_numpy(t).cuda().float() | |
| prev_error = current_error | |
| return s, R, t | |
| def warmup_triton(): | |
| print("\nWarming up Triton functions...") | |
| n_points = 10000 | |
| src = np.random.randn(n_points, 3).astype(np.float32) | |
| tgt = np.random.randn(n_points, 3).astype(np.float32) | |
| weights = np.ones(n_points, dtype=np.float32) | |
| src_torch = torch.from_numpy(src).cuda().float() | |
| tgt_torch = torch.from_numpy(tgt).cuda().float() | |
| weights_torch = torch.from_numpy(weights).cuda().float() | |
| R = np.eye(3, dtype=np.float32) | |
| t = np.zeros(3, dtype=np.float32) | |
| s = np.float32(1.0) | |
| delta = np.float32(0.1) | |
| R_torch = torch.from_numpy(R).cuda().float() | |
| t_torch = torch.from_numpy(t).cuda().float() | |
| s_torch = torch.tensor(s, device="cuda", dtype=torch.float32) | |
| try: | |
| _, _ = apply_transformation_residual_triton( | |
| src_torch, tgt_torch, s_torch, R_torch, t_torch | |
| ) | |
| print(" - apply_transformation_residual_triton warmed up.") | |
| except Exception as e: | |
| print(f" ! Failed to warm up apply_transformation_residual_triton: {e}") | |
| try: | |
| _, _ = compute_weighted_mean_triton(src_torch, weights_torch) | |
| print(" - compute_weighted_mean_triton warmed up.") | |
| except Exception as e: | |
| print(f" ! Failed to warm up compute_weighted_mean_triton: {e}") | |
| try: | |
| mu_src, _ = compute_weighted_mean_triton(src_torch, weights_torch) | |
| mu_tgt, _ = compute_weighted_mean_triton(tgt_torch, weights_torch) | |
| _ = compute_weighted_covariance_triton(src_torch, tgt_torch, weights_torch, mu_src, mu_tgt) | |
| print(" - compute_weighted_covariance_triton warmed up.") | |
| except Exception as e: | |
| print(f" ! Failed to warm up compute_weighted_covariance_triton: {e}") | |
| try: | |
| residuals = torch.abs(torch.randn(n_points, device="cuda", dtype=torch.float32)) | |
| _ = compute_huber_weights_triton(residuals, delta) | |
| print(" - compute_huber_weights_triton warmed up.") | |
| except Exception as e: | |
| print(f" ! Failed to warm up compute_huber_weights_triton: {e}") | |
| print("Triton warm-up complete.\n") | |
| def print_gpu_memory(): | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 # GB | |
| cached = torch.cuda.memory_reserved() / 1024**3 # GB | |
| print(f"GPU Memory Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB") | |
| if __name__ == "__main__": | |
| warmup_triton() | |
| n_points = 7_500_000 | |
| src = np.random.randn(n_points, 3).astype(np.float32) | |
| true_R = np.array([[0.866, -0.5, 0], [0.5, 0.866, 0], [0, 0, 1]], dtype=np.float32) | |
| true_t = np.array([1.0, 2.0, 0.5], dtype=np.float32) | |
| true_s = 1.2 | |
| tgt = true_s * (src @ true_R.T) + true_t | |
| tgt += 0.01 * np.random.randn(*tgt.shape).astype(np.float32) | |
| weights = np.ones(n_points, dtype=np.float32) | |
| print_gpu_memory() | |
| s, R, t = robust_weighted_estimate_sim3_triton( | |
| src, tgt, weights, delta=0.1, max_iters=5, align_method="sim3" | |
| ) | |
| print(f"\nEstimated scale: {s:.6f}") | |
| print(f"Estimated rotation:\n{R}") | |
| print(f"Estimated translation: {t}") | |
| print_gpu_memory() | |