VLAdaptorBench / code /scripts /run_oven_frame_batch.py
lsnu's picture
Add iter29 single-pass logging and episode0 debug-aware GIF suite
ba3985e verified
from pathlib import Path
import argparse
import json
import pickle
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from rr_label_study.oven_study import (
BimanualTakeTrayOutOfOven,
ReplayCache,
_build_frame_artifacts,
_launch_replay_env,
_load_demo,
)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--episode-dir", required=True)
parser.add_argument("--templates-pkl", required=True)
parser.add_argument("--frame-indices", nargs="+", type=int, required=True)
parser.add_argument("--checkpoint-stride", type=int, default=16)
parser.add_argument("--output-dir", required=True)
parser.add_argument("--independent-replay", action="store_true")
args = parser.parse_args()
episode_dir = Path(args.episode_dir)
with Path(args.templates_pkl).open("rb") as handle:
templates = pickle.load(handle)
demo = _load_demo(episode_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
env = _launch_replay_env()
try:
task = env.get_task(BimanualTakeTrayOutOfOven)
cache = ReplayCache(task, demo, checkpoint_stride=args.checkpoint_stride)
cache.reset()
initial_snapshot = cache.snapshot() if args.independent_replay else None
for frame_index in sorted({int(index) for index in args.frame_indices}):
if args.independent_replay:
cache.restore_to_index(initial_snapshot, 0)
cache.step_to(frame_index)
frame_snapshot = cache.snapshot() if not args.independent_replay else None
state = cache.current_state()
row, debug = _build_frame_artifacts(
episode_dir=episode_dir,
demo=demo,
templates=templates,
task=task,
state=state,
)
with output_dir.joinpath(f"frame_{frame_index:04d}.json").open(
"w", encoding="utf-8"
) as handle:
json.dump(row, handle)
with output_dir.joinpath(f"frame_{frame_index:04d}.debug.json").open(
"w", encoding="utf-8"
) as handle:
json.dump(debug, handle)
if frame_snapshot is not None:
cache.restore(frame_snapshot)
finally:
env.shutdown()
return 0
if __name__ == "__main__":
raise SystemExit(main())