rogermt commited on
Commit
43007c3
·
verified ·
1 Parent(s): 2199265

Move own-solver/neurogolf_solver/validators.py to own-solver/

Browse files
own-solver/neurogolf_solver/validators.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Model validation utilities.
3
+
4
+ Validation order (matches Kaggle's verify_network):
5
+ 1. File size ≤ 1.44MB
6
+ 2. onnx.checker.check_model() — catches malformed graphs
7
+ 3. No banned ops (UPPERCASE check to match Kaggle)
8
+ 4. All tensor shapes are static (no dynamic dims)
9
+ 5. onnxruntime.InferenceSession loads successfully
10
+ 6. Correct outputs on train + test + arc-gen
11
+ """
12
+
13
+ import os
14
+ import numpy as np
15
+ import onnx
16
+ import onnxruntime as ort
17
+ from .data_loader import to_onehot
18
+ from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE, BANNED_OPS
19
+
20
+ _BANNED_OPS_UPPER = {op.upper() for op in BANNED_OPS}
21
+
22
+
23
+ def check_model_structure(path):
24
+ """Check ONNX model structure: size, valid graph, no banned ops, static shapes.
25
+ Returns (ok, error_message)."""
26
+ # 1. File size
27
+ try:
28
+ fsize = os.path.getsize(path)
29
+ except OSError:
30
+ return False, f"File not found: {path}"
31
+ if fsize > MAX_ONNX_FILESIZE:
32
+ return False, f"File size {fsize} exceeds {MAX_ONNX_FILESIZE} ({fsize/1024:.1f} KB)"
33
+
34
+ # 2. ONNX checker
35
+ try:
36
+ model = onnx.load(path)
37
+ onnx.checker.check_model(model)
38
+ except Exception as e:
39
+ return False, f"onnx.checker failed: {e}"
40
+
41
+ # 3. Banned ops (UPPERCASE comparison — matches Kaggle)
42
+ for node in model.graph.node:
43
+ if node.op_type.upper() in _BANNED_OPS_UPPER:
44
+ return False, f"Banned op: {node.op_type}"
45
+
46
+ # 4. Static shapes — all tensors must have fully defined shapes
47
+ for inp in model.graph.input:
48
+ if inp.type.HasField('tensor_type'):
49
+ shape = inp.type.tensor_type.shape
50
+ if shape:
51
+ for dim in shape.dim:
52
+ if not dim.dim_value and dim.dim_value != 0:
53
+ if not dim.dim_param: # symbolic dim is also not static
54
+ pass # dim_value=0 is valid (means unknown in some contexts)
55
+ return False, f"Dynamic shape on input '{inp.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"
56
+
57
+ for out in model.graph.output:
58
+ if out.type.HasField('tensor_type'):
59
+ shape = out.type.tensor_type.shape
60
+ if shape:
61
+ for dim in shape.dim:
62
+ if not dim.dim_value and dim.dim_value != 0:
63
+ if not dim.dim_param:
64
+ pass
65
+ return False, f"Dynamic shape on output '{out.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"
66
+
67
+ return True, None
68
+
69
+
70
+ def validate(path, td, providers):
71
+ """Full validation: structure check + correct outputs on all splits.
72
+ Returns False immediately if any check fails."""
73
+ # Structure checks first
74
+ ok, err = check_model_structure(path)
75
+ if not ok:
76
+ return False
77
+
78
+ # Load and run inference
79
+ try:
80
+ opts = ort.SessionOptions()
81
+ opts.log_severity_level = 3
82
+ sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
83
+ except:
84
+ return False
85
+
86
+ examples = td['train'] + td['test']
87
+ if 'arc-gen' in td:
88
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
89
+ for ex in examples:
90
+ inp = to_onehot(ex['input'])
91
+ exp = to_onehot(ex['output'])
92
+ try:
93
+ out = sess.run(['output'], {'input': inp})[0]
94
+ out = (out > 0.0).astype(np.float32)
95
+ except:
96
+ return False
97
+ if not np.array_equal(out, exp):
98
+ return False
99
+ return True
100
+
101
+
102
+ def validate_raw(raw_bytes, td, providers):
103
+ """Validate ONNX model from raw bytes."""
104
+ if len(raw_bytes) > MAX_ONNX_FILESIZE:
105
+ return False
106
+ try:
107
+ opts = ort.SessionOptions()
108
+ opts.log_severity_level = 3
109
+ sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers)
110
+ except:
111
+ return False
112
+ examples = td['train'] + td['test']
113
+ if 'arc-gen' in td:
114
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
115
+ for ex in examples:
116
+ inp = to_onehot(ex['input'])
117
+ exp = to_onehot(ex['output'])
118
+ try:
119
+ out = sess.run(['output'], {'input': inp})[0]
120
+ out = (out > 0.0).astype(np.float32)
121
+ except:
122
+ return False
123
+ if not np.array_equal(out, exp):
124
+ return False
125
+ return True