Diffusers
Safetensors
EvalMDE / Edit2Perceive /utils /eval_multiple_datasets.py
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
#!/usr/bin/env python3
"""
简化版:一次评估多个数据集,支持批量推理
用法示例:
python eval_multiple_datasets.py \
--model-root ./FLUX.1-Kontext-dev \
--state-dict models/train/kontext/bs64_mask/step-3200.safetensors \
--datasets scannet,nyuv2 \
--max-samples 800 \
--batch-size 4
或者不传 --datasets 则评估 DATASETS 中列出的所有数据集(按顺序)。
"""
import argparse
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from pipelines.flux_image_new import FluxImagePipeline, ModelConfig
from models.utils import load_state_dict, parse_flux_model_configs
from models.unified_dataset import UnifiedDataset, gen_mask, gen_bbox, gen_points, gen_trimap
from utils.eval_depth import test as eval_depth
from utils.eval_normal import test as eval_normal
from utils.eval_matting import test as eval_matting
@dataclass
class DatasetConfig:
name: str
file_list: str # 文件列表,每行一般是 "rel_path [other cols]" 或单列路径
gt_path: str # 用于评估的 ground-truth 根目录(传给 eval_depth)
output_dir: str # 存放预测 npy 的目录(pred_path)
dataset_arg: str # 传给 eval_depth 的 dataset 名字(如 "scannet","nyuv2","kitti","eth3d")
# 可以按需扩展更多 eval 参数
# ====== 在这里添加/修改你要评估的数据集 ======
DATASETS_DEPTH: List[DatasetConfig] = [
DatasetConfig(
name="nyuv2",
file_list="./data_split/nyu_depth/labeled/filename_list_test.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/nyuv2/",
output_dir="result/nyuv2/",
dataset_arg="nyu",
),
DatasetConfig(
name="kitti",
file_list="./data_split/kitti_depth/eigen_test_files_with_gt.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/kitti/",
output_dir="result/kitti/",
dataset_arg="kitti",
),
DatasetConfig(
name="eth3d",
file_list="./data_split/eth3d_depth/eth3d_filename_list.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/eth3d/",
output_dir="result/eth3d/",
dataset_arg="eth3d",
),
DatasetConfig(
name="scannet",
file_list="./data_split/scannet_depth/scannet_val_sampled_list_800_1.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/scannet",
output_dir="result/scannet/",
dataset_arg="scannet",
),
DatasetConfig(
name="diode",
file_list="./data_split/diode_depth/diode_val_all_filename_list.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/diode/",
output_dir="result/diode/",
dataset_arg="diode",
),
]
DATASETS_NORMAL: List[DatasetConfig] = [
DatasetConfig(
name="nyuv2",
file_list="./data_split/nyu_normals/nyuv2_test2.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/nyuv2/",
output_dir="result/nyuv2_normal/",
dataset_arg="nyu",
),
DatasetConfig(
name="scannet",
file_list="./data_split/scannet_normals/scannet_test.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/scannet/",
output_dir="result/scannet_normal/",
dataset_arg="scannet",
),
DatasetConfig(
name="ibims",
file_list="./data_split/ibims_normals/ibims_test2.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/ibims/",
output_dir="result/ibims_normal/",
dataset_arg="ibims",
),
DatasetConfig(
name="diode",
file_list="./data_split/diode_normals/diode_test.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/diode/",
output_dir="result/diode_normal/",
dataset_arg="diode",
),
# DatasetConfig(
# name="oasis",
# file_list="./data_split/oasis_normals/oasis_test2.txt",
# gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/oasis/",
# output_dir="result/oasis_normal/",
# dataset_arg="oasis",
# )
]
DATASETS_MATTING: List[DatasetConfig] = [
# DatasetConfig(
# name="comp",
# file_list="./data_split/comp_matting/filenames_test.txt",
# gt_path="/mnt/nfs/workspace/syq/dataset/matting/composition-1k",
# output_dir="result/comp_matting/",
# dataset_arg="comp",
# ),
# DatasetConfig(
# name="p3m",
# file_list="./data_split/P3M_matting/filenames_val_P.txt",
# gt_path="/mnt/nfs/workspace/syq/dataset/matting/P3M-10k",
# output_dir="result/p3m_matting/",
# dataset_arg="p3m",
# ),
DatasetConfig(
name="aim",
file_list="./data_split/AIM_matting/filenames_val.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/matting/AIM-500",
output_dir="result/aim_matting/",
dataset_arg="aim",
),
DatasetConfig(
name="p3m-np",
file_list="./data_split/P3M_matting/filenames_val_NP.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/matting/P3M-10k",
output_dir="result/p3m_matting/",
dataset_arg="p3m-np",
),
DatasetConfig(
name="am",
file_list="./data_split/AM_matting/filenames_val.txt",
gt_path="/mnt/nfs/workspace/syq/dataset/matting/AM-2k",
output_dir="result/am_matting/",
dataset_arg="am",
)
]
def read_file_list(file_list_path: str, base_dir: str, extra_cols = 0) -> List[str]:
"""读取文件列表,返回拼接好的绝对路径列表。兼容每行只有一个 path 或带其它列的情况。"""
files = []
files_extra = [[] for _ in range(extra_cols)]
base = Path(base_dir)
with open(file_list_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
rel = parts[0]
# 如果第二列是 "None",跳过(你原来的逻辑)
if len(parts) > 1 and parts[1] == "None":
continue
p = base / rel
files.append(str(p))
if extra_cols > 0:
for i in range(extra_cols):
rel2 = parts[i+1] if len(parts) > i+1 else "None"
files_extra[i].append(str(base / rel2) if rel2 != "None" else "None")
if extra_cols > 0:
return files, files_extra
return files
def create_batches(files: List[str], batch_size: int) -> List[List[str]]:
"""将文件列表分成批次"""
batches = []
for i in range(0, len(files), batch_size):
batch = files[i:i + batch_size]
batches.append(batch)
return batches
def load_batch_images(file_batch: List[str], transform=None) -> List:
"""加载一个批次的图像"""
images = []
valid_files = []
if transform is not None:
for file in file_batch:
try:
img = transform(file)
images.append(img)
valid_files.append(file)
except Exception as e:
print(f"Warning: Failed to load {file}: {e}")
# 继续处理其他文件
else:
for file in file_batch:
try:
img = np.array(Image.open(file))
images.append(img)
valid_files.append(file)
except Exception as e:
print(f"Warning: Failed to load {file}: {e}")
# 继续处理其他文件
return images, valid_files
def evaluate_dataset(pipe,
transform,
ds_cfg: DatasetConfig,
batch_size: int = 4,
max_samples: Optional[int] = None,
inference_kwargs: Optional[dict] = None,
cur_step: int = 3000,
args=None,):
print(f"\n=== Evaluate dataset: {ds_cfg.name} (batch_size={batch_size}) ===")
trimaps,alphas = None, None
if args.task == "depth" or args.task == "normal":
files = read_file_list(ds_cfg.file_list, ds_cfg.gt_path)
elif args.task == "matting":
files, extras = read_file_list(ds_cfg.file_list, ds_cfg.gt_path, extra_cols=2)
trimaps, alphas = extras
if max_samples is not None:
files = files[:max_samples]
print(f"Total {len(files)} files for eval (dataset={ds_cfg.name})")
# ds_cfg.output_dir = os.path.join(ds_cfg.output_dir, f"test_step{cur_step}")
out_base = Path(ds_cfg.output_dir)
out_base.mkdir(parents=True, exist_ok=True)
# 创建批次
batches = create_batches(files, batch_size)
if trimaps is not None:
trimap_batches = create_batches(trimaps, batch_size)
if alphas is not None:
alpha_batches = create_batches(alphas, batch_size)
print(f" using alphas: {len(alphas)}")
failures = []
total_processed = 0
if trimaps is not None:
# Matting: 生成多类型视觉提示 trimap/mask/bbox/points + 可选 coords
for batch_files, alpha_files in tqdm(zip(batches, alpha_batches), desc=f"generating {ds_cfg.name}", total=len(batches),disable=True):
batch_images, valid_files = load_batch_images(batch_files, transform)
alpha_images, _ = load_batch_images(alpha_files)
if not batch_images:
failures.extend([(f, "Failed to load image") for f in batch_files])
continue
visual_prompts = []
coords_list = [] # 每个样本坐标
for img_tensor, alpha in zip(batch_images, alpha_images):
if alpha.ndim == 3:
alpha = alpha[:,:,0]
if alpha.dtype != np.float32:
alpha = alpha.astype(np.float32)/255.0
alpha = np.clip(alpha, 0.0, 1.0)
vp = None
vp_coords = None
if args.matting_prompt is not None:
if args.matting_prompt == "trimap":
vp, vp_coords = gen_trimap((alpha * 255).astype(np.uint8))
if vp is not None:
vp = (vp / 255.0).astype(np.float32)
elif args.matting_prompt == "mask":
vp, vp_coords = gen_mask(alpha)
elif args.matting_prompt == "bbox":
vp, vp_coords = gen_bbox(alpha, 0.01)
elif args.matting_prompt == "points":
vp, vp_coords = gen_points(alpha)
else:
raise ValueError(f"Unsupported matting_prompt {args.matting_prompt}")
# resize -> 768x768 if needed
if isinstance(vp, np.ndarray):
if vp.shape != (768, 768):
import cv2
vp = cv2.resize(vp, (768, 768), interpolation=cv2.INTER_LINEAR)
vp_tensor = torch.from_numpy(vp).unsqueeze(0).to(img_tensor.dtype)
else:
vp_tensor = vp
if vp_tensor.dim() == 2:
vp_tensor = vp_tensor.unsqueeze(0)
vp_tensor = (vp_tensor * 2 - 1).repeat(3, 1, 1)
visual_prompts.append(vp_tensor)
# if args.use_coor_input:
if vp_coords is not None:
coords_list.append(vp_coords.astype(np.float32))
else:
coords_list.append(np.array([0, 0, 1, 1], dtype=np.float32))
# 组装推理参数
if len(batch_images) == 1:
if visual_prompts != []:
kontext_images = [batch_images[0], visual_prompts[0]]
else:
kontext_images = batch_images[0]
pipe_kwargs = dict(
prompt=f"Transform to {args.task} map while maintaining original composition",
kontext_images=kontext_images,
height=768, width=768,
embedded_guidance=inference_kwargs.get("embedded_guidance", 4),
num_inference_steps=inference_kwargs.get("num_inference_steps", 4),
seed=inference_kwargs.get("seed", 42),
output_type="np",
rand_device=inference_kwargs.get("rand_device", "cuda"),
task=args.task,
)
if args.use_coor_input and len(coords_list) > 0:
vpc_tensor = torch.from_numpy(coords_list[0]).to(torch.float32)
if vpc_tensor.dim() == 1:
vpc_tensor = vpc_tensor.unsqueeze(0).to("cuda")
pipe_kwargs["visual_prompt_coords"] = vpc_tensor
out_np_batch = pipe(**pipe_kwargs)
if not isinstance(out_np_batch, list):
out_np_batch = [out_np_batch]
else:
images_stack = torch.stack(batch_images).to("cuda")
if visual_prompts != []:
prompts_stack = torch.stack(visual_prompts).to("cuda")
kontext_images = [images_stack,prompts_stack]
else:
kontext_images = images_stack
pipe_kwargs = dict(
prompt=[f"Transform to {args.task} map while maintaining original composition"] * len(batch_images),
kontext_images=kontext_images,
height=768, width=768,
embedded_guidance=inference_kwargs.get("embedded_guidance", 4),
num_inference_steps=inference_kwargs.get("num_inference_steps", 4),
seed=inference_kwargs.get("seed", 42),
output_type="np",
rand_device=inference_kwargs.get("rand_device", "cuda"),
task=args.task,
)
if args.use_coor_input and len(coords_list) == len(batch_images):
vpc_tensor = torch.from_numpy(np.stack(coords_list, axis=0)).to(torch.float32).to("cuda")
pipe_kwargs["visual_prompt_coords"] = vpc_tensor
out_np_batch = pipe(**pipe_kwargs)
for file, out_np in zip(valid_files, out_np_batch):
rel_out = Path(file).relative_to(Path(ds_cfg.gt_path))
save_to = out_base / rel_out
save_to = save_to.with_suffix(".npy")
save_to.parent.mkdir(parents=True, exist_ok=True)
np.save(str(save_to), out_np)
# 额外保存点提示 visual_prompts
# for i, file in enumerate(valid_files):
# if args.matting_prompt == "points":
# # 只保存坐标,每次用"a"的方式添加写入到txt文件中
# with open(out_base / "points_coords.txt", "a") as pf:
# rel_out = Path(file).relative_to(Path(ds_cfg.gt_path))
# coords = coords_list[valid_files.index(file)]
# coord_str = ",".join([str(c) for c in coords])
# pf.write(f"{rel_out.as_posix()}: {coord_str}\n")
# total_processed += 1
else:
for batch_files in tqdm(batches, desc=f"generating {ds_cfg.name}",disable=True):
# 加载这个批次的图像
batch_images, valid_files = load_batch_images(batch_files, transform)
if not batch_images:
# 如果整个批次都加载失败,跳过
failures.extend([(f, "Failed to load image") for f in batch_files])
continue
# 批量推理
# 注意:这里假设 pipe 支持批量输入,如果不支持需要修改
if len(batch_images) == 1:
# 单个图像的情况
out_np_batch = pipe(
prompt=f"Transform to {args.task} map while maintaining original composition",
kontext_images=batch_images[0],
height=768, width=768,
embedded_guidance=inference_kwargs.get("embedded_guidance", 4),
num_inference_steps=inference_kwargs.get("num_inference_steps", 4),
seed=inference_kwargs.get("seed", 42),
output_type="np",
rand_device=inference_kwargs.get("rand_device", "cuda"),
task=args.task,
# deterministic_flow=True, ############ only for debug ####################
)
# 确保输出是列表形式
if not isinstance(out_np_batch, list):
out_np_batch = [out_np_batch]
else:
# 批量推理 - 这里可能需要根据你的 pipeline 实现进行调整
# 方案1: 如果 pipeline 支持批量输入
out_np_batch = pipe(
prompt=[f"Transform to {args.task} map while maintaining original composition"] * len(batch_images),
kontext_images=torch.stack(batch_images),
height=768, width=768,
embedded_guidance=inference_kwargs.get("embedded_guidance", 4),
num_inference_steps=inference_kwargs.get("num_inference_steps", 4),
seed=inference_kwargs.get("seed", 42),
output_type="np",
rand_device='cuda',
task=args.task,
)
# 保存结果
for file, out_np in zip(valid_files, out_np_batch):
rel_out = Path(file).relative_to(Path(ds_cfg.gt_path))
save_to = out_base / rel_out
save_to = save_to.with_suffix(".npy")
save_to.parent.mkdir(parents=True, exist_ok=True)
np.save(str(save_to), out_np)
total_processed += 1
print(f"Generation finished for {ds_cfg.name}. Processed: {total_processed}, Failures: {len(failures)}")
if failures:
# 简短输出若干错误示例
print("Some failures (showing up to 5):")
for f, err in failures[:5]:
print(f" - {f}: {err}")
# 调用评估脚本
max_depth_eval = {"scannet": 10.0, "nyu": 10.0, "kitti": 80.0, "eth3d": 99999, "diode": 80.0}.get(ds_cfg.dataset_arg, 80.0)
if args.task == "depth":
eval_args = argparse.Namespace(
pred_path=ds_cfg.output_dir,
gt_path=ds_cfg.gt_path,
dataset=ds_cfg.dataset_arg,
eigen_crop=True if ds_cfg.dataset_arg == "kitti" else False,
garg_crop=False,
min_depth_eval=1e-3,
max_depth_eval=max_depth_eval,
do_kb_crop=True if ds_cfg.dataset_arg == "kitti" else False,
no_verbose=False,
using_log=("log" in args.state_dict),
using_disp=("disp" in args.state_dict or "inverse" in args.state_dict),
using_sqrt=("sqrt" in args.state_dict),
using_sqrt_disp=("sqrt_disp" in args.state_dict),
using_pdf=("pdf" in args.state_dict)
)
print(eval_args)
try:
eval_depth(eval_args)
except Exception as e:
print(f"Evaluation failed for {ds_cfg.name}: {e}")
elif args.task == "normal":
eval_args = argparse.Namespace(
pred_path=ds_cfg.output_dir,
gt_path=ds_cfg.gt_path,
dataset=ds_cfg.dataset_arg,
)
print(eval_args)
try:
eval_normal(eval_args)
except Exception as e:
print(f"Evaluation failed for {ds_cfg.name}: {e}")
elif args.task == "matting":
eval_args = argparse.Namespace(
pred_path=ds_cfg.output_dir,
gt_path=ds_cfg.gt_path,
dataset=ds_cfg.dataset_arg,
)
print(eval_args)
try:
eval_matting(eval_args)
except Exception as e:
print(f"Evaluation failed for {ds_cfg.name}: {e}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cur_step", type=int, default=5700, help="当前评估的训练步数,用于结果目录命名")
parser.add_argument("--model_root", type=str, default="./FLUX.1-Kontext-dev", help="Flux model root directory")
# parser.add_argument("--state_dict", type=str, default=f"models/train/kontext/bs64_mask/step-@@.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state-dict", type=str, default=f"models/train/kontext/bs64_log_cons/step-{cur_step}.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state_dict", type=str, default=f"models/train/kontext/bs64_sqrt_cons/step-@@.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state-dict", type=str, default=f"models/train/kontext/bs64_sqrt_deter_zero/step-@@.safetensors", help="训练好的 state_dict path to load")
parser.add_argument("--state_dict", type=str, default=f"models/train/kontext_normal/bs16_flux_cons/step-@@.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_normal/bs16_cons/step-@@.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_normal/bs16_deter_zero/step-@@.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_matting/bs16_cons_mixSDMatte_points/step-@@.safetensors", help="训练好的 state_dict path to load")
# parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_matting/bs16_cons_mixSDMatte_empty/step-{cur_step}.safetensors", help="训练好的 state_dict path to load")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
parser.add_argument("--datasets", type=str, default=None, help="逗号分隔的数据集名(使用 DATASETS 中定义的 name 字段)")
parser.add_argument("--max_samples", type=int, default=None, help="便捷:只跑每个数据集的前 N 张做快速测试")
parser.add_argument("--batch_size", type=int, default=1, help="批量推理的批次大小")
parser.add_argument("--embedded_guidance", type=float, default=1, help="推理时的 embedded_guidance 强度")
parser.add_argument("--task", type=str, default="depth", choices=["depth", "normal","matting"], help="评估任务类型,决定用哪个 DATASETS 列表")
parser.add_argument("--matting_prompt", type=str, default=None, choices=["trimap", "mask", "bbox", "points"], help="(matting) 视觉提示类型")
parser.add_argument("--use_coor_input", action="store_true", help="是否传入 visual_prompt_coords (bbox/points 等坐标 embedding)")
parser.add_argument("--hw", type=str, default="768x768")
parser.add_argument("--inference_steps", type=int, default=1, help="推理时的采样步数")
args = parser.parse_args()
# 选择要跑的 datasets
# if "normal" in args.state_dict:
# args.task = "normal"
# elif "matting" in args.state_dict:
# args.task = "matting"
if args.task == "depth":
DATASETS = DATASETS_DEPTH
elif args.task == "normal":
DATASETS = DATASETS_NORMAL
elif args.task == "matting":
DATASETS = DATASETS_MATTING
if args.datasets:
wanted = {n.strip() for n in args.datasets.split(",") if n.strip()}
datasets = [d for d in DATASETS if d.name in wanted]
print(f"Selected datasets: {[d.name for d in datasets]}")
if not datasets:
raise ValueError(f"No matching dataset in DATASETS for names: {wanted}")
else:
datasets = DATASETS
print(f"No --datasets specified, will evaluate all: {[d.name for d in datasets]}")
# load pipeline once
torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.torch_dtype]
print("Loading FluxImagePipeline ...")
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch_dtype,
device=args.device,
model_configs=parse_flux_model_configs(args.model_root)
)
state_dict = load_state_dict(args.state_dict)
pipe.dit.load_state_dict(state_dict)
# transform reuse
h, w = map(int, args.hw.split("x"))
print(f"Eval with {h}x{w}")
transform = UnifiedDataset.default_image_operator(height=h, width=w)
inference_kwargs = dict(
num_inference_steps=args.inference_steps,
seed=42,
rand_device=args.device.split(":")[0] if ":" in args.device else args.device,
embedded_guidance=args.embedded_guidance
)
for ds in datasets:
evaluate_dataset(
pipe,
transform,
ds,
batch_size=args.batch_size,
max_samples=args.max_samples,
inference_kwargs=inference_kwargs,
args=args,
)
print("\nAll done.")
if __name__ == "__main__":
main()