File size: 11,412 Bytes
b48dd06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""
Utility tests and debug helpers for the itt_solver beam runs.

Save this file into the `itt_solver/` package and import the helpers from your notebook.
Example usage (in a notebook cell after a run that produced `states, sigmas, logs, phi_target`):

from itt_solver import tests
tests.print_depth0_logs(logs)
tests.check_first_accepted_score(logs, lock_coeff=0.01)
tests.gate_failure_summary(logs)
tests.plot_layer1_mask(states[0], l1_module)   # pass the layer_minus_one module you imported as l1

Added helpers:
- transform_effect_test(transform, phi): returns number of changed cells and a small diff map.
- sigma_decrease_smoke_test(beam_func, phi_in, phi_target, atomic_library): runs a relaxed beam and reports sigma trace and whether sigma decreased.
- run_all_quick_checks(...) convenience runner for quick local verification.
"""

from pprint import pprint
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np

def print_depth0_logs(logs):
    """Pretty-print depth-0 logs and accepted candidate summary."""
    if not logs:
        print("No logs available.")
        return
    if len(logs) <= 0:
        print("Logs list is empty.")
        return

    depth0 = logs[0]
    print("Depth 0 log entries:", len(depth0))
    pprint(depth0)

    accepted = [r for r in depth0 if r.get('accepted')]
    print("Accepted count at depth 0:", len(accepted))
    for i, r in enumerate(accepted):
        print(i, r.get('atomic'), "score", r.get('score'), "gates", r.get('gates'))


def check_first_accepted_score(logs, lock_coeff=0.01, tolerance=1e-8):
    """
    Find the first accepted candidate in depth 0 and assert score == residue + lock_coeff * energy.
    Returns True if check passes, False otherwise.
    """
    if not logs or len(logs) == 0:
        print("No logs to check.")
        return False

    first_accepted = next((r for r in logs[0] if r.get('accepted')), None)
    if first_accepted is None:
        print("No accepted candidates at depth 0. See logs for gate failures.")
        return False

    res = first_accepted.get('residue')
    E = first_accepted.get('energy')
    score = first_accepted.get('score')
    if res is None or E is None or score is None:
        print("Missing numeric fields in the candidate log.")
        return False

    ok = abs(score - (res + lock_coeff * E)) < tolerance
    if ok:
        print("Score check passed for first accepted candidate.")
    else:
        print("Score check FAILED.")
        print(f"Logged score: {score}")
        print(f"Computed  : {res + lock_coeff * E}")
    return ok


def gate_failure_summary(logs):
    """
    Count gate failures in depth 0 and print a summary.
    Returns a Counter with failure counts.
    """
    if not logs or len(logs) == 0:
        print("No logs to summarize.")
        return Counter()

    c = Counter()
    for r in logs[0]:
        g = r.get('gates')
        if not g:
            c['no_gate_info'] += 1
            continue
        if not g.get('A_boundary', True):
            c['A_boundary_failed'] += 1
        if not g.get('B_localization', True):
            c['B_localization_failed'] += 1
        if not g.get('C_quantization', True):
            c['C_quantization_failed'] += 1
        if g.get('passed') is False:
            c['total_rejected'] += 1
        else:
            c['total_accepted'] += 1
    print("Gate failure summary (depth 0):", dict(c))
    return c


def plot_layer1_mask(state, l1_module, imag_grad_threshold=None, figsize=(4,4)):
    """
    Compute and plot the Layer-1 admissible edit mask and magnitude.
    - state: a NumPy array (resized candidate / phi field).
    - l1_module: the imported layer_minus_one module (e.g., import itt_solver.layer_minus_one as l1).
    - imag_grad_threshold: optional threshold to pass through to admissible_edit_mask.
    """
    if state is None:
        print("No state provided.")
        return

    try:
        mask, mag = l1_module.admissible_edit_mask(state, imag_grad_threshold)
    except Exception as e:
        print("Error computing Layer-1 mask:", e)
        return

    plt.figure(figsize=figsize)
    plt.imshow(mask, cmap='gray')
    plt.title('Layer-1 admissible edit mask')
    plt.axis('off')
    plt.show()

    plt.figure(figsize=figsize)
    plt.imshow(mag, cmap='magma')
    plt.title('||∇Im(Φ_c)|| magnitude')
    plt.colorbar()
    plt.axis('off')
    plt.show()


def assert_states_shape(states, phi_target):
    """Assert all states have the same shape as phi_target. Returns True if OK, False otherwise."""
    if not states:
        print("No states provided.")
        return False
    target_shape = tuple(phi_target.shape)
    for i, s in enumerate(states):
        if tuple(s.shape) != target_shape:
            print(f"State {i} shape mismatch: {s.shape} != {target_shape}")
            return False
    print("All states match target shape:", target_shape)
    return True


# --- New tests added below ---------------------------------------------------

def transform_effect_test(transform, phi, show_diff=False):
    """
    Apply a Transform-like object (must have .apply(phi)) to phi and return:
      - changed_count: number of cells that differ after transform
      - diff_map: boolean array where True indicates changed cells
    If show_diff True, also plot the diff map.
    """
    if not hasattr(transform, 'apply'):
        raise ValueError("transform must have an apply(phi) method")
    phi = np.array(phi, dtype=float)
    phi_after = transform.apply(phi.copy())
    if phi_after.shape != phi.shape:
        # try to resize phi_after to phi shape by simple tiling if shapes differ
        from .solver_core import tile_transform
        try:
            phi_after = tile_transform(phi_after, phi.shape)
        except Exception:
            # fallback: broadcast if possible
            phi_after = np.broadcast_to(phi_after, phi.shape)
    diff_map = (phi_after != phi)
    changed_count = int(np.sum(diff_map))
    if show_diff:
        plt.figure(figsize=(4,4))
        plt.imshow(diff_map, cmap='gray')
        plt.title(f'Transform effect diff (changed={changed_count})')
        plt.axis('off')
        plt.show()
    return changed_count, diff_map


def sigma_decrease_smoke_test(beam_func, phi_in, phi_target, atomic_library,
                              beam_kwargs=None):
    """
    Run a relaxed beam (lock_coeff=0, max_fraction=1.0) to check whether sigma can decrease.
    beam_func: callable with signature beam_func(phi_in, phi_target, atomic_library, **kwargs)
    Returns a dict with keys:
      - 'sigmas': sigma trace list
      - 'decreased': True if final sigma < initial sigma
      - 'result': tuple returned by beam_func
    """
    beam_kwargs = dict(beam_kwargs or {})
    # enforce relaxed settings for smoke test
    beam_kwargs.setdefault('lock_coeff', 0.0)
    beam_kwargs.setdefault('max_fraction', 1.0)
    beam_kwargs.setdefault('enable_layer_minus_one', True)
    beam_kwargs.setdefault('boundary_source', 'target')
    # allow all quantized symbols for this test
    beam_kwargs.setdefault('allowed_symbols', list(range(10)))

    result = beam_func(phi_in, phi_target, atomic_library, **beam_kwargs)
    # beam_func expected to return (T_best, phi_best, states, sigmas, logs)
    if not result or len(result) < 4:
        return {'sigmas': None, 'decreased': False, 'result': result}

    sigmas = result[3]
    decreased = False
    if sigmas and len(sigmas) >= 2:
        decreased = float(sigmas[-1]) < float(sigmas[0])
    return {'sigmas': sigmas, 'decreased': decreased, 'result': result}


def run_all_quick_checks(states, logs, phi_target, l1_module=None, lock_coeff=0.01,
                         beam_smoke_runner=None, phi_in=None, atomic_library=None):
    """
    Convenience runner that executes the basic checks and the smoke sigma test (if beam_smoke_runner provided).
    Returns a dict of results.
    """
    results = {}
    results['shape_ok'] = assert_states_shape(states, phi_target)
    try:
        print_depth0_logs(logs)
        results['print_logs'] = True
    except Exception:
        results['print_logs'] = False

    results['first_accepted_score_ok'] = check_first_accepted_score(logs, lock_coeff=lock_coeff)
    results['gate_summary'] = gate_failure_summary(logs)
    if l1_module is not None:
        try:
            print("Plotting Layer-1 mask for states[0]...")
            plot_layer1_mask(states[0], l1_module)
            results['layer1_plotted'] = True
        except Exception:
            results['layer1_plotted'] = False

    if beam_smoke_runner is not None and phi_in is not None and atomic_library is not None:
        print("Running sigma decrease smoke test (relaxed beam)...")
        smoke = sigma_decrease_smoke_test(beam_smoke_runner, phi_in, phi_target, atomic_library)
        results['smoke_sigmas'] = smoke.get('sigmas')
        results['smoke_decreased'] = smoke.get('decreased')
    return results

def run_atomic_effects(task_input=None, params=None, target_shape=(9,9)):
    """
    Build the default atomic library and report whether each transform changes the provided task_input.
    Prints shape and changed-cell counts.

    - task_input: either a NumPy array or a small grid (list of lists). If None, a default 3x3 example is used.
    - params: dict of parameters passed to default_atomic_factory (optional).
    - target_shape: tuple used to construct the atomic library via default_atomic_factory.
    """
    # lazy imports to avoid top-level dependency issues
    from .experiment_driver import default_atomic_factory
    from .solver_core import initialize_potential, tile_transform
    import numpy as _np

    if task_input is None:
        task_input = [[0,7,7],[7,7,7],[0,7,7]]
    phi_in = initialize_potential(task_input)

    if params is None:
        params = {'beam_width':6,'max_depth':3,'lock_coeff':0.0,'max_fraction':1.0,'enable_layer_minus_one':True,'boundary_source':'target'}

    task_stub = {'target_shape': target_shape}
    atomic_library = default_atomic_factory(params, task_stub)

    print("Testing atomic library transforms on input shape", phi_in.shape)
    results = []
    for T in atomic_library:
        try:
            phi_after = T.apply(phi_in.copy())
        except Exception as e:
            print(f"{repr(T)} raised exception during apply(): {e}")
            results.append({'transform': repr(T), 'error': str(e)})
            continue

        # If shapes differ, try to tile phi_after to phi_in shape for comparison
        if phi_after.shape != phi_in.shape:
            try:
                phi_after_resized = tile_transform(phi_after, phi_in.shape)
            except Exception:
                try:
                    phi_after_resized = _np.broadcast_to(phi_after, phi_in.shape)
                except Exception:
                    phi_after_resized = None
        else:
            phi_after_resized = phi_after

        if phi_after_resized is None:
            changed = None
        else:
            diff_map = (phi_after_resized != phi_in)
            changed = int(_np.sum(diff_map))

        print(repr(T), "-> out shape", None if phi_after is None else phi_after.shape, "changed cells:", changed)
        results.append({'transform': repr(T), 'out_shape': None if phi_after is None else phi_after.shape, 'changed_cells': changed})

    return results