File size: 1,334 Bytes
150d02a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from pathlib import Path
import argparse
import json
import pickle
import sys


PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from rr_label_study.oven_study import _compute_frame_row_isolated, _load_demo


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode-dir", required=True)
    parser.add_argument("--templates-pkl", required=True)
    parser.add_argument("--frame-index", type=int, required=True)
    parser.add_argument("--checkpoint-stride", type=int, default=16)
    parser.add_argument("--output-json", required=True)
    args = parser.parse_args()

    episode_dir = Path(args.episode_dir)
    with Path(args.templates_pkl).open("rb") as handle:
        templates = pickle.load(handle)
    demo = _load_demo(episode_dir)
    row = _compute_frame_row_isolated(
        episode_dir=episode_dir,
        demo=demo,
        templates=templates,
        checkpoint_stride=args.checkpoint_stride,
        frame_index=args.frame_index,
    )
    output_path = Path(args.output_json)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with output_path.open("w", encoding="utf-8") as handle:
        json.dump(row, handle)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())