Imaginephoenix commited on
Commit
97c9151
·
verified ·
1 Parent(s): 1f071d0

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -377
inference.py DELETED
@@ -1,377 +0,0 @@
1
- """Inference script for OpenEnv email triage with strict stdout event format."""
2
-
3
- import argparse
4
- import json
5
- import os
6
- import re
7
- import time
8
- from typing import Any
9
-
10
- from openai import OpenAI
11
-
12
- from environment import EmailTriageEnv
13
- from models import EmailObservation, TriageAction
14
-
15
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
16
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- API_KEY = HF_TOKEN or os.getenv("API_KEY")
19
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
20
-
21
- BENCHMARK = "openenv-email-triage"
22
- MAX_STEPS = 30
23
- TEMPERATURE = 0.2
24
- MAX_TOKENS = 200
25
- SUCCESS_SCORE_THRESHOLD = 0.5
26
- DEFAULT_RUNTIME_BUDGET_SECONDS = int(os.getenv("INFERENCE_RUNTIME_BUDGET_SECONDS", "1140"))
27
- DEFAULT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("INFERENCE_REQUEST_TIMEOUT_SECONDS", "12"))
28
-
29
- SYSTEM_PROMPT = (
30
- "You are an email triage assistant. For each email, prioritize risk/time impact, "
31
- "categorize with one label (urgent|normal|spam|archive), route to the best team, "
32
- "and summarize the key evidence. Return one JSON object with keys label, summary, route_to."
33
- )
34
-
35
- FALLBACK_ACTION = {
36
- "label": "normal",
37
- "summary": "Unable to parse response",
38
- "route_to": "general",
39
- }
40
-
41
- TASK_MAP = {
42
- "1": "task_easy",
43
- "2": "task_medium",
44
- "3": "task_hard",
45
- "4": "task_production",
46
- }
47
-
48
-
49
- def parse_args() -> argparse.Namespace:
50
- """Parse command-line arguments for task and optional model override."""
51
- parser = argparse.ArgumentParser(description="Run OpenEnv email triage inference.")
52
- parser.add_argument(
53
- "--task",
54
- default="all",
55
- choices=["1", "2", "3", "4", "all"],
56
- help="Task selection: 1, 2, 3, 4, or all.",
57
- )
58
- parser.add_argument(
59
- "--model",
60
- default=None,
61
- help="Optional model override. Falls back to MODEL_NAME environment variable.",
62
- )
63
- parser.add_argument(
64
- "--split",
65
- default=os.getenv("OPENENV_EVAL_SPLIT", "public"),
66
- choices=["public", "private_eval"],
67
- help="Scenario split to evaluate.",
68
- )
69
- parser.add_argument(
70
- "--episodes-per-task",
71
- default=1,
72
- type=int,
73
- help="Number of deterministic scenarios to evaluate per task.",
74
- )
75
- parser.add_argument(
76
- "--runtime-budget-seconds",
77
- default=DEFAULT_RUNTIME_BUDGET_SECONDS,
78
- type=int,
79
- help="Global wall-clock budget for the full script run.",
80
- )
81
- parser.add_argument(
82
- "--request-timeout-seconds",
83
- default=DEFAULT_REQUEST_TIMEOUT_SECONDS,
84
- type=float,
85
- help="Timeout per LLM request.",
86
- )
87
- parser.add_argument(
88
- "--production-profile",
89
- default="standard",
90
- choices=["light", "standard", "heavy"],
91
- help="Runtime workload profile used for task 4 episodes.",
92
- )
93
- parser.add_argument(
94
- "--business-hours-mode",
95
- action="store_true",
96
- help="If set, task 4 timestamps focus on business-hours windows.",
97
- )
98
- parser.add_argument(
99
- "--escalation-mode",
100
- default="normal",
101
- choices=["low", "normal", "high"],
102
- help="Escalation strictness for task 4 follow-up generation.",
103
- )
104
- return parser.parse_args()
105
-
106
-
107
- def validate_runtime_config(model_name: str | None) -> str:
108
- """Validate required runtime settings and return effective model name."""
109
- if not API_KEY:
110
- raise ValueError("Missing HF_TOKEN or API_KEY environment variable.")
111
-
112
- effective_model = model_name or MODEL_NAME
113
- return effective_model
114
-
115
-
116
- def log_start(task_name: str, benchmark_name: str, model_name: str) -> None:
117
- """Emit mandatory START line."""
118
- print(
119
- f"[START] task={task_name} env={benchmark_name} model={model_name}",
120
- flush=True,
121
- )
122
-
123
-
124
- def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
125
- """Emit mandatory STEP line."""
126
- error_value = error if error else "null"
127
- done_value = str(done).lower()
128
- print(
129
- f"[STEP] step={step} action={action_str} reward={reward:.2f} "
130
- f"done={done_value} error={error_value}",
131
- flush=True,
132
- )
133
-
134
-
135
- def log_end(success: bool, steps: int, rewards: list[float]) -> None:
136
- """Emit mandatory END line."""
137
- rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
138
- print(
139
- f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}",
140
- flush=True,
141
- )
142
-
143
-
144
- def build_user_prompt(observation: EmailObservation, history: list[str]) -> str:
145
- """Build model prompt from current observation and recent history."""
146
- recent_history = "\n".join(history[-5:]) if history else "None"
147
- return (
148
- f"email_id: {observation.email_id}\n"
149
- f"subject: {observation.subject}\n"
150
- f"sender: {observation.sender}\n"
151
- f"timestamp: {observation.timestamp}\n"
152
- f"body: {observation.body}\n"
153
- f"thread_history: {observation.thread_history}\n"
154
- f"task_id: {observation.task_id}\n"
155
- f"step_number: {observation.step_number}\n"
156
- f"total_emails: {observation.total_emails}\n\n"
157
- f"recent_history:\n{recent_history}\n\n"
158
- "Return exactly one JSON object with label, summary, route_to."
159
- )
160
-
161
-
162
- def strip_action_prefixes(response_text: str) -> str:
163
- """Remove common formatting wrappers before parsing model output."""
164
- cleaned = response_text.strip()
165
- cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
166
- cleaned = re.sub(r"```$", "", cleaned).strip()
167
- cleaned = re.sub(r"^(next\s+action|action)\s*:\s*", "", cleaned, flags=re.IGNORECASE)
168
- return cleaned.strip()
169
-
170
-
171
- def parse_text_action(cleaned_text: str) -> dict[str, str]:
172
- """Parse action from free-form text with deterministic regex fallback."""
173
- result: dict[str, str] = {}
174
-
175
- label_match = re.search(
176
- r"(?:\"label\"|label)\s*[:=]\s*\"?(urgent|normal|spam|archive)\"?",
177
- cleaned_text,
178
- flags=re.IGNORECASE,
179
- )
180
- if label_match:
181
- result["label"] = label_match.group(1).lower()
182
-
183
- route_match = re.search(
184
- r"(?:\"route_to\"|route_to|route)\s*[:=]\s*\"?([a-zA-Z0-9_\-/ ]+)\"?",
185
- cleaned_text,
186
- flags=re.IGNORECASE,
187
- )
188
- if route_match:
189
- result["route_to"] = route_match.group(1).strip().lower()
190
-
191
- summary_match = re.search(
192
- r"(?:\"summary\"|summary)\s*[:=]\s*\"?([^\"\n]+)\"?",
193
- cleaned_text,
194
- flags=re.IGNORECASE,
195
- )
196
- if summary_match:
197
- result["summary"] = summary_match.group(1).strip()
198
-
199
- return result
200
-
201
-
202
- def parse_action_response(response_text: str) -> TriageAction:
203
- """Parse model response into a valid TriageAction with fallback behavior."""
204
- cleaned_text = strip_action_prefixes(response_text)
205
- parsed_payload: dict[str, Any] = {}
206
-
207
- json_start = cleaned_text.find("{")
208
- json_end = cleaned_text.rfind("}")
209
- if json_start != -1 and json_end != -1 and json_end > json_start:
210
- candidate = cleaned_text[json_start : json_end + 1]
211
- try:
212
- loaded = json.loads(candidate)
213
- if isinstance(loaded, dict):
214
- parsed_payload = loaded
215
- except json.JSONDecodeError:
216
- parsed_payload = {}
217
-
218
- if not parsed_payload:
219
- parsed_payload = parse_text_action(cleaned_text)
220
-
221
- fallback_copy = dict(FALLBACK_ACTION)
222
- fallback_copy.update(parsed_payload)
223
-
224
- try:
225
- return TriageAction.model_validate(fallback_copy)
226
- except Exception:
227
- return TriageAction.model_validate(FALLBACK_ACTION)
228
-
229
-
230
- def action_to_log_string(action: TriageAction) -> str:
231
- """Return single-line action string for required STEP logging."""
232
- return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
233
-
234
-
235
- def run_episode(
236
- client: OpenAI,
237
- model_name: str,
238
- task_id: str,
239
- scenario_index: int,
240
- eval_split: str,
241
- deadline: float,
242
- request_timeout_seconds: float,
243
- runtime_options: dict[str, Any] | None = None,
244
- ) -> None:
245
- """Run one episode and emit strict START/STEP/END lines."""
246
- rewards: list[float] = []
247
- steps_taken = 0
248
- success = False
249
- env: EmailTriageEnv | None = None
250
-
251
- log_start(task_name=task_id, benchmark_name=BENCHMARK, model_name=model_name)
252
-
253
- try:
254
- env = EmailTriageEnv(
255
- task_id=task_id,
256
- scenario_index=scenario_index,
257
- split=eval_split,
258
- runtime_options=runtime_options,
259
- )
260
- reset_result = env.reset()
261
- observation = reset_result.observation
262
- history: list[str] = []
263
-
264
- for step in range(1, MAX_STEPS + 1):
265
- if time.monotonic() >= deadline:
266
- break
267
-
268
- prompt = build_user_prompt(observation, history)
269
-
270
- response_text = ""
271
- try:
272
- remaining = max(1.0, deadline - time.monotonic())
273
- timeout_seconds = max(
274
- 1.0,
275
- min(float(request_timeout_seconds), float(remaining)),
276
- )
277
- completion = client.chat.completions.create(
278
- model=model_name,
279
- messages=[
280
- {"role": "system", "content": SYSTEM_PROMPT},
281
- {"role": "user", "content": prompt},
282
- ],
283
- temperature=TEMPERATURE,
284
- max_tokens=MAX_TOKENS,
285
- stream=False,
286
- timeout=timeout_seconds,
287
- )
288
- response_text = completion.choices[0].message.content or ""
289
- except Exception:
290
- response_text = ""
291
-
292
- action = parse_action_response(response_text)
293
- step_result = env.step(action)
294
-
295
- reward = float(step_result.reward)
296
- done = bool(step_result.done)
297
- error_raw = step_result.info.get("validation_error")
298
- error = str(error_raw) if isinstance(error_raw, str) else None
299
-
300
- rewards.append(reward)
301
- steps_taken = step
302
-
303
- log_step(
304
- step=step,
305
- action_str=action_to_log_string(action),
306
- reward=reward,
307
- done=done,
308
- error=error,
309
- )
310
-
311
- history.append(
312
- f"step={step} action={action.label}/{action.route_to} reward={reward:.2f}"
313
- )
314
- observation = step_result.observation
315
-
316
- if done:
317
- break
318
-
319
- avg_reward = sum(rewards) / max(len(rewards), 1)
320
- success = avg_reward >= SUCCESS_SCORE_THRESHOLD
321
- except Exception:
322
- success = False
323
- finally:
324
- if env is not None:
325
- close_method = getattr(env, "close", None)
326
- if callable(close_method):
327
- try:
328
- close_method()
329
- except Exception:
330
- pass
331
-
332
- log_end(success=success, steps=steps_taken, rewards=rewards)
333
-
334
-
335
- def main() -> None:
336
- """Entrypoint for running one or many tasks with strict stdout logs."""
337
- args = parse_args()
338
- deadline = time.monotonic() + max(args.runtime_budget_seconds, 1)
339
- request_timeout_seconds = max(float(args.request_timeout_seconds), 1.0)
340
-
341
- try:
342
- effective_model = validate_runtime_config(args.model)
343
- except ValueError as error:
344
- print(str(error), flush=True)
345
- raise SystemExit(1) from error
346
-
347
- _ = LOCAL_IMAGE_NAME
348
-
349
- client = OpenAI(
350
- base_url=API_BASE_URL,
351
- api_key=API_KEY,
352
- )
353
-
354
- task_ids = [TASK_MAP[args.task]] if args.task in TASK_MAP else list(TASK_MAP.values())
355
- for task_id in task_ids:
356
- runtime_options = None
357
- if task_id == "task_production":
358
- runtime_options = {
359
- "production_profile": args.production_profile,
360
- "business_hours_mode": args.business_hours_mode,
361
- "escalation_mode": args.escalation_mode,
362
- }
363
- for scenario_index in range(max(args.episodes_per_task, 1)):
364
- run_episode(
365
- client=client,
366
- model_name=effective_model,
367
- task_id=task_id,
368
- scenario_index=scenario_index,
369
- eval_split=args.split,
370
- deadline=deadline,
371
- request_timeout_seconds=request_timeout_seconds,
372
- runtime_options=runtime_options,
373
- )
374
-
375
-
376
- if __name__ == "__main__":
377
- main()