import gzip import io import os import torch from typing import Optional from cryptography.hazmat.primitives.ciphers.aead import AESGCM def _parse_key(key_str: str) -> bytes: key_str = key_str.strip() try: key = bytes.fromhex(key_str) if len(key) == 32: return key except ValueError: pass key = key_str.encode("utf-8") if len(key) == 32: return key raise ValueError("Key must be either a 64-character hex string or a 32-character raw string.") def _get_key(key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes: if key is not None: return _parse_key(key) env_value = os.environ.get(env_var) if not env_value: raise RuntimeError("Missing key. Provide key=... or set environment variable {}.".format(env_var)) return _parse_key(env_value) def decrypt_and_decompress_to_bytes(path: str, key: Optional[str] = None, env_var: str = "MODEL_KEY") -> bytes: key_bytes = _get_key(key=key, env_var=env_var) aesgcm = AESGCM(key_bytes) with open(path, "rb") as f: data = f.read() if len(data) < 13: raise ValueError("Encrypted file is too short or invalid.") nonce = data[:12] ciphertext = data[12:] compressed = aesgcm.decrypt(nonce, ciphertext, None) plaintext = gzip.decompress(compressed) return plaintext def secure_torch_load(path: str, *args, key: Optional[str] = None, env_var: str = "MODEL_KEY", **kwargs): plaintext = decrypt_and_decompress_to_bytes(path, key=key, env_var=env_var) buffer = io.BytesIO(plaintext) return torch.load(buffer, *args, **kwargs)