| import cv2 |
| import numpy as np |
| import json |
| from PIL import Image, ImageDraw, ImageFont |
| from transformers import pipeline |
| from huggingface_hub import from_pretrained_keras |
|
|
|
|
| def resize_image(img_in,input_height,input_width): |
| return cv2.resize( img_in, ( input_width,input_height) ,interpolation=cv2.INTER_NEAREST) |
|
|
| def write_dict_to_json(dictionary, save_path, indent=4): |
| with open(save_path, "w") as outfile: |
| json.dump(dictionary, outfile, indent=indent) |
|
|
| def load_json_to_dict(load_path): |
| with open(load_path) as json_file: |
| return json.load(json_file) |
|
|
|
|
| class OCRD: |
| """ |
| Optical Character Recognition and Document processing class that provides functionalities |
| to preprocess images, detect text lines, perform OCR, and visualize the results. |
| |
| The class utilizes deep learning models for various tasks such as binarization and text |
| line segmentation. It provides comprehensive methods to handle image scaling, prediction, |
| text extraction, and overlaying recognized text on images. |
| |
| Attributes: |
| image (ndarray): The image loaded into memory from the specified path. This image |
| is used across various methods within the class. |
| |
| Methods: |
| __init__(img_path: str): |
| Initializes the OCRD class by loading an image from the specified file path. |
| |
| scale_image(img: ndarray) -> ndarray: |
| Scales an image while maintaining its aspect ratio based on predefined width thresholds. |
| |
| predict(model, img: ndarray) -> ndarray: |
| Uses a specified model to make predictions on the image. This function handles |
| image resizing and segmenting for model input. |
| |
| binarize_image(img: ndarray, binarize_mode: str) -> ndarray: |
| Applies binarization to the image based on the specified mode ('detailed', 'fast', or 'no'). |
| |
| segment_textlines(img: ndarray) -> ndarray: |
| Segments text lines from the binarized image using a pretrained model. |
| |
| extract_filter_and_deskew_textlines(img: ndarray, textline_mask: ndarray, min_pixel_sum: int, median_bounds: tuple) -> (dict, ndarray): |
| Processes an image to extract and correct orientation of text lines based on the provided mask. |
| |
| ocr_on_textlines(textline_images: dict) -> dict: |
| Performs OCR on the extracted text lines and returns the recognized text. |
| |
| create_text_overlay_image(textline_images: dict, textline_preds: dict, img_shape: tuple, font_size: int) -> Image: |
| Creates an image overlay with the recognized text annotations. |
| |
| visualize_model_output(prediction: ndarray, img: ndarray) -> ndarray: |
| Visualizes the model's prediction by overlaying it onto the original image with distinct colors. |
| """ |
| |
| def __init__(self, img_path): |
| self.image = np.array(Image.open(img_path)) |
|
|
| def scale_image(self, img): |
| """ |
| Scales an image to have dimensions suitable for neural network inference. Scaling is based on the |
| input width parameter. The new width and height of the image are calculated to maintain the aspect |
| ratio of the original image. |
| |
| Parameters: |
| - img (ndarray): The image to be scaled, expected to be in the form of a numpy array where |
| img.shape[0] is the height and img.shape[1] is the width. |
| |
| Behavior: |
| - If image width is less than 1100, the new width is set to 2000 pixels. The height is adjusted |
| to maintain the aspect ratio. |
| - If image width is between 1100 (inclusive) and 2500 (exclusive), the width remains unchanged |
| and the height is adjusted to maintain the aspect ratio. |
| - If image width is 2500 or more, the width is set to 2000 pixels and the height is similarly |
| adjusted to maintain the aspect ratio. |
| |
| Returns: |
| - img_new (ndarray): A new image array that has been resized according to the specified rules. |
| The aspect ratio of the original image is preserved. |
| |
| Note: |
| - This function assumes that a function `resize_image(img, height, width)` is available and is |
| used to resize the image where `img` is the original image array, `height` is the new height, |
| and `width` is the new width. |
| """ |
|
|
| width_early = img.shape[1] |
|
|
| if width_early < 1100: |
| img_w_new = 2000 |
| img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000) |
| elif width_early >= 1100 and width_early < 2500: |
| img_w_new = width_early |
| img_h_new = int(img.shape[0] / float(img.shape[1]) * width_early) |
| else: |
| img_w_new = 2000 |
| img_h_new = int(img.shape[0] / float(img.shape[1]) * 2000) |
|
|
| img_new = resize_image(img, img_h_new, img_w_new) |
|
|
| return img_new |
|
|
| def predict(self, model, img): |
| """ |
| Processes an image to predict segmentation outputs using a given model. The function handles image resizing |
| to match the model's input dimensions and ensures that the entire image is processed by segmenting it into patches |
| that the model can handle. The prediction from these patches is then reassembled into a single output image. |
| |
| Parameters: |
| - model (keras.Model): The neural network model used for predicting the image segmentation. The model should have |
| predefined input dimensions (height and width). |
| - img (ndarray): The image to be processed, represented as a numpy array. |
| |
| Returns: |
| - prediction_true (ndarray): An image of the same size as the input image, containing the segmentation prediction |
| with each pixel labeled according to the model's output. |
| |
| Details: |
| - The function first scales the input image according to the model's required input dimensions. If the scaled image |
| is smaller than the model's height or width, it is resized to match exactly. |
| - The function processes the image in overlapping patches to ensure smooth transitions between the segments. These |
| patches are then processed individually through the model. |
| - Predictions from these patches are then stitched together to form a complete output image, ensuring that edge |
| artifacts are minimized by carefully blending the overlapping areas. |
| - This method assumes the availability of `resize_image` function for scaling and resizing |
| operations, respectively. |
| - The output is converted to an 8-bit image before returning, suitable for display or further processing. |
| """ |
|
|
| |
| img_height_model=model.layers[len(model.layers)-1].output_shape[1] |
| img_width_model=model.layers[len(model.layers)-1].output_shape[2] |
|
|
| img = self.scale_image(img) |
|
|
| if img.shape[0] < img_height_model: |
| img = resize_image(img, img_height_model, img.shape[1]) |
|
|
| if img.shape[1] < img_width_model: |
| img = resize_image(img, img.shape[0], img_width_model) |
|
|
| marginal_of_patch_percent = 0.1 |
| margin = int(marginal_of_patch_percent * img_height_model) |
| width_mid = img_width_model - 2 * margin |
| height_mid = img_height_model - 2 * margin |
| img = img / float(255.0) |
| img = img.astype(np.float16) |
| img_h = img.shape[0] |
| img_w = img.shape[1] |
| prediction_true = np.zeros((img_h, img_w, 3)) |
| nxf = img_w / float(width_mid) |
| nyf = img_h / float(height_mid) |
| nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf) |
| nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf) |
|
|
| for i in range(nxf): |
| for j in range(nyf): |
| if i == 0: |
| index_x_d = i * width_mid |
| index_x_u = index_x_d + img_width_model |
| else: |
| index_x_d = i * width_mid |
| index_x_u = index_x_d + img_width_model |
| if j == 0: |
| index_y_d = j * height_mid |
| index_y_u = index_y_d + img_height_model |
| else: |
| index_y_d = j * height_mid |
| index_y_u = index_y_d + img_height_model |
| if index_x_u > img_w: |
| index_x_u = img_w |
| index_x_d = img_w - img_width_model |
| if index_y_u > img_h: |
| index_y_u = img_h |
| index_y_d = img_h - img_height_model |
|
|
| img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :] |
| label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]), |
| verbose=0) |
|
|
| seg = np.argmax(label_p_pred, axis=3)[0] |
| seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2) |
|
|
| if i == 0 and j == 0: |
| seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :] |
| prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color |
| elif i == nxf - 1 and j == nyf - 1: |
| seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :] |
| prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg_color |
| elif i == 0 and j == nyf - 1: |
| seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :] |
| prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg_color |
| elif i == nxf - 1 and j == 0: |
| seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :] |
| prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color |
| elif i == 0 and j != 0 and j != nyf - 1: |
| seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :] |
| prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color |
| elif i == nxf - 1 and j != 0 and j != nyf - 1: |
| seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :] |
| prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color |
| elif i != 0 and i != nxf - 1 and j == 0: |
| seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :] |
| prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color |
| elif i != 0 and i != nxf - 1 and j == nyf - 1: |
| seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :] |
| prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg_color |
| else: |
| seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :] |
| prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color |
|
|
| prediction_true = prediction_true.astype(np.uint8) |
|
|
| return prediction_true |
|
|
| def binarize_image(self, img, binarize_mode='detailed'): |
| """ |
| Binarizes an image according to the specified mode. |
| |
| Parameters: |
| - img (ndarray): The input image to be binarized. |
| - binarize_mode (str): The mode of binarization. Can be 'detailed', 'fast', or 'no'. |
| - 'detailed': Uses a pre-trained deep learning model for binarization. |
| - 'fast': Uses OpenCV for a quicker, threshold-based binarization. |
| - 'no': Returns a copy of the original image. |
| |
| Returns: |
| - ndarray: The binarized image. |
| |
| Raises: |
| - ValueError: If an invalid binarize_mode is provided. |
| |
| Description: |
| Depending on the 'binarize_mode', the function processes the image differently: |
| - For 'detailed' mode, it loads a specific model and performs prediction to binarize the image. |
| - For 'fast' mode, it quickly converts the image to grayscale and applies a threshold. |
| - For 'no' mode, it simply returns the original image unchanged. |
| If an unsupported mode is provided, the function raises a ValueError. |
| |
| Note: |
| - The 'detailed' mode requires a pre-trained model from huggingface_hub. |
| - This function depends on OpenCV (cv2) for image processing in 'fast' mode. |
| """ |
|
|
| if binarize_mode == 'detailed': |
| model_name = "SBB/eynollah-binarization" |
| model = from_pretrained_keras(model_name) |
| binarized = self.predict(model, img) |
|
|
| |
| binarized = binarized.astype(np.int8) |
| binarized = -binarized + 1 |
| binarized = (binarized * 255).astype(np.uint8) |
|
|
| elif binarize_mode == 'fast': |
| binarized = self.scale_image(img, self.image) |
| binarized = cv2.cvtColor(binarized, cv2.COLOR_BGR2GRAY) |
| _, binarized = cv2.threshold(binarized, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU) |
| binarized = np.repeat(binarized[:, :, np.newaxis], 3, axis=2) |
|
|
| elif binarize_mode == 'no': |
| binarized = img.copy() |
|
|
| else: |
| accepted_values = ['detailed', 'fast', 'no'] |
| raise ValueError(f"Invalid value provided: {binarize_mode}. Accepted values are: {accepted_values}") |
|
|
| binarized = binarized.astype(np.uint8) |
|
|
| return binarized |
| |
|
|
| def segment_textlines(self, img): |
| ''' |
| ADD DOCUMENTATION! |
| ''' |
| model_name = "SBB/eynollah-textline" |
| model = from_pretrained_keras(model_name) |
| textline_segments = self.predict(model, img) |
|
|
| return textline_segments |
| |
|
|
| def extract_filter_and_deskew_textlines(self, img, textline_mask, min_pixel_sum=20, median_bounds=(.5, 20)): |
|
|
| """ |
| Extracts and deskews text lines from an image based on a provided textline mask. This function identifies |
| text lines, filters out those that do not meet size criteria, calculates their minimum area rectangles, |
| performs perspective transformations to deskew each text line, and handles potential rotations to ensure |
| text lines are presented horizontally. |
| |
| Parameters: |
| - img (numpy.ndarray): The original image from which to extract and deskew text lines. It should be a 3D array. |
| - textline_mask (numpy.ndarray): A binary mask where text lines have been segmented. It should be a 2D array. |
| - min_pixel_sum (int, optional): The minimum number of pixels (area) a connected component must have to be considered |
| a valid text line. If None, no filtering is applied. |
| - median_bounds (tuple, optional): A tuple representing the lower and upper bounds as multipliers for filtering |
| text lines based on the median size of identified text lines. If None, no filtering is applied. |
| |
| Returns: |
| - tuple: |
| - dict: A dictionary containing lists of the extracted and deskewed text line images along with their |
| metadata (center, left side, height, width, and rotation angle of the bounding box). |
| - numpy.ndarray: An image visualization of the filtered text line mask for debugging or analysis. |
| |
| Description: |
| The function first uses connected components to identify potential text lines from the mask. It filters these |
| based on absolute size (min_pixel_sum) and relative size (median_bounds). For each valid text line, it computes |
| a minimum area rectangle, extracts and deskews the bounded region. This includes rotating the text line if it |
| is detected as vertical (taller than wide). Finally, it aggregates the results and provides an image for |
| visualization of the text lines retained after filtering. |
| |
| Notes: |
| - This function assumes the textline_mask is properly segmented and binary (0s for background, 255 for text lines). |
| - Errors in perspective transformation due to incorrect contour extraction or bounding box calculations are handled |
| gracefully, reporting the error but continuing with other text lines. |
| """ |
| |
| num_labels, labels_im = cv2.connectedComponents(textline_mask) |
|
|
| |
| MIN_PIXEL_SUM = min_pixel_sum |
| MEDIAN_LOWER_BOUND = median_bounds[0] |
| MEDIAN_UPPER_BOUND = median_bounds[1] |
|
|
| |
| cc_sizes = [] |
| masks = [] |
| labels_im_filtered = labels_im > 0 |
| for label in range(1, num_labels): |
| mask = np.where(labels_im == label, True, False) |
| if MIN_PIXEL_SUM is None: |
| is_above_min_pixel_sum = True |
| else: |
| is_above_min_pixel_sum = mask.sum() > MIN_PIXEL_SUM |
| if is_above_min_pixel_sum: |
| cc_sizes.append(mask.sum()) |
| masks.append(mask) |
|
|
| |
| rectangles = [] |
| median = np.median(cc_sizes) |
| for mask in masks: |
| mask_sum = mask.sum() |
| if MEDIAN_LOWER_BOUND is None: |
| is_above_lower_media_bound = True |
| else: |
| is_above_lower_media_bound = mask_sum > median*MEDIAN_LOWER_BOUND |
| if MEDIAN_UPPER_BOUND is None: |
| is_below_upper_median_bound = True |
| else: |
| is_below_upper_median_bound = mask_sum < median*MEDIAN_UPPER_BOUND |
| if is_above_lower_media_bound and is_below_upper_median_bound: |
| labels_im_filtered[mask > 0] = False |
| mask = (mask*255).astype(np.uint8) |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| rect = cv2.minAreaRect(contours[0]) |
| if np.prod(rect[1]) > 0: |
| rectangles.append(rect) |
|
|
| |
| if rectangles: |
| |
| textline_images = [] |
| for rect in rectangles: |
| width, height = rect[1] |
| rotation_angle = rect[2] |
| |
| |
| width = int(width) |
| height = int(height) |
|
|
| |
| box = cv2.boxPoints(rect) |
| box = np.intp(box) |
| src_pts = box.astype("float32") |
| dst_pts = np.array([[0, height-1], |
| [0, 0], |
| [width-1, 0], |
| [width-1, height-1]], dtype="float32") |
| |
| try: |
| M = cv2.getPerspectiveTransform(src_pts, dst_pts) |
| warped = cv2.warpPerspective(img, M, (width, height)) |
| |
| if height > width: |
| warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) |
| temp = height |
| height = width |
| width = temp |
| rotation_angle = 90-rotation_angle |
| center = rect[0] |
| left = center[0] - width//2 |
| textline_images.append((warped, center, left, height, width, rotation_angle)) |
| except cv2.error as e: |
| print(f"Error with warpPerspective: {e}") |
|
|
| |
| keys = ['array', 'center', 'left', 'height', 'width', 'rotation_angle'] |
| textline_images = {key: [tup[i] for tup in textline_images] for i, key in enumerate(keys)} |
| num_labels_filtered = len(textline_images['array']) |
| labels_im_filtered = np.repeat(labels_im_filtered[:, :, np.newaxis], 3, axis=2).astype(np.uint8) |
| print(f'Kept {num_labels_filtered} of {num_labels} text segments after filtering.') |
| print(f'All segments deleted smaller than {MIN_PIXEL_SUM} pixels (absolute min size).') |
| if MEDIAN_LOWER_BOUND is not None: |
| print(f'All segments deleted smaller than {median*MEDIAN_LOWER_BOUND} pixels (lower median bound).') |
| if MEDIAN_UPPER_BOUND is not None: |
| print(f'All segments deleted bigger than {median*MEDIAN_UPPER_BOUND} pixels (upper median bound).') |
| if MEDIAN_LOWER_BOUND is not None or MEDIAN_UPPER_BOUND is not None: |
| print(f'Median segment size (pixel sum) used for filtering: {int(median)}.') |
|
|
| return textline_images, labels_im_filtered |
|
|
|
|
| def ocr_on_textlines(self, textline_images, model_name="microsoft/trocr-base-handwritten"): |
| """ |
| Processes a list of image arrays using a pre-trained OCR model to extract text. |
| |
| Parameters: |
| - textline_images (dict): A dictionary with a key 'array' that contains a list of image arrays. |
| Each image array represents a line of text that will be processed by the OCR model. |
| - model_name (str): A huggingface model trained for OCR on single text lines |
| |
| Returns: |
| - dict: A dictionary containing a list of extracted text under the key 'preds'. |
| |
| Description: |
| The function initializes the OCR model 'microsoft/trocr-base-handwritten' using Hugging Face's |
| `pipeline` API for image-to-text conversion. Each image in the input list is converted from an |
| array format to a PIL Image, processed by the model, and the text prediction is collected. |
| The progress of image processing is printed every 10 images. The final result is a dictionary |
| with the key 'preds' that holds all text predictions as a list. |
| |
| Note: |
| - This function requires the `transformers` library from Hugging Face and PIL library to run. |
| - Ensure that the model 'microsoft/trocr-base-handwritten' is correctly loaded and the |
| `transformers` library is updated to use the pipeline. |
| """ |
| |
| pipe = pipeline("image-to-text", model=model_name) |
|
|
| |
| textline_preds = [] |
| len_array = len(textline_images['array']) |
| for i, textline in enumerate(textline_images['array'][:]): |
| if i % 10 == 1: |
| print(f'Processing textline no. {i} of {len_array}') |
| textline = Image.fromarray(textline) |
| textline_preds.append(pipe(textline)) |
|
|
| |
| preds = [pred[0]['generated_text'] for pred in textline_preds] |
| textline_preds_dict = {'preds': preds} |
|
|
| return textline_preds_dict |
|
|
|
|
| def adjust_font_size(self, draw, text, box_width): |
| """ |
| Adjusts the font size to ensure the text fits within a specified width. |
| |
| Parameters: |
| - draw (ImageDraw.Draw): An instance of ImageDraw.Draw used to render the text. |
| - text (str): The text string to be rendered. |
| - box_width (int): The maximum width in pixels that the text should occupy. |
| |
| Returns: |
| - ImageFont: A font object with a size adjusted to fit the text within the specified width. |
| """ |
|
|
| for font_size in range(1, 200): |
| font = ImageFont.load_default(font_size) |
| text_width = draw.textlength(text, font=font) |
| if text_width > box_width: |
| font_size = max(5, int(font_size - 10)) |
| return ImageFont.load_default(font_size) |
| return font |
|
|
|
|
| def create_text_overlay_image(self, textline_images, textline_preds, img_shape, font_size=-1): |
| """ |
| Creates an image overlay with text annotations based on provided bounding box information and predictions. |
| |
| Parameters: |
| - textline_images (dict): A dictionary containing the bounding box data for each text segment. |
| It should have keys 'left', 'center', 'width', and optionally 'height'. Each key should have |
| a list of values corresponding to each text segment's properties. |
| - textline_preds (dict): A dictionary containing the predicted text segments. It should have |
| a key 'preds' which holds a list of text predictions corresponding to the bounding boxes in |
| textline_images. |
| - img_shape (tuple): A tuple representing the shape of the image where the text is to be drawn. |
| The format should be (height, width). |
| - font_size (int, optional): Specifies the font size for the text. If set to -1 (default), the font size |
| is dynamically adjusted to fit the text within its bounding box width using the `adjust_font_size` |
| function. If a specific integer is provided, it uses that size for all text segments. |
| |
| Returns: |
| - Image: An image object with text drawn over a blank white background. |
| |
| Raises: |
| - AssertionError: If the lengths of the lists in `textline_images` and `textline_preds['preds']` |
| do not correspond, indicating a mismatch in the number of bounding boxes and text predictions. |
| """ |
|
|
| for key in textline_images.keys(): |
| assert len(textline_images[key]) == len(textline_preds['preds']), f'Length of {key} and preds doesnt correspond' |
|
|
| |
| img_gen = Image.new('RGB', (img_shape[1], img_shape[0]), color=(255, 255, 255)) |
| draw = ImageDraw.Draw(img_gen) |
|
|
| |
| for i in range(len(textline_preds['preds'])): |
| left_x = textline_images['left'][i] |
| center_y = textline_images['center'][i][1] |
| |
| width = textline_images['width'][i] |
| text = textline_preds['preds'][i] |
| |
| |
| if font_size==-1: |
| font = self.adjust_font_size(draw, text, width) |
| else: |
| font = ImageFont.load_default(font_size) |
| draw.text((left_x, center_y), text, fill=(0, 0, 0), font=font, align='left') |
|
|
| return img_gen |
|
|
|
|
| def visualize_model_output(self, prediction, img): |
| """ |
| Visualizes the output of a model prediction by overlaying predicted classes with distinct colors onto the original image. |
| |
| Parameters: |
| - prediction (ndarray): A 3D array where the first channel holds the class predictions. |
| - img (ndarray): The original image to overlay predictions onto. This should be in the same dimensions or resized accordingly. |
| |
| Returns: |
| - ndarray: An image where the model's predictions are overlaid on the original image using a predefined color map. |
| |
| Description: |
| The function first identifies unique classes present in the prediction's first channel. Each class is assigned a specific color from a predefined dictionary `rgb_colors`. The function then creates an output image where each pixel's color corresponds to the class predicted at that location. |
| |
| The function resizes the original image to match the dimensions of the prediction if necessary. It then blends the original image and the colored prediction output using OpenCV's `addWeighted` method to produce a final image that highlights the model's predictions with transparency. |
| |
| Note: |
| - This function relies on `numpy` for array manipulations and `cv2` for image processing. |
| - Ensure the `rgb_colors` dictionary contains enough colors for all classes your model can predict. |
| - The function assumes `prediction` array's shape is compatible with `img`. |
| """ |
|
|
| unique_classes = np.unique(prediction[:,:,0]) |
| rgb_colors = {'0' : [255, 255, 255], |
| '1' : [255, 0, 0], |
| '2' : [255, 125, 0], |
| '3' : [255, 0, 125], |
| '4' : [125, 125, 125], |
| '5' : [125, 125, 0], |
| '6' : [0, 125, 255], |
| '7' : [0, 125, 0], |
| '8' : [125, 125, 125], |
| '9' : [0, 125, 255], |
| '10' : [125, 0, 125], |
| '11' : [0, 255, 0], |
| '12' : [0, 0, 255], |
| '13' : [0, 255, 255], |
| '14' : [255, 125, 125], |
| '15' : [255, 0, 255]} |
|
|
| output = np.zeros(prediction.shape) |
|
|
| for unq_class in unique_classes: |
| rgb_class_unique = rgb_colors[str(int(unq_class))] |
| output[:,:,0][prediction[:,:,0]==unq_class] = rgb_class_unique[0] |
| output[:,:,1][prediction[:,:,0]==unq_class] = rgb_class_unique[1] |
| output[:,:,2][prediction[:,:,0]==unq_class] = rgb_class_unique[2] |
|
|
| img = resize_image(img, output.shape[0], output.shape[1]) |
|
|
| output = output.astype(np.int32) |
| img = img.astype(np.int32) |
| |
| |
| added_image = cv2.addWeighted(img,0.8,output,0.2,10) |
| |
| return added_image |