rogermt commited on
Commit
2b04dc9
·
verified ·
1 Parent(s): 31248fa

solver_registry: check file size after onnx.save, before validate — skip oversized immediately

Browse files
neurogolf_solver/solvers/solver_registry.py CHANGED
@@ -15,7 +15,7 @@ from .mode import s_mode_fill
15
  from .conv import solve_conv_fixed, solve_conv_variable, solve_conv_diffshape, solve_conv_var_diff
16
  from ..data_loader import get_exs, fixed_shapes
17
  from ..validators import validate
18
- from ..constants import EXCLUDED_TASKS
19
 
20
  # Analytical solvers registry — order matters (cheaper first)
21
  ANALYTICAL_SOLVERS = [
@@ -45,6 +45,14 @@ ANALYTICAL_SOLVERS = [
45
  ]
46
 
47
 
 
 
 
 
 
 
 
 
48
  def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None):
49
  """Solve a single ARC-AGI task.
50
 
@@ -67,6 +75,8 @@ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None)
67
  if model is None:
68
  continue
69
  onnx.save(model, path)
 
 
70
  if validate(path, td, providers):
71
  return True, sname, os.path.getsize(path), time.time() - t_start, path
72
  except:
@@ -84,12 +94,14 @@ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None)
84
  if fixed_in:
85
  result = solve_conv_fixed(td, path, providers, time_budget=conv_time / 2)
86
  if result is not None:
87
- sname, model = result
88
- return True, sname, os.path.getsize(path), time.time() - t_start, path
 
89
  result = solve_conv_variable(td, path, providers, time_budget=conv_time)
90
  if result is not None:
91
- sname, model = result
92
- return True, sname, os.path.getsize(path), time.time() - t_start, path
 
93
  else:
94
  sp = fixed_shapes(td)
95
  if sp is not None:
@@ -97,12 +109,14 @@ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None)
97
  if OH <= IH and OW <= IW:
98
  result = solve_conv_diffshape(td, path, providers, time_budget=conv_time)
99
  if result is not None:
100
- sname, model = result
101
- return True, sname, os.path.getsize(path), time.time() - t_start, path
 
102
 
103
  result = solve_conv_var_diff(td, path, providers, time_budget=conv_time)
104
  if result is not None:
105
- sname, model = result
106
- return True, sname, os.path.getsize(path), time.time() - t_start, path
 
107
 
108
  return False, None, None, time.time() - t_start, path
 
15
  from .conv import solve_conv_fixed, solve_conv_variable, solve_conv_diffshape, solve_conv_var_diff
16
  from ..data_loader import get_exs, fixed_shapes
17
  from ..validators import validate
18
+ from ..constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
19
 
20
  # Analytical solvers registry — order matters (cheaper first)
21
  ANALYTICAL_SOLVERS = [
 
45
  ]
46
 
47
 
48
+ def _check_size(path):
49
+ """Return True if file is within 1.44MB limit."""
50
+ try:
51
+ return os.path.getsize(path) <= MAX_ONNX_FILESIZE
52
+ except OSError:
53
+ return False
54
+
55
+
56
  def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None):
57
  """Solve a single ARC-AGI task.
58
 
 
75
  if model is None:
76
  continue
77
  onnx.save(model, path)
78
+ if not _check_size(path):
79
+ continue # oversized, skip to next solver
80
  if validate(path, td, providers):
81
  return True, sname, os.path.getsize(path), time.time() - t_start, path
82
  except:
 
94
  if fixed_in:
95
  result = solve_conv_fixed(td, path, providers, time_budget=conv_time / 2)
96
  if result is not None:
97
+ if _check_size(path):
98
+ sname, model = result
99
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
100
  result = solve_conv_variable(td, path, providers, time_budget=conv_time)
101
  if result is not None:
102
+ if _check_size(path):
103
+ sname, model = result
104
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
105
  else:
106
  sp = fixed_shapes(td)
107
  if sp is not None:
 
109
  if OH <= IH and OW <= IW:
110
  result = solve_conv_diffshape(td, path, providers, time_budget=conv_time)
111
  if result is not None:
112
+ if _check_size(path):
113
+ sname, model = result
114
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
115
 
116
  result = solve_conv_var_diff(td, path, providers, time_budget=conv_time)
117
  if result is not None:
118
+ if _check_size(path):
119
+ sname, model = result
120
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
121
 
122
  return False, None, None, time.time() - t_start, path