abhshkp's picture
Upload folder using huggingface_hub
959dfe5 verified
"""
JSON Array Position Bias
Target key-value pair at varying positions in a JSON array.
"""
import json as jsonlib
import logging
import os
import random
import time
import uuid
from typing import List, Dict, Any
from tqdm import tqdm
from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json
logger = logging.getLogger(__name__)
def _make_json_array(n: int, target_key: str, target_value: str, target_pos: int) -> str:
"""Generate JSON array with target KV pair at position."""
entries = []
for i in range(n):
if i == target_pos:
entries.append({"key": target_key, "value": target_value})
else:
entries.append({
"key": f"key_{uuid.uuid4().hex[:8]}",
"value": f"val_{uuid.uuid4().hex[:8]}",
})
return jsonlib.dumps({"records": entries}, indent=2)
def run_json_retrieval(
model_name: str,
num_entries: int,
num_examples: int,
out_dir: str,
depths: List[float] = None,
) -> Dict[str, Any]:
ensure_dir(out_dir)
if depths is None:
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
results = {}
start = time.time()
for depth in depths:
logger.info(f"[JSON] Depth {depth:.1%}")
preds = []
for _ in tqdm(range(num_examples), desc=f"JSON {depth:.1%}", leave=False):
target_key = f"gold_key_{uuid.uuid4().hex[:6]}"
target_value = f"gold_val_{uuid.uuid4().hex[:6]}"
pos = int(depth * (num_entries - 1))
json_str = _make_json_array(num_entries, target_key, target_value, pos)
prompt = (
f"Find the value for the key '{target_key}' in the JSON data below.\n\n"
f"```json\n{json_str}\n```\n\n"
f"Value:"
)
ans = generate_text(
[{"role": "user", "content": prompt}],
model_name=model_name,
max_new_tokens=20,
)
correct = 1.0 if target_value.lower() in ans.lower() else 0.0
preds.append({
"model_answer": ans,
"correct": correct,
"target_value": target_value,
"depth": depth,
})
save_jsonl(os.path.join(out_dir, f"json_depth_{depth}.jsonl"), preds)
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
results[depth] = {"accuracy": acc, "predictions": preds}
logger.info(f"[JSON] Depth {depth:.1%}: acc={acc:.3f}")
summary = {
"experiment": "json_retrieval",
"num_entries": num_entries,
"num_examples": num_examples,
"depths": {str(d): results[d]["accuracy"] for d in depths},
"time_minutes": (time.time() - start) / 60,
}
save_json(os.path.join(out_dir, "json_summary.json"), summary)
logger.info(f"[JSON] Time={(time.time()-start)/60:.1f} min")
return summary