Kyle Pearson commited on
Commit
b89e643
·
1 Parent(s): d723e62

Replace dependency, add quantizer support, fix safetensors export, improve error handling, update UI docs

Browse files
Files changed (3) hide show
  1. requirements.txt +1 -1
  2. src/exporter.py +69 -54
  3. src/ui/exporter_tab.py +3 -2
requirements.txt CHANGED
@@ -25,7 +25,7 @@ huggingface-hub>=0.23.0
25
  psutil>=5.9.0
26
 
27
  # Optional: quantization support
28
- optimum-quanto>=0.2.0
29
 
30
  # ZeroGPU support for HuggingFace Spaces
31
  spaces
 
25
  psutil>=5.9.0
26
 
27
  # Optional: quantization support
28
+ torchao>=0.4.0
29
 
30
  # ZeroGPU support for HuggingFace Spaces
31
  spaces
src/exporter.py CHANGED
@@ -8,6 +8,35 @@ from .config import SCRIPT_DIR
8
  from .gpu_decorator import GPU
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  @GPU(duration=180)
12
  def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
13
  """GPU-decorated helper that extracts weights and saves the model."""
@@ -17,64 +46,57 @@ def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
17
  except Exception as e:
18
  print(f" ℹ️ Could not unload LoRAs: {e}")
19
 
 
 
 
 
 
 
20
  merged_state_dict = {}
21
 
22
  # Extract UNet weights
23
  for k, v in pipe.unet.state_dict().items():
24
- merged_state_dict[f"unet.{k}"] = v.contiguous().half()
 
 
 
 
 
25
 
26
  # Extract text encoder weights
27
  if pipe.text_encoder is not None:
28
  for k, v in pipe.text_encoder.state_dict().items():
29
- merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
 
 
 
30
  if pipe.text_encoder_2 is not None:
31
  for k, v in pipe.text_encoder_2.state_dict().items():
32
- merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
 
 
 
33
 
34
  # Extract VAE weights
35
  if pipe.vae is not None:
36
  for k, v in pipe.vae.state_dict().items():
37
- merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
38
-
39
- # Quantize if requested
40
- try:
41
- from optimum.quanto import quantize as quanto_quantize, QTensor
42
- QUANTO_AVAILABLE = True
43
- except ImportError:
44
- QUANTO_AVAILABLE = False
45
-
46
- if quantize and qtype != "none" and QUANTO_AVAILABLE:
47
- class FakeModel(torch.nn.Module):
48
- pass
49
-
50
- fake_model = FakeModel()
51
- fake_model.__dict__.update(merged_state_dict)
52
-
53
- if qtype == "int8":
54
- from optimum.quanto import int8_weight_only
55
- quanto_quantize(fake_model, int8_weight_only())
56
- elif qtype == "int4":
57
- from optimum.quanto import int4_weight_only
58
- quanto_quantize(fake_model, int4_weight_only())
59
- elif qtype == "float8":
60
- from optimum.quanto import float8_dynamic_activation_float8_weight
61
- quanto_quantize(fake_model, float8_dynamic_activation_float8_weight())
62
- else:
63
- raise ValueError(f"Unsupported qtype: {qtype}")
64
-
65
- merged_state_dict = {
66
- k: v.dequantize().half() if isinstance(v, QTensor) else v
67
- for k, v in fake_model.state_dict().items()
68
- }
69
- elif quantize and not QUANTO_AVAILABLE:
70
- raise ImportError("optimum.quanto not installed. Install with: pip install optimum-quanto")
71
 
72
  # Save model
73
  ext = ".bin" if save_format == "bin" else ".safetensors"
74
  prefix = f"{qtype}_" if quantize and qtype != "none" else ""
75
  out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
76
 
77
- if ext == ".bin":
 
 
 
 
 
 
78
  torch.save(merged_state_dict, str(out_path))
79
  else:
80
  save_file(merged_state_dict, str(out_path))
@@ -97,27 +119,20 @@ def export_merged_model(
97
  qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
98
  save_format: Output format - 'safetensors' or 'bin'
99
 
100
- Yields:
101
- Tuple of (status_message, progress_text) at each export stage.
102
-
103
  Returns:
104
- Final yielded tuple of (output_path or None, status message)
105
  """
106
  # Fetch the pipeline at call time — avoids the stale import-by-value problem.
107
  pipe = config.get_pipe()
108
 
109
  if not pipe:
110
- yield None, "⚠️ Please load a pipeline first."
111
- return
112
 
113
  try:
114
  # Validate quantization type
115
  valid_qtypes = ("none", "int8", "int4", "float8")
116
  if qtype not in valid_qtypes:
117
- yield None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
118
- return
119
-
120
- yield "💾 Exporting model...", "Extracting and saving weights..."
121
 
122
  out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
123
 
@@ -128,20 +143,20 @@ def export_merged_model(
128
  else:
129
  msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
130
 
131
- yield str(out_path), msg
132
 
133
  except ImportError as e:
134
- yield None, f"❌ Missing dependency: {str(e)}"
135
  except Exception as e:
136
  import traceback
137
  print(traceback.format_exc())
138
- yield None, f"❌ Export failed: {str(e)}"
139
 
140
 
141
  def get_export_status() -> str:
142
  """Get current export capability status."""
143
  try:
144
- from optimum.quanto import quantize
145
- return "✅ optimum.quanto available for quantization"
146
  except ImportError:
147
- return "ℹ️ Install optimum-quanto for quantization support"
 
8
  from .gpu_decorator import GPU
9
 
10
 
11
+ def _quantize_model(model, qtype: str):
12
+ """Apply torchao quantization to a model using quantize_."""
13
+ from torchao.quantization import quantize_
14
+
15
+ if qtype == "int8":
16
+ from torchao.quantization import Int8WeightOnlyConfig
17
+
18
+ print(" ⚙️ Quantizing with int8_weight_only...")
19
+ config = Int8WeightOnlyConfig()
20
+ quantize_(model, config)
21
+
22
+ elif qtype == "int4":
23
+ from torchao.quantization import Int4WeightOnlyConfig
24
+
25
+ print(" ⚙️ Quantizing with int4_weight_only (group_size=32)...")
26
+ config = Int4WeightOnlyConfig(group_size=32)
27
+ quantize_(model, config)
28
+
29
+ elif qtype == "float8":
30
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
31
+
32
+ print(" ⚙️ Quantizing with float8_dynamic_activation_float8_weight...")
33
+ config = Float8DynamicActivationFloat8WeightConfig()
34
+ quantize_(model, config)
35
+
36
+ else:
37
+ raise ValueError(f"Unsupported qtype: {qtype}. Must be one of: int8, int4, float8")
38
+
39
+
40
  @GPU(duration=180)
41
  def _extract_and_save(pipe, include_lora, quantize, qtype, save_format):
42
  """GPU-decorated helper that extracts weights and saves the model."""
 
46
  except Exception as e:
47
  print(f" ℹ️ Could not unload LoRAs: {e}")
48
 
49
+ # Quantize components in-place before extracting state dicts
50
+ if quantize and qtype != "none":
51
+ _quantize_model(pipe.unet, qtype)
52
+ # torchao quantized tensors cannot be saved with safetensors, use torch.save instead
53
+ # Don't dequantize - keep the quantized format for smaller file size
54
+
55
  merged_state_dict = {}
56
 
57
  # Extract UNet weights
58
  for k, v in pipe.unet.state_dict().items():
59
+ # For quantized tensors, save directly; otherwise convert to half
60
+ if hasattr(v, 'dequantize'):
61
+ # Keep quantized tensor as-is for smaller file size
62
+ merged_state_dict[f"unet.{k}"] = v
63
+ else:
64
+ merged_state_dict[f"unet.{k}"] = v.contiguous().half()
65
 
66
  # Extract text encoder weights
67
  if pipe.text_encoder is not None:
68
  for k, v in pipe.text_encoder.state_dict().items():
69
+ if hasattr(v, 'dequantize'):
70
+ merged_state_dict[f"text_encoder.{k}"] = v
71
+ else:
72
+ merged_state_dict[f"text_encoder.{k}"] = v.contiguous().half()
73
  if pipe.text_encoder_2 is not None:
74
  for k, v in pipe.text_encoder_2.state_dict().items():
75
+ if hasattr(v, 'dequantize'):
76
+ merged_state_dict[f"text_encoder_2.{k}"] = v
77
+ else:
78
+ merged_state_dict[f"text_encoder_2.{k}"] = v.contiguous().half()
79
 
80
  # Extract VAE weights
81
  if pipe.vae is not None:
82
  for k, v in pipe.vae.state_dict().items():
83
+ if hasattr(v, 'dequantize'):
84
+ merged_state_dict[f"first_stage_model.{k}"] = v
85
+ else:
86
+ merged_state_dict[f"first_stage_model.{k}"] = v.contiguous().half()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Save model
89
  ext = ".bin" if save_format == "bin" else ".safetensors"
90
  prefix = f"{qtype}_" if quantize and qtype != "none" else ""
91
  out_path = SCRIPT_DIR / f"merged_{prefix}checkpoint{ext}"
92
 
93
+ if quantize and qtype != "none":
94
+ # torchao quantized tensors are not compatible with safetensors
95
+ # Use torch.save instead which preserves the quantization format
96
+ ext = ".pt"
97
+ out_path = SCRIPT_DIR / f"merged_{qtype}_checkpoint.pt"
98
+ torch.save(merged_state_dict, str(out_path))
99
+ elif ext == ".bin":
100
  torch.save(merged_state_dict, str(out_path))
101
  else:
102
  save_file(merged_state_dict, str(out_path))
 
119
  qtype: Quantization type - 'none', 'int8', 'int4', or 'float8'
120
  save_format: Output format - 'safetensors' or 'bin'
121
 
 
 
 
122
  Returns:
123
+ Tuple of (output_path or None, status message)
124
  """
125
  # Fetch the pipeline at call time — avoids the stale import-by-value problem.
126
  pipe = config.get_pipe()
127
 
128
  if not pipe:
129
+ return None, "⚠️ Please load a pipeline first."
 
130
 
131
  try:
132
  # Validate quantization type
133
  valid_qtypes = ("none", "int8", "int4", "float8")
134
  if qtype not in valid_qtypes:
135
+ return None, f"❌ Invalid quantization type: {qtype}. Must be one of: {valid_qtypes}"
 
 
 
136
 
137
  out_path = _extract_and_save(pipe, include_lora, quantize, qtype, save_format)
138
 
 
143
  else:
144
  msg = f"✅ Merged checkpoint saved: `{out_path}` ({size_gb:.2f} GB)"
145
 
146
+ return str(out_path), msg
147
 
148
  except ImportError as e:
149
+ return None, f"❌ Missing dependency: {str(e)}"
150
  except Exception as e:
151
  import traceback
152
  print(traceback.format_exc())
153
+ return None, f"❌ Export failed: {str(e)}"
154
 
155
 
156
  def get_export_status() -> str:
157
  """Get current export capability status."""
158
  try:
159
+ from torchao.quantization import quantize_, Int4WeightOnlyConfig, Int8WeightOnlyConfig, Float8DynamicActivationFloat8WeightConfig
160
+ return "✅ torchao available for quantization"
161
  except ImportError:
162
+ return "ℹ️ Install torchao for quantization support: pip install torchao"
src/ui/exporter_tab.py CHANGED
@@ -62,8 +62,9 @@ def create_exporter_tab():
62
  <div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
63
  <strong>ℹ️ About Quantization:</strong>
64
  <p style="font-size: 0.9em; margin: 8px 0;">
65
- Reduces model size by lowering precision. Int8 is typically
66
- lossless for inference while cutting size in half.
 
67
  </p>
68
  </div>
69
  """)
 
62
  <div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
63
  <strong>ℹ️ About Quantization:</strong>
64
  <p style="font-size: 0.9em; margin: 8px 0;">
65
+ Reduces model size by lowering precision using torchao.
66
+ Int8 is typically lossless for inference while cutting size in half.
67
+ Int4 provides maximum compression with minimal quality loss.
68
  </p>
69
  </div>
70
  """)