| import torch |
| from safetensors.torch import load_file, save_file |
| import logging |
| import asyncio |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
| |
| MODEL_CHECKPOINT = "model-3-of-10.safetensors" |
|
|
| |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| async def load_model(filepath: str) -> dict: |
| """Asynchronously loads a model from a safetensors file.""" |
| try: |
| logging.info(f"Loading model from {filepath} on {DEVICE}...") |
| model_data = load_file(filepath, device=DEVICE) |
| logging.info(f"Model {filepath} successfully loaded.") |
| return model_data |
| except Exception as e: |
| logging.error(f"Error loading model: {str(e)}") |
| raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
| |
| async def save_model(filepath: str, model_tensors: dict): |
| """Asynchronously saves a model to a safetensors file.""" |
| try: |
| logging.info(f"Saving model to {filepath}...") |
| save_file(model_tensors, filepath) |
| logging.info(f"Model saved at {filepath}") |
| except Exception as e: |
| logging.error(f"Error saving model: {str(e)}") |
| raise RuntimeError(f"Error saving model: {str(e)}") |
|
|
| |
| def initialize_model(layers: list = [4096, 8192, 16384], dtype: torch.dtype = torch.float16) -> dict: |
| """Initializes a model with random tensors for each layer.""" |
| model_tensors = {} |
| for i, size in enumerate(layers): |
| layer_name = f"layer_{i+1}" |
| logging.info(f"Initializing {layer_name} with size {size}x{size} on {DEVICE}...") |
| model_tensors[layer_name] = torch.randn(size, size, dtype=dtype, device=DEVICE) |
| |
| torch.cuda.empty_cache() |
| logging.info("Model initialization completed.") |
| return model_tensors |
|
|
| |
| async def main(): |
| model_data = initialize_model() |
|
|
| |
| await save_model(MODEL_CHECKPOINT, model_data) |
|
|
| |
| loaded_model_data = await load_model(MODEL_CHECKPOINT) |
|
|
| |
| for key in model_data: |
| if not torch.allclose(model_data[key], loaded_model_data[key], atol=1e-5): |
| logging.warning(f"Tensor mismatch in {key}!") |
| else: |
| logging.info(f"Tensor {key} verified successfully.") |
|
|
| |
| asyncio.run(main()) |