| import os |
| import cv2 |
| import numpy as np |
| from PIL import Image, ImageEnhance, ImageFilter |
| import time |
|
|
| try: |
| from modelscope.pipelines import pipeline |
| from modelscope.utils.constant import Tasks |
| from modelscope.outputs import OutputKeys |
| HAS_MODELSCOPE = True |
| except ImportError: |
| HAS_MODELSCOPE = False |
|
|
| try: |
| import torch |
| except ImportError: |
| torch = None |
|
|
| class MockPipeline: |
| def __call__(self, image): |
| |
| h, w = image.shape[:2] |
| time.sleep((h * w) / 10_000_000.0) |
|
|
| |
| |
| output = image.copy() |
| |
| output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) |
|
|
| |
| output[:, :, 0] = np.clip(output[:, :, 0] * 0.9, 0, 255) |
| output[:, :, 1] = np.clip(output[:, :, 1] * 0.95, 0, 255) |
| output[:, :, 2] = np.clip(output[:, :, 2] * 1.1, 0, 255) |
|
|
| return {'output_img': output} |
|
|
| class Colorizer: |
| def __init__(self, model_id="iic/cv_ddcolor_image-colorization", device="cpu"): |
| self.model_id = model_id |
| self.device = device |
| self.pipeline = None |
| self.load_model() |
|
|
| def load_model(self): |
| if HAS_MODELSCOPE: |
| try: |
| print(f"Loading model {self.model_id}...") |
| self.pipeline = pipeline( |
| Tasks.image_colorization, |
| model=self.model_id, |
| |
| ) |
| print("Model loaded.") |
|
|
| |
| if self.device == 'cpu' and torch is not None and hasattr(self.pipeline, 'model'): |
| try: |
| print("Applying dynamic quantization...") |
| self.pipeline.model = torch.quantization.quantize_dynamic( |
| self.pipeline.model, {torch.nn.Linear}, dtype=torch.qint8 |
| ) |
| print("Quantization applied.") |
| except Exception as qe: |
| print(f"Quantization failed: {qe}") |
|
|
| except Exception as e: |
| print(f"Failed to load real model: {e}. Using mock.") |
| self.pipeline = MockPipeline() |
| else: |
| print("ModelScope not found. Using Mock.") |
| self.pipeline = MockPipeline() |
|
|
| def process(self, img_pil: Image.Image, brightness: float = 1.0, contrast: float = 1.0, edge_enhance: bool = False, adaptive_resolution: int = 512) -> Image.Image: |
| """ |
| Process a PIL Image: Colorize -> Enhance. |
| |
| Args: |
| img_pil: Input image (PIL) |
| brightness: Brightness factor |
| contrast: Contrast factor |
| edge_enhance: Apply edge enhancement |
| adaptive_resolution: Max dimension for inference. |
| If image is larger, it's resized for colorization, |
| then upscaled and merged with original Luma. |
| Set to 0 to disable. |
| |
| Returns a PIL Image. |
| """ |
| t0 = time.time() |
| w_orig, h_orig = img_pil.size |
| use_adaptive = (w_orig > adaptive_resolution or h_orig > adaptive_resolution) and adaptive_resolution > 0 |
|
|
| if use_adaptive: |
| |
| scale = adaptive_resolution / max(w_orig, h_orig) |
| new_w, new_h = int(w_orig * scale), int(h_orig * scale) |
| |
| img_input = img_pil.resize((new_w, new_h), Image.BILINEAR) |
| else: |
| img_input = img_pil |
|
|
| |
| img_np = np.array(img_input) |
|
|
| t1 = time.time() |
| |
| try: |
| output = self.pipeline(img_np) |
| except Exception as e: |
| print(f"Inference error: {e}") |
| raise e |
| t2 = time.time() |
|
|
| |
| if isinstance(output, dict): |
| key = OutputKeys.OUTPUT_IMG if HAS_MODELSCOPE else 'output_img' |
| result_bgr = output[key] |
| else: |
| result_bgr = output |
|
|
| result_bgr = result_bgr.astype(np.uint8) |
|
|
| if use_adaptive: |
| |
| result_lab = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2LAB) |
|
|
| |
| orig_np = np.array(img_pil) |
| orig_bgr = cv2.cvtColor(orig_np, cv2.COLOR_RGB2BGR) |
| orig_lab = cv2.cvtColor(orig_bgr, cv2.COLOR_BGR2LAB) |
| L_orig = orig_lab[:, :, 0] |
|
|
| |
| result_lab_up = cv2.resize(result_lab, (w_orig, h_orig), interpolation=cv2.INTER_CUBIC) |
|
|
| |
| merged_lab = np.empty_like(orig_lab) |
| merged_lab[:, :, 0] = L_orig |
| merged_lab[:, :, 1] = result_lab_up[:, :, 1] |
| merged_lab[:, :, 2] = result_lab_up[:, :, 2] |
|
|
| |
| result_bgr_final = cv2.cvtColor(merged_lab, cv2.COLOR_LAB2BGR) |
| result_rgb = cv2.cvtColor(result_bgr_final, cv2.COLOR_BGR2RGB) |
| else: |
| |
| result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) |
|
|
| t3 = time.time() |
| |
| out_pil = Image.fromarray(result_rgb) |
|
|
| if brightness != 1.0: |
| out_pil = ImageEnhance.Brightness(out_pil).enhance(brightness) |
| if contrast != 1.0: |
| out_pil = ImageEnhance.Contrast(out_pil).enhance(contrast) |
| if edge_enhance: |
| out_pil = out_pil.filter(ImageFilter.EDGE_ENHANCE) |
|
|
| t4 = time.time() |
| |
| return out_pil |
|
|