from pathlib import Path from omegaconf import OmegaConf import train.run_experiment as run_experiment def test_adapter_dataset_rebuilds_when_existing_bundle_lacks_proposal_targets(monkeypatch, tmp_path): dataset_path = tmp_path / "proxy_train.pt" dataset_path.write_text("stub", encoding="utf-8") builder = object() captured: dict[str, object] = {} monkeypatch.setattr( run_experiment, "load_teacher_dataset", lambda path: {"samples": [{"task_name": "bag"}]}, ) def _fake_collect_teacher_dataset(**kwargs): captured.update(kwargs) return { "samples": [ { "task_name": "bag", "proposal_target_action_chunks": 1, "proposal_target_retrieval_success": 1, "proposal_target_risk": 1, "proposal_target_utility": 1, } ] } monkeypatch.setattr(run_experiment, "collect_teacher_dataset", _fake_collect_teacher_dataset) monkeypatch.setattr(run_experiment, "save_teacher_dataset", lambda path, bundle: Path(path)) data_cfg = OmegaConf.create( { "proxies": ["bag_proxy"], "resolution": 16, "seed": 17, "chunk_horizon": 2, "rollout_horizon": 2, "history_steps": 1, "planner_candidates": 2, "dataset_version": "proxy_test", "train_episodes_per_proxy": 1, "train_dataset_path": str(dataset_path), "rebuild_dataset": False, } ) bundle = run_experiment._build_dataset_from_config( data_cfg, "train", proposal_target_builder=builder, require_proposal_targets=True, ) assert run_experiment._bundle_has_proposal_targets(bundle) assert captured["proposal_target_builder"] is builder def test_adapter_dataset_missing_proposal_targets_raises_without_builder(monkeypatch, tmp_path): dataset_path = tmp_path / "proxy_train.pt" dataset_path.write_text("stub", encoding="utf-8") monkeypatch.setattr( run_experiment, "load_teacher_dataset", lambda path: {"samples": [{"task_name": "bag"}]}, ) data_cfg = OmegaConf.create( { "proxies": ["bag_proxy"], "resolution": 16, "seed": 17, "chunk_horizon": 2, "rollout_horizon": 2, "history_steps": 1, "planner_candidates": 2, "dataset_version": "proxy_test", "train_episodes_per_proxy": 1, "train_dataset_path": str(dataset_path), "rebuild_dataset": False, } ) try: run_experiment._build_dataset_from_config( data_cfg, "train", proposal_target_builder=None, require_proposal_targets=True, ) except RuntimeError as exc: assert "proposal-aligned targets" in str(exc) else: raise AssertionError("Expected a RuntimeError for unaligned adapter dataset.") def test_adapter_dataset_rebuilds_when_transition_rollout_targets_are_missing(monkeypatch, tmp_path): dataset_path = tmp_path / "proxy_train.pt" dataset_path.write_text("stub", encoding="utf-8") builder = object() captured: dict[str, object] = {} monkeypatch.setattr( run_experiment, "load_teacher_dataset", lambda path: { "samples": [ { "task_name": "bag", "proposal_target_action_chunks": 1, "proposal_target_retrieval_success": 1, "proposal_target_risk": 1, "proposal_target_utility": 1, } ] }, ) def _fake_collect_teacher_dataset(**kwargs): captured.update(kwargs) return { "samples": [ { "task_name": "bag", "proposal_target_action_chunks": 1, "proposal_target_retrieval_success": 1, "proposal_target_risk": 1, "proposal_target_utility": 1, "proposal_target_rollout_support_mode": 1, "proposal_target_rollout_corridor_feasible": 1, "proposal_target_rollout_persistence_horizon": 1, "proposal_target_rollout_disturbance_cost": 1, } ] } monkeypatch.setattr(run_experiment, "collect_teacher_dataset", _fake_collect_teacher_dataset) monkeypatch.setattr(run_experiment, "save_teacher_dataset", lambda path, bundle: Path(path)) data_cfg = OmegaConf.create( { "proxies": ["bag_proxy"], "resolution": 16, "seed": 17, "chunk_horizon": 2, "rollout_horizon": 2, "history_steps": 1, "planner_candidates": 2, "dataset_version": "proxy_test", "train_episodes_per_proxy": 1, "train_dataset_path": str(dataset_path), "rebuild_dataset": False, } ) bundle = run_experiment._build_dataset_from_config( data_cfg, "train", proposal_target_builder=builder, require_proposal_targets=True, require_proposal_rollout_targets=True, ) assert run_experiment._bundle_has_proposal_rollout_targets(bundle) assert captured["proposal_target_builder"] is builder