ARC-AGI / pemf /tests /test_transforms.py
Roger MT
move fles into pemf folder
feb08d1
"""
Unit tests for all transforms in itt_solver.transforms.
Usage:
python tests/test_transforms.py
40 tests covering: Kronecker, mirror tiles, upscale, downscale, stack,
rotate, reflect, color ops, gravity, crop, transpose, shifted tile,
fill enclosed.
"""
import numpy as np
from itt_solver import transforms as tr
INP = np.array([[0,7,7],[7,7,7],[0,7,7]], dtype=float)
tests_passed = 0
tests_failed = 0
def check(name, condition):
global tests_passed, tests_failed
if condition:
print(f" ✅ {name}")
tests_passed += 1
else:
print(f" ❌ {name}")
tests_failed += 1
print("=== Kronecker Self-Similar ===")
T = tr.KroneckerSelfSimilar()
out = T.apply(INP)
check("Output shape is 9x9", out.shape == (9, 9))
check("σ=0 vs known target", np.array_equal(out, np.kron((INP!=0).astype(float), INP)))
print("\n=== KroneckerSelfSimilarInv ===")
T = tr.KroneckerSelfSimilarInv()
out = T.apply(INP)
check("Output shape is 9x9", out.shape == (9, 9))
print("\n=== MirrorTileH ===")
T = tr.MirrorTileH()
out = T.apply(INP)
check("Shape is 3x6", out.shape == (3, 6))
check("Left half is input", np.array_equal(out[:, :3], INP))
check("Right half is fliplr(input)", np.array_equal(out[:, 3:], np.fliplr(INP)))
print("\n=== MirrorTileV ===")
T = tr.MirrorTileV()
out = T.apply(INP)
check("Shape is 6x3", out.shape == (6, 3))
check("Top half is input", np.array_equal(out[:3, :], INP))
check("Bottom half is flipud(input)", np.array_equal(out[3:, :], np.flipud(INP)))
print("\n=== MirrorTile4Way ===")
T = tr.MirrorTile4Way()
out = T.apply(INP)
check("Shape is 6x6", out.shape == (6, 6))
print("\n=== Upscale 2x ===")
T = tr.Upscale(2)
out = T.apply(INP)
check("Shape is 6x6", out.shape == (6, 6))
check("Top-left 2x2 block is INP[0,0]", np.all(out[:2, :2] == INP[0, 0]))
print("\n=== Upscale 3x ===")
T = tr.Upscale(3)
out = T.apply(INP)
check("Shape is 9x9", out.shape == (9, 9))
check("Top-left 3x3 block is INP[0,0]", np.all(out[:3, :3] == INP[0, 0]))
print("\n=== Downscale 2x ===")
T = tr.Downscale(2)
big = np.kron(INP, np.ones((2, 2)))
out = T.apply(big)
check("Downscale of upscaled recovers original", np.array_equal(out, INP))
print("\n=== StackH 3 ===")
T = tr.StackH(3)
out = T.apply(INP)
check("Shape is 3x9", out.shape == (3, 9))
check("First third is input", np.array_equal(out[:, :3], INP))
print("\n=== StackV 3 ===")
T = tr.StackV(3)
out = T.apply(INP)
check("Shape is 9x3", out.shape == (9, 3))
check("First third is input", np.array_equal(out[:3, :], INP))
print("\n=== Rotate 90/180/270 ===")
for k in [1, 2, 3]:
T = tr.Rotate(k)
out = T.apply(INP)
check(f"Rotate_{90*k} matches np.rot90", np.array_equal(out, np.rot90(INP, k)))
print("\n=== Reflect h/v ===")
T = tr.Reflect('h')
check("Reflect_h matches flipud", np.array_equal(T.apply(INP), np.flipud(INP)))
T = tr.Reflect('v')
check("Reflect_v matches fliplr", np.array_equal(T.apply(INP), np.fliplr(INP)))
print("\n=== RetainColor ===")
T = tr.RetainColor(7)
out = T.apply(INP)
check("Only 7s remain", np.all(out[INP == 7] == 7))
check("Non-7 positions are 0", np.all(out[INP != 7] == 0))
print("\n=== RemoveColor ===")
T = tr.RemoveColor(7)
out = T.apply(INP)
check("7s are removed", np.all(out[INP == 7] == 0))
check("0s stay 0", np.all(out[INP == 0] == 0))
print("\n=== InvertColors ===")
T = tr.InvertColors()
out = T.apply(INP)
check("0→7 swap", np.all(out[INP == 0] == 7))
check("7→0 swap", np.all(out[INP == 7] == 0))
print("\n=== GravityDown ===")
T = tr.GravityDown()
col_in = np.array([[0,7,0],[0,0,7],[7,0,0]], dtype=float)
out = T.apply(col_in)
check("Col 0: 7 at bottom", out[2, 0] == 7 and out[0, 0] == 0 and out[1, 0] == 0)
check("Col 1: 7 at bottom", out[2, 1] == 7 and out[0, 1] == 0)
print("\n=== GravityUp ===")
T = tr.GravityUp()
out = T.apply(col_in)
check("Col 0: 7 at top", out[0, 0] == 7 and out[1, 0] == 0 and out[2, 0] == 0)
print("\n=== CropToContent ===")
T = tr.CropToContent()
padded = np.array([[0,0,0,0],[0,7,7,0],[0,7,7,0],[0,0,0,0]], dtype=float)
out = T.apply(padded)
check("Crops to 2x2", out.shape == (2, 2))
check("All 7s", np.all(out == 7))
print("\n=== Transpose ===")
T = tr.Transpose()
out = T.apply(INP)
check("Shape is transposed", out.shape == (3, 3))
check("Values match transpose", np.array_equal(out, INP.T))
print("\n=== ShiftedTile ===")
T = tr.tile_to_target_shifted(shift=(1, 1), tile_factor=3)
out = T.apply(INP)
check("Shape is 9x9", out.shape == (9, 9))
check("Differs from vanilla tile", not np.array_equal(out, np.tile(INP, (3, 3))))
print("\n=== FillEnclosedHarmonic ===")
T = tr.FillEnclosedHarmonic()
enclosed = np.array([[7,7,7],[7,0,7],[7,7,7]], dtype=float)
out = T.apply(enclosed)
check("Center hole filled", out[1, 1] == 7)
print(f"\n{'='*50}")
print(f"Results: {tests_passed} passed, {tests_failed} failed")