prashantmatlani commited on
Commit
6819726
·
1 Parent(s): bca2b27

modified inference, agent, agent_llm with inclusion of open ai api

Browse files
Files changed (4) hide show
  1. agent.py +28 -8
  2. agent_llm.py +16 -1
  3. inference.py +6 -0
  4. requirements.txt +2 -1
agent.py CHANGED
@@ -10,7 +10,7 @@ import json
10
  import random
11
 
12
  from dotenv import load_dotenv
13
- # from openai import OpenAI
14
  from groq import Groq
15
 
16
  from app.env import CustomerSupportEnv
@@ -28,7 +28,7 @@ ENV_PATH = os.path.join(BASE_DIR, ".env")
28
  load_dotenv(ENV_PATH)
29
  print(f"\nCWD: {os.getcwd()}")
30
 
31
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
32
  #client = os.getenv("GROQ_API_KEY")
33
 
34
  #print(f"\nENV PATH: {ENV_PATH}")
@@ -37,10 +37,28 @@ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
37
  ##print("KEY:", os.getenv("GROQ_API_KEY"))
38
  #print(f"\nmodel name: {os.getenv('MODEL_NAME')}")
39
 
40
- print("Sending request...")
41
 
42
  #sys.exit()
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # =========================
45
  # Smarter, mapped ask_info - boosts info_progress speed, reward per episode
46
  # =========================
@@ -70,7 +88,6 @@ def smart_classify(message):
70
 
71
  return {"category": "general", "priority": "medium"}
72
 
73
-
74
  def override_classify(message):
75
  msg = message.lower()
76
 
@@ -86,7 +103,6 @@ def override_classify(message):
86
  return {"type": "classify", "category": "general", "priority": "medium"}
87
 
88
 
89
-
90
  def is_ready_to_resolve(category, known):
91
  if category == "billing":
92
  return "order_id" in known
@@ -100,7 +116,7 @@ def is_ready_to_resolve(category, known):
100
  return False
101
 
102
  # =========================
103
- # POLICY ENFORCEMENT INTEAD OF LLM DECISION
104
  # =========================
105
  def enforce_policy(obs, action):
106
  known = obs["known_info"]
@@ -224,7 +240,7 @@ FORMAT:
224
 
225
  def call_llm(prompt):
226
  completion = client.chat.completions.create(
227
- model=os.getenv("MODEL_NAME"),
228
  #model="llama-3.1-8b-instant",
229
  messages=[{"role": "user", "content": prompt}],
230
  temperature=0.2,
@@ -338,7 +354,7 @@ def get_action(obs):
338
  return {"type": "classify", "category": "technical", "priority": "medium"}
339
 
340
  # =====================
341
- # 2. COMPUTE MISSING INFO (🔥 KEY CHANGE)
342
  # =====================
343
  missing = [f for f in required if f not in known]
344
 
@@ -410,6 +426,10 @@ def run_multiple(n=3):
410
 
411
  avg = sum(scores) / len(scores)
412
  print("\n📊 AVERAGE SCORE:", avg)
 
 
 
 
413
 
414
 
415
  if __name__ == "__main__":
 
10
  import random
11
 
12
  from dotenv import load_dotenv
13
+ from openai import OpenAI
14
  from groq import Groq
15
 
16
  from app.env import CustomerSupportEnv
 
28
  load_dotenv(ENV_PATH)
29
  print(f"\nCWD: {os.getcwd()}")
30
 
31
+ #client = Groq(api_key=os.getenv("GROQ_API_KEY"))
32
  #client = os.getenv("GROQ_API_KEY")
33
 
34
  #print(f"\nENV PATH: {ENV_PATH}")
 
37
  ##print("KEY:", os.getenv("GROQ_API_KEY"))
38
  #print(f"\nmodel name: {os.getenv('MODEL_NAME')}")
39
 
40
+ #print("Sending request...")
41
 
42
  #sys.exit()
43
 
44
+
45
+ # =========================
46
+ # CONFIG (NEW - VENDOR NEUTRAL)
47
+ # =========================
48
+ def get_llm_client():
49
+ return OpenAI(
50
+ base_url=os.getenv(
51
+ "API_BASE_URL",
52
+ "https://router.huggingface.co/v1"
53
+ ),
54
+ api_key=os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
55
+ )
56
+
57
+ client = get_llm_client()
58
+
59
+ print(f"[CONFIG] API_BASE_URL={os.getenv('API_BASE_URL', 'https://router.huggingface.co/v1')}")
60
+ print("Sending request...")
61
+
62
  # =========================
63
  # Smarter, mapped ask_info - boosts info_progress speed, reward per episode
64
  # =========================
 
88
 
89
  return {"category": "general", "priority": "medium"}
90
 
 
91
  def override_classify(message):
92
  msg = message.lower()
93
 
 
103
  return {"type": "classify", "category": "general", "priority": "medium"}
104
 
105
 
 
106
  def is_ready_to_resolve(category, known):
107
  if category == "billing":
108
  return "order_id" in known
 
116
  return False
117
 
118
  # =========================
119
+ # POLICY ENFORCEMENT INSTEAD OF LLM DECISION
120
  # =========================
121
  def enforce_policy(obs, action):
122
  known = obs["known_info"]
 
240
 
241
  def call_llm(prompt):
242
  completion = client.chat.completions.create(
243
+ model=os.getenv("MODEL_NAME", "unknown-model"),
244
  #model="llama-3.1-8b-instant",
245
  messages=[{"role": "user", "content": prompt}],
246
  temperature=0.2,
 
354
  return {"type": "classify", "category": "technical", "priority": "medium"}
355
 
356
  # =====================
357
+ # 2. COMPUTE MISSING INFO
358
  # =====================
359
  missing = [f for f in required if f not in known]
360
 
 
426
 
427
  avg = sum(scores) / len(scores)
428
  print("\n📊 AVERAGE SCORE:", avg)
429
+ #print("\n📊 scores:", scores)
430
+ #print("\n📊 sum scores:", sum(scores))
431
+ #print("\n📊 len scores:", len(scores))
432
+
433
 
434
 
435
  if __name__ == "__main__":
agent_llm.py CHANGED
@@ -15,12 +15,27 @@ import json
15
  import time
16
  from dotenv import load_dotenv
17
  from groq import Groq
 
18
 
19
  from app.env import CustomerSupportEnv
20
 
21
  load_dotenv()
22
 
23
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # =========================
26
  # PROMPT (STRICT + MINIMAL)
 
15
  import time
16
  from dotenv import load_dotenv
17
  from groq import Groq
18
+ from openai import OpenAI
19
 
20
  from app.env import CustomerSupportEnv
21
 
22
  load_dotenv()
23
 
24
+ #client = Groq(api_key=os.getenv("GROQ_API_KEY"))
25
+
26
+ # =========================
27
+ # CONFIG (NEW)
28
+ # =========================
29
+ def get_llm_client():
30
+ return OpenAI(
31
+ base_url=os.getenv(
32
+ "API_BASE_URL",
33
+ "https://router.huggingface.co/v1"
34
+ ),
35
+ api_key=os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
36
+ )
37
+
38
+ client = get_llm_client()
39
 
40
  # =========================
41
  # PROMPT (STRICT + MINIMAL)
inference.py CHANGED
@@ -37,6 +37,12 @@ def main():
37
  model_name = os.getenv("MODEL_NAME", "unknown-model")
38
  #model_name="llama-3.1-8b-instant"
39
 
 
 
 
 
 
 
40
  task_name = "customer-support"
41
  benchmark = "openenv"
42
 
 
37
  model_name = os.getenv("MODEL_NAME", "unknown-model")
38
  #model_name="llama-3.1-8b-instant"
39
 
40
+ api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
41
+
42
+ #api_base_url = os.getenv("API_BASE_URL")
43
+
44
+ print(f"[CONFIG] api_base_url={api_base_url}")
45
+
46
  task_name = "customer-support"
47
  benchmark = "openenv"
48
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ openai
5
  groq
6
  python-dotenv
7
  pyyaml
8
- requests
 
 
5
  groq
6
  python-dotenv
7
  pyyaml
8
+ requests
9
+ openai>=1.0.0