| |
| |
| from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler |
| from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed |
| from flax.training.common_utils import shard |
| from flax.jax_utils import replicate |
| from diffusers.utils import load_image |
| import jax.numpy as jnp |
| import jax |
| import cv2 |
| from PIL import Image |
| import numpy as np |
| import gradio as gr |
| import os |
|
|
|
|
| if gr.__version__ != "3.28.3": |
| os.system("pip uninstall -y gradio") |
| os.system("pip install gradio==3.28.3") |
|
|
| title_description = """ |
| # Unlimited Controlled Domain Randomization Network for Bridging the Sim2Real Gap in Robotics |
| |
| """ |
|
|
| description = """ |
| While existing ControlNet and public diffusion models are predominantly geared towards high-resolution images (512x512 or above) and intricate artistic detail generation, there's an untapped potential of these models in Automatic Data Augmentation (ADA). |
| By harnessing the inherent variance in prompt-conditioned generated images, we can significantly boost the visual diversity of training samples for computer vision pipelines. |
| This is particularly relevant in the field of robotics, where deep learning is increasingly playing a pivotal role in training policies for robotic manipulation from images. |
| |
| In this HuggingFace sprint, we present UCDR-Net (Unlimited Controlled Domain Randomization Network), a novel CannyEdge mini-ControlNet trained on Stable Diffusion 1.5 with mixed datasets. |
| Our model generates photorealistic and varied renderings from simplistic robotic simulation images, enabling real-time data augmentation for robotic vision training. |
| |
| We specifically designed UCDR-Net to be fast and composition preserving, with an emphasis on lower resolution images (128x128) for online data augmentation in typical preprocessing pipelines. |
| Our choice of Canny Edge version of ControlNet ensures shape and structure preservation in the image, which is crucial for visuomotor policy learning. |
| |
| We trained ControlNet from scratch using only 128x128 images, preprocessing the training datasets and extracting Canny Edge maps. |
| We then trained four Control-Nets with different mixtures of 2 datasets (Coyo-700M and Bridge Data) and showcased the results. |
| * [Coyo-700M](https://github.com/kakaobrain/coyo-dataset) |
| * [Bridge](https://sites.google.com/view/bridgedata) |
| |
| Model Description and Training Process: Please refer to the readme file attached to the model repository. |
| |
| Model Repository: [ControlNet repo](https://huggingface.co/Baptlem/UCDR-Net_models) |
| |
| """ |
|
|
| traj_description = """ |
| To demonstrate UCDR-Net's capabilities, we generated a trajectory of our simulated robotic environment and presented the resulting videos for each model. |
| We batched the frames for each video and performed independent inference for each frame, which explains the "wobbling" effect. |
| Prompt used for every video: "A robotic arm with a gripper and a small cube on a table, super realistic, industrial background" |
| |
| """ |
|
|
| perfo_description = """ |
| Our model has been benchmarked on a node of 8 A100 80Go GPUs, achieving an impressive 170 FPS image generation rate! |
| |
| To make the benchmark, we loaded one of our model on every GPUs of the node. We then retrieve an episode of our simulation. |
| For every frame of the episode, we preprocess the image (resize, canny, …) and process the Canny image on the GPUs. |
| We repeated this procedure for different Batch Size (BS). |
| |
| We can see that the greater the BS the greater the FPS. By increazing the BS, we take advantage of the parallelization of the GPUs. |
| """ |
|
|
| conclusion_description = """ |
| UCDR-Net stands as a natural development in bridging the Sim2Real gap in robotics by providing real-time data augmentation for training visual policies. |
| We are excited to share our work with the HuggingFace community and contribute to the advancement of robotic vision training techniques. |
| |
| """ |
|
|
| def create_key(seed=0): |
| return jax.random.PRNGKey(seed) |
|
|
| def load_controlnet(controlnet_version): |
| controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( |
| "Baptlem/UCDR-Net_models", |
| subfolder=controlnet_version, |
| from_flax=True, |
| dtype=jnp.float32, |
| ) |
| return controlnet, controlnet_params |
|
|
|
|
| def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): |
| controlnet, controlnet_params = load_controlnet(controlnet_version) |
|
|
| scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( |
| sb_path, |
| subfolder="scheduler" |
| ) |
| |
| pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( |
| sb_path, |
| controlnet=controlnet, |
| revision="flax", |
| dtype=jnp.bfloat16 |
| ) |
| |
| pipe.scheduler = scheduler |
| params["controlnet"] = controlnet_params |
| params["scheduler"] = scheduler_params |
| return pipe, params |
|
|
| |
|
|
| controlnet_path = "Baptlem/UCDR-Net_models" |
| controlnet_version = "coyo-500k" |
|
|
| |
| low_threshold = 100 |
| high_threshold = 200 |
|
|
| print(os.path.abspath('.')) |
| print(os.listdir(".")) |
| print("Gradio version:", gr.__version__) |
| |
| |
| |
| print("Loaded models...") |
| def pipe_inference( |
| image, |
| prompt, |
| is_canny=False, |
| num_samples=4, |
| resolution=128, |
| num_inference_steps=50, |
| guidance_scale=7.5, |
| model="coyo-500k", |
| seed=0, |
| negative_prompt="", |
| ): |
| print("Loading pipe") |
| pipe, params = load_sb_pipe(model) |
| |
| if not isinstance(image, np.ndarray): |
| image = np.array(image) |
|
|
| processed_image = resize_image(image, resolution) |
| |
| if not is_canny: |
| resized_image, processed_image = preprocess_canny(processed_image, resolution) |
|
|
| rng = create_key(seed) |
| rng = jax.random.split(rng, jax.device_count()) |
|
|
| prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) |
| negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) |
| processed_image = pipe.prepare_image_inputs([processed_image] * num_samples) |
| |
| p_params = replicate(params) |
| prompt_ids = shard(prompt_ids) |
| negative_prompt_ids = shard(negative_prompt_ids) |
| processed_image = shard(processed_image) |
| print("Inference...") |
| output = pipe( |
| prompt_ids=prompt_ids, |
| image=processed_image, |
| params=p_params, |
| prng_seed=rng, |
| num_inference_steps=num_inference_steps, |
| guidance_scale=guidance_scale, |
| neg_prompt_ids=negative_prompt_ids, |
| jit=True, |
| ).images |
| print("Finished inference...") |
| |
| |
| |
| |
| |
| |
| |
|
|
| all_outputs = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) |
| return all_outputs |
|
|
| def resize_image(image, resolution): |
| if not isinstance(image, np.ndarray): |
| image = np.array(image) |
| h, w = image.shape[:2] |
| ratio = w/h |
| if ratio > 1 : |
| resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) |
| elif ratio < 1 : |
| resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) |
| else: |
| resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) |
| |
| return Image.fromarray(resized_image) |
| |
| |
| def preprocess_canny(image, resolution=128): |
| if not isinstance(image, np.ndarray): |
| image = np.array(image) |
| |
| processed_image = cv2.Canny(image, low_threshold, high_threshold) |
| processed_image = processed_image[:, :, None] |
| processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) |
|
|
| resized_image = Image.fromarray(image) |
| processed_image = Image.fromarray(processed_image) |
| return resized_image, processed_image |
|
|
|
|
| def create_demo(process, max_images=12, default_num_images=4): |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| gr.Markdown(title_description) |
| with gr.Row(): |
| with gr.Column(): |
| input_image = gr.Image(source='upload', type='numpy') |
| prompt = gr.Textbox(label='Prompt') |
| run_button = gr.Button(label='Run') |
| with gr.Accordion('Advanced options', open=False): |
| is_canny = gr.Checkbox( |
| label='Is canny', value=False) |
| num_samples = gr.Slider(label='Images', |
| minimum=1, |
| maximum=max_images, |
| value=default_num_images, |
| step=1) |
| """ |
| canny_low_threshold = gr.Slider( |
| label='Canny low threshold', |
| minimum=1, |
| maximum=255, |
| value=100, |
| step=1) |
| canny_high_threshold = gr.Slider( |
| label='Canny high threshold', |
| minimum=1, |
| maximum=255, |
| value=200, |
| step=1) |
| """ |
| resolution = gr.Slider(label='Resolution', |
| minimum=128, |
| maximum=128, |
| value=128, |
| step=1) |
| num_steps = gr.Slider(label='Steps', |
| minimum=1, |
| maximum=100, |
| value=20, |
| step=1) |
| guidance_scale = gr.Slider(label='Guidance Scale', |
| minimum=0.1, |
| maximum=30.0, |
| value=7.5, |
| step=0.1) |
| model = gr.Dropdown(choices=["coyo-500k", "bridge-2M", "coyo1M-bridge2M", "coyo2M-bridge325k"], |
| value="coyo-500k", |
| label="Model used for inference", |
| info="Find every models at https://huggingface.co/Baptlem/UCDR-Net_models") |
| seed = gr.Slider(label='Seed', |
| minimum=-1, |
| maximum=2147483647, |
| step=1, |
| randomize=True) |
| n_prompt = gr.Textbox( |
| label='Negative Prompt', |
| value= |
| 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
| ) |
| with gr.Column(): |
| result = gr.Gallery(label='Output', |
| show_label=False, |
| elem_id='gallery').style(grid=2, |
| height='auto') |
|
|
| with gr.Row(): |
| gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k_64.avi", |
| format="avi", |
| interactive=False).style(height=512, |
| width=512) |
| |
| with gr.Row(): |
| gr.Markdown(description) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown(traj_description) |
| with gr.Column(): |
| gr.Video("./trajectory_hf/trajectory.avi", |
| format="avi", |
| interactive=False) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("Trajectory processed with coyo-500k model :") |
| with gr.Column(): |
| gr.Video("./trajectory_hf/trajectory_coyo-500k.avi", |
| format="avi", |
| interactive=False) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("Trajectory processed with bridge-2M model :") |
| with gr.Column(): |
| gr.Video("./trajectory_hf/trajectory_bridge-2M.avi", |
| format="avi", |
| interactive=False) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("Trajectory processed with coyo1M-bridge2M model :") |
| with gr.Column(): |
| gr.Video("./trajectory_hf/trajectory_coyo1M-bridge2M.avi", |
| format="avi", |
| interactive=False) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("Trajectory processed with coyo2M-bridge325k model :") |
| with gr.Column(): |
| gr.Video("./trajectory_hf/trajectory_coyo2M-bridge325k.avi", |
| format="avi", |
| interactive=False) |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown(perfo_description) |
| with gr.Column(): |
| gr.Image("./perfo_rtx.png", |
| interactive=False) |
|
|
| with gr.Row(): |
| gr.Markdown(conclusion_description) |
| |
| |
| |
| inputs = [ |
| input_image, |
| prompt, |
| is_canny, |
| num_samples, |
| resolution, |
| |
| |
| num_steps, |
| guidance_scale, |
| model, |
| seed, |
| n_prompt, |
| ] |
| prompt.submit(fn=process, inputs=inputs, outputs=result) |
| run_button.click(fn=process, |
| inputs=inputs, |
| outputs=result, |
| api_name='canny') |
| |
| return demo |
|
|
| if __name__ == '__main__': |
|
|
| pipe_inference |
| demo = create_demo(pipe_inference) |
| demo.queue().launch() |
| |
| |