HaoxingChen commited on
Commit
ae62f19
·
verified ·
1 Parent(s): ae429f4

update repo

Browse files
configuration_llada2uni_moe.py CHANGED
@@ -111,6 +111,9 @@ class LLaDA2MoeConfig(PretrainedConfig):
111
  self.moe_intermediate_size = moe_intermediate_size
112
  self.first_k_dense_replace = first_k_dense_replace
113
  self.output_router_logits = output_router_logits
 
 
 
114
 
115
  super().__init__(
116
  pad_token_id=pad_token_id,
 
111
  self.moe_intermediate_size = moe_intermediate_size
112
  self.first_k_dense_replace = first_k_dense_replace
113
  self.output_router_logits = output_router_logits
114
+
115
+ # FP8 quantization flag — set to True to use FP8Linear for experts
116
+ self.use_fp8_experts = kwargs.pop("use_fp8_experts", False)
117
 
118
  super().__init__(
119
  pad_token_id=pad_token_id,
convert_experts_to_fp8.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert MoE expert weights from bf16 to fp8 with block-wise quantization.
4
+
5
+ Professional FP8 quantization following the same approach as Qwen3.5-FP8:
6
+ - Block-wise quantization with per-block scale (weight_scale_inv)
7
+ - Only quantize expert Linear weight tensors (gate_proj, up_proj, down_proj)
8
+ - Keep all other weights in bf16: embedding, lm_head, routing gates, layernorms,
9
+ attention projections, shared experts
10
+ - Stores weight_scale_inv alongside each quantized weight
11
+
12
+ Usage:
13
+ python convert_experts_to_fp8.py \
14
+ --input_dir /path/to/UniLLaDA \
15
+ --output_dir /path/to/UniLLaDA-FP8
16
+
17
+ Then load with:
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ output_dir, device_map="cuda", torch_dtype="bfloat16", trust_remote_code=True
20
+ )
21
+ # config.json will have use_fp8_experts=true, so experts use FP8Linear automatically
22
+ """
23
+
24
+ import os
25
+ import re
26
+ import json
27
+ import argparse
28
+ from collections import OrderedDict
29
+
30
+ import torch
31
+ from safetensors.torch import load_file, save_file
32
+ from tqdm import tqdm
33
+
34
+ FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
35
+ DEFAULT_BLOCK_SIZE = 128
36
+
37
+
38
+ def is_expert_weight(name: str) -> bool:
39
+ """Match routed expert weight tensors (not shared experts)."""
40
+ return bool(re.match(
41
+ r"model\.layers\.\d+\.mlp\.experts\.\d+\.(gate_proj|up_proj|down_proj)\.weight$",
42
+ name
43
+ ))
44
+
45
+
46
+ def quantize_blockwise(tensor: torch.Tensor, block_size: int = DEFAULT_BLOCK_SIZE):
47
+ """Quantize a 2D weight tensor to FP8 with block-wise scaling.
48
+
49
+ Args:
50
+ tensor: Weight tensor of shape (out_features, in_features)
51
+ block_size: Block size for quantization (default 128)
52
+
53
+ Returns:
54
+ fp8_tensor: Quantized tensor (float8_e4m3fn), same shape as input
55
+ scale_inv: Per-block scale (bfloat16), shape (ceil(out/bs), ceil(in/bs))
56
+ """
57
+ assert tensor.dim() == 2
58
+ weight = tensor.float()
59
+ out_f, in_f = weight.shape
60
+ bs = block_size
61
+
62
+ n_bo = (out_f + bs - 1) // bs
63
+ n_bi = (in_f + bs - 1) // bs
64
+
65
+ # Pad for even blocking
66
+ pad_out = n_bo * bs - out_f
67
+ pad_in = n_bi * bs - in_f
68
+ if pad_out > 0 or pad_in > 0:
69
+ padded = torch.zeros(n_bo * bs, n_bi * bs, dtype=torch.float32)
70
+ padded[:out_f, :in_f] = weight
71
+ else:
72
+ padded = weight
73
+
74
+ # Reshape into blocks: (n_bo, bs, n_bi, bs) -> (n_bo, n_bi, bs, bs)
75
+ blocks = padded.reshape(n_bo, bs, n_bi, bs).permute(0, 2, 1, 3)
76
+
77
+ # Per-block absmax -> scale
78
+ absmax = blocks.abs().amax(dim=(-2, -1)).clamp_min(1e-12) # (n_bo, n_bi)
79
+ scale = absmax / FP8_MAX
80
+
81
+ # Quantize
82
+ scale_exp = scale[:, :, None, None] # (n_bo, n_bi, 1, 1)
83
+ fp8_blocks = (blocks / scale_exp).clamp(-FP8_MAX, FP8_MAX)
84
+
85
+ # Reshape back: (n_bo, n_bi, bs, bs) -> (n_bo, bs, n_bi, bs) -> (H, W)
86
+ fp8_full = fp8_blocks.permute(0, 2, 1, 3).reshape(n_bo * bs, n_bi * bs)
87
+ fp8_tensor = fp8_full[:out_f, :in_f].to(torch.float8_e4m3fn)
88
+
89
+ # scale_inv for dequantization: real_weight = fp8.to(dtype) * scale_expanded
90
+ scale_inv = scale.to(torch.bfloat16)
91
+
92
+ return fp8_tensor, scale_inv
93
+
94
+
95
+ def main():
96
+ parser = argparse.ArgumentParser(
97
+ description="Convert UniLLaDA expert weights to FP8 (block-wise quantization)")
98
+ parser.add_argument("--input_dir", type=str, required=True,
99
+ help="Path to original bf16 model directory")
100
+ parser.add_argument("--output_dir", type=str, required=True,
101
+ help="Path to output FP8 model directory")
102
+ parser.add_argument("--block_size", type=int, default=DEFAULT_BLOCK_SIZE,
103
+ help=f"Quantization block size (default: {DEFAULT_BLOCK_SIZE})")
104
+ args = parser.parse_args()
105
+
106
+ input_dir = os.path.abspath(args.input_dir)
107
+ output_dir = os.path.abspath(args.output_dir)
108
+ block_size = args.block_size
109
+ os.makedirs(output_dir, exist_ok=True)
110
+
111
+ # Load index
112
+ with open(os.path.join(input_dir, "model.safetensors.index.json")) as f:
113
+ index = json.load(f)
114
+
115
+ weight_map = index["weight_map"]
116
+ shard_to_keys = {}
117
+ for key, shard in weight_map.items():
118
+ shard_to_keys.setdefault(shard, []).append(key)
119
+
120
+ new_weight_map = OrderedDict()
121
+ stats = {"expert": 0, "other": 0, "bytes_before": 0, "bytes_after": 0}
122
+
123
+ # Process each shard
124
+ for shard_file in tqdm(sorted(shard_to_keys.keys()), desc="Converting shards"):
125
+ tensors = load_file(os.path.join(input_dir, shard_file), device="cpu")
126
+ new_tensors = OrderedDict()
127
+
128
+ for key in sorted(tensors.keys()):
129
+ tensor = tensors[key]
130
+ old_bytes = tensor.nelement() * tensor.element_size()
131
+ stats["bytes_before"] += old_bytes
132
+
133
+ if is_expert_weight(key):
134
+ fp8_tensor, scale_inv = quantize_blockwise(tensor, block_size)
135
+ new_tensors[key] = fp8_tensor
136
+ scale_key = key.replace(".weight", ".weight_scale_inv")
137
+ new_tensors[scale_key] = scale_inv
138
+
139
+ new_bytes = (fp8_tensor.nelement() * fp8_tensor.element_size() +
140
+ scale_inv.nelement() * scale_inv.element_size())
141
+ stats["bytes_after"] += new_bytes
142
+ stats["expert"] += 1
143
+ new_weight_map[key] = shard_file
144
+ new_weight_map[scale_key] = shard_file
145
+ else:
146
+ new_tensors[key] = tensor
147
+ stats["bytes_after"] += old_bytes
148
+ stats["other"] += 1
149
+ new_weight_map[key] = shard_file
150
+
151
+ save_file(new_tensors, os.path.join(output_dir, shard_file))
152
+ del tensors, new_tensors
153
+
154
+ # Save new index
155
+ new_index = {"metadata": index.get("metadata", {}), "weight_map": dict(new_weight_map)}
156
+ with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f:
157
+ json.dump(new_index, f, indent=2)
158
+
159
+ # Update config.json: add use_fp8_experts=true and quantization_config
160
+ with open(os.path.join(input_dir, "config.json")) as f:
161
+ config = json.load(f)
162
+
163
+ config["use_fp8_experts"] = True
164
+ # Note: we do NOT add quantization_config here because transformers' built-in
165
+ # FP8 quantizer would conflict with our custom FP8Linear class.
166
+ # The use_fp8_experts flag is handled by our modeling code directly.
167
+
168
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
169
+ json.dump(config, f, indent=2)
170
+
171
+ # Symlink everything else (decoder, vae, tokenizer, code files...)
172
+ for fname in os.listdir(input_dir):
173
+ if fname.startswith("model-") and fname.endswith(".safetensors"):
174
+ continue
175
+ if fname in ("model.safetensors.index.json", "config.json"):
176
+ continue
177
+ src = os.path.join(input_dir, fname)
178
+ dst = os.path.join(output_dir, fname)
179
+ if os.path.exists(dst):
180
+ continue
181
+ os.symlink(src, dst)
182
+
183
+ gb_b = stats["bytes_before"] / 1024**3
184
+ gb_a = stats["bytes_after"] / 1024**3
185
+ print(f"\n{'='*60}")
186
+ print(f"✅ Block-wise FP8 conversion complete!")
187
+ print(f" Block size: {block_size}x{block_size}")
188
+ print(f" Expert tensors quantized: {stats['expert']}")
189
+ print(f" Other tensors (kept bf16): {stats['other']}")
190
+ print(f" Weights: {gb_b:.2f} GB → {gb_a:.2f} GB "
191
+ f"(saved {gb_b-gb_a:.2f} GB, -{(1-gb_a/gb_b)*100:.1f}%)")
192
+ print(f" Output: {output_dir}")
193
+ print(f"{'='*60}")
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
convert_to_fp8_blockwise.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert UniLLaDA MoE backbone weights to FP8 (block-wise quantization).
4
+
5
+ Professional FP8 quantization following the same approach as Qwen3.5-FP8:
6
+ - Block-wise quantization with per-block scale (weight_scale_inv)
7
+ - Only quantize Linear weight tensors (experts, shared experts, attention projections)
8
+ - Keep sensitive layers in bf16: embedding, lm_head, routing gates, layernorms
9
+ - Store quantization_config in config.json for framework compatibility
10
+
11
+ Usage:
12
+ python convert_to_fp8_blockwise.py \
13
+ --input_dir /path/to/UniLLaDA \
14
+ --output_dir /path/to/UniLLaDA-FP8
15
+
16
+ The output can be loaded with the SAME modeling code (no changes needed):
17
+ model = AutoModelForCausalLM.from_pretrained(output_dir, ...)
18
+ """
19
+
20
+ import os
21
+ import re
22
+ import json
23
+ import argparse
24
+ from collections import OrderedDict
25
+
26
+ import torch
27
+ from safetensors.torch import load_file, save_file
28
+ from tqdm import tqdm
29
+
30
+ FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
31
+ BLOCK_SIZE = 128 # quantization block size (128x128)
32
+
33
+
34
+ def should_quantize(name: str) -> bool:
35
+ """Determine if a weight should be quantized to FP8.
36
+
37
+ Quantize: expert weights, shared expert weights, attention projections (Linear .weight)
38
+ Keep bf16: embedding, lm_head, gate weights, layernorm, biases, expert_bias
39
+ """
40
+ # Must be a weight tensor (not bias, not scale, not buffer)
41
+ if not name.endswith(".weight"):
42
+ return False
43
+
44
+ # Never quantize these
45
+ skip_patterns = [
46
+ r"word_embeddings\.weight$", # embedding
47
+ r"lm_head\.weight$", # output head
48
+ r"\.gate\.weight$", # routing gate
49
+ r"layernorm\.weight$", # QK layernorm
50
+ r"input_layernorm\.weight$", # layer norm
51
+ r"post_attention_layernorm\.weight$", # layer norm
52
+ r"norm\.weight$", # final norm
53
+ ]
54
+ for pat in skip_patterns:
55
+ if re.search(pat, name):
56
+ return False
57
+
58
+ # Quantize: expert proj, shared_expert proj, attention proj
59
+ quantize_patterns = [
60
+ r"experts\.\d+\.(gate_proj|up_proj|down_proj)\.weight$",
61
+ r"shared_experts\.(gate_proj|up_proj|down_proj)\.weight$",
62
+ r"attention\.(query_key_value|dense)\.weight$",
63
+ r"mlp\.(gate_proj|up_proj|down_proj)\.weight$", # dense layer (layer 0)
64
+ ]
65
+ for pat in quantize_patterns:
66
+ if re.search(pat, name):
67
+ return True
68
+
69
+ return False
70
+
71
+
72
+ def quantize_tensor_blockwise(tensor: torch.Tensor, block_size: int = BLOCK_SIZE):
73
+ """Quantize a 2D weight tensor to FP8 with block-wise scaling.
74
+
75
+ Args:
76
+ tensor: Weight tensor of shape (out_features, in_features), dtype bf16/fp32
77
+ block_size: Block size for quantization (default 128)
78
+
79
+ Returns:
80
+ fp8_tensor: Quantized tensor (float8_e4m3fn)
81
+ scale_inv: Per-block inverse scale (bf16), shape (ceil(out/block), ceil(in/block))
82
+ """
83
+ assert tensor.dim() == 2, f"Expected 2D tensor, got {tensor.dim()}D"
84
+
85
+ out_features, in_features = tensor.shape
86
+ # Pad if needed
87
+ pad_out = (block_size - out_features % block_size) % block_size
88
+ pad_in = (block_size - in_features % block_size) % block_size
89
+
90
+ if pad_out > 0 or pad_in > 0:
91
+ padded = torch.zeros(out_features + pad_out, in_features + pad_in,
92
+ dtype=torch.float32, device=tensor.device)
93
+ padded[:out_features, :in_features] = tensor.float()
94
+ else:
95
+ padded = tensor.float()
96
+
97
+ n_blocks_out = padded.shape[0] // block_size
98
+ n_blocks_in = padded.shape[1] // block_size
99
+
100
+ # Reshape into blocks
101
+ blocks = padded.reshape(n_blocks_out, block_size, n_blocks_in, block_size)
102
+ blocks = blocks.permute(0, 2, 1, 3) # (n_out, n_in, block, block)
103
+
104
+ # Compute per-block absmax
105
+ absmax = blocks.abs().amax(dim=(-2, -1)) # (n_out, n_in)
106
+
107
+ # Compute scale: scale = absmax / FP8_MAX
108
+ # scale_inv = 1 / scale = FP8_MAX / absmax (for dequantization: real = fp8 * scale_inv)
109
+ # But we store scale_inv as absmax / FP8_MAX (same as Qwen convention)
110
+ # Actually Qwen stores: weight_scale_inv where real_weight ≈ fp8_weight * scale_inv * FP8_MAX
111
+ # Let's match Qwen's convention exactly:
112
+ # scale_inv = absmax / FP8_MAX (so dequant = fp8 * scale_inv * FP8_MAX / FP8_MAX = fp8 * scale_inv... no)
113
+
114
+ # Looking at Qwen's values (~1e-4), and weight range is small (~0.03 max for 512x2048)
115
+ # scale_inv ≈ absmax / FP8_MAX
116
+ # Dequantization: real_weight = fp8_weight.float() * scale_inv_expanded
117
+ # This means: quantization: fp8 = clamp(weight / scale_inv, -FP8_MAX, FP8_MAX)
118
+ # Wait, that would make fp8 values huge...
119
+
120
+ # Actually the standard convention is:
121
+ # scale = absmax / FP8_MAX
122
+ # fp8 = weight / scale (maps to [-FP8_MAX, FP8_MAX])
123
+ # dequant: weight = fp8 * scale
124
+ # scale_inv = scale (confusing naming, but that's what Qwen uses - they call it scale_inv
125
+ # because it's the inverse of the integer-style scale)
126
+
127
+ scale = absmax / FP8_MAX # (n_out, n_in)
128
+ scale = scale.clamp_min(1e-12) # avoid division by zero
129
+
130
+ # Quantize
131
+ scale_expanded = scale[:, :, None, None] # (n_out, n_in, 1, 1)
132
+ fp8_blocks = (blocks / scale_expanded).clamp(-FP8_MAX, FP8_MAX)
133
+
134
+ # Reshape back
135
+ fp8_blocks = fp8_blocks.permute(0, 2, 1, 3) # (n_out, block, n_in, block)
136
+ fp8_full = fp8_blocks.reshape(padded.shape[0], padded.shape[1])
137
+
138
+ # Trim padding
139
+ fp8_tensor = fp8_full[:out_features, :in_features].to(torch.float8_e4m3fn)
140
+
141
+ # scale_inv for dequantization: real = fp8.float() * scale_inv_expanded
142
+ scale_inv = scale.to(torch.bfloat16)
143
+
144
+ return fp8_tensor, scale_inv
145
+
146
+
147
+ def build_modules_to_not_convert(weight_map: dict) -> list:
148
+ """Build the modules_to_not_convert list from weight map."""
149
+ not_convert = set()
150
+ for key in weight_map.keys():
151
+ if not should_quantize(key):
152
+ # Extract module name (remove .weight suffix)
153
+ module_name = key.rsplit(".weight", 1)[0] if key.endswith(".weight") else key.rsplit(".", 1)[0]
154
+ not_convert.add(module_name)
155
+ return sorted(not_convert)
156
+
157
+
158
+ def main():
159
+ parser = argparse.ArgumentParser(description="Convert UniLLaDA to FP8 (block-wise)")
160
+ parser.add_argument("--input_dir", type=str, required=True,
161
+ help="Path to original bf16 model directory")
162
+ parser.add_argument("--output_dir", type=str, required=True,
163
+ help="Path to output FP8 model directory")
164
+ parser.add_argument("--block_size", type=int, default=128,
165
+ help="Quantization block size (default: 128)")
166
+ args = parser.parse_args()
167
+
168
+ input_dir = os.path.abspath(args.input_dir)
169
+ output_dir = os.path.abspath(args.output_dir)
170
+ block_size = args.block_size
171
+ os.makedirs(output_dir, exist_ok=True)
172
+
173
+ # Load index
174
+ with open(os.path.join(input_dir, "model.safetensors.index.json")) as f:
175
+ index = json.load(f)
176
+
177
+ weight_map = index["weight_map"]
178
+ shard_to_keys = {}
179
+ for key, shard in weight_map.items():
180
+ shard_to_keys.setdefault(shard, []).append(key)
181
+
182
+ new_weight_map = OrderedDict()
183
+ stats = {"quantized": 0, "kept_bf16": 0, "bytes_before": 0, "bytes_after": 0}
184
+
185
+ # Process each shard
186
+ for shard_file in tqdm(sorted(shard_to_keys.keys()), desc="Converting shards"):
187
+ tensors = load_file(os.path.join(input_dir, shard_file), device="cpu")
188
+ new_tensors = OrderedDict()
189
+
190
+ for key in sorted(tensors.keys()):
191
+ tensor = tensors[key]
192
+ old_bytes = tensor.nelement() * tensor.element_size()
193
+ stats["bytes_before"] += old_bytes
194
+
195
+ if should_quantize(key) and tensor.dim() == 2:
196
+ fp8_tensor, scale_inv = quantize_tensor_blockwise(tensor, block_size)
197
+ new_tensors[key] = fp8_tensor
198
+ scale_key = key.replace(".weight", ".weight_scale_inv")
199
+ new_tensors[scale_key] = scale_inv
200
+
201
+ new_bytes = fp8_tensor.nelement() * fp8_tensor.element_size() + \
202
+ scale_inv.nelement() * scale_inv.element_size()
203
+ stats["bytes_after"] += new_bytes
204
+ stats["quantized"] += 1
205
+ new_weight_map[key] = shard_file
206
+ new_weight_map[scale_key] = shard_file
207
+ else:
208
+ new_tensors[key] = tensor
209
+ stats["bytes_after"] += old_bytes
210
+ stats["kept_bf16"] += 1
211
+ new_weight_map[key] = shard_file
212
+
213
+ save_file(new_tensors, os.path.join(output_dir, shard_file))
214
+ del tensors, new_tensors
215
+
216
+ # Save new index
217
+ new_index = {"metadata": index.get("metadata", {}), "weight_map": dict(new_weight_map)}
218
+ with open(os.path.join(output_dir, "model.safetensors.index.json"), "w") as f:
219
+ json.dump(new_index, f, indent=2)
220
+
221
+ # Build quantization config (following Qwen's format)
222
+ not_convert_modules = build_modules_to_not_convert(weight_map)
223
+
224
+ # Load and modify config.json
225
+ with open(os.path.join(input_dir, "config.json")) as f:
226
+ config = json.load(f)
227
+
228
+ config["quantization_config"] = {
229
+ "quant_method": "fp8",
230
+ "activation_scheme": "dynamic",
231
+ "weight_per_tensor": False,
232
+ "act_per_tensor": False,
233
+ "weight_block_size": [block_size, block_size],
234
+ "modules_to_not_convert": not_convert_modules
235
+ }
236
+
237
+ with open(os.path.join(output_dir, "config.json"), "w") as f:
238
+ json.dump(config, f, indent=2)
239
+
240
+ # Symlink everything else (code files, tokenizer, decoder, vae, etc.)
241
+ for fname in os.listdir(input_dir):
242
+ if fname.startswith("model-") and fname.endswith(".safetensors"):
243
+ continue
244
+ if fname in ("model.safetensors.index.json", "config.json"):
245
+ continue
246
+ src = os.path.join(input_dir, fname)
247
+ dst = os.path.join(output_dir, fname)
248
+ if os.path.exists(dst):
249
+ continue
250
+ os.symlink(src, dst)
251
+
252
+ # Print summary
253
+ gb_b = stats["bytes_before"] / 1024**3
254
+ gb_a = stats["bytes_after"] / 1024**3
255
+ print(f"\n{'='*60}")
256
+ print(f"✅ Block-wise FP8 conversion complete!")
257
+ print(f" Block size: {block_size}x{block_size}")
258
+ print(f" Quantized tensors: {stats['quantized']}")
259
+ print(f" Kept bf16 tensors: {stats['kept_bf16']}")
260
+ print(f" Weights: {gb_b:.2f} GB → {gb_a:.2f} GB (saved {gb_b-gb_a:.2f} GB, -{(1-gb_a/gb_b)*100:.1f}%)")
261
+ print(f" Output: {output_dir}")
262
+ print(f"{'='*60}")
263
+
264
+
265
+ if __name__ == "__main__":
266
+ main()
decoder-turbo/config.json CHANGED
@@ -13,11 +13,11 @@
13
  48
14
  ],
15
  "axes_lens": [
16
- 1536,
17
- 512,
18
- 512
19
  ],
20
- "cap_feat_dim": 2560,
21
  "dim": 3840,
22
  "in_channels": 16,
23
  "n_heads": 30,
 
13
  48
14
  ],
15
  "axes_lens": [
16
+ 32768,
17
+ 1024,
18
+ 1024
19
  ],
20
+ "cap_feat_dim": 4096,
21
  "dim": 3840,
22
  "in_channels": 16,
23
  "n_heads": 30,
decoder-turbo/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7589be2b548a3c1ef7e81431781e79c216ff3ae6e4f84c91397feeefef7d36dc
3
+ size 6160866440
decoder/config.json CHANGED
@@ -13,11 +13,11 @@
13
  48
14
  ],
15
  "axes_lens": [
16
- 1536,
17
- 512,
18
- 512
19
  ],
20
- "cap_feat_dim": 2560,
21
  "dim": 3840,
22
  "in_channels": 16,
23
  "n_heads": 30,
 
13
  48
14
  ],
15
  "axes_lens": [
16
+ 32768,
17
+ 1024,
18
+ 1024
19
  ],
20
+ "cap_feat_dim": 4096,
21
  "dim": 3840,
22
  "in_channels": 16,
23
  "n_heads": 30,
decoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ace13533ec063e0a1edb1b9819546e6b5bf79f23cf759aece3d0bccfd7f62933
3
+ size 6160866440
image_tokenizer/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0a11a82ad221ac1f3b917abfce31ffaaec3571200ae7ee5318a223ff2eedc49
3
+ size 2398968416
modeling_llada2uni_moe.py CHANGED
@@ -339,21 +339,129 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
339
  return q_embed, k_embed
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  class LLaDA2MoeMLP(nn.Module):
343
- def __init__(self, config: LLaDA2MoeConfig, intermediate_size: int):
344
  super().__init__()
345
  self.config = config
346
  self.hidden_size = config.hidden_size
347
  self.intermediate_size = intermediate_size
348
 
349
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
350
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
351
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 
352
  self.act_fn = ACT2FN[config.hidden_act]
353
 
354
  def forward(self, x):
355
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
356
 
 
 
 
 
 
 
 
357
 
358
  class LLaDA2MoeGate(nn.Module):
359
  def __init__(self, config):
@@ -446,16 +554,24 @@ class LLaDA2MoeSparseMoeBlock(nn.Module):
446
  )
447
 
448
  def _setup_experts(self):
 
449
  self.experts = nn.ModuleList(
450
  [
451
  LLaDA2MoeMLP(
452
  config=self.config,
453
  intermediate_size=self.config.moe_intermediate_size,
 
454
  )
455
  for _ in range(self.config.num_experts)
456
  ]
457
  )
458
 
 
 
 
 
 
 
459
  def forward(self, hidden_states):
460
  identity = hidden_states
461
  bsz, seq_len, h = hidden_states.shape
@@ -1109,6 +1225,20 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
1109
  def set_decoder(self, decoder):
1110
  self.model = decoder
1111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1112
  def get_decoder(self):
1113
  return self.model
1114
 
 
339
  return q_embed, k_embed
340
 
341
 
342
+ class FP8Linear(nn.Module):
343
+ """Drop-in replacement for nn.Linear that stores weights in float8_e4m3fn.
344
+
345
+ The weight is kept as ``float8_e4m3fn`` on GPU. During ``forward`` it is
346
+ dequantized back to the compute dtype (bf16/fp16) on-the-fly.
347
+
348
+ Supports two modes:
349
+ - **Per-tensor** (legacy): no scale stored; direct cast ``fp8 → compute_dtype``.
350
+ Works when weight magnitudes are well within fp8 range (±448).
351
+ - **Block-wise** (recommended): a ``weight_scale_inv`` buffer of shape
352
+ ``(ceil(out/block), ceil(in/block))`` stores per-block scales.
353
+ Dequantization: ``real_weight = fp8_weight * scale_expanded``.
354
+
355
+ This halves the GPU memory for expert weights — no custom CUDA kernel needed.
356
+ """
357
+
358
+ def __init__(self, in_features: int, out_features: int, bias: bool = False,
359
+ block_size: int = 128):
360
+ super().__init__()
361
+ self.in_features = in_features
362
+ self.out_features = out_features
363
+ self.block_size = block_size
364
+ # Placeholder – will be overwritten by state-dict loading
365
+ self.weight = nn.Parameter(
366
+ torch.empty(out_features, in_features, dtype=torch.float8_e4m3fn),
367
+ requires_grad=False,
368
+ )
369
+ # Optional block-wise scale — stored as a Parameter so from_pretrained can load it
370
+ n_bo = (out_features + block_size - 1) // block_size
371
+ n_bi = (in_features + block_size - 1) // block_size
372
+ self.weight_scale_inv = nn.Parameter(
373
+ torch.empty(n_bo, n_bi, dtype=torch.bfloat16), requires_grad=False
374
+ )
375
+ if bias:
376
+ self.bias = nn.Parameter(torch.zeros(out_features))
377
+ else:
378
+ self.bias = None
379
+
380
+ def _dequantize_weight(self, dtype: torch.dtype) -> torch.Tensor:
381
+ """Dequantize fp8 weight to the given compute dtype."""
382
+ w = self.weight.to(dtype)
383
+ # Block-wise dequantization
384
+ scale = self.weight_scale_inv.to(dtype) # (n_blocks_out, n_blocks_in)
385
+ bs = self.block_size
386
+ n_bo, n_bi = scale.shape
387
+ # Expand scale to match weight shape via repeat_interleave
388
+ scale_expanded = scale.repeat_interleave(bs, dim=0).repeat_interleave(bs, dim=1)
389
+ # Trim to actual weight shape (in case of padding during quantization)
390
+ scale_expanded = scale_expanded[:self.out_features, :self.in_features]
391
+ return w * scale_expanded
392
+
393
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
394
+ return F.linear(x, self._dequantize_weight(x.dtype), self.bias)
395
+
396
+ @classmethod
397
+ def from_linear(cls, linear: nn.Linear, block_size: int = 128) -> "FP8Linear":
398
+ """Convert a regular nn.Linear to FP8Linear with block-wise quantization."""
399
+ fp8_mod = cls(linear.in_features, linear.out_features,
400
+ bias=linear.bias is not None, block_size=block_size)
401
+ weight = linear.weight.data.float()
402
+ out_f, in_f = weight.shape
403
+ bs = block_size
404
+
405
+ # Compute block-wise scale
406
+ n_bo = (out_f + bs - 1) // bs
407
+ n_bi = (in_f + bs - 1) // bs
408
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
409
+
410
+ # Pad weight for even blocking
411
+ pad_out = n_bo * bs - out_f
412
+ pad_in = n_bi * bs - in_f
413
+ if pad_out > 0 or pad_in > 0:
414
+ padded = torch.zeros(n_bo * bs, n_bi * bs, dtype=torch.float32)
415
+ padded[:out_f, :in_f] = weight
416
+ else:
417
+ padded = weight
418
+
419
+ blocks = padded.reshape(n_bo, bs, n_bi, bs).permute(0, 2, 1, 3)
420
+ absmax = blocks.abs().amax(dim=(-2, -1)).clamp_min(1e-12) # (n_bo, n_bi)
421
+ scale = absmax / fp8_max
422
+
423
+ # Quantize
424
+ scale_exp = scale[:, :, None, None]
425
+ fp8_blocks = (blocks / scale_exp).clamp(-fp8_max, fp8_max)
426
+ fp8_full = fp8_blocks.permute(0, 2, 1, 3).reshape(n_bo * bs, n_bi * bs)
427
+ fp8_weight = fp8_full[:out_f, :in_f].to(torch.float8_e4m3fn)
428
+
429
+ fp8_mod.weight = nn.Parameter(fp8_weight, requires_grad=False)
430
+ fp8_mod.weight_scale_inv = nn.Parameter(scale.to(torch.bfloat16), requires_grad=False)
431
+ if linear.bias is not None:
432
+ fp8_mod.bias = nn.Parameter(linear.bias.data.clone())
433
+ return fp8_mod
434
+
435
+ def extra_repr(self) -> str:
436
+ has_scale = self.weight_scale_inv.numel() > 0
437
+ return (f"in_features={self.in_features}, out_features={self.out_features}, "
438
+ f"bias={self.bias is not None}, dtype=float8_e4m3fn, "
439
+ f"block_scale={'yes' if has_scale else 'no'}")
440
+
441
+
442
  class LLaDA2MoeMLP(nn.Module):
443
+ def __init__(self, config: LLaDA2MoeConfig, intermediate_size: int, use_fp8: bool = False):
444
  super().__init__()
445
  self.config = config
446
  self.hidden_size = config.hidden_size
447
  self.intermediate_size = intermediate_size
448
 
449
+ linear_cls = FP8Linear if use_fp8 else nn.Linear
450
+ self.gate_proj = linear_cls(self.hidden_size, self.intermediate_size, bias=False)
451
+ self.up_proj = linear_cls(self.hidden_size, self.intermediate_size, bias=False)
452
+ self.down_proj = linear_cls(self.intermediate_size, self.hidden_size, bias=False)
453
  self.act_fn = ACT2FN[config.hidden_act]
454
 
455
  def forward(self, x):
456
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
457
 
458
+ def to_fp8(self) -> "LLaDA2MoeMLP":
459
+ """Convert all Linear layers in this MLP to FP8Linear (in-place)."""
460
+ self.gate_proj = FP8Linear.from_linear(self.gate_proj)
461
+ self.up_proj = FP8Linear.from_linear(self.up_proj)
462
+ self.down_proj = FP8Linear.from_linear(self.down_proj)
463
+ return self
464
+
465
 
466
  class LLaDA2MoeGate(nn.Module):
467
  def __init__(self, config):
 
554
  )
555
 
556
  def _setup_experts(self):
557
+ use_fp8 = getattr(self.config, "use_fp8_experts", False)
558
  self.experts = nn.ModuleList(
559
  [
560
  LLaDA2MoeMLP(
561
  config=self.config,
562
  intermediate_size=self.config.moe_intermediate_size,
563
+ use_fp8=use_fp8,
564
  )
565
  for _ in range(self.config.num_experts)
566
  ]
567
  )
568
 
569
+ def convert_experts_to_fp8(self):
570
+ """Convert all routed experts to FP8 in-place (call after loading bf16 weights)."""
571
+ for expert in self.experts:
572
+ expert.to_fp8()
573
+ return self
574
+
575
  def forward(self, hidden_states):
576
  identity = hidden_states
577
  bsz, seq_len, h = hidden_states.shape
 
1225
  def set_decoder(self, decoder):
1226
  self.model = decoder
1227
 
1228
+ def convert_experts_to_fp8(self):
1229
+ """Convert all routed MoE experts to FP8 storage (in-place).
1230
+
1231
+ Call this after ``from_pretrained`` to halve expert memory::
1232
+
1233
+ model = AutoModelForCausalLM.from_pretrained(...)
1234
+ model.convert_experts_to_fp8()
1235
+ """
1236
+ for layer in self.model.layers:
1237
+ if hasattr(layer.mlp, "convert_experts_to_fp8"):
1238
+ layer.mlp.convert_experts_to_fp8()
1239
+ torch.cuda.empty_cache()
1240
+ return self
1241
+
1242
  def get_decoder(self):
1243
  return self.model
1244