from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch def model_fn(model_dir): tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) return model, tokenizer def predict_fn(data, model_and_tokenizer): model, tokenizer = model_and_tokenizer # Assuming 'inputs' is the key in the input data inputs = data.pop("inputs", data) # Tokenize the input tokenized = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True) # Make the prediction with torch.no_grad(): output = model(**tokenized) # Get the predicted class (assuming it's a classification task) predicted_class = torch.argmax(output.logits, dim=1).item() return {"predicted_class": predicted_class}