| import torch |
| from transformers import AutoModel |
|
|
|
|
| def build_text_encoder(config): |
| if config.model_type == "mpnet": |
| model = AutoModel.from_pretrained(config.pretrained_name_or_path) |
| else: |
| raise NotImplementedError() |
|
|
| return model |
|
|
|
|
| |
| def mean_pooling(model_output, attention_mask): |
| token_embeddings = model_output[ |
| 0 |
| ] |
| input_mask_expanded = ( |
| attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| ) |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| input_mask_expanded.sum(1), min=1e-9 |
| ) |
|
|