| |
| |
| |
| |
| |
|
|
| import os |
| import gradio as gr |
| import spaces |
| import torch |
| import numpy as np |
| from PIL import Image |
| import tempfile |
| import gc |
| from datetime import datetime |
|
|
| from addit_flux_pipeline import AdditFluxPipeline |
| from addit_flux_transformer import AdditFluxTransformer2DModel |
| from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler |
| from addit_methods import add_object_generated, add_object_real |
|
|
| |
| pipe = None |
| device = None |
| original_image_size = None |
|
|
| |
| print("Initializing ADDIT model...") |
| try: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| my_transformer = AdditFluxTransformer2DModel.from_pretrained( |
| "black-forest-labs/FLUX.1-dev", |
| subfolder="transformer", |
| torch_dtype=torch.bfloat16 |
| ) |
| |
| |
| pipe = AdditFluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-dev", |
| transformer=my_transformer, |
| torch_dtype=torch.bfloat16 |
| ).to(device) |
| |
| |
| pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) |
| |
| print("Model initialized successfully!") |
| |
| except Exception as e: |
| print(f"Error initializing model: {str(e)}") |
| print("The application will start but model functionality will be unavailable.") |
|
|
| def validate_inputs(prompt_source, prompt_target, subject_token): |
| """Validate user inputs""" |
| if not prompt_source.strip(): |
| return "Source prompt cannot be empty" |
| if not prompt_target.strip(): |
| return "Target prompt cannot be empty" |
| if not subject_token.strip(): |
| return "Subject token cannot be empty" |
| if subject_token not in prompt_target: |
| return f"Subject token '{subject_token}' must appear in the target prompt" |
| return None |
|
|
| def resize_and_crop_image(image): |
| """ |
| Resize and center crop image to 1024x1024. |
| Returns the processed image, a message about what was done, and original size info. |
| """ |
| if image is None: |
| return None, "", None |
| |
| original_width, original_height = image.size |
| original_size = (original_width, original_height) |
| |
| |
| if original_width == 1024 and original_height == 1024: |
| return image, "", original_size |
| |
| |
| scale = 1024 / min(original_width, original_height) |
| new_width = int(original_width * scale) |
| new_height = int(original_height * scale) |
| |
| |
| resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| |
| |
| left = (new_width - 1024) // 2 |
| top = (new_height - 1024) // 2 |
| right = left + 1024 |
| bottom = top + 1024 |
| |
| cropped_image = resized_image.crop((left, top, right, bottom)) |
| |
| |
| if new_width == 1024 and new_height == 1024: |
| message = f"<div style='background-color: #e8f5e8; border: 1px solid #4caf50; border-radius: 5px; padding: 8px; margin-bottom: 10px;'><span style='color: #2e7d32; font-weight: bold;'>✅ Image resized to 1024×1024</span></div>" |
| else: |
| message = f"<div style='background-color: #e8f5e8; border: 1px solid #4caf50; border-radius: 5px; padding: 8px; margin-bottom: 10px;'><span style='color: #2e7d32; font-weight: bold;'>✅ Image resized and center cropped to 1024×1024</span></div>" |
| |
| return cropped_image, message, original_size |
|
|
| def handle_image_upload(image): |
| """Handle image upload and store original size globally""" |
| global original_image_size |
| |
| if image is None: |
| original_image_size = None |
| return None, "" |
| |
| |
| original_image_size = image.size |
| |
| |
| processed_image, message, _ = resize_and_crop_image(image) |
| return processed_image, message |
|
|
| @spaces.GPU |
| def process_generated_image( |
| prompt_source, |
| prompt_target, |
| subject_token, |
| seed_src, |
| seed_obj, |
| extended_scale, |
| structure_transfer_step, |
| blend_steps, |
| localization_model, |
| progress=gr.Progress(track_tqdm=True) |
| ): |
| """Process generated image with ADDIT""" |
| global pipe |
| |
| if pipe is None: |
| return None, None, "Model not initialized. Please restart the application." |
| |
| |
| error_msg = validate_inputs(prompt_source, prompt_target, subject_token) |
| if error_msg: |
| return None, None, error_msg |
| |
| |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| print(f"\n[{current_time}] Starting Generated Image Processing") |
| print(f"Source Prompt: '{prompt_source}'") |
| print(f"Target Prompt: '{prompt_target}'") |
| print(f"Subject Token: '{subject_token}'") |
| print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") |
| print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") |
| print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") |
| |
| try: |
| |
| if blend_steps.strip(): |
| blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] |
| else: |
| blend_steps_list = [] |
| |
| |
| src_image, edited_image = add_object_generated( |
| pipe=pipe, |
| prompt_source=prompt_source, |
| prompt_object=prompt_target, |
| subject_token=subject_token, |
| seed_src=seed_src, |
| seed_obj=seed_obj, |
| show_attention=False, |
| extended_scale=extended_scale, |
| structure_transfer_step=structure_transfer_step, |
| blend_steps=blend_steps_list, |
| localization_model=localization_model, |
| display_output=False |
| ) |
| |
| return src_image, edited_image, "Images generated successfully!" |
| |
| except Exception as e: |
| error_msg = f"Error generating images: {str(e)}" |
| print(error_msg) |
| return None, None, error_msg |
|
|
| @spaces.GPU |
| def process_real_image( |
| source_image, |
| prompt_source, |
| prompt_target, |
| subject_token, |
| seed_src, |
| seed_obj, |
| extended_scale, |
| structure_transfer_step, |
| blend_steps, |
| localization_model, |
| use_offset, |
| disable_inversion, |
| progress=gr.Progress(track_tqdm=True) |
| ): |
| """Process real image with ADDIT""" |
| global pipe |
| |
| if pipe is None: |
| return None, None, "Model not initialized. Please restart the application." |
| |
| if source_image is None: |
| return None, None, "Please upload a source image" |
| |
| |
| error_msg = validate_inputs(prompt_source, prompt_target, subject_token) |
| if error_msg: |
| return None, None, error_msg |
| |
| |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| print(f"\n[{current_time}] Starting Real Image Processing") |
| if original_image_size: |
| print(f"Original uploaded image size: {original_image_size[0]}×{original_image_size[1]}") |
| print(f"Source Image Size: {source_image.size}") |
| print(f"Source Prompt: '{prompt_source}'") |
| print(f"Target Prompt: '{prompt_target}'") |
| print(f"Subject Token: '{subject_token}'") |
| print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") |
| print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") |
| print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") |
| print(f"Use Offset: {use_offset}, Disable Inversion: {disable_inversion}") |
| |
| try: |
| |
| source_image = source_image.resize((1024, 1024)) |
| |
| |
| if blend_steps.strip(): |
| blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] |
| else: |
| blend_steps_list = [] |
| |
| |
| src_image, edited_image = add_object_real( |
| pipe=pipe, |
| source_image=source_image, |
| prompt_source=prompt_source, |
| prompt_object=prompt_target, |
| subject_token=subject_token, |
| seed_src=seed_src, |
| seed_obj=seed_obj, |
| extended_scale=extended_scale, |
| structure_transfer_step=structure_transfer_step, |
| blend_steps=blend_steps_list, |
| localization_model=localization_model, |
| use_offset=use_offset, |
| show_attention=False, |
| use_inversion=not disable_inversion, |
| display_output=False |
| ) |
| |
| return src_image, edited_image, "Image edited successfully!" |
| |
| except Exception as e: |
| error_msg = f"Error processing image: {str(e)}" |
| print(error_msg) |
| return None, None, error_msg |
|
|
| def create_interface(): |
| """Create the Gradio interface""" |
| |
| |
| model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable" |
| |
| with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo: |
| gr.HTML(f""" |
| <div style="text-align: center; margin-bottom: 20px;"> |
| <h1>🎨 Add-it: Training-Free Object Insertion</h1> |
| <p>Add objects to images using pretrained diffusion models</p> |
| <p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> | |
| <a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> | |
| <a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p> |
| <p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p> |
| </div> |
| """) |
| |
| |
| with gr.Tabs(): |
| |
| with gr.TabItem("🎭 Generated Images"): |
| gr.Markdown("### Generate a base image and add objects to it") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gen_prompt_source = gr.Textbox( |
| label="Source Prompt", |
| placeholder="A photo of a cat sitting on the couch", |
| value="A photo of a cat sitting on the couch" |
| ) |
| gen_prompt_target = gr.Textbox( |
| label="Target Prompt", |
| placeholder="A photo of a cat wearing a blue hat sitting on the couch", |
| value="A photo of a cat wearing a blue hat sitting on the couch" |
| ) |
| gen_subject_token = gr.Textbox( |
| label="Subject Token", |
| placeholder="hat", |
| value="hat", |
| info="Single token representing the object to add **(must appear in target prompt)**" |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| gen_seed_src = gr.Number(label="Source Seed", value=1, precision=0) |
| gen_seed_obj = gr.Number(label="Object Seed", value=42, precision=0) |
| gen_extended_scale = gr.Slider( |
| label="Extended Scale", |
| minimum=1.0, |
| maximum=1.3, |
| value=1.05, |
| step=0.01 |
| ) |
| gen_structure_transfer_step = gr.Slider( |
| label="Structure Transfer Step", |
| minimum=0, |
| maximum=10, |
| value=2, |
| step=1 |
| ) |
| gen_blend_steps = gr.Textbox( |
| label="Blend Steps", |
| value="15", |
| info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" |
| ) |
| gen_localization_model = gr.Dropdown( |
| label="Localization Model", |
| choices=[ |
| "attention_points_sam", |
| "attention", |
| "attention_box_sam", |
| "attention_mask_sam", |
| "grounding_sam" |
| ], |
| value="attention_points_sam" |
| ) |
| |
| gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary") |
| |
| with gr.Column(scale=2): |
| with gr.Row(): |
| gen_src_output = gr.Image(label="Generated Source Image", type="pil") |
| gen_edited_output = gr.Image(label="Edited Image", type="pil") |
| gen_status = gr.Textbox(label="Status", interactive=False) |
| |
| gen_submit_btn.click( |
| fn=process_generated_image, |
| inputs=[ |
| gen_prompt_source, gen_prompt_target, gen_subject_token, |
| gen_seed_src, gen_seed_obj, gen_extended_scale, |
| gen_structure_transfer_step, gen_blend_steps, |
| gen_localization_model |
| ], |
| outputs=[gen_src_output, gen_edited_output, gen_status] |
| ) |
| |
| |
| gr.Examples( |
| examples=[ |
| ["An empty throne", "A king sitting on a throne", "king"], |
| ["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"], |
| ["A photo of a cat sitting on the couch", "A photo of a cat wearing a blue hat sitting on the couch", "hat"], |
| ["A car driving through an empty street", "A pink car driving through an empty street", "car"] |
| ], |
| inputs=[ |
| gen_prompt_source, gen_prompt_target, gen_subject_token |
| ], |
| label="Example Prompts" |
| ) |
| |
| |
| with gr.TabItem("📸 Real Images"): |
| gr.Markdown("### Upload an image and add objects to it") |
| gr.HTML("<p style='color: orange; font-weight: bold; margin: -15px -10px;'>Note: Images will be automatically resized and center cropped to 1024×1024 pixels.</p>") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| real_image_status = gr.HTML(visible=False) |
| real_source_image = gr.Image(label="Source Image", type="pil") |
| real_prompt_source = gr.Textbox( |
| label="Source Prompt", |
| placeholder="A photo of a bed in a dark room", |
| value="A photo of a bed in a dark room" |
| ) |
| real_prompt_target = gr.Textbox( |
| label="Target Prompt", |
| placeholder="A photo of a dog lying on a bed in a dark room", |
| value="A photo of a dog lying on a bed in a dark room" |
| ) |
| real_subject_token = gr.Textbox( |
| label="Subject Token", |
| placeholder="dog", |
| value="dog", |
| info="Single token representing the object to add **(must appear in target prompt)**" |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| real_seed_src = gr.Number(label="Source Seed", value=1, precision=0) |
| real_seed_obj = gr.Number(label="Object Seed", value=0, precision=0) |
| real_extended_scale = gr.Slider( |
| label="Extended Scale", |
| minimum=1.0, |
| maximum=1.3, |
| value=1.1, |
| step=0.01 |
| ) |
| real_structure_transfer_step = gr.Slider( |
| label="Structure Transfer Step", |
| minimum=0, |
| maximum=10, |
| value=4, |
| step=1 |
| ) |
| real_blend_steps = gr.Textbox( |
| label="Blend Steps", |
| value="18", |
| info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" |
| ) |
| real_localization_model = gr.Dropdown( |
| label="Localization Model", |
| choices=[ |
| "attention", |
| "attention_points_sam", |
| "attention_box_sam", |
| "attention_mask_sam", |
| "grounding_sam" |
| ], |
| value="attention" |
| ) |
| real_use_offset = gr.Checkbox(label="Use Offset", value=False) |
| real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False) |
| |
| real_submit_btn = gr.Button("🎨 Edit Image", variant="primary") |
| |
| with gr.Column(scale=2): |
| with gr.Row(): |
| real_src_output = gr.Image(label="Source Image", type="pil") |
| real_edited_output = gr.Image(label="Edited Image", type="pil") |
| real_status = gr.Textbox(label="Status", interactive=False) |
| |
| |
| real_source_image.upload( |
| fn=handle_image_upload, |
| inputs=[real_source_image], |
| outputs=[real_source_image, real_image_status] |
| ).then( |
| fn=lambda status: gr.update(visible=bool(status.strip()), value=status), |
| inputs=[real_image_status], |
| outputs=[real_image_status] |
| ) |
| |
| real_submit_btn.click( |
| fn=process_real_image, |
| inputs=[ |
| real_source_image, real_prompt_source, real_prompt_target, real_subject_token, |
| real_seed_src, real_seed_obj, real_extended_scale, |
| real_structure_transfer_step, real_blend_steps, |
| real_localization_model, real_use_offset, |
| real_disable_inversion |
| ], |
| outputs=[real_src_output, real_edited_output, real_status] |
| ) |
| |
| |
| gr.Examples( |
| examples=[ |
| [ |
| "images/bed_dark_room.jpg", |
| "A photo of a bed in a dark room", |
| "A photo of a dog lying on a bed in a dark room", |
| "dog" |
| ], |
| [ |
| "images/flower.jpg", |
| "A photo of a flower", |
| "A bee standing on a flower", |
| "bee" |
| ] |
| ], |
| inputs=[ |
| real_source_image, real_prompt_source, real_prompt_target, real_subject_token |
| ], |
| label="Example Images & Prompts" |
| ) |
| |
| |
| with gr.Accordion("💡 Tips for Better Results", open=False): |
| gr.Markdown(""" |
| - **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert |
| - **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results |
| - **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance |
| - **Object Placement Issues**: If the object is not added to the image: |
| - Try **decreasing** Structure Transfer Step |
| - Try **increasing** Extended Scale |
| - **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list |
| """) |
| |
| return demo |
|
|
| demo = create_interface() |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True |
| ) |