| """ |
| Example script for running dreaming on a dataset. |
| The idea is that there are ground_truth ("reference") video clips, and we dream the same clips given some initial context. |
| |
| After dreaming, we have two sets of videos which, barring the intrinsic noise of the game environment (e.g., randomness of other players), |
| should be identical if model was ideal. |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| import os |
| import subprocess |
|
|
| import cv2 |
| from tensordict import TensorDict |
| import torch as th |
| from tqdm import tqdm |
| import numpy as np |
| import ffmpegcv |
| from PIL import Image |
|
|
| import wham.utils as utils |
|
|
|
|
| parser = argparse.ArgumentParser(description="Run dreaming.") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.") |
| parser.add_argument("--data_path", type=str, required=True, help="Path to the directory that contains the ground truth data to dream for.") |
| parser.add_argument("--output", type=str, default="dreaming_output", help="Path to the directory where output should be put.") |
| parser.add_argument("--max_files", type=int, default=None, help="Maximum number of files to process.") |
| parser.add_argument("--metadata_config", type=str, default="configs/metadata_custom_tag.config", help="Path to metadata tag config for origin field.") |
|
|
|
|
| parser.add_argument( |
| "--protocol", |
| type=str, |
| default="base", |
| choices=["base", "comprehensive"], |
| help="What protocol to use for the dreaming. base = action conditioned, comprehensive = dream actions as well.", |
| ) |
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size for dreaming. Higher batch_size uses more VRAM but overall is faster.") |
| parser.add_argument("--context_length", type=int, default=10, help="Number of frames to use an initial context.") |
| parser.add_argument("--steps_to_dream", type=int, default=10, help="Batch size for dreaming.") |
|
|
| parser.add_argument("--sampling_temperature", type=float, default=0.9, help="Temperature for sampling from the model.") |
| parser.add_argument("--sampling_top_k", type=int, default=None, help="Top-k for sampling from the model.") |
| parser.add_argument("--sampling_top_p", type=float, default=None, help="Top-p for sampling from the model.") |
|
|
|
|
| def get_context_data(image_context, action_context, action_sequences): |
| |
| assert image_context.shape[-3] == 3, "Image context should be CHW" |
|
|
| image_context = th.from_numpy(image_context).cuda() |
| action_data = th.from_numpy(action_context).float().cuda() |
| action_sequences = th.from_numpy(action_sequences).float().cuda() if action_sequences is not None else None |
|
|
| return TensorDict({"images": image_context, "actions_output": action_data}, batch_size=image_context.shape[:2]) |
|
|
|
|
| def add_video_metadata(file_path, metadata_config): |
| |
| cmd = [ |
| 'exiftool', |
| '-config', metadata_config, |
| f'-ProgramName=\"{utils.PROGRAM_NAME}\"', |
| '-overwrite_original', |
| file_path |
| ] |
|
|
| try: |
| |
| subprocess.run(cmd, check=True) |
| print(f"Metadata modified successfully.") |
| |
| cmd_output = [ |
| 'exiftool', |
| file_path |
| ] |
| subprocess.run(cmd_output, check=True) |
| except subprocess.CalledProcessError as e: |
| print(f"Error modifying metadata: {e}") |
|
|
|
|
| @th.no_grad() |
| def do_dreaming(model, image_context, action_context, args, action_sequences=None): |
| """ |
| image_contect and action_context provide the initial context for the model to dream from. |
| |
| If action_sequences (batch_size, args.steps_to_dream, action_dim) is provided, then model will be prompted with these actions. |
| """ |
| context_data = get_context_data(image_context, action_context, action_sequences) |
| encoded_context_data = model.encode_context(context_data) |
|
|
| encoded_action_sequences = None |
| if action_sequences is not None: |
| assert action_sequences.shape[1] == args.steps_to_dream, "action_sequences should have shape (batch_size, args.steps_to_dream, action_dim)" |
| action_sequences = TensorDict({"actions_output": action_sequences}, batch_size=action_sequences.shape[:2]).cuda() |
| encoded_action_sequences = model.encode_context(action_sequences) |
|
|
| encoded_dreamt_steps = [] |
|
|
| for dream_step in range(args.steps_to_dream): |
| encoded_predicted_step, _ = model.predictor.predict_next_step( |
| encoded_context_data, temperature=args.sampling_temperature, top_k=args.sampling_top_k, top_p=args.sampling_top_p, min_tokens_to_keep=1 |
| ) |
|
|
| |
| if encoded_context_data.shape[1] == args.context_length: |
| encoded_context_data = encoded_context_data[:, 1:] |
|
|
| |
| append_step = encoded_predicted_step |
| if encoded_action_sequences is not None: |
| |
| append_step["actions_output"] = encoded_action_sequences["actions_output"][:, [dream_step], :] |
| encoded_context_data = th.cat((encoded_context_data, append_step), dim=1) |
|
|
| encoded_dreamt_steps.append(encoded_predicted_step) |
|
|
| |
| dreamed_images = [] |
| actions_during_dream = [] |
| for seq_i in range(args.steps_to_dream): |
| decoded_step = model.decode_context(encoded_dreamt_steps[seq_i]) |
| dreamed_images.append(decoded_step["images"][:, [0]].cpu().numpy()) |
| actions_during_dream.append(decoded_step["actions_output"][:, [0]].cpu().numpy()) |
|
|
| dreamed_images = np.concatenate(dreamed_images, axis=1) |
| actions_during_dream = np.concatenate(actions_during_dream, axis=1) |
|
|
| return dreamed_images, actions_during_dream |
|
|
|
|
| @th.no_grad() |
| def encode_decode_images(model, images): |
| """ |
| Pass ground_truth images through the encoding/decoding process of the model. |
| """ |
| context = TensorDict({"images": th.from_numpy(images).cuda()}, batch_size=images.shape[:2]) |
| output_images = [] |
| for seq_i in range(images.shape[1]): |
| encoded_images = model.encode_context(context[:, [seq_i]]) |
| decoded_images = model.decode_context(encoded_images) |
| output_images.append(decoded_images["images"].cpu().numpy()) |
| return np.concatenate(output_images, axis=1) |
|
|
|
|
| def main(args): |
| total_video_length = args.context_length + args.steps_to_dream |
|
|
| |
| model_path = Path(args.model_path) |
| assert model_path.is_file(), "Could not find the model!" |
| model = utils.load_model_from_checkpoint(model_path).cuda() |
|
|
| |
| data_path = Path(args.data_path) |
| ground_truth_files = list(data_path.rglob("*.npz")) |
| num_dreams = len(ground_truth_files) |
|
|
| if args.max_files is not None: |
| |
| ground_truth_files = sorted(ground_truth_files) |
| ground_truth_files = ground_truth_files[: args.max_files] |
| num_dreams = len(ground_truth_files) |
|
|
| output_path = Path(args.output) |
| os.makedirs(output_path, exist_ok=True) |
|
|
| print("=" * 100) |
| print(f"GENERATING DREAMS OF {num_dreams} SEGMENTS") |
| print(f"WRITING TO {args.output}") |
| print("=" * 100) |
|
|
| dreams_created = 0 |
| with tqdm(total=num_dreams, desc="Dreams") as pbar: |
| while ground_truth_files: |
| |
| batches = min(args.batch_size, len(ground_truth_files)) |
| batched_image_context = [] |
| batched_image_sequence = [] |
| batched_action_context = [] |
| batched_action_sequence = [] |
| episode_names = [] |
| for i in range(batches): |
| episode = ground_truth_files.pop() |
| episode_names.append(episode) |
| try: |
| data = np.load(episode) |
| images = data["images"] |
| actions = data["actions"] |
| except Exception: |
| print(f"Failed to load episode {episode} - skipping.") |
| continue |
|
|
| if actions.shape[0] < total_video_length: |
| |
| raise ValueError(f"Episode {episode} is too short to dream from. It has {actions.shape[0]} steps, but we need at least {total_video_length}.") |
| batched_image_context.append(images[: args.context_length]) |
| batched_image_sequence.append(images[args.context_length: total_video_length]) |
| batched_action_context.append(actions[: args.context_length]) |
| batched_action_sequence.append(actions[args.context_length: total_video_length]) |
|
|
| image_context = np.array(batched_image_context) |
| image_sequences = np.array(batched_image_sequence) |
| action_context = np.array(batched_action_context) |
| action_sequences = np.array(batched_action_sequence) |
|
|
| if args.protocol == "comprehensive": |
| |
| action_sequences = None |
|
|
| full_image_sequence = np.concatenate((image_context, image_sequences), axis=1) |
|
|
| dreamt_images, actions_during_dream = do_dreaming(model, image_context, action_context, args, action_sequences=action_sequences) |
| encoded_decoded_images_batch = encode_decode_images(model, full_image_sequence) |
|
|
| pbar.update(batches) |
| dreams_created += batches |
|
|
| |
| |
| |
| for i, dream in enumerate(dreamt_images): |
| episode = episode_names[i] |
| output_file = output_path / episode.relative_to(data_path) |
| output_file.parent.mkdir(parents=True, exist_ok=True) |
| np.savez( |
| output_file, |
| context_length=args.context_length, |
| steps_to_dream=args.steps_to_dream, |
| raw_context=image_context[i], |
| dreamt_images=dream, |
| all_actions=np.concatenate((action_context[i], actions_during_dream[i])), |
| encoded_decoded_ground_truth_images=encoded_decoded_images_batch[i], |
| ) |
|
|
| video_file = str(output_file.with_suffix(".mp4")) |
| writer = ffmpegcv.VideoWriter(video_file, None, utils.DREAMING_FPS) |
| full_sequence = np.concatenate((image_context[i], dream), axis=0) |
| for frame in full_sequence: |
| img = frame.transpose(1, 2, 0).astype(np.uint8).copy() |
| |
| (text_width, _), _ = cv2.getTextSize(utils.WATERMARK_TEXT, utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_THICKNESS) |
| x = img.shape[1] - text_width - 10 |
| y = img.shape[0] - 10 |
| cv2.putText(img, utils.WATERMARK_TEXT, (x, y), utils.WATERMARK_FONT, utils.WATERMARK_FONT_SCALE, utils.WATERMARK_FONT_COLOR, utils.WATERMARK_FONT_THICKNESS) |
|
|
| |
| pil_image = Image.fromarray(img) |
| pil_image.info['Id'] = 0x0131 |
| pil_image.info['Type'] = 2 |
| pil_image.info['Value'] = utils.PROGRAM_NAME.encode("utf-8") |
| pil_image.info['Len'] = len(utils.PROGRAM_NAME) + 1 |
|
|
| |
| cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) |
| writer.write(cv_image) |
| writer.release() |
| add_video_metadata(video_file, args.metadata_config) |
|
|
| if __name__ == "__main__": |
| args = parser.parse_args() |
| main(args) |
|
|