rogermt commited on
Commit
717cc1f
·
verified ·
1 Parent(s): 2c8b675

Move own-solver/neurogolf_solver/solvers/solver_registry.py to own-solver/

Browse files
own-solver/neurogolf_solver/solvers/solver_registry.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Solver registry and task solving orchestration."""
3
+
4
+ import os
5
+ import time
6
+ import onnx
7
+ from .analytical import s_identity, s_constant, s_color_map, s_transpose
8
+ from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop, s_gravity
9
+ from .tiling import (s_tile, s_upscale, s_kronecker, s_nonuniform_scale, s_diagonal_tile,
10
+ s_mirror_h, s_mirror_v, s_quad_mirror, s_concat, s_concat_enhanced,
11
+ s_spatial_gather, s_varshape_spatial_gather)
12
+ from .gravity import s_gravity_unrolled
13
+ from .edge import s_edge_detect
14
+ from .mode import s_mode_fill
15
+ from .wave1 import (s_downsample_stride, s_symmetry_complete, s_extract_inner,
16
+ s_add_border, s_sparse_fill, s_channel_filter)
17
+ from .conv import solve_conv_fixed, solve_conv_variable, solve_conv_diffshape, solve_conv_var_diff
18
+ from ..data_loader import get_exs, fixed_shapes
19
+ from ..validators import validate
20
+ from ..profiler import score_network
21
+ from ..constants import EXCLUDED_TASKS, MAX_ONNX_FILESIZE
22
+
23
+ # Analytical solvers registry — order matters (cheaper first)
24
+ ANALYTICAL_SOLVERS = [
25
+ ('identity', s_identity),
26
+ ('constant', s_constant),
27
+ ('color_map', s_color_map),
28
+ ('transpose', s_transpose),
29
+ ('flip', s_flip),
30
+ ('rotate', s_rotate),
31
+ ('shift', s_shift),
32
+ ('tile', s_tile),
33
+ ('upscale', s_upscale),
34
+ ('kronecker', s_kronecker),
35
+ ('nonuniform_scale', s_nonuniform_scale),
36
+ ('mirror_h', s_mirror_h),
37
+ ('mirror_v', s_mirror_v),
38
+ ('quad_mirror', s_quad_mirror),
39
+ ('concat', s_concat),
40
+ ('concat_enhanced', s_concat_enhanced),
41
+ ('diagonal_tile', s_diagonal_tile),
42
+ ('fixed_crop', s_fixed_crop),
43
+ ('spatial_gather', s_spatial_gather),
44
+ ('varshape_spatial_gather', s_varshape_spatial_gather),
45
+ ('gravity_unrolled', s_gravity_unrolled),
46
+ ('edge_detect', s_edge_detect),
47
+ ('mode_fill', s_mode_fill),
48
+ ('downsample_stride', s_downsample_stride),
49
+ ('symmetry_complete', s_symmetry_complete),
50
+ ('extract_inner', s_extract_inner),
51
+ ('add_border', s_add_border),
52
+ ('sparse_fill', s_sparse_fill),
53
+ ('channel_filter', s_channel_filter),
54
+ ]
55
+
56
+
57
+ def _check_size(path):
58
+ """Return True if file is within 1.44MB limit."""
59
+ try:
60
+ return os.path.getsize(path) <= MAX_ONNX_FILESIZE
61
+ except OSError:
62
+ return False
63
+
64
+
65
+ def _check_scoreable(path):
66
+ """Return True if score_network returns valid scores (not None).
67
+ A model that can't be scored will be REJECTED by Kaggle."""
68
+ macs, memory, params = score_network(path)
69
+ if macs is None or memory is None or params is None:
70
+ return False
71
+ return True
72
+
73
+
74
+ def _accept_model(path, td, providers):
75
+ """Full acceptance check: size + validate (outputs) + scoreable.
76
+ Returns True only if model would be accepted by Kaggle."""
77
+ if not _check_size(path):
78
+ return False
79
+ if not validate(path, td, providers):
80
+ return False
81
+ if not _check_scoreable(path):
82
+ return False
83
+ return True
84
+
85
+
86
+ def _cleanup_failed(path):
87
+ """Delete leftover .onnx file from failed solve attempts.
88
+ Prevents bad files from ending up in submission zip."""
89
+ try:
90
+ if os.path.exists(path):
91
+ os.remove(path)
92
+ except OSError:
93
+ pass
94
+
95
+
96
+ def solve_task(tn, td, outdir, providers, conv_budget=30.0, excluded_tasks=None):
97
+ """Solve a single ARC-AGI task.
98
+
99
+ Returns: (ok, solver_name, file_size, elapsed, model_path)
100
+ If unsolved, deletes any leftover .onnx file.
101
+ """
102
+ if excluded_tasks is None:
103
+ excluded_tasks = EXCLUDED_TASKS
104
+
105
+ t_start = time.time()
106
+ os.makedirs(outdir, exist_ok=True)
107
+ path = os.path.join(outdir, f"task{tn:03d}.onnx")
108
+
109
+ if tn in excluded_tasks:
110
+ return False, 'excluded', None, time.time() - t_start, path
111
+
112
+ # 1. Try analytical solvers (fast, tiny models)
113
+ for sname, sfn in ANALYTICAL_SOLVERS:
114
+ try:
115
+ model = sfn(td)
116
+ if model is None:
117
+ continue
118
+ onnx.save(model, path)
119
+ if _accept_model(path, td, providers):
120
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
121
+ except:
122
+ pass
123
+
124
+ # 2. Determine task shape category and try conv solvers
125
+ exs = get_exs(td)
126
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
127
+ shapes = set(inp.shape for inp, _ in exs)
128
+ fixed_in = len(shapes) == 1
129
+
130
+ conv_time = conv_budget
131
+
132
+ if same_shape:
133
+ if fixed_in:
134
+ result = solve_conv_fixed(td, path, providers, time_budget=conv_time / 2)
135
+ if result is not None:
136
+ if _check_size(path) and _check_scoreable(path):
137
+ sname, model = result
138
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
139
+ result = solve_conv_variable(td, path, providers, time_budget=conv_time)
140
+ if result is not None:
141
+ if _check_size(path) and _check_scoreable(path):
142
+ sname, model = result
143
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
144
+ else:
145
+ sp = fixed_shapes(td)
146
+ if sp is not None:
147
+ (IH, IW), (OH, OW) = sp
148
+ if OH <= IH and OW <= IW:
149
+ result = solve_conv_diffshape(td, path, providers, time_budget=conv_time)
150
+ if result is not None:
151
+ if _check_size(path) and _check_scoreable(path):
152
+ sname, model = result
153
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
154
+
155
+ result = solve_conv_var_diff(td, path, providers, time_budget=conv_time)
156
+ if result is not None:
157
+ if _check_size(path) and _check_scoreable(path):
158
+ sname, model = result
159
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
160
+
161
+ # All solvers failed — delete leftover .onnx file
162
+ _cleanup_failed(path)
163
+ return False, None, None, time.time() - t_start, path