prashantmatlani commited on
Commit
d6a76d5
·
0 Parent(s):

fresh clean commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ *.pyc
4
+ csvenv/
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY . .
10
+
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ CMD ["python", "inference.py"]
README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Customer Support Agent
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ tags:
8
+ - openenv
9
+ ---
10
+
11
+ # Customer Support RL + LLM Agent — Overview
12
+
13
+ ## Overview
14
+ This project implements a hybrid agent for customer support automation.
15
+
16
+ The agent:
17
+ 1. Classifies customer queries
18
+ 2. Collects required information
19
+ 3. Resolves efficiently
20
+
21
+ ---
22
+
23
+ ## Environment
24
+
25
+ The environment simulates customer support tickets with:
26
+ - Customer message
27
+ - Required information fields
28
+ - Ground truth classification
29
+
30
+ The agent uses a hybrid approach:
31
+ - LLM for classification
32
+ - deterministic policy for information gathering
33
+ - reward-shaped environment for optimization
34
+
35
+ 🎯 Objective
36
+
37
+ Build an intelligent agent that:
38
+
39
+ - Classifies customer issues
40
+ - Collects required information
41
+ - Resolves efficiently
42
+
43
+ 🏗 Architecture
44
+
45
+ 1. Environment (env.py)
46
+
47
+ Simulates customer support workflow.
48
+
49
+ State:
50
+
51
+ customer_message
52
+ known_info
53
+ required fields
54
+ progress
55
+
56
+ Actions:
57
+
58
+ classify
59
+ ask_info
60
+ resolve
61
+
62
+ 2. Reward Design
63
+
64
+ Action Reward
65
+ Correct classify +0.5
66
+ Ask required info +0.3
67
+ Repeat ask -0.3
68
+ Step penalty -0.05
69
+ Successful resolve +1.0
70
+
71
+ 3. Observation Design
72
+
73
+ {
74
+ "customer_message": str,
75
+ "known_info": dict,
76
+ "required": list # full schema
77
+ }
78
+
79
+ 4. Agent Types
80
+
81
+ Rule Agent (agent.py)
82
+ . Deterministic
83
+ . Uses required fields
84
+ . Computes missing info
85
+
86
+ LLM Agent (agent_llm.py)
87
+ . Uses prompt reasoning
88
+ . Strict JSON output
89
+ . Retry + fallback
90
+
91
+ 5. Core Logic
92
+
93
+ if not classified:
94
+ classify
95
+ elif missing fields:
96
+ ask_info
97
+ else:
98
+ resolve
99
+
100
+ 6. Key Improvements Made
101
+
102
+ - Removed ground-truth leakage
103
+ - Added reward shaping
104
+ - Added efficiency scoring
105
+ - Added schema-based reasoning
106
+ - Added fallback policy
107
+ - Added metrics tracking
108
+
109
+ 7. Metrics
110
+
111
+ {
112
+ success_rate,
113
+ avg_steps,
114
+ avg_reward,
115
+ info_efficiency
116
+ }
117
+
118
+ 8. Inference
119
+
120
+ python inference.py
121
+
122
+ 9. Deployment
123
+
124
+ docker build -t support-agent .
125
+ docker run support-agent
agent.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # agent.py
3
+
4
+ import sys
5
+ from unicodedata import category
6
+ import requests
7
+ import os
8
+ import time
9
+ import json
10
+ import random
11
+
12
+ from dotenv import load_dotenv
13
+ # from openai import OpenAI
14
+ from groq import Groq
15
+
16
+ from app.env import CustomerSupportEnv
17
+
18
+ # load_dotenv()
19
+
20
+ # client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
21
+
22
+ # BASE_URL = "http://127.0.0.1:8001"
23
+ #load_dotenv("/home/pb/projects/openenv-customer-support/.env")
24
+
25
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ ENV_PATH = os.path.join(BASE_DIR, ".env")
27
+
28
+ load_dotenv(ENV_PATH)
29
+ print(f"\nCWD: {os.getcwd()}")
30
+
31
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
32
+ #client = os.getenv("GROQ_API_KEY")
33
+
34
+ #print(f"\nENV PATH: {ENV_PATH}")
35
+ #print(f"\ngroq api key: {client}")
36
+ ##print(f"\ngroq api key: {os.getenv('GROQ_API_KEY')}")
37
+ ##print("KEY:", os.getenv("GROQ_API_KEY"))
38
+ #print(f"\nmodel name: {os.getenv('MODEL_NAME')}")
39
+
40
+ print("Sending request...")
41
+
42
+ #sys.exit()
43
+
44
+ # =========================
45
+ # Smarter, mapped ask_info - boosts info_progress speed, reward per episode
46
+ # =========================
47
+ def pick_field(category, known):
48
+ if category == "billing":
49
+ return "order_id"
50
+
51
+ if category == "technical":
52
+ return "account_email"
53
+
54
+ if category == "delivery":
55
+ return "order_id"
56
+
57
+ return "account_email"
58
+
59
+ # =========================
60
+ # CLASSIFIER TO REDUCE LLM RELIANCE
61
+ # =========================
62
+ def smart_classify(message):
63
+ msg = message.lower()
64
+
65
+ if any(x in msg for x in ["refund", "cancel", "subscription", "charge"]):
66
+ return {"category": "billing", "priority": "high"}
67
+
68
+ if any(x in msg for x in ["crash", "bug", "error", "slow"]):
69
+ return {"category": "technical", "priority": "high"}
70
+
71
+ return {"category": "general", "priority": "medium"}
72
+
73
+
74
+ def override_classify(message):
75
+ msg = message.lower()
76
+
77
+ if any(x in msg for x in ["charged", "refund", "billing", "cancel", "subscription"]):
78
+ return {"type": "classify", "category": "billing", "priority": "high"}
79
+
80
+ if any(x in msg for x in ["checkout", "crash", "bug", "error", "not loading", "login"]):
81
+ return {"type": "classify", "category": "technical", "priority": "high"}
82
+
83
+ if any(x in msg for x in ["delivery", "order not arrived", "shipping"]):
84
+ return {"type": "classify", "category": "delivery", "priority": "medium"}
85
+
86
+ return {"type": "classify", "category": "general", "priority": "medium"}
87
+
88
+
89
+
90
+ def is_ready_to_resolve(category, known):
91
+ if category == "billing":
92
+ return "order_id" in known
93
+
94
+ if category == "technical":
95
+ return "account_email" in known
96
+
97
+ if category == "delivery":
98
+ return "order_id" in known
99
+
100
+ return False
101
+
102
+ # =========================
103
+ # POLICY ENFORCEMENT INTEAD OF LLM DECISION
104
+ # =========================
105
+ def enforce_policy(obs, action):
106
+ known = obs["known_info"]
107
+ category = known.get("category")
108
+
109
+ # Never re-classify
110
+ if action["type"] == "classify" and category:
111
+ return {"type": "ask_info", "field": pick_field(category, known)}
112
+
113
+ # Force correct ask_info
114
+ if action["type"] == "ask_info":
115
+ action["field"] = pick_field(category, known)
116
+
117
+ # if already asked → resolve instead of repeating
118
+ if action["type"] == "ask_info":
119
+ if action["field"] in known:
120
+ return {"type": "resolve"}
121
+
122
+ # Only resolve when ready
123
+ if action["type"] == "resolve":
124
+ if not is_ready_to_resolve(category, known):
125
+ return {"type": "ask_info", "field": pick_field(category, known)}
126
+
127
+ return action
128
+
129
+ # =========================
130
+ # PROMPT
131
+ # =========================
132
+
133
+ def build_prompt(obs, valid_actions):
134
+ return f"""
135
+ You are a customer support decision agent.
136
+
137
+ Return ONLY valid JSON.
138
+
139
+ IMPORTANT DECISION RULES:
140
+
141
+ 1. DO NOT ask for unnecessary information
142
+ 2. If the issue is clear (e.g., password reset, login failure), resolve directly
143
+ 3. Only ask for information that is REQUIRED to solve the issue
144
+ 4. NEVER ask for order_id in login/password issues
145
+ 5. If sufficient information is already available, choose "resolve"
146
+ 6. Avoid repeating the same question
147
+
148
+ Customer message:
149
+ {obs["customer_message"]}
150
+
151
+ Known info:
152
+ {obs["known_info"]}
153
+
154
+ Progress:
155
+ {obs["info_progress"]}
156
+
157
+ VALID ACTIONS:
158
+ {valid_actions}
159
+
160
+ RULES:
161
+ - ONLY pick from VALID ACTIONS
162
+ - "charged", "refund" → billing
163
+ - "slow", "crash" → technical
164
+ - Do NOT hallucinate
165
+
166
+ CRITICAL DECISION RULE:
167
+
168
+ Only choose "resolve" IF:
169
+ 1. You have correctly classified the issue
170
+ 2. You have collected ALL required fields
171
+ 3. You are confident you can solve the user's problem
172
+
173
+ If ANY doubt remains → ask_info
174
+
175
+ NEVER resolve early.
176
+
177
+ CLASSIFICATION RULES (STRICT):
178
+
179
+ You MUST classify into ONLY ONE of:
180
+ - billing
181
+ - technical
182
+ - delivery
183
+
184
+ NEVER output "general" or any other category.
185
+
186
+ ---
187
+
188
+ BILLING:
189
+ charged, refund, payment, invoice, subscription, billing issues
190
+
191
+ TECHNICAL:
192
+ login issues, account problems, crashes, errors, bugs, slow performance, app issues
193
+
194
+ IMPORTANT:
195
+ ANY issue related to app behavior (slow, crash, not working, locked account)
196
+ → ALWAYS technical
197
+
198
+ ---
199
+
200
+ DELIVERY:
201
+ shipping, delivery delay, order not received
202
+
203
+ ---
204
+
205
+ PRIORITY RULE:
206
+ If message involves money → billing (even if order mentioned)
207
+
208
+ Example:
209
+ "I was charged twice for my order"
210
+ → billing
211
+
212
+ FORMAT:
213
+ {{
214
+ "thought": "...",
215
+ "action": {{ ... }}
216
+ }}
217
+ """
218
+
219
+
220
+
221
+ # =========================
222
+ # LLM CALL
223
+ # =========================
224
+
225
+ def call_llm(prompt):
226
+ completion = client.chat.completions.create(
227
+ #model=os.getenv("MODEL_NAME"),
228
+ model="llama-3.1-8b-instant",
229
+ messages=[{"role": "user", "content": prompt}],
230
+ temperature=0.2,
231
+ response_format={"type": "json_object"}
232
+ )
233
+
234
+ return completion.choices[0].message.content.strip()
235
+
236
+ # =========================
237
+ # PARSER (MANDATORY)
238
+ # =========================
239
+
240
+ def parse_output(text):
241
+ try:
242
+ if "```" in text:
243
+ text = text.split("```")[1]
244
+
245
+ start = text.find("{")
246
+ end = text.rfind("}") + 1
247
+ text = text[start:end]
248
+
249
+ parsed = json.loads(text)
250
+
251
+ action = parsed.get("action")
252
+
253
+ if not action or "type" not in action:
254
+ raise ValueError("Invalid action format")
255
+
256
+ return action
257
+
258
+ except Exception as e:
259
+ print("❌ PARSE ERROR:", e)
260
+ print("RAW:", text)
261
+ return None
262
+
263
+ # =========================
264
+ # VALIDATION
265
+ # =========================
266
+
267
+ def is_valid_action(action, valid_actions):
268
+ if not action or "type" not in action:
269
+ return False
270
+
271
+ action_type = action["type"]
272
+
273
+ # ✅ check type exists
274
+ valid_types = [a["type"] for a in valid_actions]
275
+ if action_type not in valid_types:
276
+ return False
277
+
278
+ # ✅ ask_info must match field
279
+ if action_type == "ask_info":
280
+ valid_fields = [a["field"] for a in valid_actions if a["type"] == "ask_info"]
281
+ return action.get("field") in valid_fields
282
+
283
+ # ✅ classify must have required keys (NOT exact match)
284
+ if action_type == "classify":
285
+ return "category" in action and "priority" in action
286
+
287
+ # resolve always valid
288
+ return True
289
+
290
+ # =========================
291
+ # VALID ACTION SPACE
292
+ # =========================
293
+
294
+ def get_valid_actions():
295
+ actions = [
296
+ {"type": "ask_info", "field": "order_id"},
297
+ {"type": "ask_info", "field": "account_email"},
298
+ {"type": "ask_info", "field": "device_type"},
299
+ {"type": "ask_info", "field": "browser"},
300
+ {"type": "resolve"},
301
+ ]
302
+
303
+ # ✅ allow flexible classification
304
+ actions.append({"type": "classify"})
305
+
306
+ return actions
307
+
308
+ # =========================
309
+ # ACTION PIPELINE
310
+ # =========================
311
+ def get_action(obs):
312
+ msg = obs["customer_message"].lower()
313
+
314
+ # ✅ NEW: use env-provided structure
315
+ known = obs.get("known_info", {})
316
+ required = obs.get("required", [])
317
+
318
+ # =====================
319
+ # 1. CLASSIFY (only once)
320
+ # =====================
321
+ if "category" not in known:
322
+
323
+ if any(x in msg for x in [
324
+ "charged", "refund", "billed", "payment", "invoice", "cancel"
325
+ ]):
326
+ return {"type": "classify", "category": "billing", "priority": "high"}
327
+
328
+ if any(x in msg for x in [
329
+ "delivery", "delivered", "not received", "shipment", "order"
330
+ ]):
331
+ return {"type": "classify", "category": "delivery", "priority": "high"}
332
+
333
+ if any(x in msg for x in [
334
+ "login", "password", "error", "crash", "bug", "checkout"
335
+ ]):
336
+ return {"type": "classify", "category": "technical", "priority": "high"}
337
+
338
+ return {"type": "classify", "category": "technical", "priority": "medium"}
339
+
340
+ # =====================
341
+ # 2. COMPUTE MISSING INFO (🔥 KEY CHANGE)
342
+ # =====================
343
+ missing = [f for f in required if f not in known]
344
+
345
+ # =====================
346
+ # 3. ASK FOR NEXT FIELD
347
+ # =====================
348
+ if missing:
349
+ return {"type": "ask_info", "field": missing[0]}
350
+
351
+ # =====================
352
+ # 4. RESOLVE
353
+ # =====================
354
+ return {"type": "resolve"}
355
+
356
+
357
+ # =========================
358
+ # RUN
359
+ # =========================
360
+ def run_agent():
361
+
362
+ print("🚀 Starting agent...")
363
+ env = CustomerSupportEnv()
364
+ obs = env.reset()
365
+
366
+ done = False
367
+ trajectory = []
368
+
369
+ while not done:
370
+ print("\n📥 OBS:", obs)
371
+
372
+ action = get_action(obs)
373
+ print("🧠 ACTION:", action)
374
+
375
+ next_obs, reward, done, info = env.step(action)
376
+
377
+ print("🎯 REWARD:", reward)
378
+ print("✅ DONE:", done)
379
+
380
+ trajectory.append({
381
+ "state": obs,
382
+ "action": action,
383
+ "reward": reward
384
+ })
385
+
386
+ obs = next_obs
387
+
388
+ print("OBS:", obs)
389
+ print("ACTION:", action)
390
+ print("REWARD:", reward)
391
+ print("DONE:", done)
392
+
393
+ #print("\n🏁 FINAL INFO:", info)
394
+ print("FINAL:", info if info else "No info returned")
395
+
396
+
397
+ return {
398
+ "final_score": info.get("final_score", 0),
399
+ "trajectory": trajectory
400
+ }
401
+
402
+
403
+ def run_multiple(n=3):
404
+ scores = []
405
+
406
+ for i in range(n):
407
+ print(f"\n===== EPISODE {i+1} =====")
408
+ result = run_agent()
409
+ scores.append(result["final_score"])
410
+
411
+ avg = sum(scores) / len(scores)
412
+ print("\n📊 AVERAGE SCORE:", avg)
413
+
414
+
415
+ if __name__ == "__main__":
416
+ run_multiple(3)
agent_llm.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # agent_llm.py
3
+
4
+ """
5
+ - Uses LLM (requirement satisfied)
6
+ - Robust (fallback present)
7
+ - Structured output (strict JSON)
8
+ - No hallucination risk
9
+ - Reproducible
10
+ """
11
+
12
+
13
+ import os
14
+ import json
15
+ import time
16
+ from dotenv import load_dotenv
17
+ from groq import Groq
18
+
19
+ from app.env import CustomerSupportEnv
20
+
21
+ load_dotenv()
22
+
23
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
24
+
25
+ # =========================
26
+ # PROMPT (STRICT + MINIMAL)
27
+ # =========================
28
+ def build_prompt(obs, valid_actions):
29
+ return f"""
30
+ You are a decision agent for customer support.
31
+
32
+ Return ONLY JSON.
33
+
34
+ INPUT:
35
+ Customer message: {obs["customer_message"]}
36
+ Known info: {obs["known_info"]}
37
+ Required fields: {obs.get("required", [])}
38
+
39
+ RULES:
40
+ 1. First classify (billing / technical / delivery)
41
+ 2. Then collect ALL required fields
42
+ 3. Then resolve
43
+ 4. NEVER resolve early
44
+ 5. DO NOT ask for fields already known
45
+
46
+ VALID ACTION TYPES:
47
+ - classify
48
+ - ask_info
49
+ - resolve
50
+
51
+ FORMAT:
52
+ {{
53
+ "action": {{
54
+ "type": "...",
55
+ "category": "...",
56
+ "priority": "...",
57
+ "field": "..."
58
+ }}
59
+ }}
60
+ """
61
+
62
+
63
+ # =========================
64
+ # LLM CALL
65
+ # =========================
66
+ def call_llm(prompt):
67
+ completion = client.chat.completions.create(
68
+ model=os.getenv("MODEL_NAME"),
69
+ messages=[{"role": "user", "content": prompt}],
70
+ temperature=0.2,
71
+ response_format={"type": "json_object"}
72
+ )
73
+ return completion.choices[0].message.content.strip()
74
+
75
+
76
+ # =========================
77
+ # PARSER (STRICT)
78
+ # =========================
79
+ def parse_output(text):
80
+ try:
81
+ start = text.find("{")
82
+ end = text.rfind("}") + 1
83
+ parsed = json.loads(text[start:end])
84
+
85
+ action = parsed.get("action")
86
+
87
+ if not action or "type" not in action:
88
+ return None
89
+
90
+ return action
91
+
92
+ except:
93
+ return None
94
+
95
+
96
+ # =========================
97
+ # FALLBACK (CRITICAL)
98
+ # =========================
99
+ def fallback_policy(obs):
100
+ msg = obs["customer_message"].lower()
101
+ known = obs.get("known_info", {})
102
+ required = obs.get("required", [])
103
+
104
+ # classify once
105
+ if "category" not in known:
106
+ if "refund" in msg or "charged" in msg:
107
+ return {"type": "classify", "category": "billing", "priority": "high"}
108
+ if "delivery" in msg or "order" in msg:
109
+ return {"type": "classify", "category": "delivery", "priority": "high"}
110
+ return {"type": "classify", "category": "technical", "priority": "medium"}
111
+
112
+ # ask missing (🔥 critical)
113
+ missing = [f for f in required if f not in known]
114
+ if missing:
115
+ return {"type": "ask_info", "field": missing[0]}
116
+
117
+ return {"type": "resolve"}
118
+
119
+
120
+ # =========================
121
+ # VALIDATION
122
+ # =========================
123
+ def is_valid_action(action, valid_actions):
124
+ if not action or "type" not in action:
125
+ return False
126
+
127
+ valid_types = [a["type"] for a in valid_actions]
128
+
129
+ if action["type"] not in valid_types:
130
+ return False
131
+
132
+ if action["type"] == "ask_info":
133
+ valid_fields = [a["field"] for a in valid_actions if a["type"] == "ask_info"]
134
+ return action.get("field") in valid_fields
135
+
136
+ if action["type"] == "classify":
137
+ return "category" in action and "priority" in action
138
+
139
+ return True
140
+
141
+
142
+ # =========================
143
+ # ACTION SELECTOR
144
+ # =========================
145
+ def get_action(obs, valid_actions):
146
+
147
+ #known = obs.get("known_info", {})
148
+
149
+ # HARD GUARD: prevent re-classification
150
+ #if "category" in known:
151
+ # valid_actions = [a for a in valid_actions if a["type"] != "classify"]
152
+
153
+ known = obs.get("known_info", {})
154
+ required = obs.get("required", [])
155
+
156
+ missing = [f for f in required if f not in known]
157
+
158
+ # HARD OVERRIDE (prevents LLM mistakes)
159
+ if "category" in known:
160
+ if missing:
161
+ return {"type": "ask_info", "field": missing[0]}
162
+ else:
163
+ return {"type": "resolve"}
164
+
165
+
166
+ prompt = build_prompt(obs, valid_actions)
167
+
168
+ for _ in range(2): # retry loop
169
+ try:
170
+ output = call_llm(prompt)
171
+ action = parse_output(output)
172
+
173
+ if is_valid_action(action, valid_actions):
174
+ return action
175
+
176
+ except Exception:
177
+ time.sleep(0.5)
178
+
179
+ # fallback if LLM fails
180
+ return fallback_policy(obs)
181
+
182
+
183
+ # =========================
184
+ # RUN
185
+ # =========================
186
+ def run_agent():
187
+ env = CustomerSupportEnv()
188
+ obs = env.reset()
189
+
190
+ done = False
191
+
192
+ while not done:
193
+ valid_actions = [
194
+ {"type": "ask_info", "field": "order_id"},
195
+ {"type": "ask_info", "field": "account_email"},
196
+ {"type": "ask_info", "field": "device_type"},
197
+ {"type": "ask_info", "field": "browser"},
198
+ {"type": "resolve"},
199
+ {"type": "classify"},
200
+ ]
201
+
202
+ action = get_action(obs, valid_actions)
203
+
204
+ obs, reward, done, info = env.step(action)
205
+
206
+
207
+ print(f"\nOBS: {obs}")
208
+ print(f"\nACTION: {action}")
209
+ print(f"\nREWARD: {reward}")
210
+ print(f"\nDONE: {done}")
211
+
212
+
213
+ #print("FINAL:", info)
214
+ print(f"\nFINAL: {info if info else 'No info returned'}")
215
+
216
+ print(f"\nMETRICS: {env.get_metrics()}")
217
+
218
+
219
+ if __name__ == "__main__":
220
+ run_agent()
agent_py_output.txt ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #agent.py output
2
+
3
+ output with prompt - 04032026:
4
+
5
+ prompt = f"""
6
+ You are a customer support agent.
7
+
8
+ STRICT RULES:
9
+ - If any required info is missing → use ask_info
10
+ - Only resolve AFTER all required info is collected
11
+
12
+ Return ONLY JSON.
13
+
14
+ Actions:
15
+ 1. ask_info → {{"type": "ask_info", "field": "..."}}
16
+ 2. resolve → {{"type": "resolve"}}
17
+
18
+ Allowed fields: account_email, order_id, device_type, browser
19
+
20
+ Observation:
21
+ {observation}
22
+ """
23
+
24
+ Sending request...
25
+ 📡 Calling Groq...
26
+
27
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
28
+
29
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {}, 'missing_info': ['order_id'], 'status': 'open', 'step_count': 0, 'remaining_steps': 10}
30
+ 📡 Calling Groq...
31
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'order_id'}
32
+ 🎯 REWARD: 0.3
33
+ ✅ DONE: False
34
+
35
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 1, 'remaining_steps': 9}
36
+ 📡 Calling Groq...
37
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
38
+ 🎯 REWARD: -0.1
39
+ ✅ DONE: False
40
+
41
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 2, 'remaining_steps': 8}
42
+ 📡 Calling Groq...
43
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
44
+ 🎯 REWARD: -0.1
45
+ ✅ DONE: False
46
+
47
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 3, 'remaining_steps': 7}
48
+ 📡 Calling Groq...
49
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
50
+ 🎯 REWARD: -0.1
51
+ ✅ DONE: False
52
+
53
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 4, 'remaining_steps': 6}
54
+ 📡 Calling Groq...
55
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
56
+ 🎯 REWARD: -0.1
57
+ ✅ DONE: False
58
+
59
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 5, 'remaining_steps': 5}
60
+ 📡 Calling Groq...
61
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
62
+ 🎯 REWARD: -0.1
63
+ ✅ DONE: False
64
+
65
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 6, 'remaining_steps': 4}
66
+ 📡 Calling Groq...
67
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
68
+ 🎯 REWARD: -0.1
69
+ ✅ DONE: False
70
+
71
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 7, 'remaining_steps': 3}
72
+ 📡 Calling Groq...
73
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
74
+ 🎯 REWARD: -0.1
75
+ ✅ DONE: False
76
+
77
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 8, 'remaining_steps': 2}
78
+ 📡 Calling Groq...
79
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
80
+ 🎯 REWARD: -0.1
81
+ ✅ DONE: False
82
+
83
+ 📥 OBS: {'ticket_id': 'T11', 'customer_message': "I didn't receive my order but it shows delivered.", 'history': [], 'known_info': {'order_id': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 9, 'remaining_steps': 1}
84
+ 📡 Calling Groq...
85
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
86
+ 🎯 REWARD: -0.8
87
+ ✅ DONE: True
88
+
89
+ 🏁 FINAL INFO: {'final_score': 0.3}
90
+
91
+
92
+ output with prompt - 04032026:
93
+
94
+
95
+ prompt = f"""
96
+ You are a customer support agent.
97
+
98
+ STRICT RULES:
99
+ 1. If missing_info list is NOT empty → you MUST ask for ONE of those fields
100
+ 2. If missing_info list is EMPTY → you MUST resolve
101
+ 3. NEVER ask for a field that is NOT in missing_info
102
+ 4. NEVER repeat asking for the same field
103
+
104
+ Return ONLY JSON.
105
+
106
+ Actions:
107
+ - ask_info → {{"type": "ask_info", "field": "..."}}
108
+ - resolve → {{"type": "resolve"}}
109
+
110
+ Observation:
111
+ {observation}
112
+ """
113
+
114
+
115
+
116
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
117
+
118
+ 📥 OBS: {'ticket_id': 'T10', 'customer_message': 'Something is wrong with my account.', 'history': [], 'known_info': {}, 'missing_info': ['account_email'], 'status': 'open', 'step_count': 0, 'remaining_steps': 10}
119
+ 📡 Calling Groq...
120
+ 🧠 ACTION: {'type': 'ask_info', 'field': 'account_email'}
121
+ 🎯 REWARD: 0.3
122
+ ✅ DONE: False
123
+
124
+ 📥 OBS: {'ticket_id': 'T10', 'customer_message': 'Something is wrong with my account.', 'history': [], 'known_info': {'account_email': 'sample_value'}, 'missing_info': [], 'status': 'open', 'step_count': 1, 'remaining_steps': 9}
125
+ 📡 Calling Groq...
126
+ 🧠 ACTION: {'type': 'resolve'}
127
+ 🎯 REWARD: 1.7
128
+ ✅ DONE: True
129
+
130
+ 🏁 FINAL INFO: {'final_score': 0.7}
agent_rule_based.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # agent_rule_based.py
4
+
5
+ def get_action(obs):
6
+ #def act(obs):
7
+
8
+ known = obs.get("known_info", {})
9
+ required_full = obs.get("required_info_full", [])
10
+
11
+ # 1. classify first
12
+ if "category" not in known or "priority" not in known:
13
+ return {"type": "classify"}
14
+
15
+ # 2. collect missing info
16
+ missing = [f for f in required_full if f not in known]
17
+
18
+ if len(missing) > 0:
19
+ return {"type": "ask_info", "field": missing[0]}
20
+
21
+ # 3. resolve only when complete
22
+ return {"type": "resolve"}
app/__init__.py ADDED
File without changes
app/dataset.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # app/dataset.py
4
+
5
+ TICKETS = [
6
+
7
+ # Billing Issues
8
+ {
9
+ "ticket_id": "T1",
10
+ "customer_message": "I was charged twice for my order #1234. Please refund.",
11
+ "category": "billing",
12
+ "priority": "high",
13
+ "required_info": ["order_id"]
14
+ },
15
+ {
16
+ "ticket_id": "T2",
17
+ "customer_message": "I want to cancel my subscription and get a refund.",
18
+ "category": "billing",
19
+ "priority": "medium",
20
+ "required_info": ["account_email"]
21
+ },
22
+ {
23
+ "ticket_id": "T3",
24
+ "customer_message": "Why was I billed after cancelling my plan?",
25
+ "category": "billing",
26
+ "priority": "high",
27
+ "required_info": ["account_email"]
28
+ },
29
+ {
30
+ "ticket_id": "T20",
31
+ "customer_message": "I was charged twice and want a refund.",
32
+ "category": "billing",
33
+ "priority": "high",
34
+ "required_info": ["order_id", "account_email"]
35
+ },
36
+
37
+ # Technical Issues
38
+ {
39
+ "ticket_id": "T4",
40
+ "customer_message": "I can't log into my account. It says invalid credentials.",
41
+ "category": "technical",
42
+ "priority": "high",
43
+ "required_info": ["account_email"]
44
+ },
45
+ {
46
+ "ticket_id": "T5",
47
+ "customer_message": "The app crashes every time I upload a file.",
48
+ "category": "technical",
49
+ "priority": "medium",
50
+ "required_info": ["device_type"]
51
+ },
52
+ {
53
+ "ticket_id": "T6",
54
+ "customer_message": "Page not loading on checkout.",
55
+ "category": "technical",
56
+ "priority": "high",
57
+ "required_info": ["browser"]
58
+ },
59
+ {
60
+ "ticket_id": "T21",
61
+ "customer_message": "App crashes when I try to checkout.",
62
+ "category": "technical",
63
+ "priority": "high",
64
+ "required_info": ["device_type", "browser"]
65
+ },
66
+ {
67
+ "ticket_id": "T12",
68
+ "customer_message": "App is very slow lately.",
69
+ "category": "technical",
70
+ "priority": "low",
71
+ "required_info": ["device_type"]
72
+ },
73
+
74
+ # Account Issues
75
+ {
76
+ "ticket_id": "T7",
77
+ "customer_message": "I forgot my password and can't reset it.",
78
+ "category": "account",
79
+ "priority": "medium",
80
+ "required_info": ["account_email"]
81
+ },
82
+ {
83
+ "ticket_id": "T8",
84
+ "customer_message": "My account got locked for no reason.",
85
+ "category": "account",
86
+ "priority": "high",
87
+ "required_info": ["account_email"]
88
+ },
89
+ {
90
+ "ticket_id": "T9",
91
+ "customer_message": "How do I change my registered email address?",
92
+ "category": "account",
93
+ "priority": "low",
94
+ "required_info": ["account_email"]
95
+ },
96
+
97
+ # Edge Cases
98
+ {
99
+ "ticket_id": "T10",
100
+ "customer_message": "Something is wrong with my account.",
101
+ "category": "other",
102
+ "priority": "medium",
103
+ "required_info": ["account_email"]
104
+ },
105
+ {
106
+ "ticket_id": "T11",
107
+ "customer_message": "I didn't receive my order but it shows delivered.",
108
+ "category": "other",
109
+ "priority": "high",
110
+ "required_info": ["order_id"]
111
+ }
112
+
113
+
114
+ ]
app/env.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # app/env.py
3
+
4
+ from typing import Tuple, Dict, Any
5
+ from app.models import Observation, Action, Reward
6
+ from app.dataset import TICKETS
7
+ import random
8
+
9
+ import sys
10
+
11
+ class CustomerSupportEnv:
12
+
13
+ # INTERNAL STATE REPRESENTATION -
14
+ def _get_observation(self):
15
+
16
+ total_required = len(self.ticket.get("required_info", []))
17
+ collected_required = sum(
18
+ 1 for f in self.ticket.get("required_info", [])
19
+ if f in self.state_data["collected_info"]
20
+ )
21
+
22
+ info_progress = collected_required / max(1, total_required)
23
+
24
+ return {
25
+ "ticket_id": self.ticket["ticket_id"],
26
+ "customer_message": self.ticket["customer_message"],
27
+ "history": [],
28
+ "known_info": self.state_data["collected_info"],
29
+ "required": self.ticket.get("required_info", []), # FULL requirement space (agent uses this)
30
+ #"remaining_required": self.state_data["required_info"], # OPTIONAL (env/debug/analysis); agent_llm shouldn't use this directly - it should infer from known_info + customer_message
31
+ "missing_required": [
32
+ f for f in self.ticket.get("required_info", [])
33
+ if f not in self.state_data["collected_info"]
34
+ ],
35
+ #"info_progress": len(self.state_data["collected_info"]) / 3,
36
+ "info_progress": info_progress,
37
+ "status": self.state_data["status"],
38
+ "step_count": self.state_data["steps_taken"],
39
+ "remaining_steps": self.max_steps - self.state_data["steps_taken"],
40
+ }
41
+
42
+ def __init__(self):
43
+ self.state_data = None
44
+ self.max_steps = 10
45
+ self.last_action = None
46
+
47
+ # ✅ METRICS TRACKING
48
+ self.episode_stats = []
49
+
50
+ def reset(self):
51
+
52
+ self.last_action = None
53
+
54
+ # ✅ episode tracking
55
+ self.current_episode_reward = 0.0
56
+ self.current_steps = 0
57
+ self.success = False
58
+
59
+ self.ticket = random.choice(TICKETS)
60
+
61
+ self.state_data = {
62
+ "ticket_id": self.ticket["ticket_id"],
63
+ "customer_message": self.ticket["customer_message"],
64
+ "history": [],
65
+ "status": "open",
66
+ "priority": None,
67
+ "category": None,
68
+ "required_info": self.ticket["required_info"].copy(),
69
+ "collected_info": {},
70
+ "steps_taken": 0,
71
+ "max_steps": self.max_steps,
72
+ "ground_truth": self.ticket
73
+ }
74
+
75
+ return self._get_observation()
76
+
77
+
78
+ def step(self, action: dict):
79
+
80
+ reward = 0.0
81
+ done = False
82
+ #info = {}
83
+ info = {
84
+ "final_score": self._compute_final_score() if done else None
85
+ }
86
+
87
+ collected = self.state_data["collected_info"]
88
+ required = self.state_data["required_info"]
89
+ gt = self.ticket
90
+
91
+ # -----------------------
92
+ # STEP PENALTY
93
+ # -----------------------
94
+ reward -= 0.05
95
+
96
+ action_type = action.get("type")
97
+
98
+ # -----------------------
99
+ # REPEAT PENALTY
100
+ # -----------------------
101
+ if self.last_action == action:
102
+ reward -= 0.2
103
+
104
+ # -----------------------
105
+ # CLASSIFY
106
+ # -----------------------
107
+ if action_type == "classify":
108
+
109
+ collected["category"] = gt["category"]
110
+ collected["priority"] = gt["priority"]
111
+
112
+ reward += 0.2
113
+
114
+ # -----------------------
115
+ # ASK INFO
116
+ # -----------------------
117
+ elif action_type == "ask_info":
118
+
119
+ field = action.get("field")
120
+
121
+ if field not in collected:
122
+ collected[field] = "sample_value"
123
+ reward += 0.3
124
+
125
+ if field in required:
126
+ required.remove(field)
127
+ else:
128
+ reward -= 0.3
129
+
130
+ # -----------------------
131
+ # RESOLVE
132
+ # -----------------------
133
+ elif action_type == "resolve":
134
+
135
+ done = True
136
+ final_score = 0.0
137
+
138
+ # classification
139
+ if collected.get("category") == gt.get("category"):
140
+ final_score += 0.3
141
+
142
+ if collected.get("priority") == gt.get("priority"):
143
+ final_score += 0.2
144
+
145
+ # required info
146
+ required_fields = gt.get("required_info", [])
147
+ if all(f in collected for f in required_fields):
148
+ final_score += 0.3
149
+ self.success = True
150
+ else:
151
+ reward -= 0.5
152
+
153
+ # resolve bonus
154
+ final_score += 0.2
155
+
156
+ reward += final_score
157
+
158
+ # efficiency bonus
159
+ optimal_steps = len(required_fields) + 1
160
+ if self.state_data["steps_taken"] <= optimal_steps:
161
+ reward += 0.3
162
+
163
+ # episode stats
164
+ collected_required = sum(1 for f in required_fields if f in collected)
165
+
166
+ episode_data = {
167
+ "success": self.success,
168
+ "steps": self.state_data["steps_taken"],
169
+ "reward": reward,
170
+ "info_efficiency": collected_required / max(1, len(required_fields))
171
+ }
172
+
173
+ self.episode_stats.append(episode_data)
174
+
175
+ info = {
176
+ "final_score": final_score,
177
+ "task_success": self.success,
178
+ "collected_info": collected
179
+ }
180
+
181
+ self.last_action = action
182
+ return self._get_observation(), reward, done, info
183
+
184
+ # -----------------------
185
+ # INVALID
186
+ # -----------------------
187
+ else:
188
+ reward -= 0.3
189
+
190
+ # -----------------------
191
+ # STEP UPDATE
192
+ # -----------------------
193
+ self.state_data["steps_taken"] += 1
194
+ self.current_steps += 1
195
+
196
+ # -----------------------
197
+ # MAX STEP TERMINATION
198
+ # -----------------------
199
+ if self.state_data["steps_taken"] >= self.state_data["max_steps"]:
200
+ done = True
201
+ reward -= 2.0
202
+
203
+ # record failure episode
204
+ self.episode_stats.append({
205
+ "success": False,
206
+ "steps": self.state_data["steps_taken"],
207
+ "reward": reward,
208
+ "info_efficiency": 0
209
+ })
210
+
211
+ info = {
212
+ "final_score": 0.0,
213
+ "task_success": False
214
+ }
215
+
216
+ # -----------------------
217
+ # SAVE STATE
218
+ # -----------------------
219
+ self.last_action = action
220
+ self.current_episode_reward += reward
221
+
222
+ return self._get_observation(), reward, done, info
223
+
224
+ def state(self) -> Dict:
225
+ return self.state_data
226
+
227
+ def get_metrics(self):
228
+
229
+ if not self.episode_stats:
230
+ return {}
231
+
232
+ total = len(self.episode_stats)
233
+
234
+ success_rate = sum(e["success"] for e in self.episode_stats) / total
235
+ avg_steps = sum(e["steps"] for e in self.episode_stats) / total
236
+ avg_reward = sum(e["reward"] for e in self.episode_stats) / total
237
+ info_eff = sum(e["info_efficiency"] for e in self.episode_stats) / total
238
+
239
+ return {
240
+ "success_rate": round(success_rate, 3),
241
+ "avg_steps": round(avg_steps, 3),
242
+ "avg_reward": round(avg_reward, 3),
243
+ "info_efficiency": round(info_eff, 3)
244
+ }
app/graders.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # app/graders.py
4
+
5
+ def grade_task1(state):
6
+ score = 0.0
7
+ gt = state["ground_truth"]
8
+
9
+ if state["category"] == gt["category"]:
10
+ score += 0.5
11
+ if state["priority"] == gt["priority"]:
12
+ score += 0.5
13
+
14
+ return score
15
+
16
+
17
+ def grade_task2(state):
18
+ required = set(state["ground_truth"]["required_info"])
19
+ collected = set(state["collected_info"].keys())
20
+
21
+ if not required:
22
+ return 1.0
23
+
24
+ return len(collected & required) / len(required)
25
+
26
+
27
+ def grade_task3(state):
28
+ score = 0.0
29
+ gt = state["ground_truth"]
30
+
31
+ # classification
32
+ if state["category"] == gt["category"]:
33
+ score += 0.3
34
+
35
+ # info collection
36
+ required = set(gt["required_info"])
37
+ collected = set(state["collected_info"].keys())
38
+ if required:
39
+ score += 0.3 * (len(collected & required) / len(required))
40
+
41
+ # resolution
42
+ if state["status"] == "resolved":
43
+ score += 0.4
44
+
45
+ return score
app/models.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #app/models.py
3
+
4
+ from pydantic import BaseModel
5
+ from typing import List, Dict, Optional
6
+
7
+ class Observation(BaseModel):
8
+ ticket_id: str
9
+ customer_message: str
10
+ history: List[str]
11
+ known_info: Dict
12
+ #missing_info: List[str]
13
+ status: str
14
+ step_count: int
15
+ remaining_steps: int
16
+
17
+
18
+ class Action(BaseModel):
19
+ action_type: str
20
+ content: Optional[str] = ""
21
+ metadata: Optional[Dict] = {}
22
+
23
+
24
+ class Reward(BaseModel):
25
+ value: float
26
+ reason: str
app/tasks.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # app/tasks.py
4
+
5
+ TASKS = {
6
+ "task1": {
7
+ "description": "Classify ticket category and priority",
8
+ "max_steps": 2
9
+ },
10
+ "task2": {
11
+ "description": "Gather required information",
12
+ "max_steps": 5
13
+ },
14
+ "task3": {
15
+ "description": "Full resolution workflow",
16
+ "max_steps": 10
17
+ }
18
+ }
app/test_env.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ // Testing, conceptually
4
+
5
+ Test What it verifies
6
+ ask_info info collection logic
7
+ resolve (after) success path
8
+ resolve (before) penalty logic
9
+ reward values correctness of shaping
10
+ done flag termination logic
11
+
12
+ > Detailed test flow
13
+
14
+ - ask_info
15
+
16
+ Conceptually, checks whether agent can reduce uncerainiy by asking the correct question
17
+
18
+ -- The environment is partially observable — the agent doesn’t know everything upfront --
19
+
20
+ Real-world analogy:
21
+ Support agent asking the client of their email
22
+
23
+
24
+ - resolve (after)
25
+
26
+ Conceptually, checks:
27
+
28
+ “Can the agent complete the task after gathering required info?”
29
+
30
+ This is goal completion
31
+
32
+ - resolve (before)
33
+
34
+ Conceptually, checks:
35
+
36
+ “Does the system penalize shortcut / lazy behavior?”
37
+
38
+ Without this:
39
+ Agent would always jump to resolve
40
+
41
+ - Reward values
42
+
43
+ Conceptually, checks:
44
+
45
+ “Is the agent receiving useful learning signals?”
46
+
47
+ With the reward-mechanism implemented:
48
+
49
+ Behavior Reward
50
+ correct info +0.3
51
+ correct resolution +1.0
52
+ final score +0.0 → +1.0
53
+ wrong action negative
54
+
55
+ technically, we validate:
56
+
57
+ reward accumulation works
58
+ no random jumps
59
+ consistent scaling
60
+
61
+ This is critical, because:
62
+
63
+ . Bad reward = bad agent/system
64
+ . Good reward = learnable system
65
+
66
+ - done flag
67
+
68
+ Conceptually, checks:
69
+
70
+ “Does the environment know when the episode ends?”
71
+
72
+ - no score field in /reset, since at reset:
73
+
74
+ Episode has not happened yet
75
+ → No performance → No score
76
+
77
+
78
+ These tests collectively validate:
79
+
80
+ MDP (Markov Decision Process) -> (State, Action, Reward, Transition, Termination) -> Thorough RL Environment
81
+
82
+ Component Verified by
83
+ State reset
84
+ Action ask_info / resolve
85
+ Reward reward tests
86
+ Transition state updates
87
+ Termination done flag
88
+
89
+
90
+
91
+ // Expected behavior
92
+
93
+ Good Agent Flow:
94
+ Reset
95
+ → ask_info (+0.3)
96
+ → resolve (+1.0 + bonus)
97
+
98
+ Bad Agent Flow:
99
+ Reset
100
+ → resolve (-0.3)
101
+ → ask random info (-0.1)
102
+ → timeout (-1.0)
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+ """
111
+
112
+ import requests
113
+
114
+ BASE = "http://127.0.0.1:8001"
115
+
116
+ # Reset
117
+ r = requests.get(f"{BASE}/reset")
118
+ print(f"\nRESET: \n\n{r.json()}")
119
+
120
+
121
+ # Ask info
122
+ r = requests.post(f"{BASE}/step", json={
123
+ "type": "ask_info",
124
+ "field": "account_email"
125
+ })
126
+ #print("ASK INFO:", r.json())
127
+ print(f"\nASK INFO: \n\n{r.json()}")
128
+
129
+ # Resolve
130
+ r = requests.post(f"{BASE}/step", json={
131
+ "type": "resolve"
132
+ })
133
+ print(f"\nRESOLVE: \n\n{r.json()}")
134
+ #print(f"\n"RESOLVE:", {r.json()})
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # inference.py
3
+
4
+ import os
5
+ from agent_llm import get_action
6
+ from app.env import CustomerSupportEnv
7
+
8
+
9
+ def format_action(action: dict) -> str:
10
+ """Convert action dict → string"""
11
+ if not action:
12
+ return "null"
13
+ return str(action).replace("\n", "").replace(" ", " ")
14
+
15
+
16
+ def main():
17
+
18
+ env = CustomerSupportEnv()
19
+ obs = env.reset()
20
+
21
+ #model_name = os.getenv("MODEL_NAME", "unknown-model")
22
+ model_name="llama-3.1-8b-instant"
23
+
24
+ task_name = "customer-support"
25
+ benchmark = "openenv"
26
+
27
+ step_count = 0
28
+ rewards = []
29
+ success = False
30
+
31
+ # =========================
32
+ # START
33
+ # =========================
34
+ print(f"[START] task={task_name} env={benchmark} model={model_name}")
35
+
36
+ try:
37
+ done = False
38
+
39
+ while not done:
40
+
41
+ valid_actions = [
42
+ {"type": "ask_info", "field": "order_id"},
43
+ {"type": "ask_info", "field": "account_email"},
44
+ {"type": "ask_info", "field": "device_type"},
45
+ {"type": "ask_info", "field": "browser"},
46
+ {"type": "resolve"},
47
+ {"type": "classify"},
48
+ ]
49
+
50
+ action = get_action(obs, valid_actions)
51
+
52
+ next_obs, reward, done, info = env.step(action)
53
+
54
+ step_count += 1
55
+ rewards.append(reward)
56
+
57
+ # =========================
58
+ # STEP
59
+ # =========================
60
+ print(
61
+ f"[STEP] step={step_count} "
62
+ f"action={format_action(action)} "
63
+ f"reward={reward:.2f} "
64
+ f"done={'true' if done else 'false'} "
65
+ f"error=null"
66
+ )
67
+
68
+ obs = next_obs
69
+
70
+ # success from env
71
+ success = info.get("task_success", False)
72
+
73
+ except Exception as e:
74
+ # still must print END
75
+ print(
76
+ f"[STEP] step={step_count+1} "
77
+ f"action=null reward=0.00 done=true error={str(e)}"
78
+ )
79
+
80
+ finally:
81
+ # =========================
82
+ # END
83
+ # =========================
84
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
85
+
86
+ print(
87
+ f"[END] success={'true' if success else 'false'} "
88
+ f"steps={step_count} "
89
+ f"rewards={rewards_str}"
90
+ )
91
+
92
+
93
+ if __name__ == "__main__":
94
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ pydantic
4
+ openai
5
+ groq
6
+ python-dotenv
7
+ pyyaml
8
+ requests
server/__init__.py ADDED
File without changes
server/__init__.py:Zone.Identifier ADDED
Binary file (25 Bytes). View file
 
server/main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # server/main.py
3
+
4
+ from fastapi import FastAPI
5
+ from app.env import CustomerSupportEnv
6
+
7
+ import sys
8
+ import os
9
+
10
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
11
+
12
+ app = FastAPI()
13
+ env = CustomerSupportEnv()
14
+
15
+ @app.get("/reset")
16
+ def reset():
17
+ return env.reset()
18
+
19
+ """
20
+ @app.post("/step")
21
+ def step(action: dict):
22
+ return env.step(action)
23
+ """
24
+ @app.post("/step")
25
+ def step(action: dict):
26
+ obs, reward, done, info = env.step(action)
27
+
28
+ return {
29
+ "observation": obs,
30
+ "reward": reward,
31
+ "done": done,
32
+ "info": info
33
+ }
34
+
35
+ @app.get("/state")
36
+ def state():
37
+ return env.state()
38
+
39
+ @app.get("/health")
40
+ def health():
41
+ return {"status": "ok"}
test_rule_agent.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # test_rule_agent.py
3
+
4
+ from app.env import CustomerSupportEnv
5
+ from agent_rule_based import get_action
6
+
7
+ env = CustomerSupportEnv()
8
+
9
+ for i in range(5):
10
+ obs = env.reset()
11
+ done = False
12
+
13
+ print(f"\n===== EPISODE {i+1} =====")
14
+
15
+ while not done:
16
+ action = get_action(obs)
17
+ obs, reward, done, info = env.step(action)
18
+
19
+ print("FINAL:", info)
20
+ print(env.get_metrics())