import argparse import os import sys import glob import time from pathlib import Path from PIL import Image import torch import torchvision.transforms as T # Output resolution is capped at 768px def parse_args(): parser = argparse.ArgumentParser(description="TorchScript Pipeline Inference for Watermark Removal") group = parser.add_mutually_exclusive_group(required=True) group.add_argument('-i', '--image', type=str, help="Path to single input watermarked image") group.add_argument('-f', '--folder', type=str, help="Path to folder containing watermarked images") parser.add_argument('-o', '--output_folder', type=str, default='tests', help="Output folder to save original and clean images") parser.add_argument('-m', '--model_path', type=str, default='model.ts', help="Path to TorchScript pipeline model (.ts file)") return parser.parse_args() def calculate_output_dimensions(orig_width, orig_height, max_size): """ Calculate output dimensions maintaining original aspect ratio. Caps at max_size (never upscale beyond processing size). """ # If image fits within max_size, keep original dimensions if orig_width <= max_size and orig_height <= max_size: return (orig_width, orig_height) # Scale down to fit within max_size, maintaining aspect ratio if orig_width >= orig_height: output_width = max_size output_height = int(orig_height * (max_size / orig_width)) else: output_height = max_size output_width = int(orig_width * (max_size / orig_height)) return (output_width, output_height) def load_torchscript_model(model_path): """Load TorchScript pipeline model.""" device = torch.device('cuda') print(f"Loading TorchScript pipeline from: {model_path}") model = torch.jit.load(model_path, map_location=device) model.eval() return model, device def process_image(img_path, model, device, output_folder=None): # Load image and get original size img = Image.open(img_path).convert('RGB') orig_width, orig_height = img.size base_name = os.path.basename(img_path) print(f" [{base_name}] Original: {orig_width}x{orig_height}", end="") # Convert to tensor [1, 3, H, W] in [0, 1] range img_tensor = T.ToTensor()(img).unsqueeze(0).to(device) # Inference with TorchScript pipeline # Pipeline handles: resize → normalize → model1 → model2 → denormalize → final resize with torch.no_grad(): pred_t = model(img_tensor) # Output: [1, 3, final_size, final_size] in [0, 1] # Get output size from pipeline _, _, pipeline_size, _ = pred_t.shape print(f" → Pipeline output: {pipeline_size}x{pipeline_size}", end="") # Convert tensor to PIL (square output at pipeline_size) pred_img = T.ToPILImage()(pred_t.squeeze(0).cpu()) # Resize back to original dimensions using PIL LANCZOS (capped at pipeline_size) output_width, output_height = calculate_output_dimensions(orig_width, orig_height, pipeline_size) pred_img = pred_img.resize((output_width, output_height), resample=Image.LANCZOS) print(f" → Resized: {output_width}x{output_height}", end="") output_width, output_height = pred_img.size print(f" → Output: {output_width}x{output_height}") # Determine save paths base_name = os.path.splitext(os.path.basename(img_path))[0] clean_name = f"{base_name}-clean.webp" # Create output folder and save both original and clean versions os.makedirs(output_folder, exist_ok=True) # Save original in output folder (keeps original extension) orig_save_path = os.path.join(output_folder, os.path.basename(img_path)) img.save(orig_save_path) # Save clean version (webp format with -clean suffix) clean_path = os.path.join(output_folder, clean_name) pred_img.save(clean_path, 'WEBP', quality=95) def main(): # Enable TensorFloat32 for faster matmul on Ampere+ GPUs torch.set_float32_matmul_precision('high') args = parse_args() # Verify TorchScript model exists if not os.path.exists(args.model_path): print(f"Error: TorchScript model not found: {args.model_path}") return print(f"TorchScript Pipeline Inference") print(f"Model: {args.model_path}") print() # Load TorchScript pipeline once model, device = load_torchscript_model(args.model_path) print(f"Pipeline loaded on {device}") print() num_images = 0 # Determine output folder based on processing mode if args.image: # Single image: save directly in output_folder output_path = args.output_folder # Start timing AFTER model loading start_time = time.time() process_image(args.image, model, device, output_path) num_images = 1 elif args.folder: # Folder processing: create subfolder {model_name}_{folder_name}_ts model_name = os.path.splitext(os.path.basename(args.model_path))[0] folder_name = os.path.basename(os.path.normpath(args.folder)) subfolder_name = f"{model_name}_{folder_name}_ts" output_path = os.path.join(args.output_folder, subfolder_name) print(f"Saving outputs to: {output_path}") print() # Process all JPG/WebP in folder patterns = ['*.jpg', '*.webp'] images = [] for pattern in patterns: images.extend(glob.glob(os.path.join(args.folder, pattern))) num_images = len(images) # Start timing AFTER model loading start_time = time.time() for img_path in sorted(images): process_image(img_path, model, device, output_path) # Print total processing time elapsed_time = time.time() - start_time print(f"\nProcessed {num_images} image{'s' if num_images != 1 else ''} in {elapsed_time:.2f} seconds ({elapsed_time/num_images:.2f}s per image)") if __name__ == '__main__': main()