Imaginephoenix commited on
Commit
33f0af6
·
verified ·
1 Parent(s): b22aacc

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +473 -0
inference.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference script for OpenEnv email triage with strict stdout event format."""
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import re
7
+ import time
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from openai import OpenAI
12
+
13
+ from environment import EmailTriageEnv
14
+ from models import EmailObservation, TriageAction
15
+
16
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
17
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+ API_KEY = HF_TOKEN or os.getenv("API_KEY")
20
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
21
+
22
+ BENCHMARK = "openenv-email-triage"
23
+ MAX_STEPS = 30
24
+ TEMPERATURE = 0.2
25
+ MAX_TOKENS = 200
26
+ SUCCESS_SCORE_THRESHOLD = 0.5
27
+ LOG_SCORE_EPSILON = 1e-2
28
+ DEFAULT_RUNTIME_BUDGET_SECONDS = int(os.getenv("INFERENCE_RUNTIME_BUDGET_SECONDS", "1140"))
29
+ DEFAULT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("INFERENCE_REQUEST_TIMEOUT_SECONDS", "12"))
30
+
31
+ SYSTEM_PROMPT = (
32
+ "You are an email triage assistant. For each email, prioritize risk/time impact, "
33
+ "categorize with one label (urgent|normal|spam|archive), route to the best team, "
34
+ "and summarize the key evidence. Return one JSON object with keys label, summary, route_to."
35
+ )
36
+
37
+ FALLBACK_ACTION = {
38
+ "label": "normal",
39
+ "summary": "Unable to parse response",
40
+ "route_to": "general",
41
+ }
42
+
43
+ TASK_MAP = {
44
+ "1": "task_easy",
45
+ "2": "task_medium",
46
+ "3": "task_hard",
47
+ "4": "task_production",
48
+ }
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class EpisodeResult:
53
+ """Represents one task episode outcome emitted by inference."""
54
+
55
+ task_id: str
56
+ scenario_index: int
57
+ score: float
58
+ steps: int
59
+ success: bool
60
+
61
+
62
+ def parse_args() -> argparse.Namespace:
63
+ """Parse command-line arguments for task and optional model override."""
64
+ parser = argparse.ArgumentParser(description="Run OpenEnv email triage inference.")
65
+ parser.add_argument(
66
+ "--task",
67
+ default="all",
68
+ choices=["1", "2", "3", "4", "all"],
69
+ help="Task selection: 1, 2, 3, 4, or all.",
70
+ )
71
+ parser.add_argument(
72
+ "--model",
73
+ default=None,
74
+ help="Optional model override. Falls back to MODEL_NAME environment variable.",
75
+ )
76
+ parser.add_argument(
77
+ "--split",
78
+ default=os.getenv("OPENENV_EVAL_SPLIT", "public"),
79
+ choices=["public", "private_eval"],
80
+ help="Scenario split to evaluate.",
81
+ )
82
+ parser.add_argument(
83
+ "--episodes-per-task",
84
+ default=1,
85
+ type=int,
86
+ help="Number of deterministic scenarios to evaluate per task.",
87
+ )
88
+ parser.add_argument(
89
+ "--runtime-budget-seconds",
90
+ default=DEFAULT_RUNTIME_BUDGET_SECONDS,
91
+ type=int,
92
+ help="Global wall-clock budget for the full script run.",
93
+ )
94
+ parser.add_argument(
95
+ "--request-timeout-seconds",
96
+ default=DEFAULT_REQUEST_TIMEOUT_SECONDS,
97
+ type=float,
98
+ help="Timeout per LLM request.",
99
+ )
100
+ parser.add_argument(
101
+ "--production-profile",
102
+ default="standard",
103
+ choices=["light", "standard", "heavy"],
104
+ help="Runtime workload profile used for task 4 episodes.",
105
+ )
106
+ parser.add_argument(
107
+ "--business-hours-mode",
108
+ action="store_true",
109
+ help="If set, task 4 timestamps focus on business-hours windows.",
110
+ )
111
+ parser.add_argument(
112
+ "--escalation-mode",
113
+ default="normal",
114
+ choices=["low", "normal", "high"],
115
+ help="Escalation strictness for task 4 follow-up generation.",
116
+ )
117
+ return parser.parse_args()
118
+
119
+
120
+ def validate_runtime_config(model_name: str | None) -> str:
121
+ """Validate required runtime settings and return effective model name."""
122
+ if not API_KEY:
123
+ raise ValueError("Missing HF_TOKEN or API_KEY environment variable.")
124
+
125
+ effective_model = model_name or MODEL_NAME
126
+ return effective_model
127
+
128
+
129
+ def log_start(task_name: str, benchmark_name: str, model_name: str) -> None:
130
+ """Emit mandatory START line."""
131
+ print(
132
+ f"[START] task={task_name} env={benchmark_name} model={model_name}",
133
+ flush=True,
134
+ )
135
+
136
+
137
+ def _format_open_score(value: float) -> str:
138
+ """Format scores in strict-open range while preserving .2f log contract."""
139
+ clamped = max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(value)))
140
+ return f"{clamped:.2f}"
141
+
142
+
143
+ def _strict_task_score(raw_score: float) -> float:
144
+ """Return task score in strict-open interval for evaluator compatibility."""
145
+ return max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(raw_score)))
146
+
147
+
148
+ def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
149
+ """Emit mandatory STEP line."""
150
+ error_value = error if error else "null"
151
+ done_value = str(done).lower()
152
+ print(
153
+ f"[STEP] step={step} action={action_str} reward={_format_open_score(reward)} "
154
+ f"done={done_value} error={error_value}",
155
+ flush=True,
156
+ )
157
+
158
+
159
+ def log_end(success: bool, steps: int, rewards: list[float], task_score: float) -> None:
160
+ """Emit mandatory END line."""
161
+ rewards_str = ",".join(_format_open_score(reward) for reward in rewards)
162
+ strict_task_score = _strict_task_score(task_score)
163
+ score_str = _format_open_score(strict_task_score)
164
+ print(
165
+ f"[END] score={score_str} task_score={score_str} "
166
+ f"success={str(success).lower()} steps={steps} rewards={rewards_str}",
167
+ flush=True,
168
+ )
169
+
170
+
171
+ def log_task_score(result: EpisodeResult) -> None:
172
+ """Emit explicit per-task score line for downstream validators."""
173
+ print(
174
+ "[TASK_SCORE] "
175
+ f"task={result.task_id} scenario={result.scenario_index} "
176
+ f"score={_format_open_score(result.score)} steps={result.steps} "
177
+ f"success={str(result.success).lower()}",
178
+ flush=True,
179
+ )
180
+
181
+
182
+ def log_score_table(results: list[EpisodeResult], ordered_task_ids: list[str]) -> None:
183
+ """Emit a parse-friendly score table across tasks."""
184
+ if not results:
185
+ return
186
+
187
+ print("=== SCORE TABLE ===", flush=True)
188
+ print("Task\tScore\tEpisodes", flush=True)
189
+
190
+ task_means: list[float] = []
191
+ for task_id in ordered_task_ids:
192
+ task_scores = [item.score for item in results if item.task_id == task_id]
193
+ if not task_scores:
194
+ continue
195
+
196
+ task_mean = _strict_task_score(sum(task_scores) / len(task_scores))
197
+ task_means.append(task_mean)
198
+
199
+ print(
200
+ f"{task_id}\t{_format_open_score(task_mean)}\t{len(task_scores)}",
201
+ flush=True,
202
+ )
203
+ print(
204
+ f"[TASK_AGG] task={task_id} score={_format_open_score(task_mean)}",
205
+ flush=True,
206
+ )
207
+
208
+ if task_means:
209
+ mean_score = _strict_task_score(sum(task_means) / len(task_means))
210
+ print(f"Mean\t{_format_open_score(mean_score)}\t-", flush=True)
211
+ print(f"[MEAN_SCORE] score={_format_open_score(mean_score)}", flush=True)
212
+
213
+
214
+ def build_user_prompt(observation: EmailObservation, history: list[str]) -> str:
215
+ """Build model prompt from current observation and recent history."""
216
+ recent_history = "\n".join(history[-5:]) if history else "None"
217
+ return (
218
+ f"email_id: {observation.email_id}\n"
219
+ f"subject: {observation.subject}\n"
220
+ f"sender: {observation.sender}\n"
221
+ f"timestamp: {observation.timestamp}\n"
222
+ f"body: {observation.body}\n"
223
+ f"thread_history: {observation.thread_history}\n"
224
+ f"task_id: {observation.task_id}\n"
225
+ f"step_number: {observation.step_number}\n"
226
+ f"total_emails: {observation.total_emails}\n\n"
227
+ f"recent_history:\n{recent_history}\n\n"
228
+ "Return exactly one JSON object with label, summary, route_to."
229
+ )
230
+
231
+
232
+ def strip_action_prefixes(response_text: str) -> str:
233
+ """Remove common formatting wrappers before parsing model output."""
234
+ cleaned = response_text.strip()
235
+ cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
236
+ cleaned = re.sub(r"```$", "", cleaned).strip()
237
+ cleaned = re.sub(r"^(next\s+action|action)\s*:\s*", "", cleaned, flags=re.IGNORECASE)
238
+ return cleaned.strip()
239
+
240
+
241
+ def parse_text_action(cleaned_text: str) -> dict[str, str]:
242
+ """Parse action from free-form text with deterministic regex fallback."""
243
+ result: dict[str, str] = {}
244
+
245
+ label_match = re.search(
246
+ r"(?:\"label\"|label)\s*[:=]\s*\"?(urgent|normal|spam|archive)\"?",
247
+ cleaned_text,
248
+ flags=re.IGNORECASE,
249
+ )
250
+ if label_match:
251
+ result["label"] = label_match.group(1).lower()
252
+
253
+ route_match = re.search(
254
+ r"(?:\"route_to\"|route_to|route)\s*[:=]\s*\"?([a-zA-Z0-9_\-/ ]+)\"?",
255
+ cleaned_text,
256
+ flags=re.IGNORECASE,
257
+ )
258
+ if route_match:
259
+ result["route_to"] = route_match.group(1).strip().lower()
260
+
261
+ summary_match = re.search(
262
+ r"(?:\"summary\"|summary)\s*[:=]\s*\"?([^\"\n]+)\"?",
263
+ cleaned_text,
264
+ flags=re.IGNORECASE,
265
+ )
266
+ if summary_match:
267
+ result["summary"] = summary_match.group(1).strip()
268
+
269
+ return result
270
+
271
+
272
+ def parse_action_response(response_text: str) -> TriageAction:
273
+ """Parse model response into a valid TriageAction with fallback behavior."""
274
+ cleaned_text = strip_action_prefixes(response_text)
275
+ parsed_payload: dict[str, Any] = {}
276
+
277
+ json_start = cleaned_text.find("{")
278
+ json_end = cleaned_text.rfind("}")
279
+ if json_start != -1 and json_end != -1 and json_end > json_start:
280
+ candidate = cleaned_text[json_start : json_end + 1]
281
+ try:
282
+ loaded = json.loads(candidate)
283
+ if isinstance(loaded, dict):
284
+ parsed_payload = loaded
285
+ except json.JSONDecodeError:
286
+ parsed_payload = {}
287
+
288
+ if not parsed_payload:
289
+ parsed_payload = parse_text_action(cleaned_text)
290
+
291
+ fallback_copy = dict(FALLBACK_ACTION)
292
+ fallback_copy.update(parsed_payload)
293
+
294
+ try:
295
+ return TriageAction.model_validate(fallback_copy)
296
+ except Exception:
297
+ return TriageAction.model_validate(FALLBACK_ACTION)
298
+
299
+
300
+ def action_to_log_string(action: TriageAction) -> str:
301
+ """Return single-line action string for required STEP logging."""
302
+ return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
303
+
304
+
305
+ def run_episode(
306
+ client: OpenAI,
307
+ model_name: str,
308
+ task_id: str,
309
+ scenario_index: int,
310
+ eval_split: str,
311
+ deadline: float,
312
+ request_timeout_seconds: float,
313
+ runtime_options: dict[str, Any] | None = None,
314
+ ) -> EpisodeResult:
315
+ """Run one episode and emit strict START/STEP/END lines."""
316
+ rewards: list[float] = []
317
+ steps_taken = 0
318
+ success = False
319
+ final_task_score = LOG_SCORE_EPSILON
320
+ env: EmailTriageEnv | None = None
321
+
322
+ log_start(task_name=task_id, benchmark_name=BENCHMARK, model_name=model_name)
323
+
324
+ try:
325
+ env = EmailTriageEnv(
326
+ task_id=task_id,
327
+ scenario_index=scenario_index,
328
+ split=eval_split,
329
+ runtime_options=runtime_options,
330
+ )
331
+ reset_result = env.reset()
332
+ observation = reset_result.observation
333
+ history: list[str] = []
334
+
335
+ for step in range(1, MAX_STEPS + 1):
336
+ if time.monotonic() >= deadline:
337
+ break
338
+
339
+ prompt = build_user_prompt(observation, history)
340
+
341
+ response_text = ""
342
+ try:
343
+ remaining = max(1.0, deadline - time.monotonic())
344
+ timeout_seconds = max(
345
+ 1.0,
346
+ min(float(request_timeout_seconds), float(remaining)),
347
+ )
348
+ completion = client.chat.completions.create(
349
+ model=model_name,
350
+ messages=[
351
+ {"role": "system", "content": SYSTEM_PROMPT},
352
+ {"role": "user", "content": prompt},
353
+ ],
354
+ temperature=TEMPERATURE,
355
+ max_tokens=MAX_TOKENS,
356
+ stream=False,
357
+ timeout=timeout_seconds,
358
+ )
359
+ response_text = completion.choices[0].message.content or ""
360
+ except Exception:
361
+ response_text = ""
362
+
363
+ action = parse_action_response(response_text)
364
+ step_result = env.step(action)
365
+
366
+ reward = _strict_task_score(float(step_result.reward))
367
+ done = bool(step_result.done)
368
+ error_raw = step_result.info.get("validation_error")
369
+ error = str(error_raw) if isinstance(error_raw, str) else None
370
+
371
+ rewards.append(reward)
372
+ steps_taken = step
373
+
374
+ log_step(
375
+ step=step,
376
+ action_str=action_to_log_string(action),
377
+ reward=reward,
378
+ done=done,
379
+ error=error,
380
+ )
381
+
382
+ history.append(
383
+ f"step={step} action={action.label}/{action.route_to} reward={_format_open_score(reward)}"
384
+ )
385
+ observation = step_result.observation
386
+
387
+ if done:
388
+ break
389
+
390
+ if not rewards:
391
+ rewards.append(LOG_SCORE_EPSILON)
392
+
393
+ final_task_score = _strict_task_score(sum(rewards) / len(rewards))
394
+ success = final_task_score >= SUCCESS_SCORE_THRESHOLD
395
+ except Exception:
396
+ if not rewards:
397
+ rewards.append(LOG_SCORE_EPSILON)
398
+ final_task_score = _strict_task_score(sum(rewards) / len(rewards))
399
+ success = False
400
+ finally:
401
+ if env is not None:
402
+ close_method = getattr(env, "close", None)
403
+ if callable(close_method):
404
+ try:
405
+ close_method()
406
+ except Exception:
407
+ pass
408
+
409
+ log_end(
410
+ success=success,
411
+ steps=steps_taken,
412
+ rewards=rewards,
413
+ task_score=final_task_score,
414
+ )
415
+
416
+ return EpisodeResult(
417
+ task_id=task_id,
418
+ scenario_index=scenario_index,
419
+ score=_strict_task_score(final_task_score),
420
+ steps=steps_taken,
421
+ success=success,
422
+ )
423
+
424
+
425
+ def main() -> None:
426
+ """Entrypoint for running one or many tasks with strict stdout logs."""
427
+ args = parse_args()
428
+ deadline = time.monotonic() + max(args.runtime_budget_seconds, 1)
429
+ request_timeout_seconds = max(float(args.request_timeout_seconds), 1.0)
430
+
431
+ try:
432
+ effective_model = validate_runtime_config(args.model)
433
+ except ValueError as error:
434
+ print(str(error), flush=True)
435
+ raise SystemExit(1) from error
436
+
437
+ _ = LOCAL_IMAGE_NAME
438
+
439
+ client = OpenAI(
440
+ base_url=API_BASE_URL,
441
+ api_key=API_KEY,
442
+ )
443
+
444
+ task_ids = [TASK_MAP[args.task]] if args.task in TASK_MAP else list(TASK_MAP.values())
445
+ episode_results: list[EpisodeResult] = []
446
+
447
+ for task_id in task_ids:
448
+ runtime_options = None
449
+ if task_id == "task_production":
450
+ runtime_options = {
451
+ "production_profile": args.production_profile,
452
+ "business_hours_mode": args.business_hours_mode,
453
+ "escalation_mode": args.escalation_mode,
454
+ }
455
+ for scenario_index in range(max(args.episodes_per_task, 1)):
456
+ result = run_episode(
457
+ client=client,
458
+ model_name=effective_model,
459
+ task_id=task_id,
460
+ scenario_index=scenario_index,
461
+ eval_split=args.split,
462
+ deadline=deadline,
463
+ request_timeout_seconds=request_timeout_seconds,
464
+ runtime_options=runtime_options,
465
+ )
466
+ episode_results.append(result)
467
+ log_task_score(result)
468
+
469
+ log_score_table(episode_results, task_ids)
470
+
471
+
472
+ if __name__ == "__main__":
473
+ main()