File size: 8,928 Bytes
c357a18
 
 
 
 
 
 
 
 
 
 
 
 
0e31d8c
c357a18
d0b62f7
22c0d63
c357a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e651edc
 
 
 
 
c357a18
 
 
e651edc
 
c357a18
 
 
 
 
 
 
 
 
 
 
d0b62f7
 
 
 
c357a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
from dotenv import load_dotenv

load_dotenv()

import os
import json
import sys
from typing import Any, Dict, List, Literal, Optional
import httpx
from openai import OpenAI
from pydantic import BaseModel, ConfigDict
from logger import log_start, log_step, log_end

API_BASE_URL: str = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME: str = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-7B-Instruct"
# HF router and similar setups use HF_TOKEN; OpenAI uses OPENAI_API_KEY.
API_KEY: str = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN", "")
ENV_BASE_URL: str = os.getenv("ENV_BASE_URL", "http://localhost:7860")
TASKS = ["email_classification", "response_drafting", "support_session"]
MAX_STEPS_FALLBACK = 60
BENCHMARK = "Sieve"


class ClassifyOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")
    category: Literal[
        "billing", "technical", "general", "spam", "account", "feature_request"
    ]
    urgency: Literal["high", "medium", "low"]


class RespondOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")
    response_text: str


class SessionOutput(BaseModel):
    model_config = ConfigDict(extra="forbid")
    email_id: str
    action_type: Literal["respond", "escalate", "archive"]
    category: Literal[
        "billing", "technical", "general", "spam", "account", "feature_request"
    ]
    urgency: Literal["high", "medium", "low"]
    response_text: Optional[str] = None
    escalation_reason: Optional[str] = None


def strict_schema(output_cls: type[BaseModel]) -> dict:
    schema = output_cls.model_json_schema()
    if "properties" in schema:
        schema["required"] = list(schema["properties"].keys())
        schema["additionalProperties"] = False
    return schema


def structured_call(client: OpenAI, messages: list, output_cls: type[BaseModel]):
    response = client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        response_format={
            "type": "json_schema",
            "json_schema": {
                "name": output_cls.__name__,
                "schema": strict_schema(output_cls),
                "strict": True,
            },
        },
        temperature=0,
    )
    return output_cls.model_validate_json(response.choices[0].message.content)


def llm_action(
    observation: Dict[str, Any], task_id: str, client: OpenAI
) -> Dict[str, Any]:
    current = observation.get("current_email") or {}
    queue: List[Dict] = observation.get("email_queue", [])

    try:
        if task_id == "email_classification":
            result = structured_call(
                client,
                messages=[
                    {
                        "role": "system",
                        "content": "Classify customer support emails by category and urgency.",
                    },
                    {
                        "role": "user",
                        "content": f"Subject: {current.get('subject', '')}\nBody: {current.get('body', '')}",
                    },
                ],
                output_cls=ClassifyOutput,
            )
            return {
                "action_type": "classify",
                "category": result.category,
                "urgency": result.urgency,
            }

        if task_id == "response_drafting":
            result = structured_call(
                client,
                messages=[
                    {
                        "role": "system",
                        "content": (
                            "Write a professional, empathetic customer support response. "
                            "Address the issue directly, include relevant details (timelines, links, case refs). "
                            "Minimum 80 words."
                        ),
                    },
                    {
                        "role": "user",
                        "content": f"Subject: {current.get('subject', '')}\nBody: {current.get('body', '')}",
                    },
                ],
                output_cls=RespondOutput,
            )
            return {"action_type": "respond", "response_text": result.response_text}

        if task_id == "support_session":
            queue_str = "\n".join(
                f"{e['id']} | VIP={e.get('sender_tier') == 'vip'} | {e['subject'][:70]}"
                for e in queue[:15]
            )
            result = structured_call(
                client,
                messages=[
                    {
                        "role": "system",
                        "content": (
                            "You manage a customer support queue. Pick the single highest-priority email.\n"
                            "Priority: 1) VIP=True first  2) security/breach → escalate  "
                            "3) billing/technical → respond  4) feature requests/spam → archive"
                        ),
                    },
                    {"role": "user", "content": queue_str},
                ],
                output_cls=SessionOutput,
            )
            action: Dict[str, Any] = {
                "action_type": result.action_type,
                "email_id": result.email_id,
                "category": result.category,
                "urgency": result.urgency,
            }
            if result.response_text:
                action["response_text"] = result.response_text
            if result.escalation_reason:
                action["escalation_reason"] = result.escalation_reason
            return action

    except Exception as exc:
        raise RuntimeError(f"LLM call failed for task '{task_id}': {exc}") from exc


def get_task_max_steps(http: httpx.Client, task_id: str) -> int:
    try:
        tasks = http.get("/tasks").json().get("tasks", [])
        for t in tasks:
            if t["id"] == task_id:
                return t["max_steps"]
    except Exception:
        pass
    return MAX_STEPS_FALLBACK


def run_task(task_id: str, client: OpenAI) -> Dict[str, Any]:
    http = httpx.Client(base_url=ENV_BASE_URL, timeout=30.0)
    max_steps = get_task_max_steps(http, task_id)

    log_start(task_id, BENCHMARK, MODEL_NAME)

    resp = http.post("/reset", params={"task_id": task_id})
    resp.raise_for_status()
    obs = resp.json()

    done = False
    steps = 0
    rewards: List[float] = []

    try:
        while not done and steps < max_steps:
            action = llm_action(obs, task_id, client)

            resp = http.post("/step", json=action)
            resp.raise_for_status()
            result = resp.json()

            obs = result["observation"]
            reward = result["reward"]["value"]
            done = result["done"]
            error = result.get("info", {}).get("error") or None
            steps += 1
            rewards.append(reward)

            log_step(steps, action, reward, done, error)

        grader = http.get("/grader").json()
        raw_score = grader.get("score") or 0.0
        # Clamp strictly inside (0, 1): logger uses :.4f, so bounds must be
        # representable without rounding to 0.0000 or 1.0000.
        score = min(max(raw_score, 0.001), 0.999)
        log_end(success=True, steps=steps, score=score, rewards=rewards)

    except Exception as exc:
        print(f"Episode error ({task_id}): {exc}", file=sys.stderr)
        score = 0.001
        log_end(success=False, steps=steps, score=score, rewards=rewards)

    return {
        "task_id": task_id,
        "score": score,
        "steps": steps,
        "total_reward": round(sum(rewards), 4),
    }


def main() -> None:
    if not API_KEY:
        print(
            "Error: set HF_TOKEN or OPENAI_API_KEY for the LLM API key.",
            file=sys.stderr,
        )
        sys.exit(1)

    try:
        client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    except Exception as exc:
        print(f"Error: could not initialise OpenAI client: {exc}", file=sys.stderr)
        sys.exit(1)

    print(f"Agent    : {MODEL_NAME}", file=sys.stderr)
    print(f"API base : {API_BASE_URL}", file=sys.stderr)
    print(f"Env URL  : {ENV_BASE_URL}", file=sys.stderr)

    results: List[Dict[str, Any]] = []
    for task_id in TASKS:
        print(f"Running {task_id} ...", file=sys.stderr)
        try:
            result = run_task(task_id, client)
            results.append(result)
            print(
                f"  score={result['score']}  steps={result['steps']}", file=sys.stderr
            )
        except Exception as exc:
            print(f"  ERROR: {exc}", file=sys.stderr)
            results.append(
                {"task_id": task_id, "score": 0.0, "steps": 0, "error": str(exc)}
            )

    avg = round(sum(r.get("score", 0.0) for r in results) / len(results), 3)
    summary = {"agent": MODEL_NAME, "results": results, "average_score": avg}
    print(json.dumps(summary, indent=2), file=sys.stderr)


if __name__ == "__main__":
    main()