File size: 1,245 Bytes
c357a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e31d8c
c357a18
 
 
0e31d8c
 
c357a18
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Any, Dict, List, Optional


def log_start(task_id: str, benchmark: str, model_name: str) -> None:
    print(f"[START] task={task_id} env={benchmark} model={model_name}", flush=True)


def log_step(
    step: int, action: Dict[str, Any], reward: float, done: bool, error: Optional[str]
) -> None:
    action_str = action_to_str(action)
    done_str = "true" if done else "false"
    error_str = error if error else "null"
    print(
        f"[STEP] step={step} action={action_str} reward={reward:.2f} done={done_str} error={error_str}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    success_str = "true" if success else "false"
    print(
        f"[END] success={success_str} steps={steps} score={score:.4f} rewards={rewards_str}",
        flush=True,
    )


def action_to_str(action: Dict[str, Any]) -> str:
    parts = [action.get("action_type", "skip")]
    if action.get("email_id"):
        parts.append(action["email_id"])
    if action.get("category"):
        parts.append(action["category"])
    if action.get("urgency"):
        parts.append(action["urgency"])
    return ":".join(parts)