rogermt commited on
Commit
f84db43
Β·
verified Β·
1 Parent(s): 2dfec08

Add Stage 1 trivial task optimizer (zero-intermediate architectures for max score)

Browse files
own-solver/stage1_trivial_optimizer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stage 1: Rebuild trivial tasks for maximum score under new formula.
4
+
5
+ New formula: score = max(1.0, 25.0 - ln(memory + params))
6
+ Where memory = sum of ALL intermediate tensor bytes (excluding 'input'/'output')
7
+ params = sum of all initializer element counts + Constant node values
8
+
9
+ KEY INSIGHT: If a node's output IS the graph output ('output'), it's NOT counted.
10
+ So: input β†’ [single op] β†’ output = ZERO intermediate memory cost.
11
+ Only weights/constants contribute.
12
+
13
+ This script rewrites existing ONNX models to use minimal architectures:
14
+ - Transpose: input→Transpose→output (cost=0, score=25)
15
+ - Color permutation: input→Gather(axis=1)→output (cost=40+10=50, score=21)
16
+ - Color mapping: input→Conv1x1→output (cost=500, score=18.8)
17
+ - Identity: input→Identity→output (cost=0, score=25)
18
+ - Flips: input→Slice(reverse)→output (cost=32+4=36, score=21.4)
19
+
20
+ Usage:
21
+ python stage1_trivial_optimizer.py --input_zip submission.zip --data_dir ./tasks --output_zip submission_stage1.zip
22
+ """
23
+
24
+ import json
25
+ import math
26
+ import os
27
+ import sys
28
+ import zipfile
29
+ from pathlib import Path
30
+
31
+ import numpy as np
32
+ import onnx
33
+ import onnxruntime as ort
34
+ from onnx import helper, TensorProto, numpy_helper
35
+
36
+ # ─── Config ───
37
+ GRID_SHAPE = [1, 10, 30, 30]
38
+ DT = TensorProto.FLOAT
39
+ IR = 8
40
+
41
+
42
+ def make_model(nodes, inits=None, opset=17):
43
+ """Create minimal ONNX model."""
44
+ x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
45
+ y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
46
+ g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
47
+ return helper.make_model(g, ir_version=IR, opset_imports=[helper.make_opsetid("", opset)])
48
+
49
+
50
+ def encode_grid(grid):
51
+ """Encode grid to one-hot tensor."""
52
+ arr = np.array(grid, dtype=np.int32)
53
+ h, w = arr.shape
54
+ t = np.zeros((1, 10, 30, 30), dtype=np.float32)
55
+ for r in range(h):
56
+ for c in range(w):
57
+ v = int(arr[r, c])
58
+ if 0 <= v < 10:
59
+ t[0, v, r, c] = 1.0
60
+ return t
61
+
62
+
63
+ def validate_model(model_bytes, examples):
64
+ """Validate model produces correct output on all examples."""
65
+ try:
66
+ opts = ort.SessionOptions()
67
+ opts.log_severity_level = 3
68
+ sess = ort.InferenceSession(model_bytes, sess_options=opts, providers=['CPUExecutionProvider'])
69
+ except Exception:
70
+ return False
71
+ for ex in examples:
72
+ try:
73
+ inp = encode_grid(ex['input'])
74
+ out = sess.run(['output'], {'input': inp})[0]
75
+ expected = encode_grid(ex['output'])
76
+ if not np.array_equal((out > 0.0).astype(float), (expected > 0.0).astype(float)):
77
+ return False
78
+ except Exception:
79
+ return False
80
+ return True
81
+
82
+
83
+ # ─── Stage 1 Optimizers ───
84
+
85
+ def optimize_transpose(model_bytes):
86
+ """Rebuild as: input β†’ Transpose β†’ output (cost=0, score=25)."""
87
+ nodes = [helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])]
88
+ model = make_model(nodes)
89
+ return model.SerializeToString()
90
+
91
+
92
+ def optimize_identity(model_bytes):
93
+ """Rebuild as: input β†’ Identity β†’ output (cost=0, score=25)."""
94
+ nodes = [helper.make_node('Identity', ['input'], ['output'])]
95
+ model = make_model(nodes)
96
+ return model.SerializeToString()
97
+
98
+
99
+ def optimize_color_permutation(model_bytes, perm):
100
+ """Rebuild as: input β†’ Gather(axis=1) β†’ output.
101
+ perm[i] = source channel for output channel i.
102
+ Cost: 10 int32 elements = 40 bytes memory + 10 params = 50. Score β‰ˆ 21.1
103
+ """
104
+ gi = np.array(perm, dtype=np.int32)
105
+ inits = [numpy_helper.from_array(gi, 'gi')]
106
+ nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
107
+ model = make_model(nodes, inits)
108
+ return model.SerializeToString()
109
+
110
+
111
+ def optimize_color_map_conv1x1(model_bytes, color_map):
112
+ """Rebuild as: input β†’ Conv(1x1) β†’ output.
113
+ Cost: 100 float32 elements = 400 bytes + 100 params = 500. Score β‰ˆ 18.8
114
+ """
115
+ W = np.zeros((10, 10, 1, 1), dtype=np.float32)
116
+ for ic in range(10):
117
+ oc = color_map.get(ic, ic)
118
+ if 0 <= oc < 10:
119
+ W[oc, ic, 0, 0] = 1.0
120
+ inits = [numpy_helper.from_array(W, 'W')]
121
+ nodes = [helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])]
122
+ model = make_model(nodes, inits)
123
+ return model.SerializeToString()
124
+
125
+
126
+ def optimize_flip(axis):
127
+ """Rebuild as: input β†’ Slice(reverse) β†’ output.
128
+ Cost: 4 int64 scalars = 32 bytes + 4 params = 36. Score β‰ˆ 21.4
129
+ """
130
+ starts = np.array([29], dtype=np.int64)
131
+ ends = np.array([np.iinfo(np.int64).min], dtype=np.int64)
132
+ axes = np.array([axis], dtype=np.int64)
133
+ steps = np.array([-1], dtype=np.int64)
134
+ inits = [
135
+ numpy_helper.from_array(starts, 'st'),
136
+ numpy_helper.from_array(ends, 'en'),
137
+ numpy_helper.from_array(axes, 'ax'),
138
+ numpy_helper.from_array(steps, 'sp'),
139
+ ]
140
+ nodes = [helper.make_node('Slice', ['input', 'st', 'en', 'ax', 'sp'], ['output'])]
141
+ model = make_model(nodes, inits)
142
+ return model.SerializeToString()
143
+
144
+
145
+ def detect_and_optimize(task_id, model_bytes, examples):
146
+ """Detect if task can be optimized with a simpler architecture.
147
+ Returns (optimized_bytes, name, estimated_score) or None.
148
+ """
149
+ # Try identity
150
+ model_id = optimize_identity(model_bytes)
151
+ if validate_model(model_id, examples):
152
+ return model_id, "identity_direct", 25.0
153
+
154
+ # Try transpose
155
+ model_t = optimize_transpose(model_bytes)
156
+ if validate_model(model_t, examples):
157
+ return model_t, "transpose_direct", 25.0
158
+
159
+ # Try flips
160
+ for axis, name in [(3, 'flip_lr'), (2, 'flip_ud')]:
161
+ opt = optimize_flip(axis)
162
+ if validate_model(opt, examples):
163
+ return opt, f"{name}_direct", 21.4
164
+
165
+ # Try color map detection
166
+ cm = {}
167
+ is_color_map = True
168
+ for ex in examples:
169
+ inp, out = np.array(ex['input']), np.array(ex['output'])
170
+ if inp.shape != out.shape:
171
+ is_color_map = False
172
+ break
173
+ for iv, ov in zip(inp.flat, out.flat):
174
+ iv, ov = int(iv), int(ov)
175
+ if iv in cm and cm[iv] != ov:
176
+ is_color_map = False
177
+ break
178
+ cm[iv] = ov
179
+ if not is_color_map:
180
+ break
181
+
182
+ if is_color_map and cm:
183
+ is_perm = (set(cm.keys()) <= set(range(10)) and set(cm.values()) <= set(range(10)))
184
+ if is_perm:
185
+ gather_ch = list(range(10))
186
+ for src, dst in cm.items():
187
+ if 0 <= src < 10 and 0 <= dst < 10:
188
+ gather_ch[dst] = src
189
+ opt = optimize_color_permutation(model_bytes, gather_ch)
190
+ if validate_model(opt, examples):
191
+ return opt, "color_perm_direct", 21.1
192
+
193
+ opt = optimize_color_map_conv1x1(model_bytes, cm)
194
+ if validate_model(opt, examples):
195
+ return opt, "color_map_conv1x1_direct", 18.8
196
+
197
+ return None
198
+
199
+
200
+ def main():
201
+ """Process submission zip and optimize trivial tasks."""
202
+ import argparse
203
+ parser = argparse.ArgumentParser()
204
+ parser.add_argument('--input_zip', required=True, help='Input submission.zip')
205
+ parser.add_argument('--data_dir', required=True, help='Directory with taskNNN.json files')
206
+ parser.add_argument('--output_zip', required=True, help='Output optimized submission.zip')
207
+ args = parser.parse_args()
208
+
209
+ # Load all models from zip
210
+ models = {}
211
+ with zipfile.ZipFile(args.input_zip, 'r') as zf:
212
+ for tid in range(1, 401):
213
+ fname = f'task{tid:03d}.onnx'
214
+ if fname in zf.namelist():
215
+ models[tid] = zf.read(fname)
216
+
217
+ print(f"Loaded {len(models)} models from {args.input_zip}")
218
+
219
+ # Process each task
220
+ optimized = {}
221
+ total_score_gain = 0.0
222
+
223
+ for tid in sorted(models.keys()):
224
+ task_path = os.path.join(args.data_dir, f'task{tid:03d}.json')
225
+ if not os.path.exists(task_path):
226
+ continue
227
+ with open(task_path) as f:
228
+ task_data = json.load(f)
229
+
230
+ examples = task_data.get('train', []) + task_data.get('test', [])
231
+ arcgen = task_data.get('arc-gen', [])[:30]
232
+ all_examples = examples + arcgen
233
+
234
+ if not all_examples:
235
+ continue
236
+
237
+ result = detect_and_optimize(tid, models[tid], all_examples)
238
+ if result:
239
+ opt_bytes, opt_name, est_score = result
240
+ orig_size = len(models[tid])
241
+ opt_size = len(opt_bytes)
242
+ optimized[tid] = opt_bytes
243
+ print(f" Task {tid:3d}: {opt_name:30s} ({orig_size:>6,} β†’ {opt_size:>6,} bytes) est_score={est_score:.1f}")
244
+
245
+ print(f"\nOptimized {len(optimized)} tasks (Stage 1: trivial rebuilds)")
246
+
247
+ # Write output zip
248
+ with zipfile.ZipFile(args.output_zip, 'w', zipfile.ZIP_DEFLATED) as zf:
249
+ for tid in range(1, 401):
250
+ fname = f'task{tid:03d}.onnx'
251
+ if tid in optimized:
252
+ zf.writestr(fname, optimized[tid])
253
+ elif tid in models:
254
+ zf.writestr(fname, models[tid])
255
+
256
+ print(f"Written to {args.output_zip}")
257
+
258
+
259
+ if __name__ == '__main__':
260
+ main()