| from fastapi import FastAPI, Request |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
|
|
| app = FastAPI() |
|
|
| |
| model_name = "grammarly/coedit-xl" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
| class InputText(BaseModel): |
| text: str |
|
|
| @app.post("/correct") |
| async def correct_text(data: InputText): |
| input_text = data.text |
| inputs = tokenizer(input_text, return_tensors="pt") |
| outputs = model.generate(**inputs, max_new_tokens=256) |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return {"corrected": result} |