| |
|
|
| import numpy as np |
| import sys |
| import os |
| from pathlib import Path |
|
|
| def quick_logits_check(pytorch_file, llamacpp_file): |
| """Lightweight sanity check before NMSE""" |
|
|
| try: |
| pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32) |
| llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32) |
| except Exception as e: |
| print(f"β NOK: Failed to load files - {e}") |
| return False |
|
|
| |
| if pytorch_logits.shape != llamacpp_logits.shape: |
| print(f"β NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}") |
| return False |
|
|
| |
| diff = pytorch_logits - llamacpp_logits |
| abs_diff = np.abs(diff) |
| max_diff = np.max(abs_diff) |
|
|
| |
| pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1] |
| llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1] |
| print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}") |
| print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}") |
| print(f"Max absolute difference: {max_diff:.4f}") |
|
|
| if max_diff > 1.0: |
| print(f"β NOK: Large differences detected - max diff: {max_diff:.4f}") |
| return False |
|
|
| return True |
|
|
| def main(): |
| model_path = os.getenv('MODEL_PATH') |
| if not model_path: |
| print("Error: MODEL_PATH environment variable not set") |
| sys.exit(1) |
|
|
| if not os.path.exists(model_path): |
| print(f"Error: Model file not found: {model_path}") |
| sys.exit(1) |
|
|
| model_name = os.path.basename(model_path) |
| data_dir = Path("data") |
|
|
| pytorch_file = data_dir / f"pytorch-{model_name}.bin" |
| llamacpp_file = data_dir / f"llamacpp-{model_name}.bin" |
|
|
| if not pytorch_file.exists(): |
| print(f"Error: PyTorch logits file not found: {pytorch_file}") |
| print("Please run scripts/run-org-model.sh first to generate this file.") |
| sys.exit(1) |
|
|
| if not llamacpp_file.exists(): |
| print(f"Error: llama.cpp logits file not found: {llamacpp_file}") |
| print("Please run scripts/run-converted-model.sh first to generate this file.") |
| sys.exit(1) |
|
|
| print("Checked all required files were found. Proceeding...\n") |
|
|
|
|
| print("π GGML Model Validation for model ", model_name) |
| print("=" * 40) |
| print(f"PyTorch logits : {pytorch_file}") |
| print(f"llama.cpp logits: {llamacpp_file}") |
| print() |
|
|
| success = quick_logits_check(pytorch_file, llamacpp_file) |
|
|
| |
| if success: |
| print("β
OK: Lightweight model check successful!") |
| print(" Ok to proceed with NMSE check...") |
| sys.exit(0) |
| else: |
| print(f"β NOK: Top 10 predictions don't match - generation will differ") |
| sys.exit(1) |
|
|
| if __name__ == "__main__": |
| main() |
|
|