VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_adapter_dataset_alignment.py
lsnu's picture
Add files using upload-large-folder tool
31ade1f verified
from pathlib import Path
from omegaconf import OmegaConf
import train.run_experiment as run_experiment
def test_adapter_dataset_rebuilds_when_existing_bundle_lacks_proposal_targets(monkeypatch, tmp_path):
dataset_path = tmp_path / "proxy_train.pt"
dataset_path.write_text("stub", encoding="utf-8")
builder = object()
captured: dict[str, object] = {}
monkeypatch.setattr(
run_experiment,
"load_teacher_dataset",
lambda path: {"samples": [{"task_name": "bag"}]},
)
def _fake_collect_teacher_dataset(**kwargs):
captured.update(kwargs)
return {
"samples": [
{
"task_name": "bag",
"proposal_target_action_chunks": 1,
"proposal_target_retrieval_success": 1,
"proposal_target_risk": 1,
"proposal_target_utility": 1,
}
]
}
monkeypatch.setattr(run_experiment, "collect_teacher_dataset", _fake_collect_teacher_dataset)
monkeypatch.setattr(run_experiment, "save_teacher_dataset", lambda path, bundle: Path(path))
data_cfg = OmegaConf.create(
{
"proxies": ["bag_proxy"],
"resolution": 16,
"seed": 17,
"chunk_horizon": 2,
"rollout_horizon": 2,
"history_steps": 1,
"planner_candidates": 2,
"dataset_version": "proxy_test",
"train_episodes_per_proxy": 1,
"train_dataset_path": str(dataset_path),
"rebuild_dataset": False,
}
)
bundle = run_experiment._build_dataset_from_config(
data_cfg,
"train",
proposal_target_builder=builder,
require_proposal_targets=True,
)
assert run_experiment._bundle_has_proposal_targets(bundle)
assert captured["proposal_target_builder"] is builder
def test_adapter_dataset_missing_proposal_targets_raises_without_builder(monkeypatch, tmp_path):
dataset_path = tmp_path / "proxy_train.pt"
dataset_path.write_text("stub", encoding="utf-8")
monkeypatch.setattr(
run_experiment,
"load_teacher_dataset",
lambda path: {"samples": [{"task_name": "bag"}]},
)
data_cfg = OmegaConf.create(
{
"proxies": ["bag_proxy"],
"resolution": 16,
"seed": 17,
"chunk_horizon": 2,
"rollout_horizon": 2,
"history_steps": 1,
"planner_candidates": 2,
"dataset_version": "proxy_test",
"train_episodes_per_proxy": 1,
"train_dataset_path": str(dataset_path),
"rebuild_dataset": False,
}
)
try:
run_experiment._build_dataset_from_config(
data_cfg,
"train",
proposal_target_builder=None,
require_proposal_targets=True,
)
except RuntimeError as exc:
assert "proposal-aligned targets" in str(exc)
else:
raise AssertionError("Expected a RuntimeError for unaligned adapter dataset.")
def test_adapter_dataset_rebuilds_when_transition_rollout_targets_are_missing(monkeypatch, tmp_path):
dataset_path = tmp_path / "proxy_train.pt"
dataset_path.write_text("stub", encoding="utf-8")
builder = object()
captured: dict[str, object] = {}
monkeypatch.setattr(
run_experiment,
"load_teacher_dataset",
lambda path: {
"samples": [
{
"task_name": "bag",
"proposal_target_action_chunks": 1,
"proposal_target_retrieval_success": 1,
"proposal_target_risk": 1,
"proposal_target_utility": 1,
}
]
},
)
def _fake_collect_teacher_dataset(**kwargs):
captured.update(kwargs)
return {
"samples": [
{
"task_name": "bag",
"proposal_target_action_chunks": 1,
"proposal_target_retrieval_success": 1,
"proposal_target_risk": 1,
"proposal_target_utility": 1,
"proposal_target_rollout_support_mode": 1,
"proposal_target_rollout_corridor_feasible": 1,
"proposal_target_rollout_persistence_horizon": 1,
"proposal_target_rollout_disturbance_cost": 1,
}
]
}
monkeypatch.setattr(run_experiment, "collect_teacher_dataset", _fake_collect_teacher_dataset)
monkeypatch.setattr(run_experiment, "save_teacher_dataset", lambda path, bundle: Path(path))
data_cfg = OmegaConf.create(
{
"proxies": ["bag_proxy"],
"resolution": 16,
"seed": 17,
"chunk_horizon": 2,
"rollout_horizon": 2,
"history_steps": 1,
"planner_candidates": 2,
"dataset_version": "proxy_test",
"train_episodes_per_proxy": 1,
"train_dataset_path": str(dataset_path),
"rebuild_dataset": False,
}
)
bundle = run_experiment._build_dataset_from_config(
data_cfg,
"train",
proposal_target_builder=builder,
require_proposal_targets=True,
require_proposal_rollout_targets=True,
)
assert run_experiment._bundle_has_proposal_rollout_targets(bundle)
assert captured["proposal_target_builder"] is builder