SDXL-Model-Merger / src /exporter.py
Kyle Pearson
Replace dependency, add quantizer support, fix safetensors export, improve error handling, update UI docs
b89e643
"""Model export functionality for SDXL Model Merger."""
import torch
from safetensors.torch import save_file
from . import config
from .config import SCRIPT_DIR
from .gpu_decorator import GPU
def _quantize_model(model, qtype: str):
"""Apply torchao quantization to a model using quantize_."""
from torchao.quantization import quantize_
if qtype == "int8":
from torchao.quantization import Int8WeightOnlyConfig
print(" ⚙️ Quantizing with int8_weight_only...")
config = Int8WeightOnlyConfig()
quantize_(model, config)
elif qtype == "int4":
from torchao.quantization import Int4WeightOnlyConfig
print(" ⚙️ Quantizing with int4_weight_only (group_size=32)...")
config = Int4WeightOnlyConfig(group_size=32)
quantize_(model, config)
elif qtype == "float8":
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
print(" ⚙️ Quantizing with float8_dynamic_activation_float8_weight...")
config = Float8DynamicActivationFloat8WeightConfig()
quantize_(model, config)
else:
raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8")
@GPU(duration=180)
def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
"""GPU-decorated helper that extracts weights and saves the model."""
if include_lora:
try:
pipe.unload_lora_weights()
except Exception as e:
print(f" ℹ️ Could not unload LoRAs: {e}")
# Quantize components in-place before extracting state dicts
if quantize and qtype != "none":
_quantize_model(pipe.unet, qtype)
# torchao quantized tensors cannot be saved with safetensors, use torch.save instead
# Don't dequantize - keep the quantized format for smaller file size
merged_state_dict = {}
# Extract UNet weights
for k, v in pipe.unet.state_dict().items():
# For quantized tensors, save directly; otherwise convert to half
if hasattr(v, 'dequantize'):
# Keep quantized tensor as-is for smaller file size
merged_state_dict[f"unet.{k}"] = v
else:
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
# Extract text encoder weights
if pipe.text_encoder is not None:
for k, v in pipe.text_encoder.state_dict().items():
if hasattr(v, 'dequantize'):
merged_state_dict[f"text_encoder.{k}"] = v
else:
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
if pipe.text_encoder_2 is not None:
for k, v in pipe.text_encoder_2.state_dict().items():
if hasattr(v, 'dequantize'):
merged_state_dict[f"text_encoder_2.{k}"] = v
else:
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
# Extract VAE weights
if pipe.vae is not None:
for k, v in pipe.vae.state_dict().items():
if hasattr(v, 'dequantize'):
merged_state_dict[f"first_stage_model.{k}"] = v
else:
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
# Save model
ext = ".bin" if save_format == "bin" else ".safetensors"
prefix = f"{qtype}_" if quantize and qtype != "none" else ""
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
if quantize and qtype != "none":
# torchao quantized tensors are not compatible with safetensors
# Use torch.save instead which preserves the quantization format
ext = ".pt"
out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt"
torch.save(merged_state_dict, str(out_path))
elif ext == ".bin":
torch.save(merged_state_dict, str(out_path))
else:
save_file(merged_state_dict, str(out_path))
return out_path
def export_merged_model(
include_lora: bool,
quantize: bool,
qtype: str,
save_format: str = "safetensors",
):
"""
Export the merged pipeline model with optional LoRA baking and quantization.
Args:
include_lora: Whether to include fused LoRAs in export
quantize: Whether to apply quantization
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
save_format: Output format - 'safetensors' or 'bin'
Returns:
Tuple of (output_path or None, status message)
"""
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
pipe = config.get_pipe()
if not pipe:
return None, "⚠️ Please load a pipeline first."
try:
# Validate quantization type
valid_qtypes = ("none", "int8", "int4", "float8")
if qtype not in valid_qtypes:
return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
size_gb = out_path.stat().st_size / 1024**3
if quantize and qtype != "none":
msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
else:
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
return str(out_path), msg
except ImportError as e:
return None, f"❌ Missing dependency: {str(e)}"
except Exception as e:
import traceback
print(traceback.format_exc())
return None, f"❌ Export failed: {str(e)}"
def get_export_status() -> str:
"""Get current export capability status."""
try:
from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig
return "✅ torchao available for quantization"
except ImportError:
return "ℹ️ Install torchao for quantization support: pip install torchao"