File size: 25,537 Bytes
e415506
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
821b7b8
e415506
821b7b8
e415506
 
 
 
 
 
 
 
 
 
821b7b8
e415506
 
 
 
 
 
 
 
 
 
 
 
 
821b7b8
 
 
 
4a30ec6
e415506
 
 
 
 
 
 
 
 
 
eb340e3
e415506
eb340e3
e415506
 
 
 
 
 
 
 
 
 
 
15a0c0f
 
e415506
 
 
 
 
 
 
15a0c0f
e415506
 
 
15a0c0f
e415506
 
 
 
9ff23e0
e415506
9ff23e0
e415506
 
 
 
4a30ec6
 
 
 
 
9ff23e0
 
 
 
 
00dd5b7
9ff23e0
 
00dd5b7
9ff23e0
00dd5b7
 
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb340e3
e415506
 
 
 
 
 
 
 
eb340e3
e415506
eb340e3
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a0c0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e415506
 
821b7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e415506
 
 
 
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a0c0f
e415506
 
 
 
 
 
 
 
 
15a0c0f
e415506
15a0c0f
e415506
 
 
 
 
 
 
 
891d2a4
e415506
 
 
 
 
 
891d2a4
 
e415506
 
 
 
 
 
 
 
 
4d725fc
 
e415506
 
 
 
 
 
 
 
821b7b8
 
 
 
 
 
 
 
 
 
15a0c0f
 
00dd5b7
15a0c0f
821b7b8
 
15a0c0f
 
 
00dd5b7
15a0c0f
 
821b7b8
 
 
 
 
 
 
 
15a0c0f
821b7b8
 
 
 
 
 
 
 
 
15a0c0f
00dd5b7
821b7b8
 
 
 
 
e415506
 
 
 
 
 
 
821b7b8
 
e415506
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
"""Tool-driven inference for the metric tracker RL environment."""

from __future__ import annotations

import asyncio
import json
import os
import sys
import textwrap
from dataclasses import dataclass, field
from typing import Any

from openenv.core.containers.runtime.providers import LocalDockerProvider
from openai import APIStatusError, OpenAI
from websockets.exceptions import ConnectionClosedError

from metric_tracker_rl import DEFAULT_TASK_ORDER, MetricTrackerRlAction, MetricTrackerRlEnv, get_task_spec
from metric_tracker_rl.analysis_tools import available_analysis_methods
from metric_tracker_rl.models import (
    MetricSubmissionRow,
    MetricTrackerRlObservation,
    PayloadGeneratorMethod,
)


IMAGE_NAME = (os.getenv("IMAGE_NAME") or "metric_tracker_rl:latest").strip()
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
API_BASE_URL = (
    os.getenv("API_BASE_URL")
    or os.getenv("OPENAI_BASE_URL")
    or "https://router.huggingface.co/v1"
)
MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL") or "Qwen/Qwen2.5-72B-Instruct"
BASE_URL = os.getenv("BASE_URL")
TASK_NAME = os.getenv("MetricTrackerRl_TASK", "multi_task_agent_baseline")
BENCHMARK = os.getenv("MetricTrackerRl_BENCHMARK", "metric_tracker_rl")
TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
MAX_TOKENS = min(int(os.getenv("MAX_TOKENS", "1000")), 4096)
MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "16"))
CONNECT_TIMEOUT_S = float(os.getenv("OPENENV_CONNECT_TIMEOUT_S", "30"))
MESSAGE_TIMEOUT_S = float(os.getenv("OPENENV_MESSAGE_TIMEOUT_S", "180"))
DOCKER_WAIT_TIMEOUT_S = float(os.getenv("OPENENV_DOCKER_WAIT_TIMEOUT_S", "120"))
TASK_RETRY_COUNT = int(os.getenv("OPENENV_TASK_RETRY_COUNT", "1"))
SCORE_EPSILON = float(os.getenv("OPENENV_SCORE_EPSILON", "0.000001"))

SYSTEM_PROMPT = textwrap.dedent(
    """
    You are solving a multi-anomaly analytics benchmark with tool use.

    Rules:
    - Use only the shared safe analysis methods.
    - Do not request full hidden answers or assume direct access to ground truth.
    - Prefer declarative payload generators over manual row construction.
    - Start from the default reset observation only.
    - Start by trying `get_rate_spike_from_median_rows` across different conversion metrics to learn which metrics produce useful anomaly rows.
    - Compare candidate metrics, then refine with raw-data inspection and median/std methods only when needed.
    - Prefer: task_overview -> get_rate_spike_from_median_rows on several conversion metrics -> compare useful results -> payload_generator -> submit_payload_generator.
    - Keep notes brief and factual.
    """
).strip()


@dataclass
class ToolRuntimeState:
    """Mutable state shared across tool calls."""

    method_log: list[dict[str, Any]] = field(default_factory=list)
    last_preview: dict[str, Any] | None = None
    rewards: list[float] = field(default_factory=list)
    steps: int = 0


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_method(tool_name: str, arguments: dict[str, Any], note: str) -> None:
    return None


def log_payload_generator_methods(tool_name: str, generator_methods: list[dict[str, Any]]) -> None:
    return None


def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
    error_val = error if error else "null"
    clipped_reward = bounded_log_reward(reward)
    print(
        f"[STEP] step={step} action={action} reward={clipped_reward:.2f} done={str(done).lower()} error={error_val}",
        flush=True,
    )


def bounded_task_score(score: float) -> float:
    """Clamp task scores to the open interval (0, 1)."""
    return min(1.0 - SCORE_EPSILON, max(SCORE_EPSILON, score))


def bounded_log_reward(reward: float) -> float:
    """Clamp logged scores and rewards to [0.01, 0.99]."""
    return min(0.99, max(0.01, reward))


def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
    clipped_score = bounded_log_reward(score)
    rewards_str = ",".join(f"{bounded_log_reward(reward):.2f}" for reward in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} score={clipped_score:.2f} rewards={rewards_str}",
        flush=True,
    )


def tool_schemas() -> list[dict[str, Any]]:
    """OpenAI-compatible tool definitions."""
    shared_schemas = []
    for spec in available_analysis_methods():
        properties = {}
        required = []
        if spec.name in {"rows_for_date", "hourly_rows_for_date", "detect_funnel_break", "check_impossible_counts"}:
            properties = {"date": {"type": "string"}}
            required = ["date"]
        elif spec.name in {"compare_rate_to_median", "compare_count_to_median"}:
            properties = {
                "date": {"type": "string"},
                "entity_name": {"type": "string"},
            }
            required = ["date", "entity_name"]
        elif spec.name == "list_suspicious_dates":
            properties = {"limit": {"type": "integer", "default": 10}}
        elif spec.name == "preview_submission":
            properties = {
                "rows": {
                    "type": "array",
                    "items": {"type": "object"},
                }
            }
        elif spec.name == "show_raw_data":
            properties = {"limit": {"type": "integer", "default": 5}}
        elif spec.name in {"get_metric_median", "get_metric_std_dev_from_median"}:
            properties = {
                "metric_name": {"type": "string"},
                "metric_names": {"type": "array", "items": {"type": "string"}},
            }
        elif spec.name == "get_rows_with_abs_diff_from_median_gt":
            properties = {
                "metric_name": {"type": "string"},
                "metric_names": {"type": "array", "items": {"type": "string"}},
                "threshold": {"type": "number"},
            }
            required = ["threshold"]
        elif spec.name in {
            "get_median_filter_rows",
            "get_rate_drop_from_median_rows",
            "get_rate_spike_from_median_rows",
            "get_absolute_drop_in_event_count_rows",
            "get_absolute_spike_in_event_count_rows",
        }:
            properties = {
                "metric_name": {"type": "string"},
                "metric_names": {"type": "array", "items": {"type": "string"}},
                "threshold_multiplier": {"type": "number"},
            }
            required = ["threshold_multiplier"]
        elif spec.name in {
            "get_funnel_break_rows",
            "get_hourly_traffic_mix_shift_rows",
            "get_instrumentation_data_quality_issue_rows",
        }:
            properties = {
                "threshold_multiplier": {"type": "number"},
            }
            required = ["threshold_multiplier"]
        elif spec.name == "payload_generator":
            properties = {
                "generator_methods": {
                    "type": "array",
                    "items": {"type": "object"},
                }
            }
            required = ["generator_methods"]
        shared_schemas.append(
            {
                "type": "function",
                "function": {
                    "name": spec.name,
                    "description": spec.description,
                    "parameters": {
                        "type": "object",
                        "properties": properties,
                        "required": required,
                        "additionalProperties": False,
                    },
                },
            }
        )
    shared_schemas.append(
        {
            "type": "function",
            "function": {
                "name": "submit_payload_generator",
                "description": "Submit declarative payload generator methods for environment-side payload generation and grading.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "generator_methods": {
                            "type": "array",
                            "items": {"type": "object"},
                        }
                    },
                    "required": ["generator_methods"],
                    "additionalProperties": False,
                },
            },
        }
    )
    shared_schemas.append(
        {
            "type": "function",
            "function": {
                "name": "submit_solution",
                "description": "Submit the final anomaly payload to the environment.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "rows": {
                            "type": "array",
                            "items": {"type": "object"},
                        }
                    },
                    "required": ["rows"],
                    "additionalProperties": False,
                },
            },
        }
    )
    return shared_schemas


def build_initial_user_prompt(observation: MetricTrackerRlObservation) -> str:
    return textwrap.dedent(
        f"""
        Solve the RL environment with tools.

        Initial observation:
        {json.dumps(observation.model_dump(exclude={"debug"}), indent=2)}

        Prefer building a payload generator first, then submit it.
        Start by calling `get_rate_spike_from_median_rows` on several conversion metrics and see which ones return useful anomaly rows.
        If a metric returns nothing or low-signal rows, try another metric.
        For funnel, hourly mix, or data-quality tasks, use the family-specific generator methods instead.

        Final payload rows use:
        `date`, `entity_type`, `entity_name`, `anomaly_type`, `detection_method`,
        `baseline_value`, `observed_value`, `delta_value`, `severity`.

        Supported generator method example:
        `{{"method_name":"get_rate_spike_from_median_rows","threshold_multiplier":2.0}}`
        or
        `{{"method_name":"get_rate_spike_from_median_rows","metric_names":["app_open_to_order_placed","app_open_to_payment_successful"],"threshold_multiplier":2.0}}`

        Use shared analysis methods only. Prefer `submit_payload_generator` over `submit_solution`.
        """
    ).strip()


def create_chat_completion(client: OpenAI, **kwargs):
    try:
        return client.chat.completions.create(**kwargs)
    except APIStatusError as exc:
        if exc.status_code == 402:
            raise RuntimeError(
                "The configured inference provider rejected the request with HTTP 402. "
                "Your Hugging Face router credits are depleted. Update `.env.inference` "
                "with a working provider/key, or switch `API_BASE_URL`/`MODEL_NAME`."
            ) from exc
        raise


def decode_arguments(raw_arguments: str | None) -> dict[str, Any]:
    if not raw_arguments:
        return {}
    return json.loads(raw_arguments)


def preview_text(text: str, limit: int = 220) -> str:
    return text.replace("\n", " ")[:limit]


def format_action(tool_name: str, arguments: dict[str, Any]) -> str:
    if not arguments:
        return f"{tool_name}()"
    return preview_text(f"{tool_name}({json.dumps(arguments, sort_keys=True)})")


def step_error(result: Any) -> str | None:
    message = getattr(result.observation, "message", None)
    return message if result.observation.status == "error" and message else None


def record_step(
    runtime_state: ToolRuntimeState,
    action: str,
    result: Any,
) -> None:
    reward = float(result.reward or 0.0)
    runtime_state.steps += 1
    runtime_state.rewards.append(reward)
    log_step(
        step=runtime_state.steps,
        action=action,
        reward=reward,
        done=bool(result.done),
        error=step_error(result),
    )


async def connect_env() -> MetricTrackerRlEnv:
    if BASE_URL:
        client = MetricTrackerRlEnv(
            base_url=BASE_URL,
            connect_timeout_s=CONNECT_TIMEOUT_S,
            message_timeout_s=MESSAGE_TIMEOUT_S,
        )
        return await client.connect()
    provider = LocalDockerProvider()
    base_url = provider.start_container(IMAGE_NAME)
    provider.wait_for_ready(base_url, timeout_s=DOCKER_WAIT_TIMEOUT_S)
    client = MetricTrackerRlEnv(
        base_url=base_url,
        connect_timeout_s=CONNECT_TIMEOUT_S,
        message_timeout_s=MESSAGE_TIMEOUT_S,
        provider=provider,
    )
    return await client.connect()


async def execute_tool_call(
    env: MetricTrackerRlEnv,
    observation: MetricTrackerRlObservation,
    runtime_state: ToolRuntimeState,
    tool_name: str,
    arguments: dict[str, Any],
) -> tuple[dict[str, Any], Any | None, MetricTrackerRlObservation]:
    """Execute one model-requested tool locally."""
    action = format_action(tool_name, arguments)
    if tool_name == "submit_payload_generator":
        methods = [
            PayloadGeneratorMethod(**item)
            for item in arguments.get("generator_methods", [])
        ]
        runtime_state.method_log.append(
            {
                "tool_name": tool_name,
                "arguments": arguments,
                "generator_methods": [item.model_dump() for item in methods],
                "note": _tool_note(tool_name, arguments),
            }
        )
        result = await env.step(MetricTrackerRlAction(payload_generators=methods))
        record_step(runtime_state, action, result)
        return (
            {
                "status": result.observation.status,
                "message": result.observation.message,
                "reward": result.reward,
                "done": result.done,
                "generated_rows": [row.model_dump() for row in result.observation.generated_rows],
                "submission_issues": [issue.model_dump() for issue in result.observation.submission_issues],
                "reward_breakdown": (
                    result.observation.reward_breakdown.model_dump()
                    if result.observation.reward_breakdown
                    else None
                ),
            },
            result,
            result.observation,
        )
    if tool_name == "submit_solution":
        rows = [MetricSubmissionRow(**row) for row in arguments.get("rows", [])]
        result = await env.step(MetricTrackerRlAction(classifications=rows))
        record_step(runtime_state, action, result)
        return (
            {
                "status": result.observation.status,
                "message": result.observation.message,
                "reward": result.reward,
                "done": result.done,
                "reward_breakdown": (
                    result.observation.reward_breakdown.model_dump()
                    if result.observation.reward_breakdown
                    else None
                ),
                "issue_count": len(result.observation.submission_issues),
                "correct_row_count": result.observation.correct_row_count,
            },
            result,
            result.observation,
        )

    result = await env.step(
        MetricTrackerRlAction(
            analysis_method=tool_name,
            analysis_args=arguments,
        )
    )
    record_step(runtime_state, action, result)
    output = result.observation.analysis_result or {
        "method": tool_name,
        "arguments": arguments,
        "result": None,
    }
    log_arguments = {
        "tool_name": tool_name,
        "arguments": arguments,
        "note": _tool_note(tool_name, arguments),
    }
    if tool_name == "payload_generator":
        log_arguments["generator_methods"] = arguments.get("generator_methods", [])
    runtime_state.method_log.append(
        log_arguments
    )
    if tool_name == "preview_submission":
        runtime_state.last_preview = output
    return output, None, result.observation


def _tool_note(tool_name: str, arguments: dict[str, Any]) -> str:
    notes = {
        "task_overview": "bootstrap the task and payload schema",
        "list_dates": "confirm the date range",
        "list_entities": "confirm valid entities",
        "rows_for_date": "inspect daily counts on one date",
        "hourly_rows_for_date": "inspect hourly traffic shape",
        "compare_rate_to_median": "check a conversion-rate anomaly against median baseline",
        "compare_count_to_median": "check an absolute count anomaly against median baseline",
        "detect_funnel_break": "test whether a funnel step is broken",
        "check_impossible_counts": "test for instrumentation or impossible count issues",
        "list_suspicious_dates": "prioritize dates worth deeper inspection",
        "preview_submission": "validate payload structure before submit",
        "show_raw_data": "inspect daily aggregate rows in head() form",
        "get_metric_median": "measure a baseline median for one metric",
        "get_metric_std_dev_from_median": "measure metric spread around the median",
        "get_rows_with_abs_diff_from_median_gt": "inspect dates outside a chosen absolute-difference threshold",
        "get_median_filter_rows": "generate candidate anomaly rows using median and std-from-median filtering",
        "get_rate_drop_from_median_rows": "generate candidate conversion-rate drop rows using median and std-from-median filtering",
        "get_rate_spike_from_median_rows": "generate candidate conversion-rate spike rows using median and std-from-median filtering",
        "get_absolute_drop_in_event_count_rows": "generate candidate event-count drop rows using median and std-from-median filtering",
        "get_absolute_spike_in_event_count_rows": "generate candidate event-count spike rows using median and std-from-median filtering",
        "get_funnel_break_rows": "generate candidate funnel-break rows across funnel steps",
        "get_hourly_traffic_mix_shift_rows": "generate candidate hourly traffic mix shift rows across dates",
        "get_instrumentation_data_quality_issue_rows": "generate candidate impossible-count or instrumentation-issue rows across dates",
        "payload_generator": "merge multiple generator methods into one candidate payload",
        "submit_payload_generator": "submit generator methods for environment-side generation and grading",
    }
    return notes.get(tool_name, f"run {tool_name} with {arguments}")


async def run_agent_loop(
    client: OpenAI,
    env: MetricTrackerRlEnv,
    observation: MetricTrackerRlObservation,
) -> tuple[Any, str, int, list[dict[str, Any]], ToolRuntimeState]:
    """Run a tool-calling loop until the env is solved or the round limit is hit."""
    runtime_state = ToolRuntimeState()
    current_observation = observation
    messages: list[dict[str, Any]] = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": build_initial_user_prompt(current_observation)},
    ]
    last_result = None
    final_text = ""
    tool_rounds = 0

    for _ in range(MAX_TOOL_ROUNDS):
        completion = create_chat_completion(
            client,
            model=MODEL_NAME,
            messages=messages,
            tools=tool_schemas(),
            tool_choice="auto",
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            stream=False,
        )
        message = completion.choices[0].message
        assistant_payload: dict[str, Any] = {
            "role": "assistant",
            "content": message.content or "",
        }
        if message.tool_calls:
            assistant_payload["tool_calls"] = [
                {
                    "id": tool_call.id,
                    "type": tool_call.type,
                    "function": {
                        "name": tool_call.function.name,
                        "arguments": tool_call.function.arguments,
                    },
                }
                for tool_call in message.tool_calls
            ]
        messages.append(assistant_payload)

        if not message.tool_calls:
            final_text = (message.content or "").strip()
            break

        tool_rounds += 1
        for tool_call in message.tool_calls:
            tool_name = tool_call.function.name
            arguments = decode_arguments(tool_call.function.arguments)
            if tool_name != "submit_solution":
                log_method(tool_name, arguments, _tool_note(tool_name, arguments))
            if tool_name in {"payload_generator", "submit_payload_generator"}:
                log_payload_generator_methods(
                    tool_name,
                    arguments.get("generator_methods", []),
                )
            tool_output, maybe_result, current_observation = await execute_tool_call(
                env,
                current_observation,
                runtime_state,
                tool_name,
                arguments,
            )
            messages.append(
                {
                    "role": "tool",
                    "tool_call_id": tool_call.id,
                    "content": json.dumps(tool_output),
                }
            )
            if maybe_result is not None:
                last_result = maybe_result

        if last_result is not None:
            completion = create_chat_completion(
                client,
                model=MODEL_NAME,
                messages=messages,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
                stream=False,
            )
            final_text = (completion.choices[0].message.content or "").strip()
            break

    return last_result, final_text, tool_rounds, runtime_state.method_log, runtime_state


async def run_single_task(
    client: OpenAI,
    env: MetricTrackerRlEnv,
    task_id: str,
) -> dict[str, Any]:
    """Run one named benchmark task and return a reproducible summary."""
    task_spec = get_task_spec(task_id)
    log_start(task=task_spec.task_id, env=BENCHMARK, model=MODEL_NAME)
    reset_result = await env.reset(task_id=task_spec.task_id)
    final_result, final_text, tool_rounds, method_log, runtime_state = await run_agent_loop(
        client,
        env,
        reset_result.observation,
    )
    if final_result is None:
        raise RuntimeError(f"The model never submitted a graded action for task `{task_spec.task_id}`.")

    reward = float(final_result.reward or 0.0)
    task_score = bounded_task_score(reward)
    success = bool(final_result.done and reward >= 0.999999)
    return {
        "task_id": task_spec.task_id,
        "difficulty": task_spec.difficulty,
        "objective": task_spec.objective,
        "grader_name": task_spec.grader_name,
        "score": task_score,
        "normalized_score": task_score,
        "done": final_result.done,
        "success": success,
        "final_status": final_result.observation.status,
        "final_message": final_result.observation.message,
        "issue_count": len(final_result.observation.submission_issues),
        "correct_row_count": final_result.observation.correct_row_count,
        "expected_row_count": final_result.observation.expected_row_count,
        "tool_rounds": tool_rounds,
        "assistant_summary": final_text,
        "steps": runtime_state.steps,
        "rewards": runtime_state.rewards,
        "reward_breakdown": (
            final_result.observation.reward_breakdown.model_dump()
            if final_result.observation.reward_breakdown
            else None
        ),
    }


async def run_single_task_with_retries(
    client: OpenAI,
    task_id: str,
) -> dict[str, Any]:
    """Run one task with a fresh env connection and bounded reconnect retries."""
    attempts = TASK_RETRY_COUNT + 1
    last_error: Exception | None = None

    for attempt in range(1, attempts + 1):
        env = None
        success = False
        steps = 0
        score = 0.0
        rewards: list[float] = []
        try:
            env = await connect_env()
            summary = await run_single_task(client, env, task_id)
            success = bool(summary["success"])
            steps = int(summary["steps"])
            score = float(summary["score"])
            rewards = list(summary["rewards"])
            return summary
        except (ConnectionClosedError, ConnectionError, TimeoutError, OSError) as exc:
            last_error = exc
            print(
                (
                    f"[WARN] task_id={task_id} attempt={attempt}/{attempts} "
                    f"env_connection_error={type(exc).__name__}: {exc}"
                ),
                flush=True,
                file=sys.stderr,
            )
            if attempt >= attempts:
                raise
        finally:
            try:
                if env is not None:
                    await env.close()
            except Exception:
                pass
            if env is not None:
                log_end(success=success, steps=steps, score=score, rewards=rewards)

    assert last_error is not None
    raise last_error


async def main() -> None:
    if not API_KEY:
        raise RuntimeError("Set OPENAI_API_KEY, HF_TOKEN, or API_KEY.")

    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
    task_summaries: list[dict[str, Any]] = []

    for task_id in DEFAULT_TASK_ORDER:
        task_summaries.append(await run_single_task_with_retries(client, task_id))


if __name__ == "__main__":
    asyncio.run(main())