dev-models commited on
Commit
0d9f6c2
·
1 Parent(s): 4e459af

initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ FROM pytorch/pytorch:latest
3
+
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1 \
6
+ PYTHONUNBUFFERED=1 \
7
+ PIP_NO_CACHE_DIR=1
8
+
9
+ WORKDIR /app
10
+
11
+ # System deps for OpenCV / OCR
12
+ RUN apt-get update && apt-get install -y \
13
+ libgl1 \
14
+ libglib2.0-0 \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ RUN pip install --upgrade pip
18
+
19
+ # Install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install --prefer-binary -r requirements.txt
22
+
23
+ # Copy backend + frontend
24
+ COPY backend ./backend
25
+ COPY frontend ./frontend
26
+
27
+ EXPOSE 7860
28
+
29
+ CMD ["uvicorn", "backend.main.app:app", "--host", "0.0.0.0", "--port", "7860"]
README copy.md ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # End-to-End OCR with CRNN, CTC, and FastAPI
2
+
3
+ This project provides a complete pipeline for Optical Character Recognition (OCR). It features a **Convolutional Recurrent Neural Network (CRNN)** trained with **Connectionist Temporal Classification (CTC)** loss for text recognition. The trained model is served via a **FastAPI** application, which uses the **CRAFT** model for initial text detection.
4
+
5
+ The entire workflow, from data preparation and model training to deployment as a web service, is documented and implemented in this repository.
6
+
7
+ ## Table of Contents
8
+
9
+ - [Workflow Overview](#workflow-overview)
10
+ - [Features](#features)
11
+ - [Project Structure](#project-structure)
12
+ - [Technical Details](#technical-details)
13
+ - [Text Detection: CRAFT](#text-detection-craft)
14
+ - [Text Recognition: CRNN](#text-recognition-crnn)
15
+ - [Getting Started](#getting-started)
16
+ - [Prerequisites](#prerequisites)
17
+ - [Installation](#installation)
18
+ - [Model Training and Export](#model-training-and-export)
19
+ - [Dataset](#dataset)
20
+ - [Running the Training Notebook](#running-the-training-notebook)
21
+ - [Exporting to ONNX](#exporting-to-onnx)
22
+ - [Inference](#inference)
23
+ - [Notebook Inference](#notebook-inference)
24
+ - [API Server (FastAPI)](#api-server-fastapi)
25
+ - [Running the Server](#running-the-server)
26
+ - [API Endpoints](#api-endpoints)
27
+ - [HTTP POST Request](#http-post-request)
28
+ - [WebSocket Connection](#websocket-connection)
29
+ - [Deployment with Docker](#deployment-with-docker)
30
+
31
+ ## Workflow Overview
32
+
33
+ The project follows a clear, step-by-step process from model creation to deployment:
34
+
35
+ 1. **Train the Model**: The `notebook/ocr.ipynb` notebook is used to train the CRNN text recognition model on the MJSynth dataset.
36
+ 2. **Export the Model**: The trained PyTorch model is exported to the ONNX format (`model.onnx`) for efficient inference.
37
+ 3. **Serve the Model**: The FastAPI application (`main/app.py`) loads the ONNX model and the CRAFT text detection model to provide OCR capabilities through a web API.
38
+ 4. **Deploy**: The entire application is containerized using Docker for easy and reproducible deployment.
39
+
40
+ ## Features
41
+
42
+ - **Two-Stage OCR**: Uses CRAFT for accurate text detection and a CRNN for robust text recognition.
43
+ - **Deep Learning Model**: An optimized CRNN architecture implemented in PyTorch.
44
+ - **Efficient Inference**: Model is exported to ONNX for fast performance.
45
+ - **Web API**: A FastAPI server with both REST and WebSocket endpoints.
46
+ - **Reproducible Environment**: Comes with a `Dockerfile` for easy setup and deployment.
47
+ - **Complete Workflow**: Includes all steps from training to deployment.
48
+
49
+ ## Project Structure
50
+
51
+ ```
52
+ .
53
+ ├── .gitignore # Files to be ignored by Git
54
+ ├── Dockerfile # Docker configuration for the API
55
+ ├── main
56
+ │ ├── app.py # Main FastAPI application
57
+ │ └── core # Modularized application logic
58
+ │ ├── __init__.py
59
+ │ ├── config.py # Configuration variables
60
+ │ ├── models.py # Model loading (CRAFT)
61
+ │ ├── ocr.py # Core OCR pipeline
62
+ │ └── utils.py # Utility functions
63
+ ├── models
64
+ │ ├── final.pth # Final trained PyTorch model weights
65
+ │ └── model.onnx # Trained model in ONNX format
66
+ ├── notebook
67
+ │ └── ocr.ipynb # Jupyter Notebook for training and export
68
+ └── requirements.txt # Python dependencies
69
+ ```
70
+
71
+ ## Technical Details
72
+
73
+ ### Text Detection: CRAFT
74
+
75
+ The FastAPI application first uses the **CRAFT (Character-Region Awareness for Text)** model to detect text regions in the input image. It identifies bounding boxes around words or lines of text. This project uses the `hezarai/CRAFT` implementation.
76
+
77
+ ### Text Recognition: CRNN
78
+
79
+ For each bounding box detected by CRAFT, a **Convolutional Recurrent Neural Network (CRNN)** is used to recognize the text within that region.
80
+
81
+ - **Convolutional Layers (CNN)**: Serve as a powerful feature extractor, processing the image patch and outputting a sequence of feature vectors.
82
+ - **Recurrent Layers (RNN)**: A bidirectional LSTM network processes the feature sequence, capturing contextual dependencies between characters.
83
+ - **CTC Loss**: The model is trained with Connectionist Temporal Classification (CTC) loss, which eliminates the need for character-level alignment between the input image and the output text, making it perfect for OCR.
84
+
85
+ ## Getting Started
86
+
87
+ ### Prerequisites
88
+
89
+ - Python 3.9 or higher
90
+ - An NVIDIA GPU with CUDA for training is highly recommended.
91
+ - Docker for containerized deployment.
92
+
93
+ ### Installation
94
+
95
+ 1. **Clone the repository:**
96
+ ```bash
97
+ git clone <repository-url>
98
+ cd crnn-ctc-ocr
99
+ ```
100
+
101
+ 2. **Set up a virtual environment (recommended):**
102
+ ```bash
103
+ python -m venv venv
104
+ source venv/bin/activate # On Windows use `venv\Scripts\activate`
105
+ ```
106
+
107
+ 3. **Install dependencies:**
108
+ ```bash
109
+ pip install -r requirements.txt
110
+ ```
111
+
112
+ ## Model Training and Export
113
+
114
+ The `notebook/ocr.ipynb` notebook contains the complete code for training and exporting the model.
115
+
116
+ ### Dataset
117
+
118
+ The model is trained on the **MJSynth (MJ Synth)** dataset, a large-scale synthetic dataset for text recognition. The notebook automatically downloads it using the `datasets` library from Hugging Face (`priyank-m/MJSynth_text_recognition`).
119
+
120
+ ### Running the Training Notebook
121
+
122
+ 1. **Launch Jupyter:**
123
+ ```bash
124
+ jupyter notebook
125
+ ```
126
+ 2. Open `notebook/ocr.ipynb`.
127
+ 3. You can run all cells to execute the full pipeline. The `DEMO` flag is set to `True` by default to train on a smaller subset for a quick run. Set it to `False` for full training.
128
+ 4. The notebook will:
129
+ - Load and preprocess the dataset.
130
+ - Define the CRNN model, loss function, and optimizer.
131
+ - Run the training loop, showing progress and validation metrics (Character Error Rate, Word Error Rate).
132
+ - Save the best model to `checkpoints/best.pth` and the final model to `checkpoints/final.pth`.
133
+
134
+ ### Exporting to ONNX
135
+
136
+ After training, the last cells of the notebook handle the export to ONNX.
137
+
138
+ - It takes the trained CRNN model.
139
+ - It exports the model to `export/model.onnx`. **This file is crucial for the FastAPI application.** The repository already includes a pre-trained `models/model.onnx`.
140
+
141
+ ## Inference
142
+
143
+ ### Notebook Inference
144
+
145
+ The notebook includes helper functions to test the model directly.
146
+
147
+ **1. Using the PyTorch model (`.pth`):**
148
+ ```python
149
+ # predict with the best saved PyTorch model
150
+ predict('path/to/your/image.png', model_path='checkpoints/best.pth')
151
+ ```
152
+
153
+ **2. Using the ONNX model:**
154
+ ```python
155
+ # predict with the exported ONNX model
156
+ predict_onnx('path/to/your/image.png', onnx_path='export/model.onnx')
157
+ ```
158
+
159
+ ## API Server (FastAPI)
160
+
161
+ The application provides a web server to perform OCR on uploaded images.
162
+
163
+ ### Running the Server
164
+
165
+ To run the API server locally:
166
+ ```bash
167
+ python main/app.py
168
+ ```
169
+ The server will start on `http://localhost:8000`.
170
+
171
+ ### API Endpoints
172
+
173
+ #### HTTP POST Request
174
+
175
+ - **Endpoint**: `POST /predict/image`
176
+ - **Description**: Upload an image and receive the OCR results in JSON format.
177
+ - **Example using `curl`**:
178
+ ```bash
179
+ curl -X POST -F "file=@/path/to/your/image.jpg" http://localhost:8000/predict/image
180
+ ```
181
+ - **Response**: A JSON object containing the detected paragraph, lines, words, and a base64-encoded image with bounding boxes drawn on it.
182
+
183
+ #### WebSocket Connection
184
+
185
+ - **Endpoint**: `ws://localhost:8000/ws/predict`
186
+ - **Description**: A WebSocket endpoint for real-time OCR. Send an image as bytes, and the server will return the OCR result as a JSON message. This is useful for streaming or interactive applications.
187
+
188
+ ## Deployment with Docker
189
+
190
+ The project includes a `Dockerfile` to easily containerize and deploy the application.
191
+
192
+ 1. **Prerequisite**: Ensure the `models/model.onnx` file exists. If you've trained your own model, make sure your exported `export/model.onnx` is moved to `models/model.onnx` or the path in `main/app.py` is updated.
193
+ *(Note: The provided code in `main/app.py` looks for `export/model.onnx`, so ensure this path is correct for your setup or that the file is present at build time.)*
194
+
195
+ 2. **Build the Docker image:**
196
+ ```bash
197
+ docker build -t ocr-api .
198
+ ```
199
+
200
+ 3. **Run the Docker container:**
201
+ ```bash
202
+ docker run -p 8000:8000 ocr-api
203
+ ```
204
+ The application will be running and accessible at `http://localhost:8000`.
backend/__init__.py ADDED
File without changes
backend/main/app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from PIL import Image
5
+ import onnxruntime as ort
6
+ from fastapi.staticfiles import StaticFiles
7
+
8
+ from backend.main.core.ocr import run_ocr
9
+ from backend.main.core.models import load_models_on_startup
10
+
11
+ app = FastAPI(title="OCR API")
12
+
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ @app.on_event("startup")
21
+ async def startup_event():
22
+ await load_models_on_startup()
23
+
24
+ @app.post("/predict/image", tags=["Image Prediction"])
25
+ async def predict_image(file: UploadFile = File(...)):
26
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
27
+ ort_session = ort.InferenceSession("./backend/models/model.onnx")
28
+ return run_ocr(image, ort_session)
29
+
30
+ # Serve React
31
+ app.mount(
32
+ "/",
33
+ StaticFiles(directory="frontend", html=True),
34
+ name="react"
35
+ )
36
+
37
+ if __name__ == "__main__":
38
+ import uvicorn
39
+ uvicorn.run(app, host="0.0.0.0", port=8000)
backend/main/core/__init__.py ADDED
File without changes
backend/main/core/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (130 Bytes). View file
 
backend/main/core/__pycache__/config.cpython-310.pyc ADDED
Binary file (225 Bytes). View file
 
backend/main/core/__pycache__/models.cpython-310.pyc ADDED
Binary file (465 Bytes). View file
 
backend/main/core/__pycache__/ocr.cpython-310.pyc ADDED
Binary file (2.6 kB). View file
 
backend/main/core/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.02 kB). View file
 
backend/main/core/config.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -------------------- Config --------------------
2
+ alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
3
+ blank_idx = 0
4
+ TARGET_HEIGHT = 32
backend/main/core/models.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from hezar.models import Model
2
+
3
+ models = {}
4
+ dynamic_dictionary = set()
5
+
6
+ async def load_models_on_startup():
7
+ print("Loading CRAFT...")
8
+ models["craft"] = Model.load("hezarai/CRAFT", device="cpu")
9
+ print("CRAFT loaded")
backend/main/core/ocr.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import numpy as np
4
+ from PIL import ImageDraw, ImageFont
5
+
6
+ from . import utils
7
+ from . import models
8
+
9
+ def predict_word(crop, ort_session, dictionary=None):
10
+ if crop.width < 5 or crop.height < 5:
11
+ return "", 0.0
12
+
13
+ tensor = utils.preprocess_crop(crop)
14
+ logits = ort_session.run(None, {ort_session.get_inputs()[0].name: tensor})[0][:, 0, :]
15
+ logits -= np.max(logits, axis=1, keepdims=True)
16
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
17
+
18
+ text, conf = utils.ctc_greedy_decode_with_confidence(probs)
19
+
20
+ if conf < 0.85:
21
+ bt, bc = utils.ctc_beam_search_decode(probs)
22
+ if bt:
23
+ text, conf = bt, max(conf, bc)
24
+
25
+
26
+ # Correct using dynamic dictionary if available
27
+ if dictionary and conf < 0.85:
28
+ text = utils.correct_word_dynamic(text, dictionary)
29
+
30
+ # Update dictionary with high-confidence words
31
+ if dictionary and conf >= 0.9 and text:
32
+ dictionary.add(text.lower())
33
+
34
+
35
+ return (text if conf >= 0.6 else ""), conf
36
+
37
+ def run_ocr(image, ort_session):
38
+ craft = models.models["craft"]
39
+ boxes = craft.predict(image)[0]["boxes"]
40
+
41
+ draw_img = image.copy()
42
+ draw = ImageDraw.Draw(draw_img)
43
+ font = ImageFont.load_default()
44
+
45
+ words = []
46
+ for box in boxes:
47
+ x, y, w, h = map(int, box)
48
+ crop = utils.safe_crop(image, x, y, x + w, y + h)
49
+ text, conf = predict_word(crop, ort_session, models.dynamic_dictionary)
50
+
51
+ bbox = [x, y, x + w, y + h]
52
+ words.append({"text": text, "confidence": conf, "bbox": bbox})
53
+
54
+ color = "green" if conf > 0.9 else "orange" if conf > 0.75 else "red"
55
+ draw.rectangle(bbox, outline=color, width=2)
56
+ if text:
57
+ draw.text((x, max(0, y - 12)), f"{text} {conf:.2f}", fill=color, font=font)
58
+
59
+ lines = utils.order_words_reading_order(words)
60
+
61
+ paragraph_lines = []
62
+ for line in lines:
63
+ texts = [w["text"] for w in line if w["text"]]
64
+ if not texts:
65
+ continue
66
+ paragraph_lines.append({
67
+ "text": " ".join(texts),
68
+ "confidence": float(np.mean([w["confidence"] for w in line]))
69
+ })
70
+
71
+ paragraph = "\n".join(l["text"] for l in paragraph_lines)
72
+
73
+ # -------- Encode annotated image --------
74
+ buf = io.BytesIO()
75
+ draw_img.save(buf, format="PNG")
76
+ img_b64 = base64.b64encode(buf.getvalue()).decode()
77
+
78
+ return {
79
+ "paragraph": paragraph,
80
+ "lines": paragraph_lines,
81
+ "words": words,
82
+ "image": f"data:image/png;base64,{img_b64}"
83
+ }
backend/main/core/utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from rapidfuzz import process
4
+ from .config import alphabet, blank_idx, TARGET_HEIGHT
5
+
6
+
7
+ # -------------------- CTC Decoding --------------------
8
+ def ctc_greedy_decode_with_confidence(probs):
9
+ pred = np.argmax(probs, axis=1)
10
+ confs = np.max(probs, axis=1)
11
+
12
+ chars, scores = [], []
13
+ prev = None
14
+ for p, c in zip(pred, confs):
15
+ if p != blank_idx and p != prev:
16
+ chars.append(alphabet[p - 1])
17
+ scores.append(float(c))
18
+ prev = p
19
+
20
+ return "".join(chars), float(np.mean(scores)) if scores else 0.0
21
+
22
+
23
+ def ctc_beam_search_decode(probs, beam_width=10):
24
+ T, C = probs.shape
25
+ beams = [("", 1.0)]
26
+
27
+ for t in range(T):
28
+ new_beams = {}
29
+ for prefix, score in beams:
30
+ for c in range(C):
31
+ p = probs[t, c]
32
+ if p < 1e-4:
33
+ continue
34
+ new_prefix = prefix if c == blank_idx else prefix + alphabet[c - 1]
35
+ new_beams[new_prefix] = max(
36
+ new_beams.get(new_prefix, 0.0),
37
+ score * float(p),
38
+ )
39
+ beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]
40
+
41
+ return beams[0][0], float(beams[0][1])
42
+
43
+
44
+ # -------------------- Utils --------------------
45
+ def safe_crop(img, x1, y1, x2, y2):
46
+ x1 = max(0, min(x1, img.width - 1))
47
+ y1 = max(0, min(y1, img.height - 1))
48
+ x2 = max(x1 + 1, min(x2, img.width))
49
+ y2 = max(y1 + 1, min(y2, img.height))
50
+ return img.crop((x1, y1, x2, y2))
51
+
52
+
53
+ def preprocess_crop(crop):
54
+ crop = crop.convert("L")
55
+ w, h = crop.size
56
+ new_w = max(int(w * TARGET_HEIGHT / h), 32)
57
+ crop = crop.resize((new_w, TARGET_HEIGHT), Image.BILINEAR)
58
+ img = np.array(crop, dtype=np.float32) / 255.0
59
+ img = (img - 0.5) / 0.5
60
+ return img[np.newaxis, np.newaxis, :, :]
61
+
62
+
63
+ def correct_word_dynamic(word, dictionary, threshold=80):
64
+ if not dictionary:
65
+ return word
66
+ match, score = process.extractOne(word.lower(), dictionary)
67
+ if score >= threshold:
68
+ return match
69
+ return word
70
+
71
+
72
+ # -------------------- Reading Order --------------------
73
+ def order_words_reading_order(words):
74
+ if not words:
75
+ return []
76
+
77
+ words = sorted(words, key=lambda w: w["bbox"][1])
78
+ lines = [[words[0]]]
79
+
80
+ for w in words[1:]:
81
+ prev = lines[-1][-1]
82
+
83
+ y1, y2 = w["bbox"][1], w["bbox"][3]
84
+ py1, py2 = prev["bbox"][1], prev["bbox"][3]
85
+
86
+ overlap = min(y2, py2) - max(y1, py1)
87
+ min_h = min(y2 - y1, py2 - py1)
88
+
89
+ if overlap / max(min_h, 1) > 0.4:
90
+ lines[-1].append(w)
91
+ else:
92
+ lines.append([w])
93
+
94
+ for line in lines:
95
+ line.sort(key=lambda w: w["bbox"][0])
96
+
97
+ return lines
backend/models/final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3ecd6bd68ed3c31f879ebda713dfb9a25b0835cfdd277fd410e50cd1167300e
3
+ size 25477944
backend/models/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4544b05178a744791a205d8f583b8828b7a9127cb32b295dc22fa59e59580f0e
3
+ size 25454716
backend/notebook/ocr.ipynb ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Optimized OCR Pipeline (CRNN + CTC)\n",
8
+ "Efficient implementation for RTX 4060 (8GB) with 24GB RAM"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 6,
14
+ "metadata": {},
15
+ "outputs": [
16
+ {
17
+ "name": "stderr",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "d:\\OCR-MODEL\\myenv\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
21
+ " from .autonotebook import tqdm as notebook_tqdm\n"
22
+ ]
23
+ },
24
+ {
25
+ "name": "stdout",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "Device: cuda\n",
29
+ "GPU: NVIDIA GeForce RTX 4060 Laptop GPU\n"
30
+ ]
31
+ }
32
+ ],
33
+ "source": [
34
+ "# Setup\n",
35
+ "import os\n",
36
+ "import torch\n",
37
+ "import torch.nn as nn\n",
38
+ "import torch.optim as optim\n",
39
+ "from torch.utils.data import Dataset, DataLoader\n",
40
+ "import torchvision.transforms as T\n",
41
+ "from PIL import Image\n",
42
+ "from tqdm.auto import tqdm\n",
43
+ "from datasets import load_dataset\n",
44
+ "import Levenshtein\n",
45
+ "\n",
46
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
47
+ "print(f\"Device: {device}\")\n",
48
+ "if torch.cuda.is_available():\n",
49
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
50
+ " torch.backends.cudnn.benchmark = True\n",
51
+ " torch.cuda.empty_cache()"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 7,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "# Character encoding\n",
61
+ "CHARS = \"0123456789abcdefghijklmnopqrstuvwxyz\"\n",
62
+ "BLANK = 0\n",
63
+ "char2idx = {c: i+1 for i, c in enumerate(CHARS)}\n",
64
+ "idx2char = {i+1: c for i, c in enumerate(CHARS)}\n",
65
+ "idx2char[BLANK] = \"\"\n",
66
+ "VOCAB_SIZE = len(CHARS) + 1\n",
67
+ "\n",
68
+ "def encode(text):\n",
69
+ " return [char2idx[c] for c in text.lower() if c in char2idx]\n",
70
+ "\n",
71
+ "def decode(indices):\n",
72
+ " chars, last = [], None\n",
73
+ " for i in indices:\n",
74
+ " if i != BLANK and i != last:\n",
75
+ " chars.append(idx2char.get(i, ''))\n",
76
+ " last = i\n",
77
+ " return ''.join(chars)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 11,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "# Efficient Dataset\n",
87
+ "class OCRDataset(Dataset):\n",
88
+ " def __init__(self, data, h=32, train=False):\n",
89
+ " self.data = data\n",
90
+ " self.h = h\n",
91
+ " self.train = train\n",
92
+ " self.normalize = T.Normalize(0.5, 0.5)\n",
93
+ " if train:\n",
94
+ " self.aug = T.Compose([\n",
95
+ " T.RandomRotation(2, fill=255),\n",
96
+ " T.ColorJitter(0.2, 0.2)\n",
97
+ " ])\n",
98
+ " \n",
99
+ " def __len__(self):\n",
100
+ " return len(self.data)\n",
101
+ " \n",
102
+ " def __getitem__(self, i):\n",
103
+ " img = self.data[i]['image'].convert('L')\n",
104
+ " if self.train and self.aug:\n",
105
+ " img = self.aug(img)\n",
106
+ " w, h = img.size\n",
107
+ " img = img.resize((int(w * self.h / h), self.h), Image.BILINEAR)\n",
108
+ " img = self.normalize(T.ToTensor()(img))\n",
109
+ " return img, self.data[i]['label']\n",
110
+ "\n",
111
+ "def collate(batch):\n",
112
+ " imgs, texts = zip(*batch)\n",
113
+ " enc = [encode(t) for t in texts]\n",
114
+ " valid = [i for i, e in enumerate(enc) if e]\n",
115
+ " if not valid:\n",
116
+ " return None\n",
117
+ " imgs = [imgs[i] for i in valid]\n",
118
+ " enc = [enc[i] for i in valid]\n",
119
+ " texts = [texts[i] for i in valid]\n",
120
+ " \n",
121
+ " max_w = max(img.shape[2] for img in imgs)\n",
122
+ " padded = torch.stack([nn.functional.pad(img, (0, max_w - img.shape[2])) for img in imgs])\n",
123
+ " targets = torch.IntTensor([c for seq in enc for c in seq])\n",
124
+ " lengths = torch.IntTensor([len(seq) for seq in enc])\n",
125
+ " return padded, targets, lengths, texts"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 12,
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "# Optimized CRNN Model\n",
135
+ "class CRNN(nn.Module):\n",
136
+ " def __init__(self, vocab_size=VOCAB_SIZE, hidden=256):\n",
137
+ " super().__init__()\n",
138
+ " # Efficient CNN backbone\n",
139
+ " self.cnn = nn.Sequential(\n",
140
+ " nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(True),\n",
141
+ " nn.MaxPool2d(2, 2), # 16x\n",
142
+ " \n",
143
+ " nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True),\n",
144
+ " nn.MaxPool2d(2, 2), # 8x\n",
145
+ " \n",
146
+ " nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True),\n",
147
+ " nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True),\n",
148
+ " nn.MaxPool2d((2, 1)), # 4x\n",
149
+ " \n",
150
+ " nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True),\n",
151
+ " nn.MaxPool2d((2, 1)), # 2x\n",
152
+ " \n",
153
+ " nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True)\n",
154
+ " )\n",
155
+ " self.rnn = nn.LSTM(512, hidden, 2, bidirectional=True, batch_first=False)\n",
156
+ " self.fc = nn.Linear(hidden * 2, vocab_size)\n",
157
+ " \n",
158
+ " def forward(self, x):\n",
159
+ " x = self.cnn(x) # (B, 512, 1, W)\n",
160
+ " x = x.squeeze(2).permute(2, 0, 1) # (W, B, 512)\n",
161
+ " x, _ = self.rnn(x) # (W, B, 512)\n",
162
+ " return self.fc(x) # (W, B, vocab)"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 5,
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "name": "stdout",
172
+ "output_type": "stream",
173
+ "text": [
174
+ "Loading dataset...\n",
175
+ "Demo mode: 50000 samples\n",
176
+ "Train batches: 391, Val batches: 79\n"
177
+ ]
178
+ }
179
+ ],
180
+ "source": [
181
+ "# Load data efficiently\n",
182
+ "DEMO = True # Set False for full training\n",
183
+ "SUBSET = 50000 if DEMO else None\n",
184
+ "\n",
185
+ "print(\"Loading dataset...\")\n",
186
+ "ds = load_dataset(\"priyank-m/MJSynth_text_recognition\")\n",
187
+ "\n",
188
+ "if SUBSET:\n",
189
+ " print(f\"Demo mode: {SUBSET} samples\")\n",
190
+ " train_ds = OCRDataset(ds['train'].select(range(min(SUBSET, len(ds['train'])))), train=True)\n",
191
+ " val_ds = OCRDataset(ds['val'].select(range(min(SUBSET//5, len(ds['val'])))))\n",
192
+ "else:\n",
193
+ " train_ds = OCRDataset(ds['train'], train=True)\n",
194
+ " val_ds = OCRDataset(ds['val'])\n",
195
+ "\n",
196
+ "# Optimal batch size for 8GB GPU\n",
197
+ "BS = 128\n",
198
+ "train_loader = DataLoader(train_ds, BS, shuffle=True, collate_fn=collate, num_workers=0, pin_memory=True)\n",
199
+ "val_loader = DataLoader(val_ds, BS, shuffle=False, collate_fn=collate, num_workers=0, pin_memory=True)\n",
200
+ "\n",
201
+ "print(f\"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}\")"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 6,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "# Training setup\n",
211
+ "model = CRNN().to(device)\n",
212
+ "criterion = nn.CTCLoss(blank=BLANK, zero_infinity=True)\n",
213
+ "optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)\n",
214
+ "scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, \n",
215
+ " steps_per_epoch=len(train_loader), \n",
216
+ " epochs=30 if DEMO else 50)\n",
217
+ "scaler = torch.amp.GradScaler('cuda')\n",
218
+ "\n",
219
+ "def train_epoch(model, loader, opt, crit, sched):\n",
220
+ " model.train()\n",
221
+ " total = 0\n",
222
+ " for batch in tqdm(loader, leave=False):\n",
223
+ " if batch is None:\n",
224
+ " continue\n",
225
+ " imgs, targets, t_lens, _ = batch\n",
226
+ " imgs = imgs.to(device)\n",
227
+ " \n",
228
+ " opt.zero_grad(set_to_none=True)\n",
229
+ " with torch.amp.autocast('cuda'):\n",
230
+ " preds = model(imgs)\n",
231
+ " i_lens = torch.full((imgs.size(0),), preds.size(0), dtype=torch.long)\n",
232
+ " loss = crit(nn.functional.log_softmax(preds, 2), targets, i_lens, t_lens)\n",
233
+ " \n",
234
+ " if not torch.isnan(loss):\n",
235
+ " scaler.scale(loss).backward()\n",
236
+ " scaler.unscale_(opt)\n",
237
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 5)\n",
238
+ " scaler.step(opt)\n",
239
+ " scaler.update()\n",
240
+ " sched.step()\n",
241
+ " total += loss.item()\n",
242
+ " return total / len(loader)\n",
243
+ "\n",
244
+ "@torch.no_grad()\n",
245
+ "def evaluate(model, loader):\n",
246
+ " model.eval()\n",
247
+ " cer, wer, chars, words = 0, 0, 0, 1\n",
248
+ " for batch in tqdm(loader, leave=False):\n",
249
+ " if batch is None:\n",
250
+ " continue\n",
251
+ " imgs, _, _, texts = batch\n",
252
+ " preds = model(imgs.to(device)).argmax(2).T.cpu() # (B, W)\n",
253
+ " \n",
254
+ " for pred, true in zip(preds, texts):\n",
255
+ " pred_str = decode(pred.tolist())\n",
256
+ " cer += Levenshtein.distance(pred_str, true)\n",
257
+ " wer += pred_str != true\n",
258
+ " chars += len(true)\n",
259
+ " words += 1\n",
260
+ " return cer/chars, wer/words"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 7,
266
+ "metadata": {},
267
+ "outputs": [
268
+ {
269
+ "name": "stderr",
270
+ "output_type": "stream",
271
+ "text": [
272
+ " 0%| | 0/391 [00:00<?, ?it/s]d:\\OCR-MODEL\\myenv\\lib\\site-packages\\torch\\optim\\lr_scheduler.py:224: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate\n",
273
+ " warnings.warn(\n",
274
+ " \r"
275
+ ]
276
+ },
277
+ {
278
+ "name": "stdout",
279
+ "output_type": "stream",
280
+ "text": [
281
+ "Epoch 1/30 | Loss: 5.8365 | CER: 1.0000 | WER: 0.9999\n",
282
+ " ✓ Best model saved\n"
283
+ ]
284
+ },
285
+ {
286
+ "name": "stderr",
287
+ "output_type": "stream",
288
+ "text": [
289
+ " \r"
290
+ ]
291
+ },
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "Epoch 2/30 | Loss: 3.0576 | CER: 0.9601 | WER: 0.9999\n",
297
+ " ✓ Best model saved\n"
298
+ ]
299
+ },
300
+ {
301
+ "name": "stderr",
302
+ "output_type": "stream",
303
+ "text": [
304
+ " \r"
305
+ ]
306
+ },
307
+ {
308
+ "name": "stdout",
309
+ "output_type": "stream",
310
+ "text": [
311
+ "Epoch 3/30 | Loss: 2.0099 | CER: 0.5324 | WER: 0.8891\n",
312
+ " ✓ Best model saved\n"
313
+ ]
314
+ },
315
+ {
316
+ "name": "stderr",
317
+ "output_type": "stream",
318
+ "text": [
319
+ " \r"
320
+ ]
321
+ },
322
+ {
323
+ "name": "stdout",
324
+ "output_type": "stream",
325
+ "text": [
326
+ "Epoch 4/30 | Loss: 0.6543 | CER: 0.4724 | WER: 0.8187\n",
327
+ " ✓ Best model saved\n"
328
+ ]
329
+ },
330
+ {
331
+ "name": "stderr",
332
+ "output_type": "stream",
333
+ "text": [
334
+ " \r"
335
+ ]
336
+ },
337
+ {
338
+ "name": "stdout",
339
+ "output_type": "stream",
340
+ "text": [
341
+ "Epoch 5/30 | Loss: 0.4517 | CER: 0.4502 | WER: 0.7940\n",
342
+ " ✓ Best model saved\n"
343
+ ]
344
+ },
345
+ {
346
+ "name": "stderr",
347
+ "output_type": "stream",
348
+ "text": [
349
+ " \r"
350
+ ]
351
+ },
352
+ {
353
+ "name": "stdout",
354
+ "output_type": "stream",
355
+ "text": [
356
+ "Epoch 6/30 | Loss: 0.3696 | CER: 0.4515 | WER: 0.7932\n"
357
+ ]
358
+ },
359
+ {
360
+ "name": "stderr",
361
+ "output_type": "stream",
362
+ "text": [
363
+ " \r"
364
+ ]
365
+ },
366
+ {
367
+ "name": "stdout",
368
+ "output_type": "stream",
369
+ "text": [
370
+ "Epoch 7/30 | Loss: 0.3155 | CER: 0.4296 | WER: 0.7697\n",
371
+ " ✓ Best model saved\n"
372
+ ]
373
+ },
374
+ {
375
+ "name": "stderr",
376
+ "output_type": "stream",
377
+ "text": [
378
+ " \r"
379
+ ]
380
+ },
381
+ {
382
+ "name": "stdout",
383
+ "output_type": "stream",
384
+ "text": [
385
+ "Epoch 8/30 | Loss: 0.2886 | CER: 0.4299 | WER: 0.7703\n"
386
+ ]
387
+ },
388
+ {
389
+ "name": "stderr",
390
+ "output_type": "stream",
391
+ "text": [
392
+ " \r"
393
+ ]
394
+ },
395
+ {
396
+ "name": "stdout",
397
+ "output_type": "stream",
398
+ "text": [
399
+ "Epoch 9/30 | Loss: 0.2423 | CER: 0.4226 | WER: 0.7601\n",
400
+ " ✓ Best model saved\n"
401
+ ]
402
+ },
403
+ {
404
+ "name": "stderr",
405
+ "output_type": "stream",
406
+ "text": [
407
+ " \r"
408
+ ]
409
+ },
410
+ {
411
+ "name": "stdout",
412
+ "output_type": "stream",
413
+ "text": [
414
+ "Epoch 10/30 | Loss: 0.2049 | CER: 0.4224 | WER: 0.7588\n",
415
+ " ✓ Best model saved\n"
416
+ ]
417
+ },
418
+ {
419
+ "name": "stderr",
420
+ "output_type": "stream",
421
+ "text": [
422
+ " \r"
423
+ ]
424
+ },
425
+ {
426
+ "name": "stdout",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "Epoch 11/30 | Loss: 0.1812 | CER: 0.4172 | WER: 0.7523\n",
430
+ " ✓ Best model saved\n"
431
+ ]
432
+ },
433
+ {
434
+ "name": "stderr",
435
+ "output_type": "stream",
436
+ "text": [
437
+ " \r"
438
+ ]
439
+ },
440
+ {
441
+ "name": "stdout",
442
+ "output_type": "stream",
443
+ "text": [
444
+ "Epoch 12/30 | Loss: 0.1600 | CER: 0.4197 | WER: 0.7504\n"
445
+ ]
446
+ },
447
+ {
448
+ "name": "stderr",
449
+ "output_type": "stream",
450
+ "text": [
451
+ " \r"
452
+ ]
453
+ },
454
+ {
455
+ "name": "stdout",
456
+ "output_type": "stream",
457
+ "text": [
458
+ "Epoch 13/30 | Loss: 0.1414 | CER: 0.4154 | WER: 0.7454\n",
459
+ " ✓ Best model saved\n"
460
+ ]
461
+ },
462
+ {
463
+ "name": "stderr",
464
+ "output_type": "stream",
465
+ "text": [
466
+ " \r"
467
+ ]
468
+ },
469
+ {
470
+ "name": "stdout",
471
+ "output_type": "stream",
472
+ "text": [
473
+ "Epoch 14/30 | Loss: 0.1234 | CER: 0.4139 | WER: 0.7449\n",
474
+ " ✓ Best model saved\n"
475
+ ]
476
+ },
477
+ {
478
+ "name": "stderr",
479
+ "output_type": "stream",
480
+ "text": [
481
+ " \r"
482
+ ]
483
+ },
484
+ {
485
+ "name": "stdout",
486
+ "output_type": "stream",
487
+ "text": [
488
+ "Epoch 15/30 | Loss: 0.1082 | CER: 0.4121 | WER: 0.7411\n",
489
+ " ✓ Best model saved\n"
490
+ ]
491
+ },
492
+ {
493
+ "name": "stderr",
494
+ "output_type": "stream",
495
+ "text": [
496
+ " \r"
497
+ ]
498
+ },
499
+ {
500
+ "name": "stdout",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "Epoch 16/30 | Loss: 0.0928 | CER: 0.4104 | WER: 0.7389\n",
504
+ " ✓ Best model saved\n"
505
+ ]
506
+ },
507
+ {
508
+ "name": "stderr",
509
+ "output_type": "stream",
510
+ "text": [
511
+ " \r"
512
+ ]
513
+ },
514
+ {
515
+ "name": "stdout",
516
+ "output_type": "stream",
517
+ "text": [
518
+ "Epoch 17/30 | Loss: 0.0807 | CER: 0.4114 | WER: 0.7395\n"
519
+ ]
520
+ },
521
+ {
522
+ "name": "stderr",
523
+ "output_type": "stream",
524
+ "text": [
525
+ " \r"
526
+ ]
527
+ },
528
+ {
529
+ "name": "stdout",
530
+ "output_type": "stream",
531
+ "text": [
532
+ "Epoch 18/30 | Loss: 0.0672 | CER: 0.4092 | WER: 0.7384\n",
533
+ " ✓ Best model saved\n"
534
+ ]
535
+ },
536
+ {
537
+ "name": "stderr",
538
+ "output_type": "stream",
539
+ "text": [
540
+ " \r"
541
+ ]
542
+ },
543
+ {
544
+ "name": "stdout",
545
+ "output_type": "stream",
546
+ "text": [
547
+ "Epoch 19/30 | Loss: 0.0549 | CER: 0.4097 | WER: 0.7363\n"
548
+ ]
549
+ },
550
+ {
551
+ "name": "stderr",
552
+ "output_type": "stream",
553
+ "text": [
554
+ " \r"
555
+ ]
556
+ },
557
+ {
558
+ "name": "stdout",
559
+ "output_type": "stream",
560
+ "text": [
561
+ "Epoch 20/30 | Loss: 0.0466 | CER: 0.4091 | WER: 0.7371\n",
562
+ " ✓ Best model saved\n"
563
+ ]
564
+ },
565
+ {
566
+ "name": "stderr",
567
+ "output_type": "stream",
568
+ "text": [
569
+ " \r"
570
+ ]
571
+ },
572
+ {
573
+ "name": "stdout",
574
+ "output_type": "stream",
575
+ "text": [
576
+ "Epoch 21/30 | Loss: 0.0389 | CER: 0.4074 | WER: 0.7356\n",
577
+ " ✓ Best model saved\n"
578
+ ]
579
+ },
580
+ {
581
+ "name": "stderr",
582
+ "output_type": "stream",
583
+ "text": [
584
+ " \r"
585
+ ]
586
+ },
587
+ {
588
+ "name": "stdout",
589
+ "output_type": "stream",
590
+ "text": [
591
+ "Epoch 22/30 | Loss: 0.0311 | CER: 0.4072 | WER: 0.7332\n",
592
+ " ✓ Best model saved\n"
593
+ ]
594
+ },
595
+ {
596
+ "name": "stderr",
597
+ "output_type": "stream",
598
+ "text": [
599
+ " \r"
600
+ ]
601
+ },
602
+ {
603
+ "name": "stdout",
604
+ "output_type": "stream",
605
+ "text": [
606
+ "Epoch 23/30 | Loss: 0.0272 | CER: 0.4073 | WER: 0.7348\n"
607
+ ]
608
+ },
609
+ {
610
+ "name": "stderr",
611
+ "output_type": "stream",
612
+ "text": [
613
+ " \r"
614
+ ]
615
+ },
616
+ {
617
+ "name": "stdout",
618
+ "output_type": "stream",
619
+ "text": [
620
+ "Epoch 24/30 | Loss: 0.0207 | CER: 0.4063 | WER: 0.7328\n",
621
+ " ✓ Best model saved\n"
622
+ ]
623
+ },
624
+ {
625
+ "name": "stderr",
626
+ "output_type": "stream",
627
+ "text": [
628
+ " \r"
629
+ ]
630
+ },
631
+ {
632
+ "name": "stdout",
633
+ "output_type": "stream",
634
+ "text": [
635
+ "Epoch 25/30 | Loss: 0.0179 | CER: 0.4055 | WER: 0.7300\n",
636
+ " ✓ Best model saved\n"
637
+ ]
638
+ },
639
+ {
640
+ "name": "stderr",
641
+ "output_type": "stream",
642
+ "text": [
643
+ " \r"
644
+ ]
645
+ },
646
+ {
647
+ "name": "stdout",
648
+ "output_type": "stream",
649
+ "text": [
650
+ "Epoch 26/30 | Loss: 0.0148 | CER: 0.4052 | WER: 0.7290\n",
651
+ " ✓ Best model saved\n"
652
+ ]
653
+ },
654
+ {
655
+ "name": "stderr",
656
+ "output_type": "stream",
657
+ "text": [
658
+ " \r"
659
+ ]
660
+ },
661
+ {
662
+ "name": "stdout",
663
+ "output_type": "stream",
664
+ "text": [
665
+ "Epoch 27/30 | Loss: 0.0140 | CER: 0.4050 | WER: 0.7298\n",
666
+ " ✓ Best model saved\n"
667
+ ]
668
+ },
669
+ {
670
+ "name": "stderr",
671
+ "output_type": "stream",
672
+ "text": [
673
+ " \r"
674
+ ]
675
+ },
676
+ {
677
+ "name": "stdout",
678
+ "output_type": "stream",
679
+ "text": [
680
+ "Epoch 28/30 | Loss: 0.0125 | CER: 0.4050 | WER: 0.7303\n"
681
+ ]
682
+ },
683
+ {
684
+ "name": "stderr",
685
+ "output_type": "stream",
686
+ "text": [
687
+ " \r"
688
+ ]
689
+ },
690
+ {
691
+ "name": "stdout",
692
+ "output_type": "stream",
693
+ "text": [
694
+ "Epoch 29/30 | Loss: 0.0122 | CER: 0.4046 | WER: 0.7296\n",
695
+ " ✓ Best model saved\n"
696
+ ]
697
+ },
698
+ {
699
+ "name": "stderr",
700
+ "output_type": "stream",
701
+ "text": [
702
+ " "
703
+ ]
704
+ },
705
+ {
706
+ "name": "stdout",
707
+ "output_type": "stream",
708
+ "text": [
709
+ "Epoch 30/30 | Loss: 0.0122 | CER: 0.4047 | WER: 0.7292\n",
710
+ "\n",
711
+ "Training complete! Best CER: 0.4046\n"
712
+ ]
713
+ },
714
+ {
715
+ "name": "stderr",
716
+ "output_type": "stream",
717
+ "text": [
718
+ "\r"
719
+ ]
720
+ }
721
+ ],
722
+ "source": [
723
+ "# Train\n",
724
+ "EPOCHS = 30 if DEMO else 50\n",
725
+ "best_cer = float('inf')\n",
726
+ "os.makedirs('checkpoints', exist_ok=True)\n",
727
+ "\n",
728
+ "for ep in range(EPOCHS):\n",
729
+ " loss = train_epoch(model, train_loader, optimizer, criterion, scheduler)\n",
730
+ " cer, wer = evaluate(model, val_loader)\n",
731
+ " \n",
732
+ " print(f\"Epoch {ep+1}/{EPOCHS} | Loss: {loss:.4f} | CER: {cer:.4f} | WER: {wer:.4f}\")\n",
733
+ " \n",
734
+ " if cer < best_cer:\n",
735
+ " best_cer = cer\n",
736
+ " torch.save(model.state_dict(), 'checkpoints/best.pth')\n",
737
+ " print(\" ✓ Best model saved\")\n",
738
+ "\n",
739
+ "torch.save(model.state_dict(), 'checkpoints/final.pth')\n",
740
+ "print(f\"\\nTraining complete! Best CER: {best_cer:.4f}\")"
741
+ ]
742
+ },
743
+ {
744
+ "cell_type": "code",
745
+ "execution_count": 16,
746
+ "metadata": {},
747
+ "outputs": [
748
+ {
749
+ "ename": "NameError",
750
+ "evalue": "name 'model' is not defined",
751
+ "output_type": "error",
752
+ "traceback": [
753
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
754
+ "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
755
+ "Cell \u001b[1;32mIn[16], line 4\u001b[0m\n\u001b[0;32m 2\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mexport\u001b[39m\u001b[38;5;124m'\u001b[39m, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m 3\u001b[0m dummy \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandn(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m32\u001b[39m, \u001b[38;5;241m128\u001b[39m)\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m----> 4\u001b[0m torch\u001b[38;5;241m.\u001b[39monnx\u001b[38;5;241m.\u001b[39mexport(\u001b[43mmodel\u001b[49m, dummy, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mexport/model.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m, \n\u001b[0;32m 5\u001b[0m input_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput\u001b[39m\u001b[38;5;124m'\u001b[39m], output_names\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[0;32m 6\u001b[0m dynamic_axes\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124minput\u001b[39m\u001b[38;5;124m'\u001b[39m: {\u001b[38;5;241m0\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m3\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mwidth\u001b[39m\u001b[38;5;124m'\u001b[39m}, \n\u001b[0;32m 7\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124moutput\u001b[39m\u001b[38;5;124m'\u001b[39m: {\u001b[38;5;241m0\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mseq\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m1\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m'\u001b[39m}},\n\u001b[0;32m 8\u001b[0m opset_version\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m14\u001b[39m)\n\u001b[0;32m 9\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel exported to export/model.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
756
+ "\u001b[1;31mNameError\u001b[0m: name 'model' is not defined"
757
+ ]
758
+ }
759
+ ],
760
+ "source": [
761
+ "# Export to ONNX\n",
762
+ "os.makedirs('export', exist_ok=True)\n",
763
+ "dummy = torch.randn(1, 1, 32, 128).to(device)\n",
764
+ "torch.onnx.export(model, dummy, 'export/model.onnx', \n",
765
+ " input_names=['input'], output_names=['output'],\n",
766
+ " dynamic_axes={'input': {0: 'batch', 3: 'width'}, \n",
767
+ " 'output': {0: 'seq', 1: 'batch'}},\n",
768
+ " opset_version=14)\n",
769
+ "print(\"Model exported to export/model.onnx\")"
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": 15,
775
+ "metadata": {},
776
+ "outputs": [
777
+ {
778
+ "name": "stderr",
779
+ "output_type": "stream",
780
+ "text": [
781
+ "C:\\Users\\Jagadeesh\\AppData\\Local\\Temp\\ipykernel_15540\\4199438618.py:7: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
782
+ " model.load_state_dict(torch.load(model_path, map_location=device))\n"
783
+ ]
784
+ },
785
+ {
786
+ "data": {
787
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfoAAADECAYAAAB3EuMgAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAFGhJREFUeJzt3XuQT/X/wPH3Wpa1uy7JrvvaYiOXWHe55tIUGpVJUalJEQlNU1ETpfmaDBEp1SC0Jf6Q6OISFQpjXApZYt2HXewiu5bl85vX+fXZ2c/n/V6fY/eza/e9z8eMWee1788553POZz+v8z7ndd4nxOPxeBQAALBSmVu9AgAAoPCQ6AEAsBiJHgAAi5HoAQCwGIkeAACLkegBALAYiR4AAIuR6AEAsBiJHgAAi5HogSCqX7++euaZZ3Kmf/nlFxUSEuL8DBaZ38SJE1Vx0a1bN+ef1+HDh511/OKLL4K2jMLYjiVpGwMFQaKHNSSxyBe091+FChVUfHy8eumll9Tp06dVSfLDDz+QaAL46quv1IwZM4pmhwAlWNlbvQJAsL377rsqLi5OXb58WW3cuFF98sknTuLcvXu3qlixYpFu8C5duqjMzEwVFhZ2U6+T9Z09e7Yx2cv8ypYtvn+6sbGxzjqWK1euULejJHrZp2PGjAnacgAbFd9vCyCfHnjgAdW6dWvn/0OHDlXVqlVTH3zwgVq+fLl64oknjK+5dOmSioiICPo2L1OmjHNmIZiCPb9g855NKe7bESgtOHUP6913333Oz+TkZOenXEOPjIxUBw8eVA8++KCKiopSgwcPdn53/fp153RwkyZNnMQSExOjhg0bptLS0nzmKQ99fO+991SdOnWcswTdu3dXe/bscX1tecuWLc6yq1at6hxgNG/eXH344Yc56ye9eZH7UsSNrh/v2LHDOcCpVKmS89569OihNm/ebLy0sWnTJvXKK6+o6tWrO8t++OGHVWpqqk/b8+fPq3379jk/b5bpGr13mx89elT17dvX+X/t2rVz3udff/3l7CdZHzkjIL31G21HqQn4/vvv1ZEjR3K2j9RHeGVlZakJEyaoBg0aqPLly6u6deuq1157zYnnJtNjx451toV8Dh566CF1/Pjxm37PQHFGjx7Wk4QupGfvlZ2dre6//37VqVMnNXXq1JxT+pLUJUE9++yz6uWXX3YODj766CMnkUqC9J6Ofvvtt51EL8la/m3fvl317t1bXblyJeD6rFmzxkl2NWvWVKNHj1Y1atRQf//9t1q5cqUzLetw8uRJp92iRYsCzk8OMDp37uwkeUlmso6ffvqpkwx//fVX1a5dO5/2o0aNcg4wJBFKUpYDG6lj+Oabb3LaLFu2zNkG8+fP9ykuLIhr1645ByNyGn7KlCkqMTHRWa4k9zfffNM52HrkkUfUnDlz1NNPP606dOjgXIIxkfZyECJJefr06U5MDh68B2uSsOWyzQsvvKAaN27sHEhIu/3796tvv/02Zz5yxufLL79UgwYNUh07dlTr1q1Tffr0Ccr7BYoNeR49YIP58+d75CO9du1aT2pqqufYsWOexYsXe6pVq+YJDw/3HD9+3Gk3ZMgQp90bb7zh8/oNGzY48cTERJ/4Tz/95BNPSUnxhIWFefr06eO5fv16Trvx48c77WT+XuvXr3di8lNkZ2d74uLiPLGxsZ60tDSf5eSe18iRI53XmUh8woQJOdP9+/d31ufgwYM5sZMnT3qioqI8Xbp00bZPz549fZY1duxYT2hoqCc9PV1rKz8D6dq1q/PPKzk5WXutd5v/73//y4nJ+5f9EhIS4uwnr3379mnv0X87Ctn+sh39LVq0yFOmTBlnf+Y2Z84cZx6bNm1ypnfu3OlMjxgxwqfdoEGDtOUDJRmn7mGdnj17Oqdi5XTt448/7vT0pIcqp4pze/HFF32mly5dqipXrqx69eqlzpw5k/OvVatWzjzWr1/vtFu7dq3Tc5eece5T6m6KwuTMgJwlkLZVqlTx+V3ued1ML3n16tWqf//+6o477siJy9kC6aVKr/bChQs+r5Febu5lydkAmY+cBveSXrwcUwSrN5+7B+0l7/+uu+5yevSPPfZYTlxi8rtDhw7laxmyH6UX36hRI5/96L2E492PUvAo5MxNbhT3wTacuod15Lqv3FYnlelyjV0ShxRz5Sa/k+vruR04cMA5HRwdHW2cb0pKivPTmxAbNmzo83s5uJBT4m4uIzRt2lQFg1xbz8jIcN6jP0l2chr72LFjTs2BV7169XzaedfZvw4h2KTmQbZRbnJgJfvB/yBH4vldH9mPcinEf1mm/SifizvvvNPn96ZtCZRkJHpYp23btjlV93mRAi3/5C9JUZK8XDs2yStxlDShoaHG+P9fFSj65QZ7fWQ/NmvWzLnTwkTO9AClCYke+I/07OS0/L333qvCw8Pz3C5SFe7tOeY+XS6960C9UG/vUe7/lksMeXF7Gl8OPqSQMCkpSfudVM3LwYytiS2vbSTbeNeuXc6dBzfajrIf5aBAzrLk7sWbtiVQknGNHviPXCeWa9WTJk3StolU6aenpzv/lwQtle2zZs3y6XW6GaUtISHBqSSXtt75eeWel/eefv82pt6wVPvLGAFSQe8lIwHKLWpyV4FU49+sgtxeV1RkG5nWT/bjiRMn1Oeff679TgbdkTEThNwBIGbOnOnThtH2YBt69MB/unbt6tzaNnnyZLVz504ngUpCl567FHjJfe4DBgxwetGvvvqq005uk5Pb66TI7scff1S33377Dben9LBlpL5+/fqpFi1aOLewSeGcJFW5TW7VqlVOOykA9BaKyW2AktClsNBEbvOTW/EkqY8YMcKpP5Db6+QecbmNLT8K4/a6YJNtJLcEypgAbdq0cQomZbs+9dRTasmSJWr48OFO4Z2coZEDONnGEpdtLJd2ZPvLAEoff/yxc8Agt9f9/PPP6p9//rnVbw0IKhI9kIvcwy0JRBLl+PHjnaQpA7E8+eSTTsLInVyluEzaSzKRe9Wl+t3NPdiSuOU177zzjpo2bZpz+lhONz///PM5beR+cqnqX7x4sXOft/T280r0Umi3YcMGNW7cOOfgQ+Yn6yOv87+H3iZyUCMHZHIwIvfIy6l4SfRyMCX3ykts4cKFzkGLXN6QyywyToEUanrNmzfPOXCTugx5jVTmy0A8tl7uQOkUIvfY3eqVAAAAhYNr9AAAWIxEDwCAxUj0AABYjEQPAIDFSPQAAFiMRA8AgMVI9AAAWIxEDwCAxUj0AABYjEQPAIDFSPQAAFiMRA8AgMVI9AAAWIxEDwCAxUj0AABYjEQPAIDFSPQAAFiMRA8AgMVI9AAAWIxEDwCAxUj0AABYjEQPAIDFSPQAAFiMRA8AgMVI9AAAWIxEDwCAxUj0AABYjEQPAIDFSPQAAFiMRA8AgMVI9AAAWIxEDwCAxUj0AABYjEQPAIDFSPQAAFiMRA8AgMXKqmLs6tWrWmzlypVa7NSpUwHn1blzZy3WtGlTVVxlZWX5TK9YsUJrk5qaqsVCQkK0WLdu3bRYo0aNtNi5c+d8pr/77jutTWZmpsqv6tWr+0z369dPa1O+fHkVTEX9nkzvq1y5clqb1atXa7Hk5GRXy0xISPCZbteundZmy5YtWmz79u3KNrVq1dJiffv2dfVdsnz58ht+Voq70NBQLda7d2+f6fr167ua17p167RYUlKSKkxly+rpp1KlSlqsbt26PtNxcXFam5iYGC1Wpgz9WC+2BAAAFiPRAwBgMRI9AAAWI9EDAGCxYl2M51+QJqZOnarFfv/994AFabNnzy5RxXiXLl3ymX7//fe1Ntu2bXNVoDN37lxXxXgnT570mX799de1NikpKSq/YmNjfaabN2+utYmPj1fB5P/ZGDlypNYmIyMj3/Nv3bq1FuvRo4fPdEREhNbms88+02LLli1ztcy33nrLZ7pt27YBC83E5MmTlW26d++uxXr16qXF/v33Xy323nvv+Uzv3r1blSSmwtWlS5cGLMbzeDxabMGCBVps4cKFKlhM38luC/QqVqx4w+8R8eijj2qxoUOHuireLA3o0QMAYDESPQAAFiPRAwBgsWJ9jR4lg+n6m+k6oP/ARrt27QrqNXrTMv2v0Zuux7u9fmiaf3FlqsEwDSTjlv+gQn/88YfWxrRtTYOWtGnTxtXAQ240a9bMVZ1Kfpk+G61atdJiderUUUXNNBBTjRo1inw96tWr5zPdpEkTV387phqsY8eOabEjR474TP/5559amz179mixzZs3u6rVijMMwGMbevQAAFiMRA8AgMVI9AAAWIxEDwCAxSjGQ4FFR0e7Ksy6ePGiz/TGjRtdDXzh9ilUpkFRTEVjbgrBrly5osXS09NVSTF48GAtNnDgwHzPz79IyjQozeHDh10VjPkP+JPX/NwwfTZMy8wv0wAuY8eO1WIDBgxQxYFpfQub/xPzZs2a5aqo0fREwTNnzgQcBGjKlCmuXrdq1SpXg0bNnDlTi1WoUEHZhB49AAAWI9EDAGAxEj0AABYj0QMAYDGK8VBgdevWdVXM5j+i1datW7U258+f12JVq1Z1tR6mYrCkpKSAr2vRooWr15WkYjzT6HAFGTEuLCysgGt043mZnsRWXJmK/YK5fUoa/4JI0740FeOZ2kVGRmqxMWPGBCzimzhxohYztVtueKrjsGHDXI1+WJLRowcAwGIkegAALEaiBwDAYiR6AAAsRjEeCiwqKkqLxcbGBizGO3jwoNYmOTk538V427dvDzhilqmQqn379lps//79rpYJoHD5/82aRnicN2+eFjN9v5wxjKDn/yhrQTEeAAAoMTh1DwCAxUj0AABYjEQPAIDFKMZDoejUqZMWS0xM9Jk+d+6cq4K6hIQELXbt2jUtZnrsrf/oWDExMVqbli1barEFCxZoMSCv0RtTUlIKdeNUqVJFi5XW0fhq166txeLj410V412/fl2L7d27V4t5PJ6AI/uVJPToAQCwGIkeAACLkegBALAYiR4AAItRjIcC8y9cEffcc48Wi46O9pk+ceKEq4K6IUOGuHpkrKmQz1/Dhg21WL169VwV+6H0yc7O1mKTJk3SYjNmzAjaMsPDw7XYzJkztViHDh1UaWR6vK2pyNats2fPBtzvphE1SxJ69AAAWIxEDwCAxUj0AABYjGv0KDDTIBSmQS2aNGkS8Br9zp07tVhaWpoWO3LkiKsn3/lr166dq6fvAXnVn5w+fVqLmQZ/CuY1+szMTHbIDRRk8KBsQx2Gab+XZPToAQCwGIkeAACLkegBALAYiR4AAItRjIdCERkZqcU6duzoM71mzRqtzdGjR109hWr37t2unipWoUKFgE/VMz2ZyrZiHORPaGioFhs3bpwW69KlS6Eus1mzZkGbf0ln+tvMyMjI9/wq+H1H5LUPSjJ69AAAWIxEDwCAxUj0AABYjEQPAIDFKMZDoYyMZ+JfjBcREaG1uXDhghbbtm2bFtuxY4er9ahRo0bAp+rxpDrkpUwZvS909913a7Hu3buzEYuIqfDONMqmW7UNo3ia9ntJZte7AQAAPkj0AABYjEQPAIDFSPQAAFiMYjwUmcaNG/tM16tXT2uzd+9eLbZu3TpXI+iZ+Bff1axZU2tz/PhxV/MCcOsdOnRIiyUlJbl6bbly5bSYqUDXNFpmSUaPHgAAi5HoAQCwGIkeAACLkegBALAYxXgoMjExMT7TCQkJrorxfvvtNy2WlZXlajSrzp07B3wkJYDi6+LFiz7Tc+bM0dqcOnXK1bxiY2O1WPv27ZXt6NEDAGAxEj0AABYj0QMAYLFSc40+LS0t34OuBJPpOnJ0dLQqDfwHq+jSpYvW5uuvv9Zi586dczX/KlWqBHxiHnAzTE9EPHPmTLH4LjGpXLnyDaeL4hq6aUAb0/fe1atXtdjhw4e1WGJios/0kiVLXO2n8uXLa7HnnntOi8XFxSnb0aMHAMBiJHoAACxGogcAwGIkegAALGZlMZ7H49Fi06dP12Lz5s1TRS0qKkqLzZ07V4vVr19f2a5NmzZa7LbbbtNiqamprubXoEEDLdawYcN8fV5MMZQ+piKviRMnarFp06ap4mD06NE+06NGjSr0Za5YscJnetOmTa5el52drcUuXLigxS5duhTwb9NUdDh06FAtNnz4cC0WGhqqbEePHgAAi5HoAQCwGIkeAACLkegBALBYsS7GM42mZCpmM42I5qbw4+zZs6qomdbj2rVrWiwkJCRf77tsWX2XhoWFuVo3/6KUSpUqaW2uXLmixSIjIwOuv4lpRKqWLVtqsa1btyo3/J9UJ6pWrZqvz5npvWdkZGgx035x894jIiLy9TkW4eHhqqj5byNT8ZNp/U2jk/mPkFgUTPvEfx+73f6mv4Fb8V1iek+mz6gbFStW1GJut4ebgjq362/6e/J/4pzpO2LgwIFarGfPnlqsQil9eiU9egAALEaiBwDAYiR6AAAsRqIHAMBiIZ5iPASYaVSq/fv3Bxw5qTgzjcJkGtHNv+DK9L7dFt6Yit5MI9BdvnzZZzopKclVMaGpsCw+Pj5gQZfpo2d6xGV6erpyo06dOlosJiYm4OuysrK0mOm9mx6raSpi8n/vpmK/grzPmjVr+kzXqlVLFTb/AjTT9jEVqZkKrkyjFZqKsILJtO8OHDjgM52ZmalKOv/Pgv9nJS+mx8MWdoGh6bNhKparVq1awAJbtwXHpRU9egAALEaiBwDAYiR6AAAsRqIHAMBixboYDwAAFAw9egAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAMBiJHoAACxGogcAwGIkegAALEaiBwDAYiR6AAAsRqIHAEDZ6/8AeD4DQXfr8wcAAAAASUVORK5CYII=",
788
+ "text/plain": [
789
+ "<Figure size 1000x200 with 1 Axes>"
790
+ ]
791
+ },
792
+ "metadata": {},
793
+ "output_type": "display_data"
794
+ },
795
+ {
796
+ "data": {
797
+ "text/plain": [
798
+ "'limited'"
799
+ ]
800
+ },
801
+ "execution_count": 15,
802
+ "metadata": {},
803
+ "output_type": "execute_result"
804
+ }
805
+ ],
806
+ "source": [
807
+ "# Inference function\n",
808
+ "import matplotlib.pyplot as plt\n",
809
+ "\n",
810
+ "@torch.no_grad()\n",
811
+ "def predict(img_path, model_path='checkpoints/best.pth'):\n",
812
+ " model = CRNN().to(device)\n",
813
+ " model.load_state_dict(torch.load(model_path, map_location=device))\n",
814
+ " model.eval()\n",
815
+ " \n",
816
+ " img = Image.open(img_path).convert('L')\n",
817
+ " w, h = img.size\n",
818
+ " img = img.resize((int(w * 32 / h), 32), Image.BILINEAR)\n",
819
+ " \n",
820
+ " tensor = T.Normalize(0.5, 0.5)(T.ToTensor()(img)).unsqueeze(0).to(device)\n",
821
+ " pred = model(tensor).argmax(2).squeeze(1).cpu()\n",
822
+ " text = decode(pred.tolist())\n",
823
+ " \n",
824
+ " plt.figure(figsize=(10, 2))\n",
825
+ " plt.imshow(img, cmap='gray')\n",
826
+ " plt.title(f\"Prediction: {text}\")\n",
827
+ " plt.axis('off')\n",
828
+ " plt.show()\n",
829
+ " \n",
830
+ " return text\n",
831
+ "\n",
832
+ "# Example usage:\n",
833
+ "predict('image copy 3.png')"
834
+ ]
835
+ },
836
+ {
837
+ "cell_type": "code",
838
+ "execution_count": 25,
839
+ "metadata": {},
840
+ "outputs": [],
841
+ "source": [
842
+ "import onnxruntime as ort\n",
843
+ "import numpy as np\n",
844
+ "from PIL import Image\n",
845
+ "import torchvision.transforms as T\n",
846
+ "import matplotlib.pyplot as plt\n",
847
+ "\n",
848
+ "# Define your alphabet exactly as in training\n",
849
+ "alphabet = \"0123456789abcdefghijklmnopqrstuvwxyz\" # replace with your chars\n",
850
+ "blank_idx = 0 # Usually 0 in CTC training\n",
851
+ "\n",
852
+ "def ctc_greedy_decode(pred_indices, blank=blank_idx):\n",
853
+ " \"\"\"Convert CTC output indices to string.\"\"\"\n",
854
+ " prev_idx = None\n",
855
+ " result = []\n",
856
+ "\n",
857
+ " for idx in pred_indices:\n",
858
+ " if idx != blank and idx != prev_idx:\n",
859
+ " result.append(idx)\n",
860
+ " prev_idx = idx\n",
861
+ "\n",
862
+ " text = ''.join([alphabet[i - 1] for i in result]) # subtract 1 if alphabet doesn't include blank\n",
863
+ " return text\n",
864
+ " \n",
865
+ "def predict_onnx(img_path, onnx_path='crnn.onnx'):\n",
866
+ " # Load ONNX model\n",
867
+ " ort_session = ort.InferenceSession(onnx_path)\n",
868
+ "\n",
869
+ " # Load and preprocess image\n",
870
+ " img = Image.open(img_path).convert('L')\n",
871
+ " w, h = img.size\n",
872
+ " new_w = max(int(w * 32 / h), 32)\n",
873
+ " img = img.resize((new_w, 32), Image.BILINEAR)\n",
874
+ "\n",
875
+ " tensor = T.ToTensor()(img)\n",
876
+ " tensor = T.Normalize((0.5,), (0.5,))(tensor)\n",
877
+ " tensor = tensor.unsqueeze(0).numpy() # batch dimension\n",
878
+ "\n",
879
+ " # Run inference\n",
880
+ " ort_inputs = {ort_session.get_inputs()[0].name: tensor}\n",
881
+ " preds = ort_session.run(None, ort_inputs)[0] # shape: (seq_len, batch, num_classes)\n",
882
+ "\n",
883
+ " # Remove batch dimension\n",
884
+ " preds = preds[:, 0, :] # shape: (seq_len, num_classes)\n",
885
+ "\n",
886
+ " # Greedy decode\n",
887
+ " pred_indices = np.argmax(preds, axis=1) # shape: (seq_len,)\n",
888
+ " text = ctc_greedy_decode(pred_indices.tolist()) # collapse repeats & remove blank\n",
889
+ "\n",
890
+ " # Plot\n",
891
+ " plt.figure(figsize=(10, 2))\n",
892
+ " plt.imshow(img, cmap='gray')\n",
893
+ " plt.title(f\"Prediction: {text}\")\n",
894
+ " plt.axis('off')\n",
895
+ " plt.show()\n",
896
+ "\n",
897
+ " return text\n"
898
+ ]
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "execution_count": 29,
903
+ "metadata": {},
904
+ "outputs": [
905
+ {
906
+ "data": {
907
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAACVCAYAAADfTozCAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJKBJREFUeJzt3Qe0FEXWwPFGMgoSnhJEsgSRrBJUUKIEERAQEyqioAIr4Lquu6iwunhAd1fXCKuCCiJIEhGQLKjkjIAkMeAKSjAgAvq+c3vP9He76BrmvVcvjPx/53CoeV3T09Pd09M1detWrtTU1FQPAAAAABw6w+XKAAAAAICGBgAAAIBMQY8GAAAAAOdoaAAAAABwjoYGAAAAAOdoaAAAAABwjoYGAAAAAOdoaAAAAABwjoYGAAAAAOdoaABAgipUqODddtttweNFixZ5uXLl8v93Rdb36KOPJs0x+eyzz/xtfvLJJ52tc8yYMf46Zd3p8frrr3vVq1f38ubN6xUtWtT/25VXXun/AwBkHRoaAJJC7OYz9q9AgQJe1apVvX79+nnffPONl0zee++9pGpMJJOtW7f6jcHKlSt7o0eP9kaNGpXdmwQAp6082b0BAJAWw4YN8ypWrOgdPXrUW7p0qffCCy/4N+6bNm3yChUqlKU7s2nTpt7PP//s5cuXL03Pk+197rnnIhsbsr48eU7vS/Mtt9zi9ejRw8ufP3+anyu9S7/99pv39NNPe1WqVAn+/v777zveSgDAqZze32YAkk7btm29iy++2C/37t3bK1GihPePf/zDmz59unfDDTdEPuenn37yzjzzTOfbcsYZZ/g9Ky65Xl8yyp07t/8vPfbt2+f/HwuZiklrYxAAkHGETgFIas2bN/f/3717t/+/hM2cddZZ3s6dO7127dp5hQsX9m666SZ/mfzS/a9//curWbOmf0NfsmRJr0+fPt7BgwdD60xNTfUee+wxr2zZsn4vyVVXXeVt3rz5pNe2jdFYvny5/9rFihXzGzi1a9f2f2GPbZ/0ZggdChZvjMbatWv9BlaRIkX899aiRQtv2bJlkaFlH374oTdo0CDvnHPO8V+7c+fO3v79+0N1Dx8+7IcYyf+nsmrVKq9NmzZeSkqKV7BgQb83qVevXpF1//nPf3rly5f36zVr1szvZdI2bNjgv/9KlSr5+79UqVL+ur777rtTjtGQhmT79u29MmXK+D0dEhr1t7/9zfv1119DY2geeeQRvyzvX+9Lc4xG7NhNnDjRe/zxx/1jLdsk+3bHjh2h7VmyZInXrVs3r1y5cv5rn3/++d7AgQP93ictdu599dVXXqdOnfyybMf9998f2k4R63WpVauW/7pS7+qrr/b3t/bGG294DRo08Pdp8eLF/Z6eL7744hRHDQByBno0ACQ1aVAI6dmIOXHihH9zfPnll/uDlGMhVdKokJvY22+/3RswYIDfOHn22Wf9G3m5QZfBw+Lhhx/2GxrSWJB/a9as8Vq3bu0dO3bslNszd+5cr0OHDl7p0qW9P/zhD/7N9JYtW7x3333XfyzbsHfvXr+eDFo+FWngXHHFFX4j44EHHvC38aWXXvJvmhcvXuw1bNgwVL9///5+A0duuOVGXRpWMo7lrbfeCupMnTrV3wevvvpqaHB7VO+AvG+5CX7wwQf9XgJZ55QpU06q+9prr3k//PCDd++99/phbXITLY3AjRs3+g262L7ZtWuX/9qyX+S9yRgK+V8aTrrBZZLjJjfu0oiS/xcsWOAfp++//94bOXKkX0feq2yHvD8JqZN60siL54knnvB7pqQxIA2vESNG+A1TaSzGTJo0yTty5Ih39913++fZihUrvH//+9/el19+6S/TpEEh554cFzn35s2b5z311FN+w0ieH3PHHXf470kakNIzJ+esNGhkP8R67KQBNGTIEK979+5+HWkwyutKyJ6cs2avDQDkOKkAkAReffXVVLlkzZs3L3X//v2pX3zxReqECRNSS5QokVqwYMHUL7/80q936623+vUefPDB0POXLFni/33cuHGhv8+ePTv093379qXmy5cvtX379qm//fZbUO+hhx7y68n6YxYuXOj/Tf4XJ06cSK1YsWJq+fLlUw8ePBh6Hb2ue++9139eFPn7I488Ejzu1KmTvz07d+4M/rZ3797UwoULpzZt2vSk/dOyZcvQaw0cODA1d+7cqYcOHTqprvwfz9SpU/16K1eutNbZvXu3X0cfA7F8+XL/7/L6MUeOHDnp+W+++aZf74MPPjhp+2Td8Z7bp0+f1EKFCqUePXo0+JvsO3munCNas2bN/H/msatRo0bqL7/8Evz96aef9v++cePGuK89fPjw1Fy5cqXu2bMn+Fvs3Bs2bFiobr169VIbNGgQPF6wYIFfb8CAASetN3bsPvvsM/+4Pf7446Hlsl158uQ56e8AkBMROgUgqbRs2dL/hV3CVySMRH61ll+wzzvvvFA9/euxkF+ezz77bK9Vq1bet99+G/yTsBRZx8KFC/168gu09FxIz4D+hf2+++475bbJr8zSSyJ1zV+b4/1abyO/jssgZgnDkXCjGOktufHGG/3B8PKLvnbXXXeFXkt6Q2Q9e/bsCf4mvRjSponXmyFi70F6Y44fPx63rmyjPgaXXnqp/6u+DHyPkfCfGOn1kP3fqFEj/7H0GsWjnys9J/JceW/S0yBhYOklvSt6/IasU0jPS9Rry3gfee0mTZr4+1COualv376hx7JOvb7Jkyf7xygW5qXFjp30Gkl4lfRm6PNVeoIuuOCC4HwFgJyM0CkASUXGN0haW8nMJCE51apV80NfNFkmMffa9u3b/dCYc889N+4g4tgNudzMadK4kZCkRMK4LrroIs8FCZWRG2l5j6YaNWr4N6ISry9jTmJkHIEW22ZzHEoiZJzFdddd5w0dOtQffyHhWtKgkEaOmRHK3F9CjpOMgYg5cOCAv64JEyYE+zvmVONFJLzqr3/9qx8yZTauEhlrYpPI/vr888/9MK133nnnpP1ovnZsvIW5Tv08OU9krImMubCR81UaMlH7VcTC/AAgJ6OhASCpyC/lsRh2G7kJNhsfclMujYxx48ZFPse8OUxWtmxN/4vKShv5df3tt9/2xw3MmDHDmzNnjj94W8YcyN+kJygt5Nf5jz76yPvjH//o1a1b13++HBcZBC3/2xw6dMhv9Mg4FUlvLOMd5IZeekH+9Kc/xX1uRveX9AZJL5g0kuS1ZCJAGWQvA76lR8h87fRmyzLJemX/z5o1K3Kdad33AJAdaGgAOC3IzamERV122WWhUBiTZE2K/aKsw5Wkd+FUvQLyGkKyLUmIl02iYVTS+JGB7Nu2bTtpmYQLSWNKQsgym4Q3yT8ZnDx+/Hh/sLT0SsgA5RjZX6ZPP/3UzwQlZN/Nnz/f79GQ3oF4zzNJhijJTCXhRDIQOiaWaSwzyWB2eR9jx471evbsGfxdBranl5wn0miTxoutV0PqSGNHsnxJzxAAJCPGaAA4Lciv6fLrtKRENUnGH/nVXEgDQcJSJLuP7gWQjEanUr9+ff/GUOrG1hej1xWb08OsY5JfsiXrk6R21aleZSZ0ueGXrFryK39aJZreVhoHZk+I9ESIX375JfT3adOm+b/yx0hmJsncJFmVYu9FmOtLZL9GPVfG0Tz//PNeZot6bSnH0hWnh4SjyTqk0WWKvU6XLl3815Y65j6Tx2ZKYADIiejRAHBakNAbSS07fPhwb926df4NvDQo5Bd1GSguN45du3YN5j2QepKmVtLbyoBfCWGRuSTikR4GSat6zTXX+DfkMtBYBm7LTb2MMZBfsYUMQBeSYldSocoNpQxsjyJpduXXc2lU3HPPPf74E0lvKzf6koo1PRJNbyu/4svNvMzFIb+wyyDs0aNH+40b2S+azMIt2yiD8GXbpAEhqWAlJa+Q50hvhGyzDCyXgeMy0D2RXgkZeC3jHG699VZ/n0mPkKQGTk84WFpJqJS8dzknpCEl70MGc6dnzEuMzMsis58/88wz/vkXCx2T9LayTNIRy2vKsf/zn//sNzJlbIzMCSP7S46fDPqXbQKAnIyGBoDTxosvvujf5MuN+kMPPeTftEtoz8033+yHVMXIDZ6MAZD6kt1HsifJTbFMGHcq0nCQ58gv0TKWQW4g5abxzjvvDOrIr9WS1UrCj2RCNrlhtjU0ZKC33IDKDac0fmR9sj3yPHMOjcxonEnPhGyn9KJI1i4ZIyPjXKTnRpOwImloSQNDBnpLPZmjRBpaMdILI+9bBvTLe5bGnjTgZGB0PNJgkcxXgwcP9geES6NDjplMrif7OzNJY1TGp0gDR/a/nBfS8JLGQJ06ddK9XmnkyRwfL7/8sj9mRfatjD2SRlWMzF0iYVMyED/W+yGhcrLfOnbs6OT9AUBmyiU5bjP1FQAASAO5+ZbxH5JRy8weBgBIHozRAADkKF9//bUfHhUv/SsAIOcjdAoAkCNIeJak05WQtcaNG/sZtwAAyYseDQBAjrBlyxZ/vIIMLB8zZkx2bw4AIIMYowEAAADAOXo0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAczQ0AAAAADhHQwMAAACAc3ncrxIAAADImNTUVOuyXLlysXuTAD0aAAAAAJyjoQEAAADAORoaAAAAAJxjjAZOae/evaHHW7dudbrXcufOHZTz588fWnb22WcH5ZSUlKBcvHhx6zqy0qFDh4Ly+vXrQ8t+/fXXhNZx7rnnBuWaNWvm6PjTX375JfR47dq1QfnIkSPW55UvXz4oV65c2ctMhw8fjjwmJ06cSOj5+jwTtWrVytHH5NixY5HHQ/z000+Rzylbtmzo8QUXXJAt71GfT2vWrAkt+/nnn73spj+PomTJkpGx4xs3bgzV+/bbbyPXV6JECeu5dcYZ//+73759+0L1Nm/eHPm68ZQpUyb0uFq1amk+xr/99ltQ3rBhQ2jZgQMHIp9TtWrVuOdaRultMvf7d999l9A68uT5/1ufevXqBeXChQt7Lpnngd7eRI9jqVKlQo9r1KiR5uOYnnPVvE7ra3h66e3Yv39/UF63bl2o3ieffGL9LNjO8dq1a0d+rkSxYsW8jPrxxx8jr7PHjx/3coKCBQsG5fr164eWmfdVWY0eDQAAAADO0dAAAAAA4Fyu1ET773Daeu2110KP+/Xrl+bwoHh0yIDu0haFChUKyqVLlw7KTZo0CdXr2bNnZLehXndmGDNmTFAeMGBAaFmi+6Zhw4ZBefLkyU67ezM7jK5du3ZBefv27dbnDR48OCgPHTo0U8N0Jk6cGJT79OkTGWIUT926dUOPp02bFpTPOeccL6fRIQjXXHNNaJkZJhFz9913hx6PGDEiyz4ztvOpbdu2oWU7duzwsoN+/6NGjQotu+GGGyLPp5tvvjlUb+bMmZHrbtWqVejxhAkTgnKBAgWC8tSpU0P1br/99jSHaujrinj77betoac2OhyyR48eoWXz58+P/Bw/8cQT1u8LF7755pug3KFDB2vITTw6lEQf465du3ouvffee6HH+jwxw1BtzDAYfRx1KF88+rVuueWWhM5VfZ0W999/v5eRcCMxadKkyP2+adMma8hnvFtU/VktUqRIUL7kkktC9QYNGhSUW7RoEZTz5s2b0Pswt1Gfd/r6m50qVaoUlGfNmpWp4YtpRY8GAAAAAOdoaAAAAABwjqxTOCUzBEh3p5vLdFdm0aJFrV2UujtUZwMyMxfpDEI6zMLMUKO7qF944QVrqIILuhv6nXfeCco//PBDqF6+fPkiM0Lo92Rm3NChLk2bNvVyGp3xxcwMFC/rVGZm5jBDombMmBGZFcw8B88888zIema4kc4w0rp1ay+n0Z8lM1OT7ZgkGkaWleeTua22bdfHTZx11llOt0lfwxINrTh69GhC227WszEzpOn1JfpZWr58uTXUqVu3bl5amaE+ept06FSi2d3Sa/Xq1dZQKb1NOvT0+++/t9bToUMdO3a0XsNdf3cmGjq1cuXK0OM5c+ZEhgwnKtFzNb3XbH0tffjhh0PLXnnllcjwKDNrpM5wVaFCBeu5tW3btsjsWXPnzg3V0xnTnnzyyaB84403hurFCxu1XWfjfe/prJmFVBh4ZnB9HXSJHg0AAAAAztHQAAAAAOAcDQ0AAAAAzjFGA07pmMTRo0cH5erVq1ufo2NBv/7669CyKVOmBOVx48ZZ4yJ37twZlJ977rmgfNlll4XquYiT1K9lxs/a0s01atQoKI8dO9Ya0zp79uygfPnll4fqZWXa0WTy+eefhx4vXbo0st55550XeqzH77z88svWlIx6/I9OjZhds9Gf7sw0qzp1pQt6vIE5u3YyMa+R+rpz9dVXZ9ps2Jk9lufdd9+1vkd9jdTniR63Jb788svI68UXX3wRd3bs7GCO5dDp5vWYEj0mMqvpsSh6jKSZHlq/F53OuW/fvqF699xzT2RaezPVrR7f+MADDwTlZcuWWVMiDxkyJChfdNFFcVObp4c+B3Va927pGBeVFno8WaJpj7MKdy4AAAAAnKOhAQAAAMA5Qqfg9oRSM3vr0KELL7wwoefXqVMn9FiHD+3bty8oT58+3boOnfLOTDnrInRq0aJF1lAvrXbt2kG5U6dOkbOjmmn+dFq+gQMHhurlxFmpcwIzVMqcvdzWTa6PiZ5N3Ew/rI+37oJP5rCaZGbOap3oteX3yLwm6LTF5nmsPycfffRRUG7Tpo2X0+nP9OLFixMK3e3cubN1lnkdOqVDLz/++ONsC51KSUmJTOOqQ2vNtMV6X1x77bVedtmyZUtQfumll6xhXzossXv37kF52LBhoXqJhvNdccUVQfmxxx6zhlfq1Ld79uyJDEMTtWrVchoaW6pUqYTCx3/v6NEAAAAA4BwNDQAAAADOETqFHE3PAqxnCE00fMtFpiYzC9GsWbMis22YXa067KtBgwZB+fzzzw/V27p1a2TYl86okVmznCcrPbutzgplho/o429m8apXr17kubV+/fpQPR12oWclJnQK2c0M7dGZh3QGO3N2bJ2BqlmzZqF6OhuQ/vzosJestmrVqqC8a9cuaz0drqvDcM3P/rx58yKvFzqjlZkpKH/+/F5m0rNhly1b1homrL+P9HHUGfHMmaL1sXPxnWhmf9LfiWbmLk3P1t67d2+nmc8aNmwYlLt27Rpapr9XdbhhuXLlQvV0yBpZBd2hRwMAAACAczQ0AAAAADhHQwMAAACAc4zRQI5ixn7qVHQ6rZ9Jx6BefPHFkekO08tMjajj9G3xp+Zs4HqmTj02wByjoeNvddyrGYN7us8SruO0450XOvbXnCX+3HPPDcqXXnqpdYyGTj+sx4Po2ZXNmVmBrKDHo5mx6R988EFomZ5Fe/78+daxYPq6pWXlGI3jx4+HHs+cOTNyfJZJf45LlCgRmQbVHPunr7nmjNL6+6dq1apeVh3LLl26RB4rc3t1ets1a9aE6jVt2jTTjqOZtnbJkiWRs7jHG0NTs2ZNzyU9JmXkyJHWenqsjTkO43T/Xs0s7FUAAAAAztHQAAAAAOAcoVPIkrSjmzZtsj5Hp5TTXdVmusG1a9daUw02btw4KA8aNCgo58uXz8toCJfZda1nh9bMrvUqVapEhtU0b948VO+tt96K7HY2X3f//v2RoVinIx0y8NVXXyWU/tOcmVV3m1955ZWRKSPN9JfxZoU3UyUic5gzv5uzOaeHTlWs04zmdGaYik5VW7t27dAyHRakryXjx48P1dOpuHWYTVam+zQ/02YYmO17QH+O9faas8frdNb6u0nPGG6+bmaHTulU6U2aNAnK9evXt27TwYMHg/Ibb7xhTffqOk3xDz/8EHocL6WtLXRKhzq5lpnrTsu9gz63ZhvpptNDhwOa50VOTsdLjwYAAAAA52hoAAAAAHCOhgYAAAAA5xijAad07OaQIUMSigvVMY1mzLF+no6d7ty5c6he7969g3KNGjW8jNIpBOfMmRNaprdRb5+Oq42XWlenYBQpKSlBed++fda0uitXrgzKHTp08E43Oj2nHv+jY5vj7evixYtb6+mUyDrtrRm3vXv3bmtaXcZoZI1JkyaFHk+fPj3D6xw4cGBQHjp0aLakdE0P83qpY7i7detmvX7oz8yMGTNC9fr27Rs5xikr94U57sYcuxdTunTp0GMzdbjts6/HL+g4evNaotPq3nzzzUG5QIECXmYey6JFiwblHj16WPeNTgNspkPfsmVLZCpZF7H8P//8s/X7Mh59fpqpmX8v9P3MCy+8EJRHjRqV4XVfddVVQXnixIk5ZlzKqdCjAQAAAMA5GhoAAAAAnPt99l0h2+g0emXLlk0ozazuajS7YHX6Pp3y8PXXXw/V07Pb9u/f3xpilOjszXq2bnPmXK1gwYKRqSXjzTJqps/UaVd16JQOFTLDhdq2bZsUae1c0qFkttnZzXNNdzXHm/VVn6vmjLU6dErPiKvDKkSnTp2CMrOEZx7zOKY3hbWWrGEc+tpp6tixY+ixDuPQn6XPP/88VG/atGlBefDgwVl2ndGfLX2tM5dptWrVCj0+//zzI+uZn0edBld/l+hU1ma42a5du6zpcl3QYVv6uLZr1y5U79lnnw3Kn3zyiTXt8+TJkyO3Nztnvz7dZt7W4Uxnqtno00uH1CXTvkyeLQUAAACQNGhoAAAAAHAuOfuLkWPprr3Ro0cH5YsuusiaYUOXDx8+HKq3fv36yK5/MyvJwoULI7uTzZljzW5oTXdXz507Nyh/++231ufo7lBzxnDbTKBmpphixYp5aZ0NW89KrcN+fs/0MbbNzi4KFSoUlA8cOJDmmVlt2cJMS5cutYZYVaxYMaF1IO2uv/760OMBAwZkeDeWKlUqaTJNaWaWJH0NM89BHdr31FNPWa9HOqvXTTfdlGWhGnp2afOzpenjU6RIkdCyRYsWJfRaOiRXh7+aoVP//e9/g/KCBQsyNXRKHztdNsPBunbtGpQfe+wx63GcMmVKUO7Vq5fTEDjzezXRLFw6K2W8sL9kpj8ngwYNsl630kOfq7qc09GjAQAAAMA5GhoAAAAAnKOhAQAAAMA5xmjAKR3/qWdYLlOmTLrWp9MX1q9f3zoz+KeffhoZvz9mzJhQvRYtWljjTPX4ED3LarxYUj1+49577w0tSzTWO97M1ppOr7hixYpsGaNhxgEnuu3pSfdqpjrWKS/N7bAdx/vuuy/TjoeOKRcffvhhtozR0Nt74sSJdKVzTaZxCebM7fq6AM8ai69nmB43blzkeC9zRul58+ZlWcpmPS5DpzI36euxOTuyTukaj16HOS7D9tnS15/bbrstVM/FrMy2a5o5NkbP+K6/38w0xdu3b48cn+biOJqpWkuWLBk503q8a6ZO3164cGHPJfO7Qx9HPbbBvA66Hoek98sFF1zgna7o0QAAAADgHA0NAAAAAM4ROoWkUaVKFWt6QR06pe3evTv0WHfXmqFTuss3Xvev7l5NSUlx2iWt0y6aM4MfPXo0KL/77rtB+ZprrgnVy8wQh++//z5uF7UtFKdEiRLWZTbbtm0LPV67dm1Cr6WPiYtZow8dOhSUf/rpJ2vIhT4m3bt3d7oNiaaMNI+PbT/pfWQuQ/Iww23ihRTqFONt2rQJymPHjrXOwq1Dk1zMbGy7nomZM2cG5ePHj1ufp7dDp1NPL/05NlOZ6xCrNWvWWL9vXITv2dLbmqpVqxaU27dvH5RffPHFUD29D/VxNEMP00OnEBd169aNTANsvg+93/bs2WNNf58eOjxq5MiRoWU6BFCHM9WpUydUr3///kG5ePHiGd4m/A89GgAAAACco6EBAAAAwDlCp5A0dDYhPWNrPPGySpjduu+//37ka5lKly4dGXZgzuCaKB3uMGLEiKD86quvWp+zZMkSa4aWChUqJNQFn55wmXXr1llDvTRz1tLKlSsntH69vbq7W3z33XeRzznnnHNCj/V+0+F2iTLDT5599tmg/Pzzz0duq1i2bFlQ/uyzz4Jy1apVQ/VcH5MNGzYkNIu9DuEytwm/fzpU9JZbbgnK06dPt36mdXY7PXu6C2ZYq/78xHPXXXcF5b59+2Z4O3SI5h133BFatn///siyDg8S9erVy7IwRB0ae+ONN0bO6G5eC1avXh2Uy5Url+FtMLMztW3bNij/5z//sX6P6gxnent1OFh6w391WNaECROsy+Jd6zM7zPV0RY8GAAAAAOdoaAAAAABwjoYGAAAAAOcYowGndPy5Tn2qU3DGi5PUqRXN2U5Hjx59ylSnZoysjp01Z3A1xxfMnTs38n2YateuHZQbNmwYue706tChQ1B+8803rekg9QyrH330kXWMhp6tWuzbty8ot27dOqHUlTt37gzKL730UmiZbVbdSpUqWdMfxqPTs86ZMyeh1J01a9YMPb7sssuC8tlnn+1lVLt27SJn4tWpbsXevXuD8gcffGAdD7Fq1SprnLp+rXjHRH8uRo0aZU0ZqpUvXz4oN2jQwFoPv3+XXnppUL7yyitDy6ZOnRqZ2jle6uT00OPMomYo1/RnQad0dTHWSKd7NccK6HEZ+vqjZwkXd955p9NrTqJ0Wt1WrVqFlunvD/39q2d+d0V/D+prmDlWQu9DnY7XHMPXpUuXyGNvfi/r6+Cjjz4alHfs2GHdVj2m7+677w4tc/Edrunr8Y+WVPCZwRwjmTt3bi870aMBAAAAwDkaGgAAAACcI3QKTul0dv369UsoDETP6GmGWOlQH53e1OxC1eFSOpSmT58+1i5EM1WrrUvZTOXXrFmzTJstV89Uet5551lDmHSImZ6RWnTq1Ckojx8/PrRMP77iiiuC8iWXXGLt8tVpf+OFrBUoUCAo9+7dO7SsTJkyXiI2b94clNevX2+tp4/35ZdfHlpWpEgRzyUdKqdTGG/dutU6E68OrejRo0eonp6l1wxF02FfjRo1ivyMmMdEh2LFS2l6++23O01xieSlQ0RuvfVWawipDveIN+t4oo4cOWK9bpnnuC0c9MILL/Rc0qFO+ppohp7q75yNGzdavzv05zaz6Rm6e/bsGVqmr0H6e9nFcYx3Pg0ZMsQaGrp8+fLI7/YBAwZYr5G1atWKPH/EokWLgvKmTZus71HPIP+Xv/wlKDdv3txzTZ8nzzzzTOR7yuxzYfjw4c5nrs8IejQAAAAAOEdDAwAAAIBzNDQAAAAAOMcYDTil49TjxdgnEntvjo/QKdt0SkJx1VVXRcZ76vh6M3Zz1qxZoWW29I1mzH+TJk2s25tRZcuWtabm1WM0tI8//jj0WKf2M8ey6PevY6TNeOlEFS9ePDLFY69eveKOc7Ftn44P16k1TXpsjBlX7fqY6HNNpwU1x2hoK1asCMrbtm0LLTtx4oR1v+iUvrNnz07ztupYZDP+vm/fvjkm3SHcMGPR46XlttFjzsxUpfPnz/dc0tem1atXJ/w8PYYsJSXF6Tbp64W5L3SMvU5nrccLmtctvf/SIj3Hzja+yxy7NnPmTC+rVK9ePSi/8soroWV///vfI8eQmKnm9fYmuu358uWzjuMZPHhwUO7WrVtQzps3r+eaPo76fN8RJ+WuC/o7UY/JyQno0QAAAADgHA0NAAAAAM7lSs1ofx1+9zZs2BB6rEOOXJ8+efLksaYe1KlFzZlEdbpOndIzXmjXlClTQsvMVHy2cJQbbrghcvtc07NLR80Abuv+1bOqFitWzLqOBQsWWMN7dFpLnbbWnDlXz9KrQ5j0c+IxU1pOnz49KH/66afW5+lwNjN9rA7nck2HqS1evDih87hjx46hZaVLl45cn5g3b15kaJaZ9lmf43p2ZD0rr2jatGlk+sOcQp9nZirmAwcORD7HDE3RYZPZFbZkXktsYRKVKlUKPb7uuusiw9nMsLwZM2ZEfmZKliwZqnf99ddn+HjrGbt1etdEtWzZMvT44osvjvxM68/6qdLb6tnLMzN9rHnO6XMy3szO+jN47bXXWr9Tpk2bFhlCKUqUKBH5HZPe2aqXLVsWmQY2UWYYWePGjb2M0vtwzZo11hA9nT5Yh9Ca9wc61Fh//7Ro0SJUT987pDe0Vs8Sr2ddN1PuZpe86j6ga9euoWXly5f3shM9GgAAAACco6EBAAAAwDlCp4DTlA790CFl5mPdXa0ze8TLJgW3x8Q8Pnq/6zA1jgeAZGOGYOvr3bFjx6zP05koyaSXc3GXAAAAAMA5GhoAAAAAnKOhAQAAAMA5xmgAAAAAcI4eDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAADO0dAAAAAA4BwNDQAAAACea/8HZrJ+O5pikPQAAAAASUVORK5CYII=",
908
+ "text/plain": [
909
+ "<Figure size 1000x200 with 1 Axes>"
910
+ ]
911
+ },
912
+ "metadata": {},
913
+ "output_type": "display_data"
914
+ },
915
+ {
916
+ "data": {
917
+ "text/plain": [
918
+ "'sbajafinance'"
919
+ ]
920
+ },
921
+ "execution_count": 29,
922
+ "metadata": {},
923
+ "output_type": "execute_result"
924
+ }
925
+ ],
926
+ "source": [
927
+ "predict_onnx('image copy 2.png', 'export/model.onnx')"
928
+ ]
929
+ }
930
+ ],
931
+ "metadata": {
932
+ "kernelspec": {
933
+ "display_name": "myenv",
934
+ "language": "python",
935
+ "name": "python3"
936
+ },
937
+ "language_info": {
938
+ "codemirror_mode": {
939
+ "name": "ipython",
940
+ "version": 3
941
+ },
942
+ "file_extension": ".py",
943
+ "mimetype": "text/x-python",
944
+ "name": "python",
945
+ "nbconvert_exporter": "python",
946
+ "pygments_lexer": "ipython3",
947
+ "version": "3.10.0"
948
+ }
949
+ },
950
+ "nbformat": 4,
951
+ "nbformat_minor": 4
952
+ }
frontend/assets/index-CPxYcdRp.js ADDED
The diff for this file is too large to render. See raw diff
 
frontend/index.html ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en" class="light">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <title>Lumina AI OCR</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <script>
9
+ tailwind.config = {
10
+ darkMode: 'class',
11
+ theme: {
12
+ extend: {
13
+ colors: {
14
+ primary: {
15
+ 50: '#eef2ff',
16
+ 100: '#e0e7ff',
17
+ 200: '#c7d2fe',
18
+ 300: '#a5b4fc',
19
+ 400: '#818cf8',
20
+ 500: '#6366f1', // Indigo-500
21
+ 600: '#4f46e5', // Indigo-600
22
+ 700: '#4338ca', // Indigo-700
23
+ 800: '#3730a3',
24
+ 900: '#312e81',
25
+ },
26
+ dark: {
27
+ 800: '#1e293b',
28
+ 900: '#0f172a',
29
+ 950: '#020617',
30
+ }
31
+ },
32
+ animation: {
33
+ 'fade-in': 'fadeIn 0.6s cubic-bezier(0.16, 1, 0.3, 1)',
34
+ 'slide-up': 'slideUp 0.5s cubic-bezier(0.16, 1, 0.3, 1)',
35
+ 'scan': 'scan 2.5s ease-in-out infinite',
36
+ 'pulse-glow': 'pulseGlow 2s infinite',
37
+ },
38
+ keyframes: {
39
+ fadeIn: {
40
+ '0%': { opacity: '0', transform: 'scale(0.98)' },
41
+ '100%': { opacity: '1', transform: 'scale(1)' },
42
+ },
43
+ slideUp: {
44
+ '0%': { transform: 'translateY(20px)', opacity: '0' },
45
+ '100%': { transform: 'translateY(0)', opacity: '1' },
46
+ },
47
+ scan: {
48
+ '0%, 100%': { top: '0%' },
49
+ '50%': { top: '100%' },
50
+ },
51
+ pulseGlow: {
52
+ '0%, 100%': { opacity: '0.6', transform: 'scale(1)' },
53
+ '50%': { opacity: '1', transform: 'scale(1.05)' },
54
+ }
55
+ }
56
+ },
57
+ },
58
+ }
59
+ </script>
60
+ <style>
61
+ /* Custom scrollbar for text areas */
62
+ .scrollbar-thin::-webkit-scrollbar {
63
+ width: 6px;
64
+ height: 6px;
65
+ }
66
+ .scrollbar-thin::-webkit-scrollbar-track {
67
+ background: transparent;
68
+ }
69
+ .scrollbar-thin::-webkit-scrollbar-thumb {
70
+ background-color: #cbd5e1;
71
+ border-radius: 20px;
72
+ }
73
+ .dark .scrollbar-thin::-webkit-scrollbar-thumb {
74
+ background-color: #475569;
75
+ }
76
+
77
+ /* Scanning line glow */
78
+ .scan-line {
79
+ background: linear-gradient(to right, transparent, #6366f1, transparent);
80
+ box-shadow: 0 0 15px #6366f1, 0 0 30px #818cf8;
81
+ }
82
+ </style>
83
+
84
+ <link rel="stylesheet" href="/index.css">
85
+ <script type="importmap">
86
+ {
87
+ "imports": {
88
+ "react/": "https://esm.sh/react@^19.2.3/",
89
+ "react": "https://esm.sh/react@^19.2.3",
90
+ "react-dom/": "https://esm.sh/react-dom@^19.2.3/",
91
+ "lucide-react": "https://esm.sh/lucide-react@^0.562.0",
92
+ "axios": "https://esm.sh/axios@^1.13.2"
93
+ }
94
+ }
95
+ </script>
96
+ <script type="module" crossorigin src="/assets/index-CPxYcdRp.js"></script>
97
+ </head>
98
+ <body class="bg-gray-50 text-slate-900 dark:bg-dark-950 dark:text-slate-100 transition-colors duration-300 min-h-screen font-sans selection:bg-primary-200 selection:text-primary-900">
99
+ <div id="root"></div>
100
+ </body>
101
+ </html>
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ numpy<2
3
+ pillow
4
+ rapidfuzz
5
+
6
+ # ONNX
7
+ onnx
8
+ onnxruntime
9
+
10
+ # Web
11
+ fastapi
12
+ uvicorn
13
+ python-multipart
14
+
15
+ # OCR
16
+ hezar
17
+ opencv-python-headless