""" Cross-Task Position Bias Runner Runs multiple task types on multiple models, computes PBI taxonomy. """ import json import logging import os import random import re import time import uuid from typing import Dict, List, Any from tqdm import tqdm from src.generator import generate_text from src.taxonomy import classify_bias, position_bias_index, weighted_position_bias_index from src.utils import ensure_dir, save_json logger = logging.getLogger(__name__) FILLERS = [ "The history of pottery spans thousands of years.", "Marine biologists study coral reef ecosystems.", "The periodic table arranges elements by number.", "Neural networks are inspired by biological brains.", "Light speed is 299,792,458 meters per second.", "GPS uses triangulation from satellites.", "Cryptography secures digital communication.", ] def run_task( model_name: str, task_name: str, num_examples: int, depths: List[float], config: Dict[str, Any], ) -> Dict[str, Any]: """ Generic task runner that dispatches to specific task implementations. Returns: {"accuracies": [...], "pbi": float, "classification": str, "task": str} """ task_impl = { "kv_retrieval": _run_kv, "needle_haystack": _run_needle, "reasoning": _run_reasoning, "summarization": _run_summarization, "translation": _run_translation, } if task_name not in task_impl: raise ValueError(f"Unknown task: {task_name}") return task_impl[task_name](model_name, num_examples, depths, config) def _run_kv(model_name: str, num_examples: int, depths: List[float], config: Dict) -> Dict[str, Any]: """Key-value retrieval task.""" num_keys = config.get("num_keys", 100) accuracies = [] for depth in depths: correct_count = 0 for _ in range(num_examples): kv = {} while len(kv) != num_keys: kv[str(uuid.uuid4())[:8]] = str(uuid.uuid4())[:8] ordered = list(kv.items()) gold_k, gold_v = random.choice(ordered) pos = int(depth * (num_keys - 1)) # Reorder: move gold to pos gi = next(i for i, (k, _) in enumerate(ordered) if k == gold_k) new_ordered = ordered[:gi] + ordered[gi+1:] new_ordered = new_ordered[:pos] + [(gold_k, gold_v)] + new_ordered[pos:] formatted = "{\n" for i, (k, v) in enumerate(new_ordered): comma = "," if i < len(new_ordered) - 1 else "" formatted += f' "{k}": "{v}"{comma}\n' formatted += "}" prompt = ( f"Extract the value corresponding to the specified key in the JSON object below.\n\n" f"JSON data:\n{formatted}\n\n" f"Key: \"{gold_k}\"\nCorresponding value:" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=20, ) if gold_v.lower() in ans.lower(): correct_count += 1 accuracies.append(correct_count / num_examples) pbi = position_bias_index(depths, accuracies) w_pbi = weighted_position_bias_index(depths, accuracies) return { "task": "kv_retrieval", "accuracies": accuracies, "depths": depths, "pbi": pbi, "weighted_pbi": w_pbi, "classification": classify_bias(depths, accuracies), "num_keys": num_keys, } def _run_needle(model_name: str, num_examples: int, depths: List[float], config: Dict) -> Dict[str, Any]: """Needle in haystack task.""" num_sentences = config.get("num_sentences", 300) accuracies = [] for depth in depths: correct_count = 0 for _ in range(num_examples): text = " ".join(random.choice(FILLERS) + f" [{i+1}]." for i in range(num_sentences)) sents = [s.strip() + "." for s in text.split(".") if s.strip()] code = f"CODE-{random.randint(1000, 9999)}" needle = f"The secret code is {code}." idx = int(depth * len(sents)) sents.insert(idx, needle) doc = " ".join(sents) prompt = ( f"Read the text and find the secret code.\n\n{doc}\n\n" f"What is the secret code? Answer with only the code." ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=15, ) if code.lower() in ans.lower(): correct_count += 1 accuracies.append(correct_count / num_examples) pbi = position_bias_index(depths, accuracies) w_pbi = weighted_position_bias_index(depths, accuracies) return { "task": "needle_haystack", "accuracies": accuracies, "depths": depths, "pbi": pbi, "weighted_pbi": w_pbi, "classification": classify_bias(depths, accuracies), "num_sentences": num_sentences, } def _run_reasoning(model_name: str, num_examples: int, depths: List[float], config: Dict) -> Dict[str, Any]: """Fact-dependent reasoning (math) task.""" num_sentences = config.get("num_sentences", 300) distractors = [ "The museum opens at 9 AM.", "Temperature is recorded hourly.", "Solar panels generate 45 kWh daily.", "The database has four million records.", ] accuracies = [] for depth in depths: correct_count = 0 for _ in range(num_examples): sents = [random.choice(distractors) + f" [Doc {i+1}]" for i in range(num_sentences)] price = random.randint(2, 15) qty = random.randint(3, 20) discount = random.randint(5, 30) answer = round(price * qty * (1 - discount / 100), 2) fact = f"For this order, apples cost ${price}/kg with a {discount}% discount." idx = int(depth * len(sents)) sents.insert(idx, fact) doc = " ".join(sents) prompt = ( f"Use ONLY the document below.\n\n{doc}\n\n" f"Question: I buy {qty} kg of apples. What is my total cost? " f"Answer with only the dollar amount." ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=15, ) nums = re.findall(r"[\d,]+\.?\d*", ans.replace(",", "")) if nums: pred = float(nums[0]) if abs(pred - answer) < 0.5: correct_count += 1 accuracies.append(correct_count / num_examples) pbi = position_bias_index(depths, accuracies) w_pbi = weighted_position_bias_index(depths, accuracies) return { "task": "reasoning", "accuracies": accuracies, "depths": depths, "pbi": pbi, "weighted_pbi": w_pbi, "classification": classify_bias(depths, accuracies), "num_sentences": num_sentences, } def _run_summarization(model_name: str, num_examples: int, depths: List[float], config: Dict) -> Dict[str, Any]: """ Summarization task: key fact at depth D, check if summary includes it. Simplified: check if answer contains key phrase. """ num_sentences = config.get("num_sentences", 300) key_phrases = [ "the golden statue", "the secret treaty", "the new bridge", "the ancient library", "the solar eclipse", ] accuracies = [] for depth in depths: correct_count = 0 for _ in range(num_examples): sents = [random.choice(FILLERS) for _ in range(num_sentences)] key_phrase = random.choice(key_phrases) idx = int(depth * len(sents)) sents.insert(idx, f"Everyone admired {key_phrase} in the town square.") doc = " ".join(sents) prompt = ( f"Summarize the following text in one sentence.\n\n{doc}\n\n" f"Summary:" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=40, ) # Check if key phrase concept is in summary if key_phrase.split()[-1] in ans.lower() or key_phrase in ans.lower(): correct_count += 1 accuracies.append(correct_count / num_examples) pbi = position_bias_index(depths, accuracies) w_pbi = weighted_position_bias_index(depths, accuracies) return { "task": "summarization", "accuracies": accuracies, "depths": depths, "pbi": pbi, "weighted_pbi": w_pbi, "classification": classify_bias(depths, accuracies), "num_sentences": num_sentences, } def _run_translation(model_name: str, num_examples: int, depths: List[float], config: Dict) -> Dict[str, Any]: """ Translation task: key sentence at depth D that must be translated. Check if target phrase appears in translation. """ num_sentences = config.get("num_sentences", 300) # English -> French translation pairs pairs = [ ("The cat sleeps on the mat.", "chat"), ("She walked to the market.", "marché"), ("The sun rises in the east.", "soleil"), ("Books are gateways to knowledge.", "connaissance"), ("Music heals the soul.", "musique"), ] accuracies = [] for depth in depths: correct_count = 0 for _ in range(num_examples): sents = [random.choice(FILLERS) for _ in range(num_sentences)] en_sent, target_word = random.choice(pairs) idx = int(depth * len(sents)) sents.insert(idx, en_sent) doc = " ".join(sents) prompt = ( f"Translate the following text into French:\n\n{doc}\n\n" f"French translation:" ) ans = generate_text( [{"role": "user", "content": prompt}], model_name=model_name, max_new_tokens=80, ) if target_word.lower() in ans.lower(): correct_count += 1 accuracies.append(correct_count / num_examples) pbi = position_bias_index(depths, accuracies) w_pbi = weighted_position_bias_index(depths, accuracies) return { "task": "translation", "accuracies": accuracies, "depths": depths, "pbi": pbi, "weighted_pbi": w_pbi, "classification": classify_bias(depths, accuracies), "num_sentences": num_sentences, } def run_cross_task_evaluation( model_name: str, tasks: Dict[str, Dict[str, Any]], num_examples: int = 30, depths: List[float] = None, out_dir: str = "./results", ) -> Dict[str, Any]: """ Run all tasks for a single model and compute taxonomy. Args: model_name: HF model identifier tasks: Dict of {task_name: config_dict} num_examples: Examples per depth depths: Position depths to test out_dir: Output directory Returns: Full results with taxonomy classification """ ensure_dir(out_dir) if depths is None: depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] results = {"model": model_name, "tasks": {}, "taxonomy": {}} start = time.time() for task_name, config in tasks.items(): logger.info(f"Running {task_name} on {model_name}...") task_result = run_task(model_name, task_name, num_examples, depths, config) results["tasks"][task_name] = task_result logger.info( f" {task_name}: PBI={task_result['pbi']:.3f} " f"class={task_result['classification']}" ) # Compute cross-task correlations pbi_values = {t: r["pbi"] for t, r in results["tasks"].items()} results["taxonomy"]["pbi_per_task"] = pbi_values results["taxonomy"]["mean_pbi"] = sum(pbi_values.values()) / len(pbi_values) if pbi_values else 0.0 results["taxonomy"]["std_pbi"] = ( (sum((v - results["taxonomy"]["mean_pbi"]) ** 2 for v in pbi_values.values()) / len(pbi_values)) ** 0.5 if pbi_values else 0.0 ) # Rank tasks by bias ranked = sorted(pbi_values.items(), key=lambda x: x[1]) results["taxonomy"]["least_biased_task"] = ranked[0][0] if ranked else None results["taxonomy"]["most_biased_task"] = ranked[-1][0] if ranked else None results["time_minutes"] = (time.time() - start) / 60 save_json(os.path.join(out_dir, f"taxonomy_{model_name.replace('/', '_')}.json"), results) logger.info( f"Completed {model_name}: mean_PBI={results['taxonomy']['mean_pbi']:.3f} " f"std={results['taxonomy']['std_pbi']:.3f}" ) return results