Imaginephoenix commited on
Commit
d259149
·
verified ·
1 Parent(s): 1611ba6

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -392
inference.py DELETED
@@ -1,392 +0,0 @@
1
- """Inference script for OpenEnv email triage with strict stdout event format."""
2
-
3
- import argparse
4
- import json
5
- import os
6
- import re
7
- import time
8
- from typing import Any
9
-
10
- from openai import OpenAI
11
-
12
- from environment import EmailTriageEnv
13
- from models import EmailObservation, TriageAction
14
-
15
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
16
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- API_KEY = HF_TOKEN or os.getenv("API_KEY")
19
- LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
20
-
21
- BENCHMARK = "openenv-email-triage"
22
- MAX_STEPS = 30
23
- TEMPERATURE = 0.2
24
- MAX_TOKENS = 200
25
- SUCCESS_SCORE_THRESHOLD = 0.5
26
- LOG_SCORE_EPSILON = 1e-2
27
- DEFAULT_RUNTIME_BUDGET_SECONDS = int(os.getenv("INFERENCE_RUNTIME_BUDGET_SECONDS", "1140"))
28
- DEFAULT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("INFERENCE_REQUEST_TIMEOUT_SECONDS", "12"))
29
-
30
- SYSTEM_PROMPT = (
31
- "You are an email triage assistant. For each email, prioritize risk/time impact, "
32
- "categorize with one label (urgent|normal|spam|archive), route to the best team, "
33
- "and summarize the key evidence. Return one JSON object with keys label, summary, route_to."
34
- )
35
-
36
- FALLBACK_ACTION = {
37
- "label": "normal",
38
- "summary": "Unable to parse response",
39
- "route_to": "general",
40
- }
41
-
42
- TASK_MAP = {
43
- "1": "task_easy",
44
- "2": "task_medium",
45
- "3": "task_hard",
46
- "4": "task_production",
47
- }
48
-
49
-
50
- def parse_args() -> argparse.Namespace:
51
- """Parse command-line arguments for task and optional model override."""
52
- parser = argparse.ArgumentParser(description="Run OpenEnv email triage inference.")
53
- parser.add_argument(
54
- "--task",
55
- default="all",
56
- choices=["1", "2", "3", "4", "all"],
57
- help="Task selection: 1, 2, 3, 4, or all.",
58
- )
59
- parser.add_argument(
60
- "--model",
61
- default=None,
62
- help="Optional model override. Falls back to MODEL_NAME environment variable.",
63
- )
64
- parser.add_argument(
65
- "--split",
66
- default=os.getenv("OPENENV_EVAL_SPLIT", "public"),
67
- choices=["public", "private_eval"],
68
- help="Scenario split to evaluate.",
69
- )
70
- parser.add_argument(
71
- "--episodes-per-task",
72
- default=1,
73
- type=int,
74
- help="Number of deterministic scenarios to evaluate per task.",
75
- )
76
- parser.add_argument(
77
- "--runtime-budget-seconds",
78
- default=DEFAULT_RUNTIME_BUDGET_SECONDS,
79
- type=int,
80
- help="Global wall-clock budget for the full script run.",
81
- )
82
- parser.add_argument(
83
- "--request-timeout-seconds",
84
- default=DEFAULT_REQUEST_TIMEOUT_SECONDS,
85
- type=float,
86
- help="Timeout per LLM request.",
87
- )
88
- parser.add_argument(
89
- "--production-profile",
90
- default="standard",
91
- choices=["light", "standard", "heavy"],
92
- help="Runtime workload profile used for task 4 episodes.",
93
- )
94
- parser.add_argument(
95
- "--business-hours-mode",
96
- action="store_true",
97
- help="If set, task 4 timestamps focus on business-hours windows.",
98
- )
99
- parser.add_argument(
100
- "--escalation-mode",
101
- default="normal",
102
- choices=["low", "normal", "high"],
103
- help="Escalation strictness for task 4 follow-up generation.",
104
- )
105
- return parser.parse_args()
106
-
107
-
108
- def validate_runtime_config(model_name: str | None) -> str:
109
- """Validate required runtime settings and return effective model name."""
110
- if not API_KEY:
111
- raise ValueError("Missing HF_TOKEN or API_KEY environment variable.")
112
-
113
- effective_model = model_name or MODEL_NAME
114
- return effective_model
115
-
116
-
117
- def log_start(task_name: str, benchmark_name: str, model_name: str) -> None:
118
- """Emit mandatory START line."""
119
- print(
120
- f"[START] task={task_name} env={benchmark_name} model={model_name}",
121
- flush=True,
122
- )
123
-
124
-
125
- def _format_open_score(value: float) -> str:
126
- """Format scores in strict-open range while preserving .2f log contract."""
127
- clamped = max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(value)))
128
- return f"{clamped:.2f}"
129
-
130
- def _strict_task_score(raw_score: float) -> float:
131
- """Return task score in strict-open interval for evaluator compatibility."""
132
- return max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(raw_score)))
133
-
134
- def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
135
- """Emit mandatory STEP line."""
136
- error_value = error if error else "null"
137
- done_value = str(done).lower()
138
- print(
139
- f"[STEP] step={step} action={action_str} reward={_format_open_score(reward)} "
140
- f"done={done_value} error={error_value}",
141
- flush=True,
142
- )
143
-
144
-
145
- def log_end(success: bool, steps: int, rewards: list[float]) -> None:
146
- """Emit mandatory END line."""
147
- rewards_str = ",".join(_format_open_score(reward) for reward in rewards)
148
- print(
149
- f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}",
150
- flush=True,
151
- )
152
-
153
-
154
- def build_user_prompt(observation: EmailObservation, history: list[str]) -> str:
155
- """Build model prompt from current observation and recent history."""
156
- recent_history = "\n".join(history[-5:]) if history else "None"
157
- return (
158
- f"email_id: {observation.email_id}\n"
159
- f"subject: {observation.subject}\n"
160
- f"sender: {observation.sender}\n"
161
- f"timestamp: {observation.timestamp}\n"
162
- f"body: {observation.body}\n"
163
- f"thread_history: {observation.thread_history}\n"
164
- f"task_id: {observation.task_id}\n"
165
- f"step_number: {observation.step_number}\n"
166
- f"total_emails: {observation.total_emails}\n\n"
167
- f"recent_history:\n{recent_history}\n\n"
168
- "Return exactly one JSON object with label, summary, route_to."
169
- )
170
-
171
-
172
- def strip_action_prefixes(response_text: str) -> str:
173
- """Remove common formatting wrappers before parsing model output."""
174
- cleaned = response_text.strip()
175
- cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
176
- cleaned = re.sub(r"```$", "", cleaned).strip()
177
- cleaned = re.sub(r"^(next\s+action|action)\s*:\s*", "", cleaned, flags=re.IGNORECASE)
178
- return cleaned.strip()
179
-
180
-
181
- def parse_text_action(cleaned_text: str) -> dict[str, str]:
182
- """Parse action from free-form text with deterministic regex fallback."""
183
- result: dict[str, str] = {}
184
-
185
- label_match = re.search(
186
- r"(?:\"label\"|label)\s*[:=]\s*\"?(urgent|normal|spam|archive)\"?",
187
- cleaned_text,
188
- flags=re.IGNORECASE,
189
- )
190
- if label_match:
191
- result["label"] = label_match.group(1).lower()
192
-
193
- route_match = re.search(
194
- r"(?:\"route_to\"|route_to|route)\s*[:=]\s*\"?([a-zA-Z0-9_\-/ ]+)\"?",
195
- cleaned_text,
196
- flags=re.IGNORECASE,
197
- )
198
- if route_match:
199
- result["route_to"] = route_match.group(1).strip().lower()
200
-
201
- summary_match = re.search(
202
- r"(?:\"summary\"|summary)\s*[:=]\s*\"?([^\"\n]+)\"?",
203
- cleaned_text,
204
- flags=re.IGNORECASE,
205
- )
206
- if summary_match:
207
- result["summary"] = summary_match.group(1).strip()
208
-
209
- return result
210
-
211
-
212
- def parse_action_response(response_text: str) -> TriageAction:
213
- """Parse model response into a valid TriageAction with fallback behavior."""
214
- cleaned_text = strip_action_prefixes(response_text)
215
- parsed_payload: dict[str, Any] = {}
216
-
217
- json_start = cleaned_text.find("{")
218
- json_end = cleaned_text.rfind("}")
219
- if json_start != -1 and json_end != -1 and json_end > json_start:
220
- candidate = cleaned_text[json_start : json_end + 1]
221
- try:
222
- loaded = json.loads(candidate)
223
- if isinstance(loaded, dict):
224
- parsed_payload = loaded
225
- except json.JSONDecodeError:
226
- parsed_payload = {}
227
-
228
- if not parsed_payload:
229
- parsed_payload = parse_text_action(cleaned_text)
230
-
231
- fallback_copy = dict(FALLBACK_ACTION)
232
- fallback_copy.update(parsed_payload)
233
-
234
- try:
235
- return TriageAction.model_validate(fallback_copy)
236
- except Exception:
237
- return TriageAction.model_validate(FALLBACK_ACTION)
238
-
239
-
240
- def action_to_log_string(action: TriageAction) -> str:
241
- """Return single-line action string for required STEP logging."""
242
- return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
243
-
244
-
245
- def run_episode(
246
- client: OpenAI,
247
- model_name: str,
248
- task_id: str,
249
- scenario_index: int,
250
- eval_split: str,
251
- deadline: float,
252
- request_timeout_seconds: float,
253
- runtime_options: dict[str, Any] | None = None,
254
- ) -> None:
255
- """Run one episode and emit strict START/STEP/END lines."""
256
- rewards: list[float] = []
257
- steps_taken = 0
258
- success = False
259
- env: EmailTriageEnv | None = None
260
-
261
- log_start(task_name=task_id, benchmark_name=BENCHMARK, model_name=model_name)
262
-
263
- try:
264
- env = EmailTriageEnv(
265
- task_id=task_id,
266
- scenario_index=scenario_index,
267
- split=eval_split,
268
- runtime_options=runtime_options,
269
- )
270
- reset_result = env.reset()
271
- observation = reset_result.observation
272
- history: list[str] = []
273
-
274
- for step in range(1, MAX_STEPS + 1):
275
- if time.monotonic() >= deadline:
276
- break
277
-
278
- prompt = build_user_prompt(observation, history)
279
-
280
- response_text = ""
281
- try:
282
- remaining = max(1.0, deadline - time.monotonic())
283
- timeout_seconds = max(
284
- 1.0,
285
- min(float(request_timeout_seconds), float(remaining)),
286
- )
287
- completion = client.chat.completions.create(
288
- model=model_name,
289
- messages=[
290
- {"role": "system", "content": SYSTEM_PROMPT},
291
- {"role": "user", "content": prompt},
292
- ],
293
- temperature=TEMPERATURE,
294
- max_tokens=MAX_TOKENS,
295
- stream=False,
296
- timeout=timeout_seconds,
297
- )
298
- response_text = completion.choices[0].message.content or ""
299
- except Exception:
300
- response_text = ""
301
-
302
- action = parse_action_response(response_text)
303
- step_result = env.step(action)
304
-
305
- reward = float(step_result.reward)
306
- done = bool(step_result.done)
307
- error_raw = step_result.info.get("validation_error")
308
- error = str(error_raw) if isinstance(error_raw, str) else None
309
-
310
- rewards.append(reward)
311
- steps_taken = step
312
-
313
- log_step(
314
- step=step,
315
- action_str=action_to_log_string(action),
316
- reward=reward,
317
- done=done,
318
- error=error,
319
- )
320
-
321
- history.append(
322
- f"step={step} action={action.label}/{action.route_to} reward={_format_open_score(reward)}"
323
- )
324
- observation = step_result.observation
325
-
326
- if done:
327
- break
328
-
329
- if not rewards:
330
- rewards.append(LOG_SCORE_EPSILON)
331
-
332
- avg_reward = _strict_task_score(sum(rewards) / len(rewards))
333
- success = avg_reward >= SUCCESS_SCORE_THRESHOLD
334
- except Exception:
335
- if not rewards:
336
- rewards.append(LOG_SCORE_EPSILON)
337
- success = False
338
- finally:
339
- if env is not None:
340
- close_method = getattr(env, "close", None)
341
- if callable(close_method):
342
- try:
343
- close_method()
344
- except Exception:
345
- pass
346
-
347
- log_end(success=success, steps=steps_taken, rewards=rewards)
348
-
349
-
350
- def main() -> None:
351
- """Entrypoint for running one or many tasks with strict stdout logs."""
352
- args = parse_args()
353
- deadline = time.monotonic() + max(args.runtime_budget_seconds, 1)
354
- request_timeout_seconds = max(float(args.request_timeout_seconds), 1.0)
355
-
356
- try:
357
- effective_model = validate_runtime_config(args.model)
358
- except ValueError as error:
359
- print(str(error), flush=True)
360
- raise SystemExit(1) from error
361
-
362
- _ = LOCAL_IMAGE_NAME
363
-
364
- client = OpenAI(
365
- base_url=API_BASE_URL,
366
- api_key=API_KEY,
367
- )
368
-
369
- task_ids = [TASK_MAP[args.task]] if args.task in TASK_MAP else list(TASK_MAP.values())
370
- for task_id in task_ids:
371
- runtime_options = None
372
- if task_id == "task_production":
373
- runtime_options = {
374
- "production_profile": args.production_profile,
375
- "business_hours_mode": args.business_hours_mode,
376
- "escalation_mode": args.escalation_mode,
377
- }
378
- for scenario_index in range(max(args.episodes_per_task, 1)):
379
- run_episode(
380
- client=client,
381
- model_name=effective_model,
382
- task_id=task_id,
383
- scenario_index=scenario_index,
384
- eval_split=args.split,
385
- deadline=deadline,
386
- request_timeout_seconds=request_timeout_seconds,
387
- runtime_options=runtime_options,
388
- )
389
-
390
-
391
- if __name__ == "__main__":
392
- main()