csa01 / agent.py
prashantmatlani's picture
modified agent_llm, agent, inference to use base api url
5683339
# agent.py
import sys
from unicodedata import category
import requests
import os
import time
import json
import random
from openai import OpenAI
#from groq import Groq
from app.env import CustomerSupportEnv
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
#from dotenv import load_dotenv
#ENV_PATH = os.path.join(BASE_DIR, ".env")
#load_dotenv(ENV_PATH)
# load_dotenv()
# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# BASE_URL = "http://127.0.0.1:8001"
#load_dotenv("/home/pb/projects/openenv-customer-support/.env")
try:
from dotenv import load_dotenv
ENV_PATH = os.path.join(BASE_DIR, ".env")
load_dotenv(ENV_PATH)
except ImportError:
# dotenv not available in validator environment
pass
print(f"\nCWD: {os.getcwd()}")
#client = Groq(api_key=os.getenv("GROQ_API_KEY"))
#client = os.getenv("GROQ_API_KEY")
#print(f"\nENV PATH: {ENV_PATH}")
#print(f"\ngroq api key: {client}")
##print(f"\ngroq api key: {os.getenv('GROQ_API_KEY')}")
##print("KEY:", os.getenv("GROQ_API_KEY"))
#print(f"\nmodel name: {os.getenv('MODEL_NAME')}")
#print("Sending request...")
#sys.exit()
# =========================
# CONFIG (NEW - VENDOR NEUTRAL)
# =========================
def get_llm_client():
if OpenAI is None:
return None
api_key = os.getenv("API_KEY") or os.getenv("GROQ_API_KEY")
if not api_key:
return None # πŸ”₯ critical
try:
return OpenAI(
base_url=os.getenv(
"API_BASE_URL",
"https://router.huggingface.co/v1"
),
api_key=api_key
)
except Exception:
return None
client = get_llm_client()
print(f"[CONFIG] API_BASE_URL={os.getenv('API_BASE_URL', 'https://router.huggingface.co/v1')}")
print("Sending request...")
# =========================
# Smarter, mapped ask_info - boosts info_progress speed, reward per episode
# =========================
def pick_field(category, known):
if category == "billing":
return "order_id"
if category == "technical":
return "account_email"
if category == "delivery":
return "order_id"
return "account_email"
# =========================
# CLASSIFIER TO REDUCE LLM RELIANCE
# =========================
def smart_classify(message):
msg = message.lower()
if any(x in msg for x in ["refund", "cancel", "subscription", "charge"]):
return {"category": "billing", "priority": "high"}
if any(x in msg for x in ["crash", "bug", "error", "slow"]):
return {"category": "technical", "priority": "high"}
return {"category": "general", "priority": "medium"}
def override_classify(message):
msg = message.lower()
if any(x in msg for x in ["charged", "refund", "billing", "cancel", "subscription"]):
return {"type": "classify", "category": "billing", "priority": "high"}
if any(x in msg for x in ["checkout", "crash", "bug", "error", "not loading", "login"]):
return {"type": "classify", "category": "technical", "priority": "high"}
if any(x in msg for x in ["delivery", "order not arrived", "shipping"]):
return {"type": "classify", "category": "delivery", "priority": "medium"}
return {"type": "classify", "category": "general", "priority": "medium"}
def is_ready_to_resolve(category, known):
if category == "billing":
return "order_id" in known
if category == "technical":
return "account_email" in known
if category == "delivery":
return "order_id" in known
return False
# =========================
# POLICY ENFORCEMENT INSTEAD OF LLM DECISION
# =========================
def enforce_policy(obs, action):
known = obs["known_info"]
category = known.get("category")
# Never re-classify
if action["type"] == "classify" and category:
return {"type": "ask_info", "field": pick_field(category, known)}
# Force correct ask_info
if action["type"] == "ask_info":
action["field"] = pick_field(category, known)
# if already asked β†’ resolve instead of repeating
if action["type"] == "ask_info":
if action["field"] in known:
return {"type": "resolve"}
# Only resolve when ready
if action["type"] == "resolve":
if not is_ready_to_resolve(category, known):
return {"type": "ask_info", "field": pick_field(category, known)}
return action
# =========================
# PROMPT
# =========================
def build_prompt(obs, valid_actions):
return f"""
You are a customer support decision agent.
Return ONLY valid JSON.
IMPORTANT DECISION RULES:
1. DO NOT ask for unnecessary information
2. If the issue is clear (e.g., password reset, login failure), resolve directly
3. Only ask for information that is REQUIRED to solve the issue
4. NEVER ask for order_id in login/password issues
5. If sufficient information is already available, choose "resolve"
6. Avoid repeating the same question
Customer message:
{obs["customer_message"]}
Known info:
{obs["known_info"]}
Progress:
{obs["info_progress"]}
VALID ACTIONS:
{valid_actions}
RULES:
- ONLY pick from VALID ACTIONS
- "charged", "refund" β†’ billing
- "slow", "crash" β†’ technical
- Do NOT hallucinate
CRITICAL DECISION RULE:
Only choose "resolve" IF:
1. You have correctly classified the issue
2. You have collected ALL required fields
3. You are confident you can solve the user's problem
If ANY doubt remains β†’ ask_info
NEVER resolve early.
CLASSIFICATION RULES (STRICT):
You MUST classify into ONLY ONE of:
- billing
- technical
- delivery
NEVER output "general" or any other category.
---
BILLING:
charged, refund, payment, invoice, subscription, billing issues
TECHNICAL:
login issues, account problems, crashes, errors, bugs, slow performance, app issues
IMPORTANT:
ANY issue related to app behavior (slow, crash, not working, locked account)
β†’ ALWAYS technical
---
DELIVERY:
shipping, delivery delay, order not received
---
PRIORITY RULE:
If message involves money β†’ billing (even if order mentioned)
Example:
"I was charged twice for my order"
β†’ billing
FORMAT:
{{
"thought": "...",
"action": {{ ... }}
}}
"""
# =========================
# LLM CALL
# =========================
def call_llm(prompt):
completion = client.chat.completions.create(
model=os.getenv("MODEL_NAME", "unknown-model"),
#model="llama-3.1-8b-instant",
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
response_format={"type": "json_object"}
)
return completion.choices[0].message.content.strip()
# =========================
# PARSER (MANDATORY)
# =========================
def parse_output(text):
try:
if "```" in text:
text = text.split("```")[1]
start = text.find("{")
end = text.rfind("}") + 1
text = text[start:end]
parsed = json.loads(text)
action = parsed.get("action")
if not action or "type" not in action:
raise ValueError("Invalid action format")
return action
except Exception as e:
print("❌ PARSE ERROR:", e)
print("RAW:", text)
return None
# =========================
# VALIDATION
# =========================
def is_valid_action(action, valid_actions):
if not action or "type" not in action:
return False
action_type = action["type"]
# βœ… check type exists
valid_types = [a["type"] for a in valid_actions]
if action_type not in valid_types:
return False
# βœ… ask_info must match field
if action_type == "ask_info":
valid_fields = [a["field"] for a in valid_actions if a["type"] == "ask_info"]
return action.get("field") in valid_fields
# βœ… classify must have required keys (NOT exact match)
if action_type == "classify":
return "category" in action and "priority" in action
# resolve always valid
return True
# =========================
# VALID ACTION SPACE
# =========================
def get_valid_actions():
actions = [
{"type": "ask_info", "field": "order_id"},
{"type": "ask_info", "field": "account_email"},
{"type": "ask_info", "field": "device_type"},
{"type": "ask_info", "field": "browser"},
{"type": "resolve"},
]
# βœ… allow flexible classification
actions.append({"type": "classify"})
return actions
# =========================
# ACTION PIPELINE
# =========================
def get_action(obs):
msg = obs["customer_message"].lower()
# βœ… NEW: use env-provided structure
known = obs.get("known_info", {})
required = obs.get("required", [])
# =====================
# 1. CLASSIFY (only once)
# =====================
if "category" not in known:
if any(x in msg for x in [
"charged", "refund", "billed", "payment", "invoice", "cancel"
]):
return {"type": "classify", "category": "billing", "priority": "high"}
if any(x in msg for x in [
"delivery", "delivered", "not received", "shipment", "order"
]):
return {"type": "classify", "category": "delivery", "priority": "high"}
if any(x in msg for x in [
"login", "password", "error", "crash", "bug", "checkout"
]):
return {"type": "classify", "category": "technical", "priority": "high"}
return {"type": "classify", "category": "technical", "priority": "medium"}
# =====================
# 2. COMPUTE MISSING INFO
# =====================
missing = [f for f in required if f not in known]
# =====================
# 3. ASK FOR NEXT FIELD
# =====================
if missing:
return {"type": "ask_info", "field": missing[0]}
# =====================
# 4. RESOLVE
# =====================
return {"type": "resolve"}
# =========================
# RUN
# =========================
def run_agent():
print("πŸš€ Starting agent...")
env = CustomerSupportEnv()
obs = env.reset()
done = False
trajectory = []
while not done:
print("\nπŸ“₯ OBS:", obs)
action = get_action(obs)
print("🧠 ACTION:", action)
next_obs, reward, done, info = env.step(action)
print("🎯 REWARD:", reward)
print("βœ… DONE:", done)
trajectory.append({
"state": obs,
"action": action,
"reward": reward
})
obs = next_obs
print("OBS:", obs)
print("ACTION:", action)
print("REWARD:", reward)
print("DONE:", done)
#print("\n🏁 FINAL INFO:", info)
print("FINAL:", info if info else "No info returned")
return {
"final_score": info.get("final_score", 0),
"trajectory": trajectory
}
def run_multiple(n=3):
scores = []
for i in range(n):
print(f"\n===== EPISODE {i+1} =====")
result = run_agent()
scores.append(result["final_score"])
avg = sum(scores) / len(scores)
print("\nπŸ“Š AVERAGE SCORE:", avg)
#print("\nπŸ“Š scores:", scores)
#print("\nπŸ“Š sum scores:", sum(scores))
#print("\nπŸ“Š len scores:", len(scores))
if __name__ == "__main__":
run_multiple(3)