File size: 9,722 Bytes
095c3d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Parity check: does the trajectory-eval shallow clone produce the same
polyglot-parsed graph + BERT features as the pre-existing big-machine
clone (data_multilang), for the same commit?

Runs on big where both clones exist. For each common (repo, commit) pair
it encounters, it snapshots the working tree from *both* clones, canonical-
hashes the graph structure + feature tensors, and reports match/mismatch.

A match confirms:
  - git clone --filter=blob:none + checkout fetches the same file content
    as the original full clone
  - parse_repo_polyglot is deterministic w.r.t. the file tree (modulo
    rglob ordering — we sort before hashing)
  - BertTokenEmbedder is deterministic

Usage (on big):
    python -m graphjepa.check_polyglot_parity \\
        --traj-repos  ./outputs/traj_real/repos \\
        --multi-repos /raid/train/datasets/code-graph-v7/data_multilang \\
        --n-pairs 4

If it picks a commit not present in the shallow clone's blobless ref set
(some base_commits may need lazy blob fetch), the script does the fetch
automatically via checkout.
"""

from __future__ import annotations

import argparse
import hashlib
import json
import subprocess
import sys
from pathlib import Path
from typing import List, Optional, Tuple


def run(cmd: List[str], cwd: Optional[Path] = None, check: bool = True):
    r = subprocess.run(cmd, cwd=str(cwd) if cwd else None,
                        capture_output=True, text=True)
    if check and r.returncode != 0:
        raise RuntimeError(f'{" ".join(cmd)} failed: {r.stderr[-400:]}')
    return r


def list_commits(repo_dir: Path, n: int = 20) -> List[str]:
    r = run(['git', 'log', '--format=%H', '-n', str(n)], cwd=repo_dir)
    return r.stdout.split()


def checkout(repo_dir: Path, sha: str) -> bool:
    run(['git', 'reset', '--hard', '-q'], cwd=repo_dir, check=False)
    run(['git', 'clean', '-fdx', '-q'], cwd=repo_dir, check=False)
    r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, check=False)
    if r.returncode != 0:
        # Try fetching the ref
        run(['git', 'fetch', '-q', 'origin', sha], cwd=repo_dir, check=False)
        r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir,
                 check=False)
    return r.returncode == 0


def canonical_hash(graph, features) -> Tuple[str, str, dict]:
    """Deterministic hash of (graph structure, feature tensors).
    Sorts node IDs so walk order doesn't matter.
    Returns (graph_hash, feature_hash, stats_dict).
    """
    import torch

    # Nodes: sort by id, hash (id, kind, content, type_description).
    h_nodes = hashlib.sha256()
    node_items = sorted(graph.nodes.items())
    for nid, n in node_items:
        h_nodes.update(nid.encode())
        h_nodes.update(b'\x00')
        h_nodes.update(getattr(n.kind, 'value', str(n.kind)).encode())
        h_nodes.update(b'\x00')
        h_nodes.update((n.content or '').encode())
        h_nodes.update(b'\x00')
        h_nodes.update((n.type_description or '').encode())
        h_nodes.update(b'\x01')

    # Edges: sort by (src, dst, kind).
    edge_keys = sorted(
        (e.src, e.dst, getattr(e.kind, 'value', str(e.kind)))
        for e in graph.edges.values()
    )
    for src, dst, k in edge_keys:
        h_nodes.update(f'E|{src}|{dst}|{k}|'.encode())

    graph_hash = h_nodes.hexdigest()

    # Feature tensors: for each kind in deterministic order, hash
    # (sorted_ids, content_sum, type_sum, content_first_vec, type_first_vec).
    h_feats = hashlib.sha256()
    for kind, d in sorted((k, v) for k, v in features.items() if v is not None):
        kind_str = getattr(kind, 'value', str(kind))
        h_feats.update(kind_str.encode())
        h_feats.update(b'\x00')
        ids = list(d['ids'])
        sort_idx = sorted(range(len(ids)), key=lambda i: ids[i])
        content = d['content'][sort_idx] if sort_idx else d['content']
        typev   = d['type'][sort_idx] if sort_idx else d['type']
        sorted_ids = [ids[i] for i in sort_idx]
        for sid in sorted_ids:
            h_feats.update(sid.encode()); h_feats.update(b'\x00')
        # Digest feature tensors numerically with fixed precision so
        # hashes match across float ops that might differ in trailing ULP.
        content_q = (content * 1e5).round().to(torch.int64)
        typev_q   = (typev   * 1e5).round().to(torch.int64)
        h_feats.update(content_q.cpu().numpy().tobytes())
        h_feats.update(typev_q.cpu().numpy().tobytes())

    feat_hash = h_feats.hexdigest()

    stats = {
        'n_nodes': len(graph.nodes),
        'n_edges': len(graph.edges),
        'n_feat_kinds': sum(1 for v in features.values() if v is not None),
        'feat_dim': next((v['content'].shape[1] for v in features.values()
                          if v is not None), None),
    }
    return graph_hash, feat_hash, stats


def snapshot(repo_dir: Path, embedder) -> Tuple[str, str, dict]:
    from graphjepa.trajectory_pipeline import snapshot_working_tree
    g, feats = snapshot_working_tree(repo_dir, embedder, verbose=False)
    return canonical_hash(g, feats)


# Mapping from trajectory-eval repo dirname → data_multilang subpath.
# traj repos: django__django; data_multilang: python/django
_REPO_DIR_MAP = {
    'django__django': ('python', 'django'),
    'sympy__sympy': ('python', 'sympy'),
    'sphinx-doc__sphinx': ('python', 'sphinx'),
    'matplotlib__matplotlib': ('python', 'matplotlib'),
    'scikit-learn__scikit-learn': ('python', 'scikit-learn'),
    'astropy__astropy': ('python', 'astropy'),
    'pydata__xarray': ('python', 'xarray'),
    'pytest-dev__pytest': ('python', 'pytest'),
    'pylint-dev__pylint': ('python', 'pylint'),
    'psf__requests': ('python', 'requests'),
    'mwaskom__seaborn': ('python', 'seaborn'),
    'pallets__flask': ('python', 'flask'),
}


def find_pairs(traj_root: Path, multi_root: Path) -> List[Tuple[str, Path, Path]]:
    pairs = []
    if not traj_root.is_dir():
        return pairs
    for name, (lang, mname) in _REPO_DIR_MAP.items():
        tpath = traj_root / name
        mpath = multi_root / lang / mname
        if tpath.is_dir() and mpath.is_dir():
            pairs.append((name, tpath, mpath))
    return pairs


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--traj-repos', required=True,
                   help='outputs/traj_real/repos dir from the transfer bundle')
    p.add_argument('--multi-repos', required=True,
                   help='data_multilang dir used to build cache_v7')
    p.add_argument('--n-pairs', type=int, default=3,
                   help='Number of (repo, commit) pairs to test')
    p.add_argument('--output', default=None,
                   help='Write a JSON report here')
    args = p.parse_args()

    traj_root = Path(args.traj_repos)
    multi_root = Path(args.multi_repos)

    pairs = find_pairs(traj_root, multi_root)
    if not pairs:
        print(f'[parity] no common repos found under {traj_root} and '
              f'{multi_root}'); sys.exit(1)
    print(f'[parity] {len(pairs)} repo pairs available:')
    for n, t, m in pairs:
        print(f'  {n:30s}  traj={t}  multi={m}')

    # For each pair, pick a commit that exists in both. HEAD of the
    # multi clone is a safe default since that clone has full history.
    tests = []
    for name, tpath, mpath in pairs[:args.n_pairs]:
        mcommits = list_commits(mpath, n=5)
        if not mcommits:
            print(f'[parity] {name}: no commits in multi clone, skip')
            continue
        tests.append((name, tpath, mpath, mcommits[0]))

    # Import embedder once — BERT load is slow.
    from graphjepa.features import BertTokenEmbedder
    print('\n[parity] loading BERT embedder ...')
    embedder = BertTokenEmbedder(device='cpu')

    results = []
    for name, tpath, mpath, sha in tests:
        print(f'\n[parity] === {name} @ {sha[:10]} ===')

        print(f'  checkout traj clone ...')
        if not checkout(tpath, sha):
            print(f'  [parity] traj clone cannot reach {sha[:10]}; skip')
            results.append({'repo': name, 'sha': sha, 'error': 'traj_checkout_failed'})
            continue
        print(f'  checkout multi clone ...')
        if not checkout(mpath, sha):
            print(f'  [parity] multi clone cannot reach {sha[:10]}; skip')
            results.append({'repo': name, 'sha': sha, 'error': 'multi_checkout_failed'})
            continue

        print(f'  snapshotting traj clone ...')
        tg, tf, tstats = snapshot(tpath, embedder)
        print(f'  snapshotting multi clone ...')
        mg, mf, mstats = snapshot(mpath, embedder)

        match_g = tg == mg
        match_f = tf == mf
        print(f'  graph hash   traj={tg[:12]}  multi={mg[:12]}  '
              f'{"MATCH" if match_g else "MISMATCH"}')
        print(f'  feature hash traj={tf[:12]}  multi={mf[:12]}  '
              f'{"MATCH" if match_f else "MISMATCH"}')
        print(f'  stats traj={tstats}  multi={mstats}')
        results.append({
            'repo': name, 'sha': sha,
            'graph_match': match_g, 'feature_match': match_f,
            'traj_stats': tstats, 'multi_stats': mstats,
        })

    print('\n' + '=' * 60)
    n_g = sum(1 for r in results if r.get('graph_match'))
    n_f = sum(1 for r in results if r.get('feature_match'))
    print(f'graph parity:   {n_g}/{len(results)} matched')
    print(f'feature parity: {n_f}/{len(results)} matched')
    print('=' * 60)

    if args.output:
        with open(args.output, 'w') as f:
            json.dump(results, f, indent=2)
        print(f'[parity] report saved: {args.output}')

    sys.exit(0 if (n_g == n_f == len(results) and results) else 1)


if __name__ == '__main__':
    main()