| """ |
| Given a tab seperated file (.tsv) with parallel sentences, where the second column is the translation of the sentence in the first column, for example, in the format: |
| src1 trg1 |
| src2 trg2 |
| ... |
| |
| where trg_i is the translation of src_i. |
| |
| Given src_i, the TranslationEvaluator checks which trg_j has the highest similarity using cosine similarity. If i == j, we assume |
| a match, i.e., the correct translation has been found for src_i out of all possible target sentences. |
| |
| It then computes an accuracy over all possible source sentences src_i. Equivalently, it computes also the accuracy for the other direction. |
| |
| A high accuracy score indicates that the model is able to find the correct translation out of a large pool with sentences. |
| |
| Usage: |
| python [model_name_or_path] [parallel-file1] [parallel-file2] ... |
| |
| For example: |
| python distiluse-base-multilingual-cased TED2020-en-de.tsv.gz |
| |
| See the training_multilingual/get_parallel_data_...py scripts for getting parallel sentence data from different sources |
| """ |
|
|
| from sentence_transformers import SentenceTransformer, evaluation, LoggingHandler |
| import sys |
| import gzip |
| import os |
| import logging |
|
|
|
|
| logging.basicConfig(format='%(asctime)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| level=logging.INFO, |
| handlers=[LoggingHandler()]) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| model_name = sys.argv[1] |
| filepaths = sys.argv[2:] |
| inference_batch_size = 32 |
|
|
| model = SentenceTransformer(model_name) |
|
|
|
|
| for filepath in filepaths: |
| src_sentences = [] |
| trg_sentences = [] |
| with gzip.open(filepath, 'rt', encoding='utf8') if filepath.endswith('.gz') else open(filepath, 'r', encoding='utf8') as fIn: |
| for line in fIn: |
| splits = line.strip().split('\t') |
| if len(splits) >= 2: |
| src_sentences.append(splits[0]) |
| trg_sentences.append(splits[1]) |
|
|
| logger.info(os.path.basename(filepath)+": "+str(len(src_sentences))+" sentence pairs") |
| dev_trans_acc = evaluation.TranslationEvaluator(src_sentences, trg_sentences, name=os.path.basename(filepath), batch_size=inference_batch_size) |
| dev_trans_acc(model) |
|
|
|
|
|
|