Diffusers
Safetensors
EvalMDE / Edit2Perceive /inference.py
zeyuren2002's picture
Add files using upload-large-folder tool
7f921f4 verified
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from typing import Optional, Sequence, Union
import torch
import numpy as np
from PIL import Image
from pipelines.flux_image_new import FluxImagePipeline
from models.utils import load_state_dict,parse_flux_model_configs
from models.unified_dataset import UnifiedDataset, gen_points
from models.flux_dit import FluxDiTStateDictConverter
import matplotlib.pyplot as plt
converter = FluxDiTStateDictConverter()
# 模型配置
MODEL_CONFIGS = {
"depth": "ckpts/depth.safetensors",
"normal": "ckpts/normal.safetensors",
"matting": "ckpts/matting.safetensors",
"depth_lora": "ckpts/depth_lora.safetensors",
"normal_lora": "ckpts/normal_lora.safetensors",
"matting_lora": "ckpts/matting_lora.safetensors",
}
def inference(
model_root: str,
task: str,
input_paths: Union[str, Sequence[str]],
resolution: int = 768,
num_inference_steps: int = 8,
seed: int = 42,
output_path: Optional[str] = None,
torch_dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
) -> None:
"""Run inference on one or more images."""
# 获取模型路径
if task not in MODEL_CONFIGS:
raise ValueError(f"task must be one of {list(MODEL_CONFIGS.keys())}")
state_dict_path = MODEL_CONFIGS[task]
# 加载模型
print("Loading FluxImagePipeline ...")
pipe = FluxImagePipeline.from_pretrained(
torch_dtype=torch_dtype,
device=device,
model_configs=parse_flux_model_configs(model_root),
)
if "lora" in task:
pipe.load_lora(pipe.dit, state_dict_path, hotload=False)
else:
state_dict = load_state_dict(state_dict_path)
pipe.dit.load_state_dict(state_dict)
print("Model loaded successfully!")
# 准备图片列表
if isinstance(input_paths, str):
images = [path for path in input_paths.split(",") if path]
else:
images = list(input_paths)
transform = UnifiedDataset.default_image_operator(height=resolution, width=resolution)
task = task.replace("_lora", "")
# 处理每张图片
for image_path in images:
if task in {"depth", "normal"}:
out_np = pipe(
prompt=f"Transform to {task} map while maintaining original composition",
kontext_images=transform(image_path),
height=768,
width=768,
embedded_guidance=1,
num_inference_steps=num_inference_steps,
seed=seed,
output_type="np",
rand_device=device,
task=task,
)
else: # matting
alpha = np.zeros((resolution, resolution), dtype=np.float32)
alpha[resolution // 4 : resolution * 3 // 4, resolution // 4 : resolution * 3 // 4] = 1.0
points, coords = gen_points(alpha, num_points=10, radius=30)
img = Image.open(image_path).convert("RGB").resize((resolution, resolution))
for i in range(coords.shape[0] // 2):
x, y = coords[i * 2], coords[i * 2 + 1]
for dx in range(-9, 10):
for dy in range(-9, 10):
if dx * dx + dy * dy <= 81:
px = int(x * resolution) + dx
py = int(y * resolution) + dy
if 0 <= px < resolution and 0 <= py < resolution:
img.putpixel((px, py), (255, 0, 0))
img.save(image_path.replace(".jpg", ".png").replace(".png", "_debug_points.png"))
points = torch.from_numpy(points * 2 - 1).repeat(3, 1, 1).to(device)
out_np = pipe(
prompt="Extract the foreground object and generate alpha matte",
kontext_images=[transform(image_path), points],
height=768,
width=768,
embedded_guidance=1,
num_inference_steps=10,
seed=seed,
output_type="np",
rand_device=device,
task=task,
)
# 后处理
if task == "depth":
if out_np.ndim == 3:
out_np = np.mean(out_np, axis=2)
cmap = plt.get_cmap("Spectral")
out_np = cmap(out_np)[:, :, :3]
out_np = (out_np * 255).astype(np.uint8)
elif task == "normal":
out_np = (out_np + 1) / 2 * 255.0
out_np = out_np.astype(np.uint8)
else: # matting
out_np = (out_np * 255.0).astype(np.uint8)
# 保存结果
if output_path is None:
save_path = image_path.replace(".jpg", ".png").replace(".png", f"_{task}.png")
else:
save_path = output_path
Image.fromarray(out_np).save(save_path)
print(f"Saved output to {save_path}")
if __name__ == "__main__":
# Please Change the model root path below to your own model directory
model_root = "./FLUX.1-Kontext-dev"
inference(
model_root=model_root,
task="depth", # Options: "depth", "normal", "matting"
input_paths="samples/cutecat.jpg"
)