| from typing import Dict |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
|
|
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| "gpt2", torch_dtype=torch.float16, output_hidden_states=True |
| ) |
| self.model = self.model.cuda() |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
| def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]: |
| """ |
| Args: |
| data (:obj:): |
| includes the deserialized audio file as bytes |
| Return: |
| A :obj:`dict`:. base64 encoded image |
| """ |
| |
| inputs = data.pop("inputs", data) |
| all_logits = [] |
|
|
| for doc in inputs: |
| tokenized = self.tokenizer( |
| inputs, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| ) |
| token_ids, token_mask = tokenized.input_ids.cuda(), tokenized.attention_mask.cuda() |
| with torch.no_grad(): |
| out = self.model(token_ids, attention_mask=token_mask) |
| meaned_logits = (out.logits * token_mask.unsqueeze(-1)).sum(1) / token_mask.sum( |
| 1 |
| ).unsqueeze(-1) |
| sorted_logits = torch.sort(out.logits).values |
| mean_sorted_logits = (sorted_logits * token_mask.unsqueeze(-1)).sum( |
| 1 |
| ) / token_mask.sum(1).unsqueeze(-1) |
| all_logits.append(meaned_logits.cpu().numpy().tolist()) |
| |
| |
| return {"logits": all_logits} |