anhth
Update better UI
89fffaa
import torch
from torchvision.transforms import v2
import gradio as gr
from PIL import Image
from colorizer import ColorComicNet, MODEL_CFG
from utils import smart_padding, remove_padding
# Define the transformation pipeline for the input image
TRANSFORM = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.5], std=[0.5])
])
# Image preprocessing and postprocessing functions
def preprocess_image(image: Image.Image, divisor=16):
""" Preprocess the input PIL image for the model. """
image = image.convert('RGB')
image_tensor = TRANSFORM(image).unsqueeze(0) # Shape: (1, 3, H, W)
image_tensor, padding = smart_padding(image_tensor, divisor=divisor)
return image_tensor, padding
def postprocess_output(output_tensor, padding):
""" Postprocess the model output tensor to a PIL image. """
output_tensor = remove_padding(output_tensor, padding)
output_tensor = (output_tensor + 1) / 2 # Scale back to [0, 1]
output_image = output_tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).numpy() # Shape: (H, W, C)
return output_image
# Define the colorization function
def colorize_image(gray_image: Image.Image):
""" Colorize a single grayscale image using the model. """
with torch.no_grad():
# Preprocess
input_tensor, padding = preprocess_image(gray_image, divisor=64)
# Inference
output = model(input_tensor)
# Postprocess
output_image = postprocess_output(output, padding)
return output_image
# Initialize the model
model = ColorComicNet(**MODEL_CFG)
model.load_state_dict(torch.load("./weights/colorizer.pth", map_location=torch.device('cpu')))
model.fuse()
model.eval()
# Create the Gradio interface
with gr.Blocks() as demo:
# Header
gr.Markdown("# 🎨 Comic Colorization")
gr.Markdown("Bring your grayscale comics to life with **ColorComicNet**")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_image = gr.Image(
label="📥 Upload Grayscale Image",
type="pil",
)
colorize_button = gr.Button(
"✨ Colorize Image",
elem_classes="button-primary"
)
with gr.Column(scale=1):
output_image = gr.Image(
label="📤 Colorized Result",
type="numpy",
)
# Example section
gr.Markdown("### 🖼️ Try an example")
examples = gr.Examples(
examples=[
["./examples/gray.jpg"],
["./examples/gray_2.jpg"],
["./examples/gray_4.jpg"],
],
inputs=input_image
)
# Footer
gr.Markdown("---")
# Interaction
colorize_button.click(
fn=colorize_image,
inputs=input_image,
outputs=output_image
)
demo.launch()