prashantmatlani commited on
Commit
dc44de4
·
1 Parent(s): bf34481

updated environment to get tasks from graders

Browse files
Files changed (1) hide show
  1. app/env.py +29 -2
app/env.py CHANGED
@@ -5,12 +5,33 @@ from typing import Tuple, Dict, Any
5
  from app.models import Observation, Action, Reward
6
  from app.dataset import TICKETS
7
  import random
 
8
 
9
  import sys
10
 
11
  class CustomerSupportEnv:
12
 
13
- # INTERNAL STATE REPRESENTATION -
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def _get_observation(self):
15
 
16
  total_required = len(self.ticket.get("required_info", []))
@@ -44,9 +65,15 @@ class CustomerSupportEnv:
44
  self.max_steps = 10
45
  self.last_action = None
46
 
47
- # METRICS TRACKING
48
  self.episode_stats = []
49
 
 
 
 
 
 
 
50
  def reset(self):
51
 
52
  self.last_action = None
 
5
  from app.models import Observation, Action, Reward
6
  from app.dataset import TICKETS
7
  import random
8
+ from graders import grade_easy, grade_medium, grade_hard
9
 
10
  import sys
11
 
12
  class CustomerSupportEnv:
13
 
14
+ # OBTAIN TASKS FROM GRADERS.PY
15
+ def get_tasks(self):
16
+ return [
17
+ {
18
+ "id": "easy-info-collection",
19
+ "difficulty": "easy",
20
+ "grader": grade_easy,
21
+ },
22
+ {
23
+ "id": "medium-complete-info",
24
+ "difficulty": "medium",
25
+ "grader": grade_medium,
26
+ },
27
+ {
28
+ "id": "hard-efficient-resolution",
29
+ "difficulty": "hard",
30
+ "grader": grade_hard,
31
+ },
32
+ ]
33
+
34
+ # INTERNAL STATE REPRESENTATION
35
  def _get_observation(self):
36
 
37
  total_required = len(self.ticket.get("required_info", []))
 
65
  self.max_steps = 10
66
  self.last_action = None
67
 
68
+ # METRICS TRACKING
69
  self.episode_stats = []
70
 
71
+ # CLEAN ENVIRONMENT INITIALIZATION
72
+ self.tasks = self.get_tasks()
73
+
74
+ def list_tasks(self):
75
+ return self.tasks
76
+
77
  def reset(self):
78
 
79
  self.last_action = None