Vaibhav Gaikwad commited on
Commit
374083e
Β·
1 Parent(s): a80a32e

fix: switch to gradio native api for zerogpu compatibility

Browse files
Files changed (2) hide show
  1. app.py +153 -180
  2. requirements.txt +1 -4
app.py CHANGED
@@ -1,23 +1,21 @@
1
  """
2
  audiolens β€” app.py
3
- huggingface space backend (zerogpu + fastapi + gradio)
4
 
5
- endpoints:
6
- POST /classify β€” document type classification (dit-base)
7
- POST /ocr β€” text extraction (easyocr)
8
- POST /speak β€” text to speech (kokoro)
 
9
 
10
- preprocessing (opencv) runs inline β€” no separate endpoint needed.
11
- llm extraction (gemini) is called directly from the pwa β€” not here.
 
12
 
13
- models load once at startup into cpu ram (except easyocr which
14
- lazy-inits inside the gpu function so it can bind to cuda).
15
- gpu is grabbed per-request via @spaces.GPU and released immediately after.
16
  """
17
 
18
  import io
19
- import os
20
- import tempfile
21
  import warnings
22
  warnings.filterwarnings('ignore')
23
 
@@ -28,31 +26,16 @@ from PIL import Image
28
  import torch
29
  import spaces
30
  import gradio as gr
31
- from fastapi import FastAPI, File, UploadFile, HTTPException
32
- from fastapi.middleware.cors import CORSMiddleware
33
- from fastapi.responses import JSONResponse, FileResponse
34
- from pydantic import BaseModel
35
- from starlette.background import BackgroundTask
36
 
37
  from j2_preprocess import preprocess
38
 
39
 
40
  # ============================================================
41
- # -- app setup --
42
  # ============================================================
43
 
44
- app = FastAPI(title='audiolens api')
45
-
46
- # allow pwa to call from any origin
47
- app.add_middleware(
48
- CORSMiddleware,
49
- allow_origins=['*'],
50
- allow_methods=['*'],
51
- allow_headers=['*'],
52
- )
53
-
54
  # dit maps its 16 rvl-cdip classes to audiolens categories
55
- # indices must match the 9 classes we trained with in j1
56
  DIT_CLASS_MAP = {
57
  0: 'letter',
58
  1: 'form',
@@ -73,7 +56,7 @@ SELECTED_RVL_IDX = list(DIT_CLASS_MAP.keys())
73
 
74
  print('loading models...')
75
 
76
- # -- classifier: dit-base (loads to cpu at startup) --
77
  from transformers import AutoImageProcessor, AutoModelForImageClassification
78
 
79
  dit_processor = AutoImageProcessor.from_pretrained('microsoft/dit-base-finetuned-rvlcdip')
@@ -85,7 +68,7 @@ print('dit-base loaded.')
85
  ocr_reader = None
86
  print('easyocr will lazy-init on first ocr request.')
87
 
88
- # -- tts: kokoro (loads to cpu at startup) --
89
  import soundfile as sf
90
  from kokoro import KPipeline
91
  kokoro_pipeline = KPipeline(lang_code='b') # b = british english
@@ -94,102 +77,57 @@ print('kokoro loaded.')
94
  print('all models ready.')
95
 
96
 
97
- # ============================================================
98
- # -- request schemas --
99
- # ============================================================
100
-
101
- class SpeakRequest(BaseModel):
102
- text: str
103
- voice: str = 'bf_emma'
104
-
105
-
106
  # ============================================================
107
  # -- helpers --
108
  # ============================================================
109
 
110
- def bytes_to_pil(image_bytes):
111
- """converts raw image bytes to a pil image."""
112
- return Image.open(io.BytesIO(image_bytes)).convert('RGB')
113
-
114
-
115
- def bytes_to_cv2(image_bytes):
116
- """converts raw image bytes to a bgr numpy array for opencv."""
117
- arr = np.frombuffer(image_bytes, np.uint8)
118
- img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
119
- if img is None:
120
- raise ValueError('could not decode image β€” check the file format')
121
- return img
122
-
123
-
124
- # ============================================================
125
- # -- endpoint: health check --
126
- # ============================================================
127
-
128
- @app.get('/health')
129
- async def health():
130
- """simple ping to check if the space is warm."""
131
- return {'status': 'ok', 'models': ['dit-base', 'easyocr', 'kokoro']}
132
 
133
 
134
  # ============================================================
135
- # -- endpoint: classify --
136
- # classifies the document type from the uploaded image.
137
- # uses zerogpu for inference, releases gpu immediately after.
138
  # ============================================================
139
 
140
  @spaces.GPU
141
- def _run_classify(pil_image):
142
- """runs dit-base inference on gpu. called inside classify endpoint."""
143
- dit_model.to('cuda')
144
- inputs = dit_processor(images=pil_image, return_tensors='pt').to('cuda')
145
-
146
- with torch.no_grad():
147
- logits = dit_model(**inputs).logits
148
-
149
- # no need to move back to cpu β€” zerogpu reclaims on function exit
150
-
151
- # slice to our 9 selected classes and get the winner
152
- selected_logits = logits[0, SELECTED_RVL_IDX]
153
- pred_idx = selected_logits.argmax().item()
154
- confidence = torch.softmax(selected_logits, dim=0)[pred_idx].item()
155
- doc_type = DIT_CLASS_MAP[SELECTED_RVL_IDX[pred_idx]]
156
- return doc_type, round(confidence, 4)
157
-
158
-
159
- @app.post('/classify')
160
- async def classify(file: UploadFile = File(...)):
161
  """
162
  classifies a document image into one of 9 categories.
 
163
 
164
- returns:
165
- doc_type β€” e.g. 'invoice', 'letter', 'form'
166
- confidence β€” float 0–1
167
  """
 
 
 
168
  try:
169
- image_bytes = await file.read()
170
- if not image_bytes:
171
- raise HTTPException(status_code=400, detail='empty file uploaded')
172
 
173
- pil_image = bytes_to_pil(image_bytes)
174
- doc_type, confidence = _run_classify(pil_image)
175
- return JSONResponse({'doc_type': doc_type, 'confidence': confidence})
176
 
177
- except HTTPException:
178
- raise
179
- except Exception as e:
180
- raise HTTPException(status_code=500, detail=str(e))
 
181
 
 
 
 
 
182
 
183
- # ============================================================
184
- # -- endpoint: ocr --
185
- # preprocesses the image (cpu, outside gpu) then runs easyocr
186
- # on gpu via zerogpu. easyocr lazy-inits on first call so it
187
- # binds to the cuda device provided by zerogpu.
188
- # ============================================================
189
 
190
  @spaces.GPU
191
- def _run_ocr_gpu(clean_image):
192
- """runs easyocr inference on gpu. reader lazy-inits on first call."""
 
 
 
193
  global ocr_reader
194
  if ocr_reader is None:
195
  import easyocr
@@ -200,112 +138,147 @@ def _run_ocr_gpu(clean_image):
200
  return ' '.join(results)
201
 
202
 
203
- @app.post('/ocr')
204
- async def ocr(file: UploadFile = File(...)):
205
  """
206
- extracts all text from a document image.
207
- preprocessing (deskew, denoise, contrast, binarise) is applied first.
208
 
209
- returns:
210
- text β€” raw extracted text string
 
 
 
211
  """
212
- try:
213
- image_bytes = await file.read()
214
- if not image_bytes:
215
- raise HTTPException(status_code=400, detail='empty file uploaded')
216
 
217
- cv2_image = bytes_to_cv2(image_bytes)
 
 
218
 
219
- # preprocessing runs on cpu β€” outside the gpu-decorated function
220
  clean = preprocess(cv2_image)
221
 
222
  # ocr inference on gpu
223
- text = _run_ocr_gpu(clean)
224
- return JSONResponse({'text': text})
225
 
226
- except HTTPException:
227
- raise
228
  except Exception as e:
229
- raise HTTPException(status_code=500, detail=str(e))
230
-
231
 
232
- # ============================================================
233
- # -- endpoint: speak --
234
- # converts text to speech using kokoro and returns a wav file.
235
- # kokoro runs on gpu via zerogpu.
236
- # temp wav file is cleaned up after the response is sent.
237
- # ============================================================
238
 
239
  @spaces.GPU(duration=30)
240
- def _run_tts(text, voice='bf_emma'):
241
- """runs kokoro tts on gpu. called inside speak endpoint."""
242
- chunks = []
243
- for _, _, audio in kokoro_pipeline(text, voice=voice, speed=1.0):
244
- chunks.append(audio)
245
- if not chunks:
246
- return None
247
- return np.concatenate(chunks)
248
-
249
-
250
- @app.post('/speak')
251
- async def speak(req: SpeakRequest):
252
  """
253
  converts text to speech using kokoro.
 
254
 
255
- json body:
256
- text β€” the text to synthesise
257
- voice β€” kokoro voice id (default: bf_emma β€” british female)
258
-
259
- returns:
260
- audio/wav file
261
  """
 
 
 
262
  try:
263
- if not req.text or not req.text.strip():
264
- raise HTTPException(status_code=400, detail='text cannot be empty')
265
 
266
- audio_array = _run_tts(req.text, req.voice)
 
 
267
 
268
- if audio_array is None:
269
- raise HTTPException(status_code=500, detail='tts produced no audio')
270
 
271
- # write wav to a temp file
272
- tmp = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
273
- sf.write(tmp.name, audio_array, 24000)
274
- tmp.close()
275
 
276
- # return the file and clean up after response is sent
277
- return FileResponse(
278
- tmp.name,
279
- media_type='audio/wav',
280
- filename='audiolens_output.wav',
281
- background=BackgroundTask(os.unlink, tmp.name),
282
- )
283
 
284
- except HTTPException:
285
- raise
286
  except Exception as e:
287
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
288
 
289
 
290
  # ============================================================
291
- # -- gradio ui --
292
- # minimal gradio interface β€” required by zerogpu.
293
- # pwa users never see this. it just satisfies the hf spaces sdk.
294
  # ============================================================
295
 
296
- with gr.Blocks() as gradio_ui:
 
297
  gr.Markdown("""
298
  ## AudioLens API
299
- **This space provides the AudioLens backend API.**
300
- Use the endpoints below from the AudioLens PWA:
301
- - `POST /classify` β€” document type classification
302
- - `POST /ocr` β€” text extraction
303
- - `POST /speak` β€” text to speech
304
- - `GET /health` β€” check if space is warm
305
  """)
306
 
307
- gr.Markdown("_This UI is for reference only. The AudioLens PWA calls the API directly._")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
 
310
- # mount fastapi on gradio β€” zerogpu requires gradio sdk
311
- app = gr.mount_gradio_app(app, gradio_ui, path='/gradio')
 
 
1
  """
2
  audiolens β€” app.py
3
+ huggingface space backend (zerogpu + gradio native api)
4
 
5
+ api endpoints (via gradio):
6
+ /call/classify β€” document type classification (dit-base)
7
+ /call/ocr β€” text extraction (easyocr)
8
+ /call/speak β€” text to speech (kokoro)
9
+ /call/health β€” check if space is warm
10
 
11
+ the pwa calls these using the gradio js client (@gradio/client)
12
+ or via gradio's rest api. each function decorated with @spaces.GPU
13
+ gets a gpu allocation only for the duration of that call.
14
 
15
+ llm extraction (gemini) is called directly from the pwa β€” not here.
 
 
16
  """
17
 
18
  import io
 
 
19
  import warnings
20
  warnings.filterwarnings('ignore')
21
 
 
26
  import torch
27
  import spaces
28
  import gradio as gr
 
 
 
 
 
29
 
30
  from j2_preprocess import preprocess
31
 
32
 
33
  # ============================================================
34
+ # -- dit class mapping --
35
  # ============================================================
36
 
 
 
 
 
 
 
 
 
 
 
37
  # dit maps its 16 rvl-cdip classes to audiolens categories
38
+ # indices must match the 9 classes we selected in j1
39
  DIT_CLASS_MAP = {
40
  0: 'letter',
41
  1: 'form',
 
56
 
57
  print('loading models...')
58
 
59
+ # -- classifier: dit-base --
60
  from transformers import AutoImageProcessor, AutoModelForImageClassification
61
 
62
  dit_processor = AutoImageProcessor.from_pretrained('microsoft/dit-base-finetuned-rvlcdip')
 
68
  ocr_reader = None
69
  print('easyocr will lazy-init on first ocr request.')
70
 
71
+ # -- tts: kokoro --
72
  import soundfile as sf
73
  from kokoro import KPipeline
74
  kokoro_pipeline = KPipeline(lang_code='b') # b = british english
 
77
  print('all models ready.')
78
 
79
 
 
 
 
 
 
 
 
 
 
80
  # ============================================================
81
  # -- helpers --
82
  # ============================================================
83
 
84
+ def pil_to_cv2(pil_image):
85
+ """converts a pil rgb image to a bgr numpy array for opencv."""
86
+ rgb = np.array(pil_image)
87
+ return cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  # ============================================================
91
+ # -- gpu functions --
 
 
92
  # ============================================================
93
 
94
  @spaces.GPU
95
+ def classify_fn(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  """
97
  classifies a document image into one of 9 categories.
98
+ called via gradio api: /call/classify
99
 
100
+ input: pil image (gradio Image component with type="pil")
101
+ output: json dict with doc_type and confidence
 
102
  """
103
+ if image is None:
104
+ return {'error': 'no image provided'}
105
+
106
  try:
107
+ dit_model.to('cuda')
108
+ inputs = dit_processor(images=image, return_tensors='pt').to('cuda')
 
109
 
110
+ with torch.no_grad():
111
+ logits = dit_model(**inputs).logits
 
112
 
113
+ # slice to our 9 selected classes and get the winner
114
+ selected_logits = logits[0, SELECTED_RVL_IDX]
115
+ pred_idx = selected_logits.argmax().item()
116
+ confidence = torch.softmax(selected_logits, dim=0)[pred_idx].item()
117
+ doc_type = DIT_CLASS_MAP[SELECTED_RVL_IDX[pred_idx]]
118
 
119
+ return {'doc_type': doc_type, 'confidence': round(confidence, 4)}
120
+
121
+ except Exception as e:
122
+ return {'error': str(e)}
123
 
 
 
 
 
 
 
124
 
125
  @spaces.GPU
126
+ def ocr_gpu(clean_image):
127
+ """
128
+ runs easyocr on a preprocessed image.
129
+ easyocr lazy-inits on first call so it binds to cuda.
130
+ """
131
  global ocr_reader
132
  if ocr_reader is None:
133
  import easyocr
 
138
  return ' '.join(results)
139
 
140
 
141
+ def ocr_fn(image):
 
142
  """
143
+ extracts text from a document image.
144
+ called via gradio api: /call/ocr
145
 
146
+ preprocessing (deskew, denoise, contrast, binarise) runs on cpu
147
+ before the gpu function is called for ocr inference.
148
+
149
+ input: pil image (gradio Image component with type="pil")
150
+ output: extracted text string
151
  """
152
+ if image is None:
153
+ return 'error: no image provided'
 
 
154
 
155
+ try:
156
+ # convert pil to cv2 for preprocessing
157
+ cv2_image = pil_to_cv2(image)
158
 
159
+ # preprocessing runs on cpu β€” outside the gpu function
160
  clean = preprocess(cv2_image)
161
 
162
  # ocr inference on gpu
163
+ text = ocr_gpu(clean)
164
+ return text
165
 
 
 
166
  except Exception as e:
167
+ return f'error: {str(e)}'
 
168
 
 
 
 
 
 
 
169
 
170
  @spaces.GPU(duration=30)
171
+ def speak_fn(text, voice):
 
 
 
 
 
 
 
 
 
 
 
172
  """
173
  converts text to speech using kokoro.
174
+ called via gradio api: /call/speak
175
 
176
+ input: text string + voice id
177
+ output: tuple of (sample_rate, audio_array) for gradio Audio component
 
 
 
 
178
  """
179
+ if not text or not text.strip():
180
+ return None
181
+
182
  try:
183
+ if not voice or not voice.strip():
184
+ voice = 'bf_emma'
185
 
186
+ chunks = []
187
+ for _, _, audio in kokoro_pipeline(text, voice=voice, speed=1.0):
188
+ chunks.append(audio)
189
 
190
+ if not chunks:
191
+ return None
192
 
193
+ audio_array = np.concatenate(chunks)
 
 
 
194
 
195
+ # gradio Audio expects (sample_rate, numpy_array)
196
+ return (24000, audio_array)
 
 
 
 
 
197
 
 
 
198
  except Exception as e:
199
+ print(f'tts error: {e}')
200
+ return None
201
+
202
+
203
+ def health_fn():
204
+ """
205
+ simple check to see if the space is warm and models are loaded.
206
+ called via gradio api: /call/health
207
+ """
208
+ return {'status': 'ok', 'models': ['dit-base', 'easyocr', 'kokoro']}
209
 
210
 
211
  # ============================================================
212
+ # -- gradio ui + api --
 
 
213
  # ============================================================
214
 
215
+ with gr.Blocks(title='AudioLens API') as demo:
216
+
217
  gr.Markdown("""
218
  ## AudioLens API
219
+ **This space provides the AudioLens backend.**
220
+ The AudioLens PWA calls the API endpoints below using the Gradio client.
 
 
 
 
221
  """)
222
 
223
+ # -- classify tab --
224
+ with gr.Tab('Classify'):
225
+ classify_image = gr.Image(type='pil', label='document image')
226
+ classify_btn = gr.Button('classify')
227
+ classify_out = gr.JSON(label='result')
228
+ classify_btn.click(
229
+ fn=classify_fn,
230
+ inputs=classify_image,
231
+ outputs=classify_out,
232
+ api_name='classify',
233
+ )
234
+
235
+ # -- ocr tab --
236
+ with gr.Tab('OCR'):
237
+ ocr_image = gr.Image(type='pil', label='document image')
238
+ ocr_btn = gr.Button('extract text')
239
+ ocr_out = gr.Textbox(label='extracted text', lines=10)
240
+ ocr_btn.click(
241
+ fn=ocr_fn,
242
+ inputs=ocr_image,
243
+ outputs=ocr_out,
244
+ api_name='ocr',
245
+ )
246
+
247
+ # -- speak tab --
248
+ with gr.Tab('Speak'):
249
+ speak_text = gr.Textbox(label='text to speak', lines=5)
250
+ speak_voice = gr.Textbox(label='voice id', value='bf_emma')
251
+ speak_btn = gr.Button('generate speech')
252
+ speak_out = gr.Audio(label='output audio')
253
+ speak_btn.click(
254
+ fn=speak_fn,
255
+ inputs=[speak_text, speak_voice],
256
+ outputs=speak_out,
257
+ api_name='speak',
258
+ )
259
+
260
+ # -- health (hidden, api only) --
261
+ health_btn = gr.Button('health', visible=False)
262
+ health_out = gr.JSON(visible=False)
263
+ health_btn.click(
264
+ fn=health_fn,
265
+ inputs=[],
266
+ outputs=health_out,
267
+ api_name='health',
268
+ )
269
+
270
+ gr.Markdown("""
271
+ ---
272
+ **API endpoints** (use via [@gradio/client](https://www.gradio.app/guides/getting-started-with-the-js-client)):
273
+ - `/call/classify` β€” document type classification
274
+ - `/call/ocr` β€” text extraction with preprocessing
275
+ - `/call/speak` β€” text to speech
276
+ - `/call/health` β€” check if space is warm
277
+
278
+ _This UI is for testing. The AudioLens PWA calls the API directly._
279
+ """)
280
 
281
 
282
+ # launch β€” hf spaces handles this automatically
283
+ if __name__ == '__main__':
284
+ demo.launch(server_name='0.0.0.0', server_port=7860)
requirements.txt CHANGED
@@ -1,8 +1,5 @@
1
- fastapi>=0.110.0,<1.0
2
  gradio>=6.9.0,<7.0
3
  spaces>=0.28.0
4
- uvicorn>=0.27.0,<1.0
5
- python-multipart>=0.0.7
6
  transformers>=4.35.0,<4.50
7
  torch>=2.1.0,<2.5
8
  easyocr>=1.7.0,<1.8
@@ -10,4 +7,4 @@ kokoro>=0.9.0
10
  soundfile>=0.12.0
11
  opencv-python-headless>=4.8.0,<4.11
12
  numpy>=1.24.0,<2.0
13
- Pillow>=10.0.0,<11.0
 
 
1
  gradio>=6.9.0,<7.0
2
  spaces>=0.28.0
 
 
3
  transformers>=4.35.0,<4.50
4
  torch>=2.1.0,<2.5
5
  easyocr>=1.7.0,<1.8
 
7
  soundfile>=0.12.0
8
  opencv-python-headless>=4.8.0,<4.11
9
  numpy>=1.24.0,<2.0
10
+ Pillow>=10.0.0,<11.0