Bbox-caption-8b / infer /common_infer.py
SynLayers's picture
Upload infer/common_infer.py with huggingface_hub
2204787 verified
raw
history blame
6.98 kB
import argparse
import logging
import os
import sys
import torch
from diffusers import FluxTransformer2DModel
from diffusers.configuration_utils import FrozenDict
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
from models.mmdit import CustomFluxTransformer2DModel
from models.multiLayer_adapter import MultiLayerAdapter
from models.pipeline import CustomFluxPipeline, CustomFluxPipelineCfgLayer
from models.transp_vae import AutoencoderKLTransformerTraining as CustomVAE
def resolve_local_config_path(path: str | None) -> str | None:
"""Resolve project-relative asset paths while leaving absolute paths untouched."""
if not path:
return path
if os.path.isabs(path):
return path
return os.path.join(PROJECT_ROOT, path)
def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple:
"""Scale an xyxy box from source_size to target_size."""
scale = target_size / source_size
x0, y0, x1, y1 = box
x0_s = int(x0 * scale)
y0_s = int(y0 * scale)
x1_s = int(x1 * scale)
y1_s = int(y1 * scale)
x0_s = max(0, x0_s)
y0_s = max(0, y0_s)
x1_s = min(target_size, x1_s)
y1_s = min(target_size, y1_s)
return (x0_s, y0_s, x1_s, y1_s)
def quantize_box_16(box: tuple, target_size: int) -> tuple:
"""Quantize an xyxy box to the 16-pixel latent grid."""
x0, y0, x1, y1 = box
x0_q = (x0 // 16) * 16
y0_q = (y0 // 16) * 16
x1_q = ((x1 + 15) // 16) * 16
y1_q = ((y1 + 15) // 16) * 16
x0_q = max(0, x0_q)
y0_q = max(0, y0_q)
x1_q = min(target_size, x1_q)
y1_q = min(target_size, y1_q)
return (x0_q, y0_q, x1_q, y1_q)
def get_layer_boxes(layers: list, source_size: int, target_size: int) -> list:
"""Extract, scale, and quantize prism layer boxes."""
boxes = []
for layer in layers:
box = layer.get("box", [0, 0, source_size, source_size])
scaled_box = scale_box_xyxy(box, source_size, target_size)
boxes.append(quantize_box_16(scaled_box, target_size))
return boxes
def initialize_pipeline(config):
"""Initialize the SynLayers decomposition pipeline."""
transp_vae_path = resolve_local_config_path(config.get("transp_vae_path"))
pretrained_lora_dir = resolve_local_config_path(config.get("pretrained_lora_dir"))
artplus_lora_dir = resolve_local_config_path(config.get("artplus_lora_dir"))
layer_ckpt = resolve_local_config_path(config.get("layer_ckpt"))
adapter_lora_dir = resolve_local_config_path(config.get("adapter_lora_dir"))
lora_ckpt = resolve_local_config_path(config.get("lora_ckpt"))
print("[INFO] Loading Transparent VAE...", flush=True)
vae_args = argparse.Namespace(
max_layers=config.get("max_layers", 48),
decoder_arch=config.get("decoder_arch", "vit"),
pos_embedding=config.get("pos_embedding", "rope"),
layer_embedding=config.get("layer_embedding", "rope"),
single_layer_decoder=config.get("single_layer_decoder", None),
)
transp_vae = CustomVAE(vae_args)
transp_vae_weights = torch.load(transp_vae_path, map_location=torch.device("cuda"))
missing_keys, unexpected_keys = transp_vae.load_state_dict(
transp_vae_weights["model"], strict=False
)
if missing_keys:
print(f"ViT Encoder Missing keys: {missing_keys}")
if unexpected_keys:
print(f"ViT Encoder Unexpected keys: {unexpected_keys}")
transp_vae.eval()
transp_vae = transp_vae.to(torch.device("cuda"))
print("[INFO] Transparent VAE loaded.", flush=True)
print("[INFO] Loading pretrained Transformer model...", flush=True)
transformer_orig = FluxTransformer2DModel.from_pretrained(
config.get("transformer_varient", config["pretrained_model_name_or_path"]),
subfolder="" if "transformer_varient" in config else "transformer",
revision=config.get("revision", None),
variant=config.get("variant", None),
torch_dtype=torch.bfloat16,
cache_dir=config.get("cache_dir", None),
)
mmdit_config = dict(transformer_orig.config)
mmdit_config["_class_name"] = "CustomSD3Transformer2DModel"
mmdit_config["max_layer_num"] = config["max_layer_num"]
mmdit_config = FrozenDict(mmdit_config)
transformer = CustomFluxTransformer2DModel.from_config(mmdit_config).to(
dtype=torch.bfloat16
)
transformer.load_state_dict(transformer_orig.state_dict(), strict=False)
if pretrained_lora_dir:
print("[INFO] Loading pretrained LoRA weights...", flush=True)
lora_state_dict = CustomFluxPipeline.lora_state_dict(pretrained_lora_dir)
CustomFluxPipeline.load_lora_into_transformer(lora_state_dict, None, transformer)
transformer.fuse_lora(safe_fusing=True)
transformer.unload_lora()
if artplus_lora_dir:
print("[INFO] Loading artplus LoRA weights...", flush=True)
lora_state_dict = CustomFluxPipeline.lora_state_dict(artplus_lora_dir)
CustomFluxPipeline.load_lora_into_transformer(lora_state_dict, None, transformer)
transformer.fuse_lora(safe_fusing=True)
transformer.unload_lora()
layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else ""
if os.path.exists(layer_pe_path):
print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True)
layer_pe = torch.load(layer_pe_path)
transformer.load_state_dict(layer_pe, strict=False)
print("[INFO] Loading MultiLayer-Adapter...", flush=True)
multiLayer_adapter = MultiLayerAdapter.from_pretrained(
config["pretrained_adapter_path"]
).to(torch.bfloat16).to(torch.device("cuda"))
if adapter_lora_dir:
print("[INFO] Loading adapter LoRA weights...", flush=True)
lora_state_dict = CustomFluxPipeline.lora_state_dict(adapter_lora_dir)
CustomFluxPipeline.load_lora_into_transformer(
lora_state_dict, None, multiLayer_adapter
)
multiLayer_adapter.fuse_lora(safe_fusing=True)
multiLayer_adapter.unload_lora()
multiLayer_adapter.set_layerPE(transformer.layer_pe, transformer.max_layer_num)
pipeline = CustomFluxPipelineCfgLayer.from_pretrained(
config["pretrained_model_name_or_path"],
transformer=transformer,
revision=config.get("revision", None),
variant=config.get("variant", None),
torch_dtype=torch.bfloat16,
cache_dir=config.get("cache_dir", None),
).to(torch.device("cuda"))
pipeline.set_multiLayerAdapter(multiLayer_adapter)
if lora_ckpt:
print(f"[INFO] Loading trained LoRA from {lora_ckpt}...", flush=True)
pipeline.load_lora_weights(lora_ckpt, adapter_name="layer")
return pipeline, transp_vae