| """Task sampler: loads the seed corpus and samples Tasks by difficulty. | |
| Difficulty is auto-derived from script line count. Category is auto-detected | |
| from script content (text_classification, ner, translation, etc.). | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from pathlib import Path | |
| from typing import Optional | |
| from forgeenv.tasks.models import Task | |
| def _detect_category(content: str) -> str: | |
| cl = content.lower() | |
| if "sequenceclassification" in cl or "sentiment" in cl or "ag_news" in cl or "sst2" in cl: | |
| return "text_classification" | |
| if "tokenclassification" in cl or "ner" in cl or "conll" in cl: | |
| return "ner" | |
| if "seq2seq" in cl or "translation" in cl or "summariz" in cl or "t5" in cl: | |
| return "seq2seq" | |
| if "causallm" in cl or "gpt2" in cl or "wikitext" in cl: | |
| return "text_generation" | |
| if "imageclassification" in cl or "vit" in cl or "cifar" in cl or "mnist" in cl: | |
| return "image_classification" | |
| if "questionanswering" in cl or "squad" in cl: | |
| return "qa" | |
| if "logisticregression" in cl or "make_classification" in cl: | |
| return "tabular" | |
| if "regression" in cl: | |
| return "regression" | |
| return "general" | |
| def _derive_difficulty(content: str) -> str: | |
| lines = len(content.splitlines()) | |
| if lines < 30: | |
| return "easy" | |
| if lines < 60: | |
| return "medium" | |
| return "hard" | |
| class TaskSampler: | |
| """Loads seed corpus and samples tasks by difficulty / category.""" | |
| def __init__(self, seed_dir: Optional[str] = None) -> None: | |
| if seed_dir is None: | |
| seed_dir = str(Path(__file__).parent / "seed_corpus") | |
| self.tasks: list[Task] = [] | |
| self._load_corpus(seed_dir) | |
| def _load_corpus(self, seed_dir: str) -> None: | |
| corpus_path = Path(seed_dir) | |
| if not corpus_path.exists(): | |
| return | |
| for py_file in sorted(corpus_path.glob("*.py")): | |
| if py_file.name.startswith("__"): | |
| continue | |
| content = py_file.read_text(encoding="utf-8") | |
| task_id = py_file.stem | |
| difficulty = _derive_difficulty(content) | |
| category = _detect_category(content) | |
| description = "" | |
| if content.startswith('"""'): | |
| end = content.find('"""', 3) | |
| if end != -1: | |
| description = content[3:end].strip() | |
| self.tasks.append( | |
| Task( | |
| task_id=task_id, | |
| description=description or f"Training script: {task_id}", | |
| script_content=content, | |
| difficulty=difficulty, | |
| category=category, | |
| ) | |
| ) | |
| def sample(self, difficulty: Optional[str] = None) -> Optional[Task]: | |
| candidates = self.tasks | |
| if difficulty is not None: | |
| filtered = [t for t in self.tasks if t.difficulty == difficulty] | |
| if filtered: | |
| candidates = filtered | |
| return random.choice(candidates) if candidates else None | |
| def sample_batch( | |
| self, n: int, difficulty: Optional[str] = None | |
| ) -> list[Task]: | |
| return [t for t in (self.sample(difficulty) for _ in range(n)) if t is not None] | |
| def get_all_categories(self) -> list[str]: | |
| return sorted({t.category for t in self.tasks}) | |
| def get_by_id(self, task_id: str) -> Optional[Task]: | |
| for t in self.tasks: | |
| if t.task_id == task_id: | |
| return t | |
| return None | |