File size: 3,233 Bytes
7f173cd ba3985e 7f173cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | from pathlib import Path
import argparse
import json
import pickle
import sys
import numpy as np
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from rr_label_study.oven_study import (
BimanualTakeTrayOutOfOven,
ReplayCache,
_launch_replay_env,
_load_demo,
_pregrasp_progress_and_distance,
_pregrasp_score_and_success,
)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--episode-dir", required=True)
parser.add_argument("--templates-pkl", required=True)
parser.add_argument("--frame-indices", nargs="+", type=int, required=True)
parser.add_argument("--checkpoint-stride", type=int, default=16)
parser.add_argument("--output-dir", required=True)
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
frame_indices = sorted(set(args.frame_indices))
pending_frame_indices = [
frame_index
for frame_index in frame_indices
if not output_dir.joinpath(f"frame_{frame_index:04d}.json").exists()
]
if not pending_frame_indices:
return 0
episode_dir = Path(args.episode_dir)
with Path(args.templates_pkl).open("rb") as handle:
templates = pickle.load(handle)
demo = _load_demo(episode_dir)
env = _launch_replay_env()
try:
task = env.get_task(BimanualTakeTrayOutOfOven)
cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride)
cache.reset()
total = len(pending_frame_indices)
for completed, frame_index in enumerate(pending_frame_indices, start=1):
cache.step_to(frame_index)
state = cache.current_state()
pregrasp_progress, pregrasp_distance = _pregrasp_progress_and_distance(
np.asarray(state.left_gripper_pose, dtype=np.float64),
np.asarray(state.tray_pose, dtype=np.float64),
templates,
)
p_pre, y_pre, _ = _pregrasp_score_and_success(task, templates)
row = {
"frame_index": int(frame_index),
"pregrasp_progress": float(pregrasp_progress),
"pregrasp_distance": float(pregrasp_distance),
"p_pre": float(p_pre),
"y_pre_raw": float(bool(y_pre)),
"y_pre": float(bool(y_pre)),
}
row_path = output_dir.joinpath(f"frame_{frame_index:04d}.json")
tmp_path = row_path.with_suffix(".json.tmp")
with tmp_path.open("w", encoding="utf-8") as handle:
json.dump(row, handle)
tmp_path.replace(row_path)
if completed == total or completed % 8 == 0:
print(
json.dumps(
{
"done": completed,
"total": total,
"frame_index": int(frame_index),
}
),
flush=True,
)
finally:
env.shutdown()
return 0
if __name__ == "__main__":
raise SystemExit(main())
|