from pathlib import Path import argparse import json import pickle import sys PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) from rr_label_study.oven_study import ( BimanualTakeTrayOutOfOven, ReplayCache, _build_frame_artifacts, _launch_replay_env, _load_demo, ) def main() -> int: parser = argparse.ArgumentParser() parser.add_argument("--episode-dir", required=True) parser.add_argument("--templates-pkl", required=True) parser.add_argument("--frame-indices", nargs="+", type=int, required=True) parser.add_argument("--checkpoint-stride", type=int, default=16) parser.add_argument("--output-dir", required=True) parser.add_argument("--independent-replay", action="store_true") args = parser.parse_args() episode_dir = Path(args.episode_dir) with Path(args.templates_pkl).open("rb") as handle: templates = pickle.load(handle) demo = _load_demo(episode_dir) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) env = _launch_replay_env() try: task = env.get_task(BimanualTakeTrayOutOfOven) cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride) cache.reset() initial_snapshot = cache.snapshot() if args.independent_replay else None for frame_index in sorted({int(index) for index in args.frame_indices}): if args.independent_replay: cache.restore_to_index(initial_snapshot, 0) cache.step_to(frame_index) frame_snapshot = cache.snapshot() if not args.independent_replay else None state = cache.current_state() row, debug = _build_frame_artifacts( episode_dir=episode_dir, demo=demo, templates=templates, task=task, state=state, ) with output_dir.joinpath(f"frame_{frame_index:04d}.json").open( "w", encoding="utf-8" ) as handle: json.dump(row, handle) with output_dir.joinpath(f"frame_{frame_index:04d}.debug.json").open( "w", encoding="utf-8" ) as handle: json.dump(debug, handle) if frame_snapshot is not None: cache.restore(frame_snapshot) finally: env.shutdown() return 0 if __name__ == "__main__": raise SystemExit(main())