| from transformers import T5Tokenizer, MT5ForConditionalGeneration |
| from simpletransformers.t5 import T5Model |
| import datetime |
| import logging |
| import os |
|
|
|
|
| class Inference: |
| def _discard_recommendations(self, original, proposal): |
| proposal = proposal.lower() |
| original = original.lower() |
| if proposal == original: |
| return True |
|
|
| chars = [".", "!", " ", "?", ","] |
| _proposal = proposal |
| _original = original |
| for char in chars: |
| proposal = proposal.replace(char, "") |
| original = original.replace(char, "") |
|
|
| if proposal == original: |
| return True |
|
|
| return False |
|
|
| |
| def get_paraphrases( |
| self, |
| model_name, |
| sentence, |
| temperature, |
| prefix="paraphrase: ", |
| n_predictions=2, |
| top_k=120, |
| max_length=256, |
| device="cpu", |
| ): |
| model = MT5ForConditionalGeneration.from_pretrained(model_name) |
| tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
| discaded = 0 |
| text = prefix + sentence + " </s>" |
| encoding = tokenizer.encode_plus( |
| text, pad_to_max_length=True, return_tensors="pt" |
| ) |
| input_ids, attention_masks = encoding["input_ids"].to(device), encoding[ |
| "attention_mask" |
| ].to(device) |
|
|
| do_sample = True if temperature > 0 else False |
| print(f"do_sample: {do_sample}") |
| print(f"temperature: {temperature}") |
| |
| |
| model_output = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_masks, |
| do_sample=do_sample, |
| max_length=max_length, |
| top_k=top_k, |
| num_beams=n_predictions * 2, |
| top_p=0.98, |
| temperature=temperature, |
| early_stopping=True, |
| num_return_sequences=n_predictions * 2, |
| ) |
| logging.debug(f"{len(model_output)} predictions for {sentence}") |
| outputs = [] |
| for output in model_output: |
| generated_sent = tokenizer.decode( |
| output, skip_special_tokens=True, clean_up_tokenization_spaces=True |
| ) |
| if ( |
| self._discard_recommendations(sentence, generated_sent) is False |
| and generated_sent not in outputs |
| ): |
| generated_sent = generated_sent.replace("’", "'") |
| outputs.append(generated_sent) |
| else: |
| logging.debug(f"Discarded: {generated_sent} - source:{sentence}") |
| discaded = +1 |
|
|
| if len(outputs) == n_predictions: |
| break |
|
|
| return outputs |
|
|
|
|
| def main(): |
| i = Inference() |
| sentence = "Aquesta és una associació sense ànim de lucre amb la missió de fomentar la presència i l'ús del català." |
| model = os.getcwd() |
| options = i.get_paraphrases(model, sentence, 1.0) |
| print(f"original: {sentence}") |
| for option in options: |
| print(f" {option}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|