| |
| """ |
| Quick validation script for converted HYV3 outer-format checkpoint. |
| |
| Checks: |
| 1. model.safetensors.index.json structure and completeness |
| 2. All expected weight keys exist (dense layer 0, MoE layers 1-79) |
| 3. Expert tensor shapes (fused 3D format) |
| 4. All referenced shard files exist and are non-empty |
| 5. Spot-check: load a few shards and verify tensor shapes/dtypes |
| 6. No duplicate or orphan keys |
| |
| Usage: |
| python check_converted.py <output_dir> [--spot-check N] |
| |
| Example: |
| python check_converted.py pretrain_base/hf_outer |
| python check_converted.py pretrain_base/hf_outer --spot-check 5 |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| import time |
| from collections import defaultdict |
|
|
| |
| |
| |
|
|
| |
| DENSE_SUFFIXES = [ |
| "input_layernorm.weight", |
| "post_attention_layernorm.weight", |
| "self_attn.q_proj.weight", |
| "self_attn.k_proj.weight", |
| "self_attn.v_proj.weight", |
| "self_attn.o_proj.weight", |
| "self_attn.q_norm.weight", |
| "self_attn.k_norm.weight", |
| "mlp.gate_proj.weight", |
| "mlp.up_proj.weight", |
| "mlp.down_proj.weight", |
| ] |
|
|
| |
| MOE_SUFFIXES = [ |
| "input_layernorm.weight", |
| "post_attention_layernorm.weight", |
| "self_attn.q_proj.weight", |
| "self_attn.k_proj.weight", |
| "self_attn.v_proj.weight", |
| "self_attn.o_proj.weight", |
| "self_attn.q_norm.weight", |
| "self_attn.k_norm.weight", |
| |
| "mlp.gate.weight", |
| "mlp.e_score_correction_bias", |
| "mlp.experts.gate_up_proj", |
| "mlp.experts.down_proj", |
| "mlp.shared_experts.gate_proj.weight", |
| "mlp.shared_experts.up_proj.weight", |
| "mlp.shared_experts.down_proj.weight", |
| ] |
|
|
| |
| |
| MTP_EXTRA_SUFFIXES = [ |
| "eh_proj.weight", |
| "enorm.weight", |
| "final_layernorm.weight", |
| "hnorm.weight", |
| ] |
|
|
| |
| GLOBAL_KEYS = [ |
| "model.embed_tokens.weight", |
| "model.norm.weight", |
| "lm_head.weight", |
| ] |
|
|
|
|
| def load_config(output_dir): |
| """Load config.json and extract model parameters.""" |
| config_path = os.path.join(output_dir, "config.json") |
| if not os.path.exists(config_path): |
| print(f"[ERROR] config.json not found in {output_dir}") |
| return None |
| with open(config_path) as f: |
| return json.load(f) |
|
|
|
|
| def check_index_json(output_dir): |
| """Check model.safetensors.index.json for structure and completeness.""" |
| index_path = os.path.join(output_dir, "model.safetensors.index.json") |
| if not os.path.exists(index_path): |
| print(f"[ERROR] model.safetensors.index.json not found") |
| return None, [] |
|
|
| with open(index_path) as f: |
| index = json.load(f) |
|
|
| errors = [] |
|
|
| |
| if "metadata" not in index: |
| errors.append("Missing 'metadata' in index.json") |
| elif "total_size" not in index["metadata"]: |
| errors.append("Missing 'total_size' in metadata") |
|
|
| if "weight_map" not in index: |
| errors.append("Missing 'weight_map' in index.json") |
| return index, errors |
|
|
| weight_map = index["weight_map"] |
| total_size = index.get("metadata", {}).get("total_size", 0) |
|
|
| print(f" Index keys : {len(weight_map)}") |
| print(f" Total size : {total_size / 1e9:.2f} GB") |
|
|
| |
| if len(weight_map) == 0: |
| errors.append("weight_map is empty") |
|
|
| return index, errors |
|
|
|
|
| def check_expected_keys(weight_map, config): |
| """Check that all expected keys exist in the weight_map.""" |
| errors = [] |
| warnings = [] |
|
|
| num_layers = config.get("num_hidden_layers", 80) |
| first_k_dense = config.get("first_k_dense_replace", 1) |
| num_experts = config.get("num_experts", 192) |
| num_mtp_layers = config.get("num_nextn_predict_layers", 0) |
|
|
| |
| for key in GLOBAL_KEYS: |
| if key not in weight_map: |
| errors.append(f"Missing global key: {key}") |
|
|
| |
| missing_by_type = defaultdict(list) |
| for layer_idx in range(num_layers): |
| prefix = f"model.layers.{layer_idx}." |
| if layer_idx < first_k_dense: |
| |
| suffixes = DENSE_SUFFIXES |
| else: |
| |
| suffixes = MOE_SUFFIXES |
|
|
| for suffix in suffixes: |
| full_key = prefix + suffix |
| if full_key not in weight_map: |
| missing_by_type[suffix].append(layer_idx) |
|
|
| |
| mtp_missing_by_type = defaultdict(list) |
| for mtp_idx in range(num_mtp_layers): |
| layer_idx = num_layers + mtp_idx |
| prefix = f"model.layers.{layer_idx}." |
| |
| mtp_suffixes = MOE_SUFFIXES + MTP_EXTRA_SUFFIXES |
| for suffix in mtp_suffixes: |
| full_key = prefix + suffix |
| if full_key not in weight_map: |
| mtp_missing_by_type[suffix].append(layer_idx) |
|
|
| for suffix, layers in sorted(mtp_missing_by_type.items()): |
| layer_str = str(layers) |
| errors.append(f"Missing MTP key '{suffix}' in layers: {layer_str}") |
|
|
| for suffix, layers in sorted(missing_by_type.items()): |
| if len(layers) <= 5: |
| layer_str = str(layers) |
| else: |
| layer_str = f"{layers[:3]}...({len(layers)} total)" |
| errors.append(f"Missing '{suffix}' in layers: {layer_str}") |
|
|
| |
| known_prefixes = set() |
| |
| for layer_idx in range(num_layers + num_mtp_layers): |
| known_prefixes.add(f"model.layers.{layer_idx}.") |
| known_prefixes.add("model.embed_tokens.") |
| known_prefixes.add("model.norm.") |
| known_prefixes.add("lm_head.") |
| |
| known_prefixes.add("model.mtp_layers.") |
|
|
| unexpected = [] |
| for key in weight_map: |
| if not any(key.startswith(p) for p in known_prefixes): |
| unexpected.append(key) |
|
|
| if unexpected: |
| if len(unexpected) <= 5: |
| for k in unexpected: |
| warnings.append(f"Unexpected key: {k}") |
| else: |
| warnings.append(f"{len(unexpected)} unexpected keys found (first 3: {unexpected[:3]})") |
|
|
| return errors, warnings |
|
|
|
|
| def check_shard_files(output_dir, weight_map): |
| """Check that all referenced shard files exist and are non-empty.""" |
| errors = [] |
| warnings = [] |
|
|
| |
| shard_files = sorted(set(weight_map.values())) |
| print(f" Shard files : {len(shard_files)}") |
|
|
| missing = [] |
| empty = [] |
| total_disk_size = 0 |
|
|
| for sf in shard_files: |
| path = os.path.join(output_dir, sf) |
| if not os.path.exists(path): |
| missing.append(sf) |
| else: |
| size = os.path.getsize(path) |
| if size == 0: |
| empty.append(sf) |
| total_disk_size += size |
|
|
| print(f" Disk size : {total_disk_size / 1e9:.2f} GB") |
|
|
| if missing: |
| errors.append(f"Missing shard files ({len(missing)}): {missing[:5]}") |
| if empty: |
| errors.append(f"Empty shard files ({len(empty)}): {empty[:5]}") |
|
|
| |
| all_safetensors = set( |
| f for f in os.listdir(output_dir) |
| if f.endswith(".safetensors") |
| ) |
| referenced = set(shard_files) |
| orphans = all_safetensors - referenced |
| if orphans: |
| |
| |
| EMPTY_SHARD_THRESHOLD = 128 |
| residue_orphans = [] |
| real_orphans = [] |
| for o in sorted(orphans): |
| sz = os.path.getsize(os.path.join(output_dir, o)) |
| if sz <= EMPTY_SHARD_THRESHOLD: |
| residue_orphans.append(o) |
| else: |
| real_orphans.append(o) |
|
|
| if residue_orphans: |
| warnings.append( |
| f"{len(residue_orphans)} empty residue shard(s) from cross-shard merge " |
| f"(<=128 bytes each, safe to delete)" |
| ) |
| if real_orphans: |
| errors.append( |
| f"Orphan shard files with data (not in index): {real_orphans[:5]}" |
| ) |
|
|
| return errors, warnings |
|
|
|
|
| def check_key_distribution(weight_map): |
| """Check the distribution of keys across shards.""" |
| shard_key_count = defaultdict(int) |
| for key, shard in weight_map.items(): |
| shard_key_count[shard] += 1 |
|
|
| counts = sorted(shard_key_count.values()) |
| print(f" Keys/shard : min={counts[0]}, max={counts[-1]}, " |
| f"median={counts[len(counts)//2]}") |
|
|
| |
| zero_shards = [s for s, c in shard_key_count.items() if c == 0] |
| if zero_shards: |
| return [f"Shards with 0 keys: {zero_shards}"] |
| return [] |
|
|
|
|
| def spot_check_shards(output_dir, weight_map, config, num_checks=3): |
| """Spot-check a few shards by loading and verifying tensor shapes.""" |
| errors = [] |
|
|
| try: |
| from safetensors import safe_open |
| except ImportError: |
| print(" [SKIP] safetensors not installed, skipping spot-check") |
| return errors |
|
|
| num_experts = config.get("num_experts", 192) |
| expert_hidden = config.get("expert_hidden_dim", config.get("moe_intermediate_size", 1536)) |
| hidden_size = config.get("hidden_size", 4096) |
|
|
| |
| expert_shards = set() |
| for key, shard in weight_map.items(): |
| if "experts.gate_up_proj" in key or "experts.down_proj" in key: |
| expert_shards.add(shard) |
|
|
| |
| check_shards = sorted(expert_shards)[:num_checks] |
| if not check_shards: |
| check_shards = sorted(set(weight_map.values()))[:num_checks] |
|
|
| print(f"\n Spot-checking {len(check_shards)} shard(s)...") |
|
|
| for shard_file in check_shards: |
| shard_path = os.path.join(output_dir, shard_file) |
| t0 = time.time() |
|
|
| try: |
| with safe_open(shard_path, framework="pt", device="cpu") as f: |
| keys_in_shard = list(f.keys()) |
| for key in keys_in_shard: |
| tensor = f.get_tensor(key) |
|
|
| |
| if key.endswith("experts.gate_up_proj"): |
| expected_shape = (num_experts, expert_hidden * 2, hidden_size) |
| if tuple(tensor.shape) != expected_shape: |
| errors.append( |
| f"{shard_file}/{key}: shape {tuple(tensor.shape)} " |
| f"!= expected {expected_shape}" |
| ) |
|
|
| elif key.endswith("experts.down_proj"): |
| expected_shape = (num_experts, hidden_size, expert_hidden) |
| if tuple(tensor.shape) != expected_shape: |
| errors.append( |
| f"{shard_file}/{key}: shape {tuple(tensor.shape)} " |
| f"!= expected {expected_shape}" |
| ) |
|
|
| |
| if tensor.is_floating_point(): |
| if tensor.isnan().any(): |
| errors.append(f"{shard_file}/{key}: contains NaN values") |
| if tensor.isinf().any(): |
| errors.append(f"{shard_file}/{key}: contains Inf values") |
|
|
| elapsed = time.time() - t0 |
| print(f" {shard_file}: {len(keys_in_shard)} keys, OK ({elapsed:.1f}s)") |
|
|
| except Exception as e: |
| errors.append(f"Failed to load {shard_file}: {e}") |
|
|
| return errors |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Validate converted HYV3 outer-format checkpoint." |
| ) |
| parser.add_argument( |
| "output_dir", type=str, |
| help="Path to the converted outer-format checkpoint directory.", |
| ) |
| parser.add_argument( |
| "--spot-check", type=int, default=3, dest="spot_check", |
| help="Number of shards to spot-check by loading tensors (default: 3).", |
| ) |
| args = parser.parse_args() |
|
|
| output_dir = os.path.abspath(args.output_dir) |
| print(f"Validating: {output_dir}\n") |
|
|
| if not os.path.isdir(output_dir): |
| print(f"[ERROR] Directory not found: {output_dir}") |
| sys.exit(1) |
|
|
| all_errors = [] |
| all_warnings = [] |
|
|
| |
| print("[1/5] Loading config.json...") |
| config = load_config(output_dir) |
| if config is None: |
| print("[ERROR] Cannot proceed without config.json") |
| sys.exit(1) |
|
|
| num_layers = config.get("num_hidden_layers", 0) |
| num_experts = config.get("num_experts", 0) |
| first_k_dense = config.get("first_k_dense_replace", 0) |
| num_mtp = config.get("num_nextn_predict_layers", 0) |
| print(f" Layers : {num_layers} ({first_k_dense} dense, {num_layers - first_k_dense} MoE)") |
| print(f" MTP layers : {num_mtp}") |
| print(f" Experts/layer : {num_experts}") |
| print(f" Hidden size : {config.get('hidden_size', '?')}") |
| print(f" Expert hidden : {config.get('expert_hidden_dim', config.get('moe_intermediate_size', '?'))}") |
|
|
| |
| print("\n[2/5] Checking model.safetensors.index.json...") |
| index, idx_errors = check_index_json(output_dir) |
| all_errors.extend(idx_errors) |
|
|
| if index is None or "weight_map" not in index: |
| print("[ERROR] Cannot proceed without valid index.json") |
| sys.exit(1) |
|
|
| weight_map = index["weight_map"] |
|
|
| |
| print("\n[3/5] Checking expected keys...") |
| key_errors, key_warnings = check_expected_keys(weight_map, config) |
| all_errors.extend(key_errors) |
| all_warnings.extend(key_warnings) |
|
|
| |
| dist_errors = check_key_distribution(weight_map) |
| all_errors.extend(dist_errors) |
|
|
| |
| print("\n[4/5] Checking shard files on disk...") |
| shard_errors, shard_warnings = check_shard_files(output_dir, weight_map) |
| all_errors.extend(shard_errors) |
| all_warnings.extend(shard_warnings) |
|
|
| |
| if args.spot_check > 0: |
| print(f"\n[5/5] Spot-checking tensors (loading {args.spot_check} shard(s))...") |
| spot_errors = spot_check_shards(output_dir, weight_map, config, args.spot_check) |
| all_errors.extend(spot_errors) |
| else: |
| print("\n[5/5] Spot-check skipped (--spot-check 0)") |
|
|
| |
| print(f"\n{'=' * 60}") |
| if all_warnings: |
| print(f"WARNINGS ({len(all_warnings)}):") |
| for w in all_warnings: |
| print(f" [WARN] {w}") |
|
|
| if all_errors: |
| print(f"ERRORS ({len(all_errors)}):") |
| for e in all_errors: |
| print(f" [ERROR] {e}") |
| print(f"\nResult: FAILED ({len(all_errors)} error(s), {len(all_warnings)} warning(s))") |
| sys.exit(1) |
| else: |
| print(f"Result: PASSED (0 errors, {len(all_warnings)} warning(s))") |
| print(f"{'=' * 60}") |
| sys.exit(0) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|