rogermt's picture
v5 refactor: add validators.py
5604a82 verified
raw
history blame
1.7 kB
#!/usr/bin/env python3
"""Model validation utilities."""
import numpy as np
import onnxruntime as ort
from .data_loader import to_onehot
from .constants import MAX_ARCGEN_VALIDATE
def validate(path, td, providers):
"""Validate ONNX model against task data."""
try:
opts = ort.SessionOptions()
opts.log_severity_level = 3
sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
except:
return False
examples = td['train'] + td['test']
if 'arc-gen' in td:
examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
for ex in examples:
inp = to_onehot(ex['input'])
exp = to_onehot(ex['output'])
try:
out = sess.run(['output'], {'input': inp})[0]
out = (out > 0.0).astype(np.float32)
except:
return False
if not np.array_equal(out, exp):
return False
return True
def validate_raw(raw_bytes, td, providers):
"""Validate ONNX model from raw bytes."""
try:
opts = ort.SessionOptions()
opts.log_severity_level = 3
sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers)
except:
return False
examples = td['train'] + td['test']
if 'arc-gen' in td:
examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
for ex in examples:
inp = to_onehot(ex['input'])
exp = to_onehot(ex['output'])
try:
out = sess.run(['output'], {'input': inp})[0]
out = (out > 0.0).astype(np.float32)
except:
return False
if not np.array_equal(out, exp):
return False
return True