| from transformers.modeling_utils import PreTrainedModel |
| from torch import nn |
| from transformers.models.bert.configuration_bert import BertConfig |
| from transformers.models.bert.modeling_bert import BertModel |
| import torch |
| import torch.nn.functional as F |
| class BertChunker(PreTrainedModel): |
|
|
| config_class = BertConfig |
|
|
| def __init__(self, config, ): |
| super().__init__(config) |
|
|
| self.model = BertModel(config) |
| self.chunklayer = nn.Linear(384, 2) |
|
|
| def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs): |
| model_output = self.model( |
| input_ids=input_ids, attention_mask=attention_mask, **kwargs |
| ) |
| token_embeddings = model_output[0] |
| logits = self.chunklayer(token_embeddings) |
| model_output["logits"]=logits |
| loss = None |
| logits = logits.contiguous() |
| if labels: |
| labels = labels.contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss() |
| |
| logits = logits.view(-1, logits.shape[-1]) |
| labels = labels.view(-1) |
| |
| labels = labels.to(labels.device) |
| loss = loss_fct(logits, labels) |
| model_output["loss"]=loss |
|
|
| return model_output |
|
|
| def chunk_text(self, text:str, tokenizer, prob_threshold=0.5)->list[str]: |
| |
| MAX_TOKENS=255 |
| tokens=tokenizer(text, return_tensors="pt",truncation=False) |
| input_ids=tokens['input_ids'].to(self.device) |
| attention_mask=tokens['attention_mask'][:,0:MAX_TOKENS] |
| attention_mask=attention_mask.to(self.device) |
| CLS=input_ids[:,0].unsqueeze(0) |
| SEP=input_ids[:,-1].unsqueeze(0) |
| input_ids=input_ids[:,1:-1] |
| self.eval() |
| split_str_poses=[] |
|
|
| windows_start =0 |
| windows_end= 0 |
|
|
| while windows_end <= input_ids.shape[1]: |
| windows_end= windows_start + MAX_TOKENS-2 |
|
|
| ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1) |
|
|
| ids=ids.to(self.device) |
| |
| output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1],device=self.device)) |
| logits = output['logits'][:, 1:-1,:] |
|
|
| chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] |
| chunk_decision = (chunk_probabilities>prob_threshold) |
| greater_rows_indices = torch.where(chunk_decision)[1].tolist() |
|
|
| |
| if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)): |
|
|
| split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices] |
|
|
| split_str_poses += split_str_pos |
|
|
| windows_start = greater_rows_indices[-1] + windows_start |
|
|
| else: |
|
|
| windows_start = windows_end |
|
|
| substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])] |
| return substrings |
|
|
| def chunk_text_fast( |
| self, text: str, tokenizer, batchsize=20, prob_threshold=0.5 |
| ) -> list[str]: |
| |
| self.eval() |
|
|
| split_str_poses=[] |
| MAX_TOKENS = 255 |
| USEFUL_TOKENS = MAX_TOKENS - 2 |
| tokens = tokenizer(text, return_tensors="pt", truncation=False) |
| input_ids = tokens["input_ids"] |
|
|
|
|
| CLS = tokenizer.cls_token_id |
|
|
| SEP = tokenizer.sep_token_id |
|
|
| input_ids = input_ids[:, 1:-1].squeeze().contiguous() |
|
|
| token_num = input_ids.shape[0] |
| seq_num = input_ids.shape[0] // (USEFUL_TOKENS) |
| left_token_num = input_ids.shape[0] % (USEFUL_TOKENS) |
|
|
| if seq_num > 0: |
|
|
| reshaped_input_ids = input_ids[: seq_num * USEFUL_TOKENS].view( seq_num, USEFUL_TOKENS ) |
|
|
| i = torch.arange(seq_num).unsqueeze(1) |
| j = torch.arange(USEFUL_TOKENS).repeat(seq_num, 1) |
|
|
| bias = 1 |
| position_id = i * (USEFUL_TOKENS) + j + bias |
| position_id = position_id.to(self.device) |
| reshaped_input_ids = torch.cat( |
| ( |
| torch.full((reshaped_input_ids.shape[0], 1), CLS), |
| reshaped_input_ids, |
| torch.full((reshaped_input_ids.shape[0], 1), SEP), |
| ), |
| 1, |
| ) |
|
|
| batch_num = seq_num // batchsize |
| left_seq_num = seq_num % batchsize |
| for i in range(batch_num): |
| batch_input = reshaped_input_ids[i : i + batchsize, :].to(self.device) |
| attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device) |
| output = self(input_ids=batch_input, attention_mask=attention_mask) |
| logits = output['logits'][:, 1:-1,:] |
| |
|
|
| chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] |
| chunk_decision = (chunk_probabilities>prob_threshold) |
| |
| pos = chunk_decision * position_id[i : i + batchsize, :] |
| pos = pos[pos>0].tolist() |
| split_str_poses += [tokens.token_to_chars(p).start for p in pos] |
| if left_seq_num > 0: |
| batch_input = reshaped_input_ids[-left_seq_num:, :].to(self.device) |
| attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device) |
| output = self(input_ids=batch_input, attention_mask=attention_mask) |
| logits = output['logits'][:, 1:-1,:] |
| chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] |
| chunk_decision = (chunk_probabilities>prob_threshold) |
| pos = chunk_decision * position_id[-left_seq_num:, :] |
| pos = pos[pos>0].tolist() |
| split_str_poses += [tokens.token_to_chars(p).start for p in pos] |
|
|
| if left_token_num > 0: |
| left_input_ids = torch.cat([torch.tensor([CLS]), input_ids[-left_token_num:], torch.tensor([SEP])]) |
| left_input_ids = left_input_ids.unsqueeze(0).to(self.device) |
| attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device) |
| output = self(input_ids=left_input_ids, attention_mask=attention_mask) |
| logits = output['logits'][:, 1:-1,:] |
| chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1] |
| chunk_decision = (chunk_probabilities>prob_threshold) |
| bias = token_num - (left_input_ids.shape[1] - 2) + 1 |
| pos = (torch.where(chunk_decision)[1] + bias).tolist() |
| split_str_poses += [tokens.token_to_chars(p).start for p in pos] |
| |
| substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])] |
| return substrings |
|
|