Spaces:
Running
Running
| """ | |
| UndertriAI — Adaptive Episode Selector (Theme 4: Self-Improvement) | |
| Wraps the existing BailDataset to provide performance-aware episode | |
| selection when adaptive mode is enabled. Falls back to uniform random | |
| (identical to existing behavior) when adaptive=False. | |
| """ | |
| import random | |
| from typing import Any, Dict, List, Optional | |
| from .performance_tracker import PerformanceTracker | |
| class AdaptiveSelector: | |
| """ | |
| Performance-aware episode selector. | |
| Selection strategy (applied in order when adaptive=True): | |
| 60%: sample from the weakest crime-type domain in current_stage | |
| 30%: replay cases where recent performance was poor (reward < 0.40) | |
| 10%: uniform random from current_stage (exploration) | |
| Always returns a valid episode dict. Never raises. | |
| """ | |
| def __init__(self, dataset, tracker: PerformanceTracker): | |
| """ | |
| Args: | |
| dataset: BailDataset instance (has _episodes, sample_episode) | |
| tracker: PerformanceTracker instance driving selection | |
| """ | |
| self.dataset = dataset | |
| self.tracker = tracker | |
| # ------------------------------------------------------------------ | |
| # Public API | |
| # ------------------------------------------------------------------ | |
| def select_episode(self, current_stage: int) -> Dict[str, Any]: | |
| """ | |
| Performance-aware selection for adaptive mode. | |
| 60% weakest domain → 30% failure replay → 10% exploration. | |
| Falls back to uniform on any failure. | |
| """ | |
| try: | |
| roll = random.random() | |
| if roll < 0.60: | |
| # Try weakest domain | |
| ep = self._select_weakest_domain(current_stage) | |
| if ep is not None: | |
| return ep | |
| if roll < 0.90: | |
| # Try failure replay | |
| ep = self._select_failure_replay(current_stage) | |
| if ep is not None: | |
| return ep | |
| # 10% exploration or fallback | |
| return self.select_episode_uniform(current_stage) | |
| except Exception: | |
| # Absolute fallback — never crash | |
| return self.select_episode_uniform(current_stage) | |
| def select_episode_uniform(self, current_stage: int) -> Dict[str, Any]: | |
| """ | |
| Pure random selection from current_stage. | |
| Identical to existing BailDataset.sample_episode() behavior. | |
| """ | |
| return self.dataset.sample_episode(stage=current_stage) | |
| # ------------------------------------------------------------------ | |
| # Internal strategies | |
| # ------------------------------------------------------------------ | |
| def _select_weakest_domain( | |
| self, current_stage: int | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Select an episode from the weakest crime-type domain. | |
| Returns None if no weak domain identified or no matching episodes. | |
| """ | |
| weak_domain = self.tracker.weakest_domain() | |
| if weak_domain is None: | |
| return None | |
| # Find episodes matching this crime type in the current stage | |
| episodes = self._get_stage_episodes(current_stage) | |
| matches = [ | |
| ep for ep in episodes | |
| if str(ep.get("crime_type", "")).strip() == weak_domain | |
| ] | |
| if not matches: | |
| return None | |
| return random.choice(matches) | |
| def _select_failure_replay( | |
| self, current_stage: int | |
| ) -> Optional[Dict[str, Any]]: | |
| """ | |
| Replay a case where the agent recently scored below 0.40. | |
| Returns None if no recent failures or no matching episodes. | |
| """ | |
| failed_ids = self.tracker.get_recent_failures(threshold=0.40) | |
| if not failed_ids: | |
| return None | |
| # Find episodes matching failed case_ids in current stage | |
| episodes = self._get_stage_episodes(current_stage) | |
| matches = [ | |
| ep for ep in episodes | |
| if ep.get("case_id", "") in failed_ids | |
| ] | |
| if not matches: | |
| return None | |
| return random.choice(matches) | |
| def _get_stage_episodes(self, stage: int) -> List[Dict[str, Any]]: | |
| """Get all episodes for a given stage from the dataset.""" | |
| try: | |
| eps = self.dataset._episodes.get(stage, []) | |
| if eps: | |
| return eps | |
| # Fallback chain matching BailDataset.sample_episode | |
| for candidate in [stage - 1, stage + 1, 1, 2, 3, 4]: | |
| if 1 <= candidate <= 4: | |
| eps = self.dataset._episodes.get(candidate, []) | |
| if eps: | |
| return eps | |
| except Exception: | |
| pass | |
| return [] | |