File size: 4,704 Bytes
d8f8a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
UndertriAI — Adaptive Episode Selector (Theme 4: Self-Improvement)

Wraps the existing BailDataset to provide performance-aware episode
selection when adaptive mode is enabled. Falls back to uniform random
(identical to existing behavior) when adaptive=False.
"""

import random
from typing import Any, Dict, List, Optional

from .performance_tracker import PerformanceTracker


class AdaptiveSelector:
    """
    Performance-aware episode selector.

    Selection strategy (applied in order when adaptive=True):
      60%: sample from the weakest crime-type domain in current_stage
      30%: replay cases where recent performance was poor (reward < 0.40)
      10%: uniform random from current_stage (exploration)

    Always returns a valid episode dict. Never raises.
    """

    def __init__(self, dataset, tracker: PerformanceTracker):
        """
        Args:
            dataset: BailDataset instance (has _episodes, sample_episode)
            tracker: PerformanceTracker instance driving selection
        """
        self.dataset = dataset
        self.tracker = tracker

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def select_episode(self, current_stage: int) -> Dict[str, Any]:
        """
        Performance-aware selection for adaptive mode.

        60% weakest domain → 30% failure replay → 10% exploration.
        Falls back to uniform on any failure.
        """
        try:
            roll = random.random()

            if roll < 0.60:
                # Try weakest domain
                ep = self._select_weakest_domain(current_stage)
                if ep is not None:
                    return ep

            if roll < 0.90:
                # Try failure replay
                ep = self._select_failure_replay(current_stage)
                if ep is not None:
                    return ep

            # 10% exploration or fallback
            return self.select_episode_uniform(current_stage)

        except Exception:
            # Absolute fallback — never crash
            return self.select_episode_uniform(current_stage)

    def select_episode_uniform(self, current_stage: int) -> Dict[str, Any]:
        """
        Pure random selection from current_stage.
        Identical to existing BailDataset.sample_episode() behavior.
        """
        return self.dataset.sample_episode(stage=current_stage)

    # ------------------------------------------------------------------
    # Internal strategies
    # ------------------------------------------------------------------

    def _select_weakest_domain(
        self, current_stage: int
    ) -> Optional[Dict[str, Any]]:
        """
        Select an episode from the weakest crime-type domain.
        Returns None if no weak domain identified or no matching episodes.
        """
        weak_domain = self.tracker.weakest_domain()
        if weak_domain is None:
            return None

        # Find episodes matching this crime type in the current stage
        episodes = self._get_stage_episodes(current_stage)
        matches = [
            ep for ep in episodes
            if str(ep.get("crime_type", "")).strip() == weak_domain
        ]

        if not matches:
            return None

        return random.choice(matches)

    def _select_failure_replay(
        self, current_stage: int
    ) -> Optional[Dict[str, Any]]:
        """
        Replay a case where the agent recently scored below 0.40.
        Returns None if no recent failures or no matching episodes.
        """
        failed_ids = self.tracker.get_recent_failures(threshold=0.40)
        if not failed_ids:
            return None

        # Find episodes matching failed case_ids in current stage
        episodes = self._get_stage_episodes(current_stage)
        matches = [
            ep for ep in episodes
            if ep.get("case_id", "") in failed_ids
        ]

        if not matches:
            return None

        return random.choice(matches)

    def _get_stage_episodes(self, stage: int) -> List[Dict[str, Any]]:
        """Get all episodes for a given stage from the dataset."""
        try:
            eps = self.dataset._episodes.get(stage, [])
            if eps:
                return eps
            # Fallback chain matching BailDataset.sample_episode
            for candidate in [stage - 1, stage + 1, 1, 2, 3, 4]:
                if 1 <= candidate <= 4:
                    eps = self.dataset._episodes.get(candidate, [])
                    if eps:
                        return eps
        except Exception:
            pass
        return []