"""Model export functionality for SDXL Model Merger.""" import torch from safetensors.torch import save_file from . import config from .config import SCRIPT_DIR from .gpu_decorator import GPU def _quantize_model(model, qtype: str): """Apply torchao quantization to a model using quantize_.""" from torchao.quantization import quantize_ if qtype == "int8": from torchao.quantization import Int8WeightOnlyConfig print(" ⚙️ Quantizing with int8_weight_only...") config = Int8WeightOnlyConfig() quantize_(model, config) elif qtype == "int4": from torchao.quantization import Int4WeightOnlyConfig print(" ⚙️ Quantizing with int4_weight_only (group_size=32)...") config = Int4WeightOnlyConfig(group_size=32) quantize_(model, config) elif qtype == "float8": from torchao.quantization import Float8DynamicActivationFloat8WeightConfig print(" ⚙️ Quantizing with float8_dynamic_activation_float8_weight...") config = Float8DynamicActivationFloat8WeightConfig() quantize_(model, config) else: raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8") @GPU(duration=180) def _extract_and_save(pipe, include_lora, quantize, qtype, save_format): """GPU-decorated helper that extracts weights and saves the model.""" if include_lora: try: pipe.unload_lora_weights() except Exception as e: print(f" ℹ️ Could not unload LoRAs: {e}") # Quantize components in-place before extracting state dicts if quantize and qtype != "none": _quantize_model(pipe.unet, qtype) # torchao quantized tensors cannot be saved with safetensors, use torch.save instead # Don't dequantize - keep the quantized format for smaller file size merged_state_dict = {} # Extract UNet weights for k, v in pipe.unet.state_dict().items(): # For quantized tensors, save directly; otherwise convert to half if hasattr(v, 'dequantize'): # Keep quantized tensor as-is for smaller file size merged_state_dict[f"unet.{k}"] = v else: merged_state_dict[f"unet.{k}"] = v.contiguous().half() # Extract text encoder weights if pipe.text_encoder is not None: for k, v in pipe.text_encoder.state_dict().items(): if hasattr(v, 'dequantize'): merged_state_dict[f"text_encoder.{k}"] = v else: merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half() if pipe.text_encoder_2 is not None: for k, v in pipe.text_encoder_2.state_dict().items(): if hasattr(v, 'dequantize'): merged_state_dict[f"text_encoder_2.{k}"] = v else: merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half() # Extract VAE weights if pipe.vae is not None: for k, v in pipe.vae.state_dict().items(): if hasattr(v, 'dequantize'): merged_state_dict[f"first_stage_model.{k}"] = v else: merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half() # Save model ext = ".bin" if save_format == "bin" else ".safetensors" prefix = f"{qtype}_" if quantize and qtype != "none" else "" out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}" if quantize and qtype != "none": # torchao quantized tensors are not compatible with safetensors # Use torch.save instead which preserves the quantization format ext = ".pt" out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt" torch.save(merged_state_dict, str(out_path)) elif ext == ".bin": torch.save(merged_state_dict, str(out_path)) else: save_file(merged_state_dict, str(out_path)) return out_path def export_merged_model( include_lora: bool, quantize: bool, qtype: str, save_format: str = "safetensors", ): """ Export the merged pipeline model with optional LoRA baking and quantization. Args: include_lora: Whether to include fused LoRAs in export quantize: Whether to apply quantization qtype: Quantization type - 'none', 'int8', 'int4', or 'float8' save_format: Output format - 'safetensors' or 'bin' Returns: Tuple of (output_path or None, status message) """ # Fetch the pipeline at call time — avoids the stale import-by-value problem. pipe = config.get_pipe() if not pipe: return None, "⚠️ Please load a pipeline first." try: # Validate quantization type valid_qtypes = ("none", "int8", "int4", "float8") if qtype not in valid_qtypes: return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}" out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format) size_gb = out_path.stat().st_size / 1024**3 if quantize and qtype != "none": msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)" else: msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)" return str(out_path), msg except ImportError as e: return None, f"❌ Missing dependency: {str(e)}" except Exception as e: import traceback print(traceback.format_exc()) return None, f"❌ Export failed: {str(e)}" def get_export_status() -> str: """Get current export capability status.""" try: from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig return "✅ torchao available for quantization" except ImportError: return "ℹ️ Install torchao for quantization support: pip install torchao"