Spaces:
Sleeping
Sleeping
| """Baseline inference script for the SupportDesk OpenEnv submission.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from statistics import mean | |
| try: | |
| from openai import OpenAI | |
| except ImportError: # pragma: no cover - local fallback mode | |
| OpenAI = None # type: ignore[assignment] | |
| from supportdesk_env.graders import grade_case | |
| from supportdesk_env.models import SupportDeskAction, SupportDeskObservation | |
| from supportdesk_env.policies import heuristic_action | |
| from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment | |
| from supportdesk_env.tasks import get_task, list_task_ids | |
| SYSTEM_PROMPT = """You are a support operations agent solving one triage ticket. | |
| Return exactly one JSON object with this schema: | |
| { | |
| "operation": "classify|request_info|draft_reply|add_internal_note|submit", | |
| "queue": string or null, | |
| "priority": string or null, | |
| "issue_type": string or null, | |
| "status": string or null, | |
| "resolution_code": string or null, | |
| "requested_fields": [string], | |
| "reply": string or null, | |
| "internal_note": string or null | |
| } | |
| Use the policy snippets in the observation. Keep customer replies short, precise, and professional. | |
| """ | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini") | |
| API_BASE_URL = os.getenv("API_BASE_URL") | |
| API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or "not-set" | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "6")) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0")) | |
| def _build_client() -> OpenAI | None: | |
| if OpenAI is None: | |
| return None | |
| if API_KEY == "not-set": | |
| return None | |
| kwargs = {"api_key": API_KEY} | |
| if API_BASE_URL: | |
| kwargs["base_url"] = API_BASE_URL | |
| return OpenAI(**kwargs) | |
| def _extract_json(text: str) -> dict: | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| match = re.search(r"\{.*\}", text, flags=re.DOTALL) | |
| if not match: | |
| raise | |
| return json.loads(match.group(0)) | |
| def _observation_prompt(observation: SupportDeskObservation) -> str: | |
| kb_lines = "\n".join( | |
| f"- {snippet.article_id}: {snippet.title}: {snippet.content}" for snippet in observation.knowledge_base | |
| ) | |
| history_lines = "\n".join( | |
| f"- step {entry.step}: {entry.summary} ({entry.reward_delta:+.2f})" | |
| for entry in observation.action_history | |
| ) or "- none" | |
| return f"""Task: {observation.task_id} ({observation.difficulty}) | |
| Objective: {observation.objective} | |
| Ticket subject: {observation.ticket.subject} | |
| Ticket body: {observation.ticket.body} | |
| Customer tier: {observation.ticket.customer_tier} | |
| Region: {observation.ticket.region} | |
| Affected users: {observation.ticket.affected_users} | |
| SLA minutes remaining: {observation.ticket.sla_minutes_remaining} | |
| Business impact: {observation.ticket.business_impact} | |
| Secondary concerns: {observation.ticket.secondary_concerns} | |
| Knowledge base: | |
| {kb_lines} | |
| Current case state: | |
| - queue: {observation.case.queue} | |
| - priority: {observation.case.priority} | |
| - issue_type: {observation.case.issue_type} | |
| - status: {observation.case.status} | |
| - resolution_code: {observation.case.resolution_code} | |
| - requested_fields: {observation.case.requested_fields} | |
| - reply: {observation.case.reply} | |
| - internal_note: {observation.case.internal_note} | |
| Feedback: {observation.feedback} | |
| Remaining steps: {observation.remaining_steps} | |
| History: | |
| {history_lines} | |
| """ | |
| def _model_action(client: OpenAI | None, observation: SupportDeskObservation) -> SupportDeskAction: | |
| if client is None: | |
| return heuristic_action(observation) | |
| try: | |
| completion = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| temperature=TEMPERATURE, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": _observation_prompt(observation)}, | |
| ], | |
| ) | |
| content = completion.choices[0].message.content or "" | |
| payload = _extract_json(content) | |
| return SupportDeskAction(**payload) | |
| except Exception: | |
| return heuristic_action(observation) | |
| def run_task(task_id: str, client: OpenAI | None) -> float: | |
| env = SupportDeskEnvironment(task_id=task_id) | |
| observation = env.reset() | |
| try: | |
| for _ in range(MAX_STEPS): | |
| action = _model_action(client, observation) | |
| observation = env.step(action) | |
| if observation.done: | |
| break | |
| final_grade = grade_case(get_task(task_id), env.state.case) | |
| print(f"{task_id}: score={final_grade.total_score:.2f} reward={env.state.reward:.2f}") | |
| return final_grade.total_score | |
| finally: | |
| env.close() | |
| def main() -> None: | |
| client = _build_client() | |
| scores = [run_task(task_id, client) for task_id in list_task_ids()] | |
| print(f"average_score={mean(scores):.3f}") | |
| if __name__ == "__main__": | |
| main() | |