| import os, sys |
| import cv2 |
| import time |
| import datetime, pytz |
| import gradio as gr |
| import torch |
| import numpy as np |
| from torchvision.utils import save_image |
| import json |
| import threading |
| from queue import Queue |
| from pathlib import Path |
| import shutil |
|
|
| |
| root_path = os.path.abspath('.') |
| sys.path.append(root_path) |
| from test_code.inference import super_resolve_img |
| from test_code.test_utils import load_grl, load_rrdb, load_dat |
|
|
| |
| OUTPUT_DIR = "outputs" |
| HISTORY_FILE = "history.json" |
| VIDEO_QUEUE_FILE = "video_queue.json" |
| video_queue = Queue() |
| processing_status = {} |
|
|
| |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True) |
| os.makedirs(os.path.join(OUTPUT_DIR, "videos"), exist_ok=True) |
|
|
|
|
| def auto_download_if_needed(weight_path): |
| if os.path.exists(weight_path): |
| return |
| |
| if not os.path.exists("pretrained"): |
| os.makedirs("pretrained") |
| |
| if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth": |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth") |
| os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained") |
| |
| if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth": |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth") |
| os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained") |
| |
| if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth": |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth") |
| os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained") |
| |
| if weight_path == "pretrained/4x_APISR_DAT_GAN_generator.pth": |
| os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.3.0/4x_APISR_DAT_GAN_generator.pth") |
| os.system("mv 4x_APISR_DAT_GAN_generator.pth pretrained") |
|
|
|
|
| def load_history(): |
| """Load processing history from JSON file""" |
| if os.path.exists(HISTORY_FILE): |
| with open(HISTORY_FILE, 'r') as f: |
| return json.load(f) |
| return [] |
|
|
|
|
| def save_history(history): |
| """Save processing history to JSON file""" |
| with open(HISTORY_FILE, 'w') as f: |
| json.dump(history, f, indent=2) |
|
|
|
|
| def add_to_history(input_path, output_path, model_name, process_type, status="completed"): |
| """Add a record to history""" |
| history = load_history() |
| record = { |
| "timestamp": datetime.datetime.now().isoformat(), |
| "input_path": input_path, |
| "output_path": output_path, |
| "model_name": model_name, |
| "process_type": process_type, |
| "status": status |
| } |
| history.insert(0, record) |
| save_history(history) |
|
|
|
|
| def load_generator(model_name): |
| """Load the appropriate model""" |
| if model_name == "4xGRL": |
| weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth" |
| auto_download_if_needed(weight_path) |
| generator = load_grl(weight_path, scale=4) |
| |
| elif model_name == "4xRRDB": |
| weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth" |
| auto_download_if_needed(weight_path) |
| generator = load_rrdb(weight_path, scale=4) |
| |
| elif model_name == "2xRRDB": |
| weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth" |
| auto_download_if_needed(weight_path) |
| generator = load_rrdb(weight_path, scale=2) |
| |
| elif model_name == "4xDAT": |
| weight_path = "pretrained/4x_APISR_DAT_GAN_generator.pth" |
| auto_download_if_needed(weight_path) |
| generator = load_dat(weight_path, scale=4) |
| else: |
| raise ValueError(f"Model {model_name} not supported") |
| |
| return generator.to(device='cpu') |
|
|
|
|
| def inference_image(img_path, model_name): |
| """Process a single image""" |
| try: |
| if img_path is None: |
| return None, "β Please upload an image first" |
| |
| generator = load_generator(model_name) |
| |
| print("Processing image:", img_path) |
| print("Time:", datetime.datetime.now(pytz.timezone('US/Eastern'))) |
|
|
| |
| super_resolved_img = super_resolve_img( |
| generator, img_path, output_path=None, |
| downsample_threshold=720, crop_for_4x=True |
| ) |
| |
| |
| timestamp = int(time.time() * 1000) |
| output_name = f"image_{timestamp}.png" |
| output_path = os.path.join(OUTPUT_DIR, "images", output_name) |
| save_image(super_resolved_img, output_path) |
| |
| |
| outputs = cv2.imread(output_path) |
| outputs = cv2.cvtColor(outputs, cv2.COLOR_BGR2RGB) |
| |
| |
| add_to_history(img_path, output_path, model_name, "image") |
| |
| return outputs, f"β
Saved to: {output_path}" |
| |
| except Exception as error: |
| return None, f"β Error: {str(error)}" |
|
|
|
|
| def process_video_frame_by_frame(video_path, model_name, task_id): |
| """Process video frame by frame""" |
| try: |
| processing_status[task_id] = {"status": "processing", "progress": 0} |
| |
| |
| generator = load_generator(model_name) |
| |
| |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise ValueError("Cannot open video file") |
| |
| |
| fps = cap.get(cv2.CAP_PROP_FPS) |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
| |
| timestamp = int(time.time() * 1000) |
| output_name = f"video_{timestamp}.mp4" |
| output_path = os.path.join(OUTPUT_DIR, "videos", output_name) |
| |
| |
| temp_dir = f"temp_frames_{timestamp}" |
| os.makedirs(temp_dir, exist_ok=True) |
| |
| |
| frame_count = 0 |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| |
| |
| temp_frame_path = os.path.join(temp_dir, f"frame_{frame_count:06d}.png") |
| cv2.imwrite(temp_frame_path, frame) |
| |
| |
| super_resolved_img = super_resolve_img( |
| generator, temp_frame_path, output_path=None, |
| downsample_threshold=720, crop_for_4x=True |
| ) |
| |
| |
| output_frame_path = os.path.join(temp_dir, f"output_{frame_count:06d}.png") |
| save_image(super_resolved_img, output_frame_path) |
| |
| frame_count += 1 |
| progress = int((frame_count / total_frames) * 100) |
| processing_status[task_id] = {"status": "processing", "progress": progress} |
| |
| print(f"Task {task_id}: Processed frame {frame_count}/{total_frames} ({progress}%)") |
| |
| cap.release() |
| |
| |
| print(f"Task {task_id}: Combining frames into video...") |
| processing_status[task_id] = {"status": "encoding", "progress": 100} |
| |
| os.system(f"ffmpeg -framerate {fps} -i {temp_dir}/output_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}") |
| |
| |
| shutil.rmtree(temp_dir) |
| |
| processing_status[task_id] = {"status": "completed", "progress": 100, "output": output_path} |
| add_to_history(video_path, output_path, model_name, "video") |
| |
| print(f"Task {task_id}: Completed! Output: {output_path}") |
| |
| except Exception as error: |
| processing_status[task_id] = {"status": "error", "error": str(error)} |
| print(f"Task {task_id}: Error - {error}") |
|
|
|
|
| def video_queue_worker(): |
| """Background worker to process video queue""" |
| print("Video queue worker started...") |
| while True: |
| try: |
| task = video_queue.get() |
| if task is None: |
| break |
| |
| task_id, video_path, model_name = task |
| print(f"Starting task {task_id}...") |
| process_video_frame_by_frame(video_path, model_name, task_id) |
| |
| except Exception as e: |
| print(f"Worker error: {e}") |
| finally: |
| video_queue.task_done() |
|
|
|
|
| def submit_video(video_path, model_name): |
| """Submit video to processing queue""" |
| if video_path is None: |
| return None, "β Please upload a video first" |
| |
| task_id = f"task_{int(time.time() * 1000)}" |
| video_queue.put((task_id, video_path, model_name)) |
| processing_status[task_id] = {"status": "queued", "progress": 0} |
| |
| return None, f"β
Video submitted to queue! Task ID: {task_id}\nCheck status in the monitoring section." |
|
|
|
|
| def get_queue_status(): |
| """Get current queue status""" |
| status_text = "π **Queue Status**\n\n" |
| status_text += f"Videos in queue: {video_queue.qsize()}\n\n" |
| |
| if processing_status: |
| status_text += "**Active Tasks:**\n" |
| for task_id, status in processing_status.items(): |
| status_text += f"\n㪠{task_id}:\n" |
| status_text += f" Status: {status['status']}\n" |
| status_text += f" Progress: {status.get('progress', 0)}%\n" |
| if 'output' in status: |
| status_text += f" Output: {status['output']}\n" |
| if 'error' in status: |
| status_text += f" Error: {status['error']}\n" |
| else: |
| status_text += "No active tasks" |
| |
| return status_text |
|
|
|
|
| def get_history_display(): |
| """Get formatted history for display""" |
| history = load_history() |
| if not history: |
| return "No history available" |
| |
| history_text = "π **Processing History**\n\n" |
| for idx, record in enumerate(history[:50]): |
| history_text += f"**{idx + 1}. {record['process_type'].upper()}** - {record['timestamp']}\n" |
| history_text += f" Model: {record['model_name']}\n" |
| history_text += f" Status: {record['status']}\n" |
| history_text += f" Output: {record['output_path']}\n\n" |
| |
| return history_text |
|
|
|
|
| def clear_history(): |
| """Clear all history""" |
| if os.path.exists(HISTORY_FILE): |
| os.remove(HISTORY_FILE) |
| return "β
History cleared!", get_history_display() |
|
|
|
|
| if __name__ == '__main__': |
| |
| |
| worker_thread = threading.Thread(target=video_queue_worker, daemon=True) |
| worker_thread.start() |
| |
| MARKDOWN = """ |
| # APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) |
| |
| [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598) |
| |
| APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios. |
| |
| ### β οΈ Note: Images with short side > 720px will be downsampled to 720px (e.g., 1920x1080 β 1280x720) |
| ### πΉ New: Video processing runs in background queue - you can close the browser and it continues! |
| """ |
|
|
| |
| with gr.Blocks(title="APISR - Anime Super Resolution") as demo: |
| |
| gr.Markdown(MARKDOWN) |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("πΌοΈ Image Processing"): |
| with gr.Row(): |
| with gr.Column(scale=2): |
| input_image = gr.Image(type="filepath", label="Input Image") |
| image_model = gr.Dropdown( |
| choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], |
| value="4xGRL", |
| label="Model" |
| ) |
| image_btn = gr.Button("π Process Image", variant="primary") |
| |
| with gr.Column(scale=3): |
| output_image = gr.Image(type="numpy", label="Output Image") |
| image_status = gr.Textbox(label="Status", lines=2) |
| |
| with gr.Row(): |
| gr.Examples( |
| examples=[ |
| ["__assets__/lr_inputs/image-00277.png"], |
| ["__assets__/lr_inputs/image-00542.png"], |
| ["__assets__/lr_inputs/41.png"], |
| ["__assets__/lr_inputs/f91.jpg"], |
| ], |
| inputs=[input_image], |
| ) |
| |
| image_btn.click( |
| fn=inference_image, |
| inputs=[input_image, image_model], |
| outputs=[output_image, image_status] |
| ) |
| |
| |
| with gr.Tab("π¬ Video Processing"): |
| gr.Markdown(""" |
| ### Video Processing Queue |
| Videos are processed in the background. You can submit multiple videos and close the browser - processing continues! |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| input_video = gr.Video(label="Input Video") |
| video_model = gr.Dropdown( |
| choices=["2xRRDB", "4xRRDB", "4xGRL", "4xDAT"], |
| value="4xGRL", |
| label="Model" |
| ) |
| video_btn = gr.Button("π€ Submit to Queue", variant="primary") |
| video_status = gr.Textbox(label="Submission Status", lines=3) |
| |
| with gr.Column(): |
| gr.Markdown("### π Queue Monitor") |
| queue_status = gr.Textbox(label="Queue Status", lines=15, interactive=False) |
| refresh_btn = gr.Button("π Refresh Status") |
| |
| video_btn.click( |
| fn=submit_video, |
| inputs=[input_video, video_model], |
| outputs=[input_video, video_status] |
| ) |
| |
| refresh_btn.click( |
| fn=get_queue_status, |
| outputs=[queue_status] |
| ) |
| |
| |
| with gr.Tab("π History"): |
| gr.Markdown("### Processing History") |
| |
| with gr.Row(): |
| refresh_history_btn = gr.Button("π Refresh History") |
| clear_history_btn = gr.Button("ποΈ Clear History", variant="stop") |
| |
| history_display = gr.Textbox(label="History", lines=20, interactive=False) |
| clear_status = gr.Textbox(label="Status", lines=1, visible=True) |
| |
| refresh_history_btn.click( |
| fn=get_history_display, |
| outputs=[history_display] |
| ) |
| |
| clear_history_btn.click( |
| fn=clear_history, |
| outputs=[clear_status, history_display] |
| ) |
|
|
| |
| demo.queue(max_size=20) |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_api=False, |
| theme=gr.themes.Soft() |
| ) |