rogermt commited on
Commit
8cf6976
·
verified ·
1 Parent(s): 6cae82d

Add predicate enumeration engine: neighborhood rules + enclosed fill + object predicates — 70/400 (17.5%)

Browse files
Files changed (1) hide show
  1. itt_solver/predicate_engine.py +643 -0
itt_solver/predicate_engine.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Object Predicate + Action Enumeration Engine
3
+ =============================================
4
+
5
+ For each ARC task, enumerate combinations of:
6
+ (object_abstraction) x (predicate) x (action)
7
+
8
+ Test each rule against ALL training pairs. If a rule produces
9
+ exact output for every pair, use it.
10
+
11
+ This is the GPAR approach simplified to pure Python — no PDDL,
12
+ no planner. Just brute-force enumeration of ~600 rule templates.
13
+
14
+ Covers:
15
+ - Fill miss (35% of unsolved): enclosed_by, neighbor_count conditions
16
+ - Recolor miss (24%): object attribute conditions (size, color, position)
17
+ - Shape change (25%): extract by predicate
18
+ """
19
+
20
+ import numpy as np
21
+ from collections import Counter, deque
22
+ from typing import Dict, List, Tuple, Optional, Set, Callable
23
+
24
+
25
+ # =============================================================================
26
+ # Object Extraction (robust, multiple abstractions)
27
+ # =============================================================================
28
+
29
+ def _flood(grid, r, c, visited, color, connectivity=4):
30
+ """BFS flood fill for a single color component."""
31
+ h, w = grid.shape
32
+ cells = set()
33
+ queue = deque([(r, c)])
34
+ visited[r, c] = True
35
+ deltas = [(-1,0),(1,0),(0,-1),(0,1)]
36
+ if connectivity == 8:
37
+ deltas += [(-1,-1),(-1,1),(1,-1),(1,1)]
38
+ while queue:
39
+ cr, cc = queue.popleft()
40
+ cells.add((cr, cc))
41
+ for dr, dc in deltas:
42
+ nr, nc = cr + dr, cc + dc
43
+ if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc] and grid[nr, nc] == color:
44
+ visited[nr, nc] = True
45
+ queue.append((nr, nc))
46
+ return cells
47
+
48
+
49
+ def extract_objects_multi(grid, connectivity=4):
50
+ """Extract all non-background connected components.
51
+ Returns list of dicts with color, cells, mask, bbox, size, touches_border."""
52
+ grid = np.array(grid, dtype=int)
53
+ h, w = grid.shape
54
+ bg = Counter(grid.flatten().tolist()).most_common(1)[0][0]
55
+ visited = np.zeros((h, w), dtype=bool)
56
+ objects = []
57
+
58
+ for r in range(h):
59
+ for c in range(w):
60
+ if visited[r, c] or grid[r, c] == bg:
61
+ visited[r, c] = True
62
+ continue
63
+ color = int(grid[r, c])
64
+ cells = _flood(grid, r, c, visited, color, connectivity)
65
+ if not cells:
66
+ continue
67
+ rows = [cr for cr, _ in cells]
68
+ cols = [cc for _, cc in cells]
69
+ rmin, rmax = min(rows), max(rows)
70
+ cmin, cmax = min(cols), max(cols)
71
+ mask = np.zeros((h, w), dtype=bool)
72
+ for cr, cc in cells:
73
+ mask[cr, cc] = True
74
+ touches = any(cr == 0 or cr == h-1 or cc == 0 or cc == w-1 for cr, cc in cells)
75
+ objects.append({
76
+ 'color': color,
77
+ 'cells': cells,
78
+ 'mask': mask,
79
+ 'bbox': (rmin, cmin, rmax, cmax),
80
+ 'size': len(cells),
81
+ 'touches_border': touches,
82
+ 'height': rmax - rmin + 1,
83
+ 'width': cmax - cmin + 1,
84
+ 'center_r': sum(rows) / len(rows),
85
+ 'center_c': sum(cols) / len(cols),
86
+ })
87
+
88
+ return objects, bg
89
+
90
+
91
+ def get_enclosed_bg_regions(grid, bg):
92
+ """Find background regions NOT reachable from grid border."""
93
+ grid = np.array(grid, dtype=int)
94
+ h, w = grid.shape
95
+ visited = np.zeros((h, w), dtype=bool)
96
+ queue = deque()
97
+
98
+ # Flood from all border bg cells
99
+ for r in range(h):
100
+ for c in range(w):
101
+ if (r == 0 or r == h-1 or c == 0 or c == w-1) and grid[r, c] == bg:
102
+ if not visited[r, c]:
103
+ visited[r, c] = True
104
+ queue.append((r, c))
105
+
106
+ while queue:
107
+ r, c = queue.popleft()
108
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
109
+ nr, nc = r + dr, c + dc
110
+ if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc] and grid[nr, nc] == bg:
111
+ visited[nr, nc] = True
112
+ queue.append((nr, nc))
113
+
114
+ # Enclosed = bg cells not visited
115
+ enclosed = (grid == bg) & ~visited
116
+ return enclosed
117
+
118
+
119
+ def get_neighbor_colors(grid, r, c, bg=0):
120
+ """Get non-bg neighbor colors (4-connectivity)."""
121
+ h, w = grid.shape
122
+ colors = []
123
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
124
+ nr, nc = r + dr, c + dc
125
+ if 0 <= nr < h and 0 <= nc < w and grid[nr, nc] != bg:
126
+ colors.append(int(grid[nr, nc]))
127
+ return colors
128
+
129
+
130
+ # =============================================================================
131
+ # Object Predicates
132
+ # =============================================================================
133
+
134
+ def _build_predicates(objects, bg):
135
+ """Build predicate functions that test object properties."""
136
+ if not objects:
137
+ return {}
138
+
139
+ sizes = [o['size'] for o in objects]
140
+ max_size = max(sizes)
141
+ min_size = min(sizes)
142
+ colors_list = [o['color'] for o in objects]
143
+ color_counts = Counter(colors_list)
144
+ most_common_color = color_counts.most_common(1)[0][0]
145
+ least_common_color = color_counts.most_common()[-1][0]
146
+
147
+ predicates = {
148
+ 'is_largest': lambda o: o['size'] == max_size,
149
+ 'is_smallest': lambda o: o['size'] == min_size,
150
+ 'touches_border': lambda o: o['touches_border'],
151
+ 'not_touches_border': lambda o: not o['touches_border'],
152
+ 'is_most_common_color': lambda o: o['color'] == most_common_color,
153
+ 'is_least_common_color': lambda o: o['color'] == least_common_color,
154
+ 'always_true': lambda o: True,
155
+ }
156
+
157
+ # Add per-color predicates
158
+ for color in set(colors_list):
159
+ predicates[f'color_is_{color}'] = (lambda c: lambda o: o['color'] == c)(color)
160
+
161
+ # Size-based
162
+ if len(set(sizes)) > 1:
163
+ median_size = sorted(sizes)[len(sizes) // 2]
164
+ predicates['size_above_median'] = lambda o: o['size'] > median_size
165
+ predicates['size_below_median'] = lambda o: o['size'] < median_size
166
+
167
+ return predicates
168
+
169
+
170
+ # =============================================================================
171
+ # Actions
172
+ # =============================================================================
173
+
174
+ def _build_actions(objects, bg, grid_shape):
175
+ """Build action functions that transform a grid based on selected objects."""
176
+
177
+ all_colors = set(o['color'] for o in objects) | {bg}
178
+
179
+ actions = {}
180
+
181
+ # Recolor: change matching objects to a specific color
182
+ for target_color in range(10):
183
+ if target_color == bg:
184
+ continue
185
+ actions[f'recolor_to_{target_color}'] = (
186
+ lambda tc: lambda grid, selected_masks: _apply_recolor(grid, selected_masks, tc)
187
+ )(target_color)
188
+
189
+ # Fill enclosed regions of matching objects
190
+ actions['fill_enclosed'] = lambda grid, selected_masks: _apply_fill_enclosed(grid, selected_masks, bg)
191
+
192
+ # Fill interior (bbox minus object cells)
193
+ actions['fill_interior'] = lambda grid, selected_masks: _apply_fill_interior(grid, selected_masks, bg)
194
+
195
+ # Remove (set to bg)
196
+ actions['remove'] = lambda grid, selected_masks: _apply_remove(grid, selected_masks, bg)
197
+
198
+ # Extract (keep only selected, clear rest)
199
+ actions['extract'] = lambda grid, selected_masks: _apply_extract(grid, selected_masks, bg)
200
+
201
+ return actions
202
+
203
+
204
+ def _apply_recolor(grid, selected_masks, target_color):
205
+ result = grid.copy()
206
+ for mask in selected_masks:
207
+ result[mask] = target_color
208
+ return result
209
+
210
+
211
+ def _apply_fill_enclosed(grid, selected_masks, bg):
212
+ """Fill enclosed background regions that are bounded by selected objects."""
213
+ result = grid.copy()
214
+ h, w = grid.shape
215
+
216
+ for mask in selected_masks:
217
+ color = int(grid[mask][0]) if np.any(mask) else 0
218
+ if color == 0:
219
+ continue
220
+ # Find bbox of this object
221
+ rows, cols = np.where(mask)
222
+ if len(rows) == 0:
223
+ continue
224
+ rmin, rmax = rows.min(), rows.max()
225
+ cmin, cmax = cols.min(), cols.max()
226
+
227
+ # Within bbox, find bg cells enclosed by this object
228
+ for r in range(rmin, rmax + 1):
229
+ for c in range(cmin, cmax + 1):
230
+ if result[r, c] == bg:
231
+ # Check if this bg cell is "inside" the object
232
+ # Simple test: surrounded on all 4 cardinal directions by object cells
233
+ inside = True
234
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
235
+ found = False
236
+ nr, nc = r + dr, c + dc
237
+ while 0 <= nr < h and 0 <= nc < w:
238
+ if mask[nr, nc]:
239
+ found = True
240
+ break
241
+ nr += dr
242
+ nc += dc
243
+ if not found:
244
+ inside = False
245
+ break
246
+ if inside:
247
+ result[r, c] = color
248
+ return result
249
+
250
+
251
+ def _apply_fill_interior(grid, selected_masks, bg):
252
+ """Fill the bounding box interior of selected objects with the object's color."""
253
+ result = grid.copy()
254
+ for mask in selected_masks:
255
+ color = int(grid[mask][0]) if np.any(mask) else 0
256
+ if color == 0:
257
+ continue
258
+ rows, cols = np.where(mask)
259
+ if len(rows) == 0:
260
+ continue
261
+ rmin, rmax = rows.min(), rows.max()
262
+ cmin, cmax = cols.min(), cols.max()
263
+ for r in range(rmin, rmax + 1):
264
+ for c in range(cmin, cmax + 1):
265
+ if result[r, c] == bg:
266
+ result[r, c] = color
267
+ return result
268
+
269
+
270
+ def _apply_remove(grid, selected_masks, bg):
271
+ result = grid.copy()
272
+ for mask in selected_masks:
273
+ result[mask] = bg
274
+ return result
275
+
276
+
277
+ def _apply_extract(grid, selected_masks, bg):
278
+ result = np.full_like(grid, bg)
279
+ for mask in selected_masks:
280
+ result[mask] = grid[mask]
281
+ return result
282
+
283
+
284
+ # =============================================================================
285
+ # Neighborhood Rule Table (CA-style)
286
+ # =============================================================================
287
+
288
+ def learn_neighborhood_rule(train_pairs):
289
+ """
290
+ For same-shape tasks: build a lookup table
291
+ (center_color, sorted_neighbor_colors) -> output_color
292
+ If consistent across all training pairs, return the rule.
293
+ """
294
+ # Check all same shape
295
+ for pair in train_pairs:
296
+ inp = np.array(pair['input'])
297
+ out = np.array(pair['output'])
298
+ if inp.shape != out.shape:
299
+ return None
300
+
301
+ rule_table = {} # (center, neighbor_sig) -> output_color
302
+ conflicts = False
303
+
304
+ for pair in train_pairs:
305
+ inp = np.array(pair['input'], dtype=int)
306
+ out = np.array(pair['output'], dtype=int)
307
+ h, w = inp.shape
308
+
309
+ for r in range(h):
310
+ for c in range(w):
311
+ center = int(inp[r, c])
312
+ out_val = int(out[r, c])
313
+
314
+ # Get 4-neighbor colors
315
+ neighbors = []
316
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
317
+ nr, nc = r + dr, c + dc
318
+ if 0 <= nr < h and 0 <= nc < w:
319
+ neighbors.append(int(inp[nr, nc]))
320
+ else:
321
+ neighbors.append(-1) # border sentinel
322
+
323
+ key = (center, tuple(sorted(neighbors)))
324
+
325
+ if key in rule_table:
326
+ if rule_table[key] != out_val:
327
+ conflicts = True
328
+ break
329
+ else:
330
+ rule_table[key] = out_val
331
+
332
+ if conflicts:
333
+ break
334
+ if conflicts:
335
+ break
336
+
337
+ if conflicts:
338
+ return None
339
+
340
+ return rule_table
341
+
342
+
343
+ def apply_neighborhood_rule(grid, rule_table):
344
+ """Apply a learned neighborhood rule table to a grid."""
345
+ grid = np.array(grid, dtype=int)
346
+ h, w = grid.shape
347
+ result = grid.copy()
348
+
349
+ for r in range(h):
350
+ for c in range(w):
351
+ center = int(grid[r, c])
352
+ neighbors = []
353
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
354
+ nr, nc = r + dr, c + dc
355
+ if 0 <= nr < h and 0 <= nc < w:
356
+ neighbors.append(int(grid[nr, nc]))
357
+ else:
358
+ neighbors.append(-1)
359
+
360
+ key = (center, tuple(sorted(neighbors)))
361
+ if key in rule_table:
362
+ result[r, c] = rule_table[key]
363
+
364
+ return result
365
+
366
+
367
+ # =============================================================================
368
+ # Global Fill Rules (not object-specific)
369
+ # =============================================================================
370
+
371
+ def try_global_enclosed_fill(train_pairs):
372
+ """
373
+ Try: fill all enclosed bg regions with a consistent color.
374
+ Learn the fill color from training pairs.
375
+ """
376
+ fill_colors = []
377
+
378
+ for pair in train_pairs:
379
+ inp = np.array(pair['input'], dtype=int)
380
+ out = np.array(pair['output'], dtype=int)
381
+ if inp.shape != out.shape:
382
+ return None
383
+
384
+ bg = Counter(inp.flatten().tolist()).most_common(1)[0][0]
385
+ enclosed = get_enclosed_bg_regions(inp, bg)
386
+
387
+ if not np.any(enclosed):
388
+ continue
389
+
390
+ # What color fills the enclosed region in output?
391
+ fill_vals = out[enclosed]
392
+ unique = np.unique(fill_vals)
393
+ non_bg = unique[unique != bg]
394
+ if len(non_bg) == 1:
395
+ fill_colors.append(int(non_bg[0]))
396
+ elif len(non_bg) > 1:
397
+ return None # multiple colors fill enclosed — too complex
398
+
399
+ if not fill_colors:
400
+ return None
401
+
402
+ # Check consistency
403
+ if len(set(fill_colors)) != 1:
404
+ return None
405
+
406
+ fill_color = fill_colors[0]
407
+
408
+ # Validate on all pairs
409
+ for pair in train_pairs:
410
+ inp = np.array(pair['input'], dtype=int)
411
+ out = np.array(pair['output'], dtype=int)
412
+ bg = Counter(inp.flatten().tolist()).most_common(1)[0][0]
413
+
414
+ result = inp.copy()
415
+ enclosed = get_enclosed_bg_regions(inp, bg)
416
+ result[enclosed] = fill_color
417
+
418
+ if not np.array_equal(result, out):
419
+ return None
420
+
421
+ return fill_color
422
+
423
+
424
+ def try_per_object_enclosed_fill(train_pairs):
425
+ """
426
+ Try: for each object, fill its enclosed interior with its own color.
427
+ """
428
+ for pair in train_pairs:
429
+ inp = np.array(pair['input'], dtype=int)
430
+ out = np.array(pair['output'], dtype=int)
431
+ if inp.shape != out.shape:
432
+ return False
433
+
434
+ objects, bg = extract_objects_multi(inp, connectivity=4)
435
+ result = inp.copy()
436
+
437
+ for obj in objects:
438
+ mask = obj['mask']
439
+ color = obj['color']
440
+ rmin, cmin, rmax, cmax = obj['bbox']
441
+ h, w = inp.shape
442
+
443
+ for r in range(rmin, rmax + 1):
444
+ for c in range(cmin, cmax + 1):
445
+ if result[r, c] == bg:
446
+ # Ray-cast: is this cell inside the object?
447
+ inside = True
448
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
449
+ found = False
450
+ nr, nc = r + dr, c + dc
451
+ while 0 <= nr < h and 0 <= nc < w:
452
+ if mask[nr, nc]:
453
+ found = True
454
+ break
455
+ nr += dr
456
+ nc += dc
457
+ if not found:
458
+ inside = False
459
+ break
460
+ if inside:
461
+ result[r, c] = color
462
+
463
+ if not np.array_equal(result, out):
464
+ return False
465
+
466
+ return True
467
+
468
+
469
+ def apply_per_object_enclosed_fill(grid):
470
+ """Apply per-object enclosed fill."""
471
+ grid = np.array(grid, dtype=int)
472
+ objects, bg = extract_objects_multi(grid, connectivity=4)
473
+ result = grid.copy()
474
+ h, w = grid.shape
475
+
476
+ for obj in objects:
477
+ mask = obj['mask']
478
+ color = obj['color']
479
+ rmin, cmin, rmax, cmax = obj['bbox']
480
+
481
+ for r in range(rmin, rmax + 1):
482
+ for c in range(cmin, cmax + 1):
483
+ if result[r, c] == bg:
484
+ inside = True
485
+ for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
486
+ found = False
487
+ nr, nc = r + dr, c + dc
488
+ while 0 <= nr < h and 0 <= nc < w:
489
+ if mask[nr, nc]:
490
+ found = True
491
+ break
492
+ nr += dr
493
+ nc += dc
494
+ if not found:
495
+ inside = False
496
+ break
497
+ if inside:
498
+ result[r, c] = color
499
+ return result
500
+
501
+
502
+ # =============================================================================
503
+ # Main Enumeration Engine
504
+ # =============================================================================
505
+
506
+ def enumerate_rules(train_pairs, max_time_ms=5000):
507
+ """
508
+ Try all rule strategies on a task. Return the first that passes
509
+ all training pairs, or None.
510
+
511
+ Strategies (in order):
512
+ 1. Global enclosed fill (single color)
513
+ 2. Per-object enclosed fill
514
+ 3. Neighborhood rule table (CA-style)
515
+ 4. Object predicate × action enumeration
516
+ """
517
+ import time
518
+ start = time.time()
519
+
520
+ # Check same shape
521
+ all_same_shape = all(
522
+ np.array(p['input']).shape == np.array(p['output']).shape
523
+ for p in train_pairs
524
+ )
525
+
526
+ # === Strategy 1: Global enclosed fill ===
527
+ fill_color = try_global_enclosed_fill(train_pairs)
528
+ if fill_color is not None:
529
+ def rule_fn(grid, _fc=fill_color):
530
+ g = np.array(grid, dtype=int)
531
+ bg = Counter(g.flatten().tolist()).most_common(1)[0][0]
532
+ enclosed = get_enclosed_bg_regions(g, bg)
533
+ result = g.copy()
534
+ result[enclosed] = _fc
535
+ return result
536
+ return ('global_enclosed_fill', rule_fn)
537
+
538
+ # === Strategy 2: Per-object enclosed fill ===
539
+ if all_same_shape and try_per_object_enclosed_fill(train_pairs):
540
+ return ('per_object_enclosed_fill', apply_per_object_enclosed_fill)
541
+
542
+ # === Strategy 3: Neighborhood rule table ===
543
+ if all_same_shape:
544
+ rule_table = learn_neighborhood_rule(train_pairs)
545
+ if rule_table is not None:
546
+ # Validate
547
+ valid = True
548
+ for pair in train_pairs:
549
+ pred = apply_neighborhood_rule(pair['input'], rule_table)
550
+ if not np.array_equal(pred, np.array(pair['output'], dtype=int)):
551
+ valid = False
552
+ break
553
+ if valid:
554
+ def rule_fn(grid, _rt=rule_table):
555
+ return apply_neighborhood_rule(grid, _rt)
556
+ return ('neighborhood_rule', rule_fn)
557
+
558
+ # === Strategy 4: Object predicate × action enumeration ===
559
+ if all_same_shape and (time.time() - start) * 1000 < max_time_ms:
560
+ result = _enumerate_predicate_actions(train_pairs)
561
+ if result is not None:
562
+ return result
563
+
564
+ return None
565
+
566
+
567
+ def _enumerate_predicate_actions(train_pairs):
568
+ """Enumerate (connectivity × predicate × action) combinations."""
569
+ for connectivity in [4, 8]:
570
+ # Extract objects for all pairs
571
+ pair_data = []
572
+ for pair in train_pairs:
573
+ inp = np.array(pair['input'], dtype=int)
574
+ out = np.array(pair['output'], dtype=int)
575
+ objects, bg = extract_objects_multi(inp, connectivity)
576
+ pair_data.append((inp, out, objects, bg))
577
+
578
+ if not pair_data or not pair_data[0][2]:
579
+ continue
580
+
581
+ # Build predicates from first pair
582
+ first_objects = pair_data[0][2]
583
+ first_bg = pair_data[0][3]
584
+ predicates = _build_predicates(first_objects, first_bg)
585
+ actions = _build_actions(first_objects, first_bg, pair_data[0][0].shape)
586
+
587
+ # Enumerate
588
+ for pred_name, pred_fn in predicates.items():
589
+ for act_name, act_fn in actions.items():
590
+ # Test on all pairs
591
+ all_pass = True
592
+ for inp, out, objects, bg in pair_data:
593
+ if not objects:
594
+ all_pass = False
595
+ break
596
+
597
+ # Rebuild predicates for this pair's objects
598
+ local_preds = _build_predicates(objects, bg)
599
+ local_pred = local_preds.get(pred_name)
600
+ if local_pred is None:
601
+ all_pass = False
602
+ break
603
+
604
+ # Select objects matching predicate
605
+ selected = [o for o in objects if local_pred(o)]
606
+ if not selected and pred_name != 'always_true':
607
+ all_pass = False
608
+ break
609
+
610
+ selected_masks = [o['mask'] for o in selected]
611
+
612
+ try:
613
+ result = act_fn(inp, selected_masks)
614
+ if not np.array_equal(result, out):
615
+ all_pass = False
616
+ break
617
+ except Exception:
618
+ all_pass = False
619
+ break
620
+
621
+ if all_pass:
622
+ # Build a reusable rule function
623
+ def make_rule(pn, an, conn):
624
+ def rule_fn(grid):
625
+ g = np.array(grid, dtype=int)
626
+ objs, bg = extract_objects_multi(g, conn)
627
+ if not objs:
628
+ return g
629
+ preds = _build_predicates(objs, bg)
630
+ pred = preds.get(pn, lambda o: False)
631
+ acts = _build_actions(objs, bg, g.shape)
632
+ act = acts.get(an)
633
+ if act is None:
634
+ return g
635
+ selected = [o for o in objs if pred(o)]
636
+ masks = [o['mask'] for o in selected]
637
+ return act(g, masks)
638
+ return rule_fn
639
+
640
+ return (f'predicate_{pred_name}_action_{act_name}_conn{connectivity}',
641
+ make_rule(pred_name, act_name, connectivity))
642
+
643
+ return None