abhshkp's picture
Upload folder using huggingface_hub
959dfe5 verified
"""
Table/CSV Position Bias
Target row at varying positions in a markdown table.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any
from tqdm import tqdm
from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json
logger = logging.getLogger(__name__)
def _make_table(n: int, target_row: List[str], target_pos: int) -> str:
"""Generate markdown table with target row at position."""
headers = ["ID", "Name", "Value", "Status"]
rows = []
for i in range(n):
if i == target_pos:
rows.append(target_row)
else:
rows.append([
f"ID-{random.randint(1000,9999)}",
f"Item-{random.randint(1,99)}",
f"{random.randint(1,1000)}",
random.choice(["Active", "Inactive"]),
])
table = "| " + " | ".join(headers) + " |\n"
table += "|" + "|".join(["---"] * len(headers)) + "|\n"
for row in rows:
table += "| " + " | ".join(row) + " |\n"
return table
def run_table_retrieval(
model_name: str,
num_rows: int,
num_examples: int,
out_dir: str,
depths: List[float] = None,
) -> Dict[str, Any]:
ensure_dir(out_dir)
if depths is None:
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
results = {}
start = time.time()
for depth in depths:
logger.info(f"[TABLE] Depth {depth:.1%}")
preds = []
for _ in tqdm(range(num_examples), desc=f"Table {depth:.1%}", leave=False):
target_id = f"GOLD-{random.randint(1000,9999)}"
target_value = f"{random.randint(5000,9999)}"
target_row = [target_id, "GoldenItem", target_value, "Gold"]
pos = int(depth * (num_rows - 1))
table_str = _make_table(num_rows, target_row, pos)
prompt = (
f"Find the 'Value' for the row where ID = '{target_id}' in the table below.\n\n"
f"{table_str}\n\n"
f"Value:"
)
ans = generate_text(
[{"role": "user", "content": prompt}],
model_name=model_name,
max_new_tokens=15,
)
correct = 1.0 if target_value in ans else 0.0
preds.append({
"model_answer": ans,
"correct": correct,
"target_value": target_value,
"depth": depth,
})
save_jsonl(os.path.join(out_dir, f"table_depth_{depth}.jsonl"), preds)
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
results[depth] = {"accuracy": acc, "predictions": preds}
logger.info(f"[TABLE] Depth {depth:.1%}: acc={acc:.3f}")
summary = {
"experiment": "table_retrieval",
"num_rows": num_rows,
"num_examples": num_examples,
"depths": {str(d): results[d]["accuracy"] for d in depths},
"time_minutes": (time.time() - start) / 60,
}
save_json(os.path.join(out_dir, "table_summary.json"), summary)
logger.info(f"[TABLE] Time={(time.time()-start)/60:.1f} min")
return summary