File size: 3,048 Bytes
b14c4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from sim_reveal.procedural_envs import BAG_PROXY, CLOTH_PROXY, FOLIAGE_PROXY, make_proxy_env


def test_proxy_scripted_bench():
    bag_wins = 0
    foliage_wins = 0
    cloth_wins = 0
    for seed in range(10):
        bag_env = make_proxy_env(proxy_name=BAG_PROXY.name, resolution=32, seed=seed, rollout_horizon=4)
        bag_env.reset(seed=seed)
        maintain = bag_env.evaluate_action_chunk(bag_env.macro_action_chunk("maintain_mouth", chunk_horizon=4), rollout_horizon=4)
        retrieve = bag_env.evaluate_action_chunk(bag_env.macro_action_chunk("premature_retrieve", chunk_horizon=4), rollout_horizon=4)
        if (maintain["hold_persistence"] >= retrieve["hold_persistence"]) and (maintain["reocclusion_rate"] <= retrieve["reocclusion_rate"]):
            bag_wins += 1

        foliage_env = make_proxy_env(proxy_name=FOLIAGE_PROXY.name, resolution=32, seed=seed, rollout_horizon=4)
        foliage_env.reset(seed=seed)
        pin = foliage_env.evaluate_action_chunk(foliage_env.macro_action_chunk("pin_canopy", chunk_horizon=4), rollout_horizon=4)
        swipe = foliage_env.evaluate_action_chunk(foliage_env.macro_action_chunk("foliage_immediate_reocclusion", chunk_horizon=4), rollout_horizon=4)
        if (pin["reocclusion_rate"] <= swipe["reocclusion_rate"]) and (pin["visibility_integral"] >= swipe["visibility_integral"]):
            foliage_wins += 1

        cloth_env = make_proxy_env(proxy_name=CLOTH_PROXY.name, resolution=32, seed=seed, rollout_horizon=4)
        cloth_env.reset(seed=seed)
        stabilize = cloth_env.evaluate_action_chunk(cloth_env.macro_action_chunk("stabilize_fold", chunk_horizon=4), rollout_horizon=4)
        lift_high = cloth_env.evaluate_action_chunk(cloth_env.macro_action_chunk("cloth_lift_high", chunk_horizon=4), rollout_horizon=4)
        if (stabilize["candidate_fold_preservation"] >= lift_high["candidate_fold_preservation"]) and (
            stabilize["final_disturbance_cost"] <= lift_high["final_disturbance_cost"]
        ):
            cloth_wins += 1

    assert bag_wins >= 8
    assert foliage_wins >= 8
    assert cloth_wins >= 8


def test_cloth_candidate_utility_penalizes_fold_damage():
    cloth_env = make_proxy_env(proxy_name=CLOTH_PROXY.name, resolution=32, seed=0, rollout_horizon=4)
    gentle_reveal = {
        "retrieval_success": 0.0,
        "final_disturbance_cost": 0.18,
        "reocclusion_rate": 0.02,
        "candidate_layer_separation_quality": 0.92,
        "candidate_fold_preservation": 0.78,
        "candidate_top_layer_stability": 0.82,
        "candidate_lift_too_much_risk": 0.08,
    }
    destructive_reveal = {
        "retrieval_success": 0.0,
        "final_disturbance_cost": 0.52,
        "reocclusion_rate": 0.02,
        "candidate_layer_separation_quality": 0.92,
        "candidate_fold_preservation": 0.06,
        "candidate_top_layer_stability": 0.18,
        "candidate_lift_too_much_risk": 0.42,
    }

    assert cloth_env.candidate_outcome_utility(gentle_reveal) > cloth_env.candidate_outcome_utility(destructive_reveal)