| import argparse |
| import markdown2 |
| import sys |
| import uvicorn |
|
|
| from pathlib import Path |
| from typing import Union, Optional |
|
|
| from fastapi import FastAPI |
| from pydantic import BaseModel, Field |
| from fastapi.responses import HTMLResponse |
| from tclogger import logger, OSEnver |
|
|
| from transforms.embed import JinaAIEmbedder, JinaAIOnnxEmbedder |
| from configs.constants import AVAILABLE_MODELS |
|
|
| info_path = Path(__file__).parents[1] / "configs" / "info.json" |
| ENVER = OSEnver(info_path) |
|
|
|
|
| class EmbeddingApp: |
| def __init__(self): |
| self.app = FastAPI( |
| docs_url="/", |
| title=ENVER["app_name"], |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
| version=ENVER["version"], |
| ) |
| self.embedder = JinaAIOnnxEmbedder() |
| self.setup_routes() |
|
|
| def get_available_models(self): |
| return AVAILABLE_MODELS |
|
|
| def get_readme(self): |
| readme_path = Path(__file__).parents[1] / "README.md" |
| with open(readme_path, "r", encoding="utf-8") as rf: |
| readme_str = rf.read() |
| readme_html = markdown2.markdown( |
| readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"] |
| ) |
| return readme_html |
|
|
| class CalcEmbeddingPostItem(BaseModel): |
| text: Union[str, list[str]] = Field( |
| default=None, |
| summary="Input text(s) to embed", |
| ) |
| model: Optional[str] = Field( |
| default=AVAILABLE_MODELS[0], |
| summary="Embedding model name", |
| ) |
|
|
| def calc_embedding(self, item: CalcEmbeddingPostItem): |
| logger.note(f"> Encoding text: [{item.text}]", end=" ") |
| |
| |
| embeddings = self.embedder.encode(item.text).tolist() |
| logger.success(f"[{len(embeddings[0])}]") |
| if len(embeddings) == 1: |
| return embeddings[0] |
| else: |
| return embeddings |
|
|
| def setup_routes(self): |
| self.app.get( |
| "/models", |
| summary="Get available models", |
| )(self.get_available_models) |
|
|
| self.app.post( |
| "/embedding", |
| summary="Calculate embedding for input text(s)", |
| )(self.calc_embedding) |
|
|
| self.app.get( |
| "/readme", |
| summary="README of Embed API", |
| response_class=HTMLResponse, |
| include_in_schema=False, |
| )(self.get_readme) |
|
|
|
|
| class ArgParser(argparse.ArgumentParser): |
| def __init__(self, *args, **kwargs): |
| super(ArgParser, self).__init__(*args, **kwargs) |
|
|
| self.add_argument( |
| "-s", |
| "--server", |
| type=str, |
| default=ENVER["server"], |
| help=f"Server IP ({ENVER['server']}) for Embed API", |
| ) |
| self.add_argument( |
| "-p", |
| "--port", |
| type=int, |
| default=ENVER["port"], |
| help=f"Server Port ({ENVER['port']}) for Embed API", |
| ) |
|
|
| self.args = self.parse_args(sys.argv[1:]) |
|
|
|
|
| app = EmbeddingApp().app |
|
|
| if __name__ == "__main__": |
| args = ArgParser().args |
| uvicorn.run("__main__:app", host=args.server, port=args.port) |
|
|
| |
|
|