| import torch |
| import transformers |
| from transformers import AutoTokenizer, AutoModel |
| from numpy import ndarray |
| import numpy as np |
| from .similarity import subspace_johnson, subspace_bert_score, vanilla_bert_score |
|
|
|
|
| class MySimilarity: |
| def __init__(self, device='cpu', model_name_or_path='bert-base-uncased'): |
| |
| self.device = device |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
| self.model = AutoModel.from_pretrained(model_name_or_path) |
| self.model.eval() |
| self.model.to(device) |
| self.max_length = 128 |
|
|
| def __call__(self, sentence1, sentence2, weight="L2"): |
| pass |
| |
| |
| def encode(self, sentence, return_numpy=False, batch_size=12): |
| |
| single_sentence = False |
| if isinstance(sentence, str): |
| sentence = [sentence] |
| single_sentence = True |
|
|
| embedding_list = [] |
| with torch.no_grad(): |
| total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0) |
| for batch_id in range(total_batch): |
| inputs = self.tokenizer( |
| sentence[batch_id*batch_size:(batch_id+1)*batch_size], |
| padding=True, |
| truncation=True, |
| max_length=self.max_length, |
| return_tensors="pt" |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| outputs = self.model(**inputs, return_dict=True) |
|
|
| embeddings = outputs.last_hidden_state.cpu() |
| embedding_list.append(embeddings) |
| |
| embeddings = torch.cat(embedding_list, 0) |
| |
| if return_numpy and not isinstance(embeddings, ndarray): |
| return embeddings.numpy() |
| return embeddings |
|
|
|
|
| class SubspaceJohnsonSimilarity(MySimilarity): |
| def __call__(self, sentence1, sentence2, weight="L2"): |
| hidden_states1 = self.encode(sentence1) |
| hidden_states2 = self.encode(sentence2) |
| return subspace_johnson(hidden_states1, hidden_states2, weight) |
| |
| |
| class SubspaceBERTScore(MySimilarity): |
| def __call__(self, sentence1, sentence2, weight="L2"): |
| hidden_states1 = self.encode(sentence1) |
| hidden_states2 = self.encode(sentence2) |
| return subspace_bert_score(hidden_states1, hidden_states2, weight) |
| |
| |
| class VanillaBERTScore(MySimilarity): |
| def __call__(self, sentence1, sentence2, weight="L2"): |
| hidden_states1 = self.encode(sentence1) |
| hidden_states2 = self.encode(sentence2) |
| return vanilla_bert_score(hidden_states1, hidden_states2, weight) |