ARC-AGI / itt_solver /predicate_engine.py
rogermt's picture
Add predicate enumeration engine: neighborhood rules + enclosed fill + object predicates — 70/400 (17.5%)
8cf6976 verified
raw
history blame
22.4 kB
"""
Object Predicate + Action Enumeration Engine
=============================================
For each ARC task, enumerate combinations of:
(object_abstraction) x (predicate) x (action)
Test each rule against ALL training pairs. If a rule produces
exact output for every pair, use it.
This is the GPAR approach simplified to pure Python — no PDDL,
no planner. Just brute-force enumeration of ~600 rule templates.
Covers:
- Fill miss (35% of unsolved): enclosed_by, neighbor_count conditions
- Recolor miss (24%): object attribute conditions (size, color, position)
- Shape change (25%): extract by predicate
"""
import numpy as np
from collections import Counter, deque
from typing import Dict, List, Tuple, Optional, Set, Callable
# =============================================================================
# Object Extraction (robust, multiple abstractions)
# =============================================================================
def _flood(grid, r, c, visited, color, connectivity=4):
"""BFS flood fill for a single color component."""
h, w = grid.shape
cells = set()
queue = deque([(r, c)])
visited[r, c] = True
deltas = [(-1,0),(1,0),(0,-1),(0,1)]
if connectivity == 8:
deltas += [(-1,-1),(-1,1),(1,-1),(1,1)]
while queue:
cr, cc = queue.popleft()
cells.add((cr, cc))
for dr, dc in deltas:
nr, nc = cr + dr, cc + dc
if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc] and grid[nr, nc] == color:
visited[nr, nc] = True
queue.append((nr, nc))
return cells
def extract_objects_multi(grid, connectivity=4):
"""Extract all non-background connected components.
Returns list of dicts with color, cells, mask, bbox, size, touches_border."""
grid = np.array(grid, dtype=int)
h, w = grid.shape
bg = Counter(grid.flatten().tolist()).most_common(1)[0][0]
visited = np.zeros((h, w), dtype=bool)
objects = []
for r in range(h):
for c in range(w):
if visited[r, c] or grid[r, c] == bg:
visited[r, c] = True
continue
color = int(grid[r, c])
cells = _flood(grid, r, c, visited, color, connectivity)
if not cells:
continue
rows = [cr for cr, _ in cells]
cols = [cc for _, cc in cells]
rmin, rmax = min(rows), max(rows)
cmin, cmax = min(cols), max(cols)
mask = np.zeros((h, w), dtype=bool)
for cr, cc in cells:
mask[cr, cc] = True
touches = any(cr == 0 or cr == h-1 or cc == 0 or cc == w-1 for cr, cc in cells)
objects.append({
'color': color,
'cells': cells,
'mask': mask,
'bbox': (rmin, cmin, rmax, cmax),
'size': len(cells),
'touches_border': touches,
'height': rmax - rmin + 1,
'width': cmax - cmin + 1,
'center_r': sum(rows) / len(rows),
'center_c': sum(cols) / len(cols),
})
return objects, bg
def get_enclosed_bg_regions(grid, bg):
"""Find background regions NOT reachable from grid border."""
grid = np.array(grid, dtype=int)
h, w = grid.shape
visited = np.zeros((h, w), dtype=bool)
queue = deque()
# Flood from all border bg cells
for r in range(h):
for c in range(w):
if (r == 0 or r == h-1 or c == 0 or c == w-1) and grid[r, c] == bg:
if not visited[r, c]:
visited[r, c] = True
queue.append((r, c))
while queue:
r, c = queue.popleft()
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < h and 0 <= nc < w and not visited[nr, nc] and grid[nr, nc] == bg:
visited[nr, nc] = True
queue.append((nr, nc))
# Enclosed = bg cells not visited
enclosed = (grid == bg) & ~visited
return enclosed
def get_neighbor_colors(grid, r, c, bg=0):
"""Get non-bg neighbor colors (4-connectivity)."""
h, w = grid.shape
colors = []
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < h and 0 <= nc < w and grid[nr, nc] != bg:
colors.append(int(grid[nr, nc]))
return colors
# =============================================================================
# Object Predicates
# =============================================================================
def _build_predicates(objects, bg):
"""Build predicate functions that test object properties."""
if not objects:
return {}
sizes = [o['size'] for o in objects]
max_size = max(sizes)
min_size = min(sizes)
colors_list = [o['color'] for o in objects]
color_counts = Counter(colors_list)
most_common_color = color_counts.most_common(1)[0][0]
least_common_color = color_counts.most_common()[-1][0]
predicates = {
'is_largest': lambda o: o['size'] == max_size,
'is_smallest': lambda o: o['size'] == min_size,
'touches_border': lambda o: o['touches_border'],
'not_touches_border': lambda o: not o['touches_border'],
'is_most_common_color': lambda o: o['color'] == most_common_color,
'is_least_common_color': lambda o: o['color'] == least_common_color,
'always_true': lambda o: True,
}
# Add per-color predicates
for color in set(colors_list):
predicates[f'color_is_{color}'] = (lambda c: lambda o: o['color'] == c)(color)
# Size-based
if len(set(sizes)) > 1:
median_size = sorted(sizes)[len(sizes) // 2]
predicates['size_above_median'] = lambda o: o['size'] > median_size
predicates['size_below_median'] = lambda o: o['size'] < median_size
return predicates
# =============================================================================
# Actions
# =============================================================================
def _build_actions(objects, bg, grid_shape):
"""Build action functions that transform a grid based on selected objects."""
all_colors = set(o['color'] for o in objects) | {bg}
actions = {}
# Recolor: change matching objects to a specific color
for target_color in range(10):
if target_color == bg:
continue
actions[f'recolor_to_{target_color}'] = (
lambda tc: lambda grid, selected_masks: _apply_recolor(grid, selected_masks, tc)
)(target_color)
# Fill enclosed regions of matching objects
actions['fill_enclosed'] = lambda grid, selected_masks: _apply_fill_enclosed(grid, selected_masks, bg)
# Fill interior (bbox minus object cells)
actions['fill_interior'] = lambda grid, selected_masks: _apply_fill_interior(grid, selected_masks, bg)
# Remove (set to bg)
actions['remove'] = lambda grid, selected_masks: _apply_remove(grid, selected_masks, bg)
# Extract (keep only selected, clear rest)
actions['extract'] = lambda grid, selected_masks: _apply_extract(grid, selected_masks, bg)
return actions
def _apply_recolor(grid, selected_masks, target_color):
result = grid.copy()
for mask in selected_masks:
result[mask] = target_color
return result
def _apply_fill_enclosed(grid, selected_masks, bg):
"""Fill enclosed background regions that are bounded by selected objects."""
result = grid.copy()
h, w = grid.shape
for mask in selected_masks:
color = int(grid[mask][0]) if np.any(mask) else 0
if color == 0:
continue
# Find bbox of this object
rows, cols = np.where(mask)
if len(rows) == 0:
continue
rmin, rmax = rows.min(), rows.max()
cmin, cmax = cols.min(), cols.max()
# Within bbox, find bg cells enclosed by this object
for r in range(rmin, rmax + 1):
for c in range(cmin, cmax + 1):
if result[r, c] == bg:
# Check if this bg cell is "inside" the object
# Simple test: surrounded on all 4 cardinal directions by object cells
inside = True
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
found = False
nr, nc = r + dr, c + dc
while 0 <= nr < h and 0 <= nc < w:
if mask[nr, nc]:
found = True
break
nr += dr
nc += dc
if not found:
inside = False
break
if inside:
result[r, c] = color
return result
def _apply_fill_interior(grid, selected_masks, bg):
"""Fill the bounding box interior of selected objects with the object's color."""
result = grid.copy()
for mask in selected_masks:
color = int(grid[mask][0]) if np.any(mask) else 0
if color == 0:
continue
rows, cols = np.where(mask)
if len(rows) == 0:
continue
rmin, rmax = rows.min(), rows.max()
cmin, cmax = cols.min(), cols.max()
for r in range(rmin, rmax + 1):
for c in range(cmin, cmax + 1):
if result[r, c] == bg:
result[r, c] = color
return result
def _apply_remove(grid, selected_masks, bg):
result = grid.copy()
for mask in selected_masks:
result[mask] = bg
return result
def _apply_extract(grid, selected_masks, bg):
result = np.full_like(grid, bg)
for mask in selected_masks:
result[mask] = grid[mask]
return result
# =============================================================================
# Neighborhood Rule Table (CA-style)
# =============================================================================
def learn_neighborhood_rule(train_pairs):
"""
For same-shape tasks: build a lookup table
(center_color, sorted_neighbor_colors) -> output_color
If consistent across all training pairs, return the rule.
"""
# Check all same shape
for pair in train_pairs:
inp = np.array(pair['input'])
out = np.array(pair['output'])
if inp.shape != out.shape:
return None
rule_table = {} # (center, neighbor_sig) -> output_color
conflicts = False
for pair in train_pairs:
inp = np.array(pair['input'], dtype=int)
out = np.array(pair['output'], dtype=int)
h, w = inp.shape
for r in range(h):
for c in range(w):
center = int(inp[r, c])
out_val = int(out[r, c])
# Get 4-neighbor colors
neighbors = []
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < h and 0 <= nc < w:
neighbors.append(int(inp[nr, nc]))
else:
neighbors.append(-1) # border sentinel
key = (center, tuple(sorted(neighbors)))
if key in rule_table:
if rule_table[key] != out_val:
conflicts = True
break
else:
rule_table[key] = out_val
if conflicts:
break
if conflicts:
break
if conflicts:
return None
return rule_table
def apply_neighborhood_rule(grid, rule_table):
"""Apply a learned neighborhood rule table to a grid."""
grid = np.array(grid, dtype=int)
h, w = grid.shape
result = grid.copy()
for r in range(h):
for c in range(w):
center = int(grid[r, c])
neighbors = []
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
nr, nc = r + dr, c + dc
if 0 <= nr < h and 0 <= nc < w:
neighbors.append(int(grid[nr, nc]))
else:
neighbors.append(-1)
key = (center, tuple(sorted(neighbors)))
if key in rule_table:
result[r, c] = rule_table[key]
return result
# =============================================================================
# Global Fill Rules (not object-specific)
# =============================================================================
def try_global_enclosed_fill(train_pairs):
"""
Try: fill all enclosed bg regions with a consistent color.
Learn the fill color from training pairs.
"""
fill_colors = []
for pair in train_pairs:
inp = np.array(pair['input'], dtype=int)
out = np.array(pair['output'], dtype=int)
if inp.shape != out.shape:
return None
bg = Counter(inp.flatten().tolist()).most_common(1)[0][0]
enclosed = get_enclosed_bg_regions(inp, bg)
if not np.any(enclosed):
continue
# What color fills the enclosed region in output?
fill_vals = out[enclosed]
unique = np.unique(fill_vals)
non_bg = unique[unique != bg]
if len(non_bg) == 1:
fill_colors.append(int(non_bg[0]))
elif len(non_bg) > 1:
return None # multiple colors fill enclosed — too complex
if not fill_colors:
return None
# Check consistency
if len(set(fill_colors)) != 1:
return None
fill_color = fill_colors[0]
# Validate on all pairs
for pair in train_pairs:
inp = np.array(pair['input'], dtype=int)
out = np.array(pair['output'], dtype=int)
bg = Counter(inp.flatten().tolist()).most_common(1)[0][0]
result = inp.copy()
enclosed = get_enclosed_bg_regions(inp, bg)
result[enclosed] = fill_color
if not np.array_equal(result, out):
return None
return fill_color
def try_per_object_enclosed_fill(train_pairs):
"""
Try: for each object, fill its enclosed interior with its own color.
"""
for pair in train_pairs:
inp = np.array(pair['input'], dtype=int)
out = np.array(pair['output'], dtype=int)
if inp.shape != out.shape:
return False
objects, bg = extract_objects_multi(inp, connectivity=4)
result = inp.copy()
for obj in objects:
mask = obj['mask']
color = obj['color']
rmin, cmin, rmax, cmax = obj['bbox']
h, w = inp.shape
for r in range(rmin, rmax + 1):
for c in range(cmin, cmax + 1):
if result[r, c] == bg:
# Ray-cast: is this cell inside the object?
inside = True
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
found = False
nr, nc = r + dr, c + dc
while 0 <= nr < h and 0 <= nc < w:
if mask[nr, nc]:
found = True
break
nr += dr
nc += dc
if not found:
inside = False
break
if inside:
result[r, c] = color
if not np.array_equal(result, out):
return False
return True
def apply_per_object_enclosed_fill(grid):
"""Apply per-object enclosed fill."""
grid = np.array(grid, dtype=int)
objects, bg = extract_objects_multi(grid, connectivity=4)
result = grid.copy()
h, w = grid.shape
for obj in objects:
mask = obj['mask']
color = obj['color']
rmin, cmin, rmax, cmax = obj['bbox']
for r in range(rmin, rmax + 1):
for c in range(cmin, cmax + 1):
if result[r, c] == bg:
inside = True
for dr, dc in [(-1,0),(1,0),(0,-1),(0,1)]:
found = False
nr, nc = r + dr, c + dc
while 0 <= nr < h and 0 <= nc < w:
if mask[nr, nc]:
found = True
break
nr += dr
nc += dc
if not found:
inside = False
break
if inside:
result[r, c] = color
return result
# =============================================================================
# Main Enumeration Engine
# =============================================================================
def enumerate_rules(train_pairs, max_time_ms=5000):
"""
Try all rule strategies on a task. Return the first that passes
all training pairs, or None.
Strategies (in order):
1. Global enclosed fill (single color)
2. Per-object enclosed fill
3. Neighborhood rule table (CA-style)
4. Object predicate × action enumeration
"""
import time
start = time.time()
# Check same shape
all_same_shape = all(
np.array(p['input']).shape == np.array(p['output']).shape
for p in train_pairs
)
# === Strategy 1: Global enclosed fill ===
fill_color = try_global_enclosed_fill(train_pairs)
if fill_color is not None:
def rule_fn(grid, _fc=fill_color):
g = np.array(grid, dtype=int)
bg = Counter(g.flatten().tolist()).most_common(1)[0][0]
enclosed = get_enclosed_bg_regions(g, bg)
result = g.copy()
result[enclosed] = _fc
return result
return ('global_enclosed_fill', rule_fn)
# === Strategy 2: Per-object enclosed fill ===
if all_same_shape and try_per_object_enclosed_fill(train_pairs):
return ('per_object_enclosed_fill', apply_per_object_enclosed_fill)
# === Strategy 3: Neighborhood rule table ===
if all_same_shape:
rule_table = learn_neighborhood_rule(train_pairs)
if rule_table is not None:
# Validate
valid = True
for pair in train_pairs:
pred = apply_neighborhood_rule(pair['input'], rule_table)
if not np.array_equal(pred, np.array(pair['output'], dtype=int)):
valid = False
break
if valid:
def rule_fn(grid, _rt=rule_table):
return apply_neighborhood_rule(grid, _rt)
return ('neighborhood_rule', rule_fn)
# === Strategy 4: Object predicate × action enumeration ===
if all_same_shape and (time.time() - start) * 1000 < max_time_ms:
result = _enumerate_predicate_actions(train_pairs)
if result is not None:
return result
return None
def _enumerate_predicate_actions(train_pairs):
"""Enumerate (connectivity × predicate × action) combinations."""
for connectivity in [4, 8]:
# Extract objects for all pairs
pair_data = []
for pair in train_pairs:
inp = np.array(pair['input'], dtype=int)
out = np.array(pair['output'], dtype=int)
objects, bg = extract_objects_multi(inp, connectivity)
pair_data.append((inp, out, objects, bg))
if not pair_data or not pair_data[0][2]:
continue
# Build predicates from first pair
first_objects = pair_data[0][2]
first_bg = pair_data[0][3]
predicates = _build_predicates(first_objects, first_bg)
actions = _build_actions(first_objects, first_bg, pair_data[0][0].shape)
# Enumerate
for pred_name, pred_fn in predicates.items():
for act_name, act_fn in actions.items():
# Test on all pairs
all_pass = True
for inp, out, objects, bg in pair_data:
if not objects:
all_pass = False
break
# Rebuild predicates for this pair's objects
local_preds = _build_predicates(objects, bg)
local_pred = local_preds.get(pred_name)
if local_pred is None:
all_pass = False
break
# Select objects matching predicate
selected = [o for o in objects if local_pred(o)]
if not selected and pred_name != 'always_true':
all_pass = False
break
selected_masks = [o['mask'] for o in selected]
try:
result = act_fn(inp, selected_masks)
if not np.array_equal(result, out):
all_pass = False
break
except Exception:
all_pass = False
break
if all_pass:
# Build a reusable rule function
def make_rule(pn, an, conn):
def rule_fn(grid):
g = np.array(grid, dtype=int)
objs, bg = extract_objects_multi(g, conn)
if not objs:
return g
preds = _build_predicates(objs, bg)
pred = preds.get(pn, lambda o: False)
acts = _build_actions(objs, bg, g.shape)
act = acts.get(an)
if act is None:
return g
selected = [o for o in objs if pred(o)]
masks = [o['mask'] for o in selected]
return act(g, masks)
return rule_fn
return (f'predicate_{pred_name}_action_{act_name}_conn{connectivity}',
make_rule(pred_name, act_name, connectivity))
return None