Bremin commited on
Commit
095c3d7
·
verified ·
1 Parent(s): 8fdba4d

Add polyglot parity check script

Browse files
Files changed (1) hide show
  1. check_polyglot_parity.py +250 -0
check_polyglot_parity.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parity check: does the trajectory-eval shallow clone produce the same
3
+ polyglot-parsed graph + BERT features as the pre-existing big-machine
4
+ clone (data_multilang), for the same commit?
5
+
6
+ Runs on big where both clones exist. For each common (repo, commit) pair
7
+ it encounters, it snapshots the working tree from *both* clones, canonical-
8
+ hashes the graph structure + feature tensors, and reports match/mismatch.
9
+
10
+ A match confirms:
11
+ - git clone --filter=blob:none + checkout fetches the same file content
12
+ as the original full clone
13
+ - parse_repo_polyglot is deterministic w.r.t. the file tree (modulo
14
+ rglob ordering — we sort before hashing)
15
+ - BertTokenEmbedder is deterministic
16
+
17
+ Usage (on big):
18
+ python -m graphjepa.check_polyglot_parity \\
19
+ --traj-repos ./outputs/traj_real/repos \\
20
+ --multi-repos /raid/train/datasets/code-graph-v7/data_multilang \\
21
+ --n-pairs 4
22
+
23
+ If it picks a commit not present in the shallow clone's blobless ref set
24
+ (some base_commits may need lazy blob fetch), the script does the fetch
25
+ automatically via checkout.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import argparse
31
+ import hashlib
32
+ import json
33
+ import subprocess
34
+ import sys
35
+ from pathlib import Path
36
+ from typing import List, Optional, Tuple
37
+
38
+
39
+ def run(cmd: List[str], cwd: Optional[Path] = None, check: bool = True):
40
+ r = subprocess.run(cmd, cwd=str(cwd) if cwd else None,
41
+ capture_output=True, text=True)
42
+ if check and r.returncode != 0:
43
+ raise RuntimeError(f'{" ".join(cmd)} failed: {r.stderr[-400:]}')
44
+ return r
45
+
46
+
47
+ def list_commits(repo_dir: Path, n: int = 20) -> List[str]:
48
+ r = run(['git', 'log', '--format=%H', '-n', str(n)], cwd=repo_dir)
49
+ return r.stdout.split()
50
+
51
+
52
+ def checkout(repo_dir: Path, sha: str) -> bool:
53
+ run(['git', 'reset', '--hard', '-q'], cwd=repo_dir, check=False)
54
+ run(['git', 'clean', '-fdx', '-q'], cwd=repo_dir, check=False)
55
+ r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, check=False)
56
+ if r.returncode != 0:
57
+ # Try fetching the ref
58
+ run(['git', 'fetch', '-q', 'origin', sha], cwd=repo_dir, check=False)
59
+ r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir,
60
+ check=False)
61
+ return r.returncode == 0
62
+
63
+
64
+ def canonical_hash(graph, features) -> Tuple[str, str, dict]:
65
+ """Deterministic hash of (graph structure, feature tensors).
66
+ Sorts node IDs so walk order doesn't matter.
67
+ Returns (graph_hash, feature_hash, stats_dict).
68
+ """
69
+ import torch
70
+
71
+ # Nodes: sort by id, hash (id, kind, content, type_description).
72
+ h_nodes = hashlib.sha256()
73
+ node_items = sorted(graph.nodes.items())
74
+ for nid, n in node_items:
75
+ h_nodes.update(nid.encode())
76
+ h_nodes.update(b'\x00')
77
+ h_nodes.update(getattr(n.kind, 'value', str(n.kind)).encode())
78
+ h_nodes.update(b'\x00')
79
+ h_nodes.update((n.content or '').encode())
80
+ h_nodes.update(b'\x00')
81
+ h_nodes.update((n.type_description or '').encode())
82
+ h_nodes.update(b'\x01')
83
+
84
+ # Edges: sort by (src, dst, kind).
85
+ edge_keys = sorted(
86
+ (e.src, e.dst, getattr(e.kind, 'value', str(e.kind)))
87
+ for e in graph.edges.values()
88
+ )
89
+ for src, dst, k in edge_keys:
90
+ h_nodes.update(f'E|{src}|{dst}|{k}|'.encode())
91
+
92
+ graph_hash = h_nodes.hexdigest()
93
+
94
+ # Feature tensors: for each kind in deterministic order, hash
95
+ # (sorted_ids, content_sum, type_sum, content_first_vec, type_first_vec).
96
+ h_feats = hashlib.sha256()
97
+ for kind, d in sorted((k, v) for k, v in features.items() if v is not None):
98
+ kind_str = getattr(kind, 'value', str(kind))
99
+ h_feats.update(kind_str.encode())
100
+ h_feats.update(b'\x00')
101
+ ids = list(d['ids'])
102
+ sort_idx = sorted(range(len(ids)), key=lambda i: ids[i])
103
+ content = d['content'][sort_idx] if sort_idx else d['content']
104
+ typev = d['type'][sort_idx] if sort_idx else d['type']
105
+ sorted_ids = [ids[i] for i in sort_idx]
106
+ for sid in sorted_ids:
107
+ h_feats.update(sid.encode()); h_feats.update(b'\x00')
108
+ # Digest feature tensors numerically with fixed precision so
109
+ # hashes match across float ops that might differ in trailing ULP.
110
+ content_q = (content * 1e5).round().to(torch.int64)
111
+ typev_q = (typev * 1e5).round().to(torch.int64)
112
+ h_feats.update(content_q.cpu().numpy().tobytes())
113
+ h_feats.update(typev_q.cpu().numpy().tobytes())
114
+
115
+ feat_hash = h_feats.hexdigest()
116
+
117
+ stats = {
118
+ 'n_nodes': len(graph.nodes),
119
+ 'n_edges': len(graph.edges),
120
+ 'n_feat_kinds': sum(1 for v in features.values() if v is not None),
121
+ 'feat_dim': next((v['content'].shape[1] for v in features.values()
122
+ if v is not None), None),
123
+ }
124
+ return graph_hash, feat_hash, stats
125
+
126
+
127
+ def snapshot(repo_dir: Path, embedder) -> Tuple[str, str, dict]:
128
+ from graphjepa.trajectory_pipeline import snapshot_working_tree
129
+ g, feats = snapshot_working_tree(repo_dir, embedder, verbose=False)
130
+ return canonical_hash(g, feats)
131
+
132
+
133
+ # Mapping from trajectory-eval repo dirname → data_multilang subpath.
134
+ # traj repos: django__django; data_multilang: python/django
135
+ _REPO_DIR_MAP = {
136
+ 'django__django': ('python', 'django'),
137
+ 'sympy__sympy': ('python', 'sympy'),
138
+ 'sphinx-doc__sphinx': ('python', 'sphinx'),
139
+ 'matplotlib__matplotlib': ('python', 'matplotlib'),
140
+ 'scikit-learn__scikit-learn': ('python', 'scikit-learn'),
141
+ 'astropy__astropy': ('python', 'astropy'),
142
+ 'pydata__xarray': ('python', 'xarray'),
143
+ 'pytest-dev__pytest': ('python', 'pytest'),
144
+ 'pylint-dev__pylint': ('python', 'pylint'),
145
+ 'psf__requests': ('python', 'requests'),
146
+ 'mwaskom__seaborn': ('python', 'seaborn'),
147
+ 'pallets__flask': ('python', 'flask'),
148
+ }
149
+
150
+
151
+ def find_pairs(traj_root: Path, multi_root: Path) -> List[Tuple[str, Path, Path]]:
152
+ pairs = []
153
+ if not traj_root.is_dir():
154
+ return pairs
155
+ for name, (lang, mname) in _REPO_DIR_MAP.items():
156
+ tpath = traj_root / name
157
+ mpath = multi_root / lang / mname
158
+ if tpath.is_dir() and mpath.is_dir():
159
+ pairs.append((name, tpath, mpath))
160
+ return pairs
161
+
162
+
163
+ def main():
164
+ p = argparse.ArgumentParser()
165
+ p.add_argument('--traj-repos', required=True,
166
+ help='outputs/traj_real/repos dir from the transfer bundle')
167
+ p.add_argument('--multi-repos', required=True,
168
+ help='data_multilang dir used to build cache_v7')
169
+ p.add_argument('--n-pairs', type=int, default=3,
170
+ help='Number of (repo, commit) pairs to test')
171
+ p.add_argument('--output', default=None,
172
+ help='Write a JSON report here')
173
+ args = p.parse_args()
174
+
175
+ traj_root = Path(args.traj_repos)
176
+ multi_root = Path(args.multi_repos)
177
+
178
+ pairs = find_pairs(traj_root, multi_root)
179
+ if not pairs:
180
+ print(f'[parity] no common repos found under {traj_root} and '
181
+ f'{multi_root}'); sys.exit(1)
182
+ print(f'[parity] {len(pairs)} repo pairs available:')
183
+ for n, t, m in pairs:
184
+ print(f' {n:30s} traj={t} multi={m}')
185
+
186
+ # For each pair, pick a commit that exists in both. HEAD of the
187
+ # multi clone is a safe default since that clone has full history.
188
+ tests = []
189
+ for name, tpath, mpath in pairs[:args.n_pairs]:
190
+ mcommits = list_commits(mpath, n=5)
191
+ if not mcommits:
192
+ print(f'[parity] {name}: no commits in multi clone, skip')
193
+ continue
194
+ tests.append((name, tpath, mpath, mcommits[0]))
195
+
196
+ # Import embedder once — BERT load is slow.
197
+ from graphjepa.features import BertTokenEmbedder
198
+ print('\n[parity] loading BERT embedder ...')
199
+ embedder = BertTokenEmbedder(device='cpu')
200
+
201
+ results = []
202
+ for name, tpath, mpath, sha in tests:
203
+ print(f'\n[parity] === {name} @ {sha[:10]} ===')
204
+
205
+ print(f' checkout traj clone ...')
206
+ if not checkout(tpath, sha):
207
+ print(f' [parity] traj clone cannot reach {sha[:10]}; skip')
208
+ results.append({'repo': name, 'sha': sha, 'error': 'traj_checkout_failed'})
209
+ continue
210
+ print(f' checkout multi clone ...')
211
+ if not checkout(mpath, sha):
212
+ print(f' [parity] multi clone cannot reach {sha[:10]}; skip')
213
+ results.append({'repo': name, 'sha': sha, 'error': 'multi_checkout_failed'})
214
+ continue
215
+
216
+ print(f' snapshotting traj clone ...')
217
+ tg, tf, tstats = snapshot(tpath, embedder)
218
+ print(f' snapshotting multi clone ...')
219
+ mg, mf, mstats = snapshot(mpath, embedder)
220
+
221
+ match_g = tg == mg
222
+ match_f = tf == mf
223
+ print(f' graph hash traj={tg[:12]} multi={mg[:12]} '
224
+ f'{"MATCH" if match_g else "MISMATCH"}')
225
+ print(f' feature hash traj={tf[:12]} multi={mf[:12]} '
226
+ f'{"MATCH" if match_f else "MISMATCH"}')
227
+ print(f' stats traj={tstats} multi={mstats}')
228
+ results.append({
229
+ 'repo': name, 'sha': sha,
230
+ 'graph_match': match_g, 'feature_match': match_f,
231
+ 'traj_stats': tstats, 'multi_stats': mstats,
232
+ })
233
+
234
+ print('\n' + '=' * 60)
235
+ n_g = sum(1 for r in results if r.get('graph_match'))
236
+ n_f = sum(1 for r in results if r.get('feature_match'))
237
+ print(f'graph parity: {n_g}/{len(results)} matched')
238
+ print(f'feature parity: {n_f}/{len(results)} matched')
239
+ print('=' * 60)
240
+
241
+ if args.output:
242
+ with open(args.output, 'w') as f:
243
+ json.dump(results, f, indent=2)
244
+ print(f'[parity] report saved: {args.output}')
245
+
246
+ sys.exit(0 if (n_g == n_f == len(results) and results) else 1)
247
+
248
+
249
+ if __name__ == '__main__':
250
+ main()