| from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizerFast, DataCollatorForLanguageModeling |
| import torch |
|
|
|
|
| def mask_and_unmask( |
| text: str, |
| tokenizer: AutoTokenizer | BertTokenizerFast, |
| model: AutoModelForMaskedLM, |
| data_collator: DataCollatorForLanguageModeling, |
| ) -> str: |
|
|
| collator_input = tokenizer(text) |
| collator_input["labels"] = collator_input["input_ids"].copy() |
| collator_output = data_collator([collator_input]) |
| masked_text = tokenizer.decode(collator_output["input_ids"][0]) |
|
|
| pred_dict = {"masked_text": masked_text} |
|
|
| inputs = tokenizer(masked_text, return_tensors="pt", padding="max_length", truncation=True) |
| token_logits = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits |
| all_masked_token_index = torch.argwhere(inputs["input_ids"] == tokenizer.mask_token_id) |
| if all_masked_token_index.size()[0] != 0: |
|
|
| for i, masked_index_token in enumerate(all_masked_token_index[:, 1]): |
| |
| masked_token_logits = token_logits[0, masked_index_token, :] |
| |
| top_5_tokens = torch.argsort(masked_token_logits, descending=True)[:5].tolist() |
| value = tokenizer.decode(collator_output["labels"][0, masked_index_token - 1]) |
| pred_dict[value] = [tokenizer.decode(token) for token in top_5_tokens] |
|
|
| return pred_dict |
|
|