rogermt commited on
Commit
1050864
·
verified ·
1 Parent(s): 597dca6

Add merge_best_of_both.py: picks lowest-cost model from 5743 vs V90 per task (42 tasks improved, 55% size reduction)"

Browse files
Files changed (1) hide show
  1. own-solver/merge_best_of_both.py +200 -0
own-solver/merge_best_of_both.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Best-of-both merger: Pick the better model from 5743 and V90 for each task.
4
+
5
+ Under new formula (score = 25 - ln(memory + params)):
6
+ - memory = sum of ALL intermediate tensor bytes
7
+ - Each node output that isn't 'output' costs memory
8
+ - FEWER NODES = LESS MEMORY = HIGHER SCORE
9
+
10
+ Strategy: For each task, pick the model with fewer intermediate tensors.
11
+ If tied, pick the smaller file (fewer params/weight bytes).
12
+ Validates each candidate against train+test+arc-gen before accepting.
13
+
14
+ Usage:
15
+ python merge_best_of_both.py \
16
+ --sub_5743 submission-5743.zip \
17
+ --sub_v90 submission-6043.zip \
18
+ --data_dir ./tasks \
19
+ --output_zip submission_merged.zip
20
+ """
21
+
22
+ import json
23
+ import math
24
+ import os
25
+ import zipfile
26
+
27
+ import numpy as np
28
+ import onnx
29
+ import onnxruntime as ort
30
+ from onnx import helper, TensorProto, numpy_helper
31
+
32
+
33
+ def encode_grid(grid):
34
+ arr = np.array(grid, dtype=np.int32)
35
+ h, w = arr.shape
36
+ t = np.zeros((1, 10, 30, 30), dtype=np.float32)
37
+ for r in range(h):
38
+ for c in range(w):
39
+ v = int(arr[r, c])
40
+ if 0 <= v < 10:
41
+ t[0, v, r, c] = 1.0
42
+ return t
43
+
44
+
45
+ def validate_model(model_bytes, examples, max_check=30):
46
+ """Validate model produces correct output on examples."""
47
+ try:
48
+ opts = ort.SessionOptions()
49
+ opts.log_severity_level = 3
50
+ sess = ort.InferenceSession(model_bytes, sess_options=opts, providers=['CPUExecutionProvider'])
51
+ except Exception:
52
+ return False
53
+ for ex in examples[:max_check]:
54
+ try:
55
+ inp = encode_grid(ex['input'])
56
+ out = sess.run(['output'], {'input': inp})[0]
57
+ expected = encode_grid(ex['output'])
58
+ if not np.array_equal((out > 0.0).astype(np.float32), expected):
59
+ return False
60
+ except Exception:
61
+ return False
62
+ return True
63
+
64
+
65
+ def count_intermediates(model_bytes):
66
+ """Count intermediate tensors (proxy for runtime memory cost)."""
67
+ try:
68
+ model = onnx.load_from_string(model_bytes)
69
+ count = 0
70
+ for node in model.graph.node:
71
+ for out in node.output:
72
+ if out and out != 'output':
73
+ count += 1
74
+ return count
75
+ except Exception:
76
+ return 999999
77
+
78
+
79
+ def estimate_cost(model_bytes):
80
+ """Estimate cost under new formula: memory + params."""
81
+ try:
82
+ model = onnx.load_from_string(model_bytes)
83
+ except Exception:
84
+ return float('inf')
85
+
86
+ weight_memory = 0
87
+ params = 0
88
+ for init in model.graph.initializer:
89
+ arr = numpy_helper.to_array(init)
90
+ weight_memory += arr.nbytes
91
+ params += arr.size
92
+ for node in model.graph.node:
93
+ if node.op_type == 'Constant':
94
+ for attr in node.attribute:
95
+ if attr.name == 'value' and attr.t.ByteSize() > 0:
96
+ try:
97
+ arr = numpy_helper.to_array(attr.t)
98
+ weight_memory += arr.nbytes
99
+ params += arr.size
100
+ except:
101
+ params += 1
102
+
103
+ intermediates = 0
104
+ for node in model.graph.node:
105
+ for out in node.output:
106
+ if out and out != 'output':
107
+ intermediates += 1
108
+
109
+ intermediate_memory = intermediates * 20000 # avg estimate
110
+ cost = weight_memory + intermediate_memory + params
111
+ return cost
112
+
113
+
114
+ def main():
115
+ import argparse
116
+ parser = argparse.ArgumentParser()
117
+ parser.add_argument('--sub_5743', required=True, help='5743 submission zip')
118
+ parser.add_argument('--sub_v90', required=True, help='V90 submission zip')
119
+ parser.add_argument('--data_dir', required=True, help='Task JSON directory')
120
+ parser.add_argument('--output_zip', required=True, help='Output merged zip')
121
+ args = parser.parse_args()
122
+
123
+ models_5743 = {}
124
+ models_v90 = {}
125
+
126
+ with zipfile.ZipFile(args.sub_5743, 'r') as zf:
127
+ for tid in range(1, 401):
128
+ fname = f'task{tid:03d}.onnx'
129
+ if fname in zf.namelist():
130
+ models_5743[tid] = zf.read(fname)
131
+
132
+ with zipfile.ZipFile(args.sub_v90, 'r') as zf:
133
+ for tid in range(1, 401):
134
+ fname = f'task{tid:03d}.onnx'
135
+ if fname in zf.namelist():
136
+ models_v90[tid] = zf.read(fname)
137
+
138
+ print(f"Loaded {len(models_5743)} from 5743, {len(models_v90)} from V90")
139
+
140
+ merged = {}
141
+ stats = {'5743': 0, 'v90': 0, 'validated_5743': 0}
142
+
143
+ for tid in range(1, 401):
144
+ b5 = models_5743.get(tid)
145
+ bv = models_v90.get(tid)
146
+
147
+ if not b5 and not bv:
148
+ continue
149
+ if not b5:
150
+ merged[tid] = bv
151
+ stats['v90'] += 1
152
+ continue
153
+ if not bv:
154
+ merged[tid] = b5
155
+ stats['5743'] += 1
156
+ continue
157
+
158
+ cost5 = estimate_cost(b5)
159
+ costv = estimate_cost(bv)
160
+
161
+ if cost5 < costv:
162
+ task_path = os.path.join(args.data_dir, f'task{tid:03d}.json')
163
+ if os.path.exists(task_path):
164
+ with open(task_path) as f:
165
+ task_data = json.load(f)
166
+ examples = task_data.get('train', []) + task_data.get('test', [])
167
+ arcgen = task_data.get('arc-gen', [])[:30]
168
+ all_ex = examples + arcgen
169
+
170
+ if validate_model(b5, all_ex):
171
+ merged[tid] = b5
172
+ stats['validated_5743'] += 1
173
+ inter5 = count_intermediates(b5)
174
+ interv = count_intermediates(bv)
175
+ if interv - inter5 > 5:
176
+ print(f" Task {tid:3d}: USE 5743 ({inter5} vs {interv} intermediates)")
177
+ continue
178
+
179
+ merged[tid] = bv
180
+ stats['v90'] += 1
181
+ else:
182
+ merged[tid] = bv
183
+ stats['v90'] += 1
184
+
185
+ print(f"\nUsing 5743: {stats['5743']} + {stats['validated_5743']} validated")
186
+ print(f"Using V90: {stats['v90']}")
187
+ print(f"Total: {len(merged)}")
188
+
189
+ with zipfile.ZipFile(args.output_zip, 'w', zipfile.ZIP_DEFLATED) as zf:
190
+ for tid in range(1, 401):
191
+ fname = f'task{tid:03d}.onnx'
192
+ if tid in merged:
193
+ zf.writestr(fname, merged[tid])
194
+
195
+ total_size = sum(len(v) for v in merged.values())
196
+ print(f"Output: {args.output_zip} ({total_size:,} bytes)")
197
+
198
+
199
+ if __name__ == '__main__':
200
+ main()