| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| |
| from pathlib import Path |
| import torch |
| from eo_pl import TashkeelModel as TashkeelModelEO |
| from ed_pl import TashkeelModel as TashkeelModelED |
| from tashkeel_tokenizer import TashkeelTokenizer |
| from utils import remove_non_arabic |
|
|
| app = FastAPI() |
|
|
| |
| tokenizer = TashkeelTokenizer() |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| max_seq_len = 1024 |
| |
| eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt' |
| ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt' |
| |
| |
| eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len, |
| n_layers=6, learnable_pos_emb=False) |
| eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device)) |
| eo_model.eval().to(device) |
| |
| |
| ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len, |
| n_layers=3, learnable_pos_emb=False) |
| ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device)) |
| ed_model.eval().to(device) |
|
|
|
|
| class CattRequest(BaseModel): |
| text: str |
| model_type: str |
|
|
|
|
| |
|
|
|
|
| @app.post("/catt") |
| def infer_catt(request: CattRequest): |
| try: |
| input_text = remove_non_arabic(request.text) |
| batch_size = 16 |
| verbose = True |
| |
| if request.model_type == 'Encoder-Only': |
| output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose) |
| else: |
| output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose) |
| |
| return {"result": output_text[0]} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|