Spaces:
Sleeping
Sleeping
| # agent_llm.py | |
| """ | |
| - Uses LLM (requirement satisfied) | |
| - Robust (fallback present) | |
| - Structured output (strict JSON) | |
| - No hallucination risk | |
| - Reproducible | |
| """ | |
| import os | |
| import json | |
| import time | |
| #from groq import Groq | |
| #from openai import OpenAI | |
| import random | |
| from app.env import CustomerSupportEnv | |
| #from dotenv import load_dotenv | |
| #load_dotenv() | |
| #client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| # ========================= | |
| # PURPOSE: Safe OpenAI client init | |
| # ========================= | |
| try: | |
| from openai import OpenAI | |
| except ImportError: | |
| OpenAI = None | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| # ========================= | |
| # CONFIG - CLIENT-SAFE | |
| # ========================= | |
| def get_llm_client(): | |
| if OpenAI is None: | |
| return None | |
| key = os.getenv("API_KEY") or os.getenv("GROQ_API_KEY") | |
| if not key: | |
| return None | |
| return OpenAI( | |
| base_url=os.getenv("API_BASE_URL", "https://router.huggingface.co/v1"), | |
| api_key=key | |
| ) | |
| client = get_llm_client() | |
| # ========================= | |
| # PURPOSE: Prompt - Strict + Minimal - encourages uncertainty-aware reasoning | |
| # ========================= | |
| def build_prompt(obs): | |
| return f""" | |
| You are a customer support agent. | |
| Customer message: | |
| {obs.get("customer_message")} | |
| Known info: | |
| {obs.get("known_info")} | |
| Required fields: | |
| {obs.get("required")} | |
| Your goal is to resolve the ticket efficiently. | |
| Think carefully: | |
| - You may revise earlier decisions | |
| - Do not commit too early | |
| - Ask missing info if unsure | |
| - The message may be ambiguous | |
| - Do not assume category prematurely | |
| - Ask only necessary questions | |
| - Avoid redundant actions | |
| Return JSON: | |
| {{"action": {{...}}}} | |
| """ | |
| # ========================= | |
| # LLM CALL (SAFE) | |
| # ========================= | |
| def call_llm(prompt): | |
| if client is None: | |
| return None # triggers fallback | |
| try: | |
| completion = client.chat.completions.create( | |
| model=os.getenv("MODEL_NAME", "unknown-model"), | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.3, | |
| response_format={"type": "json_object"} | |
| ) | |
| return completion.choices[0].message.content.strip() | |
| except Exception: | |
| return None # triggers fallback | |
| # ========================= | |
| # PARSER (STRICT) | |
| # ========================= | |
| def parse_output(text): | |
| try: | |
| start = text.find("{") | |
| end = text.rfind("}") + 1 | |
| parsed = json.loads(text[start:end]) | |
| action = parsed.get("action") | |
| if not action or "type" not in action: | |
| return None | |
| return action | |
| except: | |
| return None | |
| # ========================= | |
| # PURPOSE: Fallback is intentionally imperfect | |
| # ========================= | |
| def fallback(obs): | |
| known = obs.get("known_info", {}) | |
| required = obs.get("required", []) | |
| # allow reclassification even if already classified | |
| if "category" not in known or random.random() < 0.3: | |
| return { | |
| "type": "classify", | |
| "category": "technical", | |
| "priority": "medium" | |
| } | |
| missing = [f for f in required if f not in known] | |
| if missing: | |
| return {"type": "ask_info", "field": missing[0]} | |
| return {"type": "resolve"} | |
| # ========================= | |
| # VALIDATION | |
| # ========================= | |
| def is_valid_action(action, valid_actions): | |
| if not action or "type" not in action: | |
| return False | |
| valid_types = [a["type"] for a in valid_actions] | |
| if action["type"] not in valid_types: | |
| return False | |
| if action["type"] == "ask_info": | |
| valid_fields = [a["field"] for a in valid_actions if a["type"] == "ask_info"] | |
| return action.get("field") in valid_fields | |
| if action["type"] == "classify": | |
| return "category" in action and "priority" in action | |
| return True | |
| # ========================= | |
| # PURPOSE: Hybrid control (LLM + adaptive fallback) | |
| # ========================= | |
| def get_action(obs, valid_actions): | |
| prompt = build_prompt(obs) | |
| if client: | |
| try: | |
| resp = client.chat.completions.create( | |
| model=os.getenv("MODEL_NAME"), | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.4, | |
| response_format={"type": "json_object"} | |
| ) | |
| text = resp.choices[0].message.content | |
| parsed = json.loads(text) | |
| action = parsed.get("action") | |
| if action and "type" in action: | |
| return action | |
| except: | |
| pass | |
| return fallback(obs) | |
| # ========================= | |
| # RUN | |
| # ========================= | |
| def run_agent(): | |
| env = CustomerSupportEnv() | |
| obs = env.reset() | |
| done = False | |
| while not done: | |
| valid_actions = [ | |
| {"type": "ask_info", "field": "order_id"}, | |
| {"type": "ask_info", "field": "account_email"}, | |
| {"type": "ask_info", "field": "device_type"}, | |
| {"type": "ask_info", "field": "browser"}, | |
| {"type": "resolve"}, | |
| {"type": "classify"}, | |
| ] | |
| action = get_action(obs, valid_actions) | |
| obs, reward, done, info = env.step(action) | |
| #print(f"\nOBS: {obs}") | |
| #print(f"\nACTION: {action}") | |
| #print(f"\nREWARD: {reward}") | |
| #print(f"\nDONE: {done}") | |
| #print("FINAL:", info) | |
| #print(f"\nFINAL: {info if info else 'No info returned'}") | |
| #print(f"\nMETRICS: {env.get_metrics()}") | |
| if __name__ == "__main__": | |
| run_agent() |