Spaces:
Sleeping
Sleeping
Commit ·
6819726
1
Parent(s): bca2b27
modified inference, agent, agent_llm with inclusion of open ai api
Browse files- agent.py +28 -8
- agent_llm.py +16 -1
- inference.py +6 -0
- requirements.txt +2 -1
agent.py
CHANGED
|
@@ -10,7 +10,7 @@ import json
|
|
| 10 |
import random
|
| 11 |
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
-
|
| 14 |
from groq import Groq
|
| 15 |
|
| 16 |
from app.env import CustomerSupportEnv
|
|
@@ -28,7 +28,7 @@ ENV_PATH = os.path.join(BASE_DIR, ".env")
|
|
| 28 |
load_dotenv(ENV_PATH)
|
| 29 |
print(f"\nCWD: {os.getcwd()}")
|
| 30 |
|
| 31 |
-
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
| 32 |
#client = os.getenv("GROQ_API_KEY")
|
| 33 |
|
| 34 |
#print(f"\nENV PATH: {ENV_PATH}")
|
|
@@ -37,10 +37,28 @@ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
|
| 37 |
##print("KEY:", os.getenv("GROQ_API_KEY"))
|
| 38 |
#print(f"\nmodel name: {os.getenv('MODEL_NAME')}")
|
| 39 |
|
| 40 |
-
print("Sending request...")
|
| 41 |
|
| 42 |
#sys.exit()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
# =========================
|
| 45 |
# Smarter, mapped ask_info - boosts info_progress speed, reward per episode
|
| 46 |
# =========================
|
|
@@ -70,7 +88,6 @@ def smart_classify(message):
|
|
| 70 |
|
| 71 |
return {"category": "general", "priority": "medium"}
|
| 72 |
|
| 73 |
-
|
| 74 |
def override_classify(message):
|
| 75 |
msg = message.lower()
|
| 76 |
|
|
@@ -86,7 +103,6 @@ def override_classify(message):
|
|
| 86 |
return {"type": "classify", "category": "general", "priority": "medium"}
|
| 87 |
|
| 88 |
|
| 89 |
-
|
| 90 |
def is_ready_to_resolve(category, known):
|
| 91 |
if category == "billing":
|
| 92 |
return "order_id" in known
|
|
@@ -100,7 +116,7 @@ def is_ready_to_resolve(category, known):
|
|
| 100 |
return False
|
| 101 |
|
| 102 |
# =========================
|
| 103 |
-
# POLICY ENFORCEMENT
|
| 104 |
# =========================
|
| 105 |
def enforce_policy(obs, action):
|
| 106 |
known = obs["known_info"]
|
|
@@ -224,7 +240,7 @@ FORMAT:
|
|
| 224 |
|
| 225 |
def call_llm(prompt):
|
| 226 |
completion = client.chat.completions.create(
|
| 227 |
-
model=os.getenv("MODEL_NAME"),
|
| 228 |
#model="llama-3.1-8b-instant",
|
| 229 |
messages=[{"role": "user", "content": prompt}],
|
| 230 |
temperature=0.2,
|
|
@@ -338,7 +354,7 @@ def get_action(obs):
|
|
| 338 |
return {"type": "classify", "category": "technical", "priority": "medium"}
|
| 339 |
|
| 340 |
# =====================
|
| 341 |
-
# 2. COMPUTE MISSING INFO
|
| 342 |
# =====================
|
| 343 |
missing = [f for f in required if f not in known]
|
| 344 |
|
|
@@ -410,6 +426,10 @@ def run_multiple(n=3):
|
|
| 410 |
|
| 411 |
avg = sum(scores) / len(scores)
|
| 412 |
print("\n📊 AVERAGE SCORE:", avg)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
|
| 415 |
if __name__ == "__main__":
|
|
|
|
| 10 |
import random
|
| 11 |
|
| 12 |
from dotenv import load_dotenv
|
| 13 |
+
from openai import OpenAI
|
| 14 |
from groq import Groq
|
| 15 |
|
| 16 |
from app.env import CustomerSupportEnv
|
|
|
|
| 28 |
load_dotenv(ENV_PATH)
|
| 29 |
print(f"\nCWD: {os.getcwd()}")
|
| 30 |
|
| 31 |
+
#client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
| 32 |
#client = os.getenv("GROQ_API_KEY")
|
| 33 |
|
| 34 |
#print(f"\nENV PATH: {ENV_PATH}")
|
|
|
|
| 37 |
##print("KEY:", os.getenv("GROQ_API_KEY"))
|
| 38 |
#print(f"\nmodel name: {os.getenv('MODEL_NAME')}")
|
| 39 |
|
| 40 |
+
#print("Sending request...")
|
| 41 |
|
| 42 |
#sys.exit()
|
| 43 |
|
| 44 |
+
|
| 45 |
+
# =========================
|
| 46 |
+
# CONFIG (NEW - VENDOR NEUTRAL)
|
| 47 |
+
# =========================
|
| 48 |
+
def get_llm_client():
|
| 49 |
+
return OpenAI(
|
| 50 |
+
base_url=os.getenv(
|
| 51 |
+
"API_BASE_URL",
|
| 52 |
+
"https://router.huggingface.co/v1"
|
| 53 |
+
),
|
| 54 |
+
api_key=os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
client = get_llm_client()
|
| 58 |
+
|
| 59 |
+
print(f"[CONFIG] API_BASE_URL={os.getenv('API_BASE_URL', 'https://router.huggingface.co/v1')}")
|
| 60 |
+
print("Sending request...")
|
| 61 |
+
|
| 62 |
# =========================
|
| 63 |
# Smarter, mapped ask_info - boosts info_progress speed, reward per episode
|
| 64 |
# =========================
|
|
|
|
| 88 |
|
| 89 |
return {"category": "general", "priority": "medium"}
|
| 90 |
|
|
|
|
| 91 |
def override_classify(message):
|
| 92 |
msg = message.lower()
|
| 93 |
|
|
|
|
| 103 |
return {"type": "classify", "category": "general", "priority": "medium"}
|
| 104 |
|
| 105 |
|
|
|
|
| 106 |
def is_ready_to_resolve(category, known):
|
| 107 |
if category == "billing":
|
| 108 |
return "order_id" in known
|
|
|
|
| 116 |
return False
|
| 117 |
|
| 118 |
# =========================
|
| 119 |
+
# POLICY ENFORCEMENT INSTEAD OF LLM DECISION
|
| 120 |
# =========================
|
| 121 |
def enforce_policy(obs, action):
|
| 122 |
known = obs["known_info"]
|
|
|
|
| 240 |
|
| 241 |
def call_llm(prompt):
|
| 242 |
completion = client.chat.completions.create(
|
| 243 |
+
model=os.getenv("MODEL_NAME", "unknown-model"),
|
| 244 |
#model="llama-3.1-8b-instant",
|
| 245 |
messages=[{"role": "user", "content": prompt}],
|
| 246 |
temperature=0.2,
|
|
|
|
| 354 |
return {"type": "classify", "category": "technical", "priority": "medium"}
|
| 355 |
|
| 356 |
# =====================
|
| 357 |
+
# 2. COMPUTE MISSING INFO
|
| 358 |
# =====================
|
| 359 |
missing = [f for f in required if f not in known]
|
| 360 |
|
|
|
|
| 426 |
|
| 427 |
avg = sum(scores) / len(scores)
|
| 428 |
print("\n📊 AVERAGE SCORE:", avg)
|
| 429 |
+
#print("\n📊 scores:", scores)
|
| 430 |
+
#print("\n📊 sum scores:", sum(scores))
|
| 431 |
+
#print("\n📊 len scores:", len(scores))
|
| 432 |
+
|
| 433 |
|
| 434 |
|
| 435 |
if __name__ == "__main__":
|
agent_llm.py
CHANGED
|
@@ -15,12 +15,27 @@ import json
|
|
| 15 |
import time
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
from groq import Groq
|
|
|
|
| 18 |
|
| 19 |
from app.env import CustomerSupportEnv
|
| 20 |
|
| 21 |
load_dotenv()
|
| 22 |
|
| 23 |
-
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# =========================
|
| 26 |
# PROMPT (STRICT + MINIMAL)
|
|
|
|
| 15 |
import time
|
| 16 |
from dotenv import load_dotenv
|
| 17 |
from groq import Groq
|
| 18 |
+
from openai import OpenAI
|
| 19 |
|
| 20 |
from app.env import CustomerSupportEnv
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
|
| 24 |
+
#client = Groq(api_key=os.getenv("GROQ_API_KEY"))
|
| 25 |
+
|
| 26 |
+
# =========================
|
| 27 |
+
# CONFIG (NEW)
|
| 28 |
+
# =========================
|
| 29 |
+
def get_llm_client():
|
| 30 |
+
return OpenAI(
|
| 31 |
+
base_url=os.getenv(
|
| 32 |
+
"API_BASE_URL",
|
| 33 |
+
"https://router.huggingface.co/v1"
|
| 34 |
+
),
|
| 35 |
+
api_key=os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
client = get_llm_client()
|
| 39 |
|
| 40 |
# =========================
|
| 41 |
# PROMPT (STRICT + MINIMAL)
|
inference.py
CHANGED
|
@@ -37,6 +37,12 @@ def main():
|
|
| 37 |
model_name = os.getenv("MODEL_NAME", "unknown-model")
|
| 38 |
#model_name="llama-3.1-8b-instant"
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
task_name = "customer-support"
|
| 41 |
benchmark = "openenv"
|
| 42 |
|
|
|
|
| 37 |
model_name = os.getenv("MODEL_NAME", "unknown-model")
|
| 38 |
#model_name="llama-3.1-8b-instant"
|
| 39 |
|
| 40 |
+
api_base_url = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 41 |
+
|
| 42 |
+
#api_base_url = os.getenv("API_BASE_URL")
|
| 43 |
+
|
| 44 |
+
print(f"[CONFIG] api_base_url={api_base_url}")
|
| 45 |
+
|
| 46 |
task_name = "customer-support"
|
| 47 |
benchmark = "openenv"
|
| 48 |
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ openai
|
|
| 5 |
groq
|
| 6 |
python-dotenv
|
| 7 |
pyyaml
|
| 8 |
-
requests
|
|
|
|
|
|
| 5 |
groq
|
| 6 |
python-dotenv
|
| 7 |
pyyaml
|
| 8 |
+
requests
|
| 9 |
+
openai>=1.0.0
|