Upload model artifacts and classifier scripts
Browse files- .gitattributes +6 -35
- README.md +79 -0
- classifiers/__init__.py +144 -0
- classifiers/base_classifier.py +205 -0
- classifiers/classifier_onnx.py +90 -0
- classifiers/classifier_ov.py +114 -0
- classifiers/classifier_torch.py +159 -0
- classifiers/models.py +209 -0
- config.json +128 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,6 @@
|
|
| 1 |
-
*.
|
| 2 |
-
*.
|
| 3 |
-
*.
|
| 4 |
-
*.
|
| 5 |
-
*.
|
| 6 |
-
*.
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.xml filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
pipeline_tag: image-classification
|
| 4 |
+
tags:
|
| 5 |
+
- image-classification
|
| 6 |
+
- multi-label-classification
|
| 7 |
+
- onnx
|
| 8 |
+
- openvino
|
| 9 |
+
- pdf
|
| 10 |
+
- document-understanding
|
| 11 |
+
- rag
|
| 12 |
+
datasets:
|
| 13 |
+
- Wikit/PdfVisClassif
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# PDF Page Classifier
|
| 17 |
+
|
| 18 |
+
Multi-label classifier for PDF page images. Determines whether a PDF page
|
| 19 |
+
requires image embedding (vs. text-only) in RAG pipelines.
|
| 20 |
+
|
| 21 |
+
Backbone: EfficientNet-Lite0. Exported to ONNX and OpenVINO INT8 via
|
| 22 |
+
Quantization-Aware Training (QAT). **No PyTorch required at inference time.**
|
| 23 |
+
|
| 24 |
+
## Classes
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
Pages matching any of the following classes should trigger image embedding:
|
| 29 |
+
|
| 30 |
+
- `Visual Essential`
|
| 31 |
+
- `Complex Table`
|
| 32 |
+
|
| 33 |
+
Default threshold: `0.5`
|
| 34 |
+
|
| 35 |
+
## Usage
|
| 36 |
+
|
| 37 |
+
### With [chunknorris](https://github.com/wikit-ai/chunknorris) (recommended)
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
pip install "chunknorris[ml-onnx]" # ONNX backend
|
| 41 |
+
pip install "chunknorris[ml-openvino]" # OpenVINO INT8, fastest on CPU
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
from chunknorris.ml import load_classifier
|
| 46 |
+
|
| 47 |
+
clf = load_classifier("Wikit/pdf-pages-classifier") # auto-selects best available backend
|
| 48 |
+
result = clf.predict("page.png")
|
| 49 |
+
# {"needs_image_embedding": True, "predicted_classes": [...], "probabilities": {...}}
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Standalone (no chunknorris)
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
git clone https://huggingface.co/Wikit/pdf-pages-classifier
|
| 56 |
+
cd pdf-pages-classifier
|
| 57 |
+
pip install onnxruntime Pillow numpy # or: openvino Pillow numpy
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
from classifiers import load_classifier
|
| 62 |
+
|
| 63 |
+
clf = load_classifier(".") # auto-selects available backend
|
| 64 |
+
result = clf.predict("page.png")
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Files
|
| 68 |
+
|
| 69 |
+
| File | Format | Notes |
|
| 70 |
+
|------|--------|-------|
|
| 71 |
+
| `model.onnx` | ONNX FP32 | Cross-platform CPU/GPU inference |
|
| 72 |
+
| `openvino_model.xml/.bin` | OpenVINO INT8 | Fastest CPU inference (QAT) |
|
| 73 |
+
| `pytorch_model.bin` | PyTorch | Raw checkpoint; requires `torch` + `timm` |
|
| 74 |
+
| `config.json` | JSON | Preprocessing config and class names |
|
| 75 |
+
| `classifiers/` | Python | Standalone inference scripts (no chunknorris needed) |
|
| 76 |
+
|
| 77 |
+
## Dataset
|
| 78 |
+
|
| 79 |
+
Trained on [Wikit/PdfVisClassif](https://huggingface.co/datasets/Wikit/PdfVisClassif).
|
classifiers/__init__.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PDF page classifier — public factory with HuggingFace auto-download.
|
| 2 |
+
|
| 3 |
+
Standalone usage (files downloaded from HF repo):
|
| 4 |
+
from classifiers import load_classifier
|
| 5 |
+
clf = load_classifier(".") # local directory with model files
|
| 6 |
+
result = clf.predict("page.png")
|
| 7 |
+
|
| 8 |
+
HuggingFace usage:
|
| 9 |
+
from classifiers import load_classifier
|
| 10 |
+
clf = load_classifier("Wikit/pdf-pages-classifier")
|
| 11 |
+
result = clf.predict("page.png")
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# INT8 preferred over FP32 for both backends — matches classifier lookup order
|
| 22 |
+
_HF_ONNX_INT8_FILES = ["model_int8.onnx", "config.json"]
|
| 23 |
+
_HF_ONNX_FP32_FILES = ["model.onnx", "config.json"]
|
| 24 |
+
|
| 25 |
+
_HF_OV_INT8_FILES = ["openvino_model_int8.xml", "openvino_model_int8.bin", "config.json"]
|
| 26 |
+
_HF_OV_FP32_FILES = ["openvino_model.xml", "openvino_model.bin", "config.json"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _is_hf_repo_id(path: str) -> bool:
|
| 30 |
+
"""Return True if path looks like 'owner/repo' rather than a local path."""
|
| 31 |
+
if os.path.exists(path):
|
| 32 |
+
return False
|
| 33 |
+
# HF repo IDs have exactly one '/' and no OS path separators or leading dots
|
| 34 |
+
normalized = path.replace("\\", "/")
|
| 35 |
+
if normalized.startswith((".", "/", "~")):
|
| 36 |
+
return False
|
| 37 |
+
parts = normalized.split("/")
|
| 38 |
+
return len(parts) == 2 and all(p.strip() for p in parts)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _download_from_hf(repo_id: str, filenames: list[str], cache_dir: str | None) -> Path:
|
| 42 |
+
"""Download specific files from a HF repo and return the local snapshot directory."""
|
| 43 |
+
try:
|
| 44 |
+
from huggingface_hub import hf_hub_download
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
raise ImportError(
|
| 47 |
+
"huggingface_hub is required to load from a HuggingFace repo.\n"
|
| 48 |
+
"Install with: pip install huggingface-hub"
|
| 49 |
+
) from e
|
| 50 |
+
|
| 51 |
+
last: Path | None = None
|
| 52 |
+
for filename in filenames:
|
| 53 |
+
last = Path(hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir))
|
| 54 |
+
|
| 55 |
+
assert last is not None
|
| 56 |
+
return last.parent
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _download_with_int8_fallback(
|
| 60 |
+
repo_id: str,
|
| 61 |
+
int8_files: list[str],
|
| 62 |
+
fp32_files: list[str],
|
| 63 |
+
cache_dir: str | None,
|
| 64 |
+
) -> Path:
|
| 65 |
+
"""Download files from HF, preferring INT8 over FP32 when available."""
|
| 66 |
+
try:
|
| 67 |
+
from huggingface_hub import EntryNotFoundError
|
| 68 |
+
except ImportError as e:
|
| 69 |
+
raise ImportError(
|
| 70 |
+
"huggingface_hub is required to load from a HuggingFace repo.\n"
|
| 71 |
+
"Install with: pip install huggingface-hub"
|
| 72 |
+
) from e
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
return _download_from_hf(repo_id, int8_files, cache_dir)
|
| 76 |
+
except EntryNotFoundError:
|
| 77 |
+
return _download_from_hf(repo_id, fp32_files, cache_dir)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def load_classifier(
|
| 81 |
+
repo_or_dir: str = "Wikit/pdf-pages-classifier",
|
| 82 |
+
backend: str = "auto",
|
| 83 |
+
device: str = "CPU",
|
| 84 |
+
cache_dir: str | None = None,
|
| 85 |
+
) -> Any:
|
| 86 |
+
"""Load a PDF page classifier with automatic backend selection.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
repo_or_dir: HuggingFace repo ID (e.g. ``"Wikit/pdf-pages-classifier"``)
|
| 90 |
+
or local directory containing ``config.json`` and model files.
|
| 91 |
+
backend: ``"auto"`` tries OpenVINO first, falls back to ONNX.
|
| 92 |
+
Pass ``"openvino"`` or ``"onnx"`` to force a specific backend.
|
| 93 |
+
device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
|
| 94 |
+
Ignored for ONNX.
|
| 95 |
+
cache_dir: Custom cache directory for HuggingFace downloads.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
A classifier instance exposing ``predict(images)``.
|
| 99 |
+
|
| 100 |
+
Example::
|
| 101 |
+
|
| 102 |
+
clf = load_classifier("Wikit/pdf-pages-classifier")
|
| 103 |
+
result = clf.predict("page.png")
|
| 104 |
+
print(result["needs_image_embedding"], result["predicted_classes"])
|
| 105 |
+
"""
|
| 106 |
+
if backend not in ("auto", "onnx", "openvino"):
|
| 107 |
+
raise ValueError(f"Unknown backend {backend!r}. Choose 'auto', 'onnx', or 'openvino'.")
|
| 108 |
+
|
| 109 |
+
is_hf = _is_hf_repo_id(repo_or_dir)
|
| 110 |
+
|
| 111 |
+
if backend in ("auto", "openvino"):
|
| 112 |
+
try:
|
| 113 |
+
return _load_openvino(repo_or_dir, device=device, cache_dir=cache_dir, is_hf=is_hf)
|
| 114 |
+
except (ImportError, FileNotFoundError):
|
| 115 |
+
if backend == "openvino":
|
| 116 |
+
raise
|
| 117 |
+
|
| 118 |
+
return _load_onnx(repo_or_dir, cache_dir=cache_dir, is_hf=is_hf)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _load_onnx(repo_or_dir: str, cache_dir: str | None, is_hf: bool) -> Any:
|
| 122 |
+
try:
|
| 123 |
+
from .classifier_onnx import PDFPageClassifierONNX
|
| 124 |
+
except ImportError:
|
| 125 |
+
from classifier_onnx import PDFPageClassifierONNX # type: ignore[no-redef]
|
| 126 |
+
|
| 127 |
+
model_dir = (
|
| 128 |
+
_download_with_int8_fallback(repo_or_dir, _HF_ONNX_INT8_FILES, _HF_ONNX_FP32_FILES, cache_dir)
|
| 129 |
+
if is_hf else Path(repo_or_dir)
|
| 130 |
+
)
|
| 131 |
+
return PDFPageClassifierONNX.from_pretrained(str(model_dir))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _load_openvino(repo_or_dir: str, device: str, cache_dir: str | None, is_hf: bool) -> Any:
|
| 135 |
+
try:
|
| 136 |
+
from .classifier_ov import PDFPageClassifierOV
|
| 137 |
+
except ImportError:
|
| 138 |
+
from classifier_ov import PDFPageClassifierOV # type: ignore[no-redef]
|
| 139 |
+
|
| 140 |
+
model_dir = (
|
| 141 |
+
_download_with_int8_fallback(repo_or_dir, _HF_OV_INT8_FILES, _HF_OV_FP32_FILES, cache_dir)
|
| 142 |
+
if is_hf else Path(repo_or_dir)
|
| 143 |
+
)
|
| 144 |
+
return PDFPageClassifierOV.from_pretrained(str(model_dir), device=device)
|
classifiers/base_classifier.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod, ABC
|
| 2 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
+
from typing import Any, Union
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import numpy.typing as npt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class _BasePDFPageClassifier(ABC):
|
| 11 |
+
"""Shared preprocessing, formatting, and predict logic.
|
| 12 |
+
|
| 13 |
+
Subclasses must implement ``_run_batch`` to perform backend-specific
|
| 14 |
+
inference on a (N, C, H, W) float32 numpy array.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, config: dict[str, Any]) -> None:
|
| 18 |
+
self._image_size: int = config["image_size"]
|
| 19 |
+
self._mean = np.array(config["mean"], dtype=np.float32)
|
| 20 |
+
self._std = np.array(config["std"], dtype=np.float32)
|
| 21 |
+
self._center_crop: bool = config.get("center_crop_shortest", True)
|
| 22 |
+
self._whiteout: bool = config.get("whiteout_header", False)
|
| 23 |
+
self._whiteout_cutoff: int = int(
|
| 24 |
+
self._image_size * config.get("whiteout_fraction", 0.15)
|
| 25 |
+
)
|
| 26 |
+
self._class_names: list[str] = config["class_names"]
|
| 27 |
+
self._threshold: float = float(config.get("threshold", 0.5))
|
| 28 |
+
self._image_required_classes: set[str] = set(
|
| 29 |
+
config.get("image_required_classes", [])
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
|
| 34 |
+
"""Run inference on a (N, C, H, W) float32 batch.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
(N, num_classes) float32 array of probabilities.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def _load_image(item: Any) -> "Image.Image":
|
| 42 |
+
"""Load an image from a file path or PIL image and convert to RGB.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
item: File path string or PIL image (any mode).
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
RGB PIL image.
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
TypeError: If ``item`` is neither a str nor a PIL.Image.
|
| 52 |
+
"""
|
| 53 |
+
if isinstance(item, str):
|
| 54 |
+
return Image.open(item).convert("RGB")
|
| 55 |
+
if isinstance(item, Image.Image):
|
| 56 |
+
return item.convert("RGB")
|
| 57 |
+
raise TypeError(f"Expected str or PIL.Image, got {type(item).__name__}")
|
| 58 |
+
|
| 59 |
+
def _pil_to_array(self, img: "Image.Image") -> "npt.NDArray[np.float32]":
|
| 60 |
+
"""Apply spatial transforms and return a (H, W, C) float32 array in [0, 1].
|
| 61 |
+
|
| 62 |
+
Normalization and the channel transpose are intentionally deferred so
|
| 63 |
+
they can be applied in a single vectorised pass over the whole batch in
|
| 64 |
+
``_normalize_batch``.
|
| 65 |
+
|
| 66 |
+
Steps:
|
| 67 |
+
1. Center-crop to square (shortest side), if enabled.
|
| 68 |
+
2. Resize to (image_size, image_size) with bicubic interpolation.
|
| 69 |
+
3. Scale pixel values to [0, 1].
|
| 70 |
+
4. White out top header rows, if enabled.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
img: RGB PIL image.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Float32 array of shape (image_size, image_size, 3).
|
| 77 |
+
"""
|
| 78 |
+
if self._center_crop:
|
| 79 |
+
w, h = img.size
|
| 80 |
+
sq = min(w, h)
|
| 81 |
+
img = img.crop(
|
| 82 |
+
((w - sq) // 2, (h - sq) // 2, (w + sq) // 2, (h + sq) // 2)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
img = img.resize((self._image_size, self._image_size), Image.Resampling.BICUBIC)
|
| 86 |
+
arr = np.asarray(img, dtype=np.float32) * (1.0 / 255.0) # (H, W, C)
|
| 87 |
+
|
| 88 |
+
if self._whiteout:
|
| 89 |
+
arr[: self._whiteout_cutoff] = 1.0
|
| 90 |
+
|
| 91 |
+
return arr
|
| 92 |
+
|
| 93 |
+
def _normalize_batch(
|
| 94 |
+
self, arrays: list["npt.NDArray[np.float32]"]
|
| 95 |
+
) -> "npt.NDArray[np.float32]":
|
| 96 |
+
"""Stack a list of (H, W, C) arrays and apply ImageNet normalization.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
arrays: List of float32 arrays, each of shape (H, W, C) in [0, 1].
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Float32 array of shape (N, C, H, W), normalized with ImageNet stats.
|
| 103 |
+
"""
|
| 104 |
+
batch = np.stack(arrays, axis=0) # (N, H, W, C)
|
| 105 |
+
batch = (batch - self._mean) / self._std # broadcast over (H, W, C)
|
| 106 |
+
return batch.transpose(0, 3, 1, 2) # (N, C, H, W)
|
| 107 |
+
|
| 108 |
+
def _format(
|
| 109 |
+
self,
|
| 110 |
+
probabilities: "npt.NDArray[np.float32]",
|
| 111 |
+
threshold: float,
|
| 112 |
+
) -> dict[str, Any]:
|
| 113 |
+
"""Format model output probabilities into a result dict.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
probabilities: 1-D float32 array of per-class probabilities.
|
| 117 |
+
threshold: Probability cutoff for a positive prediction.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Dict with keys ``needs_image_embedding``, ``predicted_classes``,
|
| 121 |
+
and ``probabilities``.
|
| 122 |
+
"""
|
| 123 |
+
predicted_classes = [
|
| 124 |
+
name
|
| 125 |
+
for name, prob in zip(self._class_names, probabilities)
|
| 126 |
+
if prob >= threshold
|
| 127 |
+
]
|
| 128 |
+
return {
|
| 129 |
+
"needs_image_embedding": any(
|
| 130 |
+
c in self._image_required_classes for c in predicted_classes
|
| 131 |
+
),
|
| 132 |
+
"predicted_classes": predicted_classes,
|
| 133 |
+
"probabilities": {
|
| 134 |
+
name: float(prob)
|
| 135 |
+
for name, prob in zip(self._class_names, probabilities)
|
| 136 |
+
},
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
def predict(
|
| 140 |
+
self,
|
| 141 |
+
images: Union[str, "Image.Image", list[Any]],
|
| 142 |
+
threshold: float | None = None,
|
| 143 |
+
batch_size: int = 32,
|
| 144 |
+
num_workers: int = 4,
|
| 145 |
+
) -> Union[dict[str, Any], list[dict[str, Any]]]:
|
| 146 |
+
"""Classify one or more PDF page images.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
images: A single image (file path string or PIL.Image) or a list
|
| 150 |
+
of images.
|
| 151 |
+
threshold: Override the default probability threshold from config.
|
| 152 |
+
The override is local to this call and does not mutate the
|
| 153 |
+
classifier instance.
|
| 154 |
+
batch_size: Number of images to process per inference call.
|
| 155 |
+
num_workers: Number of threads for parallel image loading and
|
| 156 |
+
preprocessing. Set to 1 to disable threading.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
A single result dict when ``images`` is not a list, or a list of
|
| 160 |
+
result dicts otherwise. Each dict contains:
|
| 161 |
+
- ``needs_image_embedding`` (bool)
|
| 162 |
+
- ``predicted_classes`` (list[str])
|
| 163 |
+
- ``probabilities`` (dict[str, float])
|
| 164 |
+
"""
|
| 165 |
+
effective_threshold = self._threshold if threshold is None else threshold
|
| 166 |
+
|
| 167 |
+
is_single = not isinstance(images, list)
|
| 168 |
+
image_list: list[Any] = [images] if is_single else images
|
| 169 |
+
|
| 170 |
+
all_results: list[dict[str, Any]] = []
|
| 171 |
+
|
| 172 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 173 |
+
for batch_start in range(0, len(image_list), batch_size):
|
| 174 |
+
batch_items = image_list[batch_start : batch_start + batch_size]
|
| 175 |
+
|
| 176 |
+
# Load (file I/O + RGB conversion) in parallel, then free after use.
|
| 177 |
+
loaded: list[Image.Image] = list(
|
| 178 |
+
executor.map(self._load_image, batch_items)
|
| 179 |
+
)
|
| 180 |
+
# PIL transforms (crop + bicubic resize) in parallel.
|
| 181 |
+
arrays: list[npt.NDArray[np.float32]] = list(
|
| 182 |
+
executor.map(self._pil_to_array, loaded)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Vectorised normalization + transpose, then inference.
|
| 186 |
+
batch_input = self._normalize_batch(arrays) # (N, C, H, W)
|
| 187 |
+
probs_batch: npt.NDArray[np.float32] = self._run_batch(batch_input)
|
| 188 |
+
|
| 189 |
+
all_results.extend(
|
| 190 |
+
self._format(probs, effective_threshold) for probs in probs_batch
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return all_results[0] if is_single else all_results
|
| 194 |
+
|
| 195 |
+
def __call__(
|
| 196 |
+
self,
|
| 197 |
+
images: Union[str, "Image.Image", list[Any]],
|
| 198 |
+
threshold: float | None = None,
|
| 199 |
+
batch_size: int = 32,
|
| 200 |
+
num_workers: int = 4,
|
| 201 |
+
) -> Union[dict[str, Any], list[dict[str, Any]]]:
|
| 202 |
+
"""Delegate to predict(). See predict() for full documentation."""
|
| 203 |
+
return self.predict(
|
| 204 |
+
images, threshold=threshold, batch_size=batch_size, num_workers=num_workers
|
| 205 |
+
)
|
classifiers/classifier_onnx.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PDF page classifier for production inference."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import numpy.typing as npt
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .base_classifier import _BasePDFPageClassifier
|
| 12 |
+
except ImportError:
|
| 13 |
+
from base_classifier import _BasePDFPageClassifier # standalone / HF usage
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import onnxruntime as ort
|
| 17 |
+
except ImportError as _e:
|
| 18 |
+
raise ImportError(
|
| 19 |
+
"onnxruntime is required for inference.\n"
|
| 20 |
+
"Install with: pip install onnxruntime"
|
| 21 |
+
) from _e
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PDFPageClassifierONNX(_BasePDFPageClassifier):
|
| 26 |
+
"""Classify PDF pages using a deployed ONNX model.
|
| 27 |
+
|
| 28 |
+
Loads a self-contained deployment directory produced by
|
| 29 |
+
``export_onnx.save_for_deployment`` and exposes a simple ``predict``
|
| 30 |
+
interface. All preprocessing (center-crop, resize, normalization) is
|
| 31 |
+
performed in pure PIL + numpy, matching the pipeline used during training.
|
| 32 |
+
|
| 33 |
+
Example::
|
| 34 |
+
|
| 35 |
+
clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment")
|
| 36 |
+
result = clf.predict("page_001.png")
|
| 37 |
+
print(result["needs_image_embedding"], result["predicted_classes"])
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, model_path: str, config: dict[str, Any]) -> None:
|
| 41 |
+
"""Initialise the classifier.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model_path: Path to the ONNX model file.
|
| 45 |
+
config: Deployment config dict (same schema as config.json written
|
| 46 |
+
by save_for_deployment).
|
| 47 |
+
"""
|
| 48 |
+
super().__init__(config)
|
| 49 |
+
self._session = ort.InferenceSession(model_path)
|
| 50 |
+
self._input_name: str = self._session.get_inputs()[0].name
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier":
|
| 54 |
+
"""Load a classifier from a deployment directory.
|
| 55 |
+
|
| 56 |
+
The directory must contain:
|
| 57 |
+
- ``model.onnx`` — exported by save_for_deployment
|
| 58 |
+
- ``config.json`` — written by save_for_deployment
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
model_dir: Path to the deployment directory.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Initialised PDFPageClassifier.
|
| 65 |
+
"""
|
| 66 |
+
path = Path(model_dir)
|
| 67 |
+
config_path = path / "config.json"
|
| 68 |
+
|
| 69 |
+
if not config_path.exists():
|
| 70 |
+
raise FileNotFoundError(f"config.json not found in {model_dir}")
|
| 71 |
+
|
| 72 |
+
# Prefer INT8 (QAT export) over FP32 when both are present
|
| 73 |
+
candidates = ["model_int8.onnx", "model.onnx"]
|
| 74 |
+
for candidate in candidates:
|
| 75 |
+
if (path / candidate).exists():
|
| 76 |
+
model_path = path / candidate
|
| 77 |
+
break
|
| 78 |
+
else:
|
| 79 |
+
raise FileNotFoundError(
|
| 80 |
+
f"No ONNX model found in {model_dir}. "
|
| 81 |
+
f"Expected one of: {', '.join(candidates)}."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
with open(config_path, encoding="utf-8") as f:
|
| 85 |
+
config = json.load(f)
|
| 86 |
+
|
| 87 |
+
return cls(str(model_path), config)
|
| 88 |
+
|
| 89 |
+
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
|
| 90 |
+
return self._session.run(None, {self._input_name: batch_input})[0]
|
classifiers/classifier_ov.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenVINO-based PDF page classifier for production inference."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import numpy.typing as npt
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from openvino import Core
|
| 12 |
+
from openvino import CompiledModel
|
| 13 |
+
except ImportError as _e:
|
| 14 |
+
raise ImportError(
|
| 15 |
+
"openvino is required for OpenVINO inference.\n"
|
| 16 |
+
"Install with: pip install openvino"
|
| 17 |
+
) from _e
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from .base_classifier import _BasePDFPageClassifier
|
| 21 |
+
except ImportError:
|
| 22 |
+
from base_classifier import _BasePDFPageClassifier # standalone / HF usage
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PDFPageClassifierOV(_BasePDFPageClassifier):
|
| 26 |
+
"""Classify PDF pages using a deployed OpenVINO IR model.
|
| 27 |
+
|
| 28 |
+
Loads a self-contained deployment directory produced by
|
| 29 |
+
``export_onnx.save_for_deployment`` (with ``export_openvino=True``) and
|
| 30 |
+
exposes the same ``predict`` interface as ``PDFPageClassifier``.
|
| 31 |
+
|
| 32 |
+
Automatically selects the INT8 variant (``model_ov_int8.xml``) when it
|
| 33 |
+
exists alongside the FP32 model, falling back to ``model_ov.xml``.
|
| 34 |
+
|
| 35 |
+
Example::
|
| 36 |
+
|
| 37 |
+
clf = PDFPageClassifierOV.from_pretrained("outputs/run-42/deployment")
|
| 38 |
+
result = clf.predict("page_001.png")
|
| 39 |
+
print(result["needs_image_embedding"], result["predicted_classes"])
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
model_path: str,
|
| 45 |
+
config: dict[str, Any],
|
| 46 |
+
device: str = "CPU",
|
| 47 |
+
) -> None:
|
| 48 |
+
"""Initialise the classifier.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model_path: Path to the OpenVINO IR ``.xml`` file.
|
| 52 |
+
config: Deployment config dict (same schema as config.json written
|
| 53 |
+
by save_for_deployment).
|
| 54 |
+
device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
|
| 55 |
+
"""
|
| 56 |
+
super().__init__(config)
|
| 57 |
+
compiled: CompiledModel = Core().compile_model(model_path, device)
|
| 58 |
+
self._session: CompiledModel = compiled
|
| 59 |
+
self._input_name: str = compiled.input(0).get_any_name()
|
| 60 |
+
self._output = compiled.output(0)
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def from_pretrained(
|
| 64 |
+
cls,
|
| 65 |
+
model_dir: str,
|
| 66 |
+
device: str = "CPU",
|
| 67 |
+
) -> "PDFPageClassifierOV":
|
| 68 |
+
"""Load a classifier from a deployment directory.
|
| 69 |
+
|
| 70 |
+
The directory must contain:
|
| 71 |
+
- ``model_ov.xml`` / ``model_ov_int8.xml`` — exported by
|
| 72 |
+
save_for_deployment with ``export_openvino=True``
|
| 73 |
+
- ``config.json`` — written by save_for_deployment
|
| 74 |
+
|
| 75 |
+
The INT8 model (``model_ov_int8.xml``) is preferred when present.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model_dir: Path to the deployment directory.
|
| 79 |
+
device: OpenVINO device string (``"CPU"``, ``"GPU"``, ``"AUTO"``).
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Initialised PDFPageClassifierOV.
|
| 83 |
+
"""
|
| 84 |
+
path = Path(model_dir)
|
| 85 |
+
config_path = path / "config.json"
|
| 86 |
+
|
| 87 |
+
if not config_path.exists():
|
| 88 |
+
raise FileNotFoundError(f"config.json not found in {model_dir}")
|
| 89 |
+
|
| 90 |
+
# Search order: prefer INT8 over FP32, HF/Optimum names over legacy names
|
| 91 |
+
candidates = [
|
| 92 |
+
"openvino_model_int8.xml", # HF-style INT8 (preferred)
|
| 93 |
+
"openvino_model.xml", # HF-style FP32
|
| 94 |
+
"model_ov_int8.xml", # legacy local INT8
|
| 95 |
+
"model_ov.xml", # legacy local FP32
|
| 96 |
+
]
|
| 97 |
+
for candidate in candidates:
|
| 98 |
+
if (path / candidate).exists():
|
| 99 |
+
model_path = path / candidate
|
| 100 |
+
break
|
| 101 |
+
else:
|
| 102 |
+
raise FileNotFoundError(
|
| 103 |
+
f"No OpenVINO model found in {model_dir}. "
|
| 104 |
+
f"Expected one of: {', '.join(candidates)}. "
|
| 105 |
+
"Export with save_for_deployment(..., export_openvino=True)."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
with open(config_path, encoding="utf-8") as f:
|
| 109 |
+
config = json.load(f)
|
| 110 |
+
|
| 111 |
+
return cls(str(model_path), config, device=device)
|
| 112 |
+
|
| 113 |
+
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
|
| 114 |
+
return self._session({self._input_name: batch_input})[self._output]
|
classifiers/classifier_torch.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch-based PDF page classifier for native inference."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import numpy.typing as npt
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import torch
|
| 12 |
+
except ImportError as _e:
|
| 13 |
+
raise ImportError(
|
| 14 |
+
"torch is required for PyTorch inference.\n"
|
| 15 |
+
"Install with: pip install torch"
|
| 16 |
+
) from _e
|
| 17 |
+
|
| 18 |
+
from classifiers.base_classifier import _BasePDFPageClassifier
|
| 19 |
+
from models import create_model
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class PDFPageClassifierTorch(_BasePDFPageClassifier):
|
| 23 |
+
"""Classify PDF pages using a native PyTorch checkpoint.
|
| 24 |
+
|
| 25 |
+
Loads a checkpoint produced by the training script and exposes the same
|
| 26 |
+
``predict`` interface as the ONNX and OpenVINO classifiers. All
|
| 27 |
+
preprocessing (center-crop, resize, normalization) is handled by the
|
| 28 |
+
shared base class.
|
| 29 |
+
|
| 30 |
+
Example::
|
| 31 |
+
|
| 32 |
+
clf = PDFPageClassifierTorch.from_checkpoint("outputs/run-42/best_model.pt")
|
| 33 |
+
result = clf.predict("page_001.png")
|
| 34 |
+
print(result["needs_image_embedding"], result["predicted_classes"])
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
model: "torch.nn.Module",
|
| 40 |
+
config: dict[str, Any],
|
| 41 |
+
device: "torch.device | str" = "cpu",
|
| 42 |
+
) -> None:
|
| 43 |
+
"""Initialise the classifier.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
model: PyTorch model already loaded with weights and set to eval mode.
|
| 47 |
+
config: Flat config dict compatible with the base classifier schema.
|
| 48 |
+
device: Torch device to run inference on (``"cpu"``, ``"cuda"``, etc.).
|
| 49 |
+
"""
|
| 50 |
+
super().__init__(config)
|
| 51 |
+
self._device = torch.device(device)
|
| 52 |
+
self._model = model.to(self._device)
|
| 53 |
+
self._model.eval()
|
| 54 |
+
|
| 55 |
+
@classmethod
|
| 56 |
+
def from_checkpoint(
|
| 57 |
+
cls,
|
| 58 |
+
checkpoint_path: str,
|
| 59 |
+
device: "torch.device | str" = "cpu",
|
| 60 |
+
image_required_classes: list[str] | None = None,
|
| 61 |
+
threshold: float = 0.5,
|
| 62 |
+
) -> "PDFPageClassifierTorch":
|
| 63 |
+
"""Load a classifier from a training checkpoint.
|
| 64 |
+
|
| 65 |
+
The checkpoint must contain:
|
| 66 |
+
- ``model_state_dict`` — model weights
|
| 67 |
+
- ``config`` — training config with ``model`` and ``data`` keys
|
| 68 |
+
- ``class_names`` — ordered list of class names
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
checkpoint_path: Path to the ``.pt`` checkpoint file.
|
| 72 |
+
device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
|
| 73 |
+
image_required_classes: Class names that trigger image embedding.
|
| 74 |
+
Defaults to an empty list when not provided.
|
| 75 |
+
threshold: Default prediction threshold (can be overridden per call).
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Initialised PDFPageClassifierTorch.
|
| 79 |
+
"""
|
| 80 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 81 |
+
|
| 82 |
+
train_cfg = ckpt["config"]
|
| 83 |
+
class_names: list[str] = ckpt["class_names"]
|
| 84 |
+
data_cfg = train_cfg["data"]
|
| 85 |
+
|
| 86 |
+
model = create_model(
|
| 87 |
+
model_name=train_cfg["model"]["name"],
|
| 88 |
+
num_classes=len(class_names),
|
| 89 |
+
pretrained=False,
|
| 90 |
+
dropout=train_cfg["model"]["dropout"],
|
| 91 |
+
use_spatial_pooling=train_cfg["model"].get("use_spatial_pooling", False),
|
| 92 |
+
)
|
| 93 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 94 |
+
|
| 95 |
+
# Build a flat config dict that matches the base-class schema.
|
| 96 |
+
config: dict[str, Any] = {
|
| 97 |
+
"image_size": data_cfg["image_size"],
|
| 98 |
+
"mean": data_cfg.get("mean", [0.485, 0.456, 0.406]),
|
| 99 |
+
"std": data_cfg.get("std", [0.229, 0.224, 0.225]),
|
| 100 |
+
"center_crop_shortest": data_cfg.get("center_crop_shortest", True),
|
| 101 |
+
"whiteout_header": data_cfg.get("whiteout_header", False),
|
| 102 |
+
"whiteout_fraction": data_cfg.get("whiteout_fraction", 0.15),
|
| 103 |
+
"class_names": class_names,
|
| 104 |
+
"threshold": threshold,
|
| 105 |
+
"image_required_classes": image_required_classes or [],
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
return cls(model, config, device=device)
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def from_pretrained(
|
| 112 |
+
cls,
|
| 113 |
+
model_dir: str,
|
| 114 |
+
device: "torch.device | str" = "cpu",
|
| 115 |
+
) -> "PDFPageClassifierTorch":
|
| 116 |
+
"""Load a classifier from a deployment directory.
|
| 117 |
+
|
| 118 |
+
The directory must contain:
|
| 119 |
+
- ``model.pt`` — PyTorch checkpoint written by save_for_deployment
|
| 120 |
+
- ``config.json`` — deployment config written by save_for_deployment
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
model_dir: Path to the deployment directory.
|
| 124 |
+
device: Torch device string (``"cpu"``, ``"cuda"``, ``"mps"``).
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
Initialised PDFPageClassifierTorch.
|
| 128 |
+
"""
|
| 129 |
+
path = Path(model_dir)
|
| 130 |
+
config_path = path / "config.json"
|
| 131 |
+
model_path = path / "model.pt"
|
| 132 |
+
|
| 133 |
+
if not config_path.exists():
|
| 134 |
+
raise FileNotFoundError(f"config.json not found in {model_dir}")
|
| 135 |
+
if not model_path.exists():
|
| 136 |
+
raise FileNotFoundError(f"model.pt not found in {model_dir}")
|
| 137 |
+
|
| 138 |
+
with open(config_path, encoding="utf-8") as f:
|
| 139 |
+
config: dict[str, Any] = json.load(f)
|
| 140 |
+
|
| 141 |
+
ckpt = torch.load(str(model_path), map_location="cpu", weights_only=False)
|
| 142 |
+
|
| 143 |
+
model = create_model(
|
| 144 |
+
model_name=config["model_name"],
|
| 145 |
+
num_classes=len(config["class_names"]),
|
| 146 |
+
pretrained=False,
|
| 147 |
+
dropout=config.get("dropout", 0.2),
|
| 148 |
+
use_spatial_pooling=config.get("use_spatial_pooling", False),
|
| 149 |
+
)
|
| 150 |
+
model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
|
| 151 |
+
|
| 152 |
+
return cls(model, config, device=device)
|
| 153 |
+
|
| 154 |
+
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
|
| 155 |
+
tensor = torch.from_numpy(batch_input).to(self._device) # type:ignore
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
logits = self._model(tensor)
|
| 158 |
+
probs = torch.sigmoid(logits)
|
| 159 |
+
return probs.cpu().numpy() # type: ignore
|
classifiers/models.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model definitions for PDF page classification."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import timm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MultiLabelClassifier(nn.Module):
|
| 9 |
+
"""Multi-label image classifier with configurable backbone.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
model_name: Name of the timm model to use as backbone
|
| 13 |
+
num_classes: Number of output classes
|
| 14 |
+
pretrained: Whether to use pretrained weights
|
| 15 |
+
dropout: Dropout probability before final layer
|
| 16 |
+
use_spatial_pooling: If True, use spatial max pooling (CAM-style) instead of global pooling
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
model_name: str,
|
| 22 |
+
num_classes: int,
|
| 23 |
+
pretrained: bool = True,
|
| 24 |
+
dropout: float = 0.2,
|
| 25 |
+
use_spatial_pooling: bool = False
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.model_name = model_name
|
| 30 |
+
self.num_classes = num_classes
|
| 31 |
+
self.use_spatial_pooling = use_spatial_pooling
|
| 32 |
+
|
| 33 |
+
# Load pretrained backbone from timm
|
| 34 |
+
if use_spatial_pooling:
|
| 35 |
+
# No global pooling - keep spatial dimensions
|
| 36 |
+
self.backbone = timm.create_model(
|
| 37 |
+
model_name,
|
| 38 |
+
pretrained=pretrained,
|
| 39 |
+
num_classes=0, # Remove classification head
|
| 40 |
+
global_pool='' # No pooling
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
# Standard global average pooling
|
| 44 |
+
self.backbone = timm.create_model(
|
| 45 |
+
model_name,
|
| 46 |
+
pretrained=pretrained,
|
| 47 |
+
num_classes=0, # Remove classification head
|
| 48 |
+
global_pool='avg'
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Get feature dimension
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
dummy_input = torch.randn(1, 3, 224, 224)
|
| 54 |
+
features = self.backbone(dummy_input)
|
| 55 |
+
|
| 56 |
+
if use_spatial_pooling:
|
| 57 |
+
# features shape: [B, C, H, W]
|
| 58 |
+
self.feature_dim = features.shape[1]
|
| 59 |
+
print(f"Spatial pooling enabled - feature map shape: {features.shape}")
|
| 60 |
+
else:
|
| 61 |
+
# features shape: [B, C]
|
| 62 |
+
self.feature_dim = features.shape[1]
|
| 63 |
+
|
| 64 |
+
# Classification head
|
| 65 |
+
if use_spatial_pooling:
|
| 66 |
+
# 1x1 conv for spatial classification + dropout
|
| 67 |
+
self.classifier = nn.Sequential(
|
| 68 |
+
nn.Dropout2d(p=dropout), # Spatial dropout
|
| 69 |
+
nn.Conv2d(self.feature_dim, num_classes, kernel_size=1)
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
# Standard linear classifier
|
| 73 |
+
self.classifier = nn.Sequential(
|
| 74 |
+
nn.Dropout(p=dropout),
|
| 75 |
+
nn.Linear(self.feature_dim, num_classes)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 79 |
+
"""Forward pass.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
x: Input tensor of shape (batch_size, 3, H, W)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Logits of shape (batch_size, num_classes)
|
| 86 |
+
"""
|
| 87 |
+
features = self.backbone(x)
|
| 88 |
+
|
| 89 |
+
if self.use_spatial_pooling:
|
| 90 |
+
# features: [B, C, H, W]
|
| 91 |
+
# spatial_logits: [B, num_classes, H, W]
|
| 92 |
+
spatial_logits = self.classifier(features)
|
| 93 |
+
# Global max pooling per class: [B, num_classes]
|
| 94 |
+
logits = torch.amax(spatial_logits, dim=(2, 3))
|
| 95 |
+
else:
|
| 96 |
+
# features: [B, C]
|
| 97 |
+
# logits: [B, num_classes]
|
| 98 |
+
logits = self.classifier(features)
|
| 99 |
+
|
| 100 |
+
return logits
|
| 101 |
+
|
| 102 |
+
def get_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""Extract features without classification head.
|
| 104 |
+
|
| 105 |
+
Useful for feature visualization or transfer learning.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
x: Input tensor of shape (batch_size, 3, H, W)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Features of shape (batch_size, feature_dim) or (batch_size, feature_dim, H, W)
|
| 112 |
+
"""
|
| 113 |
+
return self.backbone(x)
|
| 114 |
+
|
| 115 |
+
def get_activation_maps(self, x: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
"""Get spatial activation maps (only for spatial pooling mode).
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
x: Input tensor of shape (batch_size, 3, H, W)
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Activation maps of shape (batch_size, num_classes, H, W)
|
| 123 |
+
|
| 124 |
+
Raises:
|
| 125 |
+
ValueError: If spatial pooling is not enabled
|
| 126 |
+
"""
|
| 127 |
+
if not self.use_spatial_pooling:
|
| 128 |
+
raise ValueError("Activation maps only available with spatial pooling enabled")
|
| 129 |
+
|
| 130 |
+
features = self.backbone(x)
|
| 131 |
+
spatial_logits = self.classifier(features)
|
| 132 |
+
return spatial_logits
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def create_model(
|
| 136 |
+
model_name: str,
|
| 137 |
+
num_classes: int,
|
| 138 |
+
pretrained: bool = True,
|
| 139 |
+
dropout: float = 0.2,
|
| 140 |
+
use_spatial_pooling: bool = False
|
| 141 |
+
) -> MultiLabelClassifier:
|
| 142 |
+
"""Factory function to create a model.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
model_name: Name of the model architecture. Example : mobilenetv3_small_100
|
| 146 |
+
num_classes: Number of output classes
|
| 147 |
+
pretrained: Whether to use pretrained weights
|
| 148 |
+
dropout: Dropout probability
|
| 149 |
+
use_spatial_pooling: If True, use spatial max pooling (CAM-style)
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Initialized model
|
| 153 |
+
"""
|
| 154 |
+
# Verify model exists in timm
|
| 155 |
+
available_models = timm.list_models(model_name)
|
| 156 |
+
if not available_models:
|
| 157 |
+
raise ValueError(
|
| 158 |
+
f"Model '{model_name}' not found in timm."
|
| 159 |
+
f"Available options: {timm.list_models()}"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
model = MultiLabelClassifier(
|
| 163 |
+
model_name=model_name,
|
| 164 |
+
num_classes=num_classes,
|
| 165 |
+
pretrained=pretrained,
|
| 166 |
+
dropout=dropout,
|
| 167 |
+
use_spatial_pooling=use_spatial_pooling
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def count_parameters(model: nn.Module) -> dict[str, int | float]:
|
| 174 |
+
"""Count model parameters.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
model: PyTorch model
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Dictionary with parameter counts
|
| 181 |
+
"""
|
| 182 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 183 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 184 |
+
|
| 185 |
+
return {
|
| 186 |
+
'total': total_params,
|
| 187 |
+
'trainable': trainable_params,
|
| 188 |
+
'non_trainable': total_params - trainable_params,
|
| 189 |
+
'total_millions': total_params / 1e6,
|
| 190 |
+
'trainable_millions': trainable_params / 1e6
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def print_model_info(model: nn.Module, model_name: str = "Model"):
|
| 195 |
+
"""Print model information.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
model: PyTorch model
|
| 199 |
+
model_name: Name to display
|
| 200 |
+
"""
|
| 201 |
+
params = count_parameters(model)
|
| 202 |
+
|
| 203 |
+
print(f"\n{'='*60}")
|
| 204 |
+
print(f"{model_name} Information")
|
| 205 |
+
print(f"{'='*60}")
|
| 206 |
+
print(f"Total parameters: {params['total']:,} ({params['total_millions']:.2f}M)")
|
| 207 |
+
print(f"Trainable parameters: {params['trainable']:,} ({params['trainable_millions']:.2f}M)")
|
| 208 |
+
print(f"Non-trainable params: {params['non_trainable']:,}")
|
| 209 |
+
print(f"{'='*60}\n")
|
config.json
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": {
|
| 3 |
+
"name": "efficientnet_lite0",
|
| 4 |
+
"pretrained": true,
|
| 5 |
+
"dropout": 0,
|
| 6 |
+
"use_spatial_pooling": false
|
| 7 |
+
},
|
| 8 |
+
"classes": [
|
| 9 |
+
"Visual - Essential",
|
| 10 |
+
"Simple Table",
|
| 11 |
+
"Chart/Graph",
|
| 12 |
+
"Visual - Supportive",
|
| 13 |
+
"Annotated figure",
|
| 14 |
+
"No Specific Feature",
|
| 15 |
+
"Diagram/Flowchart",
|
| 16 |
+
"Visual - Decorative",
|
| 17 |
+
"Complex Table",
|
| 18 |
+
"Infographic",
|
| 19 |
+
"Form",
|
| 20 |
+
"Text to OCR"
|
| 21 |
+
],
|
| 22 |
+
"class_mapping": {
|
| 23 |
+
"Form": null,
|
| 24 |
+
"No Specific Feature": null,
|
| 25 |
+
"Text to OCR": null,
|
| 26 |
+
"Visual - Decorative": null,
|
| 27 |
+
"Infographic": null,
|
| 28 |
+
"Chart/Graph": null,
|
| 29 |
+
"Annotated figure": null,
|
| 30 |
+
"Diagram/Flowchart": null
|
| 31 |
+
},
|
| 32 |
+
"image_required_classes": [
|
| 33 |
+
"Visual Essential",
|
| 34 |
+
"Complex Table"
|
| 35 |
+
],
|
| 36 |
+
"data": {
|
| 37 |
+
"train_split": 0.8,
|
| 38 |
+
"val_split": 0.1,
|
| 39 |
+
"test_split": 0.1,
|
| 40 |
+
"image_size": 224,
|
| 41 |
+
"batch_size": 32,
|
| 42 |
+
"num_workers": 4,
|
| 43 |
+
"seed": 42
|
| 44 |
+
},
|
| 45 |
+
"augmentation": {
|
| 46 |
+
"center_crop_shortest": true,
|
| 47 |
+
"whiteout_header": false,
|
| 48 |
+
"whiteout_fraction": 0.15,
|
| 49 |
+
"train": {
|
| 50 |
+
"horizontal_flip": 0.5,
|
| 51 |
+
"rotation_degrees": 5,
|
| 52 |
+
"color_jitter": {
|
| 53 |
+
"brightness": 0.2,
|
| 54 |
+
"contrast": 0.2,
|
| 55 |
+
"saturation": 0.1,
|
| 56 |
+
"hue": 0.05
|
| 57 |
+
},
|
| 58 |
+
"random_erasing": 0.1
|
| 59 |
+
},
|
| 60 |
+
"val": {
|
| 61 |
+
"enabled": false
|
| 62 |
+
}
|
| 63 |
+
},
|
| 64 |
+
"training": {
|
| 65 |
+
"epochs": 40,
|
| 66 |
+
"learning_rate": 0.0001,
|
| 67 |
+
"weight_decay": 0.0001,
|
| 68 |
+
"optimizer": "adamw",
|
| 69 |
+
"scheduler": "cosine",
|
| 70 |
+
"warmup_epochs": 5,
|
| 71 |
+
"label_smoothing": 0.0,
|
| 72 |
+
"gradient_clip_norm": 1.0,
|
| 73 |
+
"pos_weight": [
|
| 74 |
+
3.6715595722198486,
|
| 75 |
+
6.668674468994141,
|
| 76 |
+
2.3281044960021973,
|
| 77 |
+
6.0722222328186035
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
"monitoring": {
|
| 81 |
+
"metric": "val_f1",
|
| 82 |
+
"mode": "max"
|
| 83 |
+
},
|
| 84 |
+
"early_stopping": {
|
| 85 |
+
"enabled": true,
|
| 86 |
+
"patience": 20
|
| 87 |
+
},
|
| 88 |
+
"evaluation": {
|
| 89 |
+
"threshold": 0.5,
|
| 90 |
+
"save_confusion_matrix": true,
|
| 91 |
+
"save_per_class_metrics": true
|
| 92 |
+
},
|
| 93 |
+
"checkpointing": {
|
| 94 |
+
"save_best_only": true,
|
| 95 |
+
"save_last": true
|
| 96 |
+
},
|
| 97 |
+
"paths": {
|
| 98 |
+
"data_dir": "data",
|
| 99 |
+
"output_dir": "outputs",
|
| 100 |
+
"checkpoint_dir": "checkpoints",
|
| 101 |
+
"logs_dir": "logs"
|
| 102 |
+
},
|
| 103 |
+
"logging": {
|
| 104 |
+
"use_tensorboard": false,
|
| 105 |
+
"use_wandb": true,
|
| 106 |
+
"wandb_project": "pdf-page-classifier",
|
| 107 |
+
"log_interval": 10,
|
| 108 |
+
"wandb_run_name": "silver-line-69"
|
| 109 |
+
},
|
| 110 |
+
"qat": {
|
| 111 |
+
"enabled": true,
|
| 112 |
+
"epochs": 5,
|
| 113 |
+
"learning_rate": "1e-5",
|
| 114 |
+
"preset": "mixed",
|
| 115 |
+
"num_init_samples": 300
|
| 116 |
+
},
|
| 117 |
+
"onnx": {
|
| 118 |
+
"opset_version": 14,
|
| 119 |
+
"dynamic_axes": true,
|
| 120 |
+
"simplify": true,
|
| 121 |
+
"input_names": [
|
| 122 |
+
"image"
|
| 123 |
+
],
|
| 124 |
+
"output_names": [
|
| 125 |
+
"probabilities"
|
| 126 |
+
]
|
| 127 |
+
}
|
| 128 |
+
}
|