File size: 5,604 Bytes
d6a76d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5683339
 
0894e25
d6a76d5
 
 
5683339
 
 
d6a76d5
6819726
 
 
0894e25
5683339
 
 
 
 
 
 
 
 
 
 
 
 
 
6819726
 
5683339
 
 
0894e25
 
5683339
 
0894e25
 
 
 
6819726
 
d6a76d5
 
0894e25
d6a76d5
0894e25
d6a76d5
0894e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6a76d5
 
 
5683339
d6a76d5
 
5683339
0894e25
5683339
 
 
 
 
0894e25
5683339
 
 
 
 
 
0894e25
d6a76d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0894e25
d6a76d5
0894e25
 
d6a76d5
 
 
0894e25
 
 
 
 
 
 
d6a76d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0894e25
d6a76d5
 
 
0894e25
d6a76d5
0894e25
 
 
 
 
 
 
 
d6a76d5
0894e25
 
d6a76d5
0894e25
d6a76d5
0894e25
d6a76d5
 
0894e25
 
d6a76d5
0894e25
d6a76d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12a8a0f
 
 
 
d6a76d5
 
 
12a8a0f
d6a76d5
12a8a0f
d6a76d5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

# agent_llm.py

"""
- Uses LLM (requirement satisfied)
- Robust (fallback present)
- Structured output (strict JSON)
- No hallucination risk
- Reproducible
"""


import os
import json
import time
#from groq import Groq
#from openai import OpenAI
import random

from app.env import CustomerSupportEnv

#from dotenv import load_dotenv
#load_dotenv()


#client = Groq(api_key=os.getenv("GROQ_API_KEY"))

# =========================
# PURPOSE: Safe OpenAI client init
# =========================
try:
    from openai import OpenAI
except ImportError:
    OpenAI = None

try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

# =========================
# CONFIG - CLIENT-SAFE
# =========================
def get_llm_client():
    if OpenAI is None:
        return None

    key = os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
    if not key:
        return None

    return OpenAI(
        base_url=os.getenv("API_BASE_URL", "https://router.huggingface.co/v1"),
        api_key=key
    )

client = get_llm_client()

# =========================
# PURPOSE: Prompt - Strict + Minimal - encourages uncertainty-aware reasoning
# =========================
def build_prompt(obs):
    return f"""
    You are a customer support agent.

    Customer message:
    {obs.get("customer_message")}

    Known info:
    {obs.get("known_info")}

    Required fields:
    {obs.get("required")}

    Your goal is to resolve the ticket efficiently.

    Think carefully:
    - You may revise earlier decisions
    - Do not commit too early
    - Ask missing info if unsure
    - The message may be ambiguous
    - Do not assume category prematurely
    - Ask only necessary questions
    - Avoid redundant actions

    Return JSON:
    {{"action": {{...}}}}
    """


# =========================
# LLM CALL (SAFE)
# =========================
def call_llm(prompt):
    if client is None:
        return None  # triggers fallback

    try:
        completion = client.chat.completions.create(
            model=os.getenv("MODEL_NAME", "unknown-model"),
            messages=[{"role": "user", "content": prompt}],
            temperature=0.3,
            response_format={"type": "json_object"}
        )

        return completion.choices[0].message.content.strip()

    except Exception:
        return None  # triggers fallback


# =========================
# PARSER (STRICT)
# =========================
def parse_output(text):
    try:
        start = text.find("{")
        end = text.rfind("}") + 1
        parsed = json.loads(text[start:end])

        action = parsed.get("action")

        if not action or "type" not in action:
            return None

        return action

    except:
        return None


# =========================
# PURPOSE: Fallback is intentionally imperfect
# =========================
def fallback(obs):

    known = obs.get("known_info", {})
    required = obs.get("required", [])

    # allow reclassification even if already classified
    if "category" not in known or random.random() < 0.3:
        return {
            "type": "classify",
            "category": "technical",
            "priority": "medium"
        }

    missing = [f for f in required if f not in known]
    if missing:
        return {"type": "ask_info", "field": missing[0]}

    return {"type": "resolve"}


# =========================
# VALIDATION
# =========================
def is_valid_action(action, valid_actions):
    if not action or "type" not in action:
        return False

    valid_types = [a["type"] for a in valid_actions]

    if action["type"] not in valid_types:
        return False

    if action["type"] == "ask_info":
        valid_fields = [a["field"] for a in valid_actions if a["type"] == "ask_info"]
        return action.get("field") in valid_fields

    if action["type"] == "classify":
        return "category" in action and "priority" in action

    return True

# =========================
# PURPOSE: Hybrid control (LLM + adaptive fallback)
# =========================
def get_action(obs, valid_actions):

    prompt = build_prompt(obs)

    if client:
        try:
            resp = client.chat.completions.create(
                model=os.getenv("MODEL_NAME"),
                messages=[{"role": "user", "content": prompt}],
                temperature=0.4,
                response_format={"type": "json_object"}
            )

            text = resp.choices[0].message.content
            parsed = json.loads(text)

            action = parsed.get("action")

            if action and "type" in action:
                return action

        except:
            pass

    return fallback(obs)


# =========================
# RUN
# =========================
def run_agent():
    env = CustomerSupportEnv()
    obs = env.reset()

    done = False

    while not done:
        valid_actions = [
            {"type": "ask_info", "field": "order_id"},
            {"type": "ask_info", "field": "account_email"},
            {"type": "ask_info", "field": "device_type"},
            {"type": "ask_info", "field": "browser"},
            {"type": "resolve"},
            {"type": "classify"},
        ]

        action = get_action(obs, valid_actions)

        obs, reward, done, info = env.step(action)

    
        #print(f"\nOBS: {obs}")
        #print(f"\nACTION: {action}")
        #print(f"\nREWARD: {reward}")
        #print(f"\nDONE: {done}")

    
    #print("FINAL:", info)
    #print(f"\nFINAL: {info if info else 'No info returned'}")
    
    #print(f"\nMETRICS: {env.get_metrics()}")


if __name__ == "__main__":
    run_agent()