Diffusers
Safetensors
EvalMDE / Edit2Perceive /utils /eval_matting.py
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
from utils.metric import compute_matting_metrics
from PIL import Image
import numpy as np
import os
import argparse
import torch
import torch.nn.functional as F
from tqdm import tqdm # 使用tqdm来显示进度条,对并行任务尤其友好
import multiprocessing as mp # 核心库
# Helper functions (resize_tensor, load_image_rgb_or_grayscale) 保持不变
def resize_tensor(input_tensor, target_height, target_width):
"""
使用双线性插值调整深度图像大小
"""
input_tensor = torch.from_numpy(input_tensor).unsqueeze(0).unsqueeze(0).float()
resized_tensor = F.interpolate(
input_tensor,
size=(target_height, target_width),
mode='bilinear',
align_corners=True
)
return resized_tensor.squeeze().numpy()
def load_image_rgb_or_grayscale(image_path):
"""
加载图像,支持RGB和灰度图像,统一转换为numpy数组
"""
try:
if image_path.endswith('.npy'):
img_array = np.load(image_path)
else:
img = Image.open(image_path)
img_array = np.array(img)
if len(img_array.shape) == 3 and img_array.shape[2] == 4:
img_array = img_array[:, :, :3]
return np.mean(img_array, axis=2) if len(img_array.shape) == 3 else img_array
except Exception as e:
# 增加错误日志,方便调试
print(f"Warning: Failed to load image from {image_path}. Error: {e}")
return None
# ------------------- 新增:独立的工作函数 -------------------
# 这个函数包含了处理单个样本的所有逻辑,它将在一个独立的进程中被调用。
def process_sample(task_args):
"""
处理单个样本:加载数据、预处理、计算指标。
Args:
task_args (tuple): 包含单个任务所需所有路径的元组 (pred_path, alpha_path, trimap_path)
Returns:
tuple: 包含 (mse, mad, sad, grad, conn) 的元组,如果出错则返回五个 np.nan
"""
pred_path, alpha_path, trimap_path = task_args
try:
# 1. 加载数据
pred_depth = load_image_rgb_or_grayscale(pred_path)
alpha = load_image_rgb_or_grayscale(alpha_path)
trimap = load_image_rgb_or_grayscale(trimap_path)
# 如果任何一个文件加载失败,则跳过
if pred_depth is None or alpha is None or trimap is None:
return (np.nan, np.nan, np.nan, np.nan, np.nan)
# 2. 预处理 (与原eval函数中的逻辑相同)
gt_depth = alpha
gt_depth[np.isinf(gt_depth)] = 0
gt_depth[np.isnan(gt_depth)] = 0
pred_depth[np.isinf(pred_depth)] = 0
pred_depth[np.isnan(pred_depth)] = 0
if pred_depth.shape != gt_depth.shape:
pred_depth = resize_tensor(pred_depth, gt_depth.shape[0], gt_depth.shape[1])
# pred_depth = pred_depth > 0.5
gt_depth = gt_depth.astype(np.float32) / 255.0
# 3. 计算指标
mse, mad, sad, grad, conn = compute_matting_metrics(pred_depth, gt_depth, trimap, whole=True)
# 检查结果是否有效
if not all(v >= 0 for v in [mse, mad, sad, grad, conn]):
return (np.nan, np.nan, np.nan, np.nan, np.nan)
return (mse, mad, sad, grad, conn)
except Exception as e:
print(f'Error processing {os.path.basename(pred_path)}: {e}')
# 返回NaN,方便后续统计时忽略错误样本
return (np.nan, np.nan, np.nan, np.nan, np.nan)
# ------------------- 重构后的 test 函数 (接口不变) -------------------
def test(args):
# 1. 准备任务列表 (与原代码相同)
# 这部分逻辑不变,只是现在我们不加载图像,而是收集文件路径来创建任务
dataset_paths = {
"comp": "./data_split/comp_matting/filenames_test.txt",
"aim": "./data_split/AIM_matting/filenames_val.txt",
"p3m": "./data_split/P3M_matting/filenames_val_P.txt",
"p3m-np": "./data_split/P3M_matting/filenames_val_NP.txt",
"am": "./data_split/AM_matting/filenames_val.txt",
}
filenames_path = dataset_paths.get(args.dataset)
if not filenames_path:
raise ValueError(f"Unknown dataset: {args.dataset}")
with open(filenames_path, "r") as f:
test_filenames = [line.strip() for line in f.readlines()]
# 在非full模式下截断文件名列表
if "full" not in args.pred_path and "result" not in args.pred_path and args.dataset in ["comp", "p3m", "p3m-np", "am"]:
test_filenames = test_filenames[:10]
if getattr(args, 'max_samples', None) is not None:
test_filenames = test_filenames[:args.max_samples]
# 2. 创建任务列表,每个任务是一个包含所需文件路径的元组
tasks = []
for line in test_filenames:
items = line.split()
# 兼容不同格式的 filenames.txt
merged_path, trimap_path, alpha_path = items[0], items[1], items[-1]
pred_full_path = os.path.join(args.pred_path, merged_path.replace(".jpg",".npy").replace(".png",".npy"))
alpha_full_path = os.path.join(args.gt_path, alpha_path)
trimap_full_path = os.path.join(args.gt_path, trimap_path)
tasks.append((pred_full_path, alpha_full_path, trimap_full_path))
print(f'### Starting evaluation for {len(tasks)} files on dataset "{args.dataset}"...')
mp_context = mp.get_context('spawn')
# 3. 使用多进程池执行任务
# os.cpu_count() 可以获取机器的CPU核心数,自动利用所有资源
# `with`语句可以确保进程池在使用后被正确关闭
results_list = []
with mp_context.Pool(processes=32) as pool:
# pool.imap_unordered 比 map 更高效,因为它不保证返回结果的顺序
# 使用tqdm来包装,可以实时看到处理进度
results_iterator = pool.imap_unordered(process_sample, tasks)
for result in tqdm(results_iterator, total=len(tasks), desc="Processing samples"):
results_list.append(result)
# 4. 汇总结果
# 将结果列表转换为numpy数组,方便进行列操作
results_array = np.array(results_list)
# 使用 np.nanmean 来计算平均值,它会自动忽略所有NaN值(即处理失败的样本)
# 这比手动过滤更简洁、更安全
mean_mse = np.nanmean(results_array[:, 0])
mean_mad = np.nanmean(results_array[:, 1])
mean_sad = np.nanmean(results_array[:, 2])
mean_grad = np.nanmean(results_array[:, 3])
mean_conn = np.nanmean(results_array[:, 4])
# 统计有效结果的数量
valid_results_count = np.sum(~np.isnan(results_array[:, 0]))
# 5. 打印结果 (与原代码格式一致)
final_results_str = "{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}".format(
mean_mse, mean_mad, mean_sad, mean_grad, mean_conn
)
print('Done.')
print(f'Valid results: {valid_results_count}/{len(tasks)}')
print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('MSE', 'MAD', 'SAD', 'Grad', 'Conn'))
print(final_results_str)
return final_results_str
# 主程序入口保持不变
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='BTS TensorFlow implementation.', fromfile_prefix_chars='@')
parser.add_argument('--pred_path', type=str, help='path to the prediction results in png')
parser.add_argument('--gt_path', type=str, help='root path to the groundtruth data' )
parser.add_argument('--dataset', type=str, help='dataset to test on, e.g., p3m, comp, etc.', default='p3m')
parser.add_argument('--max_samples', type=int, help='maximum number of samples to test', default=None)
args = parser.parse_args()
test(args)