Spaces:
Sleeping
Sleeping
| """Tests for the Forge module β pass@k formula, difficulty classification, scheduler.""" | |
| from __future__ import annotations | |
| import pytest | |
| from forge_arena.forge.estimator import pass_at_k | |
| from forge_arena.forge.generator import TaskGenerator | |
| from forge_arena.models.tasks import CorruptionType, DifficultyTier, Task, TaskDomain | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # pass@k unbiased estimator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestPassAtK: | |
| def test_zero_correct(self): | |
| """0 correct answers β pass@k should be 0.0.""" | |
| assert pass_at_k(n=10, c=0, k=8) == pytest.approx(0.0) | |
| def test_all_correct(self): | |
| """All correct β pass@k should be 1.0.""" | |
| assert pass_at_k(n=10, c=10, k=8) == pytest.approx(1.0) | |
| def test_half_correct_k1(self): | |
| """pass@1 with half correct = 0.5.""" | |
| result = pass_at_k(n=10, c=5, k=1) | |
| assert pytest.approx(0.5, abs=0.01) == result | |
| def test_k_greater_than_n_raises(self): | |
| with pytest.raises(ValueError): | |
| pass_at_k(n=4, c=2, k=8) | |
| def test_c_greater_than_n_raises(self): | |
| with pytest.raises(ValueError): | |
| pass_at_k(n=5, c=6, k=4) | |
| def test_monotone_in_c(self): | |
| """Increasing correct answers should increase pass@k.""" | |
| results = [pass_at_k(n=10, c=c, k=8) for c in range(0, 11)] | |
| for i in range(len(results) - 1): | |
| assert results[i] <= results[i + 1] | |
| def test_output_in_unit_interval(self): | |
| for n in [8, 10, 16]: | |
| for c in range(0, n + 1, 2): | |
| val = pass_at_k(n=n, c=c, k=min(8, n)) | |
| assert 0.0 <= val <= 1.0, f"Out of range for n={n}, c={c}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Difficulty classification | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestDifficultyClassification: | |
| def setup_method(self): | |
| from unittest.mock import MagicMock | |
| from forge_arena.forge.estimator import DifficultyEstimator | |
| self.config = MagicMock() | |
| self.config.difficulty_thresholds.too_easy = 0.85 | |
| self.config.difficulty_thresholds.too_hard = 0.20 | |
| self.config.estimation_k = 8 | |
| self.config.estimation_n = 10 | |
| # DifficultyEstimator requires a shared mutable episode_counter list | |
| self.estimator = DifficultyEstimator(self.config, []) | |
| def test_too_easy(self): | |
| from forge_arena.forge.estimator import classify_difficulty | |
| tier = classify_difficulty(0.90, self.config) | |
| assert tier == DifficultyTier.TOO_EASY | |
| def test_too_hard(self): | |
| from forge_arena.forge.estimator import classify_difficulty | |
| tier = classify_difficulty(0.10, self.config) | |
| assert tier == DifficultyTier.TOO_HARD | |
| def test_learnable_midpoint(self): | |
| from forge_arena.forge.estimator import classify_difficulty | |
| tier = classify_difficulty(0.50, self.config) | |
| assert tier == DifficultyTier.LEARNABLE | |
| def test_boundary_too_easy(self): | |
| """Exactly at boundary β learnable (inclusive lower bound).""" | |
| from forge_arena.forge.estimator import classify_difficulty | |
| tier = classify_difficulty(0.85, self.config) | |
| assert tier == DifficultyTier.LEARNABLE | |
| def test_boundary_too_hard(self): | |
| from forge_arena.forge.estimator import classify_difficulty | |
| tier = classify_difficulty(0.20, self.config) | |
| assert tier == DifficultyTier.LEARNABLE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TaskScheduler queue management | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestTaskScheduler: | |
| def setup_method(self): | |
| from unittest.mock import MagicMock, AsyncMock | |
| from forge_arena.forge.scheduler import TaskScheduler | |
| config = MagicMock() | |
| config.queue_replenishment_threshold = 5 | |
| config.batch_reestimation_interval = 50 | |
| config.difficulty_thresholds.too_easy = 0.85 | |
| config.difficulty_thresholds.too_hard = 0.20 | |
| self.estimator = MagicMock() | |
| self.generator = MagicMock() | |
| self.scheduler = TaskScheduler(config, self.estimator, self.generator) | |
| def _make_task(self, task_id: str) -> object: | |
| from unittest.mock import MagicMock | |
| from forge_arena.models.tasks import TaskDomain, ObfuscationDepth | |
| task = MagicMock() | |
| task.id = task_id | |
| task.domain = TaskDomain.CUSTOMER_SUPPORT | |
| task.is_generated = False | |
| task.difficulty_tier = None | |
| return task | |
| def _make_snapshot(self, task_id: str, tier: DifficultyTier, pak: float): | |
| from unittest.mock import MagicMock | |
| snap = MagicMock() | |
| snap.task_id = task_id | |
| snap.difficulty_tier = tier | |
| snap.pass_at_k = pak | |
| return snap | |
| async def test_initialise_routes_learnable_to_active_queue(self): | |
| tasks = [self._make_task(f"t-{i}") for i in range(3)] | |
| snaps = [ | |
| self._make_snapshot(f"t-{i}", DifficultyTier.LEARNABLE, 0.50) | |
| for i in range(3) | |
| ] | |
| self.estimator.batch_estimate.return_value = snaps | |
| await self.scheduler.initialise(tasks, lambda t: False) | |
| state = self.scheduler.get_queue_state() | |
| assert state.learnable_count == 3 | |
| async def test_initialise_places_all_seed_tasks_in_active_queue(self): | |
| """Seed tasks bypass estimation and are placed directly in the learnable queue. | |
| Pre-estimating with a no-op policy (c=0) would classify every task | |
| as too-hard. Seed tasks are hand-authored for the learnable zone, so | |
| they skip estimation until real episodes are collected. | |
| """ | |
| tasks = [self._make_task("easy-task")] | |
| await self.scheduler.initialise(tasks, lambda t: False) | |
| state = self.scheduler.get_queue_state() | |
| # Task goes straight to active queue; estimator is never called. | |
| assert state.learnable_count == 1 | |
| assert state.too_easy_count == 0 | |
| self.estimator.batch_estimate.assert_not_called() | |
| def test_request_task_returns_task(self): | |
| task = self._make_task("t-1") | |
| self.scheduler._active_queue.append(task) | |
| result = self.scheduler.request_task() | |
| assert result.id == "t-1" | |
| def test_request_task_empty_raises(self): | |
| from forge_arena.forge.scheduler import QueueEmptyError | |
| # No initialise() called β both _active_queue and _seed_bank are empty. | |
| with pytest.raises(QueueEmptyError): | |
| self.scheduler.request_task() | |
| def test_request_task_cycles_when_queue_exhausted(self): | |
| """After all seed tasks are consumed the queue refills from the seed bank.""" | |
| import asyncio | |
| tasks = [self._make_task(f"seed-{i}") for i in range(3)] | |
| asyncio.run(self.scheduler.initialise(tasks, lambda t: False)) | |
| # Consume all 3 tasks | |
| for _ in range(3): | |
| self.scheduler.request_task() | |
| # 4th call must not raise β it should cycle back through the seed bank | |
| result = self.scheduler.request_task() | |
| seed_ids = {t.id for t in tasks} | |
| assert result.id in seed_ids | |
| def test_batch_reestimate_does_not_wipe_queue_with_zero_accuracy_policy(self): | |
| """_batch_reestimate with always-False policy must not leave the queue empty. | |
| This guards against the episode-50 bug where lambda t: False caused | |
| pass@k = 0.0 for every task, routing all tasks to the too-hard archive | |
| and wiping the active queue. | |
| """ | |
| import asyncio | |
| tasks = [self._make_task(f"t-{i}") for i in range(5)] | |
| snap_learnable = [ | |
| self._make_snapshot(f"t-{i}", DifficultyTier.LEARNABLE, 0.50) | |
| for i in range(5) | |
| ] | |
| # Simulate always-wrong policy: estimator returns too-hard for every task | |
| from forge_arena.models.tasks import DifficultyTier as DT | |
| snap_too_hard = [ | |
| self._make_snapshot(f"t-{i}", DT.TOO_HARD, 0.0) | |
| for i in range(5) | |
| ] | |
| self.estimator.batch_estimate.return_value = snap_too_hard | |
| asyncio.run(self.scheduler.initialise(tasks, lambda t: False)) | |
| # Trigger batch re-estimation with the always-wrong policy | |
| asyncio.run(self.scheduler._batch_reestimate(lambda t: False)) | |
| # Active queue must still be non-empty after the safety-net refill | |
| assert len(self.scheduler._active_queue) > 0 | |
| def test_difficulty_history_empty_after_initialise(self): | |
| """No snapshot history is recorded at init time. | |
| History starts accumulating only after real episodes are collected | |
| via update(). Seed tasks are placed in the queue without estimation. | |
| """ | |
| import asyncio | |
| tasks = [self._make_task("t-hist")] | |
| asyncio.run(self.scheduler.initialise(tasks, lambda t: False)) | |
| curve = self.scheduler.get_difficulty_curve() | |
| # No estimation ran, so history is empty. | |
| assert "t-hist" not in curve | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generator quality validation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _make_task( | |
| corruption_type: CorruptionType | None = None, | |
| corruption_explanation: str = "", | |
| corrupted_worker_output: str = "", | |
| ) -> Task: | |
| """Minimal Task fixture for validation tests.""" | |
| return Task( | |
| id="test-gen-001", | |
| domain=TaskDomain.CUSTOMER_SUPPORT, | |
| corruption_type=corruption_type, | |
| task_description="Summarise the refund policy.", | |
| source_material="Policy text.", | |
| clean_worker_output="30-day returns accepted.", | |
| corrupted_worker_output=corrupted_worker_output or None, | |
| corruption_location=None, | |
| corruption_explanation=corruption_explanation or None, | |
| ground_truth_correction=None, | |
| is_generated=True, | |
| ) | |
| class TestGeneratorValidation: | |
| """Unit tests for TaskVariantGenerator._validate_generated_task().""" | |
| def test_clean_task_always_valid(self): | |
| task = _make_task(corruption_type=None) | |
| assert TaskGenerator._validate_generated_task(task) is True | |
| def test_corrupted_task_with_trigger_vocab_in_explanation_valid(self): | |
| # "omits" is in FACTUAL_OMISSION trigger_vocab | |
| task = _make_task( | |
| corruption_type=CorruptionType.FACTUAL_OMISSION, | |
| corruption_explanation="The Worker omits the 30-day deadline.", | |
| ) | |
| assert TaskGenerator._validate_generated_task(task) is True | |
| def test_corrupted_task_with_trigger_vocab_in_output_valid(self): | |
| # "consistently recommends" (exact phrase) is in BIAS_INJECTION trigger_vocab | |
| task = _make_task( | |
| corruption_type=CorruptionType.BIAS_INJECTION, | |
| corrupted_worker_output="The agent consistently recommends Product A for all clients.", | |
| ) | |
| assert TaskGenerator._validate_generated_task(task) is True | |
| def test_corrupted_task_with_no_trigger_vocab_invalid(self): | |
| # Neither explanation nor output contains any AUTHORITY_FABRICATION vocab | |
| task = _make_task( | |
| corruption_type=CorruptionType.AUTHORITY_FABRICATION, | |
| corruption_explanation="Something is wrong with the answer.", | |
| corrupted_worker_output="The policy says returns are accepted.", | |
| ) | |
| assert TaskGenerator._validate_generated_task(task) is False | |
| def test_corrupted_task_with_empty_fields_invalid(self): | |
| # Corrupt task with no text fields at all | |
| task = _make_task(corruption_type=CorruptionType.TEMPORAL_SHIFT) | |
| assert TaskGenerator._validate_generated_task(task) is False | |
| def test_trigger_vocab_detected_for_each_type(self, ctype): | |
| """Each corruption type's own trigger vocabulary should make validation pass.""" | |
| from forge_arena.arena.corruptions.types import CORRUPTION_REGISTRY | |
| meta = CORRUPTION_REGISTRY[ctype] | |
| trigger_phrase = meta.trigger_vocab[0] | |
| task = _make_task( | |
| corruption_type=ctype, | |
| corruption_explanation=f"The Worker {trigger_phrase} the correct value.", | |
| ) | |
| assert TaskGenerator._validate_generated_task(task) is True | |