from transformers import AutoTokenizer, AutoModelForMaskedLM, BertTokenizerFast, DataCollatorForLanguageModeling import torch def mask_and_unmask( text: str, tokenizer: AutoTokenizer | BertTokenizerFast, model: AutoModelForMaskedLM, data_collator: DataCollatorForLanguageModeling, ) -> str: collator_input = tokenizer(text) collator_input["labels"] = collator_input["input_ids"].copy() collator_output = data_collator([collator_input]) masked_text = tokenizer.decode(collator_output["input_ids"][0]) pred_dict = {"masked_text": masked_text} inputs = tokenizer(masked_text, return_tensors="pt", padding="max_length", truncation=True) token_logits = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits all_masked_token_index = torch.argwhere(inputs["input_ids"] == tokenizer.mask_token_id) if all_masked_token_index.size()[0] != 0: for i, masked_index_token in enumerate(all_masked_token_index[:, 1]): # print(masked_index_token) masked_token_logits = token_logits[0, masked_index_token, :] # print(masked_token_logits) top_5_tokens = torch.argsort(masked_token_logits, descending=True)[:5].tolist() value = tokenizer.decode(collator_output["labels"][0, masked_index_token - 1]) pred_dict[value] = [tokenizer.decode(token) for token in top_5_tokens] return pred_dict