#!/usr/bin/env python3 """Model validation utilities.""" import numpy as np import onnxruntime as ort from .data_loader import to_onehot from .constants import MAX_ARCGEN_VALIDATE def validate(path, td, providers): """Validate ONNX model against task data.""" try: opts = ort.SessionOptions() opts.log_severity_level = 3 sess = ort.InferenceSession(path, sess_options=opts, providers=providers) except: return False examples = td['train'] + td['test'] if 'arc-gen' in td: examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] for ex in examples: inp = to_onehot(ex['input']) exp = to_onehot(ex['output']) try: out = sess.run(['output'], {'input': inp})[0] out = (out > 0.0).astype(np.float32) except: return False if not np.array_equal(out, exp): return False return True def validate_raw(raw_bytes, td, providers): """Validate ONNX model from raw bytes.""" try: opts = ort.SessionOptions() opts.log_severity_level = 3 sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers) except: return False examples = td['train'] + td['test'] if 'arc-gen' in td: examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] for ex in examples: inp = to_onehot(ex['input']) exp = to_onehot(ex['output']) try: out = sess.run(['output'], {'input': inp})[0] out = (out > 0.0).astype(np.float32) except: return False if not np.array_equal(out, exp): return False return True