ctr-ll4 / src /MLM /mask_and_unmask.py
sanjin7's picture
Upload src/ with huggingface_hub
cea4a4b
from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizerFast, DataCollatorForLanguageModeling
import torch
def mask_and_unmask(
text: str,
tokenizer: AutoTokenizer | BertTokenizerFast,
model: AutoModelForMaskedLM,
data_collator: DataCollatorForLanguageModeling,
) -> str:
collator_input = tokenizer(text)
collator_input["labels"] = collator_input["input_ids"].copy()
collator_output = data_collator([collator_input])
masked_text = tokenizer.decode(collator_output["input_ids"][0])
pred_dict = {"masked_text": masked_text}
inputs = tokenizer(masked_text, return_tensors="pt", padding="max_length", truncation=True)
token_logits = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits
all_masked_token_index = torch.argwhere(inputs["input_ids"] == tokenizer.mask_token_id)
if all_masked_token_index.size()[0] != 0:
for i, masked_index_token in enumerate(all_masked_token_index[:, 1]):
# print(masked_index_token)
masked_token_logits = token_logits[0, masked_index_token, :]
# print(masked_token_logits)
top_5_tokens = torch.argsort(masked_token_logits, descending=True)[:5].tolist()
value = tokenizer.decode(collator_output["labels"][0, masked_index_token - 1])
pred_dict[value] = [tokenizer.decode(token) for token in top_5_tokens]
return pred_dict