VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_anybimanual_overlap_eval_summary.py
lsnu's picture
Add files using upload-large-folder tool
b14c4b7 verified
from __future__ import annotations
from eval.summarize_anybimanual_overlap_eval import (
_best_overlap_step,
_delta,
_last_overlap_step,
_merge_rows_by_step,
)
def _row(step: int, **values: str) -> dict[str, str]:
base = {"step": str(step)}
base.update(values)
return base
def test_merge_rows_by_step_fills_missing_values() -> None:
rows = [
_row(600, **{"eval_envs/return/coordinated_push_box": "10"}),
_row(600, **{"eval_envs/return/coordinated_lift_ball": "4"}),
_row(600, **{"eval_envs/return/dual_push_buttons": "20"}),
]
merged = _merge_rows_by_step(rows)
assert len(merged) == 1
assert merged[0]["eval_envs/return/coordinated_push_box"] == "10"
assert merged[0]["eval_envs/return/coordinated_lift_ball"] == "4"
assert merged[0]["eval_envs/return/dual_push_buttons"] == "20"
def test_overlap_summary_picks_last_local_and_best_public() -> None:
local_rows = _merge_rows_by_step(
[
_row(
200,
**{
"eval_envs/return/coordinated_push_box": "0",
"eval_envs/return/coordinated_lift_ball": "0",
"eval_envs/return/dual_push_buttons": "0",
},
),
_row(
1000,
**{
"eval_envs/return/coordinated_push_box": "15",
"eval_envs/return/coordinated_lift_ball": "8",
"eval_envs/return/dual_push_buttons": "20",
},
),
]
)
public_rows = _merge_rows_by_step(
[
_row(
50000,
**{
"eval_envs/return/coordinated_push_box": "18",
"eval_envs/return/coordinated_lift_ball": "6",
"eval_envs/return/dual_push_buttons": "20",
},
),
_row(
60000,
**{
"eval_envs/return/coordinated_push_box": "20",
"eval_envs/return/coordinated_lift_ball": "8",
"eval_envs/return/dual_push_buttons": "24",
},
),
]
)
local_last = _last_overlap_step(local_rows, 25)
public_best = _best_overlap_step(public_rows, 25)
assert local_last["step"] == 1000
assert public_best["step"] == 60000
assert public_best["mean_success"] > local_last["mean_success"]
delta = _delta(local_last, public_best)
assert delta["mean_success_delta"] < 0.0