| import base64 |
| import io |
| import logging |
| import os |
| import pathlib |
| import typing |
| from contextlib import asynccontextmanager |
|
|
| import uvicorn |
| from fastapi import FastAPI, Request, UploadFile, File, WebSocket |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.middleware.gzip import GZipMiddleware |
| from faster_whisper import WhisperModel |
| from pydantic import BaseModel, Field, ValidationError, model_validator, ValidationInfo |
| from starlette.websockets import WebSocketState |
|
|
|
|
| @asynccontextmanager |
| async def register_init(app: FastAPI): |
| """ |
| 启动初始化 |
| |
| :return: |
| """ |
| print('Loading ASR model...') |
| setup_asr_model() |
|
|
| yield |
|
|
|
|
| def register_middleware(app: FastAPI): |
| |
| app.add_middleware(GZipMiddleware) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=['*'], |
| allow_credentials=True, |
| allow_methods=['*'], |
| allow_headers=['*'], |
| ) |
|
|
|
|
| def create_app(): |
| app = FastAPI( |
| lifespan=register_init |
| ) |
| register_middleware(app) |
| return app |
|
|
|
|
| app = create_app() |
|
|
| model_size = os.environ.get('WHISPER-MODEL-SIZE', 'large-v3') |
| |
| asr_model: typing.Optional[WhisperModel] = None |
|
|
|
|
| def setup_asr_model(): |
| global asr_model |
| if asr_model is None: |
| logging.info('Loading ASR model...') |
| asr_model = WhisperModel(model_size, device='cuda', compute_type='float16') |
| logging.info('Load ASR model finished.') |
| return asr_model |
|
|
|
|
| class TranscribeRequestParams(BaseModel): |
| uuid: str = Field(title='Request Unique Id.') |
| audio_file: str |
| language: typing.Literal['en', 'zh',] |
| using_file_content: bool |
|
|
| @model_validator(mode='after') |
| def check_audio_file(self): |
| if self.using_file_content: |
| return self |
|
|
| if not pathlib.Path(self.audio_file).exists(): |
| raise FileNotFoundError(f'Audio file not exists.') |
|
|
|
|
| @app.post('/transcribe') |
| async def transcribe_api( |
| request: Request, |
| obj: TranscribeRequestParams |
| ): |
| try: |
| audio_file = obj.audio_file |
| if obj.using_file_content: |
| audio_file = io.BytesIO(base64.b64decode(obj.audio_file)) |
|
|
| segments, _ = asr_model.transcribe(audio_file, language=obj.language) |
|
|
| transcribed_text = '' |
| for segment in segments: |
| transcribed_text = segment.text |
| break |
| except Exception as exc: |
| logging.exception(exc) |
| response_body = { |
| "if_success": False, |
| 'uuid': obj.uuid, |
| 'msg': f'{exc}' |
| } |
| else: |
| response_body = { |
| "if_success": True, |
| 'uuid': obj.uuid, |
| 'transcribed_text': transcribed_text |
| } |
| return response_body |
|
|
|
|
| @app.post('/transcribe-file') |
| async def transcribe_file_api( |
| request: Request, |
| uuid: str, |
| audio_file: typing.Annotated[UploadFile, File()], |
| language: typing.Literal['en', 'zh'] |
| ): |
| try: |
| segments, _ = asr_model.transcribe(audio_file.file, language=language) |
|
|
| transcribed_text = '' |
| for segment in segments: |
| transcribed_text = segment.text |
| break |
| except Exception as exc: |
| logging.exception(exc) |
| response_body = { |
| "if_success": False, |
| 'uuid': uuid, |
| 'msg': f'{exc}' |
| } |
| else: |
| response_body = { |
| "if_success": True, |
| 'uuid': uuid, |
| 'transcribed_text': transcribed_text |
| } |
|
|
| return response_body |
|
|
|
|
| @app.websocket('/transcribe') |
| async def transcribe_ws_api( |
| websocket: WebSocket |
| ): |
| await websocket.accept() |
|
|
| while websocket.client_state == WebSocketState.CONNECTED: |
| request_params = await websocket.receive_json() |
|
|
| try: |
| form = TranscribeRequestParams.model_validate(request_params) |
| except ValidationError as exc: |
| logging.exception(exc) |
| await websocket.send_json({ |
| "if_success": False, |
| 'uuid': request_params.get('uuid', ''), |
| 'msg': f'{exc}' |
| }) |
| continue |
|
|
| try: |
|
|
| audio_file = form.audio_file |
| if form.using_file_content: |
| audio_file = io.BytesIO(base64.b64decode(form.audio_file)) |
|
|
| segments, _ = asr_model.transcribe(audio_file, language=form.language) |
|
|
| transcribed_text = '' |
| for segment in segments: |
| transcribed_text = segment.text |
| break |
| except Exception as exc: |
| logging.exception(exc) |
| response_body = { |
| "if_success": False, |
| 'uuid': form.uuid, |
| 'msg': f'{exc}' |
| } |
| else: |
| response_body = { |
| "if_success": True, |
| 'uuid': form.uuid, |
| 'transcribed_text': transcribed_text |
| } |
|
|
| await websocket.send_json(response_body) |
|
|
|
|
| if __name__ == '__main__': |
| uvicorn.run( |
| app, |
| host=os.environ.get('HOST', '0.0.0.0'), |
| port=int(os.environ.get('PORT', 8080)) |
| ) |
|
|