Instructions to use zeyuren2002/EvalMDE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use zeyuren2002/EvalMDE with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("zeyuren2002/EvalMDE", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """ | |
| 简化版:一次评估多个数据集,支持批量推理 | |
| 用法示例: | |
| python eval_multiple_datasets.py \ | |
| --model-root ./FLUX.1-Kontext-dev \ | |
| --state-dict models/train/kontext/bs64_mask/step-3200.safetensors \ | |
| --datasets scannet,nyuv2 \ | |
| --max-samples 800 \ | |
| --batch-size 4 | |
| 或者不传 --datasets 则评估 DATASETS 中列出的所有数据集(按顺序)。 | |
| """ | |
| import argparse | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from pipelines.flux_image_new import FluxImagePipeline, ModelConfig | |
| from models.utils import load_state_dict, parse_flux_model_configs | |
| from models.unified_dataset import UnifiedDataset, gen_mask, gen_bbox, gen_points, gen_trimap | |
| from utils.eval_depth import test as eval_depth | |
| from utils.eval_normal import test as eval_normal | |
| from utils.eval_matting import test as eval_matting | |
| class DatasetConfig: | |
| name: str | |
| file_list: str # 文件列表,每行一般是 "rel_path [other cols]" 或单列路径 | |
| gt_path: str # 用于评估的 ground-truth 根目录(传给 eval_depth) | |
| output_dir: str # 存放预测 npy 的目录(pred_path) | |
| dataset_arg: str # 传给 eval_depth 的 dataset 名字(如 "scannet","nyuv2","kitti","eth3d") | |
| # 可以按需扩展更多 eval 参数 | |
| # ====== 在这里添加/修改你要评估的数据集 ====== | |
| DATASETS_DEPTH: List[DatasetConfig] = [ | |
| DatasetConfig( | |
| name="nyuv2", | |
| file_list="./data_split/nyu_depth/labeled/filename_list_test.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/nyuv2/", | |
| output_dir="result/nyuv2/", | |
| dataset_arg="nyu", | |
| ), | |
| DatasetConfig( | |
| name="kitti", | |
| file_list="./data_split/kitti_depth/eigen_test_files_with_gt.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/kitti/", | |
| output_dir="result/kitti/", | |
| dataset_arg="kitti", | |
| ), | |
| DatasetConfig( | |
| name="eth3d", | |
| file_list="./data_split/eth3d_depth/eth3d_filename_list.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/eth3d/", | |
| output_dir="result/eth3d/", | |
| dataset_arg="eth3d", | |
| ), | |
| DatasetConfig( | |
| name="scannet", | |
| file_list="./data_split/scannet_depth/scannet_val_sampled_list_800_1.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/scannet", | |
| output_dir="result/scannet/", | |
| dataset_arg="scannet", | |
| ), | |
| DatasetConfig( | |
| name="diode", | |
| file_list="./data_split/diode_depth/diode_val_all_filename_list.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/depth/diode/", | |
| output_dir="result/diode/", | |
| dataset_arg="diode", | |
| ), | |
| ] | |
| DATASETS_NORMAL: List[DatasetConfig] = [ | |
| DatasetConfig( | |
| name="nyuv2", | |
| file_list="./data_split/nyu_normals/nyuv2_test2.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/nyuv2/", | |
| output_dir="result/nyuv2_normal/", | |
| dataset_arg="nyu", | |
| ), | |
| DatasetConfig( | |
| name="scannet", | |
| file_list="./data_split/scannet_normals/scannet_test.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/scannet/", | |
| output_dir="result/scannet_normal/", | |
| dataset_arg="scannet", | |
| ), | |
| DatasetConfig( | |
| name="ibims", | |
| file_list="./data_split/ibims_normals/ibims_test2.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/ibims/", | |
| output_dir="result/ibims_normal/", | |
| dataset_arg="ibims", | |
| ), | |
| DatasetConfig( | |
| name="diode", | |
| file_list="./data_split/diode_normals/diode_test.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/diode/", | |
| output_dir="result/diode_normal/", | |
| dataset_arg="diode", | |
| ), | |
| # DatasetConfig( | |
| # name="oasis", | |
| # file_list="./data_split/oasis_normals/oasis_test2.txt", | |
| # gt_path="/mnt/nfs/workspace/syq/dataset/Eval/normal/oasis/", | |
| # output_dir="result/oasis_normal/", | |
| # dataset_arg="oasis", | |
| # ) | |
| ] | |
| DATASETS_MATTING: List[DatasetConfig] = [ | |
| # DatasetConfig( | |
| # name="comp", | |
| # file_list="./data_split/comp_matting/filenames_test.txt", | |
| # gt_path="/mnt/nfs/workspace/syq/dataset/matting/composition-1k", | |
| # output_dir="result/comp_matting/", | |
| # dataset_arg="comp", | |
| # ), | |
| # DatasetConfig( | |
| # name="p3m", | |
| # file_list="./data_split/P3M_matting/filenames_val_P.txt", | |
| # gt_path="/mnt/nfs/workspace/syq/dataset/matting/P3M-10k", | |
| # output_dir="result/p3m_matting/", | |
| # dataset_arg="p3m", | |
| # ), | |
| DatasetConfig( | |
| name="aim", | |
| file_list="./data_split/AIM_matting/filenames_val.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/matting/AIM-500", | |
| output_dir="result/aim_matting/", | |
| dataset_arg="aim", | |
| ), | |
| DatasetConfig( | |
| name="p3m-np", | |
| file_list="./data_split/P3M_matting/filenames_val_NP.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/matting/P3M-10k", | |
| output_dir="result/p3m_matting/", | |
| dataset_arg="p3m-np", | |
| ), | |
| DatasetConfig( | |
| name="am", | |
| file_list="./data_split/AM_matting/filenames_val.txt", | |
| gt_path="/mnt/nfs/workspace/syq/dataset/matting/AM-2k", | |
| output_dir="result/am_matting/", | |
| dataset_arg="am", | |
| ) | |
| ] | |
| def read_file_list(file_list_path: str, base_dir: str, extra_cols = 0) -> List[str]: | |
| """读取文件列表,返回拼接好的绝对路径列表。兼容每行只有一个 path 或带其它列的情况。""" | |
| files = [] | |
| files_extra = [[] for _ in range(extra_cols)] | |
| base = Path(base_dir) | |
| with open(file_list_path, "r") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split() | |
| rel = parts[0] | |
| # 如果第二列是 "None",跳过(你原来的逻辑) | |
| if len(parts) > 1 and parts[1] == "None": | |
| continue | |
| p = base / rel | |
| files.append(str(p)) | |
| if extra_cols > 0: | |
| for i in range(extra_cols): | |
| rel2 = parts[i+1] if len(parts) > i+1 else "None" | |
| files_extra[i].append(str(base / rel2) if rel2 != "None" else "None") | |
| if extra_cols > 0: | |
| return files, files_extra | |
| return files | |
| def create_batches(files: List[str], batch_size: int) -> List[List[str]]: | |
| """将文件列表分成批次""" | |
| batches = [] | |
| for i in range(0, len(files), batch_size): | |
| batch = files[i:i + batch_size] | |
| batches.append(batch) | |
| return batches | |
| def load_batch_images(file_batch: List[str], transform=None) -> List: | |
| """加载一个批次的图像""" | |
| images = [] | |
| valid_files = [] | |
| if transform is not None: | |
| for file in file_batch: | |
| try: | |
| img = transform(file) | |
| images.append(img) | |
| valid_files.append(file) | |
| except Exception as e: | |
| print(f"Warning: Failed to load {file}: {e}") | |
| # 继续处理其他文件 | |
| else: | |
| for file in file_batch: | |
| try: | |
| img = np.array(Image.open(file)) | |
| images.append(img) | |
| valid_files.append(file) | |
| except Exception as e: | |
| print(f"Warning: Failed to load {file}: {e}") | |
| # 继续处理其他文件 | |
| return images, valid_files | |
| def evaluate_dataset(pipe, | |
| transform, | |
| ds_cfg: DatasetConfig, | |
| batch_size: int = 4, | |
| max_samples: Optional[int] = None, | |
| inference_kwargs: Optional[dict] = None, | |
| cur_step: int = 3000, | |
| args=None,): | |
| print(f"\n=== Evaluate dataset: {ds_cfg.name} (batch_size={batch_size}) ===") | |
| trimaps,alphas = None, None | |
| if args.task == "depth" or args.task == "normal": | |
| files = read_file_list(ds_cfg.file_list, ds_cfg.gt_path) | |
| elif args.task == "matting": | |
| files, extras = read_file_list(ds_cfg.file_list, ds_cfg.gt_path, extra_cols=2) | |
| trimaps, alphas = extras | |
| if max_samples is not None: | |
| files = files[:max_samples] | |
| print(f"Total {len(files)} files for eval (dataset={ds_cfg.name})") | |
| # ds_cfg.output_dir = os.path.join(ds_cfg.output_dir, f"test_step{cur_step}") | |
| out_base = Path(ds_cfg.output_dir) | |
| out_base.mkdir(parents=True, exist_ok=True) | |
| # 创建批次 | |
| batches = create_batches(files, batch_size) | |
| if trimaps is not None: | |
| trimap_batches = create_batches(trimaps, batch_size) | |
| if alphas is not None: | |
| alpha_batches = create_batches(alphas, batch_size) | |
| print(f" using alphas: {len(alphas)}") | |
| failures = [] | |
| total_processed = 0 | |
| if trimaps is not None: | |
| # Matting: 生成多类型视觉提示 trimap/mask/bbox/points + 可选 coords | |
| for batch_files, alpha_files in tqdm(zip(batches, alpha_batches), desc=f"generating {ds_cfg.name}", total=len(batches),disable=True): | |
| batch_images, valid_files = load_batch_images(batch_files, transform) | |
| alpha_images, _ = load_batch_images(alpha_files) | |
| if not batch_images: | |
| failures.extend([(f, "Failed to load image") for f in batch_files]) | |
| continue | |
| visual_prompts = [] | |
| coords_list = [] # 每个样本坐标 | |
| for img_tensor, alpha in zip(batch_images, alpha_images): | |
| if alpha.ndim == 3: | |
| alpha = alpha[:,:,0] | |
| if alpha.dtype != np.float32: | |
| alpha = alpha.astype(np.float32)/255.0 | |
| alpha = np.clip(alpha, 0.0, 1.0) | |
| vp = None | |
| vp_coords = None | |
| if args.matting_prompt is not None: | |
| if args.matting_prompt == "trimap": | |
| vp, vp_coords = gen_trimap((alpha * 255).astype(np.uint8)) | |
| if vp is not None: | |
| vp = (vp / 255.0).astype(np.float32) | |
| elif args.matting_prompt == "mask": | |
| vp, vp_coords = gen_mask(alpha) | |
| elif args.matting_prompt == "bbox": | |
| vp, vp_coords = gen_bbox(alpha, 0.01) | |
| elif args.matting_prompt == "points": | |
| vp, vp_coords = gen_points(alpha) | |
| else: | |
| raise ValueError(f"Unsupported matting_prompt {args.matting_prompt}") | |
| # resize -> 768x768 if needed | |
| if isinstance(vp, np.ndarray): | |
| if vp.shape != (768, 768): | |
| import cv2 | |
| vp = cv2.resize(vp, (768, 768), interpolation=cv2.INTER_LINEAR) | |
| vp_tensor = torch.from_numpy(vp).unsqueeze(0).to(img_tensor.dtype) | |
| else: | |
| vp_tensor = vp | |
| if vp_tensor.dim() == 2: | |
| vp_tensor = vp_tensor.unsqueeze(0) | |
| vp_tensor = (vp_tensor * 2 - 1).repeat(3, 1, 1) | |
| visual_prompts.append(vp_tensor) | |
| # if args.use_coor_input: | |
| if vp_coords is not None: | |
| coords_list.append(vp_coords.astype(np.float32)) | |
| else: | |
| coords_list.append(np.array([0, 0, 1, 1], dtype=np.float32)) | |
| # 组装推理参数 | |
| if len(batch_images) == 1: | |
| if visual_prompts != []: | |
| kontext_images = [batch_images[0], visual_prompts[0]] | |
| else: | |
| kontext_images = batch_images[0] | |
| pipe_kwargs = dict( | |
| prompt=f"Transform to {args.task} map while maintaining original composition", | |
| kontext_images=kontext_images, | |
| height=768, width=768, | |
| embedded_guidance=inference_kwargs.get("embedded_guidance", 4), | |
| num_inference_steps=inference_kwargs.get("num_inference_steps", 4), | |
| seed=inference_kwargs.get("seed", 42), | |
| output_type="np", | |
| rand_device=inference_kwargs.get("rand_device", "cuda"), | |
| task=args.task, | |
| ) | |
| if args.use_coor_input and len(coords_list) > 0: | |
| vpc_tensor = torch.from_numpy(coords_list[0]).to(torch.float32) | |
| if vpc_tensor.dim() == 1: | |
| vpc_tensor = vpc_tensor.unsqueeze(0).to("cuda") | |
| pipe_kwargs["visual_prompt_coords"] = vpc_tensor | |
| out_np_batch = pipe(**pipe_kwargs) | |
| if not isinstance(out_np_batch, list): | |
| out_np_batch = [out_np_batch] | |
| else: | |
| images_stack = torch.stack(batch_images).to("cuda") | |
| if visual_prompts != []: | |
| prompts_stack = torch.stack(visual_prompts).to("cuda") | |
| kontext_images = [images_stack,prompts_stack] | |
| else: | |
| kontext_images = images_stack | |
| pipe_kwargs = dict( | |
| prompt=[f"Transform to {args.task} map while maintaining original composition"] * len(batch_images), | |
| kontext_images=kontext_images, | |
| height=768, width=768, | |
| embedded_guidance=inference_kwargs.get("embedded_guidance", 4), | |
| num_inference_steps=inference_kwargs.get("num_inference_steps", 4), | |
| seed=inference_kwargs.get("seed", 42), | |
| output_type="np", | |
| rand_device=inference_kwargs.get("rand_device", "cuda"), | |
| task=args.task, | |
| ) | |
| if args.use_coor_input and len(coords_list) == len(batch_images): | |
| vpc_tensor = torch.from_numpy(np.stack(coords_list, axis=0)).to(torch.float32).to("cuda") | |
| pipe_kwargs["visual_prompt_coords"] = vpc_tensor | |
| out_np_batch = pipe(**pipe_kwargs) | |
| for file, out_np in zip(valid_files, out_np_batch): | |
| rel_out = Path(file).relative_to(Path(ds_cfg.gt_path)) | |
| save_to = out_base / rel_out | |
| save_to = save_to.with_suffix(".npy") | |
| save_to.parent.mkdir(parents=True, exist_ok=True) | |
| np.save(str(save_to), out_np) | |
| # 额外保存点提示 visual_prompts | |
| # for i, file in enumerate(valid_files): | |
| # if args.matting_prompt == "points": | |
| # # 只保存坐标,每次用"a"的方式添加写入到txt文件中 | |
| # with open(out_base / "points_coords.txt", "a") as pf: | |
| # rel_out = Path(file).relative_to(Path(ds_cfg.gt_path)) | |
| # coords = coords_list[valid_files.index(file)] | |
| # coord_str = ",".join([str(c) for c in coords]) | |
| # pf.write(f"{rel_out.as_posix()}: {coord_str}\n") | |
| # total_processed += 1 | |
| else: | |
| for batch_files in tqdm(batches, desc=f"generating {ds_cfg.name}",disable=True): | |
| # 加载这个批次的图像 | |
| batch_images, valid_files = load_batch_images(batch_files, transform) | |
| if not batch_images: | |
| # 如果整个批次都加载失败,跳过 | |
| failures.extend([(f, "Failed to load image") for f in batch_files]) | |
| continue | |
| # 批量推理 | |
| # 注意:这里假设 pipe 支持批量输入,如果不支持需要修改 | |
| if len(batch_images) == 1: | |
| # 单个图像的情况 | |
| out_np_batch = pipe( | |
| prompt=f"Transform to {args.task} map while maintaining original composition", | |
| kontext_images=batch_images[0], | |
| height=768, width=768, | |
| embedded_guidance=inference_kwargs.get("embedded_guidance", 4), | |
| num_inference_steps=inference_kwargs.get("num_inference_steps", 4), | |
| seed=inference_kwargs.get("seed", 42), | |
| output_type="np", | |
| rand_device=inference_kwargs.get("rand_device", "cuda"), | |
| task=args.task, | |
| # deterministic_flow=True, ############ only for debug #################### | |
| ) | |
| # 确保输出是列表形式 | |
| if not isinstance(out_np_batch, list): | |
| out_np_batch = [out_np_batch] | |
| else: | |
| # 批量推理 - 这里可能需要根据你的 pipeline 实现进行调整 | |
| # 方案1: 如果 pipeline 支持批量输入 | |
| out_np_batch = pipe( | |
| prompt=[f"Transform to {args.task} map while maintaining original composition"] * len(batch_images), | |
| kontext_images=torch.stack(batch_images), | |
| height=768, width=768, | |
| embedded_guidance=inference_kwargs.get("embedded_guidance", 4), | |
| num_inference_steps=inference_kwargs.get("num_inference_steps", 4), | |
| seed=inference_kwargs.get("seed", 42), | |
| output_type="np", | |
| rand_device='cuda', | |
| task=args.task, | |
| ) | |
| # 保存结果 | |
| for file, out_np in zip(valid_files, out_np_batch): | |
| rel_out = Path(file).relative_to(Path(ds_cfg.gt_path)) | |
| save_to = out_base / rel_out | |
| save_to = save_to.with_suffix(".npy") | |
| save_to.parent.mkdir(parents=True, exist_ok=True) | |
| np.save(str(save_to), out_np) | |
| total_processed += 1 | |
| print(f"Generation finished for {ds_cfg.name}. Processed: {total_processed}, Failures: {len(failures)}") | |
| if failures: | |
| # 简短输出若干错误示例 | |
| print("Some failures (showing up to 5):") | |
| for f, err in failures[:5]: | |
| print(f" - {f}: {err}") | |
| # 调用评估脚本 | |
| max_depth_eval = {"scannet": 10.0, "nyu": 10.0, "kitti": 80.0, "eth3d": 99999, "diode": 80.0}.get(ds_cfg.dataset_arg, 80.0) | |
| if args.task == "depth": | |
| eval_args = argparse.Namespace( | |
| pred_path=ds_cfg.output_dir, | |
| gt_path=ds_cfg.gt_path, | |
| dataset=ds_cfg.dataset_arg, | |
| eigen_crop=True if ds_cfg.dataset_arg == "kitti" else False, | |
| garg_crop=False, | |
| min_depth_eval=1e-3, | |
| max_depth_eval=max_depth_eval, | |
| do_kb_crop=True if ds_cfg.dataset_arg == "kitti" else False, | |
| no_verbose=False, | |
| using_log=("log" in args.state_dict), | |
| using_disp=("disp" in args.state_dict or "inverse" in args.state_dict), | |
| using_sqrt=("sqrt" in args.state_dict), | |
| using_sqrt_disp=("sqrt_disp" in args.state_dict), | |
| using_pdf=("pdf" in args.state_dict) | |
| ) | |
| print(eval_args) | |
| try: | |
| eval_depth(eval_args) | |
| except Exception as e: | |
| print(f"Evaluation failed for {ds_cfg.name}: {e}") | |
| elif args.task == "normal": | |
| eval_args = argparse.Namespace( | |
| pred_path=ds_cfg.output_dir, | |
| gt_path=ds_cfg.gt_path, | |
| dataset=ds_cfg.dataset_arg, | |
| ) | |
| print(eval_args) | |
| try: | |
| eval_normal(eval_args) | |
| except Exception as e: | |
| print(f"Evaluation failed for {ds_cfg.name}: {e}") | |
| elif args.task == "matting": | |
| eval_args = argparse.Namespace( | |
| pred_path=ds_cfg.output_dir, | |
| gt_path=ds_cfg.gt_path, | |
| dataset=ds_cfg.dataset_arg, | |
| ) | |
| print(eval_args) | |
| try: | |
| eval_matting(eval_args) | |
| except Exception as e: | |
| print(f"Evaluation failed for {ds_cfg.name}: {e}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--cur_step", type=int, default=5700, help="当前评估的训练步数,用于结果目录命名") | |
| parser.add_argument("--model_root", type=str, default="./FLUX.1-Kontext-dev", help="Flux model root directory") | |
| # parser.add_argument("--state_dict", type=str, default=f"models/train/kontext/bs64_mask/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state-dict", type=str, default=f"models/train/kontext/bs64_log_cons/step-{cur_step}.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state_dict", type=str, default=f"models/train/kontext/bs64_sqrt_cons/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state-dict", type=str, default=f"models/train/kontext/bs64_sqrt_deter_zero/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| parser.add_argument("--state_dict", type=str, default=f"models/train/kontext_normal/bs16_flux_cons/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_normal/bs16_cons/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_normal/bs16_deter_zero/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_matting/bs16_cons_mixSDMatte_points/step-@@.safetensors", help="训练好的 state_dict path to load") | |
| # parser.add_argument("--state-dict", type=str, default=f"models/train/kontext_matting/bs16_cons_mixSDMatte_empty/step-{cur_step}.safetensors", help="训练好的 state_dict path to load") | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"]) | |
| parser.add_argument("--datasets", type=str, default=None, help="逗号分隔的数据集名(使用 DATASETS 中定义的 name 字段)") | |
| parser.add_argument("--max_samples", type=int, default=None, help="便捷:只跑每个数据集的前 N 张做快速测试") | |
| parser.add_argument("--batch_size", type=int, default=1, help="批量推理的批次大小") | |
| parser.add_argument("--embedded_guidance", type=float, default=1, help="推理时的 embedded_guidance 强度") | |
| parser.add_argument("--task", type=str, default="depth", choices=["depth", "normal","matting"], help="评估任务类型,决定用哪个 DATASETS 列表") | |
| parser.add_argument("--matting_prompt", type=str, default=None, choices=["trimap", "mask", "bbox", "points"], help="(matting) 视觉提示类型") | |
| parser.add_argument("--use_coor_input", action="store_true", help="是否传入 visual_prompt_coords (bbox/points 等坐标 embedding)") | |
| parser.add_argument("--hw", type=str, default="768x768") | |
| parser.add_argument("--inference_steps", type=int, default=1, help="推理时的采样步数") | |
| args = parser.parse_args() | |
| # 选择要跑的 datasets | |
| # if "normal" in args.state_dict: | |
| # args.task = "normal" | |
| # elif "matting" in args.state_dict: | |
| # args.task = "matting" | |
| if args.task == "depth": | |
| DATASETS = DATASETS_DEPTH | |
| elif args.task == "normal": | |
| DATASETS = DATASETS_NORMAL | |
| elif args.task == "matting": | |
| DATASETS = DATASETS_MATTING | |
| if args.datasets: | |
| wanted = {n.strip() for n in args.datasets.split(",") if n.strip()} | |
| datasets = [d for d in DATASETS if d.name in wanted] | |
| print(f"Selected datasets: {[d.name for d in datasets]}") | |
| if not datasets: | |
| raise ValueError(f"No matching dataset in DATASETS for names: {wanted}") | |
| else: | |
| datasets = DATASETS | |
| print(f"No --datasets specified, will evaluate all: {[d.name for d in datasets]}") | |
| # load pipeline once | |
| torch_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.torch_dtype] | |
| print("Loading FluxImagePipeline ...") | |
| pipe = FluxImagePipeline.from_pretrained( | |
| torch_dtype=torch_dtype, | |
| device=args.device, | |
| model_configs=parse_flux_model_configs(args.model_root) | |
| ) | |
| state_dict = load_state_dict(args.state_dict) | |
| pipe.dit.load_state_dict(state_dict) | |
| # transform reuse | |
| h, w = map(int, args.hw.split("x")) | |
| print(f"Eval with {h}x{w}") | |
| transform = UnifiedDataset.default_image_operator(height=h, width=w) | |
| inference_kwargs = dict( | |
| num_inference_steps=args.inference_steps, | |
| seed=42, | |
| rand_device=args.device.split(":")[0] if ":" in args.device else args.device, | |
| embedded_guidance=args.embedded_guidance | |
| ) | |
| for ds in datasets: | |
| evaluate_dataset( | |
| pipe, | |
| transform, | |
| ds, | |
| batch_size=args.batch_size, | |
| max_samples=args.max_samples, | |
| inference_kwargs=inference_kwargs, | |
| args=args, | |
| ) | |
| print("\nAll done.") | |
| if __name__ == "__main__": | |
| main() | |