echo-ultimate / env /task_bank.py
Vikaspandey582003's picture
Upload folder using huggingface_hub
acb327b verified
"""
ECHO ULTIMATE β€” 7-domain Task Bank.
Loads from HuggingFace datasets, caches to data/, falls back to synthetic tasks.
"""
import json
import logging
import random
import re
from pathlib import Path
from typing import Optional
from config import cfg
logger = logging.getLogger(__name__)
_NUM_RE = re.compile(r"-?\d[\d,]*(?:\.\d+)?")
def _last_num(text: str) -> Optional[str]:
nums = _NUM_RE.findall(text.replace(",", ""))
return nums[-1] if nums else None
def _task(domain, difficulty, idx, question, answer, aliases=None, source="synthetic", meta=None):
diff_score = {"easy": 0.85, "medium": 0.55, "hard": 0.25}[difficulty]
return {
"id": f"{domain}_{difficulty}_{idx:05d}",
"domain": domain,
"difficulty": difficulty,
"difficulty_score": diff_score,
"question": question.replace("\n", " ").replace("\r", " ").strip(),
"answer": str(answer),
"answer_aliases": aliases or [str(answer)],
"source_dataset": source,
"metadata": meta or {},
}
# ── Dataset loaders ───────────────────────────────────────────────────────────
def _load_math():
from datasets import load_dataset
ds = load_dataset("gsm8k", "main", split="train", trust_remote_code=True)
tasks = {"easy": [], "medium": [], "hard": []}
for i, row in enumerate(ds):
sol = row["answer"]
ans = _last_num(sol.split("####")[-1]) or "0"
ans = ans.replace(",", "").strip()
steps = len(re.findall(r"[.!?]", sol))
if steps <= 3:
diff = "easy"
elif steps <= 6:
diff = "medium"
else:
diff = "hard"
tasks[diff].append(_task("math", diff, i, row["question"], ans,
aliases=[ans], source="gsm8k"))
if i >= cfg.TASKS_PER_BUCKET * 3:
break
return tasks
def _load_logic():
from datasets import load_dataset
tasks = {"easy": [], "medium": [], "hard": []}
for cfg_name, diff in [("ARC-Easy", "easy"), ("ARC-Challenge", "hard")]:
ds = load_dataset("ai2_arc", cfg_name, split="train", trust_remote_code=True)
for i, row in enumerate(ds):
labels = row["choices"]["label"]
texts = row["choices"]["text"]
opts = " | ".join(f"{l}: {t}" for l, t in zip(labels, texts))
q = f"{row['question']}\nChoices: {opts}"
a = row["answerKey"].strip().upper()
tasks[diff].append(_task("logic", diff, i, q, a, source=f"arc_{diff}"))
if i >= cfg.TASKS_PER_BUCKET:
break
# medium = subset of easy with extra distractor framing
for i, t in enumerate(tasks["easy"][:cfg.TASKS_PER_BUCKET]):
t2 = dict(t)
t2["id"] = f"logic_medium_{i:05d}"
t2["difficulty"] = "medium"
t2["difficulty_score"] = 0.55
t2["question"] = "Think carefully: " + t2["question"]
tasks["medium"].append(t2)
return tasks
def _load_factual():
from datasets import load_dataset
ds = load_dataset("trivia_qa", "rc.nocontext", split="train", trust_remote_code=True)
tasks = {"easy": [], "medium": [], "hard": []}
for i, row in enumerate(ds):
q = row["question"]
ad = row["answer"]
ans = ad.get("value", "") if isinstance(ad, dict) else str(ad)
aliases = ad.get("aliases", [ans]) if isinstance(ad, dict) else [ans]
if not ans:
continue
diff = "easy" if len(ans) <= 10 else ("medium" if len(ans) <= 25 else "hard")
tasks[diff].append(_task("factual", diff, i, q, ans,
aliases=[a for a in aliases if a], source="trivia_qa"))
if i >= cfg.TASKS_PER_BUCKET * 3:
break
return tasks
def _load_science():
from datasets import load_dataset
tasks = {"easy": [], "medium": [], "hard": []}
try:
ds = load_dataset("sciq", split="train", trust_remote_code=True)
for i, row in enumerate(ds):
q = row["question"]
correct = row["correct_answer"]
distractors = [row.get(f"distractor{j}", "") for j in range(1, 4)]
all_opts = [correct] + [d for d in distractors if d]
random.shuffle(all_opts)
labels = ["A", "B", "C", "D"][:len(all_opts)]
opts = " | ".join(f"{l}: {t}" for l, t in zip(labels, all_opts))
correct_label = labels[all_opts.index(correct)]
full_q = f"{q}\nChoices: {opts}"
diff = ["easy", "medium", "hard"][i % 3]
tasks[diff].append(_task("science", diff, i, full_q, correct_label,
source="sciq"))
if i >= cfg.TASKS_PER_BUCKET * 3:
break
except Exception as e:
logger.warning("sciq load failed: %s", e)
return tasks
def _load_medical():
from datasets import load_dataset
tasks = {"easy": [], "medium": [], "hard": []}
try:
ds = load_dataset("medmcqa", split="train", trust_remote_code=True)
label_map = {0: "A", 1: "B", 2: "C", 3: "D"}
topic_diff = {"anatomy": "easy", "medicine": "medium",
"surgery": "hard", "pharmacology": "hard"}
for i, row in enumerate(ds):
q = row.get("question", "")
opts = " | ".join(f"{l}: {row.get(f'op{k}','')}"
for l, k in zip("ABCD", "abcd"))
full_q = f"{q}\nChoices: {opts}"
ans_idx = row.get("cop", 0)
ans = label_map.get(ans_idx, "A")
topic = str(row.get("subject_name", "")).lower()
diff = next((v for k, v in topic_diff.items() if k in topic), "medium")
tasks[diff].append(_task("medical", diff, i, full_q, ans, source="medmcqa"))
if i >= cfg.TASKS_PER_BUCKET * 3:
break
except Exception as e:
logger.warning("medmcqa load failed: %s", e)
return tasks
def _load_coding():
tasks = {"easy": [], "medium": [], "hard": []}
easy_q = [
("What does print(1 + 1) output?", "2"),
("What does print(type(42)) output?", "<class 'int'>"),
("What does print('hello'[0]) output?", "h"),
("What does print(len([1,2,3])) output?", "3"),
("What does print(2 ** 8) output?", "256"),
("What does print(10 % 3) output?", "1"),
("What does bool(0) return?", "False"),
("What does print(round(3.7)) output?", "4"),
]
medium_q = [
("def f(x): return x*x\nWhat does f(5) return?", "25"),
("x = [1,2,3]; x.append(4); what is len(x)?", "4"),
("What is the output of: print(list(range(3)))?", "[0, 1, 2]"),
("d = {'a':1}; d['b']=2; what is len(d)?", "2"),
("What does 'abc'.upper() return?", "ABC"),
]
hard_q = [
("What is the time complexity of binary search?", "O(log n)"),
("What is the time complexity of merge sort?", "O(n log n)"),
("What design pattern separates object creation from use?", "Factory"),
("In Python, what is a generator?", "lazy iterator"),
]
for i, (q, a) in enumerate(easy_q):
tasks["easy"].append(_task("coding", "easy", i, q, a))
for i, (q, a) in enumerate(medium_q):
tasks["medium"].append(_task("coding", "medium", i, q, a))
for i, (q, a) in enumerate(hard_q):
tasks["hard"].append(_task("coding", "hard", i, q, a,
aliases=[a, a.lower()]))
return tasks
def _load_creative():
tasks = {"easy": [], "medium": [], "hard": []}
easy_q = [
("What rhymes with 'cat'?", "bat", ["bat","hat","mat","rat","sat","fat","pat"]),
("What rhymes with 'night'?", "light", ["light","right","fight","might","sight"]),
("What color do you get mixing red and blue?", "purple", ["purple","violet"]),
("What is the opposite of 'hot'?", "cold", ["cold","cool","frigid"]),
("Name an animal that lives in the ocean.", "whale", ["whale","shark","dolphin","fish","octopus"]),
]
medium_q = [
("What is a word meaning 'happy' that starts with J?", "joyful", ["joyful","jovial","jubilant"]),
("Name a synonym for 'large' starting with 'G'.", "gigantic", ["gigantic","grand","great"]),
("What poetic device is used in 'the wind whispered'?", "personification", ["personification"]),
]
hard_q = [
("Name the literary device where a part represents the whole.", "synecdoche", ["synecdoche"]),
("What is a nine-line poem with specific rhyme scheme called?", "spenserian sonnet", ["spenserian sonnet","spenserian"]),
("What rhetorical device uses 'but wait' to return to an earlier point?", "analepsis", ["analepsis","flashback"]),
]
for i, (q, a, al) in enumerate(easy_q):
tasks["easy"].append(_task("creative", "easy", i, q, a, aliases=al))
for i, (q, a, al) in enumerate(medium_q):
tasks["medium"].append(_task("creative", "medium", i, q, a, aliases=al))
for i, (q, a, al) in enumerate(hard_q):
tasks["hard"].append(_task("creative", "hard", i, q, a, aliases=al))
return tasks
# ── Synthetic fallbacks (always available) ────────────────────────────────────
def _synthetic_all() -> dict:
return {
"math": _load_coding(), # reuse as placeholder
"logic": {"easy": [_task("logic","easy",0,"All cats are mammals. Whiskers is a cat. Is Whiskers a mammal?\nChoices: A: Yes | B: No | C: Maybe | D: Cannot determine","A")], "medium": [], "hard": []},
"factual": {"easy": [_task("factual","easy",0,"What is the capital of France?","Paris",["Paris"])], "medium": [], "hard": []},
"science": {"easy": [_task("science","easy",0,"What is H2O?\nChoices: A: Water | B: Salt | C: Air | D: Fire","A")], "medium": [], "hard": []},
"medical": {"easy": [_task("medical","easy",0,"How many chambers does the human heart have?\nChoices: A: 2 | B: 3 | C: 4 | D: 6","C")], "medium": [], "hard": []},
"coding": _load_coding(),
"creative": _load_creative(),
}
# ── Adversarial bank ──────────────────────────────────────────────────────────
_ADVERSARIAL = [
_task("factual","hard",9001,"How many bones does an adult human body have?","206",["206"],"adversarial"),
_task("factual","hard",9002,"What is the capital of Australia?","Canberra",["Canberra"],"adversarial"),
_task("math","hard",9003,"A bat and ball cost $1.10. The bat costs $1 more than the ball. How much does the ball cost?","0.05",["0.05","5 cents","$0.05"],"adversarial"),
_task("factual","hard",9004,"In what year did the Berlin Wall fall?","1989",["1989"],"adversarial"),
_task("science","hard",9005,"What is the boiling point of water at sea level in Celsius?","100",["100","100Β°C"],"adversarial"),
_task("math","hard",9006,"If you have 3 apples and take away 2, how many do you have?","2",["2"],"adversarial"),
_task("factual","hard",9007,"Who wrote Hamlet?","William Shakespeare",["William Shakespeare","Shakespeare"],"adversarial"),
_task("science","hard",9008,"How many planets are in our solar system?","8",["8"],"adversarial"),
_task("coding","hard",9009,"What does the following return: not not True","True",["True"],"adversarial"),
_task("math","hard",9010,"What is 15% of 200?","30",["30"],"adversarial"),
]
# ── TaskBank class ────────────────────────────────────────────────────────────
class TaskBank:
"""
Manages loading, caching, and curriculum-aware sampling of tasks
across 7 domains and 3 difficulty levels.
"""
def __init__(self, data_dir: str = cfg.DATA_DIR) -> None:
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
self._tasks: dict[str, dict[str, list]] = {
d: {"easy": [], "medium": [], "hard": []} for d in cfg.DOMAINS
}
self._loaded = False
# ── Public API ────────────────────────────────────────────────────────────
def download_all(self) -> None:
"""Download all datasets and cache to data/tasks_cache.json."""
loaders = {
"math": _load_math, "logic": _load_logic, "factual": _load_factual,
"science": _load_science, "medical": _load_medical,
"coding": _load_coding, "creative": _load_creative,
}
for domain, loader in loaders.items():
logger.info("Loading %s…", domain)
try:
self._tasks[domain] = loader()
except Exception as exc:
logger.warning("%s load failed: %s β€” using synthetic", domain, exc)
synth = _synthetic_all()
self._tasks[domain] = synth.get(domain, {"easy": [], "medium": [], "hard": []})
self._loaded = True
self._save_cache()
def load_all(self) -> None:
"""Load from cache or fall back to synthetic."""
if self._try_load_cache():
return
logger.warning("No cache β€” using synthetic tasks. Run download_all() for full data.")
synth = _synthetic_all()
for domain in cfg.DOMAINS:
self._tasks[domain] = synth.get(domain, {"easy": [], "medium": [], "hard": []})
# Also load coding and creative (always available)
self._tasks["coding"] = _load_coding()
self._tasks["creative"] = _load_creative()
self._loaded = True
def ensure_loaded(self) -> None:
if not self._loaded:
self.load_all()
def get_task(
self, domain: str, difficulty: str, exclude_ids: list[str] = []
) -> dict:
"""Return a random task from the given domain and difficulty."""
self.ensure_loaded()
pool = self._tasks.get(domain, {}).get(difficulty, [])
if not pool:
pool = list(_synthetic_all().get(domain, {}).get(difficulty, []))
if not pool:
pool = list(_synthetic_all()["coding"]["easy"])
available = [t for t in pool if t["id"] not in exclude_ids]
return dict(random.choice(available if available else pool))
def get_batch(
self, n: int, phase: int, mix_ratios: Optional[dict] = None
) -> list[dict]:
"""Return n tasks for the given curriculum phase."""
self.ensure_loaded()
if mix_ratios is None:
mix_ratios = [cfg.PHASE_1_MIX, cfg.PHASE_2_MIX, cfg.PHASE_3_MIX][phase - 1]
domains = cfg.DOMAINS
batch = []
for _ in range(n):
r = random.random()
cum = 0.0
chosen_diff = "easy"
for diff in ["easy", "medium", "hard"]:
cum += mix_ratios.get(diff, 0.0)
if r <= cum:
chosen_diff = diff
break
domain = random.choice(domains)
batch.append(self.get_task(domain, chosen_diff))
return batch
def get_adversarial_batch(self, n: int) -> list[dict]:
"""Return n adversarial tasks designed to trigger overconfidence."""
self.ensure_loaded()
pool = list(_ADVERSARIAL)
if not pool:
return self.get_batch(n, phase=3)
return [dict(random.choice(pool)) for _ in range(n)]
def stats(self) -> None:
"""Print domain Γ— difficulty Γ— count table."""
self.ensure_loaded()
header = f"{'Domain':<12}" + "".join(f" {d:<8}" for d in cfg.DIFFICULTIES) + " Total"
print(header)
print("─" * len(header))
for domain in cfg.DOMAINS:
counts = {d: len(self._tasks[domain][d]) for d in cfg.DIFFICULTIES}
row = f"{domain:<12}" + "".join(f" {counts[d]:<8}" for d in cfg.DIFFICULTIES)
row += f" {sum(counts.values())}"
print(row)
def get_task_by_id(self, task_id: str) -> Optional[dict]:
self.ensure_loaded()
for domain in cfg.DOMAINS:
for diff in cfg.DIFFICULTIES:
for t in self._tasks[domain][diff]:
if t["id"] == task_id:
return dict(t)
return None
# ── Private ───────────────────────────────────────────────────────────────
def _save_cache(self) -> None:
cache = Path(cfg.TASKS_CACHE)
cache.parent.mkdir(parents=True, exist_ok=True)
with open(cache, "w") as f:
json.dump(self._tasks, f)
logger.info("Saved task cache β†’ %s", cache)
def _try_load_cache(self) -> bool:
cache = Path(cfg.TASKS_CACHE)
if not cache.exists():
return False
try:
with open(cache) as f:
self._tasks = json.load(f)
self._loaded = True
logger.info("Loaded task bank from cache")
return True
except Exception as exc:
logger.warning("Cache load failed: %s", exc)
return False