prashantmatlani commited on
Commit
0894e25
·
1 Parent(s): 8e156dc

implemented agents' self-learning, self-correcting without explicit training

Browse files
Files changed (6) hide show
  1. agent_llm.py +65 -91
  2. app/dataset.py +201 -88
  3. app/env.py +110 -148
  4. graders.py:Zone.Identifier +0 -0
  5. inference.py +2 -1
  6. tasks.py:Zone.Identifier +0 -0
agent_llm.py CHANGED
@@ -15,6 +15,7 @@ import json
15
  import time
16
  #from groq import Groq
17
  #from openai import OpenAI
 
18
 
19
  from app.env import CustomerSupportEnv
20
 
@@ -25,7 +26,7 @@ from app.env import CustomerSupportEnv
25
  #client = Groq(api_key=os.getenv("GROQ_API_KEY"))
26
 
27
  # =========================
28
- # OPTIONAL IMPORTS (SAFE)
29
  # =========================
30
  try:
31
  from openai import OpenAI
@@ -45,61 +46,47 @@ def get_llm_client():
45
  if OpenAI is None:
46
  return None
47
 
48
- api_key = os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
49
-
50
- if not api_key:
51
- return None # 🔥 critical
52
-
53
- try:
54
- return OpenAI(
55
- base_url=os.getenv(
56
- "API_BASE_URL",
57
- "https://router.huggingface.co/v1"
58
- ),
59
- api_key=api_key
60
- )
61
- except Exception:
62
  return None
63
 
 
 
 
 
64
 
65
  client = get_llm_client()
66
 
67
  # =========================
68
- # PROMPT (STRICT + MINIMAL)
69
  # =========================
70
- def build_prompt(obs, valid_actions):
71
  return f"""
72
- You are a decision agent for customer support.
73
-
74
- Return ONLY JSON.
75
-
76
- INPUT:
77
- Customer message: {obs["customer_message"]}
78
- Known info: {obs["known_info"]}
79
- Required fields: {obs.get("required", [])}
80
-
81
- RULES:
82
- 1. First classify (billing / technical / delivery)
83
- 2. Then collect ALL required fields
84
- 3. Then resolve
85
- 4. NEVER resolve early
86
- 5. DO NOT ask for fields already known
87
-
88
- VALID ACTION TYPES:
89
- - classify
90
- - ask_info
91
- - resolve
92
-
93
- FORMAT:
94
- {{
95
- "action": {{
96
- "type": "...",
97
- "category": "...",
98
- "priority": "...",
99
- "field": "..."
100
- }}
101
- }}
102
- """
103
 
104
 
105
  # =========================
@@ -107,20 +94,20 @@ FORMAT:
107
  # =========================
108
  def call_llm(prompt):
109
  if client is None:
110
- return None # 🔥 triggers fallback
111
 
112
  try:
113
  completion = client.chat.completions.create(
114
  model=os.getenv("MODEL_NAME", "unknown-model"),
115
  messages=[{"role": "user", "content": prompt}],
116
- temperature=0.2,
117
  response_format={"type": "json_object"}
118
  )
119
 
120
  return completion.choices[0].message.content.strip()
121
 
122
  except Exception:
123
- return None # 🔥 triggers fallback
124
 
125
 
126
  # =========================
@@ -144,22 +131,21 @@ def parse_output(text):
144
 
145
 
146
  # =========================
147
- # FALLBACK (CRITICAL)
148
  # =========================
149
- def fallback_policy(obs):
150
- msg = obs["customer_message"].lower()
151
  known = obs.get("known_info", {})
152
  required = obs.get("required", [])
153
 
154
- # classify once
155
- if "category" not in known:
156
- if "refund" in msg or "charged" in msg:
157
- return {"type": "classify", "category": "billing", "priority": "high"}
158
- if "delivery" in msg or "order" in msg:
159
- return {"type": "classify", "category": "delivery", "priority": "high"}
160
- return {"type": "classify", "category": "technical", "priority": "medium"}
161
 
162
- # ask missing (🔥 critical)
163
  missing = [f for f in required if f not in known]
164
  if missing:
165
  return {"type": "ask_info", "field": missing[0]}
@@ -188,46 +174,34 @@ def is_valid_action(action, valid_actions):
188
 
189
  return True
190
 
191
-
192
  # =========================
193
- # ACTION SELECTOR
194
  # =========================
195
  def get_action(obs, valid_actions):
196
 
197
- #known = obs.get("known_info", {})
198
-
199
- # HARD GUARD: prevent re-classification
200
- #if "category" in known:
201
- # valid_actions = [a for a in valid_actions if a["type"] != "classify"]
202
-
203
- known = obs.get("known_info", {})
204
- required = obs.get("required", [])
205
-
206
- missing = [f for f in required if f not in known]
207
-
208
- # HARD OVERRIDE (prevents LLM mistakes)
209
- if "category" in known:
210
- if missing:
211
- return {"type": "ask_info", "field": missing[0]}
212
- else:
213
- return {"type": "resolve"}
214
 
 
 
 
 
 
 
 
 
215
 
216
- prompt = build_prompt(obs, valid_actions)
 
217
 
218
- for _ in range(2): # retry loop
219
- try:
220
- output = call_llm(prompt)
221
- action = parse_output(output)
222
 
223
- if is_valid_action(action, valid_actions):
224
  return action
225
 
226
- except Exception:
227
- time.sleep(0.5)
228
 
229
- # fallback if LLM fails
230
- return fallback_policy(obs)
231
 
232
 
233
  # =========================
 
15
  import time
16
  #from groq import Groq
17
  #from openai import OpenAI
18
+ import random
19
 
20
  from app.env import CustomerSupportEnv
21
 
 
26
  #client = Groq(api_key=os.getenv("GROQ_API_KEY"))
27
 
28
  # =========================
29
+ # PURPOSE: Safe OpenAI client init
30
  # =========================
31
  try:
32
  from openai import OpenAI
 
46
  if OpenAI is None:
47
  return None
48
 
49
+ key = os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
50
+ if not key:
 
 
 
 
 
 
 
 
 
 
 
 
51
  return None
52
 
53
+ return OpenAI(
54
+ base_url=os.getenv("API_BASE_URL", "https://router.huggingface.co/v1"),
55
+ api_key=key
56
+ )
57
 
58
  client = get_llm_client()
59
 
60
  # =========================
61
+ # PURPOSE: Prompt - Strict + Minimal - encourages uncertainty-aware reasoning
62
  # =========================
63
+ def build_prompt(obs):
64
  return f"""
65
+ You are a customer support agent.
66
+
67
+ Customer message:
68
+ {obs.get("customer_message")}
69
+
70
+ Known info:
71
+ {obs.get("known_info")}
72
+
73
+ Required fields:
74
+ {obs.get("required")}
75
+
76
+ Your goal is to resolve the ticket efficiently.
77
+
78
+ Think carefully:
79
+ - You may revise earlier decisions
80
+ - Do not commit too early
81
+ - Ask missing info if unsure
82
+ - The message may be ambiguous
83
+ - Do not assume category prematurely
84
+ - Ask only necessary questions
85
+ - Avoid redundant actions
86
+
87
+ Return JSON:
88
+ {{"action": {{...}}}}
89
+ """
 
 
 
 
 
 
90
 
91
 
92
  # =========================
 
94
  # =========================
95
  def call_llm(prompt):
96
  if client is None:
97
+ return None # triggers fallback
98
 
99
  try:
100
  completion = client.chat.completions.create(
101
  model=os.getenv("MODEL_NAME", "unknown-model"),
102
  messages=[{"role": "user", "content": prompt}],
103
+ temperature=0.3,
104
  response_format={"type": "json_object"}
105
  )
106
 
107
  return completion.choices[0].message.content.strip()
108
 
109
  except Exception:
110
+ return None # triggers fallback
111
 
112
 
113
  # =========================
 
131
 
132
 
133
  # =========================
134
+ # PURPOSE: Fallback is intentionally imperfect
135
  # =========================
136
+ def fallback(obs):
137
+
138
  known = obs.get("known_info", {})
139
  required = obs.get("required", [])
140
 
141
+ # allow reclassification even if already classified
142
+ if "category" not in known or random.random() < 0.3:
143
+ return {
144
+ "type": "classify",
145
+ "category": "technical",
146
+ "priority": "medium"
147
+ }
148
 
 
149
  missing = [f for f in required if f not in known]
150
  if missing:
151
  return {"type": "ask_info", "field": missing[0]}
 
174
 
175
  return True
176
 
 
177
  # =========================
178
+ # PURPOSE: Hybrid control (LLM + adaptive fallback)
179
  # =========================
180
  def get_action(obs, valid_actions):
181
 
182
+ prompt = build_prompt(obs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ if client:
185
+ try:
186
+ resp = client.chat.completions.create(
187
+ model=os.getenv("MODEL_NAME"),
188
+ messages=[{"role": "user", "content": prompt}],
189
+ temperature=0.4,
190
+ response_format={"type": "json_object"}
191
+ )
192
 
193
+ text = resp.choices[0].message.content
194
+ parsed = json.loads(text)
195
 
196
+ action = parsed.get("action")
 
 
 
197
 
198
+ if action and "type" in action:
199
  return action
200
 
201
+ except:
202
+ pass
203
 
204
+ return fallback(obs)
 
205
 
206
 
207
  # =========================
app/dataset.py CHANGED
@@ -2,113 +2,226 @@
2
 
3
  # app/dataset.py
4
 
 
 
 
 
 
 
 
5
  TICKETS = [
6
 
7
- # Billing Issues
8
- {
9
- "ticket_id": "T1",
10
- "customer_message": "I was charged twice for my order #1234. Please refund.",
11
- "category": "billing",
12
- "priority": "high",
13
- "required_info": ["order_id"]
14
- },
15
- {
16
- "ticket_id": "T2",
17
- "customer_message": "I want to cancel my subscription and get a refund.",
18
- "category": "billing",
19
- "priority": "medium",
20
- "required_info": ["account_email"]
21
- },
22
- {
23
- "ticket_id": "T3",
24
- "customer_message": "Why was I billed after cancelling my plan?",
25
- "category": "billing",
26
- "priority": "high",
27
- "required_info": ["account_email"]
28
- },
29
  {
30
- "ticket_id": "T20",
31
- "customer_message": "I was charged twice and want a refund.",
32
- "category": "billing",
33
- "priority": "high",
34
- "required_info": ["order_id", "account_email"]
35
- },
36
 
37
- # Technical Issues
38
- {
39
- "ticket_id": "T4",
40
- "customer_message": "I can't log into my account. It says invalid credentials.",
41
- "category": "technical",
42
- "priority": "high",
43
- "required_info": ["account_email"]
44
- },
45
- {
46
- "ticket_id": "T5",
47
- "customer_message": "The app crashes every time I upload a file.",
48
- "category": "technical",
49
- "priority": "medium",
50
- "required_info": ["device_type"]
51
- },
52
- {
53
- "ticket_id": "T6",
54
- "customer_message": "Page not loading on checkout.",
55
- "category": "technical",
56
- "priority": "high",
57
- "required_info": ["browser"]
58
  },
 
 
 
 
59
  {
60
- "ticket_id": "T21",
61
- "customer_message": "App crashes when I try to checkout.",
62
- "category": "technical",
63
- "priority": "high",
64
- "required_info": ["device_type", "browser"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  },
 
 
 
 
66
  {
67
- "ticket_id": "T12",
68
- "customer_message": "App is very slow lately.",
69
- "category": "technical",
70
- "priority": "low",
71
- "required_info": ["device_type"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  },
73
 
74
- # Account Issues
 
 
75
  {
76
- "ticket_id": "T7",
77
- "customer_message": "I forgot my password and can't reset it.",
78
- "category": "account",
79
- "priority": "medium",
80
- "required_info": ["account_email"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  },
 
 
 
 
82
  {
83
- "ticket_id": "T8",
84
- "customer_message": "My account got locked for no reason.",
85
- "category": "account",
86
- "priority": "high",
87
- "required_info": ["account_email"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  },
 
 
 
 
89
  {
90
- "ticket_id": "T9",
91
- "customer_message": "How do I change my registered email address?",
92
- "category": "account",
93
- "priority": "low",
94
- "required_info": ["account_email"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  },
96
 
97
- # Edge Cases
 
 
98
  {
99
- "ticket_id": "T10",
100
- "customer_message": "Something is wrong with my account.",
101
- "category": "other",
102
- "priority": "medium",
103
- "required_info": ["account_email"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  },
 
 
 
 
105
  {
106
- "ticket_id": "T11",
107
- "customer_message": "I didn't receive my order but it shows delivered.",
108
- "category": "other",
109
- "priority": "high",
110
- "required_info": ["order_id"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  }
112
-
113
 
114
- ]
 
2
 
3
  # app/dataset.py
4
 
5
+ """
6
+ PURPOSE: Production-grade multi-intent dataset
7
+ - Introduces ambiguity (multiple valid interpretations)
8
+ - Separates perceived vs true intent
9
+ - Supports stochastic + difficulty-aware environments
10
+ """
11
+
12
  TICKETS = [
13
 
14
+ # =========================
15
+ # 1. BILLING vs DELIVERY
16
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  {
18
+ "ticket_id": "T001",
 
 
 
 
 
19
 
20
+ "variants": [
21
+ "I was charged but didn’t receive my order",
22
+ "Payment went through but nothing arrived",
23
+ "Got billed but package is missing"
24
+ ],
25
+
26
+ "noise": [
27
+ "pls check asap",
28
+ "this is urgent",
29
+ ""
30
+ ],
31
+
32
+ # AGENT CONFUSION SPACE
33
+ "possible_categories": ["billing", "delivery"],
34
+
35
+ "ground_truth": {
36
+ "category": "delivery",
37
+ "priority": "high",
38
+ "required_info": ["order_id", "account_email"]
39
+ }
 
40
  },
41
+
42
+ # =========================
43
+ # 2. TECH vs ACCOUNT
44
+ # =========================
45
  {
46
+ "ticket_id": "T002",
47
+
48
+ "variants": [
49
+ "I can’t log into my account",
50
+ "Login keeps failing with error",
51
+ "Account not accessible"
52
+ ],
53
+
54
+ "noise": [
55
+ "tried multiple times",
56
+ "not sure what's wrong",
57
+ ""
58
+ ],
59
+
60
+ "possible_categories": ["technical", "account"],
61
+
62
+ "ground_truth": {
63
+ "category": "account",
64
+ "priority": "medium",
65
+ "required_info": ["account_email", "device_type"]
66
+ }
67
  },
68
+
69
+ # =========================
70
+ # 3. BILLING vs TECH
71
+ # =========================
72
  {
73
+ "ticket_id": "T003",
74
+
75
+ "variants": [
76
+ "I got charged twice for the same order",
77
+ "Duplicate charge happened",
78
+ "Payment processed twice"
79
+ ],
80
+
81
+ "noise": [
82
+ "this is frustrating",
83
+ "",
84
+ ],
85
+
86
+ "possible_categories": ["billing", "technical"],
87
+
88
+ "ground_truth": {
89
+ "category": "billing",
90
+ "priority": "high",
91
+ "required_info": ["order_id", "account_email"]
92
+ }
93
  },
94
 
95
+ # =========================
96
+ # 4. DELIVERY (CLEAR)
97
+ # =========================
98
  {
99
+ "ticket_id": "T004",
100
+
101
+ "variants": [
102
+ "My order hasn’t arrived yet",
103
+ "Delivery is delayed",
104
+ "Still waiting for package"
105
+ ],
106
+
107
+ "noise": [
108
+ "been 5 days",
109
+ "",
110
+ ],
111
+
112
+ "possible_categories": ["delivery"],
113
+
114
+ "ground_truth": {
115
+ "category": "delivery",
116
+ "priority": "medium",
117
+ "required_info": ["order_id"]
118
+ }
119
  },
120
+
121
+ # =========================
122
+ # 5. TECH (AMBIGUOUS UI ISSUE)
123
+ # =========================
124
  {
125
+ "ticket_id": "T005",
126
+
127
+ "variants": [
128
+ "App crashes when I open it",
129
+ "Screen goes blank after launch",
130
+ "Something is wrong with the app"
131
+ ],
132
+
133
+ "noise": [
134
+ "happens randomly",
135
+ "",
136
+ ],
137
+
138
+ "possible_categories": ["technical"],
139
+
140
+ "ground_truth": {
141
+ "category": "technical",
142
+ "priority": "high",
143
+ "required_info": ["device_type", "browser"]
144
+ }
145
  },
146
+
147
+ # =========================
148
+ # 6. ACCOUNT vs BILLING
149
+ # =========================
150
  {
151
+ "ticket_id": "T006",
152
+
153
+ "variants": [
154
+ "My subscription is active but I can’t use features",
155
+ "Paid but features locked",
156
+ "Account says active but not working"
157
+ ],
158
+
159
+ "noise": [
160
+ "pls fix",
161
+ "",
162
+ ],
163
+
164
+ "possible_categories": ["account", "billing"],
165
+
166
+ "ground_truth": {
167
+ "category": "account",
168
+ "priority": "high",
169
+ "required_info": ["account_email"]
170
+ }
171
  },
172
 
173
+ # =========================
174
+ # 7. HARD: MULTI-LAYER ISSUE
175
+ # =========================
176
  {
177
+ "ticket_id": "T007",
178
+
179
+ "variants": [
180
+ "Order delayed and I was charged twice",
181
+ "Late delivery and duplicate payment issue",
182
+ "Package not here and billing looks wrong"
183
+ ],
184
+
185
+ "noise": [
186
+ "very frustrating",
187
+ "please resolve quickly",
188
+ ""
189
+ ],
190
+
191
+ "possible_categories": ["billing", "delivery"],
192
+
193
+ "ground_truth": {
194
+ "category": "billing", # root cause focus
195
+ "priority": "high",
196
+ "required_info": ["order_id", "account_email"]
197
+ }
198
  },
199
+
200
+ # =========================
201
+ # 8. HARD: VAGUE + NOISY
202
+ # =========================
203
  {
204
+ "ticket_id": "T008",
205
+
206
+ "variants": [
207
+ "Something is wrong with my account",
208
+ "Not working properly",
209
+ "Issue with my profile"
210
+ ],
211
+
212
+ "noise": [
213
+ "not sure what exactly",
214
+ "pls help",
215
+ ""
216
+ ],
217
+
218
+ "possible_categories": ["technical", "account"],
219
+
220
+ "ground_truth": {
221
+ "category": "technical",
222
+ "priority": "medium",
223
+ "required_info": ["device_type"]
224
+ }
225
  }
 
226
 
227
+ ]
app/env.py CHANGED
@@ -6,10 +6,15 @@ from app.models import Observation, Action, Reward
6
  from app.dataset import TICKETS
7
  import random
8
  from graders import grade_easy, grade_medium, grade_hard
9
- from tasks import TASKS
10
 
11
  import sys
12
 
 
 
 
 
 
13
  DIFFICULTY_CONFIG = {
14
  "easy": {
15
  "max_steps": 8,
@@ -28,9 +33,9 @@ DIFFICULTY_CONFIG = {
28
  }
29
  }
30
 
31
- # --- TASKS ---
32
- #AVAILABLE_TASKS = TASKS
33
-
34
  AVAILABLE_TASKS = [
35
  {
36
  "id": "easy-info-collection",
@@ -74,37 +79,33 @@ class CustomerSupportEnv:
74
  },
75
  ]
76
 
77
- # INTERNAL STATE REPRESENTATION
 
 
78
  def _get_observation(self):
79
 
80
- total_required = len(self.ticket.get("required_info", []))
81
- collected_required = sum(
82
- 1 for f in self.ticket.get("required_info", [])
83
- if f in self.state_data["collected_info"]
84
- )
85
 
86
- info_progress = collected_required / max(1, total_required)
87
-
88
  return {
89
- "ticket_id": self.ticket["ticket_id"],
90
- "customer_message": self.ticket["customer_message"],
91
- "history": [],
92
- "known_info": self.state_data["collected_info"],
93
- "required": self.ticket.get("required_info", []), # FULL requirement space (agent uses this)
94
- #"remaining_required": self.state_data["required_info"], # OPTIONAL (env/debug/analysis); agent_llm shouldn't use this directly - it should infer from known_info + customer_message
95
- "missing_required": [
96
- f for f in self.ticket.get("required_info", [])
97
- if f not in self.state_data["collected_info"]
98
- ],
99
- #"info_progress": len(self.state_data["collected_info"]) / 3,
100
- "info_progress": info_progress,
101
- "status": self.state_data["status"],
102
- "step_count": self.state_data["steps_taken"],
103
- "remaining_steps": self.max_steps - self.state_data["steps_taken"],
104
- "difficulty": self.difficulty # difficulty awareness
105
  }
106
 
107
-
 
 
108
  def __init__(self, difficulty="medium", seed=None):
109
 
110
  self.difficulty = difficulty
@@ -117,82 +118,86 @@ class CustomerSupportEnv:
117
  self.max_steps = self.config["max_steps"]
118
  self.last_action = None
119
 
 
 
 
120
  # METRICS TRACKING
121
  self.episode_stats = []
122
 
123
  def list_tasks(self):
124
  return self.tasks
125
 
126
-
127
  def reset(self):
128
 
129
  self.last_action = None
130
- self.current_episode_reward = 0.0
131
  self.current_steps = 0
132
  self.success = False
133
 
134
- # 🎯 Controlled ticket sampling
135
  self.ticket = random.choice(TICKETS)
 
136
 
137
- # 🎯 Inject stochasticity (controlled)
138
- noisy_message = self._inject_noise(self.ticket["customer_message"])
 
 
139
 
140
  self.state_data = {
141
  "ticket_id": self.ticket["ticket_id"],
142
- "customer_message": noisy_message,
143
- "history": [],
144
  "status": "open",
145
- "priority": None,
146
  "category": None,
147
- "required_info": self._mask_required_info(self.ticket["required_info"]),
 
148
  "collected_info": {},
149
  "steps_taken": 0,
150
- "max_steps": self.max_steps,
151
- "ground_truth": self.ticket
152
  }
153
 
154
  return self._get_observation()
155
 
 
 
 
156
  def step(self, action: dict):
157
-
158
- # SAFETY: ensure environment initialized
159
  if self.state_data is None:
160
- print("step() called before reset — auto-resetting", flush=True)
161
  self.reset()
162
 
163
- reward = 0.0
164
  done = False
165
  info = {}
166
- #info = {
167
- #"final_score": self._compute_final_score() if done else None
168
- #}
169
 
170
  collected = self.state_data["collected_info"]
171
- required = self.state_data["required_info"]
172
- gt = self.ticket
 
173
 
174
  # -----------------------
175
- # STEP PENALTY
176
  # -----------------------
177
- reward -= 0.05
178
 
179
- action_type = action.get("type")
 
180
 
181
- # -----------------------
182
- # REPEAT PENALTY
183
- # -----------------------
184
- if self.last_action == action:
185
- reward -= 0.2
186
 
187
- # -----------------------
188
- # CLASSIFY
189
- # -----------------------
190
- if action_type == "classify":
 
191
 
192
- collected["category"] = gt["category"]
193
- collected["priority"] = gt["priority"]
 
194
 
195
- reward += 0.2
 
 
 
196
 
197
  # -----------------------
198
  # ASK INFO
@@ -202,13 +207,10 @@ class CustomerSupportEnv:
202
  field = action.get("field")
203
 
204
  if field not in collected:
205
- collected[field] = "sample_value"
206
- reward += 0.3
207
-
208
- if field in required:
209
- required.remove(field)
210
  else:
211
- reward -= 0.3
212
 
213
  # -----------------------
214
  # RESOLVE
@@ -216,57 +218,26 @@ class CustomerSupportEnv:
216
  elif action_type == "resolve":
217
 
218
  done = True
219
- final_score = 0.0
220
 
221
- # classification
222
- if collected.get("category") == gt.get("category"):
223
- final_score += 0.3
224
 
225
- if collected.get("priority") == gt.get("priority"):
226
- final_score += 0.2
227
 
228
- # required info
229
- required_fields = gt.get("required_info", [])
230
- if all(f in collected for f in required_fields):
231
- final_score += 0.3
232
- self.success = True
233
- else:
234
- reward -= 0.5
235
-
236
- # resolve bonus
237
- final_score += 0.2
238
 
239
- reward += final_score
240
-
241
- # efficiency bonus
242
- optimal_steps = len(required_fields) + 1
243
- if self.state_data["steps_taken"] <= optimal_steps:
244
  reward += 0.3
245
 
246
- # episode stats
247
- collected_required = sum(1 for f in required_fields if f in collected)
248
-
249
- episode_data = {
250
- "success": self.success,
251
- "steps": self.state_data["steps_taken"],
252
- "reward": reward,
253
- "info_efficiency": collected_required / max(1, len(required_fields))
254
- }
255
-
256
- self.episode_stats.append(episode_data)
257
-
258
- info = {
259
- "final_score": final_score,
260
- "task_success": self.success,
261
- "collected_info": collected
262
- }
263
 
264
- self.last_action = action
265
- return self._get_observation(), reward, done, info
266
 
267
- # -----------------------
268
- # INVALID
269
- # -----------------------
270
  else:
271
  reward -= 0.3
272
 
@@ -274,35 +245,14 @@ class CustomerSupportEnv:
274
  # STEP UPDATE
275
  # -----------------------
276
  self.state_data["steps_taken"] += 1
277
- self.current_steps += 1
278
 
279
- # -----------------------
280
- # MAX STEP TERMINATION
281
- # -----------------------
282
- if self.state_data["steps_taken"] >= self.state_data["max_steps"]:
283
  done = True
284
- reward -= 2.0
285
-
286
- # record failure episode
287
- self.episode_stats.append({
288
- "success": False,
289
- "steps": self.state_data["steps_taken"],
290
- "reward": reward,
291
- "info_efficiency": 0
292
- })
293
 
294
- info = {
295
- "final_score": 0.0,
296
- "task_success": False
297
- }
298
-
299
- # -----------------------
300
- # SAVE STATE
301
- # -----------------------
302
- self.last_action = action
303
- self.current_episode_reward += reward
304
-
305
- return self._get_observation(), reward, done, info
306
 
307
  def state(self) -> Dict:
308
  return self.state_data
@@ -326,21 +276,32 @@ class CustomerSupportEnv:
326
  "info_efficiency": round(info_eff, 3)
327
  }
328
 
329
-
 
 
330
  def _inject_noise(self, message):
331
-
332
  if random.random() < self.config["noise_prob"]:
333
- noise_phrases = [
334
  "pls help asap",
335
- "this is urgent",
336
  "not sure what's wrong",
337
- "it’s been days"
338
- ]
339
- return message + " " + random.choice(noise_phrases)
340
-
341
  return message
342
 
343
 
 
 
 
 
 
 
 
 
 
 
 
344
  def _mask_required_info(self, required_fields):
345
 
346
  masked = []
@@ -350,4 +311,5 @@ class CustomerSupportEnv:
350
  masked.append(field)
351
 
352
  # ensure at least 1 required field remains
353
- return masked if masked else required_fields
 
 
6
  from app.dataset import TICKETS
7
  import random
8
  from graders import grade_easy, grade_medium, grade_hard
9
+ #from tasks import TASKS
10
 
11
  import sys
12
 
13
+ # =========================
14
+ # PURPOSE: Controls difficulty-driven stochasticity
15
+ # - noise_prob → message distortion
16
+ # - missing_info_prob → partial observability
17
+ # =========================
18
  DIFFICULTY_CONFIG = {
19
  "easy": {
20
  "max_steps": 8,
 
33
  }
34
  }
35
 
36
+ # =========================
37
+ # PURPOSE: Defines tasks exposed to validator
38
+ # =========================
39
  AVAILABLE_TASKS = [
40
  {
41
  "id": "easy-info-collection",
 
79
  },
80
  ]
81
 
82
+ # =========================
83
+ # PURPOSE: Build observation exposed to agent
84
+ # =========================
85
  def _get_observation(self):
86
 
87
+ required = self.state_data["required_info"]
88
+ collected = self.state_data["collected_info"]
89
+
90
+ total = len(required)
91
+ collected_count = sum(1 for f in required if f in collected)
92
 
 
 
93
  return {
94
+ "ticket_id": self.ticket["ticket_id"],
95
+ "customer_message": self.state_data["customer_message"],
96
+ "known_info": collected,
97
+ "required": required,
98
+ "missing_required": [f for f in required if f not in collected],
99
+ "info_progress": collected_count / max(1, total),
100
+ "status": self.state_data["status"],
101
+ "step_count": self.state_data["steps_taken"],
102
+ "remaining_steps": self.max_steps - self.state_data["steps_taken"],
103
+ "difficulty": self.difficulty # difficulty awareness
 
 
 
 
 
 
104
  }
105
 
106
+ # =========================
107
+ # PURPOSE: Initialize environment with difficulty & randomness
108
+ # =========================
109
  def __init__(self, difficulty="medium", seed=None):
110
 
111
  self.difficulty = difficulty
 
118
  self.max_steps = self.config["max_steps"]
119
  self.last_action = None
120
 
121
+ # self-correction tracking
122
+ self.classification_history = []
123
+
124
  # METRICS TRACKING
125
  self.episode_stats = []
126
 
127
  def list_tasks(self):
128
  return self.tasks
129
 
 
130
  def reset(self):
131
 
132
  self.last_action = None
133
+ #self.current_episode_reward = 0.0
134
  self.current_steps = 0
135
  self.success = False
136
 
 
137
  self.ticket = random.choice(TICKETS)
138
+ gt = self.ticket["ground_truth"]
139
 
140
+ msg = random.choice(self.ticket["variants"])
141
+ msg = self._inject_noise(msg)
142
+
143
+ masked_required = self._mask_required_info(gt["required_info"])
144
 
145
  self.state_data = {
146
  "ticket_id": self.ticket["ticket_id"],
147
+ "customer_message": msg,
 
148
  "status": "open",
 
149
  "category": None,
150
+ "priority": None,
151
+ "required_info": masked_required,
152
  "collected_info": {},
153
  "steps_taken": 0,
154
+ "ground_truth": gt
 
155
  }
156
 
157
  return self._get_observation()
158
 
159
+ # =========================
160
+ # PURPOSE: Core transition function with self-correction logic
161
+ # =========================
162
  def step(self, action: dict):
163
+
 
164
  if self.state_data is None:
 
165
  self.reset()
166
 
167
+ reward = -0.05
168
  done = False
169
  info = {}
 
 
 
170
 
171
  collected = self.state_data["collected_info"]
172
+ gt = self.ticket["ground_truth"]
173
+
174
+ action_type = action.get("type") if isinstance(action, dict) else None
175
 
176
  # -----------------------
177
+ # CLASSIFY (SELF-CORRECTION ENABLED)
178
  # -----------------------
179
+ if action_type == "classify":
180
 
181
+ new_cat = action.get("category")
182
+ prev_cat = collected.get("category")
183
 
184
+ collected["category"] = new_cat
185
+ collected["priority"] = action.get("priority")
 
 
 
186
 
187
+ self.classification_history.append(new_cat)
188
+
189
+ # correct classification
190
+ if new_cat == gt["category"]:
191
+ reward += 0.3
192
 
193
+ # self-correction bonus
194
+ if prev_cat and prev_cat != gt["category"] and new_cat == gt["category"]:
195
+ reward += 0.5 # major reward
196
 
197
+ # flip-flop penalty
198
+ if len(self.classification_history) >= 3:
199
+ if len(set(self.classification_history[-3:])) > 2:
200
+ reward -= 0.3
201
 
202
  # -----------------------
203
  # ASK INFO
 
207
  field = action.get("field")
208
 
209
  if field not in collected:
210
+ collected[field] = "value"
211
+ reward += 0.25
 
 
 
212
  else:
213
+ reward -= 0.2
214
 
215
  # -----------------------
216
  # RESOLVE
 
218
  elif action_type == "resolve":
219
 
220
  done = True
 
221
 
222
+ required = gt["required_info"]
223
+ all_info = all(f in collected for f in required)
 
224
 
225
+ correct_cat = collected.get("category") == gt["category"]
 
226
 
227
+ # 🔥 premature penalty
228
+ if not all_info:
229
+ reward -= 0.7
 
 
 
 
 
 
 
230
 
231
+ # scoring
232
+ if correct_cat:
 
 
 
233
  reward += 0.3
234
 
235
+ if all_info:
236
+ reward += 0.3
237
+ self.success = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ reward += 0.2 # completion bonus
 
240
 
 
 
 
241
  else:
242
  reward -= 0.3
243
 
 
245
  # STEP UPDATE
246
  # -----------------------
247
  self.state_data["steps_taken"] += 1
 
248
 
249
+ if self.state_data["steps_taken"] >= self.max_steps:
 
 
 
250
  done = True
251
+ reward -= 1.5
 
 
 
 
 
 
 
 
252
 
253
+ return self._get_observation(), reward, done, {
254
+ "task_success": self.success
255
+ }
 
 
 
 
 
 
 
 
 
256
 
257
  def state(self) -> Dict:
258
  return self.state_data
 
276
  "info_efficiency": round(info_eff, 3)
277
  }
278
 
279
+ # =========================
280
+ # PURPOSE: Apply noise to simulate real-world messy input
281
+ # =========================
282
  def _inject_noise(self, message):
 
283
  if random.random() < self.config["noise_prob"]:
284
+ noise = random.choice([
285
  "pls help asap",
 
286
  "not sure what's wrong",
287
+ "this is urgent",
288
+ "been days"
289
+ ])
290
+ return message + " " + noise
291
  return message
292
 
293
 
294
+ # =========================
295
+ # PURPOSE: Mask required fields → partial observability
296
+ # =========================
297
+ def _mask_required_info(self, required_fields):
298
+ masked = [
299
+ f for f in required_fields
300
+ if random.random() > self.config["missing_info_prob"]
301
+ ]
302
+ return masked if masked else required_fields
303
+
304
+ """
305
  def _mask_required_info(self, required_fields):
306
 
307
  masked = []
 
311
  masked.append(field)
312
 
313
  # ensure at least 1 required field remains
314
+ return masked if masked else required_fields
315
+ """
graders.py:Zone.Identifier ADDED
Binary file (25 Bytes). View file
 
inference.py CHANGED
@@ -54,7 +54,8 @@ def run_single_task(task):
54
  task_name = task["id"]
55
  task_type = task["difficulty"]
56
 
57
- env = CustomerSupportEnv()
 
58
  obs = env.reset()
59
 
60
  step_count = 0
 
54
  task_name = task["id"]
55
  task_type = task["difficulty"]
56
 
57
+ #env = CustomerSupportEnv()
58
+ env = CustomerSupportEnv(difficulty=task["difficulty"])
59
  obs = env.reset()
60
 
61
  step_count = 0
tasks.py:Zone.Identifier ADDED
Binary file (25 Bytes). View file