prashantmatlani commited on
Commit
ecf4764
·
1 Parent(s): e3b41fc

modified inference to include [START] at the beginning of each task

Browse files
Files changed (2) hide show
  1. app/env.py +84 -14
  2. inference.py +3 -42
app/env.py CHANGED
@@ -10,7 +10,44 @@ from tasks import TASKS
10
 
11
  import sys
12
 
13
- AVAILABLE_TASKS = TASKS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_tasks():
16
  return AVAILABLE_TASKS
@@ -64,41 +101,50 @@ class CustomerSupportEnv:
64
  "status": self.state_data["status"],
65
  "step_count": self.state_data["steps_taken"],
66
  "remaining_steps": self.max_steps - self.state_data["steps_taken"],
 
67
  }
68
 
69
- def __init__(self):
 
 
 
 
 
 
 
 
70
  self.state_data = None
71
- self.max_steps = 10
72
  self.last_action = None
73
 
74
  # METRICS TRACKING
75
  self.episode_stats = []
76
 
77
- # CLEAN ENVIRONMENT INITIALIZATION
78
- self.tasks = self.get_tasks()
79
-
80
  def list_tasks(self):
81
  return self.tasks
82
 
 
83
  def reset(self):
84
 
85
  self.last_action = None
86
-
87
- # ✅ episode tracking
88
  self.current_episode_reward = 0.0
89
  self.current_steps = 0
90
  self.success = False
91
 
 
92
  self.ticket = random.choice(TICKETS)
93
 
 
 
 
94
  self.state_data = {
95
  "ticket_id": self.ticket["ticket_id"],
96
- "customer_message": self.ticket["customer_message"],
97
  "history": [],
98
  "status": "open",
99
  "priority": None,
100
  "category": None,
101
- "required_info": self.ticket["required_info"].copy(),
102
  "collected_info": {},
103
  "steps_taken": 0,
104
  "max_steps": self.max_steps,
@@ -106,13 +152,12 @@ class CustomerSupportEnv:
106
  }
107
 
108
  return self._get_observation()
109
-
110
-
111
  def step(self, action: dict):
112
 
113
- # SAFETY: ensure environment initialized
114
  if self.state_data is None:
115
- print("⚠️ step() called before reset — auto-resetting", flush=True)
116
  self.reset()
117
 
118
  reward = 0.0
@@ -281,3 +326,28 @@ class CustomerSupportEnv:
281
  "info_efficiency": round(info_eff, 3)
282
  }
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  import sys
12
 
13
+ DIFFICULTY_CONFIG = {
14
+ "easy": {
15
+ "max_steps": 8,
16
+ "noise_prob": 0.0,
17
+ "missing_info_prob": 0.1
18
+ },
19
+ "medium": {
20
+ "max_steps": 10,
21
+ "noise_prob": 0.2,
22
+ "missing_info_prob": 0.3
23
+ },
24
+ "hard": {
25
+ "max_steps": 12,
26
+ "noise_prob": 0.4,
27
+ "missing_info_prob": 0.5
28
+ }
29
+ }
30
+
31
+ # --- TASKS ---
32
+ #AVAILABLE_TASKS = TASKS
33
+
34
+ AVAILABLE_TASKS = [
35
+ {
36
+ "id": "easy-info-collection",
37
+ "difficulty": "easy",
38
+ "grader": grade_easy
39
+ },
40
+ {
41
+ "id": "medium-complete-info",
42
+ "difficulty": "medium",
43
+ "grader": grade_medium
44
+ },
45
+ {
46
+ "id": "hard-efficient-resolution",
47
+ "difficulty": "hard",
48
+ "grader": grade_hard
49
+ }
50
+ ]
51
 
52
  def get_tasks():
53
  return AVAILABLE_TASKS
 
101
  "status": self.state_data["status"],
102
  "step_count": self.state_data["steps_taken"],
103
  "remaining_steps": self.max_steps - self.state_data["steps_taken"],
104
+ "difficulty": self.difficulty # difficulty awareness
105
  }
106
 
107
+
108
+ def __init__(self, difficulty="medium", seed=None):
109
+
110
+ self.difficulty = difficulty
111
+ self.config = DIFFICULTY_CONFIG[difficulty]
112
+
113
+ if seed is not None:
114
+ random.seed(seed)
115
+
116
  self.state_data = None
117
+ self.max_steps = self.config["max_steps"]
118
  self.last_action = None
119
 
120
  # METRICS TRACKING
121
  self.episode_stats = []
122
 
 
 
 
123
  def list_tasks(self):
124
  return self.tasks
125
 
126
+
127
  def reset(self):
128
 
129
  self.last_action = None
 
 
130
  self.current_episode_reward = 0.0
131
  self.current_steps = 0
132
  self.success = False
133
 
134
+ # 🎯 Controlled ticket sampling
135
  self.ticket = random.choice(TICKETS)
136
 
137
+ # 🎯 Inject stochasticity (controlled)
138
+ noisy_message = self._inject_noise(self.ticket["customer_message"])
139
+
140
  self.state_data = {
141
  "ticket_id": self.ticket["ticket_id"],
142
+ "customer_message": noisy_message,
143
  "history": [],
144
  "status": "open",
145
  "priority": None,
146
  "category": None,
147
+ "required_info": self._mask_required_info(self.ticket["required_info"]),
148
  "collected_info": {},
149
  "steps_taken": 0,
150
  "max_steps": self.max_steps,
 
152
  }
153
 
154
  return self._get_observation()
155
+
 
156
  def step(self, action: dict):
157
 
158
+ # SAFETY: ensure environment initialized
159
  if self.state_data is None:
160
+ print("step() called before reset — auto-resetting", flush=True)
161
  self.reset()
162
 
163
  reward = 0.0
 
326
  "info_efficiency": round(info_eff, 3)
327
  }
328
 
329
+
330
+ def _inject_noise(self, message):
331
+
332
+ if random.random() < self.config["noise_prob"]:
333
+ noise_phrases = [
334
+ "pls help asap",
335
+ "this is urgent",
336
+ "not sure what's wrong",
337
+ "it’s been days"
338
+ ]
339
+ return message + " " + random.choice(noise_phrases)
340
+
341
+ return message
342
+
343
+
344
+ def _mask_required_info(self, required_fields):
345
+
346
+ masked = []
347
+
348
+ for field in required_fields:
349
+ if random.random() > self.config["missing_info_prob"]:
350
+ masked.append(field)
351
+
352
+ # ensure at least 1 required field remains
353
+ return masked if masked else required_fields
inference.py CHANGED
@@ -6,51 +6,12 @@ import json
6
  from agent_llm import get_action
7
  from app.env import CustomerSupportEnv
8
  from graders import grade_easy, grade_medium, grade_hard
9
- from tasks import TASKS
 
10
 
11
  import sys
12
 
13
- """
14
- # =========================
15
- # TASK DEFINITIONS
16
- # =========================
17
- TASKS = [
18
- {"name": "easy-info-collection", "type": "easy"},
19
- {"name": "medium-complete-info", "type": "medium"},
20
- {"name": "hard-efficient-resolution", "type": "hard"},
21
- ]
22
- """
23
-
24
- """
25
- # =========================
26
- # GRADERS (DETERMINISTIC)
27
- # =========================
28
- def get_info_efficiency(env):
29
- if env.episode_stats:
30
- return env.episode_stats[-1].get("info_efficiency", 0)
31
- return 0
32
-
33
- def grade_easy(env, success, steps, rewards):
34
- # Reward asking at least something
35
- score = 0.3 + 0.1 * len(rewards)
36
- return max(0.01, min(0.99, score))
37
-
38
- def grade_medium(env, success, steps, rewards):
39
- info_eff = get_info_efficiency(env)
40
- score = 0.5 * info_eff
41
- return max(0.01, min(0.99, score))
42
-
43
- def grade_hard(env, success, steps, rewards):
44
- info_eff = get_info_efficiency(env)
45
-
46
- score = (
47
- 0.5 * (1 if success else 0) +
48
- 0.3 * info_eff +
49
- 0.2 * (1 / (1 + steps))
50
- )
51
-
52
- return max(0.01, min(0.99, score))
53
- """
54
 
55
  def compute_score(task_type, env, success, steps, rewards):
56
 
 
6
  from agent_llm import get_action
7
  from app.env import CustomerSupportEnv
8
  from graders import grade_easy, grade_medium, grade_hard
9
+ #from tasks import TASKS
10
+ from app.env import get_tasks
11
 
12
  import sys
13
 
14
+ TASKS = get_tasks()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def compute_score(task_type, env, success, steps, rewards):
17