| import json |
| from pathlib import Path |
| from typing import Literal, Optional |
|
|
| import torch |
| from modules.autoencoder import AutoEncoder, AutoEncoderParams |
| from modules.conditioner import HFEmbedder |
| from modules.flux_model import Flux, FluxParams |
| from safetensors.torch import load_file as load_sft |
|
|
| try: |
| from enum import StrEnum |
| except: |
| from enum import Enum |
|
|
| class StrEnum(str, Enum): |
| pass |
|
|
|
|
| from pydantic import BaseModel, ConfigDict |
| from loguru import logger |
|
|
|
|
| class ModelVersion(StrEnum): |
| flux_dev = "flux-dev" |
| flux_schnell = "flux-schnell" |
|
|
|
|
| class QuantizationDtype(StrEnum): |
| qfloat8 = "qfloat8" |
| qint2 = "qint2" |
| qint4 = "qint4" |
| qint8 = "qint8" |
| bfloat16 = "bfloat16" |
| float16 = "float16" |
|
|
|
|
| class ModelSpec(BaseModel): |
| version: ModelVersion |
| params: FluxParams |
| ae_params: AutoEncoderParams |
| ckpt_path: str | None |
| |
| clip_path: str | None = "openai/clip-vit-large-patch14" |
| ae_path: str | None |
| repo_id: str | None |
| repo_flow: str | None |
| repo_ae: str | None |
| text_enc_max_length: int = 512 |
| text_enc_path: str | None |
| text_enc_device: str | torch.device | None = "cuda:0" |
| ae_device: str | torch.device | None = "cuda:0" |
| flux_device: str | torch.device | None = "cuda:0" |
| flow_dtype: str = "float16" |
| ae_dtype: str = "bfloat16" |
| text_enc_dtype: str = "bfloat16" |
| |
| num_to_quant: Optional[int] = 20 |
| quantize_extras: bool = False |
| compile_extras: bool = False |
| compile_blocks: bool = False |
| flow_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 |
| text_enc_quantization_dtype: Optional[QuantizationDtype] = QuantizationDtype.qfloat8 |
| ae_quantization_dtype: Optional[QuantizationDtype] = None |
| clip_quantization_dtype: Optional[QuantizationDtype] = None |
| offload_text_encoder: bool = False |
| offload_vae: bool = False |
| offload_flow: bool = False |
| prequantized_flow: bool = False |
|
|
| |
| quantize_modulation: bool = True |
| |
| quantize_flow_embedder_layers: bool = False |
|
|
| model_config: ConfigDict = { |
| "arbitrary_types_allowed": True, |
| "use_enum_values": True, |
| } |
|
|
|
|
| def load_models(config: ModelSpec) -> tuple[Flux, AutoEncoder, HFEmbedder, HFEmbedder]: |
| flow = load_flow_model(config) |
| ae = load_autoencoder(config) |
| clip, t5 = load_text_encoders(config) |
| return flow, ae, clip, t5 |
|
|
|
|
| def parse_device(device: str | torch.device | None) -> torch.device: |
| if isinstance(device, str): |
| return torch.device(device) |
| elif isinstance(device, torch.device): |
| return device |
| else: |
| return torch.device("cuda:0") |
|
|
|
|
| def into_dtype(dtype: str) -> torch.dtype: |
| if isinstance(dtype, torch.dtype): |
| return dtype |
| if dtype == "float16": |
| return torch.float16 |
| elif dtype == "bfloat16": |
| return torch.bfloat16 |
| elif dtype == "float32": |
| return torch.float32 |
| else: |
| raise ValueError(f"Invalid dtype: {dtype}") |
|
|
|
|
| def into_device(device: str | torch.device | None) -> torch.device: |
| if isinstance(device, str): |
| return torch.device(device) |
| elif isinstance(device, torch.device): |
| return device |
| elif isinstance(device, int): |
| return torch.device(f"cuda:{device}") |
| else: |
| return torch.device("cuda:0") |
|
|
|
|
| def load_config( |
| name: ModelVersion = ModelVersion.flux_dev, |
| flux_path: str | None = None, |
| ae_path: str | None = None, |
| text_enc_path: str | None = None, |
| text_enc_device: str | torch.device | None = None, |
| ae_device: str | torch.device | None = None, |
| flux_device: str | torch.device | None = None, |
| flow_dtype: str = "float16", |
| ae_dtype: str = "bfloat16", |
| text_enc_dtype: str = "bfloat16", |
| num_to_quant: Optional[int] = 20, |
| compile_extras: bool = False, |
| compile_blocks: bool = False, |
| offload_text_enc: bool = False, |
| offload_ae: bool = False, |
| offload_flow: bool = False, |
| quant_text_enc: Optional[Literal["float8", "qint2", "qint4", "qint8"]] = None, |
| quant_ae: bool = False, |
| prequantized_flow: bool = False, |
| quantize_modulation: bool = True, |
| quantize_flow_embedder_layers: bool = False, |
| ) -> ModelSpec: |
| """ |
| Load a model configuration using the passed arguments. |
| """ |
| text_enc_device = str(parse_device(text_enc_device)) |
| ae_device = str(parse_device(ae_device)) |
| flux_device = str(parse_device(flux_device)) |
| return ModelSpec( |
| version=name, |
| repo_id=( |
| "black-forest-labs/FLUX.1-dev" |
| if name == ModelVersion.flux_dev |
| else "black-forest-labs/FLUX.1-schnell" |
| ), |
| repo_flow=( |
| "flux1-dev.sft" if name == ModelVersion.flux_dev else "flux1-schnell.sft" |
| ), |
| repo_ae="ae.sft", |
| ckpt_path=flux_path, |
| params=FluxParams( |
| in_channels=64, |
| vec_in_dim=768, |
| context_in_dim=4096, |
| hidden_size=3072, |
| mlp_ratio=4.0, |
| num_heads=24, |
| depth=19, |
| depth_single_blocks=38, |
| axes_dim=[16, 56, 56], |
| theta=10_000, |
| qkv_bias=True, |
| guidance_embed=name == ModelVersion.flux_dev, |
| ), |
| ae_path=ae_path, |
| ae_params=AutoEncoderParams( |
| resolution=256, |
| in_channels=3, |
| ch=128, |
| out_ch=3, |
| ch_mult=[1, 2, 4, 4], |
| num_res_blocks=2, |
| z_channels=16, |
| scale_factor=0.3611, |
| shift_factor=0.1159, |
| ), |
| text_enc_path=text_enc_path, |
| text_enc_device=text_enc_device, |
| ae_device=ae_device, |
| flux_device=flux_device, |
| flow_dtype=flow_dtype, |
| ae_dtype=ae_dtype, |
| text_enc_dtype=text_enc_dtype, |
| text_enc_max_length=512 if name == ModelVersion.flux_dev else 256, |
| num_to_quant=num_to_quant, |
| compile_extras=compile_extras, |
| compile_blocks=compile_blocks, |
| offload_flow=offload_flow, |
| offload_text_encoder=offload_text_enc, |
| offload_vae=offload_ae, |
| text_enc_quantization_dtype={ |
| "float8": QuantizationDtype.qfloat8, |
| "qint2": QuantizationDtype.qint2, |
| "qint4": QuantizationDtype.qint4, |
| "qint8": QuantizationDtype.qint8, |
| }.get(quant_text_enc, None), |
| ae_quantization_dtype=QuantizationDtype.qfloat8 if quant_ae else None, |
| prequantized_flow=prequantized_flow, |
| quantize_modulation=quantize_modulation, |
| quantize_flow_embedder_layers=quantize_flow_embedder_layers, |
| ) |
|
|
|
|
| def load_config_from_path(path: str) -> ModelSpec: |
| path_path = Path(path) |
| if not path_path.exists(): |
| raise ValueError(f"Path {path} does not exist") |
| if not path_path.is_file(): |
| raise ValueError(f"Path {path} is not a file") |
| return ModelSpec(**json.loads(path_path.read_text())) |
|
|
|
|
| def print_load_warning(missing: list[str], unexpected: list[str]) -> None: |
| if len(missing) > 0 and len(unexpected) > 0: |
| logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
| logger.warning("\n" + "-" * 79 + "\n") |
| logger.warning( |
| f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
| ) |
| elif len(missing) > 0: |
| logger.warning(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) |
| elif len(unexpected) > 0: |
| logger.warning( |
| f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected) |
| ) |
|
|
|
|
| def load_flow_model(config: ModelSpec) -> Flux: |
| ckpt_path = config.ckpt_path |
| FluxClass = Flux |
|
|
| with torch.device("meta"): |
| model = FluxClass(config, dtype=into_dtype(config.flow_dtype)) |
| if not config.prequantized_flow: |
| model.type(into_dtype(config.flow_dtype)) |
|
|
| if ckpt_path is not None: |
| |
| sd = load_sft(ckpt_path, device="cpu") |
| missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) |
| print_load_warning(missing, unexpected) |
| if not config.prequantized_flow: |
| model.type(into_dtype(config.flow_dtype)) |
| return model |
|
|
|
|
| def load_text_encoders(config: ModelSpec) -> tuple[HFEmbedder, HFEmbedder]: |
| clip = HFEmbedder( |
| config.clip_path, |
| max_length=77, |
| torch_dtype=into_dtype(config.text_enc_dtype), |
| device=into_device(config.text_enc_device).index or 0, |
| is_clip=True, |
| quantization_dtype=config.clip_quantization_dtype, |
| ) |
| t5 = HFEmbedder( |
| config.text_enc_path, |
| max_length=config.text_enc_max_length, |
| torch_dtype=into_dtype(config.text_enc_dtype), |
| device=into_device(config.text_enc_device).index or 0, |
| quantization_dtype=config.text_enc_quantization_dtype, |
| ) |
| return clip, t5 |
|
|
|
|
| def load_autoencoder(config: ModelSpec) -> AutoEncoder: |
| ckpt_path = config.ae_path |
| with torch.device("meta" if ckpt_path is not None else config.ae_device): |
| ae = AutoEncoder(config.ae_params).to(into_dtype(config.ae_dtype)) |
|
|
| if ckpt_path is not None: |
| sd = load_sft(ckpt_path, device=str(config.ae_device)) |
| missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) |
| print_load_warning(missing, unexpected) |
| ae.to(device=into_device(config.ae_device), dtype=into_dtype(config.ae_dtype)) |
| if config.ae_quantization_dtype is not None: |
| from float8_quantize import recursive_swap_linears |
|
|
| recursive_swap_linears(ae) |
| if config.offload_vae: |
| ae.to("cpu") |
| torch.cuda.empty_cache() |
| return ae |
|
|
|
|
| class LoadedModels(BaseModel): |
| flow: Flux |
| ae: AutoEncoder |
| clip: HFEmbedder |
| t5: HFEmbedder |
| config: ModelSpec |
|
|
| model_config = { |
| "arbitrary_types_allowed": True, |
| "use_enum_values": True, |
| } |
|
|
|
|
| def load_models_from_config_path( |
| path: str, |
| ) -> LoadedModels: |
| config = load_config_from_path(path) |
| clip, t5 = load_text_encoders(config) |
| return LoadedModels( |
| flow=load_flow_model(config), |
| ae=load_autoencoder(config), |
| clip=clip, |
| t5=t5, |
| config=config, |
| ) |
|
|
|
|
| def load_models_from_config(config: ModelSpec) -> LoadedModels: |
| clip, t5 = load_text_encoders(config) |
| return LoadedModels( |
| flow=load_flow_model(config), |
| ae=load_autoencoder(config), |
| clip=clip, |
| t5=t5, |
| config=config, |
| ) |
|
|