Spaces:
Running on Zero
Running on Zero
Kyle Pearson
Replace dependency, add quantizer support, fix safetensors export, improve error handling, update UI docs
b89e643 | """Model export functionality for SDXL Model Merger.""" | |
| import torch | |
| from safetensors.torch import save_file | |
| from . import config | |
| from .config import SCRIPT_DIR | |
| from .gpu_decorator import GPU | |
| def _quantize_model(model, qtype: str): | |
| """Apply torchao quantization to a model using quantize_.""" | |
| from torchao.quantization import quantize_ | |
| if qtype == "int8": | |
| from torchao.quantization import Int8WeightOnlyConfig | |
| print(" ⚙️ Quantizing with int8_weight_only...") | |
| config = Int8WeightOnlyConfig() | |
| quantize_(model, config) | |
| elif qtype == "int4": | |
| from torchao.quantization import Int4WeightOnlyConfig | |
| print(" ⚙️ Quantizing with int4_weight_only (group_size=32)...") | |
| config = Int4WeightOnlyConfig(group_size=32) | |
| quantize_(model, config) | |
| elif qtype == "float8": | |
| from torchao.quantization import Float8DynamicActivationFloat8WeightConfig | |
| print(" ⚙️ Quantizing with float8_dynamic_activation_float8_weight...") | |
| config = Float8DynamicActivationFloat8WeightConfig() | |
| quantize_(model, config) | |
| else: | |
| raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8") | |
| def _extract_and_save(pipe, include_lora, quantize, qtype, save_format): | |
| """GPU-decorated helper that extracts weights and saves the model.""" | |
| if include_lora: | |
| try: | |
| pipe.unload_lora_weights() | |
| except Exception as e: | |
| print(f" ℹ️ Could not unload LoRAs: {e}") | |
| # Quantize components in-place before extracting state dicts | |
| if quantize and qtype != "none": | |
| _quantize_model(pipe.unet, qtype) | |
| # torchao quantized tensors cannot be saved with safetensors, use torch.save instead | |
| # Don't dequantize - keep the quantized format for smaller file size | |
| merged_state_dict = {} | |
| # Extract UNet weights | |
| for k, v in pipe.unet.state_dict().items(): | |
| # For quantized tensors, save directly; otherwise convert to half | |
| if hasattr(v, 'dequantize'): | |
| # Keep quantized tensor as-is for smaller file size | |
| merged_state_dict[f"unet.{k}"] = v | |
| else: | |
| merged_state_dict[f"unet.{k}"] = v.contiguous().half() | |
| # Extract text encoder weights | |
| if pipe.text_encoder is not None: | |
| for k, v in pipe.text_encoder.state_dict().items(): | |
| if hasattr(v, 'dequantize'): | |
| merged_state_dict[f"text_encoder.{k}"] = v | |
| else: | |
| merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half() | |
| if pipe.text_encoder_2 is not None: | |
| for k, v in pipe.text_encoder_2.state_dict().items(): | |
| if hasattr(v, 'dequantize'): | |
| merged_state_dict[f"text_encoder_2.{k}"] = v | |
| else: | |
| merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half() | |
| # Extract VAE weights | |
| if pipe.vae is not None: | |
| for k, v in pipe.vae.state_dict().items(): | |
| if hasattr(v, 'dequantize'): | |
| merged_state_dict[f"first_stage_model.{k}"] = v | |
| else: | |
| merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half() | |
| # Save model | |
| ext = ".bin" if save_format == "bin" else ".safetensors" | |
| prefix = f"{qtype}_" if quantize and qtype != "none" else "" | |
| out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}" | |
| if quantize and qtype != "none": | |
| # torchao quantized tensors are not compatible with safetensors | |
| # Use torch.save instead which preserves the quantization format | |
| ext = ".pt" | |
| out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt" | |
| torch.save(merged_state_dict, str(out_path)) | |
| elif ext == ".bin": | |
| torch.save(merged_state_dict, str(out_path)) | |
| else: | |
| save_file(merged_state_dict, str(out_path)) | |
| return out_path | |
| def export_merged_model( | |
| include_lora: bool, | |
| quantize: bool, | |
| qtype: str, | |
| save_format: str = "safetensors", | |
| ): | |
| """ | |
| Export the merged pipeline model with optional LoRA baking and quantization. | |
| Args: | |
| include_lora: Whether to include fused LoRAs in export | |
| quantize: Whether to apply quantization | |
| qtype: Quantization type - 'none', 'int8', 'int4', or 'float8' | |
| save_format: Output format - 'safetensors' or 'bin' | |
| Returns: | |
| Tuple of (output_path or None, status message) | |
| """ | |
| # Fetch the pipeline at call time — avoids the stale import-by-value problem. | |
| pipe = config.get_pipe() | |
| if not pipe: | |
| return None, "⚠️ Please load a pipeline first." | |
| try: | |
| # Validate quantization type | |
| valid_qtypes = ("none", "int8", "int4", "float8") | |
| if qtype not in valid_qtypes: | |
| return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}" | |
| out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format) | |
| size_gb = out_path.stat().st_size / 1024**3 | |
| if quantize and qtype != "none": | |
| msg = f"✅ Quantized checkpoint saved: `{out_path}` ({size_gb:.2f} GB)" | |
| else: | |
| msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)" | |
| return str(out_path), msg | |
| except ImportError as e: | |
| return None, f"❌ Missing dependency: {str(e)}" | |
| except Exception as e: | |
| import traceback | |
| print(traceback.format_exc()) | |
| return None, f"❌ Export failed: {str(e)}" | |
| def get_export_status() -> str: | |
| """Get current export capability status.""" | |
| try: | |
| from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig | |
| return "✅ torchao available for quantization" | |
| except ImportError: | |
| return "ℹ️ Install torchao for quantization support: pip install torchao" | |