rogermt commited on
Commit
1b065e6
·
verified ·
1 Parent(s): 36815b6

v5 refactor: add solvers/solver_registry.py

Browse files
neurogolf_solver/solvers/solver_registry.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Solver registry and task solving orchestration."""
3
+
4
+ import os
5
+ import time
6
+ import onnx
7
+ from .analytical import s_identity, s_constant, s_color_map, s_transpose
8
+ from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop, s_gravity
9
+ from .tiling import (s_tile, s_upscale, s_kronecker, s_nonuniform_scale, s_diagonal_tile,
10
+ s_mirror_h, s_mirror_v, s_quad_mirror, s_concat, s_concat_enhanced,
11
+ s_spatial_gather, s_varshape_spatial_gather)
12
+ from .conv import solve_conv_fixed, solve_conv_variable, solve_conv_diffshape, solve_conv_var_diff
13
+ from ..data_loader import get_exs, fixed_shapes
14
+ from ..validators import validate
15
+ from ..constants import EXCLUDED_TASKS
16
+
17
+ # Analytical solvers registry — order matters (cheaper first)
18
+ ANALYTICAL_SOLVERS = [
19
+ ('identity', s_identity),
20
+ ('constant', s_constant),
21
+ ('color_map', s_color_map),
22
+ ('transpose', s_transpose),
23
+ ('flip', s_flip),
24
+ ('rotate', s_rotate),
25
+ ('shift', s_shift),
26
+ ('tile', s_tile),
27
+ ('upscale', s_upscale),
28
+ ('kronecker', s_kronecker),
29
+ ('nonuniform_scale', s_nonuniform_scale),
30
+ ('mirror_h', s_mirror_h),
31
+ ('mirror_v', s_mirror_v),
32
+ ('quad_mirror', s_quad_mirror),
33
+ ('concat', s_concat),
34
+ ('concat_enhanced', s_concat_enhanced),
35
+ ('diagonal_tile', s_diagonal_tile),
36
+ ('fixed_crop', s_fixed_crop),
37
+ ('spatial_gather', s_spatial_gather),
38
+ ('varshape_spatial_gather', s_varshape_spatial_gather),
39
+ ]
40
+
41
+
42
+ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None):
43
+ """Solve a single ARC-AGI task.
44
+
45
+ Returns: (ok, solver_name, file_size, elapsed, model_path)
46
+ """
47
+ if excluded_tasks is None:
48
+ excluded_tasks = EXCLUDED_TASKS
49
+
50
+ t_start = time.time()
51
+ os.makedirs(outdir, exist_ok=True)
52
+ path = os.path.join(outdir, f"task{tn:03d}.onnx")
53
+
54
+ if tn in excluded_tasks:
55
+ return False, 'excluded', None, time.time() - t_start, path
56
+
57
+ # 1. Try analytical solvers (fast, tiny models)
58
+ for sname, sfn in ANALYTICAL_SOLVERS:
59
+ try:
60
+ model = sfn(td)
61
+ if model is None:
62
+ continue
63
+ onnx.save(model, path)
64
+ if validate(path, td, providers):
65
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
66
+ except:
67
+ pass
68
+
69
+ # 2. Determine task shape category and try conv solvers
70
+ exs = get_exs(td)
71
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
72
+ shapes = set(inp.shape for inp, _ in exs)
73
+ fixed_in = len(shapes) == 1
74
+
75
+ conv_time = conv_budget
76
+
77
+ if same_shape:
78
+ if fixed_in:
79
+ result = solve_conv_fixed(td, path, providers, time_budget=conv_time / 2)
80
+ if result is not None:
81
+ sname, model = result
82
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
83
+ result = solve_conv_variable(td, path, providers, time_budget=conv_time)
84
+ if result is not None:
85
+ sname, model = result
86
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
87
+ else:
88
+ sp = fixed_shapes(td)
89
+ if sp is not None:
90
+ (IH, IW), (OH, OW) = sp
91
+ if OH <= IH and OW <= IW:
92
+ result = solve_conv_diffshape(td, path, providers, time_budget=conv_time)
93
+ if result is not None:
94
+ sname, model = result
95
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
96
+
97
+ result = solve_conv_var_diff(td, path, providers, time_budget=conv_time)
98
+ if result is not None:
99
+ sname, model = result
100
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
101
+
102
+ return False, None, None, time.time() - t_start, path