"""Async baseline inference runner for Supermail.""" from __future__ import annotations import asyncio import json import os from dataclasses import dataclass from typing import Any, List, Optional from openai import OpenAI try: from dotenv import load_dotenv except ImportError: # pragma: no cover def load_dotenv() -> bool: return False from client import SupermailEnv from models import SupportAction, SupportObservation from server.environment import SupermailEnvironment from sys_prompt import SYSTEM_PROMPT from tasks import ALL_TASKS, TASKS_BY_ID load_dotenv() API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") HF_TOKEN = os.getenv("HF_TOKEN") LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") BASE_URL = os.getenv("SUPERMAIL_BASE_URL") or os.getenv("SUPPORT_SIM_BASE_URL") TASK_NAME = os.getenv("SUPERMAIL_TASK") or os.getenv("SUPPORT_SIM_TASK", "all") BENCHMARK = os.getenv("SUPERMAIL_BENCHMARK") or os.getenv("SUPPORT_SIM_BENCHMARK", "supermail") MAX_STEPS = 12 TEMPERATURE = 0.4 MAX_TOKENS = 25000 SUCCESS_SCORE_THRESHOLD = 0.95 MIN_SCORE = 0.01 MAX_SCORE = 0.99 @dataclass class LocalStepResult: """Minimal local stand-in for OpenEnv StepResult.""" observation: SupportObservation reward: float done: bool class LocalSupermailSession: """Async adapter for direct local environment usage.""" def __init__(self, task_id: str): self._env = SupermailEnvironment(task_id=task_id) async def reset(self) -> LocalStepResult: observation = self._env.reset() return LocalStepResult( observation=observation, reward=observation.reward or 0.0, done=observation.done, ) async def step(self, action: SupportAction) -> LocalStepResult: observation = self._env.step(action) return LocalStepResult( observation=observation, reward=observation.reward or 0.0, done=observation.done, ) async def close(self) -> None: self._env.close() def sanitize(value: Any) -> str: """Keep log output on a single line.""" text = str(value) return " ".join(text.replace("\r", " ").replace("\n", " ").split()) def clamp_score(score: float) -> float: """Clamp score into the open interval (0, 1).""" return min(max(score, MIN_SCORE), MAX_SCORE) def compact_action(action: Optional[SupportAction]) -> str: """Serialize an action for the required log format.""" if action is None: return "null" payload = { field_name: getattr(action, field_name) for field_name in ("priority", "category", "action", "notes") if getattr(action, field_name, None) } return json.dumps(payload, separators=(",", ":"), sort_keys=True) def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step( *, step: int, action: Optional[SupportAction], reward: float, done: bool, error: Optional[str], ) -> None: error_text = error if error else "null" print( "[STEP] " f"step={step} " f"action={sanitize(compact_action(action))} " f"reward={reward:.2f} " f"done={'true' if done else 'false'} " f"error={sanitize(error_text)}", flush=True, ) def log_end(*, success: bool, steps: int, score: float, rewards: List[float]) -> None: reward_text = ",".join(f"{reward:.2f}" for reward in rewards) print( f"[END] success={'true' if success else 'false'} " f"steps={steps} score={score:.2f} rewards={reward_text}", flush=True, ) def build_client() -> Optional[OpenAI]: """Create an OpenAI client when credentials are available.""" if not HF_TOKEN: return None return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) def heuristic_action(observation: SupportObservation) -> SupportAction: """Deterministic fallback policy for the bundled tasks.""" text = f"{observation.email} {json.dumps(observation.context, sort_keys=True)}".lower() if any( token in text for token in ( "click here", "gift card", "crypto", "lottery", "unsubscribe", "bypass all metrics", "encrypted emergency", "decrypt tool", "emergency slot", "override the normal queue", "sender_verified\": \"false", "spoofed sender", ) ): priority = "spam" elif any( token in text for token in ( "today", "payroll closes", "500 error", "blocked", "backing up", "immediately", "double", "charged again", ) ): priority = "urgent" else: priority = "normal" if any(token in text for token in ("charge", "charged", "invoice", "refund", "billing", "subscription")): category = "billing" elif any(token in text for token in ("tracking", "shipment", "delivery", "delivered", "ship")): category = "delivery" elif any(token in text for token in ("error", "login", "outage", "crash", "bug", "sign in")): category = "technical" else: category = "general" if priority == "spam": next_action = "ignore" elif category == "technical": next_action = "assign_to_team" elif priority == "urgent": next_action = "respond_immediately" elif category == "delivery": next_action = "assign_to_team" else: next_action = "respond_immediately" payload: dict[str, str] = {} if "priority" in observation.required_fields: payload["priority"] = priority if "category" in observation.required_fields: payload["category"] = category if "action" in observation.required_fields: payload["action"] = next_action return SupportAction(**payload) def get_model_action( client: OpenAI, observation: SupportObservation, history: List[str], ) -> SupportAction: """Use the OpenAI client for the next action.""" prompt = { "task_id": observation.task_id, "benchmark": observation.benchmark, "objective": observation.objective, "required_fields": observation.required_fields, "allowed_values": observation.allowed_values, "email": observation.email, "context": observation.context, "history": history, "feedback": observation.feedback, } response = client.chat.completions.create( model=MODEL_NAME, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": json.dumps(prompt, ensure_ascii=True)}, ], ) content = (response.choices[0].message.content or "").strip() payload = json.loads(content) filtered_payload = { key: value for key, value in payload.items() if key in {"priority", "category", "action", "notes"} } return SupportAction(**filtered_payload) def choose_action( client: Optional[OpenAI], observation: SupportObservation, history: List[str], ) -> SupportAction: """Use the model when available, otherwise fall back to heuristics.""" if client is None: return heuristic_action(observation) try: return get_model_action(client, observation, history) except Exception: return heuristic_action(observation) async def create_env(task_id: str): """Create the environment session using docker, base URL, or local fallback.""" if LOCAL_IMAGE_NAME: return await SupermailEnv.from_docker_image( LOCAL_IMAGE_NAME, env_vars={"SUPERMAIL_TASK": task_id}, ) if BASE_URL: env = SupermailEnv(base_url=BASE_URL) await env.connect() return env return LocalSupermailSession(task_id=task_id) async def run_episode(task_id: str, client: Optional[OpenAI]) -> None: """Run a single task episode and emit the required logs.""" if task_id not in TASKS_BY_ID: raise ValueError(f"Unknown task: {task_id}") env = None history: List[str] = [] rewards: List[float] = [] steps_taken = 0 score = MIN_SCORE success = False action_for_log: Optional[SupportAction] = None log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) try: env = await create_env(task_id) result = await env.reset() observation = result.observation for step in range(1, MAX_STEPS + 1): if result.done: break action_for_log = choose_action(client, observation, history) result = await env.step(action_for_log) observation = result.observation reward = result.reward or 0.0 done = result.done error = observation.metadata.get("last_action_error") rewards.append(reward) steps_taken = step score = clamp_score(float(getattr(observation, "score", 0.0))) log_step( step=step, action=action_for_log, reward=reward, done=done, error=error, ) history.append( f"step={step} action={compact_action(action_for_log)} " f"reward={reward:.2f} score={score:.2f}" ) if done: break success = score >= SUCCESS_SCORE_THRESHOLD except Exception as exc: log_step( step=steps_taken, action=action_for_log, reward=0.0, done=True, error=str(exc), ) finally: if env is not None: try: await env.close() except Exception: pass log_end(success=success, steps=steps_taken, score=score, rewards=rewards) def task_sequence() -> List[str]: """Resolve the requested task selection.""" if TASK_NAME == "all": return [task.task_id for task in ALL_TASKS] return [TASK_NAME] async def main() -> None: """Run one or more task episodes.""" client = build_client() for task_id in task_sequence(): await run_episode(task_id, client) if __name__ == "__main__": asyncio.run(main())