from .solver_core import tile_transform import numpy as np def gate_boundary_respect(phi_candidate, phi_in, phi_target, boundary_mask=None): """Gate A: candidate must not change locked boundary cells. If boundary_mask is None, infer from phi_in != 0. If mask shape differs from candidate, resize by tiling. """ if boundary_mask is None: boundary_mask = (phi_in != 0) # If mask shape differs, resize it to candidate shape using tiling if boundary_mask.shape != phi_candidate.shape: # convert to int for tile_transform then back to bool try: resized = tile_transform(boundary_mask.astype(int), phi_candidate.shape).astype(bool) except Exception: # fallback: broadcast along axes if possible resized = np.broadcast_to(boundary_mask, phi_candidate.shape) else: resized = boundary_mask # If any boundary cell in phi_target differs from candidate where resized True, reject changed = np.any((phi_candidate[resized] != phi_target[resized])) return not changed def gate_sigma_localization(phi_candidate, phi_target, max_fraction=0.5): """Gate B: ensure residue is localized. Compute fraction of cells that differ; must be <= max_fraction.""" diff = (phi_candidate != phi_target) frac = np.sum(diff) / diff.size return frac <= max_fraction def gate_quantization(phi_candidate, allowed_symbols=None, tol=1e-6): """Gate C: candidate values must quantize to allowed_symbols (integers 0-9 by default). Allow small numerical tolerance. """ if allowed_symbols is None: allowed_symbols = list(range(10)) # Round candidate to nearest integer and check membership rounded = np.rint(phi_candidate).astype(int) mask = np.isin(rounded, allowed_symbols) return np.all(mask) def validate_gates(phi_candidate, phi_in, phi_target, boundary_mask=None, max_fraction=0.5, allowed_symbols=None): """Run all gates and return a dict of booleans and an overall pass boolean.""" a = gate_boundary_respect(phi_candidate, phi_in, phi_target, boundary_mask) b = gate_sigma_localization(phi_candidate, phi_target, max_fraction=max_fraction) c = gate_quantization(phi_candidate, allowed_symbols=allowed_symbols) passed = a and b and c return {"A_boundary": a, "B_localization": b, "C_quantization": c, "passed": passed}