| |
| """Model validation utilities.""" |
|
|
| import numpy as np |
| import onnxruntime as ort |
| from .data_loader import to_onehot |
| from .constants import MAX_ARCGEN_VALIDATE |
|
|
|
|
| def validate(path, td, providers): |
| """Validate ONNX model against task data.""" |
| try: |
| opts = ort.SessionOptions() |
| opts.log_severity_level = 3 |
| sess = ort.InferenceSession(path, sess_options=opts, providers=providers) |
| except: |
| return False |
| examples = td['train'] + td['test'] |
| if 'arc-gen' in td: |
| examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] |
| for ex in examples: |
| inp = to_onehot(ex['input']) |
| exp = to_onehot(ex['output']) |
| try: |
| out = sess.run(['output'], {'input': inp})[0] |
| out = (out > 0.0).astype(np.float32) |
| except: |
| return False |
| if not np.array_equal(out, exp): |
| return False |
| return True |
|
|
|
|
| def validate_raw(raw_bytes, td, providers): |
| """Validate ONNX model from raw bytes.""" |
| try: |
| opts = ort.SessionOptions() |
| opts.log_severity_level = 3 |
| sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers) |
| except: |
| return False |
| examples = td['train'] + td['test'] |
| if 'arc-gen' in td: |
| examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE] |
| for ex in examples: |
| inp = to_onehot(ex['input']) |
| exp = to_onehot(ex['output']) |
| try: |
| out = sess.run(['output'], {'input': inp})[0] |
| out = (out > 0.0).astype(np.float32) |
| except: |
| return False |
| if not np.array_equal(out, exp): |
| return False |
| return True |
|
|