| import torch | |
| from train.losses import LossWeights, compute_total_loss | |
| def _model_output(planner_scores: torch.Tensor) -> dict[str, torch.Tensor]: | |
| shape = planner_scores.shape | |
| zeros = torch.zeros(shape, dtype=planner_scores.dtype, device=planner_scores.device) | |
| return { | |
| "action_mean": torch.zeros(shape[0], 1, 14, dtype=planner_scores.dtype, device=planner_scores.device), | |
| "planner_scores": planner_scores, | |
| "planner_success_logits": zeros, | |
| "planner_risk_values": zeros, | |
| } | |
| def test_candidate_ranking_loss_prefers_oracle_order(): | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "candidate_retrieval_success": torch.tensor([[1.0, 0.5, 0.0]]), | |
| "candidate_final_disturbance_cost": torch.tensor([[0.0, 0.1, 0.2]]), | |
| "candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.3]]), | |
| "candidate_utility": torch.tensor([[1.0, 0.4, -0.5]]), | |
| } | |
| weights = LossWeights(action=0.0, planner_success=0.0, planner_risk=0.0, planner_ranking=1.0) | |
| aligned = compute_total_loss(_model_output(torch.tensor([[2.0, 0.5, -1.0]])), batch, weights=weights) | |
| reversed_order = compute_total_loss(_model_output(torch.tensor([[-1.0, 0.5, 2.0]])), batch, weights=weights) | |
| assert float(aligned["planner_ranking"]) < float(reversed_order["planner_ranking"]) | |
| def _proposal_model_output(proposal_logits: torch.Tensor) -> dict[str, torch.Tensor]: | |
| return { | |
| "action_mean": torch.zeros(proposal_logits.shape[0], 1, 14, dtype=proposal_logits.dtype, device=proposal_logits.device), | |
| "proposal_logits": proposal_logits, | |
| "proposal_mode_logits": torch.zeros(proposal_logits.shape[0], 2, dtype=proposal_logits.dtype, device=proposal_logits.device), | |
| "proposal_mode_assignments": torch.tensor([0, 1] * ((proposal_logits.shape[1] + 1) // 2), dtype=torch.long, device=proposal_logits.device)[: proposal_logits.shape[1]], | |
| "proposal_candidates": torch.zeros( | |
| proposal_logits.shape[0], | |
| proposal_logits.shape[1], | |
| 1, | |
| 14, | |
| dtype=proposal_logits.dtype, | |
| device=proposal_logits.device, | |
| ), | |
| } | |
| def test_proposal_ranking_uses_aligned_targets_when_present(): | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "candidate_action_chunks": torch.zeros(1, 3, 1, 14), | |
| "candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]), | |
| "candidate_utility": torch.tensor([[1.0, -0.2, -0.4]]), | |
| "proposal_target_action_chunks": torch.zeros(1, 3, 1, 14), | |
| "proposal_target_retrieval_success": torch.tensor([[0.0, 1.0, 0.0]]), | |
| "proposal_target_utility": torch.tensor([[-0.3, 1.0, -0.1]]), | |
| "proposal_target_risk": torch.tensor([[0.8, 0.0, 0.2]]), | |
| } | |
| weights = LossWeights(action=0.0, proposal_reconstruction=0.0, proposal_success=0.0, proposal_ranking=1.0) | |
| aligned = compute_total_loss(_proposal_model_output(torch.tensor([[0.0, 2.0, -1.0]])), batch, weights=weights) | |
| reversed_order = compute_total_loss(_proposal_model_output(torch.tensor([[2.0, -1.0, 0.0]])), batch, weights=weights) | |
| assert float(aligned["proposal_ranking"]) < float(reversed_order["proposal_ranking"]) | |
| def test_proposal_reconstruction_uses_order_invariant_teacher_family(): | |
| proposal_candidates = torch.zeros(1, 2, 1, 14) | |
| proposal_candidates[0, 0, 0, 0] = 1.0 | |
| proposal_candidates[0, 1, 0, 1] = 1.0 | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "candidate_action_chunks": proposal_candidates.flip(1).clone(), | |
| "proposal_target_action_chunks": torch.full_like(proposal_candidates, 5.0), | |
| } | |
| model_output = _proposal_model_output(torch.zeros(1, 2)) | |
| model_output["proposal_candidates"] = proposal_candidates | |
| weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0) | |
| losses = compute_total_loss(model_output, batch, weights=weights) | |
| assert float(losses["proposal_reconstruction"]) < 1e-6 | |
| def test_proposal_reconstruction_prefers_high_utility_teacher_subset(): | |
| proposal_candidates = torch.zeros(1, 4, 1, 14) | |
| proposal_candidates[0, 0, 0, 0] = 1.0 | |
| proposal_candidates[0, 1, 0, 1] = 1.0 | |
| proposal_candidates[0, 2, 0, 0] = 1.0 | |
| proposal_candidates[0, 3, 0, 1] = 1.0 | |
| teacher_candidates = torch.zeros(1, 4, 1, 14) | |
| teacher_candidates[0, 0, 0, 1] = 1.0 | |
| teacher_candidates[0, 1, 0, 0] = 1.0 | |
| teacher_candidates[0, 2, 0, 2] = 5.0 | |
| teacher_candidates[0, 3, 0, 3] = 5.0 | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "candidate_action_chunks": teacher_candidates, | |
| "candidate_utility": torch.tensor([[1.0, 0.9, -1.0, -2.0]]), | |
| } | |
| model_output = _proposal_model_output(torch.zeros(1, 4)) | |
| model_output["proposal_candidates"] = proposal_candidates | |
| weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0) | |
| losses = compute_total_loss(model_output, batch, weights=weights) | |
| assert float(losses["proposal_reconstruction"]) < 1e-6 | |
| def test_bag_proposal_reconstruction_anchors_to_fallback_targets(): | |
| proposal_candidates = torch.zeros(1, 4, 1, 14) | |
| proposal_candidates[0, 0, 0, 0] = 1.0 | |
| proposal_candidates[0, 1, 0, 1] = 1.0 | |
| proposal_candidates[0, 2, 0, 0] = 1.0 | |
| proposal_candidates[0, 3, 0, 1] = 1.0 | |
| teacher_candidates = torch.zeros(1, 4, 1, 14) | |
| teacher_candidates[0, 0, 0, 2] = 5.0 | |
| teacher_candidates[0, 1, 0, 3] = 5.0 | |
| teacher_candidates[0, 2, 0, 4] = 5.0 | |
| teacher_candidates[0, 3, 0, 5] = 5.0 | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "task_name": ["bag"], | |
| "candidate_action_chunks": teacher_candidates, | |
| "candidate_utility": torch.tensor([[1.0, 0.9, -1.0, -2.0]]), | |
| "proposal_target_action_chunks": proposal_candidates.clone(), | |
| } | |
| model_output = _proposal_model_output(torch.zeros(1, 4)) | |
| model_output["proposal_candidates"] = proposal_candidates | |
| weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0) | |
| losses = compute_total_loss(model_output, batch, weights=weights) | |
| assert float(losses["proposal_reconstruction"]) < 1e-6 | |
| def test_proposal_mode_loss_prefers_mode_with_highest_utility(): | |
| batch = { | |
| "action_chunk": torch.zeros(1, 1, 14), | |
| "proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0]]), | |
| "proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3]]), | |
| "proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7]]), | |
| } | |
| weights = LossWeights( | |
| action=0.0, | |
| proposal_reconstruction=0.0, | |
| proposal_success=0.0, | |
| proposal_ranking=0.0, | |
| proposal_mode=1.0, | |
| proposal_diversity=0.0, | |
| ) | |
| aligned = _proposal_model_output(torch.zeros(1, 4)) | |
| aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0]]) | |
| aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) | |
| reversed_order = _proposal_model_output(torch.zeros(1, 4)) | |
| reversed_order["proposal_mode_logits"] = torch.tensor([[-1.0, 2.0]]) | |
| reversed_order["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) | |
| aligned_losses = compute_total_loss(aligned, batch, weights=weights) | |
| reversed_losses = compute_total_loss(reversed_order, batch, weights=weights) | |
| assert float(aligned_losses["proposal_mode"]) < float(reversed_losses["proposal_mode"]) | |
| def test_proposal_mode_loss_can_focus_on_cloth_only(): | |
| batch = { | |
| "action_chunk": torch.zeros(2, 1, 14), | |
| "task_name": ["cloth", "foliage"], | |
| "proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]]), | |
| "proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3], [0.9, -0.4, 0.8, -0.3]]), | |
| "proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7], [0.0, 0.8, 0.1, 0.7]]), | |
| } | |
| weights = LossWeights( | |
| action=0.0, | |
| proposal_reconstruction=0.0, | |
| proposal_success=0.0, | |
| proposal_ranking=0.0, | |
| proposal_mode=1.0, | |
| proposal_mode_cloth_only=True, | |
| proposal_diversity=0.0, | |
| ) | |
| aligned = _proposal_model_output(torch.zeros(2, 4)) | |
| aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [2.0, -1.0]]) | |
| aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) | |
| foliage_reversed = _proposal_model_output(torch.zeros(2, 4)) | |
| foliage_reversed["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [-1.0, 2.0]]) | |
| foliage_reversed["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) | |
| aligned_losses = compute_total_loss(aligned, batch, weights=weights) | |
| foliage_reversed_losses = compute_total_loss(foliage_reversed, batch, weights=weights) | |
| assert torch.isclose(aligned_losses["proposal_mode"], foliage_reversed_losses["proposal_mode"]) | |
| def test_proposal_mode_loss_can_focus_on_selected_tasks(): | |
| batch = { | |
| "action_chunk": torch.zeros(2, 1, 14), | |
| "task_name": ["bag", "foliage"], | |
| "proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]]), | |
| "proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3], [0.9, -0.4, 0.8, -0.3]]), | |
| "proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7], [0.0, 0.8, 0.1, 0.7]]), | |
| } | |
| weights = LossWeights( | |
| action=0.0, | |
| proposal_reconstruction=0.0, | |
| proposal_success=0.0, | |
| proposal_ranking=0.0, | |
| proposal_mode=1.0, | |
| proposal_mode_task_filter=["bag"], | |
| proposal_diversity=0.0, | |
| ) | |
| aligned = _proposal_model_output(torch.zeros(2, 4)) | |
| aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [2.0, -1.0]]) | |
| aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) | |
| foliage_reversed = _proposal_model_output(torch.zeros(2, 4)) | |
| foliage_reversed["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [-1.0, 2.0]]) | |
| foliage_reversed["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long) | |
| aligned_losses = compute_total_loss(aligned, batch, weights=weights) | |
| foliage_reversed_losses = compute_total_loss(foliage_reversed, batch, weights=weights) | |
| assert torch.isclose(aligned_losses["proposal_mode"], foliage_reversed_losses["proposal_mode"]) | |