Spaces:
Running
Running
| import torch | |
| from torchvision.transforms import v2 | |
| import gradio as gr | |
| from PIL import Image | |
| from colorizer import ColorComicNet, MODEL_CFG | |
| from utils import smart_padding, remove_padding | |
| # Define the transformation pipeline for the input image | |
| TRANSFORM = v2.Compose([ | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| # Image preprocessing and postprocessing functions | |
| def preprocess_image(image: Image.Image, divisor=16): | |
| """ Preprocess the input PIL image for the model. """ | |
| image = image.convert('RGB') | |
| image_tensor = TRANSFORM(image).unsqueeze(0) # Shape: (1, 3, H, W) | |
| image_tensor, padding = smart_padding(image_tensor, divisor=divisor) | |
| return image_tensor, padding | |
| def postprocess_output(output_tensor, padding): | |
| """ Postprocess the model output tensor to a PIL image. """ | |
| output_tensor = remove_padding(output_tensor, padding) | |
| output_tensor = (output_tensor + 1) / 2 # Scale back to [0, 1] | |
| output_image = output_tensor.clamp(0, 1).squeeze(0).permute(1, 2, 0).numpy() # Shape: (H, W, C) | |
| return output_image | |
| # Define the colorization function | |
| def colorize_image(gray_image: Image.Image): | |
| """ Colorize a single grayscale image using the model. """ | |
| with torch.no_grad(): | |
| # Preprocess | |
| input_tensor, padding = preprocess_image(gray_image, divisor=64) | |
| # Inference | |
| output = model(input_tensor) | |
| # Postprocess | |
| output_image = postprocess_output(output, padding) | |
| return output_image | |
| # Initialize the model | |
| model = ColorComicNet(**MODEL_CFG) | |
| model.load_state_dict(torch.load("./weights/colorizer.pth", map_location=torch.device('cpu'))) | |
| model.fuse() | |
| model.eval() | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| # Header | |
| gr.Markdown("# 🎨 Comic Colorization") | |
| gr.Markdown("Bring your grayscale comics to life with **ColorComicNet**") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="📥 Upload Grayscale Image", | |
| type="pil", | |
| ) | |
| colorize_button = gr.Button( | |
| "✨ Colorize Image", | |
| elem_classes="button-primary" | |
| ) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="📤 Colorized Result", | |
| type="numpy", | |
| ) | |
| # Example section | |
| gr.Markdown("### 🖼️ Try an example") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["./examples/gray.jpg"], | |
| ["./examples/gray_2.jpg"], | |
| ["./examples/gray_4.jpg"], | |
| ], | |
| inputs=input_image | |
| ) | |
| # Footer | |
| gr.Markdown("---") | |
| # Interaction | |
| colorize_button.click( | |
| fn=colorize_image, | |
| inputs=input_image, | |
| outputs=output_image | |
| ) | |
| demo.launch() |