Imaginephoenix commited on
Commit
b22aacc
·
verified ·
1 Parent(s): e721fd9

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -403
inference.py DELETED
@@ -1,403 +0,0 @@
1
- """Inference script for OpenEnv email triage with strict stdout event format."""
2
-
3
- import argparse
4
- import json
5
- import os
6
- import re
7
- import time
8
- from typing import Any
9
-
10
- from openai import OpenAI
11
-
12
- from environment import EmailTriageEnv
13
- from models import EmailObservation, TriageAction
14
-
15
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
16
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- API_KEY = HF_TOKEN or os.getenv("API_KEY")
19
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
20
-
21
- BENCHMARK = "openenv-email-triage"
22
- MAX_STEPS = 30
23
- TEMPERATURE = 0.2
24
- MAX_TOKENS = 200
25
- SUCCESS_SCORE_THRESHOLD = 0.5
26
- LOG_SCORE_EPSILON = 1e-2
27
- DEFAULT_RUNTIME_BUDGET_SECONDS = int(os.getenv("INFERENCE_RUNTIME_BUDGET_SECONDS", "1140"))
28
- DEFAULT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("INFERENCE_REQUEST_TIMEOUT_SECONDS", "12"))
29
-
30
- SYSTEM_PROMPT = (
31
- "You are an email triage assistant. For each email, prioritize risk/time impact, "
32
- "categorize with one label (urgent|normal|spam|archive), route to the best team, "
33
- "and summarize the key evidence. Return one JSON object with keys label, summary, route_to."
34
- )
35
-
36
- FALLBACK_ACTION = {
37
- "label": "normal",
38
- "summary": "Unable to parse response",
39
- "route_to": "general",
40
- }
41
-
42
- TASK_MAP = {
43
- "1": "task_easy",
44
- "2": "task_medium",
45
- "3": "task_hard",
46
- "4": "task_production",
47
- }
48
-
49
-
50
- def parse_args() -> argparse.Namespace:
51
- """Parse command-line arguments for task and optional model override."""
52
- parser = argparse.ArgumentParser(description="Run OpenEnv email triage inference.")
53
- parser.add_argument(
54
- "--task",
55
- default="all",
56
- choices=["1", "2", "3", "4", "all"],
57
- help="Task selection: 1, 2, 3, 4, or all.",
58
- )
59
- parser.add_argument(
60
- "--model",
61
- default=None,
62
- help="Optional model override. Falls back to MODEL_NAME environment variable.",
63
- )
64
- parser.add_argument(
65
- "--split",
66
- default=os.getenv("OPENENV_EVAL_SPLIT", "public"),
67
- choices=["public", "private_eval"],
68
- help="Scenario split to evaluate.",
69
- )
70
- parser.add_argument(
71
- "--episodes-per-task",
72
- default=1,
73
- type=int,
74
- help="Number of deterministic scenarios to evaluate per task.",
75
- )
76
- parser.add_argument(
77
- "--runtime-budget-seconds",
78
- default=DEFAULT_RUNTIME_BUDGET_SECONDS,
79
- type=int,
80
- help="Global wall-clock budget for the full script run.",
81
- )
82
- parser.add_argument(
83
- "--request-timeout-seconds",
84
- default=DEFAULT_REQUEST_TIMEOUT_SECONDS,
85
- type=float,
86
- help="Timeout per LLM request.",
87
- )
88
- parser.add_argument(
89
- "--production-profile",
90
- default="standard",
91
- choices=["light", "standard", "heavy"],
92
- help="Runtime workload profile used for task 4 episodes.",
93
- )
94
- parser.add_argument(
95
- "--business-hours-mode",
96
- action="store_true",
97
- help="If set, task 4 timestamps focus on business-hours windows.",
98
- )
99
- parser.add_argument(
100
- "--escalation-mode",
101
- default="normal",
102
- choices=["low", "normal", "high"],
103
- help="Escalation strictness for task 4 follow-up generation.",
104
- )
105
- return parser.parse_args()
106
-
107
-
108
- def validate_runtime_config(model_name: str | None) -> str:
109
- """Validate required runtime settings and return effective model name."""
110
- if not API_KEY:
111
- raise ValueError("Missing HF_TOKEN or API_KEY environment variable.")
112
-
113
- effective_model = model_name or MODEL_NAME
114
- return effective_model
115
-
116
-
117
- def log_start(task_name: str, benchmark_name: str, model_name: str) -> None:
118
- """Emit mandatory START line."""
119
- print(
120
- f"[START] task={task_name} env={benchmark_name} model={model_name}",
121
- flush=True,
122
- )
123
-
124
-
125
- def _format_open_score(value: float) -> str:
126
- """Format scores in strict-open range while preserving .2f log contract."""
127
- clamped = max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(value)))
128
- return f"{clamped:.2f}"
129
-
130
-
131
- def _strict_task_score(raw_score: float) -> float:
132
- """Return task score in strict-open interval for evaluator compatibility."""
133
- return max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(raw_score)))
134
-
135
-
136
- def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
137
- """Emit mandatory STEP line."""
138
- error_value = error if error else "null"
139
- done_value = str(done).lower()
140
- print(
141
- f"[STEP] step={step} action={action_str} reward={_format_open_score(reward)} "
142
- f"done={done_value} error={error_value}",
143
- flush=True,
144
- )
145
-
146
-
147
- def log_end(success: bool, steps: int, rewards: list[float], task_score: float) -> None:
148
- """Emit mandatory END line."""
149
- rewards_str = ",".join(_format_open_score(reward) for reward in rewards)
150
- strict_task_score = _strict_task_score(task_score)
151
- print(
152
- f"[END] task_score={_format_open_score(strict_task_score)} "
153
- f"success={str(success).lower()} steps={steps} rewards={rewards_str}",
154
- flush=True,
155
- )
156
-
157
-
158
- def build_user_prompt(observation: EmailObservation, history: list[str]) -> str:
159
- """Build model prompt from current observation and recent history."""
160
- recent_history = "\n".join(history[-5:]) if history else "None"
161
- return (
162
- f"email_id: {observation.email_id}\n"
163
- f"subject: {observation.subject}\n"
164
- f"sender: {observation.sender}\n"
165
- f"timestamp: {observation.timestamp}\n"
166
- f"body: {observation.body}\n"
167
- f"thread_history: {observation.thread_history}\n"
168
- f"task_id: {observation.task_id}\n"
169
- f"step_number: {observation.step_number}\n"
170
- f"total_emails: {observation.total_emails}\n\n"
171
- f"recent_history:\n{recent_history}\n\n"
172
- "Return exactly one JSON object with label, summary, route_to."
173
- )
174
-
175
-
176
- def strip_action_prefixes(response_text: str) -> str:
177
- """Remove common formatting wrappers before parsing model output."""
178
- cleaned = response_text.strip()
179
- cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
180
- cleaned = re.sub(r"```$", "", cleaned).strip()
181
- cleaned = re.sub(r"^(next\s+action|action)\s*:\s*", "", cleaned, flags=re.IGNORECASE)
182
- return cleaned.strip()
183
-
184
-
185
- def parse_text_action(cleaned_text: str) -> dict[str, str]:
186
- """Parse action from free-form text with deterministic regex fallback."""
187
- result: dict[str, str] = {}
188
-
189
- label_match = re.search(
190
- r"(?:\"label\"|label)\s*[:=]\s*\"?(urgent|normal|spam|archive)\"?",
191
- cleaned_text,
192
- flags=re.IGNORECASE,
193
- )
194
- if label_match:
195
- result["label"] = label_match.group(1).lower()
196
-
197
- route_match = re.search(
198
- r"(?:\"route_to\"|route_to|route)\s*[:=]\s*\"?([a-zA-Z0-9_\-/ ]+)\"?",
199
- cleaned_text,
200
- flags=re.IGNORECASE,
201
- )
202
- if route_match:
203
- result["route_to"] = route_match.group(1).strip().lower()
204
-
205
- summary_match = re.search(
206
- r"(?:\"summary\"|summary)\s*[:=]\s*\"?([^\"\n]+)\"?",
207
- cleaned_text,
208
- flags=re.IGNORECASE,
209
- )
210
- if summary_match:
211
- result["summary"] = summary_match.group(1).strip()
212
-
213
- return result
214
-
215
-
216
- def parse_action_response(response_text: str) -> TriageAction:
217
- """Parse model response into a valid TriageAction with fallback behavior."""
218
- cleaned_text = strip_action_prefixes(response_text)
219
- parsed_payload: dict[str, Any] = {}
220
-
221
- json_start = cleaned_text.find("{")
222
- json_end = cleaned_text.rfind("}")
223
- if json_start != -1 and json_end != -1 and json_end > json_start:
224
- candidate = cleaned_text[json_start : json_end + 1]
225
- try:
226
- loaded = json.loads(candidate)
227
- if isinstance(loaded, dict):
228
- parsed_payload = loaded
229
- except json.JSONDecodeError:
230
- parsed_payload = {}
231
-
232
- if not parsed_payload:
233
- parsed_payload = parse_text_action(cleaned_text)
234
-
235
- fallback_copy = dict(FALLBACK_ACTION)
236
- fallback_copy.update(parsed_payload)
237
-
238
- try:
239
- return TriageAction.model_validate(fallback_copy)
240
- except Exception:
241
- return TriageAction.model_validate(FALLBACK_ACTION)
242
-
243
-
244
- def action_to_log_string(action: TriageAction) -> str:
245
- """Return single-line action string for required STEP logging."""
246
- return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
247
-
248
-
249
- def run_episode(
250
- client: OpenAI,
251
- model_name: str,
252
- task_id: str,
253
- scenario_index: int,
254
- eval_split: str,
255
- deadline: float,
256
- request_timeout_seconds: float,
257
- runtime_options: dict[str, Any] | None = None,
258
- ) -> None:
259
- """Run one episode and emit strict START/STEP/END lines."""
260
- rewards: list[float] = []
261
- steps_taken = 0
262
- success = False
263
- final_task_score = LOG_SCORE_EPSILON
264
- env: EmailTriageEnv | None = None
265
-
266
- log_start(task_name=task_id, benchmark_name=BENCHMARK, model_name=model_name)
267
-
268
- try:
269
- env = EmailTriageEnv(
270
- task_id=task_id,
271
- scenario_index=scenario_index,
272
- split=eval_split,
273
- runtime_options=runtime_options,
274
- )
275
- reset_result = env.reset()
276
- observation = reset_result.observation
277
- history: list[str] = []
278
-
279
- for step in range(1, MAX_STEPS + 1):
280
- if time.monotonic() >= deadline:
281
- break
282
-
283
- prompt = build_user_prompt(observation, history)
284
-
285
- response_text = ""
286
- try:
287
- remaining = max(1.0, deadline - time.monotonic())
288
- timeout_seconds = max(
289
- 1.0,
290
- min(float(request_timeout_seconds), float(remaining)),
291
- )
292
- completion = client.chat.completions.create(
293
- model=model_name,
294
- messages=[
295
- {"role": "system", "content": SYSTEM_PROMPT},
296
- {"role": "user", "content": prompt},
297
- ],
298
- temperature=TEMPERATURE,
299
- max_tokens=MAX_TOKENS,
300
- stream=False,
301
- timeout=timeout_seconds,
302
- )
303
- response_text = completion.choices[0].message.content or ""
304
- except Exception:
305
- response_text = ""
306
-
307
- action = parse_action_response(response_text)
308
- step_result = env.step(action)
309
-
310
- reward = _strict_task_score(float(step_result.reward))
311
- done = bool(step_result.done)
312
- error_raw = step_result.info.get("validation_error")
313
- error = str(error_raw) if isinstance(error_raw, str) else None
314
-
315
- rewards.append(reward)
316
- steps_taken = step
317
-
318
- log_step(
319
- step=step,
320
- action_str=action_to_log_string(action),
321
- reward=reward,
322
- done=done,
323
- error=error,
324
- )
325
-
326
- history.append(
327
- f"step={step} action={action.label}/{action.route_to} reward={_format_open_score(reward)}"
328
- )
329
- observation = step_result.observation
330
-
331
- if done:
332
- break
333
-
334
- if not rewards:
335
- rewards.append(LOG_SCORE_EPSILON)
336
-
337
- final_task_score = _strict_task_score(sum(rewards) / len(rewards))
338
- success = final_task_score >= SUCCESS_SCORE_THRESHOLD
339
- except Exception:
340
- if not rewards:
341
- rewards.append(LOG_SCORE_EPSILON)
342
- final_task_score = _strict_task_score(sum(rewards) / len(rewards))
343
- success = False
344
- finally:
345
- if env is not None:
346
- close_method = getattr(env, "close", None)
347
- if callable(close_method):
348
- try:
349
- close_method()
350
- except Exception:
351
- pass
352
-
353
- log_end(
354
- success=success,
355
- steps=steps_taken,
356
- rewards=rewards,
357
- task_score=final_task_score,
358
- )
359
-
360
-
361
- def main() -> None:
362
- """Entrypoint for running one or many tasks with strict stdout logs."""
363
- args = parse_args()
364
- deadline = time.monotonic() + max(args.runtime_budget_seconds, 1)
365
- request_timeout_seconds = max(float(args.request_timeout_seconds), 1.0)
366
-
367
- try:
368
- effective_model = validate_runtime_config(args.model)
369
- except ValueError as error:
370
- print(str(error), flush=True)
371
- raise SystemExit(1) from error
372
-
373
- _ = LOCAL_IMAGE_NAME
374
-
375
- client = OpenAI(
376
- base_url=API_BASE_URL,
377
- api_key=API_KEY,
378
- )
379
-
380
- task_ids = [TASK_MAP[args.task]] if args.task in TASK_MAP else list(TASK_MAP.values())
381
- for task_id in task_ids:
382
- runtime_options = None
383
- if task_id == "task_production":
384
- runtime_options = {
385
- "production_profile": args.production_profile,
386
- "business_hours_mode": args.business_hours_mode,
387
- "escalation_mode": args.escalation_mode,
388
- }
389
- for scenario_index in range(max(args.episodes_per_task, 1)):
390
- run_episode(
391
- client=client,
392
- model_name=effective_model,
393
- task_id=task_id,
394
- scenario_index=scenario_index,
395
- eval_split=args.split,
396
- deadline=deadline,
397
- request_timeout_seconds=request_timeout_seconds,
398
- runtime_options=runtime_options,
399
- )
400
-
401
-
402
- if __name__ == "__main__":
403
- main()