File size: 9,722 Bytes
095c3d7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | """
Parity check: does the trajectory-eval shallow clone produce the same
polyglot-parsed graph + BERT features as the pre-existing big-machine
clone (data_multilang), for the same commit?
Runs on big where both clones exist. For each common (repo, commit) pair
it encounters, it snapshots the working tree from *both* clones, canonical-
hashes the graph structure + feature tensors, and reports match/mismatch.
A match confirms:
- git clone --filter=blob:none + checkout fetches the same file content
as the original full clone
- parse_repo_polyglot is deterministic w.r.t. the file tree (modulo
rglob ordering — we sort before hashing)
- BertTokenEmbedder is deterministic
Usage (on big):
python -m graphjepa.check_polyglot_parity \\
--traj-repos ./outputs/traj_real/repos \\
--multi-repos /raid/train/datasets/code-graph-v7/data_multilang \\
--n-pairs 4
If it picks a commit not present in the shallow clone's blobless ref set
(some base_commits may need lazy blob fetch), the script does the fetch
automatically via checkout.
"""
from __future__ import annotations
import argparse
import hashlib
import json
import subprocess
import sys
from pathlib import Path
from typing import List, Optional, Tuple
def run(cmd: List[str], cwd: Optional[Path] = None, check: bool = True):
r = subprocess.run(cmd, cwd=str(cwd) if cwd else None,
capture_output=True, text=True)
if check and r.returncode != 0:
raise RuntimeError(f'{" ".join(cmd)} failed: {r.stderr[-400:]}')
return r
def list_commits(repo_dir: Path, n: int = 20) -> List[str]:
r = run(['git', 'log', '--format=%H', '-n', str(n)], cwd=repo_dir)
return r.stdout.split()
def checkout(repo_dir: Path, sha: str) -> bool:
run(['git', 'reset', '--hard', '-q'], cwd=repo_dir, check=False)
run(['git', 'clean', '-fdx', '-q'], cwd=repo_dir, check=False)
r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir, check=False)
if r.returncode != 0:
# Try fetching the ref
run(['git', 'fetch', '-q', 'origin', sha], cwd=repo_dir, check=False)
r = run(['git', 'checkout', '-q', '--detach', sha], cwd=repo_dir,
check=False)
return r.returncode == 0
def canonical_hash(graph, features) -> Tuple[str, str, dict]:
"""Deterministic hash of (graph structure, feature tensors).
Sorts node IDs so walk order doesn't matter.
Returns (graph_hash, feature_hash, stats_dict).
"""
import torch
# Nodes: sort by id, hash (id, kind, content, type_description).
h_nodes = hashlib.sha256()
node_items = sorted(graph.nodes.items())
for nid, n in node_items:
h_nodes.update(nid.encode())
h_nodes.update(b'\x00')
h_nodes.update(getattr(n.kind, 'value', str(n.kind)).encode())
h_nodes.update(b'\x00')
h_nodes.update((n.content or '').encode())
h_nodes.update(b'\x00')
h_nodes.update((n.type_description or '').encode())
h_nodes.update(b'\x01')
# Edges: sort by (src, dst, kind).
edge_keys = sorted(
(e.src, e.dst, getattr(e.kind, 'value', str(e.kind)))
for e in graph.edges.values()
)
for src, dst, k in edge_keys:
h_nodes.update(f'E|{src}|{dst}|{k}|'.encode())
graph_hash = h_nodes.hexdigest()
# Feature tensors: for each kind in deterministic order, hash
# (sorted_ids, content_sum, type_sum, content_first_vec, type_first_vec).
h_feats = hashlib.sha256()
for kind, d in sorted((k, v) for k, v in features.items() if v is not None):
kind_str = getattr(kind, 'value', str(kind))
h_feats.update(kind_str.encode())
h_feats.update(b'\x00')
ids = list(d['ids'])
sort_idx = sorted(range(len(ids)), key=lambda i: ids[i])
content = d['content'][sort_idx] if sort_idx else d['content']
typev = d['type'][sort_idx] if sort_idx else d['type']
sorted_ids = [ids[i] for i in sort_idx]
for sid in sorted_ids:
h_feats.update(sid.encode()); h_feats.update(b'\x00')
# Digest feature tensors numerically with fixed precision so
# hashes match across float ops that might differ in trailing ULP.
content_q = (content * 1e5).round().to(torch.int64)
typev_q = (typev * 1e5).round().to(torch.int64)
h_feats.update(content_q.cpu().numpy().tobytes())
h_feats.update(typev_q.cpu().numpy().tobytes())
feat_hash = h_feats.hexdigest()
stats = {
'n_nodes': len(graph.nodes),
'n_edges': len(graph.edges),
'n_feat_kinds': sum(1 for v in features.values() if v is not None),
'feat_dim': next((v['content'].shape[1] for v in features.values()
if v is not None), None),
}
return graph_hash, feat_hash, stats
def snapshot(repo_dir: Path, embedder) -> Tuple[str, str, dict]:
from graphjepa.trajectory_pipeline import snapshot_working_tree
g, feats = snapshot_working_tree(repo_dir, embedder, verbose=False)
return canonical_hash(g, feats)
# Mapping from trajectory-eval repo dirname → data_multilang subpath.
# traj repos: django__django; data_multilang: python/django
_REPO_DIR_MAP = {
'django__django': ('python', 'django'),
'sympy__sympy': ('python', 'sympy'),
'sphinx-doc__sphinx': ('python', 'sphinx'),
'matplotlib__matplotlib': ('python', 'matplotlib'),
'scikit-learn__scikit-learn': ('python', 'scikit-learn'),
'astropy__astropy': ('python', 'astropy'),
'pydata__xarray': ('python', 'xarray'),
'pytest-dev__pytest': ('python', 'pytest'),
'pylint-dev__pylint': ('python', 'pylint'),
'psf__requests': ('python', 'requests'),
'mwaskom__seaborn': ('python', 'seaborn'),
'pallets__flask': ('python', 'flask'),
}
def find_pairs(traj_root: Path, multi_root: Path) -> List[Tuple[str, Path, Path]]:
pairs = []
if not traj_root.is_dir():
return pairs
for name, (lang, mname) in _REPO_DIR_MAP.items():
tpath = traj_root / name
mpath = multi_root / lang / mname
if tpath.is_dir() and mpath.is_dir():
pairs.append((name, tpath, mpath))
return pairs
def main():
p = argparse.ArgumentParser()
p.add_argument('--traj-repos', required=True,
help='outputs/traj_real/repos dir from the transfer bundle')
p.add_argument('--multi-repos', required=True,
help='data_multilang dir used to build cache_v7')
p.add_argument('--n-pairs', type=int, default=3,
help='Number of (repo, commit) pairs to test')
p.add_argument('--output', default=None,
help='Write a JSON report here')
args = p.parse_args()
traj_root = Path(args.traj_repos)
multi_root = Path(args.multi_repos)
pairs = find_pairs(traj_root, multi_root)
if not pairs:
print(f'[parity] no common repos found under {traj_root} and '
f'{multi_root}'); sys.exit(1)
print(f'[parity] {len(pairs)} repo pairs available:')
for n, t, m in pairs:
print(f' {n:30s} traj={t} multi={m}')
# For each pair, pick a commit that exists in both. HEAD of the
# multi clone is a safe default since that clone has full history.
tests = []
for name, tpath, mpath in pairs[:args.n_pairs]:
mcommits = list_commits(mpath, n=5)
if not mcommits:
print(f'[parity] {name}: no commits in multi clone, skip')
continue
tests.append((name, tpath, mpath, mcommits[0]))
# Import embedder once — BERT load is slow.
from graphjepa.features import BertTokenEmbedder
print('\n[parity] loading BERT embedder ...')
embedder = BertTokenEmbedder(device='cpu')
results = []
for name, tpath, mpath, sha in tests:
print(f'\n[parity] === {name} @ {sha[:10]} ===')
print(f' checkout traj clone ...')
if not checkout(tpath, sha):
print(f' [parity] traj clone cannot reach {sha[:10]}; skip')
results.append({'repo': name, 'sha': sha, 'error': 'traj_checkout_failed'})
continue
print(f' checkout multi clone ...')
if not checkout(mpath, sha):
print(f' [parity] multi clone cannot reach {sha[:10]}; skip')
results.append({'repo': name, 'sha': sha, 'error': 'multi_checkout_failed'})
continue
print(f' snapshotting traj clone ...')
tg, tf, tstats = snapshot(tpath, embedder)
print(f' snapshotting multi clone ...')
mg, mf, mstats = snapshot(mpath, embedder)
match_g = tg == mg
match_f = tf == mf
print(f' graph hash traj={tg[:12]} multi={mg[:12]} '
f'{"MATCH" if match_g else "MISMATCH"}')
print(f' feature hash traj={tf[:12]} multi={mf[:12]} '
f'{"MATCH" if match_f else "MISMATCH"}')
print(f' stats traj={tstats} multi={mstats}')
results.append({
'repo': name, 'sha': sha,
'graph_match': match_g, 'feature_match': match_f,
'traj_stats': tstats, 'multi_stats': mstats,
})
print('\n' + '=' * 60)
n_g = sum(1 for r in results if r.get('graph_match'))
n_f = sum(1 for r in results if r.get('feature_match'))
print(f'graph parity: {n_g}/{len(results)} matched')
print(f'feature parity: {n_f}/{len(results)} matched')
print('=' * 60)
if args.output:
with open(args.output, 'w') as f:
json.dump(results, f, indent=2)
print(f'[parity] report saved: {args.output}')
sys.exit(0 if (n_g == n_f == len(results) and results) else 1)
if __name__ == '__main__':
main()
|