rogermt's picture
Move own-solver/neurogolf_solver/validators.py to own-solver/
43007c3 verified
#!/usr/bin/env python3
"""Model validation utilities.
Validation order (matches Kaggle's verify_network):
1. File size ≤ 1.44MB
2. onnx.checker.check_model() — catches malformed graphs
3. No banned ops (UPPERCASE check to match Kaggle)
4. All tensor shapes are static (no dynamic dims)
5. onnxruntime.InferenceSession loads successfully
6. Correct outputs on train + test + arc-gen
"""
import os
import numpy as np
import onnx
import onnxruntime as ort
from .data_loader import to_onehot
from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE, BANNED_OPS
_BANNED_OPS_UPPER = {op.upper() for op in BANNED_OPS}
def check_model_structure(path):
"""Check ONNX model structure: size, valid graph, no banned ops, static shapes.
Returns (ok, error_message)."""
# 1. File size
try:
fsize = os.path.getsize(path)
except OSError:
return False, f"File not found: {path}"
if fsize > MAX_ONNX_FILESIZE:
return False, f"File size {fsize} exceeds {MAX_ONNX_FILESIZE} ({fsize/1024:.1f} KB)"
# 2. ONNX checker
try:
model = onnx.load(path)
onnx.checker.check_model(model)
except Exception as e:
return False, f"onnx.checker failed: {e}"
# 3. Banned ops (UPPERCASE comparison — matches Kaggle)
for node in model.graph.node:
if node.op_type.upper() in _BANNED_OPS_UPPER:
return False, f"Banned op: {node.op_type}"
# 4. Static shapes — all tensors must have fully defined shapes
for inp in model.graph.input:
if inp.type.HasField('tensor_type'):
shape = inp.type.tensor_type.shape
if shape:
for dim in shape.dim:
if not dim.dim_value and dim.dim_value != 0:
if not dim.dim_param: # symbolic dim is also not static
pass # dim_value=0 is valid (means unknown in some contexts)
return False, f"Dynamic shape on input '{inp.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"
for out in model.graph.output:
if out.type.HasField('tensor_type'):
shape = out.type.tensor_type.shape
if shape:
for dim in shape.dim:
if not dim.dim_value and dim.dim_value != 0:
if not dim.dim_param:
pass
return False, f"Dynamic shape on output '{out.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"
return True, None
def validate(path, td, providers):
"""Full validation: structure check + correct outputs on all splits.
Returns False immediately if any check fails."""
# Structure checks first
ok, err = check_model_structure(path)
if not ok:
return False
# Load and run inference
try:
opts = ort.SessionOptions()
opts.log_severity_level = 3
sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
except:
return False
examples = td['train'] + td['test']
if 'arc-gen' in td:
examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
for ex in examples:
inp = to_onehot(ex['input'])
exp = to_onehot(ex['output'])
try:
out = sess.run(['output'], {'input': inp})[0]
out = (out > 0.0).astype(np.float32)
except:
return False
if not np.array_equal(out, exp):
return False
return True
def validate_raw(raw_bytes, td, providers):
"""Validate ONNX model from raw bytes."""
if len(raw_bytes) > MAX_ONNX_FILESIZE:
return False
try:
opts = ort.SessionOptions()
opts.log_severity_level = 3
sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers)
except:
return False
examples = td['train'] + td['test']
if 'arc-gen' in td:
examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
for ex in examples:
inp = to_onehot(ex['input'])
exp = to_onehot(ex['output'])
try:
out = sess.run(['output'], {'input': inp})[0]
out = (out > 0.0).astype(np.float32)
except:
return False
if not np.array_equal(out, exp):
return False
return True