"""Baseline inference script for the SupportDesk OpenEnv submission.""" from __future__ import annotations import json import os import re from statistics import mean try: from openai import OpenAI except ImportError: # pragma: no cover - local fallback mode OpenAI = None # type: ignore[assignment] from supportdesk_env.graders import grade_case from supportdesk_env.models import SupportDeskAction, SupportDeskObservation from supportdesk_env.policies import heuristic_action from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment from supportdesk_env.tasks import get_task, list_task_ids SYSTEM_PROMPT = """You are a support operations agent solving one triage ticket. Return exactly one JSON object with this schema: { "operation": "classify|request_info|draft_reply|add_internal_note|submit", "queue": string or null, "priority": string or null, "issue_type": string or null, "status": string or null, "resolution_code": string or null, "requested_fields": [string], "reply": string or null, "internal_note": string or null } Use the policy snippets in the observation. Keep customer replies short, precise, and professional. """ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini") API_BASE_URL = os.getenv("API_BASE_URL") API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or "not-set" MAX_STEPS = int(os.getenv("MAX_STEPS", "6")) TEMPERATURE = float(os.getenv("TEMPERATURE", "0")) def _build_client() -> OpenAI | None: if OpenAI is None: return None if API_KEY == "not-set": return None kwargs = {"api_key": API_KEY} if API_BASE_URL: kwargs["base_url"] = API_BASE_URL return OpenAI(**kwargs) def _extract_json(text: str) -> dict: try: return json.loads(text) except json.JSONDecodeError: match = re.search(r"\{.*\}", text, flags=re.DOTALL) if not match: raise return json.loads(match.group(0)) def _observation_prompt(observation: SupportDeskObservation) -> str: kb_lines = "\n".join( f"- {snippet.article_id}: {snippet.title}: {snippet.content}" for snippet in observation.knowledge_base ) history_lines = "\n".join( f"- step {entry.step}: {entry.summary} ({entry.reward_delta:+.2f})" for entry in observation.action_history ) or "- none" return f"""Task: {observation.task_id} ({observation.difficulty}) Objective: {observation.objective} Ticket subject: {observation.ticket.subject} Ticket body: {observation.ticket.body} Customer tier: {observation.ticket.customer_tier} Region: {observation.ticket.region} Affected users: {observation.ticket.affected_users} SLA minutes remaining: {observation.ticket.sla_minutes_remaining} Business impact: {observation.ticket.business_impact} Secondary concerns: {observation.ticket.secondary_concerns} Knowledge base: {kb_lines} Current case state: - queue: {observation.case.queue} - priority: {observation.case.priority} - issue_type: {observation.case.issue_type} - status: {observation.case.status} - resolution_code: {observation.case.resolution_code} - requested_fields: {observation.case.requested_fields} - reply: {observation.case.reply} - internal_note: {observation.case.internal_note} Feedback: {observation.feedback} Remaining steps: {observation.remaining_steps} History: {history_lines} """ def _model_action(client: OpenAI | None, observation: SupportDeskObservation) -> SupportDeskAction: if client is None: return heuristic_action(observation) try: completion = client.chat.completions.create( model=MODEL_NAME, temperature=TEMPERATURE, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": _observation_prompt(observation)}, ], ) content = completion.choices[0].message.content or "" payload = _extract_json(content) return SupportDeskAction(**payload) except Exception: return heuristic_action(observation) def run_task(task_id: str, client: OpenAI | None) -> float: env = SupportDeskEnvironment(task_id=task_id) observation = env.reset() try: for _ in range(MAX_STEPS): action = _model_action(client, observation) observation = env.step(action) if observation.done: break final_grade = grade_case(get_task(task_id), env.state.case) print(f"{task_id}: score={final_grade.total_score:.2f} reward={env.state.reward:.2f}") return final_grade.total_score finally: env.close() def main() -> None: client = _build_client() scores = [run_task(task_id, client) for task_id in list_task_ids()] print(f"average_score={mean(scores):.3f}") if __name__ == "__main__": main()