rogermt commited on
Commit
2bdf288
·
verified ·
1 Parent(s): 9e704d4

Fix 2: validators.py — add onnx.checker, banned op scan (UPPERCASE), static shape check

Browse files
Files changed (1) hide show
  1. neurogolf_solver/validators.py +68 -9
neurogolf_solver/validators.py CHANGED
@@ -1,28 +1,88 @@
1
  #!/usr/bin/env python3
2
- """Model validation utilities."""
 
 
 
 
 
 
 
 
 
3
 
4
  import os
5
  import numpy as np
 
6
  import onnxruntime as ort
7
  from .data_loader import to_onehot
8
- from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE
9
 
 
10
 
11
- def validate(path, td, providers):
12
- """Validate ONNX model against task data.
13
- Returns False immediately if file exceeds 1.44MB."""
14
- # Size check FIRST — before wasting time on inference
 
15
  try:
16
- if os.path.getsize(path) > MAX_ONNX_FILESIZE:
17
- return False
18
  except OSError:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return False
 
 
20
  try:
21
  opts = ort.SessionOptions()
22
  opts.log_severity_level = 3
23
  sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
24
  except:
25
  return False
 
26
  examples = td['train'] + td['test']
27
  if 'arc-gen' in td:
28
  examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
@@ -41,7 +101,6 @@ def validate(path, td, providers):
41
 
42
  def validate_raw(raw_bytes, td, providers):
43
  """Validate ONNX model from raw bytes."""
44
- # Size check
45
  if len(raw_bytes) > MAX_ONNX_FILESIZE:
46
  return False
47
  try:
 
1
  #!/usr/bin/env python3
2
+ """Model validation utilities.
3
+
4
+ Validation order (matches Kaggle's verify_network):
5
+ 1. File size ≤ 1.44MB
6
+ 2. onnx.checker.check_model() — catches malformed graphs
7
+ 3. No banned ops (UPPERCASE check to match Kaggle)
8
+ 4. All tensor shapes are static (no dynamic dims)
9
+ 5. onnxruntime.InferenceSession loads successfully
10
+ 6. Correct outputs on train + test + arc-gen
11
+ """
12
 
13
  import os
14
  import numpy as np
15
+ import onnx
16
  import onnxruntime as ort
17
  from .data_loader import to_onehot
18
+ from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE, BANNED_OPS
19
 
20
+ _BANNED_OPS_UPPER = {op.upper() for op in BANNED_OPS}
21
 
22
+
23
+ def check_model_structure(path):
24
+ """Check ONNX model structure: size, valid graph, no banned ops, static shapes.
25
+ Returns (ok, error_message)."""
26
+ # 1. File size
27
  try:
28
+ fsize = os.path.getsize(path)
 
29
  except OSError:
30
+ return False, f"File not found: {path}"
31
+ if fsize > MAX_ONNX_FILESIZE:
32
+ return False, f"File size {fsize} exceeds {MAX_ONNX_FILESIZE} ({fsize/1024:.1f} KB)"
33
+
34
+ # 2. ONNX checker
35
+ try:
36
+ model = onnx.load(path)
37
+ onnx.checker.check_model(model)
38
+ except Exception as e:
39
+ return False, f"onnx.checker failed: {e}"
40
+
41
+ # 3. Banned ops (UPPERCASE comparison — matches Kaggle)
42
+ for node in model.graph.node:
43
+ if node.op_type.upper() in _BANNED_OPS_UPPER:
44
+ return False, f"Banned op: {node.op_type}"
45
+
46
+ # 4. Static shapes — all tensors must have fully defined shapes
47
+ for inp in model.graph.input:
48
+ if inp.type.HasField('tensor_type'):
49
+ shape = inp.type.tensor_type.shape
50
+ if shape:
51
+ for dim in shape.dim:
52
+ if not dim.dim_value and dim.dim_value != 0:
53
+ if not dim.dim_param: # symbolic dim is also not static
54
+ pass # dim_value=0 is valid (means unknown in some contexts)
55
+ return False, f"Dynamic shape on input '{inp.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"
56
+
57
+ for out in model.graph.output:
58
+ if out.type.HasField('tensor_type'):
59
+ shape = out.type.tensor_type.shape
60
+ if shape:
61
+ for dim in shape.dim:
62
+ if not dim.dim_value and dim.dim_value != 0:
63
+ if not dim.dim_param:
64
+ pass
65
+ return False, f"Dynamic shape on output '{out.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"
66
+
67
+ return True, None
68
+
69
+
70
+ def validate(path, td, providers):
71
+ """Full validation: structure check + correct outputs on all splits.
72
+ Returns False immediately if any check fails."""
73
+ # Structure checks first
74
+ ok, err = check_model_structure(path)
75
+ if not ok:
76
  return False
77
+
78
+ # Load and run inference
79
  try:
80
  opts = ort.SessionOptions()
81
  opts.log_severity_level = 3
82
  sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
83
  except:
84
  return False
85
+
86
  examples = td['train'] + td['test']
87
  if 'arc-gen' in td:
88
  examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
 
101
 
102
  def validate_raw(raw_bytes, td, providers):
103
  """Validate ONNX model from raw bytes."""
 
104
  if len(raw_bytes) > MAX_ONNX_FILESIZE:
105
  return False
106
  try: