Omkar1806 commited on
Commit
c5fdb67
Β·
verified Β·
1 Parent(s): 1d620e0

Update env.py

Browse files
Files changed (1) hide show
  1. env.py +31 -28
env.py CHANGED
@@ -1,5 +1,4 @@
1
  from typing import List, Dict, Any, Tuple
2
- import random
3
 
4
  URGENCY_LABELS = ["low", "medium", "high"]
5
  ROUTING_LABELS = ["general", "support", "security"]
@@ -13,42 +12,50 @@ class EmailTriageEnv:
13
  self._index = 0
14
  self._done = False
15
 
16
- # βœ… TASK-WISE DATA (required for grader)
17
  def _generate_emails(self) -> List[Dict]:
18
- task_data = {
19
- "easy": [
20
  {"description": "Password reset not working", "label": [2, 1, 2]},
21
  {"description": "Billing refund request", "label": [1, 2, 2]},
22
- {"description": "App is slow and buggy", "label": [0, 1, 1]},
23
- ],
24
- "medium": [
 
 
25
  {"description": "Password reset not working", "label": [2, 1, 2]},
26
  {"description": "Billing refund request", "label": [1, 2, 2]},
27
- {"description": "App is slow and buggy", "label": [0, 1, 1]},
28
  {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
29
- {"description": "Invoice mismatch and payment issue", "label": [1, 2, 2]},
30
- ],
31
- "hard": [
 
 
32
  {"description": "Password reset not working", "label": [2, 1, 2]},
33
  {"description": "Billing refund request", "label": [1, 2, 2]},
34
- {"description": "App is slow and buggy", "label": [0, 1, 1]},
35
  {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
36
- {"description": "Invoice mismatch and payment issue", "label": [1, 2, 2]},
37
- {"description": "Ransomware attack suspected on system", "label": [2, 2, 2]},
38
- {"description": "User reports data breach and performance issues", "label": [2, 2, 2]},
39
- ],
40
- }
41
 
42
- emails = task_data.get(self.task, task_data["easy"])
43
- random.shuffle(emails)
44
- return emails
45
 
46
- # βœ… RESET
47
  def reset(self) -> Dict[str, Any]:
48
  self._queue = self._generate_emails()
49
  self._index = 0
50
  self._done = False
51
- return self.state()
 
 
 
 
 
 
52
 
53
  # βœ… STATE
54
  def state(self) -> Dict[str, Any]:
@@ -63,21 +70,17 @@ class EmailTriageEnv:
63
  "done": False
64
  }
65
 
66
- # βœ… STEP (GRADER LOGIC)
67
  def step(self, action: List[int]) -> Tuple[Dict, float, bool, Dict, Dict]:
68
  if self._done:
69
  return self.state(), 0.0, True, {}, {}
70
 
71
  correct = self._queue[self._index]["label"]
72
 
73
- # 🎯 PARTIAL REWARD (important)
74
  matches = sum(1 for a, b in zip(action, correct) if a == b)
75
  reward = matches / 3.0 # normalized [0,1]
76
 
77
- # πŸ”₯ BONUS for perfect prediction
78
- if matches == 3:
79
- reward = 1.0
80
-
81
  self._index += 1
82
 
83
  if self._index >= len(self._queue):
 
1
  from typing import List, Dict, Any, Tuple
 
2
 
3
  URGENCY_LABELS = ["low", "medium", "high"]
4
  ROUTING_LABELS = ["general", "support", "security"]
 
12
  self._index = 0
13
  self._done = False
14
 
15
+ # βœ… EXPLICIT TASK DATA (NO RANDOMNESS)
16
  def _generate_emails(self) -> List[Dict]:
17
+ if self.task == "easy":
18
+ return [
19
  {"description": "Password reset not working", "label": [2, 1, 2]},
20
  {"description": "Billing refund request", "label": [1, 2, 2]},
21
+ {"description": "App is slow", "label": [0, 1, 1]},
22
+ ]
23
+
24
+ elif self.task == "medium":
25
+ return [
26
  {"description": "Password reset not working", "label": [2, 1, 2]},
27
  {"description": "Billing refund request", "label": [1, 2, 2]},
28
+ {"description": "App is slow", "label": [0, 1, 1]},
29
  {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
30
+ {"description": "Invoice mismatch issue", "label": [1, 2, 2]},
31
+ ]
32
+
33
+ elif self.task == "hard":
34
+ return [
35
  {"description": "Password reset not working", "label": [2, 1, 2]},
36
  {"description": "Billing refund request", "label": [1, 2, 2]},
37
+ {"description": "App is slow", "label": [0, 1, 1]},
38
  {"description": "Possible phishing attempt detected", "label": [2, 2, 2]},
39
+ {"description": "Invoice mismatch issue", "label": [1, 2, 2]},
40
+ {"description": "Ransomware attack suspected", "label": [2, 2, 2]},
41
+ {"description": "Data breach reported", "label": [2, 2, 2]},
42
+ ]
 
43
 
44
+ else:
45
+ return []
 
46
 
47
+ # βœ… RESET (DETERMINISTIC)
48
  def reset(self) -> Dict[str, Any]:
49
  self._queue = self._generate_emails()
50
  self._index = 0
51
  self._done = False
52
+
53
+ return {
54
+ "description": self._queue[self._index]["description"],
55
+ "step": 0,
56
+ "remaining": len(self._queue),
57
+ "done": False
58
+ }
59
 
60
  # βœ… STATE
61
  def state(self) -> Dict[str, Any]:
 
70
  "done": False
71
  }
72
 
73
+ # βœ… STEP (CLEAR GRADER)
74
  def step(self, action: List[int]) -> Tuple[Dict, float, bool, Dict, Dict]:
75
  if self._done:
76
  return self.state(), 0.0, True, {}, {}
77
 
78
  correct = self._queue[self._index]["label"]
79
 
80
+ # 🎯 GRADER (CLEAR + NORMALIZED)
81
  matches = sum(1 for a, b in zip(action, correct) if a == b)
82
  reward = matches / 3.0 # normalized [0,1]
83
 
 
 
 
 
84
  self._index += 1
85
 
86
  if self._index >= len(self._queue):