splits / subspace /tool.py
Eylon Caplan
Deploy app code targeting HF Storage Bucket
ddb7b62
import torch
import transformers
from transformers import AutoTokenizer, AutoModel
from numpy import ndarray
import numpy as np
from .similarity import subspace_johnson, subspace_bert_score, vanilla_bert_score
class MySimilarity:
def __init__(self, device='cpu', model_name_or_path='bert-base-uncased'):
# Set up model
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModel.from_pretrained(model_name_or_path)
self.model.eval()
self.model.to(device)
self.max_length = 128
def __call__(self, sentence1, sentence2, weight="L2"):
pass
def encode(self, sentence, return_numpy=False, batch_size=12):
single_sentence = False
if isinstance(sentence, str):
sentence = [sentence]
single_sentence = True
embedding_list = []
with torch.no_grad():
total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
for batch_id in range(total_batch):
inputs = self.tokenizer(
sentence[batch_id*batch_size:(batch_id+1)*batch_size],
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs, return_dict=True)
embeddings = outputs.last_hidden_state.cpu()
embedding_list.append(embeddings)
embeddings = torch.cat(embedding_list, 0)
if return_numpy and not isinstance(embeddings, ndarray):
return embeddings.numpy()
return embeddings
class SubspaceJohnsonSimilarity(MySimilarity):
def __call__(self, sentence1, sentence2, weight="L2"):
hidden_states1 = self.encode(sentence1)
hidden_states2 = self.encode(sentence2)
return subspace_johnson(hidden_states1, hidden_states2, weight)
class SubspaceBERTScore(MySimilarity):
def __call__(self, sentence1, sentence2, weight="L2"):
hidden_states1 = self.encode(sentence1)
hidden_states2 = self.encode(sentence2)
return subspace_bert_score(hidden_states1, hidden_states2, weight)
class VanillaBERTScore(MySimilarity):
def __call__(self, sentence1, sentence2, weight="L2"):
hidden_states1 = self.encode(sentence1)
hidden_states2 = self.encode(sentence2)
return vanilla_bert_score(hidden_states1, hidden_states2, weight)