Add predicate enumeration engine: neighborhood rules + enclosed fill + object predicates — 70/400 (17.5%)
Browse files- itt_solver/predicate_engine.py +643 -0
itt_solver/predicate_engine.py
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Object Predicate + Action Enumeration Engine
|
| 3 |
+
=============================================
|
| 4 |
+
|
| 5 |
+
For each ARC task, enumerate combinations of:
|
| 6 |
+
(object_abstraction) x (predicate) x (action)
|
| 7 |
+
|
| 8 |
+
Test each rule against ALL training pairs. If a rule produces
|
| 9 |
+
exact output for every pair, use it.
|
| 10 |
+
|
| 11 |
+
This is the GPAR approach simplified to pure Python — no PDDL,
|
| 12 |
+
no planner. Just brute-force enumeration of ~600 rule templates.
|
| 13 |
+
|
| 14 |
+
Covers:
|
| 15 |
+
- Fill miss (35% of unsolved): enclosed_by, neighbor_count conditions
|
| 16 |
+
- Recolor miss (24%): object attribute conditions (size, color, position)
|
| 17 |
+
- Shape change (25%): extract by predicate
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from collections import Counter, deque
|
| 22 |
+
from typing import Dict, List, Tuple, Optional, Set, Callable
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# =============================================================================
|
| 26 |
+
# Object Extraction (robust, multiple abstractions)
|
| 27 |
+
# =============================================================================
|
| 28 |
+
|
| 29 |
+
def _flood(grid, r, c, visited, color, connectivity=4):
|
| 30 |
+
"""BFS flood fill for a single color component."""
|
| 31 |
+
h, w = grid.shape
|
| 32 |
+
cells = set()
|
| 33 |
+
queue = deque([(r, c)])
|
| 34 |
+
visited[r, c] = True
|
| 35 |
+
deltas = [(-1,0),(1,0),(0,-1),(0,1)]
|
| 36 |
+
if connectivity == 8:
|
| 37 |
+
deltas += [(-1,-1),(-1,1),(1,-1),(1,1)]
|
| 38 |
+
while queue:
|
| 39 |
+
cr, cc = queue.popleft()
|
| 40 |
+
cells.add((cr, cc))
|
| 41 |
+
for dr, dc in deltas:
|
| 42 |
+
nr, nc = cr + dr, cc + dc
|
| 43 |
+
if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc] and grid[nr, nc] == color:
|
| 44 |
+
visited[nr, nc] = True
|
| 45 |
+
queue.append((nr, nc))
|
| 46 |
+
return cells
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def extract_objects_multi(grid, connectivity=4):
|
| 50 |
+
"""Extract all non-background connected components.
|
| 51 |
+
Returns list of dicts with color, cells, mask, bbox, size, touches_border."""
|
| 52 |
+
grid = np.array(grid, dtype=int)
|
| 53 |
+
h, w = grid.shape
|
| 54 |
+
bg = Counter(grid.flatten().tolist()).most_common(1)[0][0]
|
| 55 |
+
visited = np.zeros((h, w), dtype=bool)
|
| 56 |
+
objects = []
|
| 57 |
+
|
| 58 |
+
for r in range(h):
|
| 59 |
+
for c in range(w):
|
| 60 |
+
if visited[r, c] or grid[r, c] == bg:
|
| 61 |
+
visited[r, c] = True
|
| 62 |
+
continue
|
| 63 |
+
color = int(grid[r, c])
|
| 64 |
+
cells = _flood(grid, r, c, visited, color, connectivity)
|
| 65 |
+
if not cells:
|
| 66 |
+
continue
|
| 67 |
+
rows = [cr for cr, _ in cells]
|
| 68 |
+
cols = [cc for _, cc in cells]
|
| 69 |
+
rmin, rmax = min(rows), max(rows)
|
| 70 |
+
cmin, cmax = min(cols), max(cols)
|
| 71 |
+
mask = np.zeros((h, w), dtype=bool)
|
| 72 |
+
for cr, cc in cells:
|
| 73 |
+
mask[cr, cc] = True
|
| 74 |
+
touches = any(cr == 0 or cr == h-1 or cc == 0 or cc == w-1 for cr, cc in cells)
|
| 75 |
+
objects.append({
|
| 76 |
+
'color': color,
|
| 77 |
+
'cells': cells,
|
| 78 |
+
'mask': mask,
|
| 79 |
+
'bbox': (rmin, cmin, rmax, cmax),
|
| 80 |
+
'size': len(cells),
|
| 81 |
+
'touches_border': touches,
|
| 82 |
+
'height': rmax - rmin + 1,
|
| 83 |
+
'width': cmax - cmin + 1,
|
| 84 |
+
'center_r': sum(rows) / len(rows),
|
| 85 |
+
'center_c': sum(cols) / len(cols),
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
return objects, bg
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_enclosed_bg_regions(grid, bg):
|
| 92 |
+
"""Find background regions NOT reachable from grid border."""
|
| 93 |
+
grid = np.array(grid, dtype=int)
|
| 94 |
+
h, w = grid.shape
|
| 95 |
+
visited = np.zeros((h, w), dtype=bool)
|
| 96 |
+
queue = deque()
|
| 97 |
+
|
| 98 |
+
# Flood from all border bg cells
|
| 99 |
+
for r in range(h):
|
| 100 |
+
for c in range(w):
|
| 101 |
+
if (r == 0 or r == h-1 or c == 0 or c == w-1) and grid[r, c] == bg:
|
| 102 |
+
if not visited[r, c]:
|
| 103 |
+
visited[r, c] = True
|
| 104 |
+
queue.append((r, c))
|
| 105 |
+
|
| 106 |
+
while queue:
|
| 107 |
+
r, c = queue.popleft()
|
| 108 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 109 |
+
nr, nc = r + dr, c + dc
|
| 110 |
+
if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc] and grid[nr, nc] == bg:
|
| 111 |
+
visited[nr, nc] = True
|
| 112 |
+
queue.append((nr, nc))
|
| 113 |
+
|
| 114 |
+
# Enclosed = bg cells not visited
|
| 115 |
+
enclosed = (grid == bg) & ~visited
|
| 116 |
+
return enclosed
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_neighbor_colors(grid, r, c, bg=0):
|
| 120 |
+
"""Get non-bg neighbor colors (4-connectivity)."""
|
| 121 |
+
h, w = grid.shape
|
| 122 |
+
colors = []
|
| 123 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 124 |
+
nr, nc = r + dr, c + dc
|
| 125 |
+
if 0 <= nr < h and 0 <= nc < w and grid[nr, nc] != bg:
|
| 126 |
+
colors.append(int(grid[nr, nc]))
|
| 127 |
+
return colors
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# =============================================================================
|
| 131 |
+
# Object Predicates
|
| 132 |
+
# =============================================================================
|
| 133 |
+
|
| 134 |
+
def _build_predicates(objects, bg):
|
| 135 |
+
"""Build predicate functions that test object properties."""
|
| 136 |
+
if not objects:
|
| 137 |
+
return {}
|
| 138 |
+
|
| 139 |
+
sizes = [o['size'] for o in objects]
|
| 140 |
+
max_size = max(sizes)
|
| 141 |
+
min_size = min(sizes)
|
| 142 |
+
colors_list = [o['color'] for o in objects]
|
| 143 |
+
color_counts = Counter(colors_list)
|
| 144 |
+
most_common_color = color_counts.most_common(1)[0][0]
|
| 145 |
+
least_common_color = color_counts.most_common()[-1][0]
|
| 146 |
+
|
| 147 |
+
predicates = {
|
| 148 |
+
'is_largest': lambda o: o['size'] == max_size,
|
| 149 |
+
'is_smallest': lambda o: o['size'] == min_size,
|
| 150 |
+
'touches_border': lambda o: o['touches_border'],
|
| 151 |
+
'not_touches_border': lambda o: not o['touches_border'],
|
| 152 |
+
'is_most_common_color': lambda o: o['color'] == most_common_color,
|
| 153 |
+
'is_least_common_color': lambda o: o['color'] == least_common_color,
|
| 154 |
+
'always_true': lambda o: True,
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
# Add per-color predicates
|
| 158 |
+
for color in set(colors_list):
|
| 159 |
+
predicates[f'color_is_{color}'] = (lambda c: lambda o: o['color'] == c)(color)
|
| 160 |
+
|
| 161 |
+
# Size-based
|
| 162 |
+
if len(set(sizes)) > 1:
|
| 163 |
+
median_size = sorted(sizes)[len(sizes) // 2]
|
| 164 |
+
predicates['size_above_median'] = lambda o: o['size'] > median_size
|
| 165 |
+
predicates['size_below_median'] = lambda o: o['size'] < median_size
|
| 166 |
+
|
| 167 |
+
return predicates
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# =============================================================================
|
| 171 |
+
# Actions
|
| 172 |
+
# =============================================================================
|
| 173 |
+
|
| 174 |
+
def _build_actions(objects, bg, grid_shape):
|
| 175 |
+
"""Build action functions that transform a grid based on selected objects."""
|
| 176 |
+
|
| 177 |
+
all_colors = set(o['color'] for o in objects) | {bg}
|
| 178 |
+
|
| 179 |
+
actions = {}
|
| 180 |
+
|
| 181 |
+
# Recolor: change matching objects to a specific color
|
| 182 |
+
for target_color in range(10):
|
| 183 |
+
if target_color == bg:
|
| 184 |
+
continue
|
| 185 |
+
actions[f'recolor_to_{target_color}'] = (
|
| 186 |
+
lambda tc: lambda grid, selected_masks: _apply_recolor(grid, selected_masks, tc)
|
| 187 |
+
)(target_color)
|
| 188 |
+
|
| 189 |
+
# Fill enclosed regions of matching objects
|
| 190 |
+
actions['fill_enclosed'] = lambda grid, selected_masks: _apply_fill_enclosed(grid, selected_masks, bg)
|
| 191 |
+
|
| 192 |
+
# Fill interior (bbox minus object cells)
|
| 193 |
+
actions['fill_interior'] = lambda grid, selected_masks: _apply_fill_interior(grid, selected_masks, bg)
|
| 194 |
+
|
| 195 |
+
# Remove (set to bg)
|
| 196 |
+
actions['remove'] = lambda grid, selected_masks: _apply_remove(grid, selected_masks, bg)
|
| 197 |
+
|
| 198 |
+
# Extract (keep only selected, clear rest)
|
| 199 |
+
actions['extract'] = lambda grid, selected_masks: _apply_extract(grid, selected_masks, bg)
|
| 200 |
+
|
| 201 |
+
return actions
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _apply_recolor(grid, selected_masks, target_color):
|
| 205 |
+
result = grid.copy()
|
| 206 |
+
for mask in selected_masks:
|
| 207 |
+
result[mask] = target_color
|
| 208 |
+
return result
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _apply_fill_enclosed(grid, selected_masks, bg):
|
| 212 |
+
"""Fill enclosed background regions that are bounded by selected objects."""
|
| 213 |
+
result = grid.copy()
|
| 214 |
+
h, w = grid.shape
|
| 215 |
+
|
| 216 |
+
for mask in selected_masks:
|
| 217 |
+
color = int(grid[mask][0]) if np.any(mask) else 0
|
| 218 |
+
if color == 0:
|
| 219 |
+
continue
|
| 220 |
+
# Find bbox of this object
|
| 221 |
+
rows, cols = np.where(mask)
|
| 222 |
+
if len(rows) == 0:
|
| 223 |
+
continue
|
| 224 |
+
rmin, rmax = rows.min(), rows.max()
|
| 225 |
+
cmin, cmax = cols.min(), cols.max()
|
| 226 |
+
|
| 227 |
+
# Within bbox, find bg cells enclosed by this object
|
| 228 |
+
for r in range(rmin, rmax + 1):
|
| 229 |
+
for c in range(cmin, cmax + 1):
|
| 230 |
+
if result[r, c] == bg:
|
| 231 |
+
# Check if this bg cell is "inside" the object
|
| 232 |
+
# Simple test: surrounded on all 4 cardinal directions by object cells
|
| 233 |
+
inside = True
|
| 234 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 235 |
+
found = False
|
| 236 |
+
nr, nc = r + dr, c + dc
|
| 237 |
+
while 0 <= nr < h and 0 <= nc < w:
|
| 238 |
+
if mask[nr, nc]:
|
| 239 |
+
found = True
|
| 240 |
+
break
|
| 241 |
+
nr += dr
|
| 242 |
+
nc += dc
|
| 243 |
+
if not found:
|
| 244 |
+
inside = False
|
| 245 |
+
break
|
| 246 |
+
if inside:
|
| 247 |
+
result[r, c] = color
|
| 248 |
+
return result
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _apply_fill_interior(grid, selected_masks, bg):
|
| 252 |
+
"""Fill the bounding box interior of selected objects with the object's color."""
|
| 253 |
+
result = grid.copy()
|
| 254 |
+
for mask in selected_masks:
|
| 255 |
+
color = int(grid[mask][0]) if np.any(mask) else 0
|
| 256 |
+
if color == 0:
|
| 257 |
+
continue
|
| 258 |
+
rows, cols = np.where(mask)
|
| 259 |
+
if len(rows) == 0:
|
| 260 |
+
continue
|
| 261 |
+
rmin, rmax = rows.min(), rows.max()
|
| 262 |
+
cmin, cmax = cols.min(), cols.max()
|
| 263 |
+
for r in range(rmin, rmax + 1):
|
| 264 |
+
for c in range(cmin, cmax + 1):
|
| 265 |
+
if result[r, c] == bg:
|
| 266 |
+
result[r, c] = color
|
| 267 |
+
return result
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _apply_remove(grid, selected_masks, bg):
|
| 271 |
+
result = grid.copy()
|
| 272 |
+
for mask in selected_masks:
|
| 273 |
+
result[mask] = bg
|
| 274 |
+
return result
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _apply_extract(grid, selected_masks, bg):
|
| 278 |
+
result = np.full_like(grid, bg)
|
| 279 |
+
for mask in selected_masks:
|
| 280 |
+
result[mask] = grid[mask]
|
| 281 |
+
return result
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# =============================================================================
|
| 285 |
+
# Neighborhood Rule Table (CA-style)
|
| 286 |
+
# =============================================================================
|
| 287 |
+
|
| 288 |
+
def learn_neighborhood_rule(train_pairs):
|
| 289 |
+
"""
|
| 290 |
+
For same-shape tasks: build a lookup table
|
| 291 |
+
(center_color, sorted_neighbor_colors) -> output_color
|
| 292 |
+
If consistent across all training pairs, return the rule.
|
| 293 |
+
"""
|
| 294 |
+
# Check all same shape
|
| 295 |
+
for pair in train_pairs:
|
| 296 |
+
inp = np.array(pair['input'])
|
| 297 |
+
out = np.array(pair['output'])
|
| 298 |
+
if inp.shape != out.shape:
|
| 299 |
+
return None
|
| 300 |
+
|
| 301 |
+
rule_table = {} # (center, neighbor_sig) -> output_color
|
| 302 |
+
conflicts = False
|
| 303 |
+
|
| 304 |
+
for pair in train_pairs:
|
| 305 |
+
inp = np.array(pair['input'], dtype=int)
|
| 306 |
+
out = np.array(pair['output'], dtype=int)
|
| 307 |
+
h, w = inp.shape
|
| 308 |
+
|
| 309 |
+
for r in range(h):
|
| 310 |
+
for c in range(w):
|
| 311 |
+
center = int(inp[r, c])
|
| 312 |
+
out_val = int(out[r, c])
|
| 313 |
+
|
| 314 |
+
# Get 4-neighbor colors
|
| 315 |
+
neighbors = []
|
| 316 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 317 |
+
nr, nc = r + dr, c + dc
|
| 318 |
+
if 0 <= nr < h and 0 <= nc < w:
|
| 319 |
+
neighbors.append(int(inp[nr, nc]))
|
| 320 |
+
else:
|
| 321 |
+
neighbors.append(-1) # border sentinel
|
| 322 |
+
|
| 323 |
+
key = (center, tuple(sorted(neighbors)))
|
| 324 |
+
|
| 325 |
+
if key in rule_table:
|
| 326 |
+
if rule_table[key] != out_val:
|
| 327 |
+
conflicts = True
|
| 328 |
+
break
|
| 329 |
+
else:
|
| 330 |
+
rule_table[key] = out_val
|
| 331 |
+
|
| 332 |
+
if conflicts:
|
| 333 |
+
break
|
| 334 |
+
if conflicts:
|
| 335 |
+
break
|
| 336 |
+
|
| 337 |
+
if conflicts:
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
return rule_table
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def apply_neighborhood_rule(grid, rule_table):
|
| 344 |
+
"""Apply a learned neighborhood rule table to a grid."""
|
| 345 |
+
grid = np.array(grid, dtype=int)
|
| 346 |
+
h, w = grid.shape
|
| 347 |
+
result = grid.copy()
|
| 348 |
+
|
| 349 |
+
for r in range(h):
|
| 350 |
+
for c in range(w):
|
| 351 |
+
center = int(grid[r, c])
|
| 352 |
+
neighbors = []
|
| 353 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 354 |
+
nr, nc = r + dr, c + dc
|
| 355 |
+
if 0 <= nr < h and 0 <= nc < w:
|
| 356 |
+
neighbors.append(int(grid[nr, nc]))
|
| 357 |
+
else:
|
| 358 |
+
neighbors.append(-1)
|
| 359 |
+
|
| 360 |
+
key = (center, tuple(sorted(neighbors)))
|
| 361 |
+
if key in rule_table:
|
| 362 |
+
result[r, c] = rule_table[key]
|
| 363 |
+
|
| 364 |
+
return result
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# =============================================================================
|
| 368 |
+
# Global Fill Rules (not object-specific)
|
| 369 |
+
# =============================================================================
|
| 370 |
+
|
| 371 |
+
def try_global_enclosed_fill(train_pairs):
|
| 372 |
+
"""
|
| 373 |
+
Try: fill all enclosed bg regions with a consistent color.
|
| 374 |
+
Learn the fill color from training pairs.
|
| 375 |
+
"""
|
| 376 |
+
fill_colors = []
|
| 377 |
+
|
| 378 |
+
for pair in train_pairs:
|
| 379 |
+
inp = np.array(pair['input'], dtype=int)
|
| 380 |
+
out = np.array(pair['output'], dtype=int)
|
| 381 |
+
if inp.shape != out.shape:
|
| 382 |
+
return None
|
| 383 |
+
|
| 384 |
+
bg = Counter(inp.flatten().tolist()).most_common(1)[0][0]
|
| 385 |
+
enclosed = get_enclosed_bg_regions(inp, bg)
|
| 386 |
+
|
| 387 |
+
if not np.any(enclosed):
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
# What color fills the enclosed region in output?
|
| 391 |
+
fill_vals = out[enclosed]
|
| 392 |
+
unique = np.unique(fill_vals)
|
| 393 |
+
non_bg = unique[unique != bg]
|
| 394 |
+
if len(non_bg) == 1:
|
| 395 |
+
fill_colors.append(int(non_bg[0]))
|
| 396 |
+
elif len(non_bg) > 1:
|
| 397 |
+
return None # multiple colors fill enclosed — too complex
|
| 398 |
+
|
| 399 |
+
if not fill_colors:
|
| 400 |
+
return None
|
| 401 |
+
|
| 402 |
+
# Check consistency
|
| 403 |
+
if len(set(fill_colors)) != 1:
|
| 404 |
+
return None
|
| 405 |
+
|
| 406 |
+
fill_color = fill_colors[0]
|
| 407 |
+
|
| 408 |
+
# Validate on all pairs
|
| 409 |
+
for pair in train_pairs:
|
| 410 |
+
inp = np.array(pair['input'], dtype=int)
|
| 411 |
+
out = np.array(pair['output'], dtype=int)
|
| 412 |
+
bg = Counter(inp.flatten().tolist()).most_common(1)[0][0]
|
| 413 |
+
|
| 414 |
+
result = inp.copy()
|
| 415 |
+
enclosed = get_enclosed_bg_regions(inp, bg)
|
| 416 |
+
result[enclosed] = fill_color
|
| 417 |
+
|
| 418 |
+
if not np.array_equal(result, out):
|
| 419 |
+
return None
|
| 420 |
+
|
| 421 |
+
return fill_color
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def try_per_object_enclosed_fill(train_pairs):
|
| 425 |
+
"""
|
| 426 |
+
Try: for each object, fill its enclosed interior with its own color.
|
| 427 |
+
"""
|
| 428 |
+
for pair in train_pairs:
|
| 429 |
+
inp = np.array(pair['input'], dtype=int)
|
| 430 |
+
out = np.array(pair['output'], dtype=int)
|
| 431 |
+
if inp.shape != out.shape:
|
| 432 |
+
return False
|
| 433 |
+
|
| 434 |
+
objects, bg = extract_objects_multi(inp, connectivity=4)
|
| 435 |
+
result = inp.copy()
|
| 436 |
+
|
| 437 |
+
for obj in objects:
|
| 438 |
+
mask = obj['mask']
|
| 439 |
+
color = obj['color']
|
| 440 |
+
rmin, cmin, rmax, cmax = obj['bbox']
|
| 441 |
+
h, w = inp.shape
|
| 442 |
+
|
| 443 |
+
for r in range(rmin, rmax + 1):
|
| 444 |
+
for c in range(cmin, cmax + 1):
|
| 445 |
+
if result[r, c] == bg:
|
| 446 |
+
# Ray-cast: is this cell inside the object?
|
| 447 |
+
inside = True
|
| 448 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 449 |
+
found = False
|
| 450 |
+
nr, nc = r + dr, c + dc
|
| 451 |
+
while 0 <= nr < h and 0 <= nc < w:
|
| 452 |
+
if mask[nr, nc]:
|
| 453 |
+
found = True
|
| 454 |
+
break
|
| 455 |
+
nr += dr
|
| 456 |
+
nc += dc
|
| 457 |
+
if not found:
|
| 458 |
+
inside = False
|
| 459 |
+
break
|
| 460 |
+
if inside:
|
| 461 |
+
result[r, c] = color
|
| 462 |
+
|
| 463 |
+
if not np.array_equal(result, out):
|
| 464 |
+
return False
|
| 465 |
+
|
| 466 |
+
return True
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def apply_per_object_enclosed_fill(grid):
|
| 470 |
+
"""Apply per-object enclosed fill."""
|
| 471 |
+
grid = np.array(grid, dtype=int)
|
| 472 |
+
objects, bg = extract_objects_multi(grid, connectivity=4)
|
| 473 |
+
result = grid.copy()
|
| 474 |
+
h, w = grid.shape
|
| 475 |
+
|
| 476 |
+
for obj in objects:
|
| 477 |
+
mask = obj['mask']
|
| 478 |
+
color = obj['color']
|
| 479 |
+
rmin, cmin, rmax, cmax = obj['bbox']
|
| 480 |
+
|
| 481 |
+
for r in range(rmin, rmax + 1):
|
| 482 |
+
for c in range(cmin, cmax + 1):
|
| 483 |
+
if result[r, c] == bg:
|
| 484 |
+
inside = True
|
| 485 |
+
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
|
| 486 |
+
found = False
|
| 487 |
+
nr, nc = r + dr, c + dc
|
| 488 |
+
while 0 <= nr < h and 0 <= nc < w:
|
| 489 |
+
if mask[nr, nc]:
|
| 490 |
+
found = True
|
| 491 |
+
break
|
| 492 |
+
nr += dr
|
| 493 |
+
nc += dc
|
| 494 |
+
if not found:
|
| 495 |
+
inside = False
|
| 496 |
+
break
|
| 497 |
+
if inside:
|
| 498 |
+
result[r, c] = color
|
| 499 |
+
return result
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# =============================================================================
|
| 503 |
+
# Main Enumeration Engine
|
| 504 |
+
# =============================================================================
|
| 505 |
+
|
| 506 |
+
def enumerate_rules(train_pairs, max_time_ms=5000):
|
| 507 |
+
"""
|
| 508 |
+
Try all rule strategies on a task. Return the first that passes
|
| 509 |
+
all training pairs, or None.
|
| 510 |
+
|
| 511 |
+
Strategies (in order):
|
| 512 |
+
1. Global enclosed fill (single color)
|
| 513 |
+
2. Per-object enclosed fill
|
| 514 |
+
3. Neighborhood rule table (CA-style)
|
| 515 |
+
4. Object predicate × action enumeration
|
| 516 |
+
"""
|
| 517 |
+
import time
|
| 518 |
+
start = time.time()
|
| 519 |
+
|
| 520 |
+
# Check same shape
|
| 521 |
+
all_same_shape = all(
|
| 522 |
+
np.array(p['input']).shape == np.array(p['output']).shape
|
| 523 |
+
for p in train_pairs
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
# === Strategy 1: Global enclosed fill ===
|
| 527 |
+
fill_color = try_global_enclosed_fill(train_pairs)
|
| 528 |
+
if fill_color is not None:
|
| 529 |
+
def rule_fn(grid, _fc=fill_color):
|
| 530 |
+
g = np.array(grid, dtype=int)
|
| 531 |
+
bg = Counter(g.flatten().tolist()).most_common(1)[0][0]
|
| 532 |
+
enclosed = get_enclosed_bg_regions(g, bg)
|
| 533 |
+
result = g.copy()
|
| 534 |
+
result[enclosed] = _fc
|
| 535 |
+
return result
|
| 536 |
+
return ('global_enclosed_fill', rule_fn)
|
| 537 |
+
|
| 538 |
+
# === Strategy 2: Per-object enclosed fill ===
|
| 539 |
+
if all_same_shape and try_per_object_enclosed_fill(train_pairs):
|
| 540 |
+
return ('per_object_enclosed_fill', apply_per_object_enclosed_fill)
|
| 541 |
+
|
| 542 |
+
# === Strategy 3: Neighborhood rule table ===
|
| 543 |
+
if all_same_shape:
|
| 544 |
+
rule_table = learn_neighborhood_rule(train_pairs)
|
| 545 |
+
if rule_table is not None:
|
| 546 |
+
# Validate
|
| 547 |
+
valid = True
|
| 548 |
+
for pair in train_pairs:
|
| 549 |
+
pred = apply_neighborhood_rule(pair['input'], rule_table)
|
| 550 |
+
if not np.array_equal(pred, np.array(pair['output'], dtype=int)):
|
| 551 |
+
valid = False
|
| 552 |
+
break
|
| 553 |
+
if valid:
|
| 554 |
+
def rule_fn(grid, _rt=rule_table):
|
| 555 |
+
return apply_neighborhood_rule(grid, _rt)
|
| 556 |
+
return ('neighborhood_rule', rule_fn)
|
| 557 |
+
|
| 558 |
+
# === Strategy 4: Object predicate × action enumeration ===
|
| 559 |
+
if all_same_shape and (time.time() - start) * 1000 < max_time_ms:
|
| 560 |
+
result = _enumerate_predicate_actions(train_pairs)
|
| 561 |
+
if result is not None:
|
| 562 |
+
return result
|
| 563 |
+
|
| 564 |
+
return None
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _enumerate_predicate_actions(train_pairs):
|
| 568 |
+
"""Enumerate (connectivity × predicate × action) combinations."""
|
| 569 |
+
for connectivity in [4, 8]:
|
| 570 |
+
# Extract objects for all pairs
|
| 571 |
+
pair_data = []
|
| 572 |
+
for pair in train_pairs:
|
| 573 |
+
inp = np.array(pair['input'], dtype=int)
|
| 574 |
+
out = np.array(pair['output'], dtype=int)
|
| 575 |
+
objects, bg = extract_objects_multi(inp, connectivity)
|
| 576 |
+
pair_data.append((inp, out, objects, bg))
|
| 577 |
+
|
| 578 |
+
if not pair_data or not pair_data[0][2]:
|
| 579 |
+
continue
|
| 580 |
+
|
| 581 |
+
# Build predicates from first pair
|
| 582 |
+
first_objects = pair_data[0][2]
|
| 583 |
+
first_bg = pair_data[0][3]
|
| 584 |
+
predicates = _build_predicates(first_objects, first_bg)
|
| 585 |
+
actions = _build_actions(first_objects, first_bg, pair_data[0][0].shape)
|
| 586 |
+
|
| 587 |
+
# Enumerate
|
| 588 |
+
for pred_name, pred_fn in predicates.items():
|
| 589 |
+
for act_name, act_fn in actions.items():
|
| 590 |
+
# Test on all pairs
|
| 591 |
+
all_pass = True
|
| 592 |
+
for inp, out, objects, bg in pair_data:
|
| 593 |
+
if not objects:
|
| 594 |
+
all_pass = False
|
| 595 |
+
break
|
| 596 |
+
|
| 597 |
+
# Rebuild predicates for this pair's objects
|
| 598 |
+
local_preds = _build_predicates(objects, bg)
|
| 599 |
+
local_pred = local_preds.get(pred_name)
|
| 600 |
+
if local_pred is None:
|
| 601 |
+
all_pass = False
|
| 602 |
+
break
|
| 603 |
+
|
| 604 |
+
# Select objects matching predicate
|
| 605 |
+
selected = [o for o in objects if local_pred(o)]
|
| 606 |
+
if not selected and pred_name != 'always_true':
|
| 607 |
+
all_pass = False
|
| 608 |
+
break
|
| 609 |
+
|
| 610 |
+
selected_masks = [o['mask'] for o in selected]
|
| 611 |
+
|
| 612 |
+
try:
|
| 613 |
+
result = act_fn(inp, selected_masks)
|
| 614 |
+
if not np.array_equal(result, out):
|
| 615 |
+
all_pass = False
|
| 616 |
+
break
|
| 617 |
+
except Exception:
|
| 618 |
+
all_pass = False
|
| 619 |
+
break
|
| 620 |
+
|
| 621 |
+
if all_pass:
|
| 622 |
+
# Build a reusable rule function
|
| 623 |
+
def make_rule(pn, an, conn):
|
| 624 |
+
def rule_fn(grid):
|
| 625 |
+
g = np.array(grid, dtype=int)
|
| 626 |
+
objs, bg = extract_objects_multi(g, conn)
|
| 627 |
+
if not objs:
|
| 628 |
+
return g
|
| 629 |
+
preds = _build_predicates(objs, bg)
|
| 630 |
+
pred = preds.get(pn, lambda o: False)
|
| 631 |
+
acts = _build_actions(objs, bg, g.shape)
|
| 632 |
+
act = acts.get(an)
|
| 633 |
+
if act is None:
|
| 634 |
+
return g
|
| 635 |
+
selected = [o for o in objs if pred(o)]
|
| 636 |
+
masks = [o['mask'] for o in selected]
|
| 637 |
+
return act(g, masks)
|
| 638 |
+
return rule_fn
|
| 639 |
+
|
| 640 |
+
return (f'predicate_{pred_name}_action_{act_name}_conn{connectivity}',
|
| 641 |
+
make_rule(pred_name, act_name, connectivity))
|
| 642 |
+
|
| 643 |
+
return None
|