VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_dataset_hard_negative_presence.py
lsnu's picture
Add files using upload-large-folder tool
b14c4b7 verified
from sim_reveal.dataset import collect_teacher_dataset
from sim_reveal.procedural_envs import available_proxy_names
def test_dataset_hard_negative_presence():
dataset_bundle = collect_teacher_dataset(
proxy_names=available_proxy_names(),
episodes_per_proxy=1,
resolution=32,
seed=3,
chunk_horizon=4,
rollout_horizon=4,
planner_candidates=6,
)
negative_families = set()
hard_negative_count = 0
for sample in dataset_bundle["samples"]:
hard_negative_count += int(sum(sample["candidate_is_hard_negative"]))
negative_families.update(
family
for family in sample["candidate_negative_families"]
if family not in {"teacher", "positive"}
)
assert hard_negative_count > 0
assert {"premature_retrieve", "reveal_with_release"}.issubset(negative_families)