| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| def model_fn(model_dir): | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) | |
| return model, tokenizer | |
| def predict_fn(data, model_and_tokenizer): | |
| model, tokenizer = model_and_tokenizer | |
| # Assuming 'inputs' is the key in the input data | |
| inputs = data.pop("inputs", data) | |
| # Tokenize the input | |
| tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True) | |
| # Make the prediction | |
| with torch.no_grad(): | |
| output = model(**tokenized) | |
| # Get the predicted class (assuming it's a classification task) | |
| predicted_class = torch.argmax(output.logits, dim=1).item() | |
| return {"predicted_class": predicted_class} |