File size: 2,522 Bytes
712dc89
 
 
 
 
 
 
 
 
 
 
7f173cd
ba3985e
 
 
 
7f173cd
 
712dc89
 
 
 
 
 
 
 
 
7f173cd
712dc89
 
 
 
 
 
 
 
ba3985e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712dc89
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pathlib import Path
import argparse
import json
import pickle
import sys


PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from rr_label_study.oven_study import (
    BimanualTakeTrayOutOfOven,
    ReplayCache,
    _build_frame_artifacts,
    _launch_replay_env,
    _load_demo,
)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode-dir", required=True)
    parser.add_argument("--templates-pkl", required=True)
    parser.add_argument("--frame-indices", nargs="+", type=int, required=True)
    parser.add_argument("--checkpoint-stride", type=int, default=16)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--independent-replay", action="store_true")
    args = parser.parse_args()

    episode_dir = Path(args.episode_dir)
    with Path(args.templates_pkl).open("rb") as handle:
        templates = pickle.load(handle)
    demo = _load_demo(episode_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    env = _launch_replay_env()
    try:
        task = env.get_task(BimanualTakeTrayOutOfOven)
        cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride)
        cache.reset()
        initial_snapshot = cache.snapshot() if args.independent_replay else None
        for frame_index in sorted({int(index) for index in args.frame_indices}):
            if args.independent_replay:
                cache.restore_to_index(initial_snapshot, 0)
            cache.step_to(frame_index)
            frame_snapshot = cache.snapshot() if not args.independent_replay else None
            state = cache.current_state()
            row, debug = _build_frame_artifacts(
                episode_dir=episode_dir,
                demo=demo,
                templates=templates,
                task=task,
                state=state,
            )
            with output_dir.joinpath(f"frame_{frame_index:04d}.json").open(
                "w", encoding="utf-8"
            ) as handle:
                json.dump(row, handle)
            with output_dir.joinpath(f"frame_{frame_index:04d}.debug.json").open(
                "w", encoding="utf-8"
            ) as handle:
                json.dump(debug, handle)
            if frame_snapshot is not None:
                cache.restore(frame_snapshot)
    finally:
        env.shutdown()
    return 0


if __name__ == "__main__":
    raise SystemExit(main())