| """Task sampler: loads the seed corpus and samples Tasks by difficulty. |
| |
| Difficulty is auto-derived from script line count. Category is auto-detected |
| from script content (text_classification, ner, translation, etc.). |
| """ |
| from __future__ import annotations |
|
|
| import random |
| from pathlib import Path |
| from typing import Optional |
|
|
| from forgeenv.tasks.models import Task |
|
|
|
|
| def _detect_category(content: str) -> str: |
| cl = content.lower() |
| if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl: |
| return "text_classification" |
| if "tokenclassification" in cl or "ner" in cl or "conll" in cl: |
| return "ner" |
| if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl: |
| return "seq2seq" |
| if "causallm" in cl or "gpt2" in cl or "wikitext" in cl: |
| return "text_generation" |
| if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl: |
| return "image_classification" |
| if "questionanswering" in cl or "squad" in cl: |
| return "qa" |
| if "logisticregression" in cl or "make_classification" in cl: |
| return "tabular" |
| if "regression" in cl: |
| return "regression" |
| return "general" |
|
|
|
|
| def _derive_difficulty(content: str) -> str: |
| lines = len(content.splitlines()) |
| if lines < 30: |
| return "easy" |
| if lines < 60: |
| return "medium" |
| return "hard" |
|
|
|
|
| class TaskSampler: |
| """Loads seed corpus and samples tasks by difficulty / category.""" |
|
|
| def __init__(self, seed_dir: Optional[str] = None) -> None: |
| if seed_dir is None: |
| seed_dir = str(Path(__file__).parent / "seed_corpus") |
|
|
| self.tasks: list[Task] = [] |
| self._load_corpus(seed_dir) |
|
|
| def _load_corpus(self, seed_dir: str) -> None: |
| corpus_path = Path(seed_dir) |
| if not corpus_path.exists(): |
| return |
|
|
| for py_file in sorted(corpus_path.glob("*.py")): |
| if py_file.name.startswith("__"): |
| continue |
|
|
| content = py_file.read_text(encoding="utf-8") |
| task_id = py_file.stem |
| difficulty = _derive_difficulty(content) |
| category = _detect_category(content) |
|
|
| description = "" |
| if content.startswith('"""'): |
| end = content.find('"""', 3) |
| if end != -1: |
| description = content[3:end].strip() |
|
|
| self.tasks.append( |
| Task( |
| task_id=task_id, |
| description=description or f"Training script: {task_id}", |
| script_content=content, |
| difficulty=difficulty, |
| category=category, |
| ) |
| ) |
|
|
| def sample(self, difficulty: Optional[str] = None) -> Optional[Task]: |
| candidates = self.tasks |
| if difficulty is not None: |
| filtered = [t for t in self.tasks if t.difficulty == difficulty] |
| if filtered: |
| candidates = filtered |
| return random.choice(candidates) if candidates else None |
|
|
| def sample_batch( |
| self, n: int, difficulty: Optional[str] = None |
| ) -> list[Task]: |
| return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None] |
|
|
| def get_all_categories(self) -> list[str]: |
| return sorted({t.category for t in self.tasks}) |
|
|
| def get_by_id(self, task_id: str) -> Optional[Task]: |
| for t in self.tasks: |
| if t.task_id == task_id: |
| return t |
| return None |
|
|