ARC-AGI / itt_solver /gates.py
rogermt's picture
Upload 47 files
b48dd06 verified
from .solver_core import tile_transform
import numpy as np
def gate_boundary_respect(phi_candidate, phi_in, phi_target, boundary_mask=None):
"""Gate A: candidate must not change locked boundary cells.
If boundary_mask is None, infer from phi_in != 0.
If mask shape differs from candidate, resize by tiling.
"""
if boundary_mask is None:
boundary_mask = (phi_in != 0)
# If mask shape differs, resize it to candidate shape using tiling
if boundary_mask.shape != phi_candidate.shape:
# convert to int for tile_transform then back to bool
try:
resized = tile_transform(boundary_mask.astype(int), phi_candidate.shape).astype(bool)
except Exception:
# fallback: broadcast along axes if possible
resized = np.broadcast_to(boundary_mask, phi_candidate.shape)
else:
resized = boundary_mask
# If any boundary cell in phi_target differs from candidate where resized True, reject
changed = np.any((phi_candidate[resized] != phi_target[resized]))
return not changed
def gate_sigma_localization(phi_candidate, phi_target, max_fraction=0.5):
"""Gate B: ensure residue is localized. Compute fraction of cells that differ; must be <= max_fraction."""
diff = (phi_candidate != phi_target)
frac = np.sum(diff) / diff.size
return frac <= max_fraction
def gate_quantization(phi_candidate, allowed_symbols=None, tol=1e-6):
"""Gate C: candidate values must quantize to allowed_symbols (integers 0-9 by default).
Allow small numerical tolerance.
"""
if allowed_symbols is None:
allowed_symbols = list(range(10))
# Round candidate to nearest integer and check membership
rounded = np.rint(phi_candidate).astype(int)
mask = np.isin(rounded, allowed_symbols)
return np.all(mask)
def validate_gates(phi_candidate, phi_in, phi_target, boundary_mask=None, max_fraction=0.5, allowed_symbols=None):
"""Run all gates and return a dict of booleans and an overall pass boolean."""
a = gate_boundary_respect(phi_candidate, phi_in, phi_target, boundary_mask)
b = gate_sigma_localization(phi_candidate, phi_target, max_fraction=max_fraction)
c = gate_quantization(phi_candidate, allowed_symbols=allowed_symbols)
passed = a and b and c
return {"A_boundary": a, "B_localization": b, "C_quantization": c, "passed": passed}