| |
| """Static profiling for ONNX models. |
| |
| Uses neurogolf_utils.score_network() (onnx_tool) when available β this is |
| the ONLY scoring that matches Kaggle. The static fallback is approximate |
| and prints a WARNING. If onnx_tool returns (None, None, None), the model |
| is REJECTED β do not submit it. |
| """ |
|
|
| import onnx |
| from onnx import numpy_helper |
| from .constants import BANNED_OPS, GH, GW |
|
|
| try: |
| from neurogolf_utils import score_network as _score_network_official |
| HAS_ONNX_TOOL = True |
| except ImportError: |
| HAS_ONNX_TOOL = False |
|
|
| _WARNED_NO_ONNX_TOOL = False |
|
|
|
|
| def score_network(path): |
| """Score network. Returns (macs, memory, params) or (None, None, None). |
| |
| If onnx_tool is available: uses official scorer. (None,None,None) = REJECTED. |
| If onnx_tool is NOT available: uses static fallback with WARNING. |
| """ |
| global _WARNED_NO_ONNX_TOOL |
| if HAS_ONNX_TOOL: |
| |
| try: |
| result = _score_network_official(path) |
| except Exception as e: |
| print(f"WARNING: onnx_tool score_network failed on {path}: {e}") |
| return None, None, None |
| return result |
| else: |
| if not _WARNED_NO_ONNX_TOOL: |
| print("WARNING: onnx_tool not installed. Scores are APPROXIMATE and may not match Kaggle.") |
| print("WARNING: Models that fail onnx_tool profiling will be REJECTED on Kaggle.") |
| print("WARNING: Run neurogolf_utils.verify_network() in a Kaggle notebook before submitting.") |
| _WARNED_NO_ONNX_TOOL = True |
| return _static_profile(path) |
|
|
|
|
| def _static_profile(path): |
| """Static profiling fallback. APPROXIMATE β does not match Kaggle scoring. |
| Only used when onnx_tool is not installed.""" |
| try: |
| model = onnx.load(path) |
| except: |
| return None, None, None |
| tensors = {} |
| params = 0 |
| nbytes = 0 |
| macs = 0 |
| for init in model.graph.initializer: |
| a = numpy_helper.to_array(init) |
| tensors[init.name] = a |
| params += a.size |
| nbytes += a.nbytes |
| for nd in model.graph.node: |
| if nd.op_type == 'Constant': |
| for attr in nd.attribute: |
| if attr.t and attr.t.ByteSize() > 0: |
| try: |
| a = numpy_helper.to_array(attr.t) |
| if nd.output: |
| tensors[nd.output[0]] = a |
| params += a.size |
| nbytes += a.nbytes |
| except: |
| pass |
| |
| if nd.op_type.upper() in {op.upper() for op in BANNED_OPS}: |
| print(f"WARNING: Banned op '{nd.op_type}' found in {path}") |
| return None, None, None |
| if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors: |
| w = tensors[nd.input[1]] |
| if w.ndim == 4: |
| co, ci, kh, kw = w.shape |
| macs += co * ci * kh * kw * GH * GW |
| return int(macs), int(nbytes), int(params) |
|
|