| from simpletransformers.classification import ClassificationModel, ClassificationArgs |
| from typing import Dict, List, Any |
| import pandas as pd |
| import webvtt |
| from datetime import datetime |
| import torch |
| import spacy |
|
|
| nlp = spacy.load("en_core_web_sm") |
| tokenizer = nlp.tokenizer |
| token_limit = 200 |
|
|
| class EndpointHandler(): |
| def __init__(self, path="."): |
| print("Loading models...") |
| cuda_available = torch.cuda.is_available() |
| self.model = ClassificationModel( |
| "roberta", path, use_cuda=cuda_available |
| ) |
|
|
| def __call__(self, data_file: str) -> List[Dict[str, Any]]: |
| ''' data_file is a str pointing to filename of type .vtt ''' |
|
|
| utterances_list = [] |
| predictions, raw_outputs = self.model.predict(utterances_list) |
| |
| return predictions |
|
|