| """ |
| Tests the correct computation of evaluation scores from BinaryClassificationEvaluator |
| """ |
| from sentence_transformers import SentenceTransformer, evaluation, util, losses, LoggingHandler |
| import logging |
| import unittest |
| from sklearn.metrics import f1_score, accuracy_score |
| import numpy as np |
| import gzip |
| import csv |
| from sentence_transformers import InputExample |
| from torch.utils.data import DataLoader |
| import os |
|
|
| class EvaluatorTest(unittest.TestCase): |
|
|
| def test_BinaryClassificationEvaluator_find_best_f1_and_threshold(self): |
| """Tests that the F1 score for the computed threshold is correct""" |
| y_true = np.random.randint(0, 2, 1000) |
| y_pred_cosine = np.random.randn(1000) |
| best_f1, best_precision, best_recall, threshold = evaluation.BinaryClassificationEvaluator.find_best_f1_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True) |
| y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine] |
| sklearn_f1score = f1_score(y_true, y_pred_labels) |
| assert np.abs(best_f1 - sklearn_f1score) < 1e-6 |
|
|
|
|
| def test_BinaryClassificationEvaluator_find_best_accuracy_and_threshold(self): |
| """Tests that the Acc score for the computed threshold is correct""" |
| y_true = np.random.randint(0, 2, 1000) |
| y_pred_cosine = np.random.randn(1000) |
| max_acc, threshold = evaluation.BinaryClassificationEvaluator.find_best_acc_and_threshold(y_pred_cosine, y_true, high_score_more_similar=True) |
| y_pred_labels = [1 if pred >= threshold else 0 for pred in y_pred_cosine] |
| sklearn_acc = accuracy_score(y_true, y_pred_labels) |
| assert np.abs(max_acc - sklearn_acc) < 1e-6 |
|
|
| def test_LabelAccuracyEvaluator(self): |
| """Tests that the LabelAccuracyEvaluator can be loaded correctly""" |
| model = SentenceTransformer('paraphrase-distilroberta-base-v1') |
|
|
| nli_dataset_path = 'datasets/AllNLI.tsv.gz' |
| if not os.path.exists(nli_dataset_path): |
| util.http_get('https://sbert.net/datasets/AllNLI.tsv.gz', nli_dataset_path) |
|
|
| label2int = {"contradiction": 0, "entailment": 1, "neutral": 2} |
| dev_samples = [] |
| with gzip.open(nli_dataset_path, 'rt', encoding='utf8') as fIn: |
| reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE) |
| for row in reader: |
| if row['split'] == 'train': |
| label_id = label2int[row['label']] |
| dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label_id)) |
| if len(dev_samples) >= 100: |
| break |
|
|
| train_loss = losses.SoftmaxLoss(model=model, sentence_embedding_dimension=model.get_sentence_embedding_dimension(), num_labels=len(label2int)) |
|
|
| dev_dataloader = DataLoader(dev_samples, shuffle=False, batch_size=16) |
| evaluator = evaluation.LabelAccuracyEvaluator(dev_dataloader, softmax_model=train_loss) |
| acc = evaluator(model) |
| assert acc > 0.2 |
|
|
| def test_ParaphraseMiningEvaluator(self): |
| """Tests that the ParaphraseMiningEvaluator can be loaded""" |
| model = SentenceTransformer('paraphrase-distilroberta-base-v1') |
| sentences = {0: "Hello World", 1: "Hello World!", 2: "The cat is on the table", 3: "On the table the cat is"} |
| data_eval = evaluation.ParaphraseMiningEvaluator(sentences, [(0,1), (2,3)]) |
| score = data_eval(model) |
| assert score > 0.99 |