File size: 3,183 Bytes
959dfe5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Table/CSV Position Bias
Target row at varying positions in a markdown table.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)


def _make_table(n: int, target_row: List[str], target_pos: int) -> str:
    """Generate markdown table with target row at position."""
    headers = ["ID", "Name", "Value", "Status"]
    rows = []
    for i in range(n):
        if i == target_pos:
            rows.append(target_row)
        else:
            rows.append([
                f"ID-{random.randint(1000,9999)}",
                f"Item-{random.randint(1,99)}",
                f"{random.randint(1,1000)}",
                random.choice(["Active", "Inactive"]),
            ])

    table = "| " + " | ".join(headers) + " |\n"
    table += "|" + "|".join(["---"] * len(headers)) + "|\n"
    for row in rows:
        table += "| " + " | ".join(row) + " |\n"
    return table


def run_table_retrieval(
    model_name: str,
    num_rows: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    ensure_dir(out_dir)
    if depths is None:
        depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[TABLE] Depth {depth:.1%}")
        preds = []
        for _ in tqdm(range(num_examples), desc=f"Table {depth:.1%}", leave=False):
            target_id = f"GOLD-{random.randint(1000,9999)}"
            target_value = f"{random.randint(5000,9999)}"
            target_row = [target_id, "GoldenItem", target_value, "Gold"]
            pos = int(depth * (num_rows - 1))
            table_str = _make_table(num_rows, target_row, pos)

            prompt = (
                f"Find the 'Value' for the row where ID = '{target_id}' in the table below.\n\n"
                f"{table_str}\n\n"
                f"Value:"
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=15,
            )
            correct = 1.0 if target_value in ans else 0.0
            preds.append({
                "model_answer": ans,
                "correct": correct,
                "target_value": target_value,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"table_depth_{depth}.jsonl"), preds)
        acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[TABLE] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "table_retrieval",
        "num_rows": num_rows,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "time_minutes": (time.time() - start) / 60,
    }
    save_json(os.path.join(out_dir, "table_summary.json"), summary)
    logger.info(f"[TABLE] Time={(time.time()-start)/60:.1f} min")
    return summary