Vaibhav Gaikwad commited on
Commit
a80a32e
·
1 Parent(s): 31e30cc

deploy audiolens backend — dit + easyocr + kokoro

Browse files
Files changed (3) hide show
  1. app.py +308 -4
  2. j2_preprocess.py +127 -0
  3. requirements.txt +13 -0
app.py CHANGED
@@ -1,7 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ """
2
+ audiolens — app.py
3
+ huggingface space backend (zerogpu + fastapi + gradio)
4
+
5
+ endpoints:
6
+ POST /classify — document type classification (dit-base)
7
+ POST /ocr — text extraction (easyocr)
8
+ POST /speak — text to speech (kokoro)
9
+
10
+ preprocessing (opencv) runs inline — no separate endpoint needed.
11
+ llm extraction (gemini) is called directly from the pwa — not here.
12
+
13
+ models load once at startup into cpu ram (except easyocr which
14
+ lazy-inits inside the gpu function so it can bind to cuda).
15
+ gpu is grabbed per-request via @spaces.GPU and released immediately after.
16
+ """
17
+
18
+ import io
19
+ import os
20
+ import tempfile
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+
24
+ import numpy as np
25
+ import cv2
26
+ from PIL import Image
27
+
28
+ import torch
29
+ import spaces
30
  import gradio as gr
31
+ from fastapi import FastAPI, File, UploadFile, HTTPException
32
+ from fastapi.middleware.cors import CORSMiddleware
33
+ from fastapi.responses import JSONResponse, FileResponse
34
+ from pydantic import BaseModel
35
+ from starlette.background import BackgroundTask
36
+
37
+ from j2_preprocess import preprocess
38
+
39
+
40
+ # ============================================================
41
+ # -- app setup --
42
+ # ============================================================
43
+
44
+ app = FastAPI(title='audiolens api')
45
+
46
+ # allow pwa to call from any origin
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=['*'],
50
+ allow_methods=['*'],
51
+ allow_headers=['*'],
52
+ )
53
+
54
+ # dit maps its 16 rvl-cdip classes to audiolens categories
55
+ # indices must match the 9 classes we trained with in j1
56
+ DIT_CLASS_MAP = {
57
+ 0: 'letter',
58
+ 1: 'form',
59
+ 2: 'email',
60
+ 3: 'handwritten',
61
+ 4: 'advertisement',
62
+ 7: 'specification',
63
+ 9: 'news_article',
64
+ 10: 'budget',
65
+ 11: 'invoice',
66
+ }
67
+ SELECTED_RVL_IDX = list(DIT_CLASS_MAP.keys())
68
+
69
+
70
+ # ============================================================
71
+ # -- model loading (runs once at startup, cpu ram) --
72
+ # ============================================================
73
+
74
+ print('loading models...')
75
+
76
+ # -- classifier: dit-base (loads to cpu at startup) --
77
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
78
+
79
+ dit_processor = AutoImageProcessor.from_pretrained('microsoft/dit-base-finetuned-rvlcdip')
80
+ dit_model = AutoModelForImageClassification.from_pretrained('microsoft/dit-base-finetuned-rvlcdip')
81
+ dit_model.eval()
82
+ print('dit-base loaded.')
83
+
84
+ # -- ocr: easyocr (lazy-init inside gpu function so it binds to cuda) --
85
+ ocr_reader = None
86
+ print('easyocr will lazy-init on first ocr request.')
87
+
88
+ # -- tts: kokoro (loads to cpu at startup) --
89
+ import soundfile as sf
90
+ from kokoro import KPipeline
91
+ kokoro_pipeline = KPipeline(lang_code='b') # b = british english
92
+ print('kokoro loaded.')
93
+
94
+ print('all models ready.')
95
+
96
+
97
+ # ============================================================
98
+ # -- request schemas --
99
+ # ============================================================
100
+
101
+ class SpeakRequest(BaseModel):
102
+ text: str
103
+ voice: str = 'bf_emma'
104
+
105
+
106
+ # ============================================================
107
+ # -- helpers --
108
+ # ============================================================
109
+
110
+ def bytes_to_pil(image_bytes):
111
+ """converts raw image bytes to a pil image."""
112
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
113
+
114
+
115
+ def bytes_to_cv2(image_bytes):
116
+ """converts raw image bytes to a bgr numpy array for opencv."""
117
+ arr = np.frombuffer(image_bytes, np.uint8)
118
+ img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
119
+ if img is None:
120
+ raise ValueError('could not decode image — check the file format')
121
+ return img
122
+
123
+
124
+ # ============================================================
125
+ # -- endpoint: health check --
126
+ # ============================================================
127
+
128
+ @app.get('/health')
129
+ async def health():
130
+ """simple ping to check if the space is warm."""
131
+ return {'status': 'ok', 'models': ['dit-base', 'easyocr', 'kokoro']}
132
+
133
+
134
+ # ============================================================
135
+ # -- endpoint: classify --
136
+ # classifies the document type from the uploaded image.
137
+ # uses zerogpu for inference, releases gpu immediately after.
138
+ # ============================================================
139
+
140
+ @spaces.GPU
141
+ def _run_classify(pil_image):
142
+ """runs dit-base inference on gpu. called inside classify endpoint."""
143
+ dit_model.to('cuda')
144
+ inputs = dit_processor(images=pil_image, return_tensors='pt').to('cuda')
145
+
146
+ with torch.no_grad():
147
+ logits = dit_model(**inputs).logits
148
+
149
+ # no need to move back to cpu — zerogpu reclaims on function exit
150
+
151
+ # slice to our 9 selected classes and get the winner
152
+ selected_logits = logits[0, SELECTED_RVL_IDX]
153
+ pred_idx = selected_logits.argmax().item()
154
+ confidence = torch.softmax(selected_logits, dim=0)[pred_idx].item()
155
+ doc_type = DIT_CLASS_MAP[SELECTED_RVL_IDX[pred_idx]]
156
+ return doc_type, round(confidence, 4)
157
+
158
+
159
+ @app.post('/classify')
160
+ async def classify(file: UploadFile = File(...)):
161
+ """
162
+ classifies a document image into one of 9 categories.
163
+
164
+ returns:
165
+ doc_type — e.g. 'invoice', 'letter', 'form'
166
+ confidence — float 0–1
167
+ """
168
+ try:
169
+ image_bytes = await file.read()
170
+ if not image_bytes:
171
+ raise HTTPException(status_code=400, detail='empty file uploaded')
172
+
173
+ pil_image = bytes_to_pil(image_bytes)
174
+ doc_type, confidence = _run_classify(pil_image)
175
+ return JSONResponse({'doc_type': doc_type, 'confidence': confidence})
176
+
177
+ except HTTPException:
178
+ raise
179
+ except Exception as e:
180
+ raise HTTPException(status_code=500, detail=str(e))
181
+
182
+
183
+ # ============================================================
184
+ # -- endpoint: ocr --
185
+ # preprocesses the image (cpu, outside gpu) then runs easyocr
186
+ # on gpu via zerogpu. easyocr lazy-inits on first call so it
187
+ # binds to the cuda device provided by zerogpu.
188
+ # ============================================================
189
+
190
+ @spaces.GPU
191
+ def _run_ocr_gpu(clean_image):
192
+ """runs easyocr inference on gpu. reader lazy-inits on first call."""
193
+ global ocr_reader
194
+ if ocr_reader is None:
195
+ import easyocr
196
+ ocr_reader = easyocr.Reader(['en'], gpu=True, verbose=False)
197
+ print('easyocr initialised on gpu.')
198
+
199
+ results = ocr_reader.readtext(clean_image, detail=0)
200
+ return ' '.join(results)
201
+
202
+
203
+ @app.post('/ocr')
204
+ async def ocr(file: UploadFile = File(...)):
205
+ """
206
+ extracts all text from a document image.
207
+ preprocessing (deskew, denoise, contrast, binarise) is applied first.
208
+
209
+ returns:
210
+ text — raw extracted text string
211
+ """
212
+ try:
213
+ image_bytes = await file.read()
214
+ if not image_bytes:
215
+ raise HTTPException(status_code=400, detail='empty file uploaded')
216
+
217
+ cv2_image = bytes_to_cv2(image_bytes)
218
+
219
+ # preprocessing runs on cpu — outside the gpu-decorated function
220
+ clean = preprocess(cv2_image)
221
+
222
+ # ocr inference on gpu
223
+ text = _run_ocr_gpu(clean)
224
+ return JSONResponse({'text': text})
225
+
226
+ except HTTPException:
227
+ raise
228
+ except Exception as e:
229
+ raise HTTPException(status_code=500, detail=str(e))
230
+
231
+
232
+ # ============================================================
233
+ # -- endpoint: speak --
234
+ # converts text to speech using kokoro and returns a wav file.
235
+ # kokoro runs on gpu via zerogpu.
236
+ # temp wav file is cleaned up after the response is sent.
237
+ # ============================================================
238
+
239
+ @spaces.GPU(duration=30)
240
+ def _run_tts(text, voice='bf_emma'):
241
+ """runs kokoro tts on gpu. called inside speak endpoint."""
242
+ chunks = []
243
+ for _, _, audio in kokoro_pipeline(text, voice=voice, speed=1.0):
244
+ chunks.append(audio)
245
+ if not chunks:
246
+ return None
247
+ return np.concatenate(chunks)
248
+
249
+
250
+ @app.post('/speak')
251
+ async def speak(req: SpeakRequest):
252
+ """
253
+ converts text to speech using kokoro.
254
+
255
+ json body:
256
+ text — the text to synthesise
257
+ voice — kokoro voice id (default: bf_emma — british female)
258
+
259
+ returns:
260
+ audio/wav file
261
+ """
262
+ try:
263
+ if not req.text or not req.text.strip():
264
+ raise HTTPException(status_code=400, detail='text cannot be empty')
265
+
266
+ audio_array = _run_tts(req.text, req.voice)
267
+
268
+ if audio_array is None:
269
+ raise HTTPException(status_code=500, detail='tts produced no audio')
270
+
271
+ # write wav to a temp file
272
+ tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
273
+ sf.write(tmp.name, audio_array, 24000)
274
+ tmp.close()
275
+
276
+ # return the file and clean up after response is sent
277
+ return FileResponse(
278
+ tmp.name,
279
+ media_type='audio/wav',
280
+ filename='audiolens_output.wav',
281
+ background=BackgroundTask(os.unlink, tmp.name),
282
+ )
283
+
284
+ except HTTPException:
285
+ raise
286
+ except Exception as e:
287
+ raise HTTPException(status_code=500, detail=str(e))
288
+
289
+
290
+ # ============================================================
291
+ # -- gradio ui --
292
+ # minimal gradio interface — required by zerogpu.
293
+ # pwa users never see this. it just satisfies the hf spaces sdk.
294
+ # ============================================================
295
+
296
+ with gr.Blocks() as gradio_ui:
297
+ gr.Markdown("""
298
+ ## AudioLens API
299
+ **This space provides the AudioLens backend API.**
300
+ Use the endpoints below from the AudioLens PWA:
301
+ - `POST /classify` — document type classification
302
+ - `POST /ocr` — text extraction
303
+ - `POST /speak` — text to speech
304
+ - `GET /health` — check if space is warm
305
+ """)
306
+
307
+ gr.Markdown("_This UI is for reference only. The AudioLens PWA calls the API directly._")
308
 
 
 
309
 
310
+ # mount fastapi on gradio — zerogpu requires gradio sdk
311
+ app = gr.mount_gradio_app(app, gradio_ui, path='/gradio')
j2_preprocess.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ audiolens — j2 image preprocessing
3
+
4
+ prepares a raw phone-captured document image for ocr.
5
+ each preprocessing step is a separate function so they can be
6
+ tested, tuned, or swapped out individually as needed.
7
+
8
+ pipeline order:
9
+ 1. to_grayscale — converts colour input to grayscale
10
+ 2. deskew — corrects tilt from phone capture angle
11
+ 3. denoise — removes grain and compression artifacts
12
+ 4. enhance_contrast — applies clahe for local contrast improvement
13
+ 5. binarise — converts to clean black/white via otsu threshold
14
+ 6. preprocess — runs all steps in order (main entry point)
15
+
16
+ no downloads needed. import preprocess() directly into the pipeline.
17
+ """
18
+
19
+ import numpy as np
20
+ import cv2
21
+
22
+
23
+ def to_grayscale(image):
24
+ """
25
+ converts a bgr colour image to grayscale.
26
+ if image is already grayscale, returns a copy unchanged.
27
+ """
28
+ if len(image.shape) == 3 and image.shape[2] == 3:
29
+ return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
30
+ return image.copy()
31
+
32
+
33
+ def deskew(gray):
34
+ """
35
+ detects and corrects the dominant tilt angle of the document.
36
+ common when a user photographs a document at a slight angle.
37
+
38
+ uses the minimum area bounding box of dark pixel clusters to
39
+ estimate the skew angle, then rotates to correct it.
40
+ angles under 0.5 degrees are ignored to avoid introducing
41
+ unnecessary interpolation artifacts on near-straight images.
42
+ """
43
+ coords = np.column_stack(np.where(gray < 128))
44
+
45
+ # not enough dark pixels to estimate angle reliably
46
+ if len(coords) < 50:
47
+ return gray
48
+
49
+ angle = cv2.minAreaRect(coords)[-1]
50
+
51
+ # minAreaRect returns angles in [-90, 0) — normalise to [-45, 45]
52
+ if angle < -45:
53
+ angle = 90 + angle
54
+
55
+ # skip tiny corrections
56
+ if abs(angle) < 0.5:
57
+ return gray
58
+
59
+ h, w = gray.shape
60
+ center = (w // 2, h // 2)
61
+ matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
62
+ rotated = cv2.warpAffine(
63
+ gray, matrix, (w, h),
64
+ flags=cv2.INTER_CUBIC,
65
+ borderMode=cv2.BORDER_REPLICATE,
66
+ )
67
+ return rotated
68
+
69
+
70
+ def denoise(gray):
71
+ """
72
+ removes noise, grain, and jpeg compression artifacts from the image.
73
+ uses opencv's non-local means denoising which is effective on
74
+ document scans and phone camera captures without blurring text edges.
75
+
76
+ h=10 is a conservative strength — enough to clean grain but
77
+ not so aggressive that it softens thin strokes in small text.
78
+ """
79
+ return cv2.fastNlMeansDenoising(gray, h=10)
80
+
81
+
82
+ def enhance_contrast(gray):
83
+ """
84
+ applies clahe (contrast limited adaptive histogram equalisation).
85
+ unlike global histogram equalisation, clahe works on small tiles
86
+ so it handles documents with uneven lighting — e.g. a shadow
87
+ across part of a medicine label or a receipt photographed in dim light.
88
+
89
+ cliplimit=2.0 prevents over-amplification of noise in flat regions.
90
+ tileGridSize=(8, 8) gives a good balance between local and global correction.
91
+ """
92
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
93
+ return clahe.apply(gray)
94
+
95
+
96
+ def binarise(gray):
97
+ """
98
+ converts the grayscale image to a clean black and white binary image.
99
+ uses otsu's method which automatically finds the optimal threshold
100
+ value based on the image's intensity histogram — no manual tuning needed.
101
+
102
+ binarisation removes any remaining grey tones and produces the
103
+ high-contrast input that ocr models perform best on.
104
+ """
105
+ _, binary = cv2.threshold(
106
+ gray, 0, 255,
107
+ cv2.THRESH_BINARY + cv2.THRESH_OTSU,
108
+ )
109
+ return binary
110
+
111
+
112
+ def preprocess(image):
113
+ """
114
+ runs the full preprocessing pipeline on a raw document image.
115
+ this is the main entry point called from the audiolens pipeline.
116
+
117
+ input: numpy array — bgr colour or grayscale, any resolution
118
+ output: numpy array — grayscale binarised image, same resolution
119
+
120
+ pipeline: grayscale → deskew → denoise → enhance_contrast → binarise
121
+ """
122
+ image = to_grayscale(image)
123
+ image = deskew(image)
124
+ image = denoise(image)
125
+ image = enhance_contrast(image)
126
+ image = binarise(image)
127
+ return image
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0,<1.0
2
+ gradio>=6.9.0,<7.0
3
+ spaces>=0.28.0
4
+ uvicorn>=0.27.0,<1.0
5
+ python-multipart>=0.0.7
6
+ transformers>=4.35.0,<4.50
7
+ torch>=2.1.0,<2.5
8
+ easyocr>=1.7.0,<1.8
9
+ kokoro>=0.9.0
10
+ soundfile>=0.12.0
11
+ opencv-python-headless>=4.8.0,<4.11
12
+ numpy>=1.24.0,<2.0
13
+ Pillow>=10.0.0,<11.0