| |
| """Model validation utilities. |
| |
| Validation order (matches Kaggle's verify_network): |
| 1. File size ≤ 1.44MB |
| 2. onnx.checker.check_model() — catches malformed graphs |
| 3. No banned ops (UPPERCASE check to match Kaggle) |
| 4. All tensor shapes are static (no dynamic dims) |
| 5. onnxruntime.InferenceSession loads successfully |
| 6. Correct outputs on train + test + arc-gen |
| """ |
|
|
| import os |
| import numpy as np |
| import onnx |
| import onnxruntime as ort |
| from .data_loader import to_onehot |
| from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE, BANNED_OPS |
|
|
| _BANNED_OPS_UPPER = {op.upper() for op in BANNED_OPS} |
|
|
|
|
| def check_model_structure(path): |
| """Check ONNX model structure: size, valid graph, no banned ops, static shapes. |
| Returns (ok, error_message).""" |
| |
| try: |
| fsize = os.path.getsize(path) |
| except OSError: |
| return False, f"File not found: {path}" |
| if fsize > MAX_ONNX_FILESIZE: |
| return False, f"File size {fsize} exceeds {MAX_ONNX_FILESIZE} ({fsize/1024:.1f} KB)" |
|
|
| |
| try: |
| model = onnx.load(path) |
| onnx.checker.check_model(model) |
| except Exception as e: |
| return False, f"onnx.checker failed: {e}" |
|
|
| |
| for node in model.graph.node: |
| if node.op_type.upper() in _BANNED_OPS_UPPER: |
| return False, f"Banned op: {node.op_type}" |
|
|
| |
| for inp in model.graph.input: |
| if inp.type.HasField('tensor_type'): |
| shape = inp.type.tensor_type.shape |
| if shape: |
| for dim in shape.dim: |
| if not dim.dim_value and dim.dim_value != 0: |
| if not dim.dim_param: |
| pass |
| return False, f"Dynamic shape on input '{inp.name}': {[d.dim_value or d.dim_param for d in shape.dim]}" |
|
|
| for out in model.graph.output: |
| if out.type.HasField('tensor_type'): |
| shape = out.type.tensor_type.shape |
| if shape: |
| for dim in shape.dim: |
| if not dim.dim_value and dim.dim_value != 0: |
| if not dim.dim_param: |
| pass |
| return False, f"Dynamic shape on output '{out.name}': {[d.dim_value or d.dim_param for d in shape.dim]}" |
|
|
| return True, None |
|
|
|
|
| def validate(path, td, providers): |
| """Full validation: structure check + correct outputs on all splits. |
| Returns False immediately if any check fails.""" |
| |
| ok, err = check_model_structure(path) |
| if not ok: |
| return False |
|
|
| |
| try: |
| opts = ort.SessionOptions() |
| opts.log_severity_level = 3 |
| sess = ort.InferenceSession(path, sess_options=opts, providers=providers) |
| except: |
| return False |
|
|
| examples = td['train'] + td['test'] |
| if 'arc-gen' in td: |
| examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] |
| for ex in examples: |
| inp = to_onehot(ex['input']) |
| exp = to_onehot(ex['output']) |
| try: |
| out = sess.run(['output'], {'input': inp})[0] |
| out = (out > 0.0).astype(np.float32) |
| except: |
| return False |
| if not np.array_equal(out, exp): |
| return False |
| return True |
|
|
|
|
| def validate_raw(raw_bytes, td, providers): |
| """Validate ONNX model from raw bytes.""" |
| if len(raw_bytes) > MAX_ONNX_FILESIZE: |
| return False |
| try: |
| opts = ort.SessionOptions() |
| opts.log_severity_level = 3 |
| sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers) |
| except: |
| return False |
| examples = td['train'] + td['test'] |
| if 'arc-gen' in td: |
| examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] |
| for ex in examples: |
| inp = to_onehot(ex['input']) |
| exp = to_onehot(ex['output']) |
| try: |
| out = sess.run(['output'], {'input': inp})[0] |
| out = (out > 0.0).astype(np.float32) |
| except: |
| return False |
| if not np.array_equal(out, exp): |
| return False |
| return True |
|
|