Upload folder using huggingface_hub
Browse files- README.md +28 -17
- experiments/__init__.py +1 -0
- experiments/json_retrieval.py +90 -0
- experiments/log_file_retrieval.py +102 -0
- experiments/table_retrieval.py +97 -0
- requirements.txt +5 -0
- run_all.py +72 -0
- src/__init__.py +2 -0
- src/generator.py +53 -0
- src/utils.py +19 -0
README.md
CHANGED
|
@@ -1,26 +1,37 @@
|
|
| 1 |
-
|
| 2 |
-
tags:
|
| 3 |
-
- ml-intern
|
| 4 |
-
---
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
| 9 |
-
## Generated by ML Intern
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
-
|
| 14 |
-
- Source code: https://github.com/huggingface/ml-intern
|
| 15 |
|
| 16 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
```
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Structured Data Position Bias Benchmark
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
+
Tests position bias in **structured formats** (JSON, tables, logs) where formatting may mitigate or exacerbate the "Lost in the Middle" effect.
|
| 4 |
|
| 5 |
+
## Research Question
|
|
|
|
| 6 |
|
| 7 |
+
> Does structured formatting (JSON, tables, logs) reduce position bias compared to unstructured prose? Or does the visual/structural regularity make middle-position items harder to find?
|
| 8 |
|
| 9 |
+
## Experiments
|
|
|
|
| 10 |
|
| 11 |
+
| # | Format | Target | Hypothesis |
|
| 12 |
+
|---|--------|--------|-----------|
|
| 13 |
+
| 1 | **JSON Array** | Key-value pair | Structured nesting may reduce bias |
|
| 14 |
+
| 2 | **Markdown Table** | Row value | Tabular structure provides visual anchors |
|
| 15 |
+
| 3 | **Log File** | Error code | Timestamp ordering may create temporal bias |
|
| 16 |
|
| 17 |
+
## Usage
|
|
|
|
| 18 |
|
| 19 |
+
```bash
|
| 20 |
+
pip install -r requirements.txt
|
| 21 |
+
python run_all.py --model Qwen/Qwen2.5-1.5B-Instruct --num-items 100 --num-examples 50
|
| 22 |
```
|
| 23 |
|
| 24 |
+
## Expected Finding
|
| 25 |
+
|
| 26 |
+
> "Position Bias Index is significantly lower in tabular formats (PBI=0.18) than in JSON arrays (PBI=0.35) or prose (PBI=0.42), suggesting visual structure mitigates positional bias."
|
| 27 |
+
|
| 28 |
+
## Citation
|
| 29 |
+
|
| 30 |
+
```bibtex
|
| 31 |
+
@software{structured_data_position_bias,
|
| 32 |
+
title={Structured Data Position Bias: How Format Affects Long-Context Retrieval},
|
| 33 |
+
author={abhshkp},
|
| 34 |
+
year={2026},
|
| 35 |
+
url={https://huggingface.co/abhshkp/structured-data-position-bias}
|
| 36 |
+
}
|
| 37 |
+
```
|
experiments/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Structured data experiments."""
|
experiments/json_retrieval.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
JSON Array Position Bias
|
| 3 |
+
Target key-value pair at varying positions in a JSON array.
|
| 4 |
+
"""
|
| 5 |
+
import json as jsonlib
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
from typing import List, Dict, Any
|
| 12 |
+
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from src.generator import generate_text
|
| 16 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _make_json_array(n: int, target_key: str, target_value: str, target_pos: int) -> str:
|
| 22 |
+
"""Generate JSON array with target KV pair at position."""
|
| 23 |
+
entries = []
|
| 24 |
+
for i in range(n):
|
| 25 |
+
if i == target_pos:
|
| 26 |
+
entries.append({"key": target_key, "value": target_value})
|
| 27 |
+
else:
|
| 28 |
+
entries.append({
|
| 29 |
+
"key": f"key_{uuid.uuid4().hex[:8]}",
|
| 30 |
+
"value": f"val_{uuid.uuid4().hex[:8]}",
|
| 31 |
+
})
|
| 32 |
+
return jsonlib.dumps({"records": entries}, indent=2)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def run_json_retrieval(
|
| 36 |
+
model_name: str,
|
| 37 |
+
num_entries: int,
|
| 38 |
+
num_examples: int,
|
| 39 |
+
out_dir: str,
|
| 40 |
+
depths: List[float] = None,
|
| 41 |
+
) -> Dict[str, Any]:
|
| 42 |
+
ensure_dir(out_dir)
|
| 43 |
+
if depths is None:
|
| 44 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 45 |
+
|
| 46 |
+
results = {}
|
| 47 |
+
start = time.time()
|
| 48 |
+
|
| 49 |
+
for depth in depths:
|
| 50 |
+
logger.info(f"[JSON] Depth {depth:.1%}")
|
| 51 |
+
preds = []
|
| 52 |
+
for _ in tqdm(range(num_examples), desc=f"JSON {depth:.1%}", leave=False):
|
| 53 |
+
target_key = f"gold_key_{uuid.uuid4().hex[:6]}"
|
| 54 |
+
target_value = f"gold_val_{uuid.uuid4().hex[:6]}"
|
| 55 |
+
pos = int(depth * (num_entries - 1))
|
| 56 |
+
json_str = _make_json_array(num_entries, target_key, target_value, pos)
|
| 57 |
+
|
| 58 |
+
prompt = (
|
| 59 |
+
f"Find the value for the key '{target_key}' in the JSON data below.\n\n"
|
| 60 |
+
f"```json\n{json_str}\n```\n\n"
|
| 61 |
+
f"Value:"
|
| 62 |
+
)
|
| 63 |
+
ans = generate_text(
|
| 64 |
+
[{"role": "user", "content": prompt}],
|
| 65 |
+
model_name=model_name,
|
| 66 |
+
max_new_tokens=20,
|
| 67 |
+
)
|
| 68 |
+
correct = 1.0 if target_value.lower() in ans.lower() else 0.0
|
| 69 |
+
preds.append({
|
| 70 |
+
"model_answer": ans,
|
| 71 |
+
"correct": correct,
|
| 72 |
+
"target_value": target_value,
|
| 73 |
+
"depth": depth,
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
save_jsonl(os.path.join(out_dir, f"json_depth_{depth}.jsonl"), preds)
|
| 77 |
+
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
|
| 78 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 79 |
+
logger.info(f"[JSON] Depth {depth:.1%}: acc={acc:.3f}")
|
| 80 |
+
|
| 81 |
+
summary = {
|
| 82 |
+
"experiment": "json_retrieval",
|
| 83 |
+
"num_entries": num_entries,
|
| 84 |
+
"num_examples": num_examples,
|
| 85 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 86 |
+
"time_minutes": (time.time() - start) / 60,
|
| 87 |
+
}
|
| 88 |
+
save_json(os.path.join(out_dir, "json_summary.json"), summary)
|
| 89 |
+
logger.info(f"[JSON] Time={(time.time()-start)/60:.1f} min")
|
| 90 |
+
return summary
|
experiments/log_file_retrieval.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Log File Position Bias
|
| 3 |
+
Find an error message at varying positions in a log file.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from src.generator import generate_text
|
| 14 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
LOG_LEVELS = ["INFO", "DEBUG", "WARNING", "INFO", "DEBUG", "INFO"]
|
| 19 |
+
LOG_MESSAGES = [
|
| 20 |
+
"Connection established to server-01",
|
| 21 |
+
"Cache hit for key user_prefs",
|
| 22 |
+
"Processing batch job #4521",
|
| 23 |
+
"Database query completed in 12ms",
|
| 24 |
+
"Index rebuild started",
|
| 25 |
+
"Memory usage at 45%",
|
| 26 |
+
"Request served in 3ms",
|
| 27 |
+
"Background task scheduled",
|
| 28 |
+
"Config file reloaded",
|
| 29 |
+
"Metrics flushed to disk",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _make_log(n: int, target_line: str, target_pos: int) -> str:
|
| 34 |
+
"""Generate log file with target error at position."""
|
| 35 |
+
lines = []
|
| 36 |
+
for i in range(n):
|
| 37 |
+
if i == target_pos:
|
| 38 |
+
lines.append(target_line)
|
| 39 |
+
else:
|
| 40 |
+
ts = f"2024-01-{random.randint(1,28):02d} {random.randint(0,23):02d}:{random.randint(0,59):02d}:{random.randint(0,59):02d}"
|
| 41 |
+
level = random.choice(LOG_LEVELS)
|
| 42 |
+
msg = random.choice(LOG_MESSAGES)
|
| 43 |
+
lines.append(f"{ts} [{level}] {msg}")
|
| 44 |
+
return "\n".join(lines)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def run_log_retrieval(
|
| 48 |
+
model_name: str,
|
| 49 |
+
num_lines: int,
|
| 50 |
+
num_examples: int,
|
| 51 |
+
out_dir: str,
|
| 52 |
+
depths: List[float] = None,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
ensure_dir(out_dir)
|
| 55 |
+
if depths is None:
|
| 56 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 57 |
+
|
| 58 |
+
results = {}
|
| 59 |
+
start = time.time()
|
| 60 |
+
|
| 61 |
+
for depth in depths:
|
| 62 |
+
logger.info(f"[LOG] Depth {depth:.1%}")
|
| 63 |
+
preds = []
|
| 64 |
+
for _ in tqdm(range(num_examples), desc=f"Log {depth:.1%}", leave=False):
|
| 65 |
+
error_code = f"ERR-{random.randint(1000,9999)}"
|
| 66 |
+
target_line = f"2024-01-15 14:30:00 [ERROR] Critical failure: {error_code} - Service halted"
|
| 67 |
+
pos = int(depth * (num_lines - 1))
|
| 68 |
+
log_str = _make_log(num_lines, target_line, pos)
|
| 69 |
+
|
| 70 |
+
prompt = (
|
| 71 |
+
f"Find the error code in the log file below.\n\n"
|
| 72 |
+
f"```\n{log_str}\n```\n\n"
|
| 73 |
+
f"Error code:"
|
| 74 |
+
)
|
| 75 |
+
ans = generate_text(
|
| 76 |
+
[{"role": "user", "content": prompt}],
|
| 77 |
+
model_name=model_name,
|
| 78 |
+
max_new_tokens=15,
|
| 79 |
+
)
|
| 80 |
+
correct = 1.0 if error_code.lower() in ans.lower() else 0.0
|
| 81 |
+
preds.append({
|
| 82 |
+
"model_answer": ans,
|
| 83 |
+
"correct": correct,
|
| 84 |
+
"error_code": error_code,
|
| 85 |
+
"depth": depth,
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
save_jsonl(os.path.join(out_dir, f"log_depth_{depth}.jsonl"), preds)
|
| 89 |
+
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
|
| 90 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 91 |
+
logger.info(f"[LOG] Depth {depth:.1%}: acc={acc:.3f}")
|
| 92 |
+
|
| 93 |
+
summary = {
|
| 94 |
+
"experiment": "log_retrieval",
|
| 95 |
+
"num_lines": num_lines,
|
| 96 |
+
"num_examples": num_examples,
|
| 97 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 98 |
+
"time_minutes": (time.time() - start) / 60,
|
| 99 |
+
}
|
| 100 |
+
save_json(os.path.join(out_dir, "log_summary.json"), summary)
|
| 101 |
+
logger.info(f"[LOG] Time={(time.time()-start)/60:.1f} min")
|
| 102 |
+
return summary
|
experiments/table_retrieval.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Table/CSV Position Bias
|
| 3 |
+
Target row at varying positions in a markdown table.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from src.generator import generate_text
|
| 14 |
+
from src.utils import ensure_dir, save_jsonl, save_json
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _make_table(n: int, target_row: List[str], target_pos: int) -> str:
|
| 20 |
+
"""Generate markdown table with target row at position."""
|
| 21 |
+
headers = ["ID", "Name", "Value", "Status"]
|
| 22 |
+
rows = []
|
| 23 |
+
for i in range(n):
|
| 24 |
+
if i == target_pos:
|
| 25 |
+
rows.append(target_row)
|
| 26 |
+
else:
|
| 27 |
+
rows.append([
|
| 28 |
+
f"ID-{random.randint(1000,9999)}",
|
| 29 |
+
f"Item-{random.randint(1,99)}",
|
| 30 |
+
f"{random.randint(1,1000)}",
|
| 31 |
+
random.choice(["Active", "Inactive"]),
|
| 32 |
+
])
|
| 33 |
+
|
| 34 |
+
table = "| " + " | ".join(headers) + " |\n"
|
| 35 |
+
table += "|" + "|".join(["---"] * len(headers)) + "|\n"
|
| 36 |
+
for row in rows:
|
| 37 |
+
table += "| " + " | ".join(row) + " |\n"
|
| 38 |
+
return table
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_table_retrieval(
|
| 42 |
+
model_name: str,
|
| 43 |
+
num_rows: int,
|
| 44 |
+
num_examples: int,
|
| 45 |
+
out_dir: str,
|
| 46 |
+
depths: List[float] = None,
|
| 47 |
+
) -> Dict[str, Any]:
|
| 48 |
+
ensure_dir(out_dir)
|
| 49 |
+
if depths is None:
|
| 50 |
+
depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]
|
| 51 |
+
|
| 52 |
+
results = {}
|
| 53 |
+
start = time.time()
|
| 54 |
+
|
| 55 |
+
for depth in depths:
|
| 56 |
+
logger.info(f"[TABLE] Depth {depth:.1%}")
|
| 57 |
+
preds = []
|
| 58 |
+
for _ in tqdm(range(num_examples), desc=f"Table {depth:.1%}", leave=False):
|
| 59 |
+
target_id = f"GOLD-{random.randint(1000,9999)}"
|
| 60 |
+
target_value = f"{random.randint(5000,9999)}"
|
| 61 |
+
target_row = [target_id, "GoldenItem", target_value, "Gold"]
|
| 62 |
+
pos = int(depth * (num_rows - 1))
|
| 63 |
+
table_str = _make_table(num_rows, target_row, pos)
|
| 64 |
+
|
| 65 |
+
prompt = (
|
| 66 |
+
f"Find the 'Value' for the row where ID = '{target_id}' in the table below.\n\n"
|
| 67 |
+
f"{table_str}\n\n"
|
| 68 |
+
f"Value:"
|
| 69 |
+
)
|
| 70 |
+
ans = generate_text(
|
| 71 |
+
[{"role": "user", "content": prompt}],
|
| 72 |
+
model_name=model_name,
|
| 73 |
+
max_new_tokens=15,
|
| 74 |
+
)
|
| 75 |
+
correct = 1.0 if target_value in ans else 0.0
|
| 76 |
+
preds.append({
|
| 77 |
+
"model_answer": ans,
|
| 78 |
+
"correct": correct,
|
| 79 |
+
"target_value": target_value,
|
| 80 |
+
"depth": depth,
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
save_jsonl(os.path.join(out_dir, f"table_depth_{depth}.jsonl"), preds)
|
| 84 |
+
acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
|
| 85 |
+
results[depth] = {"accuracy": acc, "predictions": preds}
|
| 86 |
+
logger.info(f"[TABLE] Depth {depth:.1%}: acc={acc:.3f}")
|
| 87 |
+
|
| 88 |
+
summary = {
|
| 89 |
+
"experiment": "table_retrieval",
|
| 90 |
+
"num_rows": num_rows,
|
| 91 |
+
"num_examples": num_examples,
|
| 92 |
+
"depths": {str(d): results[d]["accuracy"] for d in depths},
|
| 93 |
+
"time_minutes": (time.time() - start) / 60,
|
| 94 |
+
}
|
| 95 |
+
save_json(os.path.join(out_dir, "table_summary.json"), summary)
|
| 96 |
+
logger.info(f"[TABLE] Time={(time.time()-start)/60:.1f} min")
|
| 97 |
+
return summary
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.40.0
|
| 3 |
+
accelerate>=0.25.0
|
| 4 |
+
bitsandbytes>=0.43.0
|
| 5 |
+
tqdm>=4.65.0
|
run_all.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Structured Data Position Bias — Master Runner"""
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from experiments.json_retrieval import run_json_retrieval
|
| 9 |
+
from experiments.table_retrieval import run_table_retrieval
|
| 10 |
+
from experiments.log_file_retrieval import run_log_retrieval
|
| 11 |
+
from src.utils import save_json
|
| 12 |
+
|
| 13 |
+
logging.basicConfig(
|
| 14 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
| 15 |
+
level=logging.INFO,
|
| 16 |
+
stream=sys.stdout,
|
| 17 |
+
)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def parse_args():
|
| 22 |
+
p = argparse.ArgumentParser(description="Structured Data Position Bias")
|
| 23 |
+
p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct")
|
| 24 |
+
p.add_argument("--output", default="./results")
|
| 25 |
+
p.add_argument("--num-items", type=int, default=100)
|
| 26 |
+
p.add_argument("--num-examples", type=int, default=30)
|
| 27 |
+
return p.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
args = parse_args()
|
| 32 |
+
model = args.model
|
| 33 |
+
out_root = args.output
|
| 34 |
+
os.makedirs(out_root, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
logger.info("\n--- Experiment 1: JSON Array Retrieval ---")
|
| 37 |
+
json_results = run_json_retrieval(
|
| 38 |
+
model, args.num_items, args.num_examples,
|
| 39 |
+
os.path.join(out_root, "exp1_json"),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
logger.info("\n--- Experiment 2: Markdown Table Retrieval ---")
|
| 43 |
+
table_results = run_table_retrieval(
|
| 44 |
+
model, args.num_items, args.num_examples,
|
| 45 |
+
os.path.join(out_root, "exp2_table"),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
logger.info("\n--- Experiment 3: Log File Retrieval ---")
|
| 49 |
+
log_results = run_log_retrieval(
|
| 50 |
+
model, args.num_items, args.num_examples,
|
| 51 |
+
os.path.join(out_root, "exp3_log"),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
master = {
|
| 55 |
+
"json": json_results,
|
| 56 |
+
"table": table_results,
|
| 57 |
+
"log": log_results,
|
| 58 |
+
}
|
| 59 |
+
save_json(os.path.join(out_root, "master_summary.json"), master)
|
| 60 |
+
|
| 61 |
+
logger.info("\n--- Structured Data PBI Comparison ---")
|
| 62 |
+
for exp_name, res in master.items():
|
| 63 |
+
depths = list(res["depths"].keys())
|
| 64 |
+
accs = list(res["depths"].values())
|
| 65 |
+
if len(accs) >= 3:
|
| 66 |
+
mid_idx = len(accs) // 2
|
| 67 |
+
pbi = (accs[0] + accs[-1]) / 2 - accs[mid_idx]
|
| 68 |
+
logger.info(f" {exp_name:10s} PBI={pbi:+.3f}")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Structured Data Position Bias Benchmark"""
|
| 2 |
+
__version__ = "1.0.0"
|
src/generator.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Text generation wrapper."""
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 5 |
+
|
| 6 |
+
_model_cache = {}
|
| 7 |
+
_tok_cache = {}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_model(model_name: str, load_in_4bit: bool = True):
|
| 11 |
+
cache_key = f"{model_name}:{load_in_4bit}"
|
| 12 |
+
if cache_key in _model_cache:
|
| 13 |
+
return _model_cache[cache_key], _tok_cache[cache_key]
|
| 14 |
+
tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 15 |
+
if tok.pad_token is None:
|
| 16 |
+
tok.pad_token = tok.eos_token
|
| 17 |
+
if load_in_4bit:
|
| 18 |
+
bnb = BitsAndBytesConfig(
|
| 19 |
+
load_in_4bit=True,
|
| 20 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 21 |
+
bnb_4bit_use_double_quant=True,
|
| 22 |
+
bnb_4bit_quant_type="nf4",
|
| 23 |
+
)
|
| 24 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 25 |
+
model_name, quantization_config=bnb, device_map="auto",
|
| 26 |
+
trust_remote_code=True, torch_dtype=torch.bfloat16,
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 30 |
+
model_name, device_map="auto",
|
| 31 |
+
trust_remote_code=True, torch_dtype=torch.bfloat16,
|
| 32 |
+
)
|
| 33 |
+
model.eval()
|
| 34 |
+
_model_cache[cache_key] = model
|
| 35 |
+
_tok_cache[cache_key] = tok
|
| 36 |
+
return model, tok
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def generate_text(messages: List[Dict[str, str]], model_name: str, max_new_tokens: int = 80):
|
| 40 |
+
model, tokenizer = load_model(model_name)
|
| 41 |
+
inputs = tokenizer.apply_chat_template(
|
| 42 |
+
messages, tokenize=True, return_tensors="pt",
|
| 43 |
+
add_generation_prompt=True, return_dict=True,
|
| 44 |
+
)
|
| 45 |
+
dev = next(model.parameters()).device
|
| 46 |
+
inputs = {k: v.to(dev) for k, v in inputs.items()}
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = model.generate(
|
| 49 |
+
**inputs, max_new_tokens=max_new_tokens,
|
| 50 |
+
do_sample=False, pad_token_id=tokenizer.pad_token_id,
|
| 51 |
+
)
|
| 52 |
+
gen = outputs[0][inputs["input_ids"].shape[1]:]
|
| 53 |
+
return tokenizer.decode(gen, skip_special_tokens=True).strip()
|
src/utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utilities."""
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def ensure_dir(path: str):
|
| 8 |
+
os.makedirs(path, exist_ok=True)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def save_jsonl(path: str, records: List[Dict[str, Any]]):
|
| 12 |
+
with open(path, "w") as f:
|
| 13 |
+
for r in records:
|
| 14 |
+
f.write(json.dumps(r) + "\n")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def save_json(path: str, data: Any):
|
| 18 |
+
with open(path, "w") as f:
|
| 19 |
+
json.dump(data, f, indent=2)
|