""" Table/CSV Position Bias Target row at varying positions in a markdown table. """ import logging import os import random import time from typing import List, Dict, Any from tqdm import tqdm from src.generator import generate_text from src.utils import ensure_dir, save_jsonl, save_json logger = logging.getLogger(__name__) def _make_table(n: int, target_row: List[str], target_pos: int) -> str: """Generate markdown table with target row at position.""" headers = ["ID", "Name", "Value", "Status"] rows = [] for i in range(n): if i == target_pos: rows.append(target_row) else: rows.append([ f"ID-{random.randint(1000,9999)}", f"Item-{random.randint(1,99)}", f"{random.randint(1,1000)}", random.choice(["Active", "Inactive"]), ]) table = "| " + " | ".join(headers) + " |\n" table += "|" + "|".join(["---"] * len(headers)) + "|\n" for row in rows: table += "| " + " | ".join(row) + " |\n" return table def run_table_retrieval( model_name: str, num_rows: int, num_examples: int, out_dir: str, depths: List[float] = None, ) -> Dict[str, Any]: ensure_dir(out_dir) if depths is None: depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] results = {} start = time.time() for depth in depths: logger.info(f"[TABLE] Depth {depth:.1%}") preds = [] for _ in tqdm(range(num_examples), desc=f"Table {depth:.1%}", leave=False): target_id = f"GOLD-{random.randint(1000,9999)}" target_value = f"{random.randint(5000,9999)}" target_row = [target_id, "GoldenItem", target_value, "Gold"] pos = int(depth * (num_rows - 1)) table_str = _make_table(num_rows, target_row, pos) prompt = ( f"Find the 'Value' for the row where ID = '{target_id}' in the table below.\n\n" f"{table_str}\n\n" f"Value:" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=15, ) correct = 1.0 if target_value in ans else 0.0 preds.append({ "model_answer": ans, "correct": correct, "target_value": target_value, "depth": depth, }) save_jsonl(os.path.join(out_dir, f"table_depth_{depth}.jsonl"), preds) acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0 results[depth] = {"accuracy": acc, "predictions": preds} logger.info(f"[TABLE] Depth {depth:.1%}: acc={acc:.3f}") summary = { "experiment": "table_retrieval", "num_rows": num_rows, "num_examples": num_examples, "depths": {str(d): results[d]["accuracy"] for d in depths}, "time_minutes": (time.time() - start) / 60, } save_json(os.path.join(out_dir, "table_summary.json"), summary) logger.info(f"[TABLE] Time={(time.time()-start)/60:.1f} min") return summary