Add Stage 1 trivial task optimizer (zero-intermediate architectures for max score)
Browse files
own-solver/stage1_trivial_optimizer.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Stage 1: Rebuild trivial tasks for maximum score under new formula.
|
| 4 |
+
|
| 5 |
+
New formula: score = max(1.0, 25.0 - ln(memory + params))
|
| 6 |
+
Where memory = sum of ALL intermediate tensor bytes (excluding 'input'/'output')
|
| 7 |
+
params = sum of all initializer element counts + Constant node values
|
| 8 |
+
|
| 9 |
+
KEY INSIGHT: If a node's output IS the graph output ('output'), it's NOT counted.
|
| 10 |
+
So: input β [single op] β output = ZERO intermediate memory cost.
|
| 11 |
+
Only weights/constants contribute.
|
| 12 |
+
|
| 13 |
+
This script rewrites existing ONNX models to use minimal architectures:
|
| 14 |
+
- Transpose: inputβTransposeβoutput (cost=0, score=25)
|
| 15 |
+
- Color permutation: inputβGather(axis=1)βoutput (cost=40+10=50, score=21)
|
| 16 |
+
- Color mapping: inputβConv1x1βoutput (cost=500, score=18.8)
|
| 17 |
+
- Identity: inputβIdentityβoutput (cost=0, score=25)
|
| 18 |
+
- Flips: inputβSlice(reverse)βoutput (cost=32+4=36, score=21.4)
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
python stage1_trivial_optimizer.py --input_zip submission.zip --data_dir ./tasks --output_zip submission_stage1.zip
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import math
|
| 26 |
+
import os
|
| 27 |
+
import sys
|
| 28 |
+
import zipfile
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import onnx
|
| 33 |
+
import onnxruntime as ort
|
| 34 |
+
from onnx import helper, TensorProto, numpy_helper
|
| 35 |
+
|
| 36 |
+
# βββ Config βββ
|
| 37 |
+
GRID_SHAPE = [1, 10, 30, 30]
|
| 38 |
+
DT = TensorProto.FLOAT
|
| 39 |
+
IR = 8
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_model(nodes, inits=None, opset=17):
|
| 43 |
+
"""Create minimal ONNX model."""
|
| 44 |
+
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
|
| 45 |
+
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
|
| 46 |
+
g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
|
| 47 |
+
return helper.make_model(g, ir_version=IR, opset_imports=[helper.make_opsetid("", opset)])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def encode_grid(grid):
|
| 51 |
+
"""Encode grid to one-hot tensor."""
|
| 52 |
+
arr = np.array(grid, dtype=np.int32)
|
| 53 |
+
h, w = arr.shape
|
| 54 |
+
t = np.zeros((1, 10, 30, 30), dtype=np.float32)
|
| 55 |
+
for r in range(h):
|
| 56 |
+
for c in range(w):
|
| 57 |
+
v = int(arr[r, c])
|
| 58 |
+
if 0 <= v < 10:
|
| 59 |
+
t[0, v, r, c] = 1.0
|
| 60 |
+
return t
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def validate_model(model_bytes, examples):
|
| 64 |
+
"""Validate model produces correct output on all examples."""
|
| 65 |
+
try:
|
| 66 |
+
opts = ort.SessionOptions()
|
| 67 |
+
opts.log_severity_level = 3
|
| 68 |
+
sess = ort.InferenceSession(model_bytes, sess_options=opts, providers=['CPUExecutionProvider'])
|
| 69 |
+
except Exception:
|
| 70 |
+
return False
|
| 71 |
+
for ex in examples:
|
| 72 |
+
try:
|
| 73 |
+
inp = encode_grid(ex['input'])
|
| 74 |
+
out = sess.run(['output'], {'input': inp})[0]
|
| 75 |
+
expected = encode_grid(ex['output'])
|
| 76 |
+
if not np.array_equal((out > 0.0).astype(float), (expected > 0.0).astype(float)):
|
| 77 |
+
return False
|
| 78 |
+
except Exception:
|
| 79 |
+
return False
|
| 80 |
+
return True
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# βββ Stage 1 Optimizers βββ
|
| 84 |
+
|
| 85 |
+
def optimize_transpose(model_bytes):
|
| 86 |
+
"""Rebuild as: input β Transpose β output (cost=0, score=25)."""
|
| 87 |
+
nodes = [helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])]
|
| 88 |
+
model = make_model(nodes)
|
| 89 |
+
return model.SerializeToString()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def optimize_identity(model_bytes):
|
| 93 |
+
"""Rebuild as: input β Identity β output (cost=0, score=25)."""
|
| 94 |
+
nodes = [helper.make_node('Identity', ['input'], ['output'])]
|
| 95 |
+
model = make_model(nodes)
|
| 96 |
+
return model.SerializeToString()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def optimize_color_permutation(model_bytes, perm):
|
| 100 |
+
"""Rebuild as: input β Gather(axis=1) β output.
|
| 101 |
+
perm[i] = source channel for output channel i.
|
| 102 |
+
Cost: 10 int32 elements = 40 bytes memory + 10 params = 50. Score β 21.1
|
| 103 |
+
"""
|
| 104 |
+
gi = np.array(perm, dtype=np.int32)
|
| 105 |
+
inits = [numpy_helper.from_array(gi, 'gi')]
|
| 106 |
+
nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
|
| 107 |
+
model = make_model(nodes, inits)
|
| 108 |
+
return model.SerializeToString()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def optimize_color_map_conv1x1(model_bytes, color_map):
|
| 112 |
+
"""Rebuild as: input β Conv(1x1) β output.
|
| 113 |
+
Cost: 100 float32 elements = 400 bytes + 100 params = 500. Score β 18.8
|
| 114 |
+
"""
|
| 115 |
+
W = np.zeros((10, 10, 1, 1), dtype=np.float32)
|
| 116 |
+
for ic in range(10):
|
| 117 |
+
oc = color_map.get(ic, ic)
|
| 118 |
+
if 0 <= oc < 10:
|
| 119 |
+
W[oc, ic, 0, 0] = 1.0
|
| 120 |
+
inits = [numpy_helper.from_array(W, 'W')]
|
| 121 |
+
nodes = [helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])]
|
| 122 |
+
model = make_model(nodes, inits)
|
| 123 |
+
return model.SerializeToString()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def optimize_flip(axis):
|
| 127 |
+
"""Rebuild as: input β Slice(reverse) β output.
|
| 128 |
+
Cost: 4 int64 scalars = 32 bytes + 4 params = 36. Score β 21.4
|
| 129 |
+
"""
|
| 130 |
+
starts = np.array([29], dtype=np.int64)
|
| 131 |
+
ends = np.array([np.iinfo(np.int64).min], dtype=np.int64)
|
| 132 |
+
axes = np.array([axis], dtype=np.int64)
|
| 133 |
+
steps = np.array([-1], dtype=np.int64)
|
| 134 |
+
inits = [
|
| 135 |
+
numpy_helper.from_array(starts, 'st'),
|
| 136 |
+
numpy_helper.from_array(ends, 'en'),
|
| 137 |
+
numpy_helper.from_array(axes, 'ax'),
|
| 138 |
+
numpy_helper.from_array(steps, 'sp'),
|
| 139 |
+
]
|
| 140 |
+
nodes = [helper.make_node('Slice', ['input', 'st', 'en', 'ax', 'sp'], ['output'])]
|
| 141 |
+
model = make_model(nodes, inits)
|
| 142 |
+
return model.SerializeToString()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def detect_and_optimize(task_id, model_bytes, examples):
|
| 146 |
+
"""Detect if task can be optimized with a simpler architecture.
|
| 147 |
+
Returns (optimized_bytes, name, estimated_score) or None.
|
| 148 |
+
"""
|
| 149 |
+
# Try identity
|
| 150 |
+
model_id = optimize_identity(model_bytes)
|
| 151 |
+
if validate_model(model_id, examples):
|
| 152 |
+
return model_id, "identity_direct", 25.0
|
| 153 |
+
|
| 154 |
+
# Try transpose
|
| 155 |
+
model_t = optimize_transpose(model_bytes)
|
| 156 |
+
if validate_model(model_t, examples):
|
| 157 |
+
return model_t, "transpose_direct", 25.0
|
| 158 |
+
|
| 159 |
+
# Try flips
|
| 160 |
+
for axis, name in [(3, 'flip_lr'), (2, 'flip_ud')]:
|
| 161 |
+
opt = optimize_flip(axis)
|
| 162 |
+
if validate_model(opt, examples):
|
| 163 |
+
return opt, f"{name}_direct", 21.4
|
| 164 |
+
|
| 165 |
+
# Try color map detection
|
| 166 |
+
cm = {}
|
| 167 |
+
is_color_map = True
|
| 168 |
+
for ex in examples:
|
| 169 |
+
inp, out = np.array(ex['input']), np.array(ex['output'])
|
| 170 |
+
if inp.shape != out.shape:
|
| 171 |
+
is_color_map = False
|
| 172 |
+
break
|
| 173 |
+
for iv, ov in zip(inp.flat, out.flat):
|
| 174 |
+
iv, ov = int(iv), int(ov)
|
| 175 |
+
if iv in cm and cm[iv] != ov:
|
| 176 |
+
is_color_map = False
|
| 177 |
+
break
|
| 178 |
+
cm[iv] = ov
|
| 179 |
+
if not is_color_map:
|
| 180 |
+
break
|
| 181 |
+
|
| 182 |
+
if is_color_map and cm:
|
| 183 |
+
is_perm = (set(cm.keys()) <= set(range(10)) and set(cm.values()) <= set(range(10)))
|
| 184 |
+
if is_perm:
|
| 185 |
+
gather_ch = list(range(10))
|
| 186 |
+
for src, dst in cm.items():
|
| 187 |
+
if 0 <= src < 10 and 0 <= dst < 10:
|
| 188 |
+
gather_ch[dst] = src
|
| 189 |
+
opt = optimize_color_permutation(model_bytes, gather_ch)
|
| 190 |
+
if validate_model(opt, examples):
|
| 191 |
+
return opt, "color_perm_direct", 21.1
|
| 192 |
+
|
| 193 |
+
opt = optimize_color_map_conv1x1(model_bytes, cm)
|
| 194 |
+
if validate_model(opt, examples):
|
| 195 |
+
return opt, "color_map_conv1x1_direct", 18.8
|
| 196 |
+
|
| 197 |
+
return None
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def main():
|
| 201 |
+
"""Process submission zip and optimize trivial tasks."""
|
| 202 |
+
import argparse
|
| 203 |
+
parser = argparse.ArgumentParser()
|
| 204 |
+
parser.add_argument('--input_zip', required=True, help='Input submission.zip')
|
| 205 |
+
parser.add_argument('--data_dir', required=True, help='Directory with taskNNN.json files')
|
| 206 |
+
parser.add_argument('--output_zip', required=True, help='Output optimized submission.zip')
|
| 207 |
+
args = parser.parse_args()
|
| 208 |
+
|
| 209 |
+
# Load all models from zip
|
| 210 |
+
models = {}
|
| 211 |
+
with zipfile.ZipFile(args.input_zip, 'r') as zf:
|
| 212 |
+
for tid in range(1, 401):
|
| 213 |
+
fname = f'task{tid:03d}.onnx'
|
| 214 |
+
if fname in zf.namelist():
|
| 215 |
+
models[tid] = zf.read(fname)
|
| 216 |
+
|
| 217 |
+
print(f"Loaded {len(models)} models from {args.input_zip}")
|
| 218 |
+
|
| 219 |
+
# Process each task
|
| 220 |
+
optimized = {}
|
| 221 |
+
total_score_gain = 0.0
|
| 222 |
+
|
| 223 |
+
for tid in sorted(models.keys()):
|
| 224 |
+
task_path = os.path.join(args.data_dir, f'task{tid:03d}.json')
|
| 225 |
+
if not os.path.exists(task_path):
|
| 226 |
+
continue
|
| 227 |
+
with open(task_path) as f:
|
| 228 |
+
task_data = json.load(f)
|
| 229 |
+
|
| 230 |
+
examples = task_data.get('train', []) + task_data.get('test', [])
|
| 231 |
+
arcgen = task_data.get('arc-gen', [])[:30]
|
| 232 |
+
all_examples = examples + arcgen
|
| 233 |
+
|
| 234 |
+
if not all_examples:
|
| 235 |
+
continue
|
| 236 |
+
|
| 237 |
+
result = detect_and_optimize(tid, models[tid], all_examples)
|
| 238 |
+
if result:
|
| 239 |
+
opt_bytes, opt_name, est_score = result
|
| 240 |
+
orig_size = len(models[tid])
|
| 241 |
+
opt_size = len(opt_bytes)
|
| 242 |
+
optimized[tid] = opt_bytes
|
| 243 |
+
print(f" Task {tid:3d}: {opt_name:30s} ({orig_size:>6,} β {opt_size:>6,} bytes) est_score={est_score:.1f}")
|
| 244 |
+
|
| 245 |
+
print(f"\nOptimized {len(optimized)} tasks (Stage 1: trivial rebuilds)")
|
| 246 |
+
|
| 247 |
+
# Write output zip
|
| 248 |
+
with zipfile.ZipFile(args.output_zip, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 249 |
+
for tid in range(1, 401):
|
| 250 |
+
fname = f'task{tid:03d}.onnx'
|
| 251 |
+
if tid in optimized:
|
| 252 |
+
zf.writestr(fname, optimized[tid])
|
| 253 |
+
elif tid in models:
|
| 254 |
+
zf.writestr(fname, models[tid])
|
| 255 |
+
|
| 256 |
+
print(f"Written to {args.output_zip}")
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == '__main__':
|
| 260 |
+
main()
|