| from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor |
| import torch |
| from typing import List, Dict, Any, Optional, Tuple |
| from shapely.geometry import box |
| from shapely.geometry.polygon import Polygon |
| from .utils import x1y1x2y2_to_xywh |
| import numpy as np |
| from numpy.typing import NDArray |
|
|
|
|
| class Magiv2Processor(): |
| """ |
| Procesor danych dla modelu Magiv2 - obsługuje preprocessing i postprocessing. |
| |
| Klasa odpowiedzialna za przygotowanie danych wejściowych dla różnych modułów |
| Magiv2 (detekcja, OCR, embeddingi) oraz przetwarzanie outputów. Zawiera również |
| metody pomocnicze do filtrowania detekcji i konwersji formatów anotacji. |
| |
| Attributes: |
| config: Konfiguracja modelu Magiv2 |
| detection_image_preprocessor: Preprocessor dla obrazów do detekcji obiektów |
| ocr_preprocessor: Preprocessor dla obrazów do OCR |
| crop_embedding_image_preprocessor: Preprocessor dla wyciętych fragmentów obrazu |
| """ |
|
|
| def __init__(self, config: Any) -> None: |
| """ |
| Inicjalizuje procesor z podaną konfiguracją. |
| |
| Tworzy preprocessory dla modułów, które są aktywne zgodnie z konfiguracją: |
| - Detekcja obiektów: ConditionalDetrImageProcessor |
| - OCR: TrOCRProcessor |
| - Embeddingi crops: ViTImageProcessor |
| |
| Args: |
| config: Obiekt konfiguracji Magiv2Config z parametrami preprocessingu |
| """ |
| self.config: Any = config |
| self.detection_image_preprocessor: Optional[ConditionalDetrImageProcessor] = None |
| self.ocr_preprocessor: Optional[TrOCRProcessor] = None |
| self.crop_embedding_image_preprocessor: Optional[ViTImageProcessor] = None |
|
|
| |
| if not config.disable_detections: |
| assert config.detection_image_preprocessing_config is not None |
| self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict( |
| config.detection_image_preprocessing_config) |
|
|
| |
| if not config.disable_ocr: |
| assert config.ocr_pretrained_processor_path is not None |
| self.ocr_preprocessor = TrOCRProcessor.from_pretrained( |
| config.ocr_pretrained_processor_path) |
|
|
| |
| if not config.disable_crop_embeddings: |
| assert config.crop_embedding_image_preprocessing_config is not None |
| self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict( |
| config.crop_embedding_image_preprocessing_config) |
|
|
| def preprocess_inputs_for_detection( |
| self, |
| images: List[NDArray[np.uint8]], |
| annotations: Optional[List[Dict[str, Any]]] = None |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Preprocessuje obrazy do formatu wymaganego przez moduł detekcji obiektów. |
| |
| Wykonuje normalizację, resize i padding obrazów. Jeśli podano anotacje, |
| konwertuje je do formatu COCO i skaluje współrzędnie bbox zgodnie z resize. |
| |
| Args: |
| images: Lista obrazów jako numpy arrays (format HWC) |
| annotations: Opcjonalne anotacje ground truth w formacie: |
| [{"image_id": int, "bboxes_as_x1y1x2y2": List, "labels": List}] |
| |
| Returns: |
| Słownik z kluczami: |
| - "pixel_values": torch.Tensor z preprocessowanymi obrazami |
| - "pixel_mask": torch.Tensor z maską paddingu |
| - "labels": List[Dict] z przetworzonymi anotacjami (jeśli podano) |
| """ |
| images_list: List[NDArray[np.uint8]] = list(images) |
| assert isinstance(images_list[0], np.ndarray) |
| |
| coco_annotations: Optional[List[Dict[str, Any]] |
| ] = self._convert_annotations_to_coco_format(annotations) |
| |
| inputs: Dict[str, torch.Tensor] = self.detection_image_preprocessor( |
| images_list, annotations=coco_annotations, return_tensors="pt") |
| return inputs |
|
|
| def preprocess_inputs_for_ocr(self, images: List[NDArray[np.uint8]]) -> torch.Tensor: |
| """ |
| Preprocessuje obrazy do formatu wymaganego przez moduł OCR. |
| |
| Wykonuje normalizację i resize obrazów tekstowych dla modelu TrOCR. |
| |
| Args: |
| images: Lista obrazów jako numpy arrays (fragmenty z tekstem) |
| |
| Returns: |
| Tensor z preprocessowanymi obrazami [batch, channels, height, width] |
| """ |
| images_list: List[NDArray[np.uint8]] = list(images) |
| assert isinstance(images_list[0], np.ndarray) |
| return self.ocr_preprocessor(images_list, return_tensors="pt").pixel_values |
|
|
| def preprocess_inputs_for_crop_embeddings(self, images: List[NDArray[np.uint8]]) -> torch.Tensor: |
| """ |
| Preprocessuje wycięte fragmenty obrazów dla modułu embeddingów. |
| |
| Wykonuje normalizację i resize crops dla modelu ViT-MAE. |
| |
| Args: |
| images: Lista wyciętych fragmentów obrazów jako numpy arrays |
| |
| Returns: |
| Tensor z preprocessowanymi crops [batch, channels, height, width] |
| """ |
| images_list: List[NDArray[np.uint8]] = list(images) |
| assert isinstance(images_list[0], np.ndarray) |
| return self.crop_embedding_image_preprocessor(images_list, return_tensors="pt").pixel_values |
|
|
| def postprocess_ocr_tokens( |
| self, |
| generated_ids: torch.Tensor, |
| skip_special_tokens: bool = True |
| ) -> List[str]: |
| """ |
| Dekoduje tokeny wygenerowane przez model OCR na tekst. |
| |
| Args: |
| generated_ids: Tensor z ID tokenów wygenerowanych przez decoder OCR |
| skip_special_tokens: Czy pomijać specjalne tokeny (PAD, BOS, EOS) w wyniku |
| |
| Returns: |
| Lista stringów z rozpoznanym tekstem |
| """ |
| return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens) |
|
|
| def crop_image( |
| self, |
| image: NDArray[np.uint8], |
| bboxes: List[List[float]] |
| ) -> List[NDArray[np.uint8]]: |
| """ |
| Wycina fragmenty obrazu zgodnie z podanymi bounding boxami. |
| |
| Metoda automatycznie naprawia nieprawidłowe bounding boxy: |
| - Ogranicza współrzędne do granic obrazu |
| - Zapewnia minimalny rozmiar 10x10 pikseli |
| - Zamienia współrzędne jeśli są w nieprawidłowej kolejności |
| |
| Args: |
| image: Obraz źródłowy jako numpy array (format HWC) |
| bboxes: Lista bounding boxów w formacie [x1, y1, x2, y2] |
| |
| Returns: |
| Lista wyciętych fragmentów obrazu (każdy jako numpy array) |
| """ |
| crops_for_image: List[NDArray[np.uint8]] = [] |
| for bbox in bboxes: |
| x1: float |
| y1: float |
| x2: float |
| y2: float |
| x1, y1, x2, y2 = bbox |
|
|
| |
| |
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
| |
| x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) |
| |
| x1, y1 = max(0, x1), max(0, y1) |
| x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1) |
| |
| x2, y2 = max(0, x2), max(0, y2) |
| x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2) |
|
|
| |
| if x2 - x1 < 10: |
| if image.shape[1] - x1 > 10: |
| x2 = x1 + 10 |
| else: |
| x1 = x2 - 10 |
|
|
| |
| if y2 - y1 < 10: |
| if image.shape[0] - y1 > 10: |
| y2 = y1 + 10 |
| else: |
| y1 = y2 - 10 |
|
|
| |
| crop: NDArray[np.uint8] = image[y1:y2, x1:x2] |
| crops_for_image.append(crop) |
| return crops_for_image |
|
|
| def _get_indices_of_characters_to_keep( |
| self, |
| batch_scores: torch.Tensor, |
| batch_labels: torch.Tensor, |
| batch_bboxes: torch.Tensor, |
| character_detection_threshold: float |
| ) -> List[torch.Tensor]: |
| """ |
| Filtruje detekcje postaci na podstawie progu prawdopodobieństwa. |
| |
| Zachowuje tylko detekcje z etykietą 0 (postać) i score powyżej progu. |
| |
| Args: |
| batch_scores: Tensor ze scorami prawdopodobieństwa [batch, num_queries] |
| batch_labels: Tensor z etykietami klas [batch, num_queries] |
| batch_bboxes: Tensor z bounding boxami [batch, num_queries, 4] |
| character_detection_threshold: Minimalny score do zachowania detekcji (0-1) |
| |
| Returns: |
| Lista tensorów z indeksami postaci do zachowania dla każdego obrazu |
| """ |
| indices_of_characters_to_keep: List[torch.Tensor] = [] |
| for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes): |
| |
| indices: torch.Tensor = torch.where((labels == 0) & ( |
| scores > character_detection_threshold))[0] |
| indices_of_characters_to_keep.append(indices) |
| return indices_of_characters_to_keep |
|
|
| def _get_indices_of_panels_to_keep( |
| self, |
| batch_scores: torch.Tensor, |
| batch_labels: torch.Tensor, |
| batch_bboxes: torch.Tensor, |
| panel_detection_threshold: float |
| ) -> List[List[int]]: |
| """ |
| Filtruje detekcje paneli z zastosowaniem NMS (Non-Maximum Suppression). |
| |
| Zachowuje tylko panele z etykietą 2 i score powyżej progu. Dodatkowo |
| stosuje NMS aby usunąć nakładające się panele - jeśli nowy panel |
| pokrywa się w >50% z już zaakceptowanymi panelami, jest odrzucany. |
| |
| Args: |
| batch_scores: Tensor ze scorami [batch, num_queries] |
| batch_labels: Tensor z etykietami [batch, num_queries] |
| batch_bboxes: Tensor z bboxami [batch, num_queries, 4] |
| panel_detection_threshold: Minimalny score do zachowania panelu |
| |
| Returns: |
| Lista list indeksów paneli do zachowania (po NMS) dla każdego obrazu |
| """ |
| indices_of_panels_to_keep: List[List[int]] = [] |
| for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): |
| |
| indices: torch.Tensor = torch.where(labels == 2)[0] |
| bboxes = bboxes[indices] |
| scores = scores[indices] |
| labels = labels[indices] |
| if len(indices) == 0: |
| indices_of_panels_to_keep.append([]) |
| continue |
|
|
| |
| scores, labels, indices, bboxes = zip( |
| *sorted(zip(scores, labels, indices, bboxes), reverse=True)) |
|
|
| panels_to_keep: List[Tuple[torch.Tensor, |
| torch.Tensor, torch.Tensor, torch.Tensor]] = [] |
| |
| union_of_panels_so_far: Polygon = box(0, 0, 0, 0) |
|
|
| for ps, pb, pl, pi in zip(scores, bboxes, labels, indices): |
| |
| panel_polygon: Polygon = box(pb[0], pb[1], pb[2], pb[3]) |
|
|
| |
| if ps < panel_detection_threshold: |
| continue |
|
|
| |
| if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5: |
| continue |
|
|
| |
| panels_to_keep.append((ps, pl, pb, pi)) |
| |
| union_of_panels_so_far = union_of_panels_so_far.union( |
| panel_polygon) |
|
|
| |
| indices_of_panels_to_keep.append( |
| [p[3].item() for p in panels_to_keep]) |
| return indices_of_panels_to_keep |
|
|
| def _get_indices_of_texts_to_keep( |
| self, |
| batch_scores: torch.Tensor, |
| batch_labels: torch.Tensor, |
| batch_bboxes: torch.Tensor, |
| text_detection_threshold: float |
| ) -> List[List[int]]: |
| """ |
| Filtruje detekcje tekstu z zastosowaniem NMS (Non-Maximum Suppression). |
| |
| Zachowuje tylko tekst z etykietą 1 i score powyżej progu. Stosuje NMS |
| aby usunąć duplikaty - jeśli nowy tekst ma IoU >0.5 z już zaakceptowanym |
| tekstem, jest odrzucany. |
| |
| Args: |
| batch_scores: Tensor ze scorami [batch, num_queries] |
| batch_labels: Tensor z etykietami [batch, num_queries] |
| batch_bboxes: Tensor z bboxami [batch, num_queries, 4] |
| text_detection_threshold: Minimalny score do zachowania tekstu |
| |
| Returns: |
| Lista list indeksów tekstów do zachowania (po NMS) dla każdego obrazu |
| """ |
| indices_of_texts_to_keep: List[List[int]] = [] |
| for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): |
| |
| indices: torch.Tensor = torch.where((labels == 1) & ( |
| scores > text_detection_threshold))[0] |
| bboxes = bboxes[indices] |
| scores = scores[indices] |
| labels = labels[indices] |
| if len(indices) == 0: |
| indices_of_texts_to_keep.append([]) |
| continue |
|
|
| |
| scores, labels, indices, bboxes = zip( |
| *sorted(zip(scores, labels, indices, bboxes), reverse=True)) |
|
|
| texts_to_keep: List[Tuple[torch.Tensor, |
| torch.Tensor, torch.Tensor, torch.Tensor]] = [] |
| |
| texts_to_keep_as_shapely_objects: List[Polygon] = [] |
|
|
| for ts, tb, tl, ti in zip(scores, bboxes, labels, indices): |
| |
| text_polygon: Polygon = box(tb[0], tb[1], tb[2], tb[3]) |
| should_append: bool = True |
|
|
| |
| for t in texts_to_keep_as_shapely_objects: |
| |
| if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5: |
| should_append = False |
| break |
|
|
| if should_append: |
| texts_to_keep.append((ts, tl, tb, ti)) |
| texts_to_keep_as_shapely_objects.append(text_polygon) |
|
|
| |
| indices_of_texts_to_keep.append( |
| [t[3].item() for t in texts_to_keep]) |
| return indices_of_texts_to_keep |
|
|
| def _get_indices_of_tails_to_keep( |
| self, |
| batch_scores: torch.Tensor, |
| batch_labels: torch.Tensor, |
| batch_bboxes: torch.Tensor, |
| text_detection_threshold: float |
| ) -> List[List[int]]: |
| """ |
| Filtruje detekcje ogonów dymków z zastosowaniem NMS (Non-Maximum Suppression). |
| |
| Zachowuje tylko ogony z etykietą 3 i score powyżej progu. Stosuje NMS |
| aby usunąć duplikaty - jeśli nowy ogon ma IoU >0.5 z już zaakceptowanym |
| ogonem, jest odrzucany. |
| |
| Args: |
| batch_scores: Tensor ze scorami [batch, num_queries] |
| batch_labels: Tensor z etykietami [batch, num_queries] |
| batch_bboxes: Tensor z bboxami [batch, num_queries, 4] |
| text_detection_threshold: Minimalny score do zachowania ogona |
| |
| Returns: |
| Lista list indeksów ogonów do zachowania (po NMS) dla każdego obrazu |
| """ |
| indices_of_tails_to_keep: List[List[int]] = [] |
| for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): |
| |
| indices: torch.Tensor = torch.where((labels == 3) & ( |
| scores > text_detection_threshold))[0] |
| bboxes = bboxes[indices] |
| scores = scores[indices] |
| labels = labels[indices] |
| if len(indices) == 0: |
| indices_of_tails_to_keep.append([]) |
| continue |
|
|
| |
| scores, labels, indices, bboxes = zip( |
| *sorted(zip(scores, labels, indices, bboxes), reverse=True)) |
|
|
| tails_to_keep: List[Tuple[torch.Tensor, |
| torch.Tensor, torch.Tensor, torch.Tensor]] = [] |
| |
| tails_to_keep_as_shapely_objects: List[Polygon] = [] |
|
|
| for ts, tb, tl, ti in zip(scores, bboxes, labels, indices): |
| |
| tail_polygon: Polygon = box(tb[0], tb[1], tb[2], tb[3]) |
| should_append: bool = True |
|
|
| |
| for t in tails_to_keep_as_shapely_objects: |
| |
| if t.intersection(tail_polygon).area / t.union(tail_polygon).area > 0.5: |
| should_append = False |
| break |
|
|
| if should_append: |
| tails_to_keep.append((ts, tl, tb, ti)) |
| tails_to_keep_as_shapely_objects.append(tail_polygon) |
|
|
| |
| indices_of_tails_to_keep.append( |
| [t[3].item() for t in tails_to_keep]) |
| return indices_of_tails_to_keep |
|
|
| def _convert_annotations_to_coco_format( |
| self, |
| annotations: Optional[List[Dict[str, Any]]] |
| ) -> Optional[List[Dict[str, Any]]]: |
| """ |
| Konwertuje anotacje z formatu x1y1x2y2 do formatu COCO (xywh). |
| |
| Format COCO używa bbox jako [x, y, width, height] zamiast [x1, y1, x2, y2]. |
| Dodatkowo oblicza pole powierzchni dla każdego bbox. |
| |
| Args: |
| annotations: Lista anotacji w formacie: |
| [{"image_id": int, "bboxes_as_x1y1x2y2": List, "labels": List}] |
| lub None |
| |
| Returns: |
| Lista anotacji w formacie COCO lub None jeśli input był None |
| """ |
| if annotations is None: |
| return None |
| |
| self._verify_annotations_are_in_correct_format(annotations) |
|
|
| coco_annotations: List[Dict[str, Any]] = [] |
| for annotation in annotations: |
| coco_annotation: Dict[str, Any] = { |
| "image_id": annotation["image_id"], |
| "annotations": [], |
| } |
| |
| for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]): |
| coco_annotation["annotations"].append({ |
| |
| "bbox": x1y1x2y2_to_xywh(bbox), |
| "category_id": label, |
| |
| "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), |
| }) |
| coco_annotations.append(coco_annotation) |
| return coco_annotations |
|
|
| def _verify_annotations_are_in_correct_format(self, annotations: Optional[List[Dict[str, Any]]]) -> None: |
| """ |
| Weryfikuje poprawność formatu anotacji. |
| |
| Sprawdza czy anotacje są w oczekiwanym formacie: |
| - Lista/tupla słowników |
| - Każdy słownik zawiera klucze: "image_id", "bboxes_as_x1y1x2y2", "labels" |
| - Labels: 0=postać, 1=tekst, 2=panel, 3=ogon |
| |
| Args: |
| annotations: Anotacje do weryfikacji lub None |
| |
| Raises: |
| ValueError: Jeśli format anotacji jest nieprawidłowy |
| """ |
| error_msg: str = """ |
| Annotations must be in the following format: |
| [ |
| { |
| "image_id": 0, |
| "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]], |
| "labels": [0, 1, 2], |
| }, |
| ... |
| ] |
| Labels: 0 for characters, 1 for text, 2 for panels, 3 for tails. |
| """ |
| if annotations is None: |
| return |
|
|
| |
| if not isinstance(annotations, List) and not isinstance(annotations, tuple): |
| raise ValueError( |
| f"{error_msg} Expected a List/Tuple, found {type(annotations)}." |
| ) |
|
|
| if len(annotations) == 0: |
| return |
|
|
| |
| if not isinstance(annotations[0], dict): |
| raise ValueError( |
| f"{error_msg} Expected a List[Dict], found {type(annotations[0])}." |
| ) |
|
|
| |
| if "image_id" not in annotations[0]: |
| raise ValueError( |
| f"{error_msg} Dict must contain 'image_id'." |
| ) |
| if "bboxes_as_x1y1x2y2" not in annotations[0]: |
| raise ValueError( |
| f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'." |
| ) |
| if "labels" not in annotations[0]: |
| raise ValueError( |
| f"{error_msg} Dict must contain 'labels'." |
| ) |
|
|