Diffusers
Safetensors
EvalMDE / baselines /fe2e.py
zeyuren2002's picture
Add files using upload-large-folder tool
4165f20 verified
# Reference: https://github.com/AMAP-ML/FE2E
# Strictly follows the official `infer/inference.py::ImageGenerator` API.
# Distributed entrypoint `evaluation.py` is bypassed; we use the same `ImageGenerator`
# class on a single GPU. README usage (for depth):
# python -u evaluation.py \
# --model_path ./pretrain --lora ./lora/LDRN.safetensors --single_denoise \
# --prompt_type empty --norm_type ln --task_name depth ...
#
# Important calling convention (from inference.py):
# ImageGenerator.__init__ requires an `args` namespace with at least:
# args.prompt_type (='empty' to skip Qwen),
# args.single_denoise (sets num_steps = 1 via FE2E's parse_args),
# args.empty_prompt_cache (path to latent/no_info.npz),
# `generate_image(prompt, negative_prompt, ref_images=PIL_or_tensor, num_steps,
# cfg_guidance, seed, ..., args=args)` returns (images_list, Lpred, Rpred).
# - Lpred: float tensor in [0, 1] (mul .5 + .5 applied) — corresponds to the *edited*
# frame (= depth output for FE2E's depth LoRA).
# - Rpred: float tensor in [-1, 1] — the reconstructed reference RGB.
#
# Output key: `depth_affine_invariant` (Lpred mean over channels).
# NEEDS_VERIFICATION: Lpred vs Rpred meaning is inferred from generate_image semantics
# (Lpred = left/first frame = denoised target = depth) and confirmed by inner_evaluation.py
# using `Lpred` for depth metrics. Switch via `--use-rpred` if a sanity run shows the
# depth is actually carried by Rpred.
import os
import sys
from typing import *
from pathlib import Path
from types import SimpleNamespace
import click
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from moge.test.baseline import MGEBaselineInterface
class Baseline(MGEBaselineInterface):
def __init__(self, repo_path: str, model_path: str, lora_path: str,
qwen2vl_path: Optional[str], empty_prompt_cache: Optional[str],
num_steps: int, cfg_guidance: float, size_level: int,
prompt_type: str, single_denoise: bool, seed: int,
quantized: bool, offload: bool, use_rpred: bool,
device: Union[torch.device, str]):
repo_path = os.path.abspath(repo_path)
if not Path(repo_path).exists():
raise FileNotFoundError(
f"Cannot find FE2E repo at {repo_path}. Clone https://github.com/AMAP-ML/FE2E."
)
if repo_path not in sys.path:
sys.path.insert(0, repo_path)
from infer.inference import ImageGenerator
from infer.seed_all import seed_all
seed_all(seed)
def _resolve(p):
return p if os.path.isabs(p) else os.path.join(repo_path, p)
model_path = _resolve(model_path)
lora_path = _resolve(lora_path)
if qwen2vl_path is not None:
qwen2vl_path = _resolve(qwen2vl_path)
else:
qwen2vl_path = os.path.join(repo_path, "Qwen") # FE2E DEFAULT_QWEN_DIR
if empty_prompt_cache is not None:
empty_prompt_cache = _resolve(empty_prompt_cache)
else:
empty_prompt_cache = os.path.join(repo_path, "latent", "no_info.npz")
# ImageGenerator reads several attrs off args (prompt_type, single_denoise, empty_prompt_cache).
ig_args = SimpleNamespace(
prompt_type=prompt_type,
single_denoise=single_denoise,
empty_prompt_cache=empty_prompt_cache,
)
device = torch.device(device)
ae_path = os.path.join(model_path, "vae.safetensors")
dit_basename = "step1x-edit-i1258-FP8.safetensors" if quantized else "step1x-edit-i1258.safetensors"
dit_path = os.path.join(model_path, dit_basename)
for p in (ae_path, dit_path, lora_path, empty_prompt_cache):
if not os.path.exists(p):
raise FileNotFoundError(f"Missing required FE2E artifact: {p}")
self.image_gen = ImageGenerator(
ae_path=ae_path,
dit_path=dit_path,
qwen2vl_model_path=qwen2vl_path,
max_length=640,
quantized=quantized,
offload=offload,
lora=lora_path,
device=str(device),
args=ig_args,
)
self.device = device
self.num_steps = 1 if single_denoise else num_steps
self.cfg_guidance = cfg_guidance
self.size_level = size_level
self.seed = seed
self.ig_args = ig_args
self.use_rpred = use_rpred
@click.command()
@click.option('--repo', 'repo_path', type=click.Path(), default='../FE2E',
help='Path to the AMAP-ML/FE2E repository.')
@click.option('--model_path', type=click.Path(), default='pretrain',
help='Pretrain dir holding vae.safetensors + step1x-edit-i1258*.safetensors '
'(relative to --repo if not absolute).')
@click.option('--lora_path', type=click.Path(), default='lora/LDRN.safetensors',
help='FE2E LoRA checkpoint (relative to --repo if not absolute).')
@click.option('--qwen2vl_path', type=click.Path(), default=None,
help='Qwen2.5-VL dir (only required when prompt_type != empty).')
@click.option('--empty_prompt_cache', type=click.Path(), default=None,
help='Path to latent/no_info.npz; defaults to <repo>/latent/no_info.npz.')
@click.option('--num_steps', type=int, default=28,
help='Diffusion steps; ignored if --single_denoise is set (becomes 1).')
@click.option('--cfg_guidance', type=float, default=6.0,
help='CFG guidance strength (FE2E default 6.0).')
@click.option('--size_level', type=int, default=768,
help='Inference resolution hint (passed through to generate_image).')
@click.option('--prompt_type', type=str, default='empty',
help='FE2E flag; "empty" skips Qwen loading and uses cached empty-prompt latent.')
@click.option('--single_denoise', is_flag=True, default=True,
help='Use single-step denoising (README recommended for depth eval).')
@click.option('--no_single_denoise', 'single_denoise', flag_value=False,
help='Disable single-step denoising (multi-step).')
@click.option('--seed', type=int, default=1234)
@click.option('--quantized', is_flag=True,
help='Use FP8 DiT (step1x-edit-i1258-FP8.safetensors).')
@click.option('--offload', is_flag=True, help='CPU offload to save VRAM.')
@click.option('--use_rpred', is_flag=True,
help='[Sanity-check] Use Rpred instead of Lpred as the depth output.')
@click.option('--device', type=str, default='cuda')
@staticmethod
def load(repo_path: str, model_path: str, lora_path: str,
qwen2vl_path: Optional[str], empty_prompt_cache: Optional[str],
num_steps: int, cfg_guidance: float, size_level: int,
prompt_type: str, single_denoise: bool, seed: int,
quantized: bool, offload: bool, use_rpred: bool,
device: str = 'cuda'):
return Baseline(repo_path, model_path, lora_path, qwen2vl_path,
empty_prompt_cache, num_steps, cfg_guidance, size_level,
prompt_type, single_denoise, seed, quantized, offload,
use_rpred, device)
@torch.inference_mode()
def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
omit_batch = image.ndim == 3
if omit_batch:
image = image.unsqueeze(0)
assert image.shape[0] == 1, "FE2E baseline only supports batch size 1"
_, _, H, W = image.shape
# generate_image accepts PIL or torch.Tensor for ref_images.
arr = (image[0].cpu().permute(1, 2, 0).clamp(0, 1).numpy() * 255).astype(np.uint8)
pil = Image.fromarray(arr)
images_list, Lpred, Rpred = self.image_gen.generate_image(
prompt='',
negative_prompt='',
ref_images=pil,
num_samples=1,
num_steps=self.num_steps,
cfg_guidance=self.cfg_guidance,
seed=self.seed,
show_progress=False,
size_level=self.size_level,
args=self.ig_args,
)
# Lpred: [1, 3, h', w'] in [0, 1]; Rpred: [1, 3, h', w'] in [-1, 1].
if self.use_rpred:
pred = Rpred.clamp(-1, 1)
pred = pred.mul(0.5).add(0.5)
else:
pred = Lpred # already in [0, 1]
# Mean over the channel dim to get scalar depth (same convention as
# Marigold / Lotus / DepthMaster). The eval pipeline aligns affine afterwards.
depth = pred[0].mean(dim=0).to(self.device).float()
if depth.shape != (H, W):
depth = F.interpolate(depth[None, None], size=(H, W),
mode='bilinear', align_corners=False)[0, 0]
# FE2E predicts affine-invariant depth via Step1X-Edit + LDRN LoRA (Wang et al., CVPR 2026).
# Emit only this physical key.
result = {'depth_affine_invariant': depth}
if not omit_batch:
result['depth_affine_invariant'] = result['depth_affine_invariant'].unsqueeze(0)
return result