File size: 3,233 Bytes
7f173cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba3985e
7f173cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from pathlib import Path
import argparse
import json
import pickle
import sys

import numpy as np


PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from rr_label_study.oven_study import (
    BimanualTakeTrayOutOfOven,
    ReplayCache,
    _launch_replay_env,
    _load_demo,
    _pregrasp_progress_and_distance,
    _pregrasp_score_and_success,
)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode-dir", required=True)
    parser.add_argument("--templates-pkl", required=True)
    parser.add_argument("--frame-indices", nargs="+", type=int, required=True)
    parser.add_argument("--checkpoint-stride", type=int, default=16)
    parser.add_argument("--output-dir", required=True)
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    frame_indices = sorted(set(args.frame_indices))
    pending_frame_indices = [
        frame_index
        for frame_index in frame_indices
        if not output_dir.joinpath(f"frame_{frame_index:04d}.json").exists()
    ]
    if not pending_frame_indices:
        return 0

    episode_dir = Path(args.episode_dir)
    with Path(args.templates_pkl).open("rb") as handle:
        templates = pickle.load(handle)
    demo = _load_demo(episode_dir)

    env = _launch_replay_env()
    try:
        task = env.get_task(BimanualTakeTrayOutOfOven)
        cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride)
        cache.reset()
        total = len(pending_frame_indices)
        for completed, frame_index in enumerate(pending_frame_indices, start=1):
            cache.step_to(frame_index)
            state = cache.current_state()
            pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance(
                np.asarray(state.left_gripper_pose, dtype=np.float64),
                np.asarray(state.tray_pose, dtype=np.float64),
                templates,
            )
            p_pre, y_pre, _ = _pregrasp_score_and_success(task, templates)
            row = {
                "frame_index": int(frame_index),
                "pregrasp_progress": float(pregrasp_progress),
                "pregrasp_distance": float(pregrasp_distance),
                "p_pre": float(p_pre),
                "y_pre_raw": float(bool(y_pre)),
                "y_pre": float(bool(y_pre)),
            }
            row_path = output_dir.joinpath(f"frame_{frame_index:04d}.json")
            tmp_path = row_path.with_suffix(".json.tmp")
            with tmp_path.open("w", encoding="utf-8") as handle:
                json.dump(row, handle)
            tmp_path.replace(row_path)
            if completed == total or completed % 8 == 0:
                print(
                    json.dumps(
                        {
                            "done": completed,
                            "total": total,
                            "frame_index": int(frame_index),
                        }
                    ),
                    flush=True,
                )
    finally:
        env.shutdown()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())