""" JSON Array Position Bias Target key-value pair at varying positions in a JSON array. """ import json as jsonlib import logging import os import random import time import uuid from typing import List, Dict, Any from tqdm import tqdm from src.generator import generate_text from src.utils import ensure_dir, save_jsonl, save_json logger = logging.getLogger(__name__) def _make_json_array(n: int, target_key: str, target_value: str, target_pos: int) -> str: """Generate JSON array with target KV pair at position.""" entries = [] for i in range(n): if i == target_pos: entries.append({"key": target_key, "value": target_value}) else: entries.append({ "key": f"key_{uuid.uuid4().hex[:8]}", "value": f"val_{uuid.uuid4().hex[:8]}", }) return jsonlib.dumps({"records": entries}, indent=2) def run_json_retrieval( model_name: str, num_entries: int, num_examples: int, out_dir: str, depths: List[float] = None, ) -> Dict[str, Any]: ensure_dir(out_dir) if depths is None: depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] results = {} start = time.time() for depth in depths: logger.info(f"[JSON] Depth {depth:.1%}") preds = [] for _ in tqdm(range(num_examples), desc=f"JSON {depth:.1%}", leave=False): target_key = f"gold_key_{uuid.uuid4().hex[:6]}" target_value = f"gold_val_{uuid.uuid4().hex[:6]}" pos = int(depth * (num_entries - 1)) json_str = _make_json_array(num_entries, target_key, target_value, pos) prompt = ( f"Find the value for the key '{target_key}' in the JSON data below.\n\n" f"```json\n{json_str}\n```\n\n" f"Value:" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=20, ) correct = 1.0 if target_value.lower() in ans.lower() else 0.0 preds.append({ "model_answer": ans, "correct": correct, "target_value": target_value, "depth": depth, }) save_jsonl(os.path.join(out_dir, f"json_depth_{depth}.jsonl"), preds) acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0 results[depth] = {"accuracy": acc, "predictions": preds} logger.info(f"[JSON] Depth {depth:.1%}: acc={acc:.3f}") summary = { "experiment": "json_retrieval", "num_entries": num_entries, "num_examples": num_examples, "depths": {str(d): results[d]["accuracy"] for d in depths}, "time_minutes": (time.time() - start) / 60, } save_json(os.path.join(out_dir, "json_summary.json"), summary) logger.info(f"[JSON] Time={(time.time()-start)/60:.1f} min") return summary