Spaces:
Runtime error
Runtime error
File size: 4,809 Bytes
d269a32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | #!/usr/bin/env python3
"""
AEGIS Dataset Auditor
Usage: python scripts/audit_dataset.py <dataset.json>
Exits with code 1 if critical issues are found.
"""
import json
import sys
import random
import hashlib
from collections import Counter, defaultdict
def compute_signature(entry: dict) -> str:
raw = (
entry.get("worker_cot_trace", "")
+ "||"
+ entry.get("worker_output", "")
+ "||"
+ entry.get("decision", "")
+ "||"
+ entry.get("violation_type", "")
)
return hashlib.md5(raw.encode("utf-8")).hexdigest()
def audit(path: str) -> int:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
total = len(data)
print("=" * 60)
print(f"AEGIS DATASET AUDIT: {path}")
print("=" * 60)
# 1. Total row count
print(f"\n[1] TOTAL ROWS: {total}")
# 2. Label distribution
decision_counts = Counter(d["decision"] for d in data)
print("\n[2] LABEL DISTRIBUTION")
all_labels = ["ALLOW", "BLOCK", "ESCALATE"]
for label in all_labels:
count = decision_counts.get(label, 0)
pct = count / total * 100 if total > 0 else 0.0
print(f" {label:10s}: {count:5d} ({pct:.1f}%)")
# 3. Flag missing classes
missing_classes = [lbl for lbl in all_labels if decision_counts.get(lbl, 0) == 0]
if missing_classes:
print(f"\n *** CRITICAL: Missing label class(es): {', '.join(missing_classes)} ***")
# 4 & 5. Signatures and duplicates
sigs = [compute_signature(d) for d in data]
sig_counts = Counter(sigs)
dup_sigs = {s: c for s, c in sig_counts.items() if c > 1}
dup_row_count = sum(c - 1 for c in dup_sigs.values())
dup_pct = dup_row_count / total * 100 if total > 0 else 0.0
print(f"\n[4-5] DUPLICATE ANALYSIS")
print(f" Duplicate rows (extra copies): {dup_row_count} ({dup_pct:.1f}%)")
print(f" Unique signatures: {len(sig_counts)}")
top5_groups = sorted(dup_sigs.values(), reverse=True)[:5]
if top5_groups:
print(f" Top-5 duplicate group sizes: {top5_groups}")
else:
print(" No duplicate groups found.")
# 6. Unique cot_trace and worker_output
unique_cots = len(set(d["worker_cot_trace"] for d in data))
unique_outputs = len(set(d["worker_output"] for d in data))
print(f"\n[6] UNIQUENESS")
print(f" Unique worker_cot_trace : {unique_cots} / {total} ({unique_cots/total*100:.1f}%)")
print(f" Unique worker_output : {unique_outputs} / {total} ({unique_outputs/total*100:.1f}%)")
# 7. Train/eval split leakage (seed=42, 80/20)
indices = list(range(total))
random.seed(42)
random.shuffle(indices)
train_end = int(total * 0.8)
train_idx = set(indices[:train_end])
eval_idx = set(indices[train_end:])
train_sigs = set(sigs[i] for i in train_idx)
eval_sigs = [sigs[i] for i in eval_idx]
leaked = sum(1 for s in eval_sigs if s in train_sigs)
overlap_pct = leaked / len(eval_sigs) * 100 if eval_sigs else 0.0
print(f"\n[7] TRAIN/EVAL SPLIT LEAKAGE (seed=42, 80/20)")
print(f" Train rows : {len(train_idx)}")
print(f" Eval rows : {len(eval_sigs)}")
print(f" Eval rows whose signature appears in train: {leaked} ({overlap_pct:.1f}%)")
# 8. Violation type distribution
vtype_counts = Counter(d.get("violation_type", "unknown") for d in data)
print(f"\n[8] VIOLATION TYPE DISTRIBUTION")
for vt, cnt in sorted(vtype_counts.items(), key=lambda x: -x[1]):
print(f" {vt:35s}: {cnt:5d} ({cnt/total*100:.1f}%)")
# 9. Level distribution
level_counts = Counter(d.get("level", "?") for d in data)
print(f"\n[9] LEVEL DISTRIBUTION")
for lvl, cnt in sorted(level_counts.items()):
print(f" Level {lvl}: {cnt:5d} ({cnt/total*100:.1f}%)")
# 10. Critical checks
critical_issues = []
if "ESCALATE" in missing_classes:
critical_issues.append("ESCALATE class is entirely missing — objective mismatch with 3-class model")
if dup_pct > 30.0:
critical_issues.append(f"Duplicate rate {dup_pct:.1f}% exceeds 30% threshold")
if overlap_pct > 50.0:
critical_issues.append(f"Train/eval overlap {overlap_pct:.1f}% exceeds 50% — severe data leakage")
print("\n" + "=" * 60)
if critical_issues:
print("CRITICAL ISSUES FOUND:")
for issue in critical_issues:
print(f" [CRITICAL] {issue}")
print("=" * 60)
return 1
else:
print("No critical issues found.")
print("=" * 60)
return 0
def main():
if len(sys.argv) < 2:
print("Usage: python scripts/audit_dataset.py <dataset.json>")
sys.exit(1)
exit_code = audit(sys.argv[1])
sys.exit(exit_code)
if __name__ == "__main__":
main()
|