#!/usr/bin/env python3 """Model validation utilities. Validation order (matches Kaggle's verify_network): 1. File size ≤ 1.44MB 2. onnx.checker.check_model() — catches malformed graphs 3. No banned ops (UPPERCASE check to match Kaggle) 4. All tensor shapes are static (no dynamic dims) 5. onnxruntime.InferenceSession loads successfully 6. Correct outputs on train + test + arc-gen """ import os import numpy as np import onnx import onnxruntime as ort from .data_loader import to_onehot from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE, BANNED_OPS _BANNED_OPS_UPPER = {op.upper() for op in BANNED_OPS} def check_model_structure(path): """Check ONNX model structure: size, valid graph, no banned ops, static shapes. Returns (ok, error_message).""" # 1. File size try: fsize = os.path.getsize(path) except OSError: return False, f"File not found: {path}" if fsize > MAX_ONNX_FILESIZE: return False, f"File size {fsize} exceeds {MAX_ONNX_FILESIZE} ({fsize/1024:.1f} KB)" # 2. ONNX checker try: model = onnx.load(path) onnx.checker.check_model(model) except Exception as e: return False, f"onnx.checker failed: {e}" # 3. Banned ops (UPPERCASE comparison — matches Kaggle) for node in model.graph.node: if node.op_type.upper() in _BANNED_OPS_UPPER: return False, f"Banned op: {node.op_type}" # 4. Static shapes — all tensors must have fully defined shapes for inp in model.graph.input: if inp.type.HasField('tensor_type'): shape = inp.type.tensor_type.shape if shape: for dim in shape.dim: if not dim.dim_value and dim.dim_value != 0: if not dim.dim_param: # symbolic dim is also not static pass # dim_value=0 is valid (means unknown in some contexts) return False, f"Dynamic shape on input '{inp.name}': {[d.dim_value or d.dim_param for d in shape.dim]}" for out in model.graph.output: if out.type.HasField('tensor_type'): shape = out.type.tensor_type.shape if shape: for dim in shape.dim: if not dim.dim_value and dim.dim_value != 0: if not dim.dim_param: pass return False, f"Dynamic shape on output '{out.name}': {[d.dim_value or d.dim_param for d in shape.dim]}" return True, None def validate(path, td, providers): """Full validation: structure check + correct outputs on all splits. Returns False immediately if any check fails.""" # Structure checks first ok, err = check_model_structure(path) if not ok: return False # Load and run inference try: opts = ort.SessionOptions() opts.log_severity_level = 3 sess = ort.InferenceSession(path, sess_options=opts, providers=providers) except: return False examples = td['train'] + td['test'] if 'arc-gen' in td: examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] for ex in examples: inp = to_onehot(ex['input']) exp = to_onehot(ex['output']) try: out = sess.run(['output'], {'input': inp})[0] out = (out > 0.0).astype(np.float32) except: return False if not np.array_equal(out, exp): return False return True def validate_raw(raw_bytes, td, providers): """Validate ONNX model from raw bytes.""" if len(raw_bytes) > MAX_ONNX_FILESIZE: return False try: opts = ort.SessionOptions() opts.log_severity_level = 3 sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers) except: return False examples = td['train'] + td['test'] if 'arc-gen' in td: examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] for ex in examples: inp = to_onehot(ex['input']) exp = to_onehot(ex['output']) try: out = sess.run(['output'], {'input': inp})[0] out = (out > 0.0).astype(np.float32) except: return False if not np.array_equal(out, exp): return False return True