Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from utils.metric import compute_matting_metrics | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import argparse | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm import tqdm # 使用tqdm来显示进度条,对并行任务尤其友好 | |
| import multiprocessing as mp # 核心库 | |
| # Helper functions (resize_tensor, load_image_rgb_or_grayscale) 保持不变 | |
| def resize_tensor(input_tensor, target_height, target_width): | |
| """ | |
| 使用双线性插值调整深度图像大小 | |
| """ | |
| input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).unsqueeze(0).float() | |
| resized_tensor = F.interpolate( | |
| input_tensor, | |
| size=(target_height, target_width), | |
| mode='bilinear', | |
| align_corners=True | |
| ) | |
| return resized_tensor.squeeze().numpy() | |
| def load_image_rgb_or_grayscale(image_path): | |
| """ | |
| 加载图像,支持RGB和灰度图像,统一转换为numpy数组 | |
| """ | |
| try: | |
| if image_path.endswith('.npy'): | |
| img_array = np.load(image_path) | |
| else: | |
| img = Image.open(image_path) | |
| img_array = np.array(img) | |
| if len(img_array.shape) == 3 and img_array.shape[2] == 4: | |
| img_array = img_array[:, :, :3] | |
| return np.mean(img_array, axis=2) if len(img_array.shape) == 3 else img_array | |
| except Exception as e: | |
| # 增加错误日志,方便调试 | |
| print(f"Warning: Failed to load image from {image_path}. Error: {e}") | |
| return None | |
| # ------------------- 新增:独立的工作函数 ------------------- | |
| # 这个函数包含了处理单个样本的所有逻辑,它将在一个独立的进程中被调用。 | |
| def process_sample(task_args): | |
| """ | |
| 处理单个样本:加载数据、预处理、计算指标。 | |
| Args: | |
| task_args (tuple): 包含单个任务所需所有路径的元组 (pred_path, alpha_path, trimap_path) | |
| Returns: | |
| tuple: 包含 (mse, mad, sad, grad, conn) 的元组,如果出错则返回五个 np.nan | |
| """ | |
| pred_path, alpha_path, trimap_path = task_args | |
| try: | |
| # 1. 加载数据 | |
| pred_depth = load_image_rgb_or_grayscale(pred_path) | |
| alpha = load_image_rgb_or_grayscale(alpha_path) | |
| trimap = load_image_rgb_or_grayscale(trimap_path) | |
| # 如果任何一个文件加载失败,则跳过 | |
| if pred_depth is None or alpha is None or trimap is None: | |
| return (np.nan, np.nan, np.nan, np.nan, np.nan) | |
| # 2. 预处理 (与原eval函数中的逻辑相同) | |
| gt_depth = alpha | |
| gt_depth[np.isinf(gt_depth)] = 0 | |
| gt_depth[np.isnan(gt_depth)] = 0 | |
| pred_depth[np.isinf(pred_depth)] = 0 | |
| pred_depth[np.isnan(pred_depth)] = 0 | |
| if pred_depth.shape != gt_depth.shape: | |
| pred_depth = resize_tensor(pred_depth, gt_depth.shape[0], gt_depth.shape[1]) | |
| # pred_depth = pred_depth > 0.5 | |
| gt_depth = gt_depth.astype(np.float32) / 255.0 | |
| # 3. 计算指标 | |
| mse, mad, sad, grad, conn = compute_matting_metrics(pred_depth, gt_depth, trimap, whole=True) | |
| # 检查结果是否有效 | |
| if not all(v >= 0 for v in [mse, mad, sad, grad, conn]): | |
| return (np.nan, np.nan, np.nan, np.nan, np.nan) | |
| return (mse, mad, sad, grad, conn) | |
| except Exception as e: | |
| print(f'Error processing {os.path.basename(pred_path)}: {e}') | |
| # 返回NaN,方便后续统计时忽略错误样本 | |
| return (np.nan, np.nan, np.nan, np.nan, np.nan) | |
| # ------------------- 重构后的 test 函数 (接口不变) ------------------- | |
| def test(args): | |
| # 1. 准备任务列表 (与原代码相同) | |
| # 这部分逻辑不变,只是现在我们不加载图像,而是收集文件路径来创建任务 | |
| dataset_paths = { | |
| "comp": "./data_split/comp_matting/filenames_test.txt", | |
| "aim": "./data_split/AIM_matting/filenames_val.txt", | |
| "p3m": "./data_split/P3M_matting/filenames_val_P.txt", | |
| "p3m-np": "./data_split/P3M_matting/filenames_val_NP.txt", | |
| "am": "./data_split/AM_matting/filenames_val.txt", | |
| } | |
| filenames_path = dataset_paths.get(args.dataset) | |
| if not filenames_path: | |
| raise ValueError(f"Unknown dataset: {args.dataset}") | |
| with open(filenames_path, "r") as f: | |
| test_filenames = [line.strip() for line in f.readlines()] | |
| # 在非full模式下截断文件名列表 | |
| if "full" not in args.pred_path and "result" not in args.pred_path and args.dataset in ["comp", "p3m", "p3m-np", "am"]: | |
| test_filenames = test_filenames[:10] | |
| if getattr(args, 'max_samples', None) is not None: | |
| test_filenames = test_filenames[:args.max_samples] | |
| # 2. 创建任务列表,每个任务是一个包含所需文件路径的元组 | |
| tasks = [] | |
| for line in test_filenames: | |
| items = line.split() | |
| # 兼容不同格式的 filenames.txt | |
| merged_path, trimap_path, alpha_path = items[0], items[1], items[-1] | |
| pred_full_path = os.path.join(args.pred_path, merged_path.replace(".jpg",".npy").replace(".png",".npy")) | |
| alpha_full_path = os.path.join(args.gt_path, alpha_path) | |
| trimap_full_path = os.path.join(args.gt_path, trimap_path) | |
| tasks.append((pred_full_path, alpha_full_path, trimap_full_path)) | |
| print(f'### Starting evaluation for {len(tasks)} files on dataset "{args.dataset}"...') | |
| mp_context = mp.get_context('spawn') | |
| # 3. 使用多进程池执行任务 | |
| # os.cpu_count() 可以获取机器的CPU核心数,自动利用所有资源 | |
| # `with`语句可以确保进程池在使用后被正确关闭 | |
| results_list = [] | |
| with mp_context.Pool(processes=32) as pool: | |
| # pool.imap_unordered 比 map 更高效,因为它不保证返回结果的顺序 | |
| # 使用tqdm来包装,可以实时看到处理进度 | |
| results_iterator = pool.imap_unordered(process_sample, tasks) | |
| for result in tqdm(results_iterator, total=len(tasks), desc="Processing samples"): | |
| results_list.append(result) | |
| # 4. 汇总结果 | |
| # 将结果列表转换为numpy数组,方便进行列操作 | |
| results_array = np.array(results_list) | |
| # 使用 np.nanmean 来计算平均值,它会自动忽略所有NaN值(即处理失败的样本) | |
| # 这比手动过滤更简洁、更安全 | |
| mean_mse = np.nanmean(results_array[:, 0]) | |
| mean_mad = np.nanmean(results_array[:, 1]) | |
| mean_sad = np.nanmean(results_array[:, 2]) | |
| mean_grad = np.nanmean(results_array[:, 3]) | |
| mean_conn = np.nanmean(results_array[:, 4]) | |
| # 统计有效结果的数量 | |
| valid_results_count = np.sum(~np.isnan(results_array[:, 0])) | |
| # 5. 打印结果 (与原代码格式一致) | |
| final_results_str = "{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}".format( | |
| mean_mse, mean_mad, mean_sad, mean_grad, mean_conn | |
| ) | |
| print('Done.') | |
| print(f'Valid results: {valid_results_count}/{len(tasks)}') | |
| print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('MSE', 'MAD', 'SAD', 'Grad', 'Conn')) | |
| print(final_results_str) | |
| return final_results_str | |
| # 主程序入口保持不变 | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@') | |
| parser.add_argument('--pred_path', type=str, help='path to the prediction results in png') | |
| parser.add_argument('--gt_path', type=str, help='root path to the groundtruth data' ) | |
| parser.add_argument('--dataset', type=str, help='dataset to test on, e.g., p3m, comp, etc.', default='p3m') | |
| parser.add_argument('--max_samples', type=int, help='maximum number of samples to test', default=None) | |
| args = parser.parse_args() | |
| test(args) |