Spaces:
Sleeping
Sleeping
File size: 4,869 Bytes
181758b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """Baseline inference script for the SupportDesk OpenEnv submission."""
from __future__ import annotations
import json
import os
import re
from statistics import mean
try:
from openai import OpenAI
except ImportError: # pragma: no cover - local fallback mode
OpenAI = None # type: ignore[assignment]
from supportdesk_env.graders import grade_case
from supportdesk_env.models import SupportDeskAction, SupportDeskObservation
from supportdesk_env.policies import heuristic_action
from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
from supportdesk_env.tasks import get_task, list_task_ids
SYSTEM_PROMPT = """You are a support operations agent solving one triage ticket.
Return exactly one JSON object with this schema:
{
"operation": "classify|request_info|draft_reply|add_internal_note|submit",
"queue": string or null,
"priority": string or null,
"issue_type": string or null,
"status": string or null,
"resolution_code": string or null,
"requested_fields": [string],
"reply": string or null,
"internal_note": string or null
}
Use the policy snippets in the observation. Keep customer replies short, precise, and professional.
"""
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
API_BASE_URL = os.getenv("API_BASE_URL")
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or "not-set"
MAX_STEPS = int(os.getenv("MAX_STEPS", "6"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
def _build_client() -> OpenAI | None:
if OpenAI is None:
return None
if API_KEY == "not-set":
return None
kwargs = {"api_key": API_KEY}
if API_BASE_URL:
kwargs["base_url"] = API_BASE_URL
return OpenAI(**kwargs)
def _extract_json(text: str) -> dict:
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if not match:
raise
return json.loads(match.group(0))
def _observation_prompt(observation: SupportDeskObservation) -> str:
kb_lines = "\n".join(
f"- {snippet.article_id}: {snippet.title}: {snippet.content}" for snippet in observation.knowledge_base
)
history_lines = "\n".join(
f"- step {entry.step}: {entry.summary} ({entry.reward_delta:+.2f})"
for entry in observation.action_history
) or "- none"
return f"""Task: {observation.task_id} ({observation.difficulty})
Objective: {observation.objective}
Ticket subject: {observation.ticket.subject}
Ticket body: {observation.ticket.body}
Customer tier: {observation.ticket.customer_tier}
Region: {observation.ticket.region}
Affected users: {observation.ticket.affected_users}
SLA minutes remaining: {observation.ticket.sla_minutes_remaining}
Business impact: {observation.ticket.business_impact}
Secondary concerns: {observation.ticket.secondary_concerns}
Knowledge base:
{kb_lines}
Current case state:
- queue: {observation.case.queue}
- priority: {observation.case.priority}
- issue_type: {observation.case.issue_type}
- status: {observation.case.status}
- resolution_code: {observation.case.resolution_code}
- requested_fields: {observation.case.requested_fields}
- reply: {observation.case.reply}
- internal_note: {observation.case.internal_note}
Feedback: {observation.feedback}
Remaining steps: {observation.remaining_steps}
History:
{history_lines}
"""
def _model_action(client: OpenAI | None, observation: SupportDeskObservation) -> SupportDeskAction:
if client is None:
return heuristic_action(observation)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
temperature=TEMPERATURE,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": _observation_prompt(observation)},
],
)
content = completion.choices[0].message.content or ""
payload = _extract_json(content)
return SupportDeskAction(**payload)
except Exception:
return heuristic_action(observation)
def run_task(task_id: str, client: OpenAI | None) -> float:
env = SupportDeskEnvironment(task_id=task_id)
observation = env.reset()
try:
for _ in range(MAX_STEPS):
action = _model_action(client, observation)
observation = env.step(action)
if observation.done:
break
final_grade = grade_case(get_task(task_id), env.state.case)
print(f"{task_id}: score={final_grade.total_score:.2f} reward={env.state.reward:.2f}")
return final_grade.total_score
finally:
env.close()
def main() -> None:
client = _build_client()
scores = [run_task(task_id, client) for task_id in list_task_ids()]
print(f"average_score={mean(scores):.3f}")
if __name__ == "__main__":
main()
|