solver_registry: check file size after onnx.save, before validate — skip oversized immediately
Browse files
neurogolf_solver/solvers/solver_registry.py
CHANGED
|
@@ -15,7 +15,7 @@ from .mode import s_mode_fill
|
|
| 15 |
from .conv import solve_conv_fixed, solve_conv_variable, solve_conv_diffshape, solve_conv_var_diff
|
| 16 |
from ..data_loader import get_exs, fixed_shapes
|
| 17 |
from ..validators import validate
|
| 18 |
-
from ..constants import EXCLUDED_TASKS
|
| 19 |
|
| 20 |
# Analytical solvers registry — order matters (cheaper first)
|
| 21 |
ANALYTICAL_SOLVERS = [
|
|
@@ -45,6 +45,14 @@ ANALYTICAL_SOLVERS = [
|
|
| 45 |
]
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None):
|
| 49 |
"""Solve a single ARC-AGI task.
|
| 50 |
|
|
@@ -67,6 +75,8 @@ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None)
|
|
| 67 |
if model is None:
|
| 68 |
continue
|
| 69 |
onnx.save(model, path)
|
|
|
|
|
|
|
| 70 |
if validate(path, td, providers):
|
| 71 |
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 72 |
except:
|
|
@@ -84,12 +94,14 @@ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None)
|
|
| 84 |
if fixed_in:
|
| 85 |
result = solve_conv_fixed(td, path, providers, time_budget=conv_time / 2)
|
| 86 |
if result is not None:
|
| 87 |
-
|
| 88 |
-
|
|
|
|
| 89 |
result = solve_conv_variable(td, path, providers, time_budget=conv_time)
|
| 90 |
if result is not None:
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
else:
|
| 94 |
sp = fixed_shapes(td)
|
| 95 |
if sp is not None:
|
|
@@ -97,12 +109,14 @@ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None)
|
|
| 97 |
if OH <= IH and OW <= IW:
|
| 98 |
result = solve_conv_diffshape(td, path, providers, time_budget=conv_time)
|
| 99 |
if result is not None:
|
| 100 |
-
|
| 101 |
-
|
|
|
|
| 102 |
|
| 103 |
result = solve_conv_var_diff(td, path, providers, time_budget=conv_time)
|
| 104 |
if result is not None:
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
return False, None, None, time.time() - t_start, path
|
|
|
|
| 15 |
from .conv import solve_conv_fixed, solve_conv_variable, solve_conv_diffshape, solve_conv_var_diff
|
| 16 |
from ..data_loader import get_exs, fixed_shapes
|
| 17 |
from ..validators import validate
|
| 18 |
+
from ..constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
|
| 19 |
|
| 20 |
# Analytical solvers registry — order matters (cheaper first)
|
| 21 |
ANALYTICAL_SOLVERS = [
|
|
|
|
| 45 |
]
|
| 46 |
|
| 47 |
|
| 48 |
+
def _check_size(path):
|
| 49 |
+
"""Return True if file is within 1.44MB limit."""
|
| 50 |
+
try:
|
| 51 |
+
return os.path.getsize(path) <= MAX_ONNX_FILESIZE
|
| 52 |
+
except OSError:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None):
|
| 57 |
"""Solve a single ARC-AGI task.
|
| 58 |
|
|
|
|
| 75 |
if model is None:
|
| 76 |
continue
|
| 77 |
onnx.save(model, path)
|
| 78 |
+
if not _check_size(path):
|
| 79 |
+
continue # oversized, skip to next solver
|
| 80 |
if validate(path, td, providers):
|
| 81 |
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 82 |
except:
|
|
|
|
| 94 |
if fixed_in:
|
| 95 |
result = solve_conv_fixed(td, path, providers, time_budget=conv_time / 2)
|
| 96 |
if result is not None:
|
| 97 |
+
if _check_size(path):
|
| 98 |
+
sname, model = result
|
| 99 |
+
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 100 |
result = solve_conv_variable(td, path, providers, time_budget=conv_time)
|
| 101 |
if result is not None:
|
| 102 |
+
if _check_size(path):
|
| 103 |
+
sname, model = result
|
| 104 |
+
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 105 |
else:
|
| 106 |
sp = fixed_shapes(td)
|
| 107 |
if sp is not None:
|
|
|
|
| 109 |
if OH <= IH and OW <= IW:
|
| 110 |
result = solve_conv_diffshape(td, path, providers, time_budget=conv_time)
|
| 111 |
if result is not None:
|
| 112 |
+
if _check_size(path):
|
| 113 |
+
sname, model = result
|
| 114 |
+
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 115 |
|
| 116 |
result = solve_conv_var_diff(td, path, providers, time_budget=conv_time)
|
| 117 |
if result is not None:
|
| 118 |
+
if _check_size(path):
|
| 119 |
+
sname, model = result
|
| 120 |
+
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 121 |
|
| 122 |
return False, None, None, time.time() - t_start, path
|