| import os |
| import torch |
| from transformers import AutoModel, AutoTokenizer |
| from sentence_transformers import SentenceTransformer |
| from sagemaker_inference import content_types, decoder, default_inference_handler, encoder |
|
|
| def model_fn(model_dir): |
| model = SentenceTransformer(model_dir) |
| return model |
|
|
| def input_fn(request_body, request_content_type): |
| if request_content_type == content_types.JSON: |
| input_data = decoder.decode(request_body, content_types.JSON) |
| return input_data |
| else: |
| raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}") |
|
|
| def predict_fn(input_data, model): |
| embeddings = model.encode(input_data) |
| return embeddings |
|
|
| def output_fn(prediction, accept): |
| if accept == content_types.JSON: |
| output = encoder.encode(prediction, content_types.JSON) |
| return output |
| else: |
| raise ValueError(f"Requested unsupported ContentType in Accept: {accept}") |
|
|
|
|