VLAarchtests3 / code /VLAarchtests2_code /VLAarchtests /tests /test_candidate_ranking_loss.py
lsnu's picture
Add files using upload-large-folder tool
31ade1f verified
import torch
from train.losses import LossWeights, compute_total_loss
def _model_output(planner_scores: torch.Tensor) -> dict[str, torch.Tensor]:
shape = planner_scores.shape
zeros = torch.zeros(shape, dtype=planner_scores.dtype, device=planner_scores.device)
return {
"action_mean": torch.zeros(shape[0], 1, 14, dtype=planner_scores.dtype, device=planner_scores.device),
"planner_scores": planner_scores,
"planner_success_logits": zeros,
"planner_risk_values": zeros,
}
def test_candidate_ranking_loss_prefers_oracle_order():
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"candidate_retrieval_success": torch.tensor([[1.0, 0.5, 0.0]]),
"candidate_final_disturbance_cost": torch.tensor([[0.0, 0.1, 0.2]]),
"candidate_reocclusion_rate": torch.tensor([[0.0, 0.1, 0.3]]),
"candidate_utility": torch.tensor([[1.0, 0.4, -0.5]]),
}
weights = LossWeights(action=0.0, planner_success=0.0, planner_risk=0.0, planner_ranking=1.0)
aligned = compute_total_loss(_model_output(torch.tensor([[2.0, 0.5, -1.0]])), batch, weights=weights)
reversed_order = compute_total_loss(_model_output(torch.tensor([[-1.0, 0.5, 2.0]])), batch, weights=weights)
assert float(aligned["planner_ranking"]) < float(reversed_order["planner_ranking"])
def _proposal_model_output(proposal_logits: torch.Tensor) -> dict[str, torch.Tensor]:
return {
"action_mean": torch.zeros(proposal_logits.shape[0], 1, 14, dtype=proposal_logits.dtype, device=proposal_logits.device),
"proposal_logits": proposal_logits,
"proposal_mode_logits": torch.zeros(proposal_logits.shape[0], 2, dtype=proposal_logits.dtype, device=proposal_logits.device),
"proposal_mode_assignments": torch.tensor([0, 1] * ((proposal_logits.shape[1] + 1) // 2), dtype=torch.long, device=proposal_logits.device)[: proposal_logits.shape[1]],
"proposal_candidates": torch.zeros(
proposal_logits.shape[0],
proposal_logits.shape[1],
1,
14,
dtype=proposal_logits.dtype,
device=proposal_logits.device,
),
}
def test_proposal_ranking_uses_aligned_targets_when_present():
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"candidate_action_chunks": torch.zeros(1, 3, 1, 14),
"candidate_retrieval_success": torch.tensor([[1.0, 0.0, 0.0]]),
"candidate_utility": torch.tensor([[1.0, -0.2, -0.4]]),
"proposal_target_action_chunks": torch.zeros(1, 3, 1, 14),
"proposal_target_retrieval_success": torch.tensor([[0.0, 1.0, 0.0]]),
"proposal_target_utility": torch.tensor([[-0.3, 1.0, -0.1]]),
"proposal_target_risk": torch.tensor([[0.8, 0.0, 0.2]]),
}
weights = LossWeights(action=0.0, proposal_reconstruction=0.0, proposal_success=0.0, proposal_ranking=1.0)
aligned = compute_total_loss(_proposal_model_output(torch.tensor([[0.0, 2.0, -1.0]])), batch, weights=weights)
reversed_order = compute_total_loss(_proposal_model_output(torch.tensor([[2.0, -1.0, 0.0]])), batch, weights=weights)
assert float(aligned["proposal_ranking"]) < float(reversed_order["proposal_ranking"])
def test_proposal_reconstruction_uses_order_invariant_teacher_family():
proposal_candidates = torch.zeros(1, 2, 1, 14)
proposal_candidates[0, 0, 0, 0] = 1.0
proposal_candidates[0, 1, 0, 1] = 1.0
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"candidate_action_chunks": proposal_candidates.flip(1).clone(),
"proposal_target_action_chunks": torch.full_like(proposal_candidates, 5.0),
}
model_output = _proposal_model_output(torch.zeros(1, 2))
model_output["proposal_candidates"] = proposal_candidates
weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0)
losses = compute_total_loss(model_output, batch, weights=weights)
assert float(losses["proposal_reconstruction"]) < 1e-6
def test_proposal_reconstruction_prefers_high_utility_teacher_subset():
proposal_candidates = torch.zeros(1, 4, 1, 14)
proposal_candidates[0, 0, 0, 0] = 1.0
proposal_candidates[0, 1, 0, 1] = 1.0
proposal_candidates[0, 2, 0, 0] = 1.0
proposal_candidates[0, 3, 0, 1] = 1.0
teacher_candidates = torch.zeros(1, 4, 1, 14)
teacher_candidates[0, 0, 0, 1] = 1.0
teacher_candidates[0, 1, 0, 0] = 1.0
teacher_candidates[0, 2, 0, 2] = 5.0
teacher_candidates[0, 3, 0, 3] = 5.0
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"candidate_action_chunks": teacher_candidates,
"candidate_utility": torch.tensor([[1.0, 0.9, -1.0, -2.0]]),
}
model_output = _proposal_model_output(torch.zeros(1, 4))
model_output["proposal_candidates"] = proposal_candidates
weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0)
losses = compute_total_loss(model_output, batch, weights=weights)
assert float(losses["proposal_reconstruction"]) < 1e-6
def test_bag_proposal_reconstruction_anchors_to_fallback_targets():
proposal_candidates = torch.zeros(1, 4, 1, 14)
proposal_candidates[0, 0, 0, 0] = 1.0
proposal_candidates[0, 1, 0, 1] = 1.0
proposal_candidates[0, 2, 0, 0] = 1.0
proposal_candidates[0, 3, 0, 1] = 1.0
teacher_candidates = torch.zeros(1, 4, 1, 14)
teacher_candidates[0, 0, 0, 2] = 5.0
teacher_candidates[0, 1, 0, 3] = 5.0
teacher_candidates[0, 2, 0, 4] = 5.0
teacher_candidates[0, 3, 0, 5] = 5.0
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"task_name": ["bag"],
"candidate_action_chunks": teacher_candidates,
"candidate_utility": torch.tensor([[1.0, 0.9, -1.0, -2.0]]),
"proposal_target_action_chunks": proposal_candidates.clone(),
}
model_output = _proposal_model_output(torch.zeros(1, 4))
model_output["proposal_candidates"] = proposal_candidates
weights = LossWeights(action=0.0, proposal_reconstruction=1.0, proposal_success=0.0, proposal_ranking=0.0, proposal_diversity=0.0)
losses = compute_total_loss(model_output, batch, weights=weights)
assert float(losses["proposal_reconstruction"]) < 1e-6
def test_proposal_mode_loss_prefers_mode_with_highest_utility():
batch = {
"action_chunk": torch.zeros(1, 1, 14),
"proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0]]),
"proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3]]),
"proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7]]),
}
weights = LossWeights(
action=0.0,
proposal_reconstruction=0.0,
proposal_success=0.0,
proposal_ranking=0.0,
proposal_mode=1.0,
proposal_diversity=0.0,
)
aligned = _proposal_model_output(torch.zeros(1, 4))
aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0]])
aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long)
reversed_order = _proposal_model_output(torch.zeros(1, 4))
reversed_order["proposal_mode_logits"] = torch.tensor([[-1.0, 2.0]])
reversed_order["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long)
aligned_losses = compute_total_loss(aligned, batch, weights=weights)
reversed_losses = compute_total_loss(reversed_order, batch, weights=weights)
assert float(aligned_losses["proposal_mode"]) < float(reversed_losses["proposal_mode"])
def test_proposal_mode_loss_can_focus_on_cloth_only():
batch = {
"action_chunk": torch.zeros(2, 1, 14),
"task_name": ["cloth", "foliage"],
"proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]]),
"proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3], [0.9, -0.4, 0.8, -0.3]]),
"proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7], [0.0, 0.8, 0.1, 0.7]]),
}
weights = LossWeights(
action=0.0,
proposal_reconstruction=0.0,
proposal_success=0.0,
proposal_ranking=0.0,
proposal_mode=1.0,
proposal_mode_cloth_only=True,
proposal_diversity=0.0,
)
aligned = _proposal_model_output(torch.zeros(2, 4))
aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [2.0, -1.0]])
aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long)
foliage_reversed = _proposal_model_output(torch.zeros(2, 4))
foliage_reversed["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [-1.0, 2.0]])
foliage_reversed["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long)
aligned_losses = compute_total_loss(aligned, batch, weights=weights)
foliage_reversed_losses = compute_total_loss(foliage_reversed, batch, weights=weights)
assert torch.isclose(aligned_losses["proposal_mode"], foliage_reversed_losses["proposal_mode"])
def test_proposal_mode_loss_can_focus_on_selected_tasks():
batch = {
"action_chunk": torch.zeros(2, 1, 14),
"task_name": ["bag", "foliage"],
"proposal_target_retrieval_success": torch.tensor([[1.0, 0.0, 1.0, 0.0], [1.0, 0.0, 1.0, 0.0]]),
"proposal_target_utility": torch.tensor([[0.9, -0.4, 0.8, -0.3], [0.9, -0.4, 0.8, -0.3]]),
"proposal_target_risk": torch.tensor([[0.0, 0.8, 0.1, 0.7], [0.0, 0.8, 0.1, 0.7]]),
}
weights = LossWeights(
action=0.0,
proposal_reconstruction=0.0,
proposal_success=0.0,
proposal_ranking=0.0,
proposal_mode=1.0,
proposal_mode_task_filter=["bag"],
proposal_diversity=0.0,
)
aligned = _proposal_model_output(torch.zeros(2, 4))
aligned["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [2.0, -1.0]])
aligned["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long)
foliage_reversed = _proposal_model_output(torch.zeros(2, 4))
foliage_reversed["proposal_mode_logits"] = torch.tensor([[2.0, -1.0], [-1.0, 2.0]])
foliage_reversed["proposal_mode_assignments"] = torch.tensor([0, 1, 0, 1], dtype=torch.long)
aligned_losses = compute_total_loss(aligned, batch, weights=weights)
foliage_reversed_losses = compute_total_loss(foliage_reversed, batch, weights=weights)
assert torch.isclose(aligned_losses["proposal_mode"], foliage_reversed_losses["proposal_mode"])