import argparse import logging import os import sys import torch from diffusers import FluxTransformer2DModel from diffusers.configuration_utils import FrozenDict PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "0") from models.mmdit import CustomFluxTransformer2DModel from models.multiLayer_adapter import MultiLayerAdapter from models.pipeline import CustomFluxPipeline, CustomFluxPipelineCfgLayer from models.transp_vae import AutoencoderKLTransformerTraining as CustomVAE def resolve_local_config_path(path: str | None) -> str | None: """Resolve project-relative asset paths while leaving absolute paths untouched.""" if not path: return path if os.path.isabs(path): return path return os.path.join(PROJECT_ROOT, path) def scale_box_xyxy(box, source_size: int, target_size: int) -> tuple: """Scale an xyxy box from source_size to target_size.""" scale = target_size / source_size x0, y0, x1, y1 = box x0_s = int(x0 * scale) y0_s = int(y0 * scale) x1_s = int(x1 * scale) y1_s = int(y1 * scale) x0_s = max(0, x0_s) y0_s = max(0, y0_s) x1_s = min(target_size, x1_s) y1_s = min(target_size, y1_s) return (x0_s, y0_s, x1_s, y1_s) def quantize_box_16(box: tuple, target_size: int) -> tuple: """Quantize an xyxy box to the 16-pixel latent grid.""" x0, y0, x1, y1 = box x0_q = (x0 // 16) * 16 y0_q = (y0 // 16) * 16 x1_q = ((x1 + 15) // 16) * 16 y1_q = ((y1 + 15) // 16) * 16 x0_q = max(0, x0_q) y0_q = max(0, y0_q) x1_q = min(target_size, x1_q) y1_q = min(target_size, y1_q) return (x0_q, y0_q, x1_q, y1_q) def get_layer_boxes(layers: list, source_size: int, target_size: int) -> list: """Extract, scale, and quantize prism layer boxes.""" boxes = [] for layer in layers: box = layer.get("box", [0, 0, source_size, source_size]) scaled_box = scale_box_xyxy(box, source_size, target_size) boxes.append(quantize_box_16(scaled_box, target_size)) return boxes def initialize_pipeline(config): """Initialize the SynLayers decomposition pipeline.""" transp_vae_path = resolve_local_config_path(config.get("transp_vae_path")) pretrained_lora_dir = resolve_local_config_path(config.get("pretrained_lora_dir")) artplus_lora_dir = resolve_local_config_path(config.get("artplus_lora_dir")) layer_ckpt = resolve_local_config_path(config.get("layer_ckpt")) adapter_lora_dir = resolve_local_config_path(config.get("adapter_lora_dir")) lora_ckpt = resolve_local_config_path(config.get("lora_ckpt")) print("[INFO] Loading Transparent VAE...", flush=True) vae_args = argparse.Namespace( max_layers=config.get("max_layers", 48), decoder_arch=config.get("decoder_arch", "vit"), pos_embedding=config.get("pos_embedding", "rope"), layer_embedding=config.get("layer_embedding", "rope"), single_layer_decoder=config.get("single_layer_decoder", None), ) transp_vae = CustomVAE(vae_args) transp_vae_weights = torch.load(transp_vae_path, map_location=torch.device("cuda")) missing_keys, unexpected_keys = transp_vae.load_state_dict( transp_vae_weights["model"], strict=False ) if missing_keys: print(f"ViT Encoder Missing keys: {missing_keys}") if unexpected_keys: print(f"ViT Encoder Unexpected keys: {unexpected_keys}") transp_vae.eval() transp_vae = transp_vae.to(torch.device("cuda")) print("[INFO] Transparent VAE loaded.", flush=True) print("[INFO] Loading pretrained Transformer model...", flush=True) transformer_orig = FluxTransformer2DModel.from_pretrained( config.get("transformer_varient", config["pretrained_model_name_or_path"]), subfolder="" if "transformer_varient" in config else "transformer", revision=config.get("revision", None), variant=config.get("variant", None), torch_dtype=torch.bfloat16, cache_dir=config.get("cache_dir", None), ) mmdit_config = dict(transformer_orig.config) mmdit_config["_class_name"] = "CustomSD3Transformer2DModel" mmdit_config["max_layer_num"] = config["max_layer_num"] mmdit_config = FrozenDict(mmdit_config) transformer = CustomFluxTransformer2DModel.from_config(mmdit_config).to( dtype=torch.bfloat16 ) transformer.load_state_dict(transformer_orig.state_dict(), strict=False) if pretrained_lora_dir: print("[INFO] Loading pretrained LoRA weights...", flush=True) lora_state_dict = CustomFluxPipeline.lora_state_dict(pretrained_lora_dir) CustomFluxPipeline.load_lora_into_transformer(lora_state_dict, None, transformer) transformer.fuse_lora(safe_fusing=True) transformer.unload_lora() if artplus_lora_dir: print("[INFO] Loading artplus LoRA weights...", flush=True) lora_state_dict = CustomFluxPipeline.lora_state_dict(artplus_lora_dir) CustomFluxPipeline.load_lora_into_transformer(lora_state_dict, None, transformer) transformer.fuse_lora(safe_fusing=True) transformer.unload_lora() layer_pe_path = os.path.join(layer_ckpt, "layer_pe.pth") if layer_ckpt else "" if os.path.exists(layer_pe_path): print(f"[INFO] Loading layer_pe from {layer_pe_path}...", flush=True) layer_pe = torch.load(layer_pe_path) transformer.load_state_dict(layer_pe, strict=False) print("[INFO] Loading MultiLayer-Adapter...", flush=True) multiLayer_adapter = MultiLayerAdapter.from_pretrained( config["pretrained_adapter_path"] ).to(torch.bfloat16).to(torch.device("cuda")) if adapter_lora_dir: print("[INFO] Loading adapter LoRA weights...", flush=True) lora_state_dict = CustomFluxPipeline.lora_state_dict(adapter_lora_dir) CustomFluxPipeline.load_lora_into_transformer( lora_state_dict, None, multiLayer_adapter ) multiLayer_adapter.fuse_lora(safe_fusing=True) multiLayer_adapter.unload_lora() multiLayer_adapter.set_layerPE(transformer.layer_pe, transformer.max_layer_num) pipeline = CustomFluxPipelineCfgLayer.from_pretrained( config["pretrained_model_name_or_path"], transformer=transformer, revision=config.get("revision", None), variant=config.get("variant", None), torch_dtype=torch.bfloat16, cache_dir=config.get("cache_dir", None), ).to(torch.device("cuda")) pipeline.set_multiLayerAdapter(multiLayer_adapter) if lora_ckpt: print(f"[INFO] Loading trained LoRA from {lora_ckpt}...", flush=True) pipeline.load_lora_weights(lora_ckpt, adapter_name="layer") return pipeline, transp_vae