Spaces:
Sleeping
Sleeping
| import argparse | |
| import matplotlib.pyplot as plt | |
| from colorizers import * | |
| # --- GUI Imports --- | |
| import tkinter as tk | |
| from tkinter import filedialog, messagebox | |
| from PIL import Image, ImageTk | |
| import os | |
| def colorize_image(img_path, use_gpu=False, save_prefix='saved'): | |
| # load colorizers | |
| colorizer_eccv16 = eccv16(pretrained=True).eval() | |
| colorizer_siggraph17 = siggraph17(pretrained=True).eval() | |
| if use_gpu: | |
| colorizer_eccv16.cuda() | |
| colorizer_siggraph17.cuda() | |
| img = load_img(img_path) | |
| (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256)) | |
| if use_gpu: | |
| tens_l_rs = tens_l_rs.cuda() | |
| img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1)) | |
| out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu()) | |
| out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu()) | |
| plt.imsave(f'{save_prefix}_eccv16.png', out_img_eccv16) | |
| plt.imsave(f'{save_prefix}_siggraph17.png', out_img_siggraph17) | |
| return img, img_bw, out_img_eccv16, out_img_siggraph17, f'{save_prefix}_eccv16.png', f'{save_prefix}_siggraph17.png' | |
| def run_cli(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-i','--img_path', type=str, default='imgs/ansel_adams3.jpg') | |
| parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU') | |
| parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes') | |
| opt = parser.parse_args() | |
| img, img_bw, out_img_eccv16, out_img_siggraph17, out_eccv16_path, out_siggraph17_path = colorize_image(opt.img_path, opt.use_gpu, opt.save_prefix) | |
| plt.figure(figsize=(12,8)) | |
| plt.subplot(2,2,1) | |
| plt.imshow(img) | |
| plt.title('Original') | |
| plt.axis('off') | |
| plt.subplot(2,2,2) | |
| plt.imshow(img_bw) | |
| plt.title('Input') | |
| plt.axis('off') | |
| plt.subplot(2,2,3) | |
| plt.imshow(out_img_eccv16) | |
| plt.title('Output (ECCV 16)') | |
| plt.axis('off') | |
| plt.subplot(2,2,4) | |
| plt.imshow(out_img_siggraph17) | |
| plt.title('Output (SIGGRAPH 17)') | |
| plt.axis('off') | |
| plt.show() | |
| # --- GUI Implementation --- | |
| def run_gui(): | |
| root = tk.Tk() | |
| root.title('Image Colorization Demo') | |
| root.geometry('600x400') | |
| img_path_var = tk.StringVar() | |
| save_prefix_var = tk.StringVar(value='saved') | |
| use_gpu_var = tk.BooleanVar(value=False) | |
| def select_image(): | |
| file_path = filedialog.askopenfilename(filetypes=[('Image Files', '*.jpg;*.jpeg;*.png;*.bmp')]) | |
| if file_path: | |
| img_path_var.set(file_path) | |
| def process_image(): | |
| img_path = img_path_var.get() | |
| save_prefix = save_prefix_var.get() | |
| use_gpu = use_gpu_var.get() | |
| if not img_path: | |
| messagebox.showerror('Error', 'Please select an image file.') | |
| return | |
| try: | |
| img, img_bw, out_img_eccv16, out_img_siggraph17, out_eccv16_path, out_siggraph17_path = colorize_image(img_path, use_gpu, save_prefix) | |
| messagebox.showinfo('Success', f'Colorized images saved as:\n{out_eccv16_path}\n{out_siggraph17_path}') | |
| show_all_images(img, img_bw, out_img_eccv16, out_img_siggraph17) | |
| except Exception as e: | |
| messagebox.showerror('Error', str(e)) | |
| def show_all_images(img, img_bw, out_img_eccv16, out_img_siggraph17): | |
| top = tk.Toplevel(root) | |
| top.title('Input and Output Images') | |
| # Convert numpy arrays to PIL Images if needed | |
| def to_pil(im): | |
| if isinstance(im, Image.Image): | |
| return im | |
| import numpy as np | |
| arr = (im * 255).astype('uint8') if im.max() <= 1.0 else im.astype('uint8') | |
| if arr.ndim == 2: | |
| return Image.fromarray(arr, mode='L') | |
| return Image.fromarray(arr) | |
| pil_imgs = [to_pil(img), to_pil(img_bw), to_pil(out_img_eccv16), to_pil(out_img_siggraph17)] | |
| titles = ['Original', 'Grayscale', 'ECCV16', 'SIGGRAPH17'] | |
| img_tks = [] | |
| for i, pil_img in enumerate(pil_imgs): | |
| pil_img = pil_img.resize((200, 200)) | |
| img_tk = ImageTk.PhotoImage(pil_img) | |
| img_tks.append(img_tk) | |
| row, col = divmod(i, 2) | |
| label = tk.Label(top, image=img_tk) | |
| label.image = img_tk | |
| label.grid(row=row*2, column=col, padx=10, pady=5) | |
| title_label = tk.Label(top, text=titles[i]) | |
| title_label.grid(row=row*2+1, column=col) | |
| tk.Label(root, text='Select Image:').pack(pady=10) | |
| tk.Entry(root, textvariable=img_path_var, width=50).pack() | |
| tk.Button(root, text='Browse', command=select_image).pack(pady=5) | |
| tk.Label(root, text='Save Prefix:').pack(pady=10) | |
| tk.Entry(root, textvariable=save_prefix_var, width=20).pack() | |
| tk.Checkbutton(root, text='Use GPU', variable=use_gpu_var).pack(pady=5) | |
| tk.Button(root, text='Colorize Image', command=process_image, bg='lightblue').pack(pady=20) | |
| root.mainloop() | |
| # --- Entry Point --- | |
| if __name__ == '__main__': | |
| import sys | |
| if len(sys.argv) > 1: | |
| run_cli() | |
| else: | |
| run_gui() | |