rogermt commited on
Commit
a5e7732
·
verified ·
1 Parent(s): d7531b4

Add object layer: connected components, color splitting, list reducers, overlay/paint/underpaint

Browse files
Files changed (1) hide show
  1. itt_solver/object_layer.py +309 -0
itt_solver/object_layer.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Object extraction and manipulation primitives for ARC-AGI tasks.
3
+
4
+ Provides connected-component extraction, color-based splitting,
5
+ list reduction (largest/smallest/most_common), spatial queries,
6
+ and composition operations (overlay/paint/underpaint).
7
+ """
8
+ import numpy as np
9
+ from collections import Counter, deque
10
+
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Connected component extraction
14
+ # ---------------------------------------------------------------------------
15
+
16
+ def _flood_fill(grid, start, visited, connectivity=4, univalued=True):
17
+ """BFS flood fill from start. Returns set of (color, (r, c)) cells."""
18
+ h, w = grid.shape
19
+ r0, c0 = start
20
+ seed_color = int(grid[r0, c0])
21
+ comp = set()
22
+ queue = deque([(r0, c0)])
23
+ visited[r0, c0] = True
24
+
25
+ if connectivity == 8:
26
+ deltas = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]
27
+ else:
28
+ deltas = [(-1,0),(1,0),(0,-1),(0,1)]
29
+
30
+ while queue:
31
+ r, c = queue.popleft()
32
+ val = int(grid[r, c])
33
+ if univalued and val != seed_color:
34
+ continue
35
+ comp.add((val, (r, c)))
36
+ for dr, dc in deltas:
37
+ nr, nc = r + dr, c + dc
38
+ if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc]:
39
+ nval = int(grid[nr, nc])
40
+ if univalued:
41
+ if nval == seed_color:
42
+ visited[nr, nc] = True
43
+ queue.append((nr, nc))
44
+ else:
45
+ visited[nr, nc] = True
46
+ queue.append((nr, nc))
47
+ return comp
48
+
49
+
50
+ def extract_objects(grid, univalued=True, connectivity=4, without_bg=True):
51
+ """Extract connected components from grid.
52
+
53
+ Args:
54
+ grid: 2D numpy array (int)
55
+ univalued: if True, each component is single-color
56
+ connectivity: 4 or 8
57
+ without_bg: if True, skip the most common color (background)
58
+
59
+ Returns:
60
+ list of objects, each object is a set of (color, (row, col))
61
+ sorted by size descending
62
+ """
63
+ grid = np.array(grid, dtype=int)
64
+ h, w = grid.shape
65
+ bg = most_common_color(grid) if without_bg else -1
66
+ visited = np.zeros((h, w), dtype=bool)
67
+ objects = []
68
+
69
+ for r in range(h):
70
+ for c in range(w):
71
+ if visited[r, c]:
72
+ continue
73
+ val = int(grid[r, c])
74
+ if val == bg:
75
+ visited[r, c] = True
76
+ continue
77
+ comp = _flood_fill(grid, (r, c), visited, connectivity, univalued)
78
+ if comp:
79
+ objects.append(comp)
80
+
81
+ objects.sort(key=len, reverse=True)
82
+ return objects
83
+
84
+
85
+ def split_by_color(grid, without_bg=True):
86
+ """Split grid into per-color masks. Returns list of (color, grid) pairs
87
+ where each grid has only that color's pixels (rest = 0)."""
88
+ grid = np.array(grid, dtype=int)
89
+ bg = most_common_color(grid) if without_bg else -1
90
+ colors = sorted(set(grid.flatten()) - {bg})
91
+ result = []
92
+ for c in colors:
93
+ mask_grid = np.zeros_like(grid)
94
+ mask_grid[grid == c] = c
95
+ result.append((c, mask_grid))
96
+ return result
97
+
98
+
99
+ # ---------------------------------------------------------------------------
100
+ # Object to grid conversion
101
+ # ---------------------------------------------------------------------------
102
+
103
+ def object_to_grid(obj, shape, bg=0):
104
+ """Render an object (set of (color, (r,c))) onto a grid of given shape."""
105
+ grid = np.full(shape, bg, dtype=int)
106
+ for color, (r, c) in obj:
107
+ if 0 <= r < shape[0] and 0 <= c < shape[1]:
108
+ grid[r, c] = color
109
+ return grid
110
+
111
+
112
+ def object_to_cropped_grid(obj, bg=0):
113
+ """Render object cropped to its bounding box."""
114
+ if not obj:
115
+ return np.array([[bg]], dtype=int)
116
+ rows = [r for _, (r, c) in obj]
117
+ cols = [c for _, (r, c) in obj]
118
+ rmin, rmax = min(rows), max(rows)
119
+ cmin, cmax = min(cols), max(cols)
120
+ h, w = rmax - rmin + 1, cmax - cmin + 1
121
+ grid = np.full((h, w), bg, dtype=int)
122
+ for color, (r, c) in obj:
123
+ grid[r - rmin, c - cmin] = color
124
+ return grid
125
+
126
+
127
+ def normalize_object(obj):
128
+ """Shift object so its top-left corner is at (0, 0)."""
129
+ if not obj:
130
+ return obj
131
+ rows = [r for _, (r, c) in obj]
132
+ cols = [c for _, (r, c) in obj]
133
+ rmin, cmin = min(rows), min(cols)
134
+ return {(color, (r - rmin, c - cmin)) for color, (r, c) in obj}
135
+
136
+
137
+ def shift_object(obj, dr, dc):
138
+ """Shift all cells by (dr, dc)."""
139
+ return {(color, (r + dr, c + dc)) for color, (r, c) in obj}
140
+
141
+
142
+ # ---------------------------------------------------------------------------
143
+ # Object queries
144
+ # ---------------------------------------------------------------------------
145
+
146
+ def object_color(obj):
147
+ """Color of a univalued object."""
148
+ colors = {c for c, _ in obj}
149
+ if len(colors) == 1:
150
+ return colors.pop()
151
+ return max(colors, key=lambda c: sum(1 for cc, _ in obj if cc == c))
152
+
153
+
154
+ def object_size(obj):
155
+ return len(obj)
156
+
157
+
158
+ def object_bbox(obj):
159
+ """Returns (rmin, cmin, rmax, cmax)."""
160
+ rows = [r for _, (r, c) in obj]
161
+ cols = [c for _, (r, c) in obj]
162
+ return min(rows), min(cols), max(rows), max(cols)
163
+
164
+
165
+ def object_height(obj):
166
+ rmin, _, rmax, _ = object_bbox(obj)
167
+ return rmax - rmin + 1
168
+
169
+
170
+ def object_width(obj):
171
+ _, cmin, _, cmax = object_bbox(obj)
172
+ return cmax - cmin + 1
173
+
174
+
175
+ def object_center(obj):
176
+ rows = [r for _, (r, c) in obj]
177
+ cols = [c for _, (r, c) in obj]
178
+ return (sum(rows) / len(rows), sum(cols) / len(cols))
179
+
180
+
181
+ # ---------------------------------------------------------------------------
182
+ # List reducers
183
+ # ---------------------------------------------------------------------------
184
+
185
+ def largest_object(objects):
186
+ """Return the largest object by cell count."""
187
+ return max(objects, key=len) if objects else None
188
+
189
+
190
+ def smallest_object(objects):
191
+ """Return the smallest object by cell count."""
192
+ return min(objects, key=len) if objects else None
193
+
194
+
195
+ def most_common_object(objects):
196
+ """Return the object whose normalized shape appears most frequently."""
197
+ if not objects:
198
+ return None
199
+ normed = [frozenset(normalize_object(o)) for o in objects]
200
+ counter = Counter(normed)
201
+ most_common_shape = counter.most_common(1)[0][0]
202
+ for o, n in zip(objects, normed):
203
+ if n == most_common_shape:
204
+ return o
205
+ return objects[0]
206
+
207
+
208
+ def unique_object(objects):
209
+ """If exactly one unique normalized shape exists, return it. Else None."""
210
+ normed = [frozenset(normalize_object(o)) for o in objects]
211
+ counter = Counter(normed)
212
+ uniques = [shape for shape, count in counter.items() if count == 1]
213
+ if len(uniques) == 1:
214
+ for o, n in zip(objects, normed):
215
+ if n == uniques[0]:
216
+ return o
217
+ return None
218
+
219
+
220
+ def filter_by_color(objects, color):
221
+ """Keep only objects of the given color."""
222
+ return [o for o in objects if object_color(o) == color]
223
+
224
+
225
+ def filter_by_size(objects, size):
226
+ """Keep only objects of the given size."""
227
+ return [o for o in objects if len(o) == size]
228
+
229
+
230
+ # ---------------------------------------------------------------------------
231
+ # Color utilities
232
+ # ---------------------------------------------------------------------------
233
+
234
+ def most_common_color(grid):
235
+ """Most frequent color in the grid (= background)."""
236
+ grid = np.array(grid, dtype=int)
237
+ counts = Counter(grid.flatten().tolist())
238
+ return counts.most_common(1)[0][0]
239
+
240
+
241
+ def least_common_color(grid):
242
+ """Least frequent color in the grid."""
243
+ grid = np.array(grid, dtype=int)
244
+ counts = Counter(grid.flatten().tolist())
245
+ return counts.most_common()[-1][0]
246
+
247
+
248
+ def palette(grid):
249
+ """Set of all colors in grid."""
250
+ return set(np.array(grid, dtype=int).flatten().tolist())
251
+
252
+
253
+ def color_normalize(grid):
254
+ """Remap colors by frequency: most common -> 0, next -> 1, etc."""
255
+ grid = np.array(grid, dtype=int)
256
+ counts = Counter(grid.flatten().tolist())
257
+ ranked = [c for c, _ in counts.most_common()]
258
+ remap = {c: i for i, c in enumerate(ranked)}
259
+ return np.vectorize(remap.get)(grid)
260
+
261
+
262
+ # ---------------------------------------------------------------------------
263
+ # Composition / overlay
264
+ # ---------------------------------------------------------------------------
265
+
266
+ def paint(grid, obj):
267
+ """Paint object onto grid. Object cells OVERWRITE grid cells."""
268
+ result = np.array(grid, dtype=int).copy()
269
+ for color, (r, c) in obj:
270
+ if 0 <= r < result.shape[0] and 0 <= c < result.shape[1]:
271
+ result[r, c] = color
272
+ return result
273
+
274
+
275
+ def underpaint(grid, obj):
276
+ """Paint object onto grid, but ONLY where grid has background color."""
277
+ result = np.array(grid, dtype=int).copy()
278
+ bg = most_common_color(result)
279
+ for color, (r, c) in obj:
280
+ if 0 <= r < result.shape[0] and 0 <= c < result.shape[1]:
281
+ if result[r, c] == bg:
282
+ result[r, c] = color
283
+ return result
284
+
285
+
286
+ def overlay_grids(base, foreground):
287
+ """Overlay foreground onto base. Foreground non-zero pixels overwrite."""
288
+ base = np.array(base, dtype=int).copy()
289
+ fg = np.array(foreground, dtype=int)
290
+ h = min(base.shape[0], fg.shape[0])
291
+ w = min(base.shape[1], fg.shape[1])
292
+ mask = fg[:h, :w] != 0
293
+ base[:h, :w][mask] = fg[:h, :w][mask]
294
+ return base
295
+
296
+
297
+ def cover(grid, obj):
298
+ """Erase object from grid (replace with background color)."""
299
+ result = np.array(grid, dtype=int).copy()
300
+ bg = most_common_color(result)
301
+ for _, (r, c) in obj:
302
+ if 0 <= r < result.shape[0] and 0 <= c < result.shape[1]:
303
+ result[r, c] = bg
304
+ return result
305
+
306
+
307
+ def canvas(bg_color, shape):
308
+ """Create a blank grid filled with bg_color."""
309
+ return np.full(shape, bg_color, dtype=int)