Imaginephoenix commited on
Commit
e721fd9
·
verified ·
1 Parent(s): 7e241be

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +403 -0
  2. models.py +89 -0
inference.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference script for OpenEnv email triage with strict stdout event format."""
2
+
3
+ import argparse
4
+ import json
5
+ import os
6
+ import re
7
+ import time
8
+ from typing import Any
9
+
10
+ from openai import OpenAI
11
+
12
+ from environment import EmailTriageEnv
13
+ from models import EmailObservation, TriageAction
14
+
15
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
16
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+ API_KEY = HF_TOKEN or os.getenv("API_KEY")
19
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
20
+
21
+ BENCHMARK = "openenv-email-triage"
22
+ MAX_STEPS = 30
23
+ TEMPERATURE = 0.2
24
+ MAX_TOKENS = 200
25
+ SUCCESS_SCORE_THRESHOLD = 0.5
26
+ LOG_SCORE_EPSILON = 1e-2
27
+ DEFAULT_RUNTIME_BUDGET_SECONDS = int(os.getenv("INFERENCE_RUNTIME_BUDGET_SECONDS", "1140"))
28
+ DEFAULT_REQUEST_TIMEOUT_SECONDS = float(os.getenv("INFERENCE_REQUEST_TIMEOUT_SECONDS", "12"))
29
+
30
+ SYSTEM_PROMPT = (
31
+ "You are an email triage assistant. For each email, prioritize risk/time impact, "
32
+ "categorize with one label (urgent|normal|spam|archive), route to the best team, "
33
+ "and summarize the key evidence. Return one JSON object with keys label, summary, route_to."
34
+ )
35
+
36
+ FALLBACK_ACTION = {
37
+ "label": "normal",
38
+ "summary": "Unable to parse response",
39
+ "route_to": "general",
40
+ }
41
+
42
+ TASK_MAP = {
43
+ "1": "task_easy",
44
+ "2": "task_medium",
45
+ "3": "task_hard",
46
+ "4": "task_production",
47
+ }
48
+
49
+
50
+ def parse_args() -> argparse.Namespace:
51
+ """Parse command-line arguments for task and optional model override."""
52
+ parser = argparse.ArgumentParser(description="Run OpenEnv email triage inference.")
53
+ parser.add_argument(
54
+ "--task",
55
+ default="all",
56
+ choices=["1", "2", "3", "4", "all"],
57
+ help="Task selection: 1, 2, 3, 4, or all.",
58
+ )
59
+ parser.add_argument(
60
+ "--model",
61
+ default=None,
62
+ help="Optional model override. Falls back to MODEL_NAME environment variable.",
63
+ )
64
+ parser.add_argument(
65
+ "--split",
66
+ default=os.getenv("OPENENV_EVAL_SPLIT", "public"),
67
+ choices=["public", "private_eval"],
68
+ help="Scenario split to evaluate.",
69
+ )
70
+ parser.add_argument(
71
+ "--episodes-per-task",
72
+ default=1,
73
+ type=int,
74
+ help="Number of deterministic scenarios to evaluate per task.",
75
+ )
76
+ parser.add_argument(
77
+ "--runtime-budget-seconds",
78
+ default=DEFAULT_RUNTIME_BUDGET_SECONDS,
79
+ type=int,
80
+ help="Global wall-clock budget for the full script run.",
81
+ )
82
+ parser.add_argument(
83
+ "--request-timeout-seconds",
84
+ default=DEFAULT_REQUEST_TIMEOUT_SECONDS,
85
+ type=float,
86
+ help="Timeout per LLM request.",
87
+ )
88
+ parser.add_argument(
89
+ "--production-profile",
90
+ default="standard",
91
+ choices=["light", "standard", "heavy"],
92
+ help="Runtime workload profile used for task 4 episodes.",
93
+ )
94
+ parser.add_argument(
95
+ "--business-hours-mode",
96
+ action="store_true",
97
+ help="If set, task 4 timestamps focus on business-hours windows.",
98
+ )
99
+ parser.add_argument(
100
+ "--escalation-mode",
101
+ default="normal",
102
+ choices=["low", "normal", "high"],
103
+ help="Escalation strictness for task 4 follow-up generation.",
104
+ )
105
+ return parser.parse_args()
106
+
107
+
108
+ def validate_runtime_config(model_name: str | None) -> str:
109
+ """Validate required runtime settings and return effective model name."""
110
+ if not API_KEY:
111
+ raise ValueError("Missing HF_TOKEN or API_KEY environment variable.")
112
+
113
+ effective_model = model_name or MODEL_NAME
114
+ return effective_model
115
+
116
+
117
+ def log_start(task_name: str, benchmark_name: str, model_name: str) -> None:
118
+ """Emit mandatory START line."""
119
+ print(
120
+ f"[START] task={task_name} env={benchmark_name} model={model_name}",
121
+ flush=True,
122
+ )
123
+
124
+
125
+ def _format_open_score(value: float) -> str:
126
+ """Format scores in strict-open range while preserving .2f log contract."""
127
+ clamped = max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(value)))
128
+ return f"{clamped:.2f}"
129
+
130
+
131
+ def _strict_task_score(raw_score: float) -> float:
132
+ """Return task score in strict-open interval for evaluator compatibility."""
133
+ return max(LOG_SCORE_EPSILON, min(1.0 - LOG_SCORE_EPSILON, float(raw_score)))
134
+
135
+
136
+ def log_step(step: int, action_str: str, reward: float, done: bool, error: str | None) -> None:
137
+ """Emit mandatory STEP line."""
138
+ error_value = error if error else "null"
139
+ done_value = str(done).lower()
140
+ print(
141
+ f"[STEP] step={step} action={action_str} reward={_format_open_score(reward)} "
142
+ f"done={done_value} error={error_value}",
143
+ flush=True,
144
+ )
145
+
146
+
147
+ def log_end(success: bool, steps: int, rewards: list[float], task_score: float) -> None:
148
+ """Emit mandatory END line."""
149
+ rewards_str = ",".join(_format_open_score(reward) for reward in rewards)
150
+ strict_task_score = _strict_task_score(task_score)
151
+ print(
152
+ f"[END] task_score={_format_open_score(strict_task_score)} "
153
+ f"success={str(success).lower()} steps={steps} rewards={rewards_str}",
154
+ flush=True,
155
+ )
156
+
157
+
158
+ def build_user_prompt(observation: EmailObservation, history: list[str]) -> str:
159
+ """Build model prompt from current observation and recent history."""
160
+ recent_history = "\n".join(history[-5:]) if history else "None"
161
+ return (
162
+ f"email_id: {observation.email_id}\n"
163
+ f"subject: {observation.subject}\n"
164
+ f"sender: {observation.sender}\n"
165
+ f"timestamp: {observation.timestamp}\n"
166
+ f"body: {observation.body}\n"
167
+ f"thread_history: {observation.thread_history}\n"
168
+ f"task_id: {observation.task_id}\n"
169
+ f"step_number: {observation.step_number}\n"
170
+ f"total_emails: {observation.total_emails}\n\n"
171
+ f"recent_history:\n{recent_history}\n\n"
172
+ "Return exactly one JSON object with label, summary, route_to."
173
+ )
174
+
175
+
176
+ def strip_action_prefixes(response_text: str) -> str:
177
+ """Remove common formatting wrappers before parsing model output."""
178
+ cleaned = response_text.strip()
179
+ cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
180
+ cleaned = re.sub(r"```$", "", cleaned).strip()
181
+ cleaned = re.sub(r"^(next\s+action|action)\s*:\s*", "", cleaned, flags=re.IGNORECASE)
182
+ return cleaned.strip()
183
+
184
+
185
+ def parse_text_action(cleaned_text: str) -> dict[str, str]:
186
+ """Parse action from free-form text with deterministic regex fallback."""
187
+ result: dict[str, str] = {}
188
+
189
+ label_match = re.search(
190
+ r"(?:\"label\"|label)\s*[:=]\s*\"?(urgent|normal|spam|archive)\"?",
191
+ cleaned_text,
192
+ flags=re.IGNORECASE,
193
+ )
194
+ if label_match:
195
+ result["label"] = label_match.group(1).lower()
196
+
197
+ route_match = re.search(
198
+ r"(?:\"route_to\"|route_to|route)\s*[:=]\s*\"?([a-zA-Z0-9_\-/ ]+)\"?",
199
+ cleaned_text,
200
+ flags=re.IGNORECASE,
201
+ )
202
+ if route_match:
203
+ result["route_to"] = route_match.group(1).strip().lower()
204
+
205
+ summary_match = re.search(
206
+ r"(?:\"summary\"|summary)\s*[:=]\s*\"?([^\"\n]+)\"?",
207
+ cleaned_text,
208
+ flags=re.IGNORECASE,
209
+ )
210
+ if summary_match:
211
+ result["summary"] = summary_match.group(1).strip()
212
+
213
+ return result
214
+
215
+
216
+ def parse_action_response(response_text: str) -> TriageAction:
217
+ """Parse model response into a valid TriageAction with fallback behavior."""
218
+ cleaned_text = strip_action_prefixes(response_text)
219
+ parsed_payload: dict[str, Any] = {}
220
+
221
+ json_start = cleaned_text.find("{")
222
+ json_end = cleaned_text.rfind("}")
223
+ if json_start != -1 and json_end != -1 and json_end > json_start:
224
+ candidate = cleaned_text[json_start : json_end + 1]
225
+ try:
226
+ loaded = json.loads(candidate)
227
+ if isinstance(loaded, dict):
228
+ parsed_payload = loaded
229
+ except json.JSONDecodeError:
230
+ parsed_payload = {}
231
+
232
+ if not parsed_payload:
233
+ parsed_payload = parse_text_action(cleaned_text)
234
+
235
+ fallback_copy = dict(FALLBACK_ACTION)
236
+ fallback_copy.update(parsed_payload)
237
+
238
+ try:
239
+ return TriageAction.model_validate(fallback_copy)
240
+ except Exception:
241
+ return TriageAction.model_validate(FALLBACK_ACTION)
242
+
243
+
244
+ def action_to_log_string(action: TriageAction) -> str:
245
+ """Return single-line action string for required STEP logging."""
246
+ return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
247
+
248
+
249
+ def run_episode(
250
+ client: OpenAI,
251
+ model_name: str,
252
+ task_id: str,
253
+ scenario_index: int,
254
+ eval_split: str,
255
+ deadline: float,
256
+ request_timeout_seconds: float,
257
+ runtime_options: dict[str, Any] | None = None,
258
+ ) -> None:
259
+ """Run one episode and emit strict START/STEP/END lines."""
260
+ rewards: list[float] = []
261
+ steps_taken = 0
262
+ success = False
263
+ final_task_score = LOG_SCORE_EPSILON
264
+ env: EmailTriageEnv | None = None
265
+
266
+ log_start(task_name=task_id, benchmark_name=BENCHMARK, model_name=model_name)
267
+
268
+ try:
269
+ env = EmailTriageEnv(
270
+ task_id=task_id,
271
+ scenario_index=scenario_index,
272
+ split=eval_split,
273
+ runtime_options=runtime_options,
274
+ )
275
+ reset_result = env.reset()
276
+ observation = reset_result.observation
277
+ history: list[str] = []
278
+
279
+ for step in range(1, MAX_STEPS + 1):
280
+ if time.monotonic() >= deadline:
281
+ break
282
+
283
+ prompt = build_user_prompt(observation, history)
284
+
285
+ response_text = ""
286
+ try:
287
+ remaining = max(1.0, deadline - time.monotonic())
288
+ timeout_seconds = max(
289
+ 1.0,
290
+ min(float(request_timeout_seconds), float(remaining)),
291
+ )
292
+ completion = client.chat.completions.create(
293
+ model=model_name,
294
+ messages=[
295
+ {"role": "system", "content": SYSTEM_PROMPT},
296
+ {"role": "user", "content": prompt},
297
+ ],
298
+ temperature=TEMPERATURE,
299
+ max_tokens=MAX_TOKENS,
300
+ stream=False,
301
+ timeout=timeout_seconds,
302
+ )
303
+ response_text = completion.choices[0].message.content or ""
304
+ except Exception:
305
+ response_text = ""
306
+
307
+ action = parse_action_response(response_text)
308
+ step_result = env.step(action)
309
+
310
+ reward = _strict_task_score(float(step_result.reward))
311
+ done = bool(step_result.done)
312
+ error_raw = step_result.info.get("validation_error")
313
+ error = str(error_raw) if isinstance(error_raw, str) else None
314
+
315
+ rewards.append(reward)
316
+ steps_taken = step
317
+
318
+ log_step(
319
+ step=step,
320
+ action_str=action_to_log_string(action),
321
+ reward=reward,
322
+ done=done,
323
+ error=error,
324
+ )
325
+
326
+ history.append(
327
+ f"step={step} action={action.label}/{action.route_to} reward={_format_open_score(reward)}"
328
+ )
329
+ observation = step_result.observation
330
+
331
+ if done:
332
+ break
333
+
334
+ if not rewards:
335
+ rewards.append(LOG_SCORE_EPSILON)
336
+
337
+ final_task_score = _strict_task_score(sum(rewards) / len(rewards))
338
+ success = final_task_score >= SUCCESS_SCORE_THRESHOLD
339
+ except Exception:
340
+ if not rewards:
341
+ rewards.append(LOG_SCORE_EPSILON)
342
+ final_task_score = _strict_task_score(sum(rewards) / len(rewards))
343
+ success = False
344
+ finally:
345
+ if env is not None:
346
+ close_method = getattr(env, "close", None)
347
+ if callable(close_method):
348
+ try:
349
+ close_method()
350
+ except Exception:
351
+ pass
352
+
353
+ log_end(
354
+ success=success,
355
+ steps=steps_taken,
356
+ rewards=rewards,
357
+ task_score=final_task_score,
358
+ )
359
+
360
+
361
+ def main() -> None:
362
+ """Entrypoint for running one or many tasks with strict stdout logs."""
363
+ args = parse_args()
364
+ deadline = time.monotonic() + max(args.runtime_budget_seconds, 1)
365
+ request_timeout_seconds = max(float(args.request_timeout_seconds), 1.0)
366
+
367
+ try:
368
+ effective_model = validate_runtime_config(args.model)
369
+ except ValueError as error:
370
+ print(str(error), flush=True)
371
+ raise SystemExit(1) from error
372
+
373
+ _ = LOCAL_IMAGE_NAME
374
+
375
+ client = OpenAI(
376
+ base_url=API_BASE_URL,
377
+ api_key=API_KEY,
378
+ )
379
+
380
+ task_ids = [TASK_MAP[args.task]] if args.task in TASK_MAP else list(TASK_MAP.values())
381
+ for task_id in task_ids:
382
+ runtime_options = None
383
+ if task_id == "task_production":
384
+ runtime_options = {
385
+ "production_profile": args.production_profile,
386
+ "business_hours_mode": args.business_hours_mode,
387
+ "escalation_mode": args.escalation_mode,
388
+ }
389
+ for scenario_index in range(max(args.episodes_per_task, 1)):
390
+ run_episode(
391
+ client=client,
392
+ model_name=effective_model,
393
+ task_id=task_id,
394
+ scenario_index=scenario_index,
395
+ eval_split=args.split,
396
+ deadline=deadline,
397
+ request_timeout_seconds=request_timeout_seconds,
398
+ runtime_options=runtime_options,
399
+ )
400
+
401
+
402
+ if __name__ == "__main__":
403
+ main()
models.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for the OpenEnv email triage environment."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel, field_validator
6
+
7
+ OPEN_INTERVAL_EPSILON = 1e-2
8
+
9
+
10
+ def _strict_open_unit_interval(raw_value: float) -> float:
11
+ """Clamp numeric values to the strict open interval (0, 1)."""
12
+ numeric_value = float(raw_value)
13
+ if numeric_value <= 0.0:
14
+ return OPEN_INTERVAL_EPSILON
15
+ if numeric_value >= 1.0:
16
+ return 1.0 - OPEN_INTERVAL_EPSILON
17
+ return numeric_value
18
+
19
+
20
+ class EmailObservation(BaseModel):
21
+ """Represents the email context visible to the agent at each step."""
22
+
23
+ email_id: str
24
+ subject: str
25
+ body: str
26
+ sender: str
27
+ timestamp: str
28
+ thread_history: list[str]
29
+ task_id: str
30
+ step_number: int
31
+ total_emails: int
32
+
33
+
34
+ class TriageAction(BaseModel):
35
+ """Represents the action chosen by the agent for an email."""
36
+
37
+ label: Literal["urgent", "normal", "spam", "archive"]
38
+ summary: str
39
+ route_to: str
40
+
41
+
42
+ class RewardResult(BaseModel):
43
+ """Represents deterministic grading output before reward shaping."""
44
+
45
+ score: float
46
+ breakdown: dict[str, float]
47
+ feedback: str
48
+
49
+ @field_validator("score")
50
+ @classmethod
51
+ def _validate_score(cls, value: float) -> float:
52
+ return _strict_open_unit_interval(value)
53
+
54
+
55
+ class EnvironmentState(BaseModel):
56
+ """Represents full internal environment state for debugging and evaluation."""
57
+
58
+ task_id: str
59
+ current_step: int
60
+ total_steps: int
61
+ done: bool
62
+ action_history: list[TriageAction]
63
+ reward_history: list[float]
64
+
65
+ @field_validator("reward_history")
66
+ @classmethod
67
+ def _validate_reward_history(cls, values: list[float]) -> list[float]:
68
+ return [_strict_open_unit_interval(value) for value in values]
69
+
70
+
71
+ class StepResult(BaseModel):
72
+ """Represents the standardized output of environment step calls."""
73
+
74
+ observation: EmailObservation
75
+ reward: float
76
+ done: bool
77
+ info: dict[str, str | int | float | bool]
78
+
79
+ @field_validator("reward")
80
+ @classmethod
81
+ def _validate_reward(cls, value: float) -> float:
82
+ return _strict_open_unit_interval(value)
83
+
84
+
85
+ class ResetResult(BaseModel):
86
+ """Represents the standardized output of environment reset calls."""
87
+
88
+ observation: EmailObservation
89
+ info: dict[str, str | int | float | bool]