rogermt commited on
Commit
8c12a26
·
verified ·
1 Parent(s): 9862993

Fix profiler.py: return 3 values (macs=0, memory, params) for backward compat with solver_registry.py

Browse files
own-solver/neurogolf_solver/profiler.py CHANGED
@@ -8,6 +8,8 @@ ORT with profiling enabled, which is too heavy for local model generation.
8
  Strategy: Use static fallback for local scoring during model generation.
9
  Real scoring happens on Kaggle at submission time via the official utils.
10
  Models are NOT rejected locally — they're validated via inference correctness.
 
 
11
  """
12
 
13
  import math
@@ -17,21 +19,25 @@ from .constants import BANNED_OPS, GH, GW
17
 
18
 
19
  def score_network(path):
20
- """Score network locally. Returns (memory, params) or (None, None, None).
21
 
22
- Uses static estimation (sum of tensor sizes + param count).
23
- This is APPROXIMATE but sufficient for local development.
24
  Real scoring uses ORT profiler on Kaggle.
25
  """
26
- return _static_profile(path)
 
 
 
 
27
 
28
 
29
  def estimate_score(path):
30
  """Estimate score under new formula: 25 - ln(memory + params)."""
31
  result = _static_profile(path)
32
- if result is None or result[0] is None:
33
  return None
34
- memory, params = result[0], result[1] # memory in bytes, params in elements
35
  cost = memory + params
36
  if cost <= 0:
37
  return 25.0
@@ -44,12 +50,12 @@ def _static_profile(path):
44
  memory = sum of all initializer bytes + estimated intermediate tensor bytes
45
  params = sum of all initializer element counts + Constant node values
46
 
47
- Returns (memory, params) or (None, None, None) if model is invalid.
48
  """
49
  try:
50
  model = onnx.load(path)
51
  except Exception:
52
- return None, None, None
53
 
54
  params = 0
55
  memory = 0 # bytes
@@ -84,11 +90,9 @@ def _static_profile(path):
84
  # Banned op check
85
  if nd.op_type.upper() in {op.upper() for op in BANNED_OPS}:
86
  print(f"WARNING: Banned op '{nd.op_type}' found in {path}")
87
- return None, None, None
88
 
89
  # Estimate intermediate tensor memory (node outputs that aren't 'output')
90
- # Each intermediate tensor is approximately (1,10,30,30) float32 = 36,000 bytes
91
- # This is rough but gives directional guidance
92
  n_intermediates = 0
93
  for nd in model.graph.node:
94
  for out_name in nd.output:
 
8
  Strategy: Use static fallback for local scoring during model generation.
9
  Real scoring happens on Kaggle at submission time via the official utils.
10
  Models are NOT rejected locally — they're validated via inference correctness.
11
+
12
+ Returns (macs=0, memory, params) for backward compatibility with solver_registry.py.
13
  """
14
 
15
  import math
 
19
 
20
 
21
  def score_network(path):
22
+ """Score network locally. Returns (macs, memory, params) or (None, None, None).
23
 
24
+ macs is always 0 (no longer used in Kaggle scoring since May 4 2026).
25
+ memory and params are static estimates sufficient for local development.
26
  Real scoring uses ORT profiler on Kaggle.
27
  """
28
+ result = _static_profile(path)
29
+ if result is None:
30
+ return None, None, None
31
+ memory, params = result
32
+ return 0, memory, params
33
 
34
 
35
  def estimate_score(path):
36
  """Estimate score under new formula: 25 - ln(memory + params)."""
37
  result = _static_profile(path)
38
+ if result is None:
39
  return None
40
+ memory, params = result
41
  cost = memory + params
42
  if cost <= 0:
43
  return 25.0
 
50
  memory = sum of all initializer bytes + estimated intermediate tensor bytes
51
  params = sum of all initializer element counts + Constant node values
52
 
53
+ Returns (memory, params) or None if model is invalid.
54
  """
55
  try:
56
  model = onnx.load(path)
57
  except Exception:
58
+ return None
59
 
60
  params = 0
61
  memory = 0 # bytes
 
90
  # Banned op check
91
  if nd.op_type.upper() in {op.upper() for op in BANNED_OPS}:
92
  print(f"WARNING: Banned op '{nd.op_type}' found in {path}")
93
+ return None
94
 
95
  # Estimate intermediate tensor memory (node outputs that aren't 'output')
 
 
96
  n_intermediates = 0
97
  for nd in model.graph.node:
98
  for out_name in nd.output: