VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_dataset_hard_negative_presence.py
| from sim_reveal.dataset import collect_teacher_dataset | |
| from sim_reveal.procedural_envs import available_proxy_names | |
| def test_dataset_hard_negative_presence(): | |
| dataset_bundle = collect_teacher_dataset( | |
| proxy_names=available_proxy_names(), | |
| episodes_per_proxy=1, | |
| resolution=32, | |
| seed=3, | |
| chunk_horizon=4, | |
| rollout_horizon=4, | |
| planner_candidates=6, | |
| ) | |
| negative_families = set() | |
| hard_negative_count = 0 | |
| for sample in dataset_bundle["samples"]: | |
| hard_negative_count += int(sum(sample["candidate_is_hard_negative"])) | |
| negative_families.update( | |
| family | |
| for family in sample["candidate_negative_families"] | |
| if family not in {"teacher", "positive"} | |
| ) | |
| assert hard_negative_count > 0 | |
| assert {"premature_retrieve", "reveal_with_release"}.issubset(negative_families) | |