| """ |
| This script contains an example how to extend an existent sentence embedding model to new languages. |
| |
| Given a (monolingual) teacher model you would like to extend to new languages, which is specified in the teacher_model_name |
| variable. We train a multilingual student model to imitate the teacher model (variable student_model_name) |
| on multiple languages. |
| |
| For training, you need parallel sentence data (machine translation training data). You need tab-seperated files (.tsv) |
| with the first column a sentence in a language understood by the teacher model, e.g. English, |
| and the further columns contain the according translations for languages you want to extend to. |
| |
| See get_parallel_data_[opus/tatoeba/ted2020].py for automatic download of parallel sentences datasets. |
| |
| Note: See make_multilingual.py for a fully automated script that downloads the necessary data and trains the model. This script just trains the model if you have already parallel data in the right format. |
| |
| |
| Further information can be found in our paper: |
| Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation |
| https://arxiv.org/abs/2004.09813 |
| |
| |
| Usage: |
| python make_multilingual_sys.py train1.tsv.gz train2.tsv.gz train3.tsv.gz --dev dev1.tsv.gz dev2.tsv.gz |
| |
| For example: |
| python make_multilingual_sys.py parallel-sentences/TED2020-en-de-train.tsv.gz --dev parallel-sentences/TED2020-en-de-dev.tsv.gz |
| |
| To load all training & dev files from a folder (Linux): |
| python make_multilingual_sys.py parallel-sentences/*-train.tsv.gz --dev parallel-sentences/*-dev.tsv.gz |
| |
| |
| |
| """ |
|
|
| from sentence_transformers import SentenceTransformer, LoggingHandler, models, evaluation, losses |
| from torch.utils.data import DataLoader |
| from sentence_transformers.datasets import ParallelSentencesDataset |
| from datetime import datetime |
|
|
| import os |
| import logging |
| import gzip |
| import numpy as np |
| import sys |
|
|
| logging.basicConfig(format='%(asctime)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| level=logging.INFO, |
| handlers=[LoggingHandler()]) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| teacher_model_name = 'paraphrase-distilroberta-base-v2' |
| student_model_name = 'xlm-roberta-base' |
|
|
| max_seq_length = 128 |
| train_batch_size = 64 |
| inference_batch_size = 64 |
| max_sentences_per_trainfile = 500000 |
| train_max_sentence_length = 250 |
|
|
| num_epochs = 5 |
| num_warmup_steps = 10000 |
|
|
| num_evaluation_steps = 1000 |
|
|
|
|
|
|
| output_path = "output/make-multilingual-sys-"+datetime.now().strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
| |
|
|
|
|
| train_files = [] |
| dev_files = [] |
| is_dev_file = False |
| for arg in sys.argv[1:]: |
| if arg.lower() == '--dev': |
| is_dev_file = True |
| else: |
| if not os.path.exists(arg): |
| print("File could not be found:", arg) |
| exit() |
|
|
| if is_dev_file: |
| dev_files.append(arg) |
| else: |
| train_files.append(arg) |
|
|
| if len(train_files) == 0: |
| print("Please pass at least some train files") |
| print("python make_multilingual_sys.py file1.tsv.gz file2.tsv.gz --dev dev1.tsv.gz dev2.tsv.gz") |
| exit() |
|
|
|
|
| logger.info("Train files: {}".format(", ".join(train_files))) |
| logger.info("Dev files: {}".format(", ".join(dev_files))) |
|
|
| |
| logger.info("Load teacher model") |
| teacher_model = SentenceTransformer(teacher_model_name) |
|
|
|
|
| logger.info("Create student model from scratch") |
| word_embedding_model = models.Transformer(student_model_name, max_seq_length=max_seq_length) |
| |
| pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) |
| student_model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) |
|
|
|
|
| |
| train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model, batch_size=inference_batch_size, use_embedding_cache=True) |
| for train_file in train_files: |
| train_data.load_data(train_file, max_sentences=max_sentences_per_trainfile, max_sentence_length=train_max_sentence_length) |
|
|
| train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size) |
| train_loss = losses.MSELoss(model=student_model) |
|
|
|
|
|
|
| |
| evaluators = [] |
|
|
| for dev_file in dev_files: |
| logger.info("Create evaluator for " + dev_file) |
| src_sentences = [] |
| trg_sentences = [] |
| with gzip.open(dev_file, 'rt', encoding='utf8') if dev_file.endswith('.gz') else open(dev_file, encoding='utf8') as fIn: |
| for line in fIn: |
| splits = line.strip().split('\t') |
| if splits[0] != "" and splits[1] != "": |
| src_sentences.append(splits[0]) |
| trg_sentences.append(splits[1]) |
|
|
|
|
| |
| dev_mse = evaluation.MSEEvaluator(src_sentences, trg_sentences, name=os.path.basename(dev_file), teacher_model=teacher_model, batch_size=inference_batch_size) |
| evaluators.append(dev_mse) |
|
|
| |
| dev_trans_acc = evaluation.TranslationEvaluator(src_sentences, trg_sentences, name=os.path.basename(dev_file),batch_size=inference_batch_size) |
| evaluators.append(dev_trans_acc) |
|
|
|
|
|
|
| |
| student_model.fit(train_objectives=[(train_dataloader, train_loss)], |
| evaluator=evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores)), |
| epochs=num_epochs, |
| warmup_steps=num_warmup_steps, |
| evaluation_steps=num_evaluation_steps, |
| output_path=output_path, |
| save_best_model=True, |
| optimizer_params= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False} |
| ) |
|
|