rogermt commited on
Commit
5604a82
·
verified ·
1 Parent(s): 6eb4759

v5 refactor: add validators.py

Browse files
Files changed (1) hide show
  1. neurogolf_solver/validators.py +55 -0
neurogolf_solver/validators.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Model validation utilities."""
3
+
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from .data_loader import to_onehot
7
+ from .constants import MAX_ARCGEN_VALIDATE
8
+
9
+
10
+ def validate(path, td, providers):
11
+ """Validate ONNX model against task data."""
12
+ try:
13
+ opts = ort.SessionOptions()
14
+ opts.log_severity_level = 3
15
+ sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
16
+ except:
17
+ return False
18
+ examples = td['train'] + td['test']
19
+ if 'arc-gen' in td:
20
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
21
+ for ex in examples:
22
+ inp = to_onehot(ex['input'])
23
+ exp = to_onehot(ex['output'])
24
+ try:
25
+ out = sess.run(['output'], {'input': inp})[0]
26
+ out = (out > 0.0).astype(np.float32)
27
+ except:
28
+ return False
29
+ if not np.array_equal(out, exp):
30
+ return False
31
+ return True
32
+
33
+
34
+ def validate_raw(raw_bytes, td, providers):
35
+ """Validate ONNX model from raw bytes."""
36
+ try:
37
+ opts = ort.SessionOptions()
38
+ opts.log_severity_level = 3
39
+ sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers)
40
+ except:
41
+ return False
42
+ examples = td['train'] + td['test']
43
+ if 'arc-gen' in td:
44
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
45
+ for ex in examples:
46
+ inp = to_onehot(ex['input'])
47
+ exp = to_onehot(ex['output'])
48
+ try:
49
+ out = sess.run(['output'], {'input': inp})[0]
50
+ out = (out > 0.0).astype(np.float32)
51
+ except:
52
+ return False
53
+ if not np.array_equal(out, exp):
54
+ return False
55
+ return True