HyperBrickCaseOps / inference.py
modelbuilderhq's picture
Upload folder using huggingface_hub
181758b verified
raw
history blame
4.87 kB
"""Baseline inference script for the SupportDesk OpenEnv submission."""
from __future__ import annotations
import json
import os
import re
from statistics import mean
try:
from openai import OpenAI
except ImportError: # pragma: no cover - local fallback mode
OpenAI = None # type: ignore[assignment]
from supportdesk_env.graders import grade_case
from supportdesk_env.models import SupportDeskAction, SupportDeskObservation
from supportdesk_env.policies import heuristic_action
from supportdesk_env.server.supportdesk_environment import SupportDeskEnvironment
from supportdesk_env.tasks import get_task, list_task_ids
SYSTEM_PROMPT = """You are a support operations agent solving one triage ticket.
Return exactly one JSON object with this schema:
{
"operation": "classify|request_info|draft_reply|add_internal_note|submit",
"queue": string or null,
"priority": string or null,
"issue_type": string or null,
"status": string or null,
"resolution_code": string or null,
"requested_fields": [string],
"reply": string or null,
"internal_note": string or null
}
Use the policy snippets in the observation. Keep customer replies short, precise, and professional.
"""
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
API_BASE_URL = os.getenv("API_BASE_URL")
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or "not-set"
MAX_STEPS = int(os.getenv("MAX_STEPS", "6"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
def _build_client() -> OpenAI | None:
if OpenAI is None:
return None
if API_KEY == "not-set":
return None
kwargs = {"api_key": API_KEY}
if API_BASE_URL:
kwargs["base_url"] = API_BASE_URL
return OpenAI(**kwargs)
def _extract_json(text: str) -> dict:
try:
return json.loads(text)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", text, flags=re.DOTALL)
if not match:
raise
return json.loads(match.group(0))
def _observation_prompt(observation: SupportDeskObservation) -> str:
kb_lines = "\n".join(
f"- {snippet.article_id}: {snippet.title}: {snippet.content}" for snippet in observation.knowledge_base
)
history_lines = "\n".join(
f"- step {entry.step}: {entry.summary} ({entry.reward_delta:+.2f})"
for entry in observation.action_history
) or "- none"
return f"""Task: {observation.task_id} ({observation.difficulty})
Objective: {observation.objective}
Ticket subject: {observation.ticket.subject}
Ticket body: {observation.ticket.body}
Customer tier: {observation.ticket.customer_tier}
Region: {observation.ticket.region}
Affected users: {observation.ticket.affected_users}
SLA minutes remaining: {observation.ticket.sla_minutes_remaining}
Business impact: {observation.ticket.business_impact}
Secondary concerns: {observation.ticket.secondary_concerns}
Knowledge base:
{kb_lines}
Current case state:
- queue: {observation.case.queue}
- priority: {observation.case.priority}
- issue_type: {observation.case.issue_type}
- status: {observation.case.status}
- resolution_code: {observation.case.resolution_code}
- requested_fields: {observation.case.requested_fields}
- reply: {observation.case.reply}
- internal_note: {observation.case.internal_note}
Feedback: {observation.feedback}
Remaining steps: {observation.remaining_steps}
History:
{history_lines}
"""
def _model_action(client: OpenAI | None, observation: SupportDeskObservation) -> SupportDeskAction:
if client is None:
return heuristic_action(observation)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
temperature=TEMPERATURE,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": _observation_prompt(observation)},
],
)
content = completion.choices[0].message.content or ""
payload = _extract_json(content)
return SupportDeskAction(**payload)
except Exception:
return heuristic_action(observation)
def run_task(task_id: str, client: OpenAI | None) -> float:
env = SupportDeskEnvironment(task_id=task_id)
observation = env.reset()
try:
for _ in range(MAX_STEPS):
action = _model_action(client, observation)
observation = env.step(action)
if observation.done:
break
final_grade = grade_case(get_task(task_id), env.state.case)
print(f"{task_id}: score={final_grade.total_score:.2f} reward={env.state.reward:.2f}")
return final_grade.total_score
finally:
env.close()
def main() -> None:
client = _build_client()
scores = [run_task(task_id, client) for task_id in list_task_ids()]
print(f"average_score={mean(scores):.3f}")
if __name__ == "__main__":
main()