| import cv2 |
| import io |
| import numpy as np |
| from PIL import Image |
|
|
| import pytesseract |
|
|
| from fastapi import FastAPI, UploadFile, File |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from mltu.inferenceModel import OnnxInferenceModel |
| from mltu.utils.text_utils import ctc_decoder |
| from mltu.transformers import ImageResizer |
| from mltu.configs import BaseModelConfigs |
|
|
| from textblob import TextBlob |
| from happytransformer import HappyTextToText, TTSettings |
|
|
|
|
| from transformers import AutoTokenizer, T5ForConditionalGeneration |
| from pydantic import BaseModel |
|
|
| tokenizer = AutoTokenizer.from_pretrained("grammarly/coedit-large") |
| chatModel = T5ForConditionalGeneration.from_pretrained("grammarly/coedit-large") |
|
|
| configs = BaseModelConfigs.load("./configs.yaml") |
|
|
| |
|
|
| beam_settings = TTSettings(num_beams=5, min_length=1, max_length=100) |
|
|
| app = FastAPI() |
|
|
| origins = ["*"] |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class ImageToWordModel(OnnxInferenceModel): |
| def __init__(self, char_list, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.char_list = char_list |
|
|
| def predict(self, image: np.ndarray): |
| image = ImageResizer.resize_maintaining_aspect_ratio( |
| image, *self.input_shape[:2][::-1] |
| ) |
|
|
| image_pred = np.expand_dims(image, axis=0).astype(np.float32) |
|
|
| preds = self.model.run(None, {self.input_name: image_pred})[0] |
|
|
| text = ctc_decoder(preds, self.char_list)[0] |
|
|
| return text |
|
|
|
|
| model = ImageToWordModel(model_path=configs.model_path, char_list=configs.vocab) |
| extracted_text = "" |
|
|
| @app.post("/extract_handwritten_text/") |
| async def predict_text(image: UploadFile): |
| global extracted_text |
| |
| img = await image.read() |
| nparr = np.frombuffer(img, np.uint8) |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
|
|
| |
| extracted_text = model.predict(img) |
| |
|
|
| return {"text": extracted_text} |
|
|
|
|
| @app.post("/extract_text/") |
| async def extract_text_from_image(image: UploadFile): |
| global extracted_text |
| |
| if image.content_type.startswith("image/"): |
| |
| image_bytes = await image.read() |
| img = Image.open(io.BytesIO(image_bytes)) |
|
|
| |
| extracted_text = pytesseract.image_to_string(img) |
| |
|
|
| return {"text": extracted_text} |
| else: |
| return {"error": "Invalid file format. Please upload an image."} |
|
|
| class ChatPrompt(BaseModel): |
| prompt: str |
|
|
| @app.post("/chat_prompt/") |
| async def chat_prompt(request: ChatPrompt): |
| global extracted_text |
| input_text = request.prompt + ": " + extracted_text |
| print(input_text) |
| input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
| outputs = chatModel.generate(input_ids, max_length=256) |
| edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| return {"edited_text": edited_text} |
|
|