Commit ·
45542e5
1
Parent(s): 365abf6
replace main with tashkeel model
Browse files
main.py
CHANGED
|
@@ -1,7 +1,57 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
app = FastAPI()
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
#from shakkala import Shakkala
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import torch
|
| 6 |
+
from eo_pl import TashkeelModel as TashkeelModelEO
|
| 7 |
+
from ed_pl import TashkeelModel as TashkeelModelED
|
| 8 |
+
from tashkeel_tokenizer import TashkeelTokenizer
|
| 9 |
+
from utils import remove_non_arabic
|
| 10 |
|
| 11 |
app = FastAPI()
|
| 12 |
|
| 13 |
+
# Load CaTT models
|
| 14 |
+
tokenizer = TashkeelTokenizer()
|
| 15 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 16 |
+
max_seq_len = 1024
|
| 17 |
+
|
| 18 |
+
eo_ckpt_path = Path(__file__).parent / 'models/best_eo_mlm_ns_epoch_193.pt'
|
| 19 |
+
ed_ckpt_path = Path(__file__).parent / 'models/best_ed_mlm_ns_epoch_178.pt'
|
| 20 |
+
|
| 21 |
+
# Load Encoder-Only model
|
| 22 |
+
eo_model = TashkeelModelEO(tokenizer, max_seq_len=max_seq_len,
|
| 23 |
+
n_layers=6, learnable_pos_emb=False)
|
| 24 |
+
eo_model.load_state_dict(torch.load(eo_ckpt_path, map_location=device))
|
| 25 |
+
eo_model.eval().to(device)
|
| 26 |
+
|
| 27 |
+
# Load Encoder-Decoder model
|
| 28 |
+
ed_model = TashkeelModelED(tokenizer, max_seq_len=max_seq_len,
|
| 29 |
+
n_layers=3, learnable_pos_emb=False)
|
| 30 |
+
ed_model.load_state_dict(torch.load(ed_ckpt_path, map_location=device))
|
| 31 |
+
ed_model.eval().to(device)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CattRequest(BaseModel):
|
| 35 |
+
text: str
|
| 36 |
+
model_type: str # "Encoder-Only" or "Encoder-Decoder"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@app.post("/catt")
|
| 43 |
+
def infer_catt(request: CattRequest):
|
| 44 |
+
try:
|
| 45 |
+
input_text = remove_non_arabic(request.text)
|
| 46 |
+
batch_size = 16
|
| 47 |
+
verbose = True
|
| 48 |
+
|
| 49 |
+
if request.model_type == 'Encoder-Only':
|
| 50 |
+
output_text = eo_model.do_tashkeel_batch([input_text], batch_size, verbose)
|
| 51 |
+
else:
|
| 52 |
+
output_text = ed_model.do_tashkeel_batch([input_text], batch_size, verbose)
|
| 53 |
+
|
| 54 |
+
return {"result": output_text[0]}
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 57 |
+
|