File size: 2,986 Bytes
959dfe5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
JSON Array Position Bias
Target key-value pair at varying positions in a JSON array.
"""
import json as jsonlib
import logging
import os
import random
import time
import uuid
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)


def _make_json_array(n: int, target_key: str, target_value: str, target_pos: int) -> str:
    """Generate JSON array with target KV pair at position."""
    entries = []
    for i in range(n):
        if i == target_pos:
            entries.append({"key": target_key, "value": target_value})
        else:
            entries.append({
                "key": f"key_{uuid.uuid4().hex[:8]}",
                "value": f"val_{uuid.uuid4().hex[:8]}",
            })
    return jsonlib.dumps({"records": entries}, indent=2)


def run_json_retrieval(
    model_name: str,
    num_entries: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    ensure_dir(out_dir)
    if depths is None:
        depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[JSON] Depth {depth:.1%}")
        preds = []
        for _ in tqdm(range(num_examples), desc=f"JSON {depth:.1%}", leave=False):
            target_key = f"gold_key_{uuid.uuid4().hex[:6]}"
            target_value = f"gold_val_{uuid.uuid4().hex[:6]}"
            pos = int(depth * (num_entries - 1))
            json_str = _make_json_array(num_entries, target_key, target_value, pos)

            prompt = (
                f"Find the value for the key '{target_key}' in the JSON data below.\n\n"
                f"```json\n{json_str}\n```\n\n"
                f"Value:"
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=20,
            )
            correct = 1.0 if target_value.lower() in ans.lower() else 0.0
            preds.append({
                "model_answer": ans,
                "correct": correct,
                "target_value": target_value,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"json_depth_{depth}.jsonl"), preds)
        acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[JSON] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "json_retrieval",
        "num_entries": num_entries,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "time_minutes": (time.time() - start) / 60,
    }
    save_json(os.path.join(out_dir, "json_summary.json"), summary)
    logger.info(f"[JSON] Time={(time.time()-start)/60:.1f} min")
    return summary