| import os |
| import shutil |
| import subprocess |
| import uuid |
| from PIL import Image |
| import gradio as gr |
|
|
| UPLOAD_DIR = "./sessions" |
| RESULTS_DIR = "./results" |
| CHECKPOINTS_DIR = "./checkpoints/SingleImageReflectionRemoval" |
| SAMPLE_DIR = "./sample_images" |
|
|
| os.makedirs(RESULTS_DIR, exist_ok=True) |
| os.makedirs(CHECKPOINTS_DIR, exist_ok=True) |
| os.makedirs(SAMPLE_DIR, exist_ok=True) |
|
|
| from huggingface_hub import hf_hub_download |
| from shutil import copyfile |
|
|
| REPO_ID = "hasnafk/SingleImageReflectionRemoval" |
| MODEL_FILE = "310_net_G.pth" |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE, cache_dir=CHECKPOINTS_DIR) |
|
|
| expected_model_path = os.path.join(CHECKPOINTS_DIR, MODEL_FILE) |
| if not os.path.exists(expected_model_path): |
| copyfile(model_path, expected_model_path) |
|
|
| def generate_session_id(): |
| return str(uuid.uuid4()) |
|
|
| def randomize_file_name(original_name): |
| extension = os.path.splitext(original_name)[1] |
| new_name = f"{uuid.uuid4().hex}{extension}" |
| return new_name |
|
|
| def clear_session_files(session_id): |
| session_dir = os.path.join(UPLOAD_DIR, session_id) |
| if os.path.exists(session_dir): |
| shutil.rmtree(session_dir) |
|
|
| def reflection_removal(input_image, preprocess_type="resize_and_crop"): |
| if preprocess_type not in ["resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none"]: |
| return "Invalid preprocessing type selected. Please choose a valid option." |
| |
| print("Preprocessing Type:", preprocess_type) |
| print("Input Image:", input_image) |
|
|
| session_id = generate_session_id() |
| session_dir = os.path.join(UPLOAD_DIR, session_id) |
| upload_dir = os.path.join(session_dir, "uploads") |
| os.makedirs(upload_dir, exist_ok=True) |
|
|
| if not input_image or not os.path.exists(input_image): |
| return "No image was provided or file was cleared. Please upload a valid image." |
|
|
| randomized_name = randomize_file_name(os.path.basename(input_image)) |
| file_path = os.path.join(upload_dir, randomized_name) |
| shutil.copy(input_image, file_path) |
|
|
| input_filename = os.path.splitext(randomized_name)[0] |
|
|
| cmd = [ |
| "python", "test.py", |
| "--dataroot", upload_dir, |
| "--name", "SingleImageReflectionRemoval", |
| "--model", "test", "--netG", "unet_256", |
| "--direction", "AtoB", "--dataset_mode", "single", |
| "--norm", "batch", "--epoch", "310", |
| "--num_test", "1", |
| "--gpu_ids", "-1", |
| "--preprocess", preprocess_type |
| ] |
| attempt = 0 |
| while True: |
| attempt += 1 |
| try: |
| subprocess.run(cmd, check=True) |
| break |
| except subprocess.CalledProcessError as e: |
| cmd = [ |
| "python", "test.py", |
| "--dataroot", upload_dir, |
| "--name", "SingleImageReflectionRemoval", |
| "--model", "test", "--netG", "unet_256", |
| "--direction", "AtoB", "--dataset_mode", "single", |
| "--norm", "batch", "--epoch", "310", |
| "--num_test", "1", |
| "--gpu_ids", "-1", |
| ] |
| if attempt > 2: |
| return "No results found. Please try again with a different image." |
|
|
| output_image = None |
| for root, _, files in os.walk(RESULTS_DIR): |
| for file in files: |
| if file.startswith(input_filename) and file.endswith("_fake.png"): |
| result_path = os.path.join(root, file) |
| output_image = Image.open(result_path) |
| |
| if preprocess_type not in ["crop", "none"]: |
| input_image = Image.open(input_image) |
| output_image = output_image.resize(input_image.size) |
| |
| os.remove(result_path) |
| elif file.startswith(input_filename) and file.endswith("_real.png"): |
| real_path = os.path.join(root, file) |
| os.remove(real_path) |
|
|
| clear_session_files(session_id) |
|
|
| if output_image: |
| return output_image |
| return "No results found." |
|
|
| def use_sample_image(sample_image_name): |
| sample_image_path = os.path.join(SAMPLE_DIR, sample_image_name) |
| if not os.path.exists(sample_image_path): |
| return "Sample image not found." |
| return sample_image_path |
|
|
| sample_images = [ |
| file for file in os.listdir(SAMPLE_DIR) |
| if file.endswith((".jpg", ".jpeg", ".png")) |
| ] |
|
|
| preprocess_options = [ |
| "resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "none" |
| ] |
|
|
| iface = gr.Interface( |
| fn=lambda input_image, preprocess_type: reflection_removal(input_image, preprocess_type or "resize_and_crop"), |
| inputs=[ |
| gr.Image(type="filepath", label="Upload Image (JPG/PNG)"), |
| gr.Dropdown(choices=preprocess_options, label="Preprocessing Type", value="resize_and_crop") |
| ], |
| outputs=gr.Image(label="Result after Reflection Removal"), |
| examples=[ |
| [os.path.join(SAMPLE_DIR, img), "resize_and_crop"] |
| for img in sample_images |
| ], |
| title="Reflection Remover with Pix2Pix", |
| description="Upload images to remove reflections using a Pix2Pix model. You can also try the sample images below.", |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|