| import torch |
| import pandas as pd |
|
|
|
|
| def get_prompt_length(tokenizer, prompt): |
| return len(tokenizer.encode(prompt)) |
|
|
|
|
| def tokenize_multipart_input( |
| tokenizer, |
| input_text_list: list, |
| max_seq_len: int, |
| template=None, |
| prompt=None, |
| ): |
| """This function is an adaptation of the `tokenize_multipart_input` found in princeton-nlp's repository |
| at https://github.com/princeton-nlp/LM-BFF/blob/main/src/dataset.py. |
| |
| Modifications include: |
| - Extension of automatic prompt generation for multi-label classification. |
| - Removal of parameters like `first_sent_limit`, `other_sent_limit`, `gpt3`, `truncate_head`, and `support_labels`. |
| - Optimization of the code flow. |
| |
| Args: |
| tokenizer: a pre-trained tokenizer from Hugging Face Transformers |
| input_text_list (list): documents ready for tokenization. |
| max_seq_len (int): max sequence length after adding the prompt along with special tokens from BERT. |
| template (str, optional): placeholder for the prompt. |
| prompt (str, optional): the prompt we use for input text. |
| """ |
|
|
| def enc(text): |
| return tokenizer.encode(text, add_special_tokens=False) |
|
|
| input_ids = [] |
| attention_mask = [] |
| token_type_ids = [] |
| mask_pos = None |
|
|
| if prompt: |
| special_token_mapping = { |
| "cls": tokenizer.cls_token_id, |
| "mask": tokenizer.mask_token_id, |
| "sep": tokenizer.sep_token_id, |
| "sep+": tokenizer.sep_token_id, |
| } |
| |
| if prompt != "auto": |
| template = template.replace("[PROMPT]", prompt) |
| template_list = template.split("*") |
| if prompt == "auto": |
| |
| cls_pos = template_list.index("cls") |
| if template_list[cls_pos + 1] == "": |
| |
| |
| prompt = template_list[cls_pos + 3] |
| elif template_list[cls_pos + 1] != "" and ( |
| template_list[cls_pos + 1].startswith("_") |
| ): |
| |
| |
| prompt = template_list[cls_pos + 1] |
| if prompt.startswith("_"): |
| prompt = prompt[1:] |
| segment_id = 0 |
|
|
| for part in template_list: |
| new_tokens = [] |
| segment_plus_1_flag = False |
| if part in special_token_mapping: |
| new_tokens.append(special_token_mapping[part]) |
| if part == "sep+": |
| segment_plus_1_flag = True |
| elif part[:5] == "sent_" or part[:6] == "+sent_": |
| sent_id = int(part.split("_")[1]) |
| max_len = max_seq_len - 3 - get_prompt_length(tokenizer, prompt) |
| |
| tokens = enc(input_text_list[sent_id])[-max_len:] |
| new_tokens += tokens |
| else: |
| |
| part = part.replace("_", " ") |
| |
| if len(part) == 1: |
| new_tokens.append(tokenizer.convert_tokens_to_ids(part)) |
| else: |
| new_tokens += enc(part) |
|
|
| input_ids += new_tokens |
| attention_mask += [1 for i in range(len(new_tokens))] |
| token_type_ids += [segment_id for i in range(len(new_tokens))] |
|
|
| if segment_plus_1_flag: |
| segment_id += 1 |
|
|
| mask_pos = [input_ids.index(tokenizer.mask_token_id)] |
| |
| assert mask_pos[0] < max_seq_len |
|
|
| else: |
| input_ids = [tokenizer.cls_token_id] |
| attention_mask = [1] |
| token_type_ids = [0] |
| max_len = max_seq_len - 2 |
|
|
| for sent_id, input_text in enumerate(input_text_list): |
| if input_text is None: |
| |
| continue |
| if pd.isna(input_text) or input_text is None: |
| |
| input_text = "" |
| input_tokens = enc(input_text)[:max_len] + [tokenizer.sep_token_id] |
| input_ids += input_tokens |
| attention_mask += [1 for i in range(len(input_tokens))] |
| token_type_ids += [sent_id for i in range(len(input_tokens))] |
|
|
| return input_ids, attention_mask, token_type_ids, mask_pos |
|
|
|
|
| class InferenceDataset(torch.utils.data.Dataset): |
| """ |
| A class for creating the CGMH dataset in PyTorch. |
| Currently, this class supports: |
| (1) Few-shot data (e.g., train_size=16) |
| (2) Small-size data (e.g., train_size>100) |
| --- |
| Attributes |
| data (pd.DataFrame): the CGMH dataset |
| tokenizer: a pre-trained HuggingFace tokenizer |
| max_seq_len (int): maximum length for a sequence |
| template (_type_, optional): template for the model. Defaults to None. |
| prompt (_type_, optional): prompt for the model. Defaults to None. |
| """ |
|
|
| def __init__( |
| self, |
| input_text: str, |
| tokenizer, |
| max_seq_len: int, |
| template=None, |
| prompt=None, |
| ): |
| self.doc = input_text |
| self.template = template |
| self.prompt = prompt |
| self.tokenizer = tokenizer |
| self.max_seq_len = max_seq_len |
|
|
| def __getitem__(self, idx): |
| input_ids, attn_mask, segs, mask_pos = tokenize_multipart_input( |
| tokenizer=self.tokenizer, |
| input_text_list=[self.doc], |
| template=self.template, |
| prompt=self.prompt, |
| max_seq_len=self.max_seq_len, |
| ) |
| item = { |
| "input_ids": input_ids, |
| "token_type_ids": segs, |
| "attention_mask": attn_mask, |
| } |
| if self.prompt: |
| item["mask_pos"] = mask_pos |
| return item |
|
|
| def __len__(self): |
| return 1 |
|
|