mciancone commited on
Commit
bd27421
·
verified ·
1 Parent(s): ead54bf

Upload model artifacts and classifier scripts

Browse files
.gitattributes CHANGED
@@ -1,35 +1,6 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.bin filter=lfs diff=lfs merge=lfs -text
2
+ *.onnx filter=lfs diff=lfs merge=lfs -text
3
+ *.xml filter=lfs diff=lfs merge=lfs -text
4
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
5
+ *.zip filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: image-classification
4
+ tags:
5
+ - image-classification
6
+ - multi-label-classification
7
+ - onnx
8
+ - openvino
9
+ - pdf
10
+ - document-understanding
11
+ - rag
12
+ datasets:
13
+ - Wikit/PdfVisClassif
14
+ ---
15
+
16
+ # PDF Page Classifier
17
+
18
+ Multi-label classifier for PDF page images. Determines whether a PDF page
19
+ requires image embedding (vs. text-only) in RAG pipelines.
20
+
21
+ Backbone: EfficientNet-Lite0. Exported to ONNX and OpenVINO INT8 via
22
+ Quantization-Aware Training (QAT). **No PyTorch required at inference time.**
23
+
24
+ ## Classes
25
+
26
+
27
+
28
+ Pages matching any of the following classes should trigger image embedding:
29
+
30
+ - `Visual Essential`
31
+ - `Complex Table`
32
+
33
+ Default threshold: `0.5`
34
+
35
+ ## Usage
36
+
37
+ ### With [chunknorris](https://github.com/wikit-ai/chunknorris) (recommended)
38
+
39
+ ```bash
40
+ pip install "chunknorris[ml-onnx]" # ONNX backend
41
+ pip install "chunknorris[ml-openvino]" # OpenVINO INT8, fastest on CPU
42
+ ```
43
+
44
+ ```python
45
+ from chunknorris.ml import load_classifier
46
+
47
+ clf = load_classifier("Wikit/pdf-pages-classifier") # auto-selects best available backend
48
+ result = clf.predict("page.png")
49
+ # {"needs_image_embedding": True, "predicted_classes": [...], "probabilities": {...}}
50
+ ```
51
+
52
+ ### Standalone (no chunknorris)
53
+
54
+ ```bash
55
+ git clone https://huggingface.co/Wikit/pdf-pages-classifier
56
+ cd pdf-pages-classifier
57
+ pip install onnxruntime Pillow numpy # or: openvino Pillow numpy
58
+ ```
59
+
60
+ ```python
61
+ from classifiers import load_classifier
62
+
63
+ clf = load_classifier(".") # auto-selects available backend
64
+ result = clf.predict("page.png")
65
+ ```
66
+
67
+ ## Files
68
+
69
+ | File | Format | Notes |
70
+ |------|--------|-------|
71
+ | `model.onnx` | ONNX FP32 | Cross-platform CPU/GPU inference |
72
+ | `openvino_model.xml/.bin` | OpenVINO INT8 | Fastest CPU inference (QAT) |
73
+ | `pytorch_model.bin` | PyTorch | Raw checkpoint; requires `torch` + `timm` |
74
+ | `config.json` | JSON | Preprocessing config and class names |
75
+ | `classifiers/` | Python | Standalone inference scripts (no chunknorris needed) |
76
+
77
+ ## Dataset
78
+
79
+ Trained on [Wikit/PdfVisClassif](https://huggingface.co/datasets/Wikit/PdfVisClassif).
classifiers/__init__.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PDF page classifier — public factory with HuggingFace auto-download.
2
+
3
+ Standalone usage (files downloaded from HF repo):
4
+ from classifiers import load_classifier
5
+ clf = load_classifier(".") # local directory with model files
6
+ result = clf.predict("page.png")
7
+
8
+ HuggingFace usage:
9
+ from classifiers import load_classifier
10
+ clf = load_classifier("Wikit/pdf-pages-classifier")
11
+ result = clf.predict("page.png")
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import os
17
+ from pathlib import Path
18
+ from typing import Any
19
+
20
+
21
+ # INT8 preferred over FP32 for both backends — matches classifier lookup order
22
+ _HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"]
23
+ _HF_ONNX_FP32_FILES = ["model.onnx", "config.json"]
24
+
25
+ _HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"]
26
+ _HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"]
27
+
28
+
29
+ def _is_hf_repo_id(path: str) -> bool:
30
+ """Return True if path looks like 'owner/repo' rather than a local path."""
31
+ if os.path.exists(path):
32
+ return False
33
+ # HF repo IDs have exactly one '/' and no OS path separators or leading dots
34
+ normalized = path.replace("\\", "/")
35
+ if normalized.startswith((".", "/", "~")):
36
+ return False
37
+ parts = normalized.split("/")
38
+ return len(parts) == 2 and all(p.strip() for p in parts)
39
+
40
+
41
+ def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path:
42
+ """Download specific files from a HF repo and return the local snapshot directory."""
43
+ try:
44
+ from huggingface_hub import hf_hub_download
45
+ except ImportError as e:
46
+ raise ImportError(
47
+ "huggingface_hub is required to load from a HuggingFace repo.\n"
48
+ "Install with: pip install huggingface-hub"
49
+ ) from e
50
+
51
+ last: Path | None = None
52
+ for filename in filenames:
53
+ last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir))
54
+
55
+ assert last is not None
56
+ return last.parent
57
+
58
+
59
+ def _download_with_int8_fallback(
60
+ repo_id: str,
61
+ int8_files: list[str],
62
+ fp32_files: list[str],
63
+ cache_dir: str | None,
64
+ ) -> Path:
65
+ """Download files from HF, preferring INT8 over FP32 when available."""
66
+ try:
67
+ from huggingface_hub import EntryNotFoundError
68
+ except ImportError as e:
69
+ raise ImportError(
70
+ "huggingface_hub is required to load from a HuggingFace repo.\n"
71
+ "Install with: pip install huggingface-hub"
72
+ ) from e
73
+
74
+ try:
75
+ return _download_from_hf(repo_id, int8_files, cache_dir)
76
+ except EntryNotFoundError:
77
+ return _download_from_hf(repo_id, fp32_files, cache_dir)
78
+
79
+
80
+ def load_classifier(
81
+ repo_or_dir: str = "Wikit/pdf-pages-classifier",
82
+ backend: str = "auto",
83
+ device: str = "CPU",
84
+ cache_dir: str | None = None,
85
+ ) -> Any:
86
+ """Load a PDF page classifier with automatic backend selection.
87
+
88
+ Args:
89
+ repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``)
90
+ or local directory containing ``config.json`` and model files.
91
+ backend: ``"auto"`` tries OpenVINO first, falls back to ONNX.
92
+ Pass ``"openvino"`` or ``"onnx"`` to force a specific backend.
93
+ device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
94
+ Ignored for ONNX.
95
+ cache_dir: Custom cache directory for HuggingFace downloads.
96
+
97
+ Returns:
98
+ A classifier instance exposing ``predict(images)``.
99
+
100
+ Example::
101
+
102
+ clf = load_classifier("Wikit/pdf-pages-classifier")
103
+ result = clf.predict("page.png")
104
+ print(result["needs_image_embedding"], result["predicted_classes"])
105
+ """
106
+ if backend not in ("auto", "onnx", "openvino"):
107
+ raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.")
108
+
109
+ is_hf = _is_hf_repo_id(repo_or_dir)
110
+
111
+ if backend in ("auto", "openvino"):
112
+ try:
113
+ return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf)
114
+ except (ImportError, FileNotFoundError):
115
+ if backend == "openvino":
116
+ raise
117
+
118
+ return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf)
119
+
120
+
121
+ def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any:
122
+ try:
123
+ from .classifier_onnx import PDFPageClassifierONNX
124
+ except ImportError:
125
+ from classifier_onnx import PDFPageClassifierONNX # type: ignore[no-redef]
126
+
127
+ model_dir = (
128
+ _download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir)
129
+ if is_hf else Path(repo_or_dir)
130
+ )
131
+ return PDFPageClassifierONNX.from_pretrained(str(model_dir))
132
+
133
+
134
+ def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any:
135
+ try:
136
+ from .classifier_ov import PDFPageClassifierOV
137
+ except ImportError:
138
+ from classifier_ov import PDFPageClassifierOV # type: ignore[no-redef]
139
+
140
+ model_dir = (
141
+ _download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir)
142
+ if is_hf else Path(repo_or_dir)
143
+ )
144
+ return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device)
classifiers/base_classifier.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import Any, Union
4
+
5
+ from PIL import Image
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+
10
+ class _BasePDFPageClassifier(ABC):
11
+ """Shared preprocessing, formatting, and predict logic.
12
+
13
+ Subclasses must implement ``_run_batch`` to perform backend-specific
14
+ inference on a (N, C, H, W) float32 numpy array.
15
+ """
16
+
17
+ def __init__(self, config: dict[str, Any]) -> None:
18
+ self._image_size: int = config["image_size"]
19
+ self._mean = np.array(config["mean"], dtype=np.float32)
20
+ self._std = np.array(config["std"], dtype=np.float32)
21
+ self._center_crop: bool = config.get("center_crop_shortest", True)
22
+ self._whiteout: bool = config.get("whiteout_header", False)
23
+ self._whiteout_cutoff: int = int(
24
+ self._image_size * config.get("whiteout_fraction", 0.15)
25
+ )
26
+ self._class_names: list[str] = config["class_names"]
27
+ self._threshold: float = float(config.get("threshold", 0.5))
28
+ self._image_required_classes: set[str] = set(
29
+ config.get("image_required_classes", [])
30
+ )
31
+
32
+ @abstractmethod
33
+ def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
34
+ """Run inference on a (N, C, H, W) float32 batch.
35
+
36
+ Returns:
37
+ (N, num_classes) float32 array of probabilities.
38
+ """
39
+
40
+ @staticmethod
41
+ def _load_image(item: Any) -> "Image.Image":
42
+ """Load an image from a file path or PIL image and convert to RGB.
43
+
44
+ Args:
45
+ item: File path string or PIL image (any mode).
46
+
47
+ Returns:
48
+ RGB PIL image.
49
+
50
+ Raises:
51
+ TypeError: If ``item`` is neither a str nor a PIL.Image.
52
+ """
53
+ if isinstance(item, str):
54
+ return Image.open(item).convert("RGB")
55
+ if isinstance(item, Image.Image):
56
+ return item.convert("RGB")
57
+ raise TypeError(f"Expected str or PIL.Image, got {type(item).__name__}")
58
+
59
+ def _pil_to_array(self, img: "Image.Image") -> "npt.NDArray[np.float32]":
60
+ """Apply spatial transforms and return a (H, W, C) float32 array in [0, 1].
61
+
62
+ Normalization and the channel transpose are intentionally deferred so
63
+ they can be applied in a single vectorised pass over the whole batch in
64
+ ``_normalize_batch``.
65
+
66
+ Steps:
67
+ 1. Center-crop to square (shortest side), if enabled.
68
+ 2. Resize to (image_size, image_size) with bicubic interpolation.
69
+ 3. Scale pixel values to [0, 1].
70
+ 4. White out top header rows, if enabled.
71
+
72
+ Args:
73
+ img: RGB PIL image.
74
+
75
+ Returns:
76
+ Float32 array of shape (image_size, image_size, 3).
77
+ """
78
+ if self._center_crop:
79
+ w, h = img.size
80
+ sq = min(w, h)
81
+ img = img.crop(
82
+ ((w - sq) // 2, (h - sq) // 2, (w + sq) // 2, (h + sq) // 2)
83
+ )
84
+
85
+ img = img.resize((self._image_size, self._image_size), Image.Resampling.BICUBIC)
86
+ arr = np.asarray(img, dtype=np.float32) * (1.0 / 255.0) # (H, W, C)
87
+
88
+ if self._whiteout:
89
+ arr[: self._whiteout_cutoff] = 1.0
90
+
91
+ return arr
92
+
93
+ def _normalize_batch(
94
+ self, arrays: list["npt.NDArray[np.float32]"]
95
+ ) -> "npt.NDArray[np.float32]":
96
+ """Stack a list of (H, W, C) arrays and apply ImageNet normalization.
97
+
98
+ Args:
99
+ arrays: List of float32 arrays, each of shape (H, W, C) in [0, 1].
100
+
101
+ Returns:
102
+ Float32 array of shape (N, C, H, W), normalized with ImageNet stats.
103
+ """
104
+ batch = np.stack(arrays, axis=0) # (N, H, W, C)
105
+ batch = (batch - self._mean) / self._std # broadcast over (H, W, C)
106
+ return batch.transpose(0, 3, 1, 2) # (N, C, H, W)
107
+
108
+ def _format(
109
+ self,
110
+ probabilities: "npt.NDArray[np.float32]",
111
+ threshold: float,
112
+ ) -> dict[str, Any]:
113
+ """Format model output probabilities into a result dict.
114
+
115
+ Args:
116
+ probabilities: 1-D float32 array of per-class probabilities.
117
+ threshold: Probability cutoff for a positive prediction.
118
+
119
+ Returns:
120
+ Dict with keys ``needs_image_embedding``, ``predicted_classes``,
121
+ and ``probabilities``.
122
+ """
123
+ predicted_classes = [
124
+ name
125
+ for name, prob in zip(self._class_names, probabilities)
126
+ if prob >= threshold
127
+ ]
128
+ return {
129
+ "needs_image_embedding": any(
130
+ c in self._image_required_classes for c in predicted_classes
131
+ ),
132
+ "predicted_classes": predicted_classes,
133
+ "probabilities": {
134
+ name: float(prob)
135
+ for name, prob in zip(self._class_names, probabilities)
136
+ },
137
+ }
138
+
139
+ def predict(
140
+ self,
141
+ images: Union[str, "Image.Image", list[Any]],
142
+ threshold: float | None = None,
143
+ batch_size: int = 32,
144
+ num_workers: int = 4,
145
+ ) -> Union[dict[str, Any], list[dict[str, Any]]]:
146
+ """Classify one or more PDF page images.
147
+
148
+ Args:
149
+ images: A single image (file path string or PIL.Image) or a list
150
+ of images.
151
+ threshold: Override the default probability threshold from config.
152
+ The override is local to this call and does not mutate the
153
+ classifier instance.
154
+ batch_size: Number of images to process per inference call.
155
+ num_workers: Number of threads for parallel image loading and
156
+ preprocessing. Set to 1 to disable threading.
157
+
158
+ Returns:
159
+ A single result dict when ``images`` is not a list, or a list of
160
+ result dicts otherwise. Each dict contains:
161
+ - ``needs_image_embedding`` (bool)
162
+ - ``predicted_classes`` (list[str])
163
+ - ``probabilities`` (dict[str, float])
164
+ """
165
+ effective_threshold = self._threshold if threshold is None else threshold
166
+
167
+ is_single = not isinstance(images, list)
168
+ image_list: list[Any] = [images] if is_single else images
169
+
170
+ all_results: list[dict[str, Any]] = []
171
+
172
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
173
+ for batch_start in range(0, len(image_list), batch_size):
174
+ batch_items = image_list[batch_start : batch_start + batch_size]
175
+
176
+ # Load (file I/O + RGB conversion) in parallel, then free after use.
177
+ loaded: list[Image.Image] = list(
178
+ executor.map(self._load_image, batch_items)
179
+ )
180
+ # PIL transforms (crop + bicubic resize) in parallel.
181
+ arrays: list[npt.NDArray[np.float32]] = list(
182
+ executor.map(self._pil_to_array, loaded)
183
+ )
184
+
185
+ # Vectorised normalization + transpose, then inference.
186
+ batch_input = self._normalize_batch(arrays) # (N, C, H, W)
187
+ probs_batch: npt.NDArray[np.float32] = self._run_batch(batch_input)
188
+
189
+ all_results.extend(
190
+ self._format(probs, effective_threshold) for probs in probs_batch
191
+ )
192
+
193
+ return all_results[0] if is_single else all_results
194
+
195
+ def __call__(
196
+ self,
197
+ images: Union[str, "Image.Image", list[Any]],
198
+ threshold: float | None = None,
199
+ batch_size: int = 32,
200
+ num_workers: int = 4,
201
+ ) -> Union[dict[str, Any], list[dict[str, Any]]]:
202
+ """Delegate to predict(). See predict() for full documentation."""
203
+ return self.predict(
204
+ images, threshold=threshold, batch_size=batch_size, num_workers=num_workers
205
+ )
classifiers/classifier_onnx.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PDF page classifier for production inference."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ try:
11
+ from .base_classifier import _BasePDFPageClassifier
12
+ except ImportError:
13
+ from base_classifier import _BasePDFPageClassifier # standalone / HF usage
14
+
15
+ try:
16
+ import onnxruntime as ort
17
+ except ImportError as _e:
18
+ raise ImportError(
19
+ "onnxruntime is required for inference.\n"
20
+ "Install with: pip install onnxruntime"
21
+ ) from _e
22
+
23
+
24
+
25
+ class PDFPageClassifierONNX(_BasePDFPageClassifier):
26
+ """Classify PDF pages using a deployed ONNX model.
27
+
28
+ Loads a self-contained deployment directory produced by
29
+ ``export_onnx.save_for_deployment`` and exposes a simple ``predict``
30
+ interface. All preprocessing (center-crop, resize, normalization) is
31
+ performed in pure PIL + numpy, matching the pipeline used during training.
32
+
33
+ Example::
34
+
35
+ clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment")
36
+ result = clf.predict("page_001.png")
37
+ print(result["needs_image_embedding"], result["predicted_classes"])
38
+ """
39
+
40
+ def __init__(self, model_path: str, config: dict[str, Any]) -> None:
41
+ """Initialise the classifier.
42
+
43
+ Args:
44
+ model_path: Path to the ONNX model file.
45
+ config: Deployment config dict (same schema as config.json written
46
+ by save_for_deployment).
47
+ """
48
+ super().__init__(config)
49
+ self._session = ort.InferenceSession(model_path)
50
+ self._input_name: str = self._session.get_inputs()[0].name
51
+
52
+ @classmethod
53
+ def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier":
54
+ """Load a classifier from a deployment directory.
55
+
56
+ The directory must contain:
57
+ - ``model.onnx`` — exported by save_for_deployment
58
+ - ``config.json`` — written by save_for_deployment
59
+
60
+ Args:
61
+ model_dir: Path to the deployment directory.
62
+
63
+ Returns:
64
+ Initialised PDFPageClassifier.
65
+ """
66
+ path = Path(model_dir)
67
+ config_path = path / "config.json"
68
+
69
+ if not config_path.exists():
70
+ raise FileNotFoundError(f"config.json not found in {model_dir}")
71
+
72
+ # Prefer INT8 (QAT export) over FP32 when both are present
73
+ candidates = ["model_int8.onnx", "model.onnx"]
74
+ for candidate in candidates:
75
+ if (path / candidate).exists():
76
+ model_path = path / candidate
77
+ break
78
+ else:
79
+ raise FileNotFoundError(
80
+ f"No ONNX model found in {model_dir}. "
81
+ f"Expected one of: {', '.join(candidates)}."
82
+ )
83
+
84
+ with open(config_path, encoding="utf-8") as f:
85
+ config = json.load(f)
86
+
87
+ return cls(str(model_path), config)
88
+
89
+ def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
90
+ return self._session.run(None, {self._input_name: batch_input})[0]
classifiers/classifier_ov.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenVINO-based PDF page classifier for production inference."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ try:
11
+ from openvino import Core
12
+ from openvino import CompiledModel
13
+ except ImportError as _e:
14
+ raise ImportError(
15
+ "openvino is required for OpenVINO inference.\n"
16
+ "Install with: pip install openvino"
17
+ ) from _e
18
+
19
+ try:
20
+ from .base_classifier import _BasePDFPageClassifier
21
+ except ImportError:
22
+ from base_classifier import _BasePDFPageClassifier # standalone / HF usage
23
+
24
+
25
+ class PDFPageClassifierOV(_BasePDFPageClassifier):
26
+ """Classify PDF pages using a deployed OpenVINO IR model.
27
+
28
+ Loads a self-contained deployment directory produced by
29
+ ``export_onnx.save_for_deployment`` (with ``export_openvino=True``) and
30
+ exposes the same ``predict`` interface as ``PDFPageClassifier``.
31
+
32
+ Automatically selects the INT8 variant (``model_ov_int8.xml``) when it
33
+ exists alongside the FP32 model, falling back to ``model_ov.xml``.
34
+
35
+ Example::
36
+
37
+ clf = PDFPageClassifierOV.from_pretrained("outputs/run-42/deployment")
38
+ result = clf.predict("page_001.png")
39
+ print(result["needs_image_embedding"], result["predicted_classes"])
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ model_path: str,
45
+ config: dict[str, Any],
46
+ device: str = "CPU",
47
+ ) -> None:
48
+ """Initialise the classifier.
49
+
50
+ Args:
51
+ model_path: Path to the OpenVINO IR ``.xml`` file.
52
+ config: Deployment config dict (same schema as config.json written
53
+ by save_for_deployment).
54
+ device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
55
+ """
56
+ super().__init__(config)
57
+ compiled: CompiledModel = Core().compile_model(model_path, device)
58
+ self._session: CompiledModel = compiled
59
+ self._input_name: str = compiled.input(0).get_any_name()
60
+ self._output = compiled.output(0)
61
+
62
+ @classmethod
63
+ def from_pretrained(
64
+ cls,
65
+ model_dir: str,
66
+ device: str = "CPU",
67
+ ) -> "PDFPageClassifierOV":
68
+ """Load a classifier from a deployment directory.
69
+
70
+ The directory must contain:
71
+ - ``model_ov.xml`` / ``model_ov_int8.xml`` — exported by
72
+ save_for_deployment with ``export_openvino=True``
73
+ - ``config.json`` — written by save_for_deployment
74
+
75
+ The INT8 model (``model_ov_int8.xml``) is preferred when present.
76
+
77
+ Args:
78
+ model_dir: Path to the deployment directory.
79
+ device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
80
+
81
+ Returns:
82
+ Initialised PDFPageClassifierOV.
83
+ """
84
+ path = Path(model_dir)
85
+ config_path = path / "config.json"
86
+
87
+ if not config_path.exists():
88
+ raise FileNotFoundError(f"config.json not found in {model_dir}")
89
+
90
+ # Search order: prefer INT8 over FP32, HF/Optimum names over legacy names
91
+ candidates = [
92
+ "openvino_model_int8.xml", # HF-style INT8 (preferred)
93
+ "openvino_model.xml", # HF-style FP32
94
+ "model_ov_int8.xml", # legacy local INT8
95
+ "model_ov.xml", # legacy local FP32
96
+ ]
97
+ for candidate in candidates:
98
+ if (path / candidate).exists():
99
+ model_path = path / candidate
100
+ break
101
+ else:
102
+ raise FileNotFoundError(
103
+ f"No OpenVINO model found in {model_dir}. "
104
+ f"Expected one of: {', '.join(candidates)}. "
105
+ "Export with save_for_deployment(..., export_openvino=True)."
106
+ )
107
+
108
+ with open(config_path, encoding="utf-8") as f:
109
+ config = json.load(f)
110
+
111
+ return cls(str(model_path), config, device=device)
112
+
113
+ def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
114
+ return self._session({self._input_name: batch_input})[self._output]
classifiers/classifier_torch.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch-based PDF page classifier for native inference."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ try:
11
+ import torch
12
+ except ImportError as _e:
13
+ raise ImportError(
14
+ "torch is required for PyTorch inference.\n"
15
+ "Install with: pip install torch"
16
+ ) from _e
17
+
18
+ from classifiers.base_classifier import _BasePDFPageClassifier
19
+ from models import create_model
20
+
21
+
22
+ class PDFPageClassifierTorch(_BasePDFPageClassifier):
23
+ """Classify PDF pages using a native PyTorch checkpoint.
24
+
25
+ Loads a checkpoint produced by the training script and exposes the same
26
+ ``predict`` interface as the ONNX and OpenVINO classifiers. All
27
+ preprocessing (center-crop, resize, normalization) is handled by the
28
+ shared base class.
29
+
30
+ Example::
31
+
32
+ clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt")
33
+ result = clf.predict("page_001.png")
34
+ print(result["needs_image_embedding"], result["predicted_classes"])
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ model: "torch.nn.Module",
40
+ config: dict[str, Any],
41
+ device: "torch.device | str" = "cpu",
42
+ ) -> None:
43
+ """Initialise the classifier.
44
+
45
+ Args:
46
+ model: PyTorch model already loaded with weights and set to eval mode.
47
+ config: Flat config dict compatible with the base classifier schema.
48
+ device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.).
49
+ """
50
+ super().__init__(config)
51
+ self._device = torch.device(device)
52
+ self._model = model.to(self._device)
53
+ self._model.eval()
54
+
55
+ @classmethod
56
+ def from_checkpoint(
57
+ cls,
58
+ checkpoint_path: str,
59
+ device: "torch.device | str" = "cpu",
60
+ image_required_classes: list[str] | None = None,
61
+ threshold: float = 0.5,
62
+ ) -> "PDFPageClassifierTorch":
63
+ """Load a classifier from a training checkpoint.
64
+
65
+ The checkpoint must contain:
66
+ - ``model_state_dict`` — model weights
67
+ - ``config`` — training config with ``model`` and ``data`` keys
68
+ - ``class_names`` — ordered list of class names
69
+
70
+ Args:
71
+ checkpoint_path: Path to the ``.pt`` checkpoint file.
72
+ device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
73
+ image_required_classes: Class names that trigger image embedding.
74
+ Defaults to an empty list when not provided.
75
+ threshold: Default prediction threshold (can be overridden per call).
76
+
77
+ Returns:
78
+ Initialised PDFPageClassifierTorch.
79
+ """
80
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
81
+
82
+ train_cfg = ckpt["config"]
83
+ class_names: list[str] = ckpt["class_names"]
84
+ data_cfg = train_cfg["data"]
85
+
86
+ model = create_model(
87
+ model_name=train_cfg["model"]["name"],
88
+ num_classes=len(class_names),
89
+ pretrained=False,
90
+ dropout=train_cfg["model"]["dropout"],
91
+ use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False),
92
+ )
93
+ model.load_state_dict(ckpt["model_state_dict"])
94
+
95
+ # Build a flat config dict that matches the base-class schema.
96
+ config: dict[str, Any] = {
97
+ "image_size": data_cfg["image_size"],
98
+ "mean": data_cfg.get("mean", [0.485, 0.456, 0.406]),
99
+ "std": data_cfg.get("std", [0.229, 0.224, 0.225]),
100
+ "center_crop_shortest": data_cfg.get("center_crop_shortest", True),
101
+ "whiteout_header": data_cfg.get("whiteout_header", False),
102
+ "whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15),
103
+ "class_names": class_names,
104
+ "threshold": threshold,
105
+ "image_required_classes": image_required_classes or [],
106
+ }
107
+
108
+ return cls(model, config, device=device)
109
+
110
+ @classmethod
111
+ def from_pretrained(
112
+ cls,
113
+ model_dir: str,
114
+ device: "torch.device | str" = "cpu",
115
+ ) -> "PDFPageClassifierTorch":
116
+ """Load a classifier from a deployment directory.
117
+
118
+ The directory must contain:
119
+ - ``model.pt`` — PyTorch checkpoint written by save_for_deployment
120
+ - ``config.json`` — deployment config written by save_for_deployment
121
+
122
+ Args:
123
+ model_dir: Path to the deployment directory.
124
+ device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
125
+
126
+ Returns:
127
+ Initialised PDFPageClassifierTorch.
128
+ """
129
+ path = Path(model_dir)
130
+ config_path = path / "config.json"
131
+ model_path = path / "model.pt"
132
+
133
+ if not config_path.exists():
134
+ raise FileNotFoundError(f"config.json not found in {model_dir}")
135
+ if not model_path.exists():
136
+ raise FileNotFoundError(f"model.pt not found in {model_dir}")
137
+
138
+ with open(config_path, encoding="utf-8") as f:
139
+ config: dict[str, Any] = json.load(f)
140
+
141
+ ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
142
+
143
+ model = create_model(
144
+ model_name=config["model_name"],
145
+ num_classes=len(config["class_names"]),
146
+ pretrained=False,
147
+ dropout=config.get("dropout", 0.2),
148
+ use_spatial_pooling=config.get("use_spatial_pooling", False),
149
+ )
150
+ model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
151
+
152
+ return cls(model, config, device=device)
153
+
154
+ def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
155
+ tensor = torch.from_numpy(batch_input).to(self._device) # type:ignore
156
+ with torch.no_grad():
157
+ logits = self._model(tensor)
158
+ probs = torch.sigmoid(logits)
159
+ return probs.cpu().numpy() # type: ignore
classifiers/models.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for PDF page classification."""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import timm
6
+
7
+
8
+ class MultiLabelClassifier(nn.Module):
9
+ """Multi-label image classifier with configurable backbone.
10
+
11
+ Args:
12
+ model_name: Name of the timm model to use as backbone
13
+ num_classes: Number of output classes
14
+ pretrained: Whether to use pretrained weights
15
+ dropout: Dropout probability before final layer
16
+ use_spatial_pooling: If True, use spatial max pooling (CAM-style) instead of global pooling
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ model_name: str,
22
+ num_classes: int,
23
+ pretrained: bool = True,
24
+ dropout: float = 0.2,
25
+ use_spatial_pooling: bool = False
26
+ ):
27
+ super().__init__()
28
+
29
+ self.model_name = model_name
30
+ self.num_classes = num_classes
31
+ self.use_spatial_pooling = use_spatial_pooling
32
+
33
+ # Load pretrained backbone from timm
34
+ if use_spatial_pooling:
35
+ # No global pooling - keep spatial dimensions
36
+ self.backbone = timm.create_model(
37
+ model_name,
38
+ pretrained=pretrained,
39
+ num_classes=0, # Remove classification head
40
+ global_pool='' # No pooling
41
+ )
42
+ else:
43
+ # Standard global average pooling
44
+ self.backbone = timm.create_model(
45
+ model_name,
46
+ pretrained=pretrained,
47
+ num_classes=0, # Remove classification head
48
+ global_pool='avg'
49
+ )
50
+
51
+ # Get feature dimension
52
+ with torch.no_grad():
53
+ dummy_input = torch.randn(1, 3, 224, 224)
54
+ features = self.backbone(dummy_input)
55
+
56
+ if use_spatial_pooling:
57
+ # features shape: [B, C, H, W]
58
+ self.feature_dim = features.shape[1]
59
+ print(f"Spatial pooling enabled - feature map shape: {features.shape}")
60
+ else:
61
+ # features shape: [B, C]
62
+ self.feature_dim = features.shape[1]
63
+
64
+ # Classification head
65
+ if use_spatial_pooling:
66
+ # 1x1 conv for spatial classification + dropout
67
+ self.classifier = nn.Sequential(
68
+ nn.Dropout2d(p=dropout), # Spatial dropout
69
+ nn.Conv2d(self.feature_dim, num_classes, kernel_size=1)
70
+ )
71
+ else:
72
+ # Standard linear classifier
73
+ self.classifier = nn.Sequential(
74
+ nn.Dropout(p=dropout),
75
+ nn.Linear(self.feature_dim, num_classes)
76
+ )
77
+
78
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
79
+ """Forward pass.
80
+
81
+ Args:
82
+ x: Input tensor of shape (batch_size, 3, H, W)
83
+
84
+ Returns:
85
+ Logits of shape (batch_size, num_classes)
86
+ """
87
+ features = self.backbone(x)
88
+
89
+ if self.use_spatial_pooling:
90
+ # features: [B, C, H, W]
91
+ # spatial_logits: [B, num_classes, H, W]
92
+ spatial_logits = self.classifier(features)
93
+ # Global max pooling per class: [B, num_classes]
94
+ logits = torch.amax(spatial_logits, dim=(2, 3))
95
+ else:
96
+ # features: [B, C]
97
+ # logits: [B, num_classes]
98
+ logits = self.classifier(features)
99
+
100
+ return logits
101
+
102
+ def get_features(self, x: torch.Tensor) -> torch.Tensor:
103
+ """Extract features without classification head.
104
+
105
+ Useful for feature visualization or transfer learning.
106
+
107
+ Args:
108
+ x: Input tensor of shape (batch_size, 3, H, W)
109
+
110
+ Returns:
111
+ Features of shape (batch_size, feature_dim) or (batch_size, feature_dim, H, W)
112
+ """
113
+ return self.backbone(x)
114
+
115
+ def get_activation_maps(self, x: torch.Tensor) -> torch.Tensor:
116
+ """Get spatial activation maps (only for spatial pooling mode).
117
+
118
+ Args:
119
+ x: Input tensor of shape (batch_size, 3, H, W)
120
+
121
+ Returns:
122
+ Activation maps of shape (batch_size, num_classes, H, W)
123
+
124
+ Raises:
125
+ ValueError: If spatial pooling is not enabled
126
+ """
127
+ if not self.use_spatial_pooling:
128
+ raise ValueError("Activation maps only available with spatial pooling enabled")
129
+
130
+ features = self.backbone(x)
131
+ spatial_logits = self.classifier(features)
132
+ return spatial_logits
133
+
134
+
135
+ def create_model(
136
+ model_name: str,
137
+ num_classes: int,
138
+ pretrained: bool = True,
139
+ dropout: float = 0.2,
140
+ use_spatial_pooling: bool = False
141
+ ) -> MultiLabelClassifier:
142
+ """Factory function to create a model.
143
+
144
+ Args:
145
+ model_name: Name of the model architecture. Example : mobilenetv3_small_100
146
+ num_classes: Number of output classes
147
+ pretrained: Whether to use pretrained weights
148
+ dropout: Dropout probability
149
+ use_spatial_pooling: If True, use spatial max pooling (CAM-style)
150
+
151
+ Returns:
152
+ Initialized model
153
+ """
154
+ # Verify model exists in timm
155
+ available_models = timm.list_models(model_name)
156
+ if not available_models:
157
+ raise ValueError(
158
+ f"Model '{model_name}' not found in timm."
159
+ f"Available options: {timm.list_models()}"
160
+ )
161
+
162
+ model = MultiLabelClassifier(
163
+ model_name=model_name,
164
+ num_classes=num_classes,
165
+ pretrained=pretrained,
166
+ dropout=dropout,
167
+ use_spatial_pooling=use_spatial_pooling
168
+ )
169
+
170
+ return model
171
+
172
+
173
+ def count_parameters(model: nn.Module) -> dict[str, int | float]:
174
+ """Count model parameters.
175
+
176
+ Args:
177
+ model: PyTorch model
178
+
179
+ Returns:
180
+ Dictionary with parameter counts
181
+ """
182
+ total_params = sum(p.numel() for p in model.parameters())
183
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
184
+
185
+ return {
186
+ 'total': total_params,
187
+ 'trainable': trainable_params,
188
+ 'non_trainable': total_params - trainable_params,
189
+ 'total_millions': total_params / 1e6,
190
+ 'trainable_millions': trainable_params / 1e6
191
+ }
192
+
193
+
194
+ def print_model_info(model: nn.Module, model_name: str = "Model"):
195
+ """Print model information.
196
+
197
+ Args:
198
+ model: PyTorch model
199
+ model_name: Name to display
200
+ """
201
+ params = count_parameters(model)
202
+
203
+ print(f"\n{'='*60}")
204
+ print(f"{model_name} Information")
205
+ print(f"{'='*60}")
206
+ print(f"Total parameters: {params['total']:,} ({params['total_millions']:.2f}M)")
207
+ print(f"Trainable parameters: {params['trainable']:,} ({params['trainable_millions']:.2f}M)")
208
+ print(f"Non-trainable params: {params['non_trainable']:,}")
209
+ print(f"{'='*60}\n")
config.json ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "name": "efficientnet_lite0",
4
+ "pretrained": true,
5
+ "dropout": 0,
6
+ "use_spatial_pooling": false
7
+ },
8
+ "classes": [
9
+ "Visual - Essential",
10
+ "Simple Table",
11
+ "Chart/Graph",
12
+ "Visual - Supportive",
13
+ "Annotated figure",
14
+ "No Specific Feature",
15
+ "Diagram/Flowchart",
16
+ "Visual - Decorative",
17
+ "Complex Table",
18
+ "Infographic",
19
+ "Form",
20
+ "Text to OCR"
21
+ ],
22
+ "class_mapping": {
23
+ "Form": null,
24
+ "No Specific Feature": null,
25
+ "Text to OCR": null,
26
+ "Visual - Decorative": null,
27
+ "Infographic": null,
28
+ "Chart/Graph": null,
29
+ "Annotated figure": null,
30
+ "Diagram/Flowchart": null
31
+ },
32
+ "image_required_classes": [
33
+ "Visual Essential",
34
+ "Complex Table"
35
+ ],
36
+ "data": {
37
+ "train_split": 0.8,
38
+ "val_split": 0.1,
39
+ "test_split": 0.1,
40
+ "image_size": 224,
41
+ "batch_size": 32,
42
+ "num_workers": 4,
43
+ "seed": 42
44
+ },
45
+ "augmentation": {
46
+ "center_crop_shortest": true,
47
+ "whiteout_header": false,
48
+ "whiteout_fraction": 0.15,
49
+ "train": {
50
+ "horizontal_flip": 0.5,
51
+ "rotation_degrees": 5,
52
+ "color_jitter": {
53
+ "brightness": 0.2,
54
+ "contrast": 0.2,
55
+ "saturation": 0.1,
56
+ "hue": 0.05
57
+ },
58
+ "random_erasing": 0.1
59
+ },
60
+ "val": {
61
+ "enabled": false
62
+ }
63
+ },
64
+ "training": {
65
+ "epochs": 40,
66
+ "learning_rate": 0.0001,
67
+ "weight_decay": 0.0001,
68
+ "optimizer": "adamw",
69
+ "scheduler": "cosine",
70
+ "warmup_epochs": 5,
71
+ "label_smoothing": 0.0,
72
+ "gradient_clip_norm": 1.0,
73
+ "pos_weight": [
74
+ 3.6715595722198486,
75
+ 6.668674468994141,
76
+ 2.3281044960021973,
77
+ 6.0722222328186035
78
+ ]
79
+ },
80
+ "monitoring": {
81
+ "metric": "val_f1",
82
+ "mode": "max"
83
+ },
84
+ "early_stopping": {
85
+ "enabled": true,
86
+ "patience": 20
87
+ },
88
+ "evaluation": {
89
+ "threshold": 0.5,
90
+ "save_confusion_matrix": true,
91
+ "save_per_class_metrics": true
92
+ },
93
+ "checkpointing": {
94
+ "save_best_only": true,
95
+ "save_last": true
96
+ },
97
+ "paths": {
98
+ "data_dir": "data",
99
+ "output_dir": "outputs",
100
+ "checkpoint_dir": "checkpoints",
101
+ "logs_dir": "logs"
102
+ },
103
+ "logging": {
104
+ "use_tensorboard": false,
105
+ "use_wandb": true,
106
+ "wandb_project": "pdf-page-classifier",
107
+ "log_interval": 10,
108
+ "wandb_run_name": "silver-line-69"
109
+ },
110
+ "qat": {
111
+ "enabled": true,
112
+ "epochs": 5,
113
+ "learning_rate": "1e-5",
114
+ "preset": "mixed",
115
+ "num_init_samples": 300
116
+ },
117
+ "onnx": {
118
+ "opset_version": 14,
119
+ "dynamic_axes": true,
120
+ "simplify": true,
121
+ "input_names": [
122
+ "image"
123
+ ],
124
+ "output_names": [
125
+ "probabilities"
126
+ ]
127
+ }
128
+ }