| import sys |
| from pathlib import Path |
|
|
| |
| sys.path.append(str(Path.cwd())) |
|
|
| from model import load_model |
| from mlx.utils import tree_flatten |
|
|
|
|
| def run_diagnostic_checks(): |
| """ |
| Performs the verification checks outlined in the review. |
| """ |
| print("--- Running Diagnostic Checks ---") |
|
|
| |
| try: |
| model = load_model(".") |
| print("Successfully loaded model definition.") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return |
|
|
| |
| try: |
| params = model.parameters() |
| num_params = sum(p.size for _, p in tree_flatten(params)) |
| print(f"Total number of parameters: {num_params / 1e6:.2f}M") |
| except Exception as e: |
| print(f"Error calculating parameters: {e}") |
|
|
| |
| print("--- Verifying MLP Weight Shapes ---") |
| try: |
| first_block = model.layers[0] |
| args = model.args |
| print(f"use_dual_mlp detected: {args.use_dual_mlp}") |
|
|
| if args.use_dual_mlp: |
| g_up_shape = first_block.feed_forward.g_up.weight.shape |
| p_up_shape = first_block.feed_forward.p_up.weight.shape |
| print(f"Gated MLP branch (g_up) weight shape: {g_up_shape}") |
| print(f"Plain MLP branch (p_up) weight shape: {p_up_shape}") |
| assert g_up_shape == (args.intermediate_size, args.hidden_size) |
| assert p_up_shape == (args.intermediate_size_mlp, args.hidden_size) |
| print("DualMLP weight shapes are correct.") |
| else: |
| gate_proj_shape = first_block.feed_forward.gate_proj.weight.shape |
| up_proj_shape = first_block.feed_forward.up_proj.weight.shape |
| print(f"SwiGLUMLP gate_proj weight shape: {gate_proj_shape}") |
| print(f"SwiGLUMLP up_proj weight shape: {up_proj_shape}") |
| assert gate_proj_shape == (args.intermediate_size_mlp, args.hidden_size) |
| assert up_proj_shape == (args.intermediate_size_mlp, args.hidden_size) |
| print("SwiGLUMLP weight shapes are correct.") |
|
|
| except AttributeError as e: |
| print( |
| f"Error accessing MLP weights. It seems the structure is not as expected: {e}" |
| ) |
| except AssertionError: |
| print("Error: MLP weight shapes do not match the configuration.") |
| except Exception as e: |
| print(f"An unexpected error occurred while verifying shapes: {e}") |
|
|
| |
| print("--- Verifying Embedding Shape ---") |
| try: |
| embedding_shape = model.tok_embeddings.weight.shape |
| print(f"Embedding weight shape: {embedding_shape}") |
|
|
| args = model.args |
| print(f"Expected embedding shape: ({args.vocab_size}, {args.hidden_size})") |
|
|
| assert embedding_shape == (args.vocab_size, args.hidden_size) |
| print("Embedding shape is correct.") |
| except Exception as e: |
| print(f"An unexpected error occurred while verifying embedding shape: {e}") |
|
|
| print("--- Sanity Checking Loaded Weights ---") |
| try: |
| |
| if model.args.use_dual_mlp: |
| _ = model.layers[0].feed_forward.g_gate.weight |
| _ = model.layers[0].feed_forward.g_up.weight |
| _ = model.layers[0].feed_forward.g_down.weight |
| _ = model.layers[0].feed_forward.p_up.weight |
| _ = model.layers[0].feed_forward.p_down.weight |
| print("Found dual-branch MLP weights in the model.") |
| else: |
| _ = model.layers[0].feed_forward.gate_proj.weight |
| _ = model.layers[0].feed_forward.up_proj.weight |
| _ = model.layers[0].feed_forward.down_proj.weight |
| print("Found SwiGLU MLP weights in the model.") |
| print("Weight presence sanity check passed.") |
| except Exception as e: |
| print(f"An error occurred during sanity check: {e}") |
|
|
| print("--- Diagnostic Checks Complete ---") |
|
|
|
|
| if __name__ == "__main__": |
| run_diagnostic_checks() |
|
|