| from PIL import Image |
| import numpy as np |
| from rembg import remove |
| import cv2 |
| import os |
| from torchvision.transforms import GaussianBlur |
| import gradio as gr |
| import replicate |
| import requests |
| from io import BytesIO |
|
|
| def create_mask(input): |
| input_path = 'input.png' |
| bg_removed_path = 'bg_removed.png' |
| mask_name = 'blured_mask.png' |
| |
| input.save(input_path) |
| bg_removed = remove(input) |
|
|
| width, height = bg_removed.size |
| max_dim = max(width, height) |
| square_img = Image.new('RGB', (max_dim, max_dim), (255, 255, 255)) |
| paste_pos = ((max_dim - width) // 2, (max_dim - height) // 2) |
| square_img.paste(bg_removed, paste_pos) |
| |
| square_img = square_img.resize((512, 512)) |
| square_img.save(bg_removed_path) |
|
|
| img2_grayscale = square_img.convert('L') |
| img2_a = np.array(img2_grayscale) |
|
|
| mask = np.array(img2_grayscale) |
| threshhold = 0 |
| mask[img2_a==threshhold] = 1 |
| mask[img2_a>threshhold] = 0 |
|
|
| strength = 1 |
| d = int(255 * (1-strength)) |
| mask *= 255-d |
| mask += d |
|
|
| mask = Image.fromarray(mask) |
|
|
| blur = GaussianBlur(11,20) |
| mask = blur(mask) |
| mask = mask.resize((512, 512)) |
|
|
| mask.save(mask_name) |
|
|
| return Image.open(mask_name) |
|
|
|
|
| def generate_image(image, product_name, target_name): |
| mask = create_mask(image) |
| image = image.resize((512, 512)) |
| mask = mask.resize((512,512)) |
| guidance_scale=16 |
| num_samples = 1 |
|
|
| prompt = 'a product photography photo of' + product_name + ' on ' + target_name |
| |
| model = replicate.models.get("cjwbw/stable-diffusion-v2-inpainting") |
| version = model.versions.get("f9bb0632bfdceb83196e85521b9b55895f8ff3d1d3b487fd1973210c0eb30bec") |
| output = version.predict(prompt=prompt, image=open("bg_removed.png", "rb"), mask=open("blured_mask.png", "rb")) |
| response = requests.get(output[0]) |
|
|
| return Image.open(BytesIO(response.content)) |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Advertise better with AI") |
| |
| with gr.Row(): |
|
|
| with gr.Column(): |
| input_image = gr.Image(label = "Upload your product's photo", type = 'pil') |
|
|
| product_name = gr.Textbox(label="Describe your product") |
| target_name = gr.Textbox(label="Where do you want to put your product?") |
| |
|
|
| image_button = gr.Button("Generate") |
| |
| with gr.Column(): |
| image_output = gr.Image() |
| |
| image_button.click(generate_image, inputs=[input_image, product_name, target_name ], outputs=image_output, api_name='test') |
|
|
|
|
| demo.launch() |