shreyask commited on
Commit
fffa30b
Β·
verified Β·
1 Parent(s): 76fda9f

Upload verify_port_parity.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. verify_port_parity.py +256 -0
verify_port_parity.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verify the PyTorch port matches the Flax model numerically (< 1e-3 max-abs-diff).
2
+
3
+ Checks:
4
+ 1. Encoder output for a fixed input_ids tensor
5
+ 2. Decoder logits at step 0 (empty past_kv) for a fixed decoder_input_id
6
+ using the encoder output from step 1
7
+
8
+ Tolerance: max(abs(flax_out - pt_out)) < 1e-3
9
+
10
+ Flax is run in float32 to avoid bfloat16 precision noise masking real bugs.
11
+ """
12
+
13
+ import sys
14
+ from pathlib import Path
15
+ import pickle
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ # Make the Cactus Flax package importable
21
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "external" / "needle"))
22
+ # Make the PyTorch port importable
23
+ sys.path.insert(0, str(Path(__file__).resolve().parent))
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ from needle.model.architecture import SimpleAttentionNetwork, TransformerConfig as FlaxTransformerConfig
29
+ from needle_torch import NeedleModel, TransformerConfig
30
+
31
+ ART = Path(__file__).resolve().parent / "artifacts"
32
+
33
+ TOLERANCE = 1e-3
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Load helpers
38
+ # ---------------------------------------------------------------------------
39
+
40
+ def load_flax_checkpoint():
41
+ """Load the locally cached needle.pkl and return (flax_params, config_dict)."""
42
+ path = ART / "needle.pkl"
43
+ print(f"Loading Flax checkpoint from {path} ...", flush=True)
44
+ with open(path, "rb") as f:
45
+ ckpt = pickle.load(f)
46
+ return ckpt["params"], ckpt["config"]
47
+
48
+
49
+ def cast_params_to_f32(params):
50
+ """Recursively cast all JAX arrays in a nested param tree to float32."""
51
+ if isinstance(params, dict):
52
+ return {k: cast_params_to_f32(v) for k, v in params.items()}
53
+ arr = np.asarray(params).astype(np.float32)
54
+ return jnp.array(arr)
55
+
56
+
57
+ def load_pt_model(config_dict):
58
+ cfg = TransformerConfig(**{k: v for k, v in config_dict.items()
59
+ if k in TransformerConfig.__dataclass_fields__})
60
+ m = NeedleModel(cfg)
61
+ m.eval()
62
+ state = torch.load(ART / "needle_torch.pt", map_location="cpu", weights_only=True)
63
+ m.load_state_dict(state, strict=True)
64
+ return m, cfg
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # Bisection helper
69
+ # ---------------------------------------------------------------------------
70
+
71
+ def bisect_encoder(flax_model, flax_params_f32, pt_model, ids_np):
72
+ """Compare encoder layer-by-layer to find the first divergent layer."""
73
+ print("\n--- Encoder bisection ---", flush=True)
74
+ ids_jax = jnp.asarray(ids_np)
75
+
76
+ # Flax intermediates via capture_intermediates
77
+ _, state = flax_model.apply(
78
+ {'params': flax_params_f32},
79
+ ids_jax,
80
+ capture_intermediates=True,
81
+ method=flax_model.encode_text,
82
+ )
83
+
84
+ print("Flax intermediates structure (top level):")
85
+ def print_tree(d, prefix='', depth=0):
86
+ if depth > 5:
87
+ return
88
+ if isinstance(d, dict):
89
+ for k, v in d.items():
90
+ if isinstance(v, dict):
91
+ print(f"{' '*depth}{prefix}{k}/")
92
+ print_tree(v, prefix='', depth=depth+1)
93
+ else:
94
+ shape = getattr(v, 'shape', '?')
95
+ print(f"{' '*depth}{prefix}{k}: {shape}")
96
+ print_tree(state['intermediates'])
97
+
98
+ # PyTorch intermediates via hooks
99
+ pt_intermediates = {}
100
+ hooks = []
101
+ for i, layer in enumerate(pt_model.encoder.layers):
102
+ def make_hook(idx):
103
+ def hook(module, inp, output):
104
+ pt_intermediates[f'encoder_layer_{idx}'] = output.detach().cpu().numpy()
105
+ return hook
106
+ hooks.append(layer.register_forward_hook(make_hook(i)))
107
+
108
+ def final_norm_hook(module, inp, output):
109
+ pt_intermediates['encoder_final_norm'] = output.detach().cpu().numpy()
110
+ hooks.append(pt_model.encoder.final_norm.register_forward_hook(final_norm_hook))
111
+
112
+ with torch.no_grad():
113
+ _ = pt_model.encoder(torch.from_numpy(ids_np.astype(np.int64)))
114
+
115
+ for h in hooks:
116
+ h.remove()
117
+
118
+ print(f"PyTorch intermediates captured: {list(pt_intermediates.keys())}", flush=True)
119
+
120
+
121
+ def bisect_decoder_step0(flax_model, flax_params_f32, pt_model, dec_id_np, flax_enc_out, pt_enc_out):
122
+ """Compare decoder step-0 layer by layer."""
123
+ print("\n--- Decoder step-0 bisection ---", flush=True)
124
+ dec_id_jax = jnp.asarray(dec_id_np)
125
+
126
+ _, state = flax_model.apply(
127
+ {'params': flax_params_f32},
128
+ dec_id_jax,
129
+ flax_enc_out,
130
+ capture_intermediates=True,
131
+ method=flax_model.decode,
132
+ )
133
+ print("Flax decoder intermediates (top-level):", list(state['intermediates'].keys()), flush=True)
134
+
135
+
136
+ # ---------------------------------------------------------------------------
137
+ # Main
138
+ # ---------------------------------------------------------------------------
139
+
140
+ def main():
141
+ flax_params, config_dict = load_flax_checkpoint()
142
+ print(f"Config: {config_dict}", flush=True)
143
+
144
+ # Cast Flax params to float32 to avoid bfloat16 precision differences
145
+ print("Casting Flax params to float32 ...", flush=True)
146
+ flax_params_f32 = cast_params_to_f32(flax_params)
147
+
148
+ # Build Flax model with float32 dtype
149
+ config_dict_f32 = dict(config_dict, dtype="float32")
150
+ flax_cfg = FlaxTransformerConfig(**config_dict_f32)
151
+ flax_model = SimpleAttentionNetwork(flax_cfg)
152
+
153
+ # Load PyTorch model
154
+ pt_model, pt_cfg = load_pt_model(config_dict)
155
+
156
+ # Fixed input token sequence
157
+ np.random.seed(0)
158
+ ids_np = np.array(
159
+ [[2, 100, 200, 300, 400, 500, 5, 600, 700, 800, 900, 1000, 1100, 1200, 1300, 1]],
160
+ dtype=np.int32,
161
+ )
162
+ ids_jax = jnp.asarray(ids_np)
163
+
164
+ # ── Check 1: Encoder ────────────────────────────────────────────────────
165
+ print("\n=== Check 1: Encoder ===", flush=True)
166
+
167
+ # Flax encode returns (encoder_out, mask)
168
+ flax_enc_out, flax_enc_mask = flax_model.apply(
169
+ {'params': flax_params_f32},
170
+ ids_jax,
171
+ method=flax_model.encode,
172
+ )
173
+ flax_enc_np = np.asarray(flax_enc_out).astype(np.float32)
174
+
175
+ with torch.no_grad():
176
+ pt_enc_out = pt_model.encoder(
177
+ torch.from_numpy(ids_np.astype(np.int64))
178
+ ).cpu().numpy()
179
+
180
+ print(f"Flax encoder output shape: {flax_enc_np.shape}, stats: "
181
+ f"min={flax_enc_np.min():.4f} max={flax_enc_np.max():.4f} "
182
+ f"mean={flax_enc_np.mean():.4f}", flush=True)
183
+ print(f"PT encoder output shape: {pt_enc_out.shape}, stats: "
184
+ f"min={pt_enc_out.min():.4f} max={pt_enc_out.max():.4f} "
185
+ f"mean={pt_enc_out.mean():.4f}", flush=True)
186
+
187
+ enc_diff = float(np.max(np.abs(flax_enc_np - pt_enc_out)))
188
+ enc_mean_diff = float(np.mean(np.abs(flax_enc_np - pt_enc_out)))
189
+ print(f"\nencoder max-abs-diff: {enc_diff:.6f}", flush=True)
190
+ print(f"encoder mean-abs-diff: {enc_mean_diff:.6f}", flush=True)
191
+
192
+ enc_ok = enc_diff < TOLERANCE
193
+ if not enc_ok:
194
+ print(f"encoder parity FAILED (diff={enc_diff:.6f} >= {TOLERANCE}) -- bisecting ...", flush=True)
195
+ bisect_encoder(flax_model, flax_params_f32, pt_model, ids_np)
196
+ sys.exit(1)
197
+ else:
198
+ print(f"encoder parity OK (diff={enc_diff:.6f} < {TOLERANCE})", flush=True)
199
+
200
+ # ── Check 2: Decoder step 0 ─────────────────────────────────────────────
201
+ print("\n=== Check 2: Decoder step 0 ===", flush=True)
202
+
203
+ dec_id_np = np.array([[1]], dtype=np.int32)
204
+ dec_id_jax = jnp.asarray(dec_id_np)
205
+
206
+ # Flax: decode(tgt, encoder_out) -> logits (B, T_dec, vocab_size)
207
+ flax_logits = flax_model.apply(
208
+ {'params': flax_params_f32},
209
+ dec_id_jax,
210
+ flax_enc_out,
211
+ method=flax_model.decode,
212
+ )
213
+ flax_logits_np = np.asarray(flax_logits).astype(np.float32)
214
+
215
+ with torch.no_grad():
216
+ past_kv = pt_model.decoder.initial_past_kv(batch=1)
217
+ pt_logits, _ = pt_model.decoder.step(
218
+ torch.from_numpy(dec_id_np.astype(np.int64)),
219
+ torch.from_numpy(pt_enc_out),
220
+ past_kv,
221
+ )
222
+ pt_logits_np = pt_logits.cpu().numpy()
223
+
224
+ print(f"Flax logits shape: {flax_logits_np.shape}, stats: "
225
+ f"min={flax_logits_np.min():.4f} max={flax_logits_np.max():.4f}", flush=True)
226
+ print(f"PT logits shape: {pt_logits_np.shape}, stats: "
227
+ f"min={pt_logits_np.min():.4f} max={pt_logits_np.max():.4f}", flush=True)
228
+
229
+ logits_diff = float(np.max(np.abs(flax_logits_np - pt_logits_np)))
230
+ logits_mean_diff = float(np.mean(np.abs(flax_logits_np - pt_logits_np)))
231
+ print(f"\ndecoder step-0 logits max-abs-diff: {logits_diff:.6f}", flush=True)
232
+ print(f"decoder step-0 logits mean-abs-diff: {logits_mean_diff:.6f}", flush=True)
233
+
234
+ dec_ok = logits_diff < TOLERANCE
235
+ if not dec_ok:
236
+ print(f"decoder parity FAILED (diff={logits_diff:.6f} >= {TOLERANCE}) -- bisecting ...", flush=True)
237
+ bisect_decoder_step0(flax_model, flax_params_f32, pt_model, dec_id_np, flax_enc_out, pt_enc_out)
238
+ sys.exit(1)
239
+ else:
240
+ print(f"decoder step-0 parity OK (diff={logits_diff:.6f} < {TOLERANCE})", flush=True)
241
+
242
+ # ── Summary ─────────────────────────────────────────────────────────────
243
+ print("\n" + "="*60, flush=True)
244
+ print("port parity OK (< 1e-3)", flush=True)
245
+ print(f" encoder max-abs-diff: {enc_diff:.6f}", flush=True)
246
+ print(f" decoder step-0 max-abs-diff: {logits_diff:.6f}", flush=True)
247
+
248
+ flax_argmax = int(np.argmax(flax_logits_np[0, 0]))
249
+ pt_argmax = int(np.argmax(pt_logits_np[0, 0]))
250
+ print(f" Flax argmax token: {flax_argmax}", flush=True)
251
+ print(f" PT argmax token: {pt_argmax}", flush=True)
252
+ print("="*60, flush=True)
253
+
254
+
255
+ if __name__ == "__main__":
256
+ main()