# Reference: https://github.com/AMAP-ML/FE2E # Strictly follows the official `infer/inference.py::ImageGenerator` API. # Distributed entrypoint `evaluation.py` is bypassed; we use the same `ImageGenerator` # class on a single GPU. README usage (for depth): # python -u evaluation.py \ # --model_path ./pretrain --lora ./lora/LDRN.safetensors --single_denoise \ # --prompt_type empty --norm_type ln --task_name depth ... # # Important calling convention (from inference.py): # ImageGenerator.__init__ requires an `args` namespace with at least: # args.prompt_type (='empty' to skip Qwen), # args.single_denoise (sets num_steps = 1 via FE2E's parse_args), # args.empty_prompt_cache (path to latent/no_info.npz), # `generate_image(prompt, negative_prompt, ref_images=PIL_or_tensor, num_steps, # cfg_guidance, seed, ..., args=args)` returns (images_list, Lpred, Rpred). # - Lpred: float tensor in [0, 1] (mul .5 + .5 applied) — corresponds to the *edited* # frame (= depth output for FE2E's depth LoRA). # - Rpred: float tensor in [-1, 1] — the reconstructed reference RGB. # # Output key: `depth_affine_invariant` (Lpred mean over channels). # NEEDS_VERIFICATION: Lpred vs Rpred meaning is inferred from generate_image semantics # (Lpred = left/first frame = denoised target = depth) and confirmed by inner_evaluation.py # using `Lpred` for depth metrics. Switch via `--use-rpred` if a sanity run shows the # depth is actually carried by Rpred. import os import sys from typing import * from pathlib import Path from types import SimpleNamespace import click import torch import torch.nn.functional as F import numpy as np from PIL import Image from moge.test.baseline import MGEBaselineInterface class Baseline(MGEBaselineInterface): def __init__(self, repo_path: str, model_path: str, lora_path: str, qwen2vl_path: Optional[str], empty_prompt_cache: Optional[str], num_steps: int, cfg_guidance: float, size_level: int, prompt_type: str, single_denoise: bool, seed: int, quantized: bool, offload: bool, use_rpred: bool, device: Union[torch.device, str]): repo_path = os.path.abspath(repo_path) if not Path(repo_path).exists(): raise FileNotFoundError( f"Cannot find FE2E repo at {repo_path}. Clone https://github.com/AMAP-ML/FE2E." ) if repo_path not in sys.path: sys.path.insert(0, repo_path) from infer.inference import ImageGenerator from infer.seed_all import seed_all seed_all(seed) def _resolve(p): return p if os.path.isabs(p) else os.path.join(repo_path, p) model_path = _resolve(model_path) lora_path = _resolve(lora_path) if qwen2vl_path is not None: qwen2vl_path = _resolve(qwen2vl_path) else: qwen2vl_path = os.path.join(repo_path, "Qwen") # FE2E DEFAULT_QWEN_DIR if empty_prompt_cache is not None: empty_prompt_cache = _resolve(empty_prompt_cache) else: empty_prompt_cache = os.path.join(repo_path, "latent", "no_info.npz") # ImageGenerator reads several attrs off args (prompt_type, single_denoise, empty_prompt_cache). ig_args = SimpleNamespace( prompt_type=prompt_type, single_denoise=single_denoise, empty_prompt_cache=empty_prompt_cache, ) device = torch.device(device) ae_path = os.path.join(model_path, "vae.safetensors") dit_basename = "step1x-edit-i1258-FP8.safetensors" if quantized else "step1x-edit-i1258.safetensors" dit_path = os.path.join(model_path, dit_basename) for p in (ae_path, dit_path, lora_path, empty_prompt_cache): if not os.path.exists(p): raise FileNotFoundError(f"Missing required FE2E artifact: {p}") self.image_gen = ImageGenerator( ae_path=ae_path, dit_path=dit_path, qwen2vl_model_path=qwen2vl_path, max_length=640, quantized=quantized, offload=offload, lora=lora_path, device=str(device), args=ig_args, ) self.device = device self.num_steps = 1 if single_denoise else num_steps self.cfg_guidance = cfg_guidance self.size_level = size_level self.seed = seed self.ig_args = ig_args self.use_rpred = use_rpred @click.command() @click.option('--repo', 'repo_path', type=click.Path(), default='../FE2E', help='Path to the AMAP-ML/FE2E repository.') @click.option('--model_path', type=click.Path(), default='pretrain', help='Pretrain dir holding vae.safetensors + step1x-edit-i1258*.safetensors ' '(relative to --repo if not absolute).') @click.option('--lora_path', type=click.Path(), default='lora/LDRN.safetensors', help='FE2E LoRA checkpoint (relative to --repo if not absolute).') @click.option('--qwen2vl_path', type=click.Path(), default=None, help='Qwen2.5-VL dir (only required when prompt_type != empty).') @click.option('--empty_prompt_cache', type=click.Path(), default=None, help='Path to latent/no_info.npz; defaults to /latent/no_info.npz.') @click.option('--num_steps', type=int, default=28, help='Diffusion steps; ignored if --single_denoise is set (becomes 1).') @click.option('--cfg_guidance', type=float, default=6.0, help='CFG guidance strength (FE2E default 6.0).') @click.option('--size_level', type=int, default=768, help='Inference resolution hint (passed through to generate_image).') @click.option('--prompt_type', type=str, default='empty', help='FE2E flag; "empty" skips Qwen loading and uses cached empty-prompt latent.') @click.option('--single_denoise', is_flag=True, default=True, help='Use single-step denoising (README recommended for depth eval).') @click.option('--no_single_denoise', 'single_denoise', flag_value=False, help='Disable single-step denoising (multi-step).') @click.option('--seed', type=int, default=1234) @click.option('--quantized', is_flag=True, help='Use FP8 DiT (step1x-edit-i1258-FP8.safetensors).') @click.option('--offload', is_flag=True, help='CPU offload to save VRAM.') @click.option('--use_rpred', is_flag=True, help='[Sanity-check] Use Rpred instead of Lpred as the depth output.') @click.option('--device', type=str, default='cuda') @staticmethod def load(repo_path: str, model_path: str, lora_path: str, qwen2vl_path: Optional[str], empty_prompt_cache: Optional[str], num_steps: int, cfg_guidance: float, size_level: int, prompt_type: str, single_denoise: bool, seed: int, quantized: bool, offload: bool, use_rpred: bool, device: str = 'cuda'): return Baseline(repo_path, model_path, lora_path, qwen2vl_path, empty_prompt_cache, num_steps, cfg_guidance, size_level, prompt_type, single_denoise, seed, quantized, offload, use_rpred, device) @torch.inference_mode() def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: omit_batch = image.ndim == 3 if omit_batch: image = image.unsqueeze(0) assert image.shape[0] == 1, "FE2E baseline only supports batch size 1" _, _, H, W = image.shape # generate_image accepts PIL or torch.Tensor for ref_images. arr = (image[0].cpu().permute(1, 2, 0).clamp(0, 1).numpy() * 255).astype(np.uint8) pil = Image.fromarray(arr) images_list, Lpred, Rpred = self.image_gen.generate_image( prompt='', negative_prompt='', ref_images=pil, num_samples=1, num_steps=self.num_steps, cfg_guidance=self.cfg_guidance, seed=self.seed, show_progress=False, size_level=self.size_level, args=self.ig_args, ) # Lpred: [1, 3, h', w'] in [0, 1]; Rpred: [1, 3, h', w'] in [-1, 1]. if self.use_rpred: pred = Rpred.clamp(-1, 1) pred = pred.mul(0.5).add(0.5) else: pred = Lpred # already in [0, 1] # Mean over the channel dim to get scalar depth (same convention as # Marigold / Lotus / DepthMaster). The eval pipeline aligns affine afterwards. depth = pred[0].mean(dim=0).to(self.device).float() if depth.shape != (H, W): depth = F.interpolate(depth[None, None], size=(H, W), mode='bilinear', align_corners=False)[0, 0] # FE2E predicts affine-invariant depth via Step1X-Edit + LDRN LoRA (Wang et al., CVPR 2026). # Emit only this physical key. result = {'depth_affine_invariant': depth} if not omit_batch: result['depth_affine_invariant'] = result['depth_affine_invariant'].unsqueeze(0) return result