Spaces:
Sleeping
Sleeping
Kyle Pearson commited on
Commit ·
b89e643
1
Parent(s): d723e62
Replace dependency, add quantizer support, fix safetensors export, improve error handling, update UI docs
Browse files- requirements.txt +1 -1
- src/exporter.py +69 -54
- src/ui/exporter_tab.py +3 -2
requirements.txt
CHANGED
|
@@ -25,7 +25,7 @@ huggingface-hub>=0.23.0
|
|
| 25 |
psutil>=5.9.0
|
| 26 |
|
| 27 |
# Optional: quantization support
|
| 28 |
-
|
| 29 |
|
| 30 |
# ZeroGPU support for HuggingFace Spaces
|
| 31 |
spaces
|
|
|
|
| 25 |
psutil>=5.9.0
|
| 26 |
|
| 27 |
# Optional: quantization support
|
| 28 |
+
torchao>=0.4.0
|
| 29 |
|
| 30 |
# ZeroGPU support for HuggingFace Spaces
|
| 31 |
spaces
|
src/exporter.py
CHANGED
|
@@ -8,6 +8,35 @@ from .config import SCRIPT_DIR
|
|
| 8 |
from .gpu_decorator import GPU
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
@GPU(duration=180)
|
| 12 |
def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
|
| 13 |
"""GPU-decorated helper that extracts weights and saves the model."""
|
|
@@ -17,64 +46,57 @@ def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
|
|
| 17 |
except Exception as e:
|
| 18 |
print(f" ℹ️ Could not unload LoRAs: {e}")
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
merged_state_dict = {}
|
| 21 |
|
| 22 |
# Extract UNet weights
|
| 23 |
for k, v in pipe.unet.state_dict().items():
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Extract text encoder weights
|
| 27 |
if pipe.text_encoder is not None:
|
| 28 |
for k, v in pipe.text_encoder.state_dict().items():
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
| 30 |
if pipe.text_encoder_2 is not None:
|
| 31 |
for k, v in pipe.text_encoder_2.state_dict().items():
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Extract VAE weights
|
| 35 |
if pipe.vae is not None:
|
| 36 |
for k, v in pipe.vae.state_dict().items():
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
from optimum.quanto import quantize as quanto_quantize, QTensor
|
| 42 |
-
QUANTO_AVAILABLE = True
|
| 43 |
-
except ImportError:
|
| 44 |
-
QUANTO_AVAILABLE = False
|
| 45 |
-
|
| 46 |
-
if quantize and qtype != "none" and QUANTO_AVAILABLE:
|
| 47 |
-
class FakeModel(torch.nn.Module):
|
| 48 |
-
pass
|
| 49 |
-
|
| 50 |
-
fake_model = FakeModel()
|
| 51 |
-
fake_model.__dict__.update(merged_state_dict)
|
| 52 |
-
|
| 53 |
-
if qtype == "int8":
|
| 54 |
-
from optimum.quanto import int8_weight_only
|
| 55 |
-
quanto_quantize(fake_model, int8_weight_only())
|
| 56 |
-
elif qtype == "int4":
|
| 57 |
-
from optimum.quanto import int4_weight_only
|
| 58 |
-
quanto_quantize(fake_model, int4_weight_only())
|
| 59 |
-
elif qtype == "float8":
|
| 60 |
-
from optimum.quanto import float8_dynamic_activation_float8_weight
|
| 61 |
-
quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
|
| 62 |
-
else:
|
| 63 |
-
raise ValueError(f"Unsupported qtype: {qtype}")
|
| 64 |
-
|
| 65 |
-
merged_state_dict = {
|
| 66 |
-
k: v.dequantize().half() if isinstance(v, QTensor) else v
|
| 67 |
-
for k, v in fake_model.state_dict().items()
|
| 68 |
-
}
|
| 69 |
-
elif quantize and not QUANTO_AVAILABLE:
|
| 70 |
-
raise ImportError("optimum.quanto not installed. Install with: pip install optimum-quanto")
|
| 71 |
|
| 72 |
# Save model
|
| 73 |
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 74 |
prefix = f"{qtype}_" if quantize and qtype != "none" else ""
|
| 75 |
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
|
| 76 |
|
| 77 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
torch.save(merged_state_dict, str(out_path))
|
| 79 |
else:
|
| 80 |
save_file(merged_state_dict, str(out_path))
|
|
@@ -97,27 +119,20 @@ def export_merged_model(
|
|
| 97 |
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
|
| 98 |
save_format: Output format - 'safetensors' or 'bin'
|
| 99 |
|
| 100 |
-
Yields:
|
| 101 |
-
Tuple of (status_message, progress_text) at each export stage.
|
| 102 |
-
|
| 103 |
Returns:
|
| 104 |
-
|
| 105 |
"""
|
| 106 |
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
|
| 107 |
pipe = config.get_pipe()
|
| 108 |
|
| 109 |
if not pipe:
|
| 110 |
-
|
| 111 |
-
return
|
| 112 |
|
| 113 |
try:
|
| 114 |
# Validate quantization type
|
| 115 |
valid_qtypes = ("none", "int8", "int4", "float8")
|
| 116 |
if qtype not in valid_qtypes:
|
| 117 |
-
|
| 118 |
-
return
|
| 119 |
-
|
| 120 |
-
yield "💾 Exporting model...", "Extracting and saving weights..."
|
| 121 |
|
| 122 |
out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
|
| 123 |
|
|
@@ -128,20 +143,20 @@ def export_merged_model(
|
|
| 128 |
else:
|
| 129 |
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
|
| 130 |
|
| 131 |
-
|
| 132 |
|
| 133 |
except ImportError as e:
|
| 134 |
-
|
| 135 |
except Exception as e:
|
| 136 |
import traceback
|
| 137 |
print(traceback.format_exc())
|
| 138 |
-
|
| 139 |
|
| 140 |
|
| 141 |
def get_export_status() -> str:
|
| 142 |
"""Get current export capability status."""
|
| 143 |
try:
|
| 144 |
-
from
|
| 145 |
-
return "✅
|
| 146 |
except ImportError:
|
| 147 |
-
return "ℹ️ Install
|
|
|
|
| 8 |
from .gpu_decorator import GPU
|
| 9 |
|
| 10 |
|
| 11 |
+
def _quantize_model(model, qtype: str):
|
| 12 |
+
"""Apply torchao quantization to a model using quantize_."""
|
| 13 |
+
from torchao.quantization import quantize_
|
| 14 |
+
|
| 15 |
+
if qtype == "int8":
|
| 16 |
+
from torchao.quantization import Int8WeightOnlyConfig
|
| 17 |
+
|
| 18 |
+
print(" ⚙️ Quantizing with int8_weight_only...")
|
| 19 |
+
config = Int8WeightOnlyConfig()
|
| 20 |
+
quantize_(model, config)
|
| 21 |
+
|
| 22 |
+
elif qtype == "int4":
|
| 23 |
+
from torchao.quantization import Int4WeightOnlyConfig
|
| 24 |
+
|
| 25 |
+
print(" ⚙️ Quantizing with int4_weight_only (group_size=32)...")
|
| 26 |
+
config = Int4WeightOnlyConfig(group_size=32)
|
| 27 |
+
quantize_(model, config)
|
| 28 |
+
|
| 29 |
+
elif qtype == "float8":
|
| 30 |
+
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
| 31 |
+
|
| 32 |
+
print(" ⚙️ Quantizing with float8_dynamic_activation_float8_weight...")
|
| 33 |
+
config = Float8DynamicActivationFloat8WeightConfig()
|
| 34 |
+
quantize_(model, config)
|
| 35 |
+
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
@GPU(duration=180)
|
| 41 |
def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
|
| 42 |
"""GPU-decorated helper that extracts weights and saves the model."""
|
|
|
|
| 46 |
except Exception as e:
|
| 47 |
print(f" ℹ️ Could not unload LoRAs: {e}")
|
| 48 |
|
| 49 |
+
# Quantize components in-place before extracting state dicts
|
| 50 |
+
if quantize and qtype != "none":
|
| 51 |
+
_quantize_model(pipe.unet, qtype)
|
| 52 |
+
# torchao quantized tensors cannot be saved with safetensors, use torch.save instead
|
| 53 |
+
# Don't dequantize - keep the quantized format for smaller file size
|
| 54 |
+
|
| 55 |
merged_state_dict = {}
|
| 56 |
|
| 57 |
# Extract UNet weights
|
| 58 |
for k, v in pipe.unet.state_dict().items():
|
| 59 |
+
# For quantized tensors, save directly; otherwise convert to half
|
| 60 |
+
if hasattr(v, 'dequantize'):
|
| 61 |
+
# Keep quantized tensor as-is for smaller file size
|
| 62 |
+
merged_state_dict[f"unet.{k}"] = v
|
| 63 |
+
else:
|
| 64 |
+
merged_state_dict[f"unet.{k}"] = v.contiguous().half()
|
| 65 |
|
| 66 |
# Extract text encoder weights
|
| 67 |
if pipe.text_encoder is not None:
|
| 68 |
for k, v in pipe.text_encoder.state_dict().items():
|
| 69 |
+
if hasattr(v, 'dequantize'):
|
| 70 |
+
merged_state_dict[f"text_encoder.{k}"] = v
|
| 71 |
+
else:
|
| 72 |
+
merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
|
| 73 |
if pipe.text_encoder_2 is not None:
|
| 74 |
for k, v in pipe.text_encoder_2.state_dict().items():
|
| 75 |
+
if hasattr(v, 'dequantize'):
|
| 76 |
+
merged_state_dict[f"text_encoder_2.{k}"] = v
|
| 77 |
+
else:
|
| 78 |
+
merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
|
| 79 |
|
| 80 |
# Extract VAE weights
|
| 81 |
if pipe.vae is not None:
|
| 82 |
for k, v in pipe.vae.state_dict().items():
|
| 83 |
+
if hasattr(v, 'dequantize'):
|
| 84 |
+
merged_state_dict[f"first_stage_model.{k}"] = v
|
| 85 |
+
else:
|
| 86 |
+
merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
# Save model
|
| 89 |
ext = ".bin" if save_format == "bin" else ".safetensors"
|
| 90 |
prefix = f"{qtype}_" if quantize and qtype != "none" else ""
|
| 91 |
out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
|
| 92 |
|
| 93 |
+
if quantize and qtype != "none":
|
| 94 |
+
# torchao quantized tensors are not compatible with safetensors
|
| 95 |
+
# Use torch.save instead which preserves the quantization format
|
| 96 |
+
ext = ".pt"
|
| 97 |
+
out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt"
|
| 98 |
+
torch.save(merged_state_dict, str(out_path))
|
| 99 |
+
elif ext == ".bin":
|
| 100 |
torch.save(merged_state_dict, str(out_path))
|
| 101 |
else:
|
| 102 |
save_file(merged_state_dict, str(out_path))
|
|
|
|
| 119 |
qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
|
| 120 |
save_format: Output format - 'safetensors' or 'bin'
|
| 121 |
|
|
|
|
|
|
|
|
|
|
| 122 |
Returns:
|
| 123 |
+
Tuple of (output_path or None, status message)
|
| 124 |
"""
|
| 125 |
# Fetch the pipeline at call time — avoids the stale import-by-value problem.
|
| 126 |
pipe = config.get_pipe()
|
| 127 |
|
| 128 |
if not pipe:
|
| 129 |
+
return None, "⚠️ Please load a pipeline first."
|
|
|
|
| 130 |
|
| 131 |
try:
|
| 132 |
# Validate quantization type
|
| 133 |
valid_qtypes = ("none", "int8", "int4", "float8")
|
| 134 |
if qtype not in valid_qtypes:
|
| 135 |
+
return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
|
| 138 |
|
|
|
|
| 143 |
else:
|
| 144 |
msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
|
| 145 |
|
| 146 |
+
return str(out_path), msg
|
| 147 |
|
| 148 |
except ImportError as e:
|
| 149 |
+
return None, f"❌ Missing dependency: {str(e)}"
|
| 150 |
except Exception as e:
|
| 151 |
import traceback
|
| 152 |
print(traceback.format_exc())
|
| 153 |
+
return None, f"❌ Export failed: {str(e)}"
|
| 154 |
|
| 155 |
|
| 156 |
def get_export_status() -> str:
|
| 157 |
"""Get current export capability status."""
|
| 158 |
try:
|
| 159 |
+
from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig
|
| 160 |
+
return "✅ torchao available for quantization"
|
| 161 |
except ImportError:
|
| 162 |
+
return "ℹ️ Install torchao for quantization support: pip install torchao"
|
src/ui/exporter_tab.py
CHANGED
|
@@ -62,8 +62,9 @@ def create_exporter_tab():
|
|
| 62 |
<div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
|
| 63 |
<strong>ℹ️ About Quantization:</strong>
|
| 64 |
<p style="font-size: 0.9em; margin: 8px 0;">
|
| 65 |
-
Reduces model size by lowering precision
|
| 66 |
-
lossless for inference while cutting size in half.
|
|
|
|
| 67 |
</p>
|
| 68 |
</div>
|
| 69 |
""")
|
|
|
|
| 62 |
<div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
|
| 63 |
<strong>ℹ️ About Quantization:</strong>
|
| 64 |
<p style="font-size: 0.9em; margin: 8px 0;">
|
| 65 |
+
Reduces model size by lowering precision using torchao.
|
| 66 |
+
Int8 is typically lossless for inference while cutting size in half.
|
| 67 |
+
Int4 provides maximum compression with minimal quality loss.
|
| 68 |
</p>
|
| 69 |
</div>
|
| 70 |
""")
|