File size: 4,329 Bytes
43007c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
"""Model validation utilities.

Validation order (matches Kaggle's verify_network):
1. File size ≤ 1.44MB
2. onnx.checker.check_model() — catches malformed graphs
3. No banned ops (UPPERCASE check to match Kaggle)
4. All tensor shapes are static (no dynamic dims)
5. onnxruntime.InferenceSession loads successfully
6. Correct outputs on train + test + arc-gen
"""

import os
import numpy as np
import onnx
import onnxruntime as ort
from .data_loader import to_onehot
from .constants import MAX_ARCGEN_VALIDATE, MAX_ONNX_FILESIZE, BANNED_OPS

_BANNED_OPS_UPPER = {op.upper() for op in BANNED_OPS}


def check_model_structure(path):
    """Check ONNX model structure: size, valid graph, no banned ops, static shapes.
    Returns (ok, error_message)."""
    # 1. File size
    try:
        fsize = os.path.getsize(path)
    except OSError:
        return False, f"File not found: {path}"
    if fsize > MAX_ONNX_FILESIZE:
        return False, f"File size {fsize} exceeds {MAX_ONNX_FILESIZE} ({fsize/1024:.1f} KB)"

    # 2. ONNX checker
    try:
        model = onnx.load(path)
        onnx.checker.check_model(model)
    except Exception as e:
        return False, f"onnx.checker failed: {e}"

    # 3. Banned ops (UPPERCASE comparison — matches Kaggle)
    for node in model.graph.node:
        if node.op_type.upper() in _BANNED_OPS_UPPER:
            return False, f"Banned op: {node.op_type}"

    # 4. Static shapes — all tensors must have fully defined shapes
    for inp in model.graph.input:
        if inp.type.HasField('tensor_type'):
            shape = inp.type.tensor_type.shape
            if shape:
                for dim in shape.dim:
                    if not dim.dim_value and dim.dim_value != 0:
                        if not dim.dim_param:  # symbolic dim is also not static
                            pass  # dim_value=0 is valid (means unknown in some contexts)
                        return False, f"Dynamic shape on input '{inp.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"

    for out in model.graph.output:
        if out.type.HasField('tensor_type'):
            shape = out.type.tensor_type.shape
            if shape:
                for dim in shape.dim:
                    if not dim.dim_value and dim.dim_value != 0:
                        if not dim.dim_param:
                            pass
                        return False, f"Dynamic shape on output '{out.name}': {[d.dim_value or d.dim_param for d in shape.dim]}"

    return True, None


def validate(path, td, providers):
    """Full validation: structure check + correct outputs on all splits.
    Returns False immediately if any check fails."""
    # Structure checks first
    ok, err = check_model_structure(path)
    if not ok:
        return False

    # Load and run inference
    try:
        opts = ort.SessionOptions()
        opts.log_severity_level = 3
        sess = ort.InferenceSession(path, sess_options=opts, providers=providers)
    except:
        return False

    examples = td['train'] + td['test']
    if 'arc-gen' in td:
        examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
    for ex in examples:
        inp = to_onehot(ex['input'])
        exp = to_onehot(ex['output'])
        try:
            out = sess.run(['output'], {'input': inp})[0]
            out = (out > 0.0).astype(np.float32)
        except:
            return False
        if not np.array_equal(out, exp):
            return False
    return True


def validate_raw(raw_bytes, td, providers):
    """Validate ONNX model from raw bytes."""
    if len(raw_bytes) > MAX_ONNX_FILESIZE:
        return False
    try:
        opts = ort.SessionOptions()
        opts.log_severity_level = 3
        sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=providers)
    except:
        return False
    examples = td['train'] + td['test']
    if 'arc-gen' in td:
        examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
    for ex in examples:
        inp = to_onehot(ex['input'])
        exp = to_onehot(ex['output'])
        try:
            out = sess.run(['output'], {'input': inp})[0]
            out = (out > 0.0).astype(np.float32)
        except:
            return False
        if not np.array_equal(out, exp):
            return False
    return True