ImageColoriser / demo_release.py
Divya-A's picture
ImageColoriser: Flask app and colorization models for Hugging Face Space
e6e4be7
import argparse
import matplotlib.pyplot as plt
from colorizers import *
# --- GUI Imports ---
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import os
def colorize_image(img_path, use_gpu=False, save_prefix='saved'):
# load colorizers
colorizer_eccv16 = eccv16(pretrained=True).eval()
colorizer_siggraph17 = siggraph17(pretrained=True).eval()
if use_gpu:
colorizer_eccv16.cuda()
colorizer_siggraph17.cuda()
img = load_img(img_path)
(tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))
if use_gpu:
tens_l_rs = tens_l_rs.cuda()
img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1))
out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
plt.imsave(f'{save_prefix}_eccv16.png', out_img_eccv16)
plt.imsave(f'{save_prefix}_siggraph17.png', out_img_siggraph17)
return img, img_bw, out_img_eccv16, out_img_siggraph17, f'{save_prefix}_eccv16.png', f'{save_prefix}_siggraph17.png'
def run_cli():
parser = argparse.ArgumentParser()
parser.add_argument('-i','--img_path', type=str, default='imgs/ansel_adams3.jpg')
parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU')
parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes')
opt = parser.parse_args()
img, img_bw, out_img_eccv16, out_img_siggraph17, out_eccv16_path, out_siggraph17_path = colorize_image(opt.img_path, opt.use_gpu, opt.save_prefix)
plt.figure(figsize=(12,8))
plt.subplot(2,2,1)
plt.imshow(img)
plt.title('Original')
plt.axis('off')
plt.subplot(2,2,2)
plt.imshow(img_bw)
plt.title('Input')
plt.axis('off')
plt.subplot(2,2,3)
plt.imshow(out_img_eccv16)
plt.title('Output (ECCV 16)')
plt.axis('off')
plt.subplot(2,2,4)
plt.imshow(out_img_siggraph17)
plt.title('Output (SIGGRAPH 17)')
plt.axis('off')
plt.show()
# --- GUI Implementation ---
def run_gui():
root = tk.Tk()
root.title('Image Colorization Demo')
root.geometry('600x400')
img_path_var = tk.StringVar()
save_prefix_var = tk.StringVar(value='saved')
use_gpu_var = tk.BooleanVar(value=False)
def select_image():
file_path = filedialog.askopenfilename(filetypes=[('Image Files', '*.jpg;*.jpeg;*.png;*.bmp')])
if file_path:
img_path_var.set(file_path)
def process_image():
img_path = img_path_var.get()
save_prefix = save_prefix_var.get()
use_gpu = use_gpu_var.get()
if not img_path:
messagebox.showerror('Error', 'Please select an image file.')
return
try:
img, img_bw, out_img_eccv16, out_img_siggraph17, out_eccv16_path, out_siggraph17_path = colorize_image(img_path, use_gpu, save_prefix)
messagebox.showinfo('Success', f'Colorized images saved as:\n{out_eccv16_path}\n{out_siggraph17_path}')
show_all_images(img, img_bw, out_img_eccv16, out_img_siggraph17)
except Exception as e:
messagebox.showerror('Error', str(e))
def show_all_images(img, img_bw, out_img_eccv16, out_img_siggraph17):
top = tk.Toplevel(root)
top.title('Input and Output Images')
# Convert numpy arrays to PIL Images if needed
def to_pil(im):
if isinstance(im, Image.Image):
return im
import numpy as np
arr = (im * 255).astype('uint8') if im.max() <= 1.0 else im.astype('uint8')
if arr.ndim == 2:
return Image.fromarray(arr, mode='L')
return Image.fromarray(arr)
pil_imgs = [to_pil(img), to_pil(img_bw), to_pil(out_img_eccv16), to_pil(out_img_siggraph17)]
titles = ['Original', 'Grayscale', 'ECCV16', 'SIGGRAPH17']
img_tks = []
for i, pil_img in enumerate(pil_imgs):
pil_img = pil_img.resize((200, 200))
img_tk = ImageTk.PhotoImage(pil_img)
img_tks.append(img_tk)
row, col = divmod(i, 2)
label = tk.Label(top, image=img_tk)
label.image = img_tk
label.grid(row=row*2, column=col, padx=10, pady=5)
title_label = tk.Label(top, text=titles[i])
title_label.grid(row=row*2+1, column=col)
tk.Label(root, text='Select Image:').pack(pady=10)
tk.Entry(root, textvariable=img_path_var, width=50).pack()
tk.Button(root, text='Browse', command=select_image).pack(pady=5)
tk.Label(root, text='Save Prefix:').pack(pady=10)
tk.Entry(root, textvariable=save_prefix_var, width=20).pack()
tk.Checkbutton(root, text='Use GPU', variable=use_gpu_var).pack(pady=5)
tk.Button(root, text='Colorize Image', command=process_image, bg='lightblue').pack(pady=20)
root.mainloop()
# --- Entry Point ---
if __name__ == '__main__':
import sys
if len(sys.argv) > 1:
run_cli()
else:
run_gui()