| from typing import List |
| import torch |
|
|
| from common import utils |
| from common.utils import OutputData, InputData |
| from torch import Tensor |
|
|
| def argmax_for_seq_len(inputs, seq_lens, padding_value=-100): |
| packed_inputs = utils.pack_sequence(inputs, seq_lens) |
| outputs = torch.argmax(packed_inputs, dim=-1, keepdim=True) |
| return utils.unpack_sequence(outputs, seq_lens, padding_value).squeeze(-1) |
|
|
|
|
| def decode(output: OutputData, |
| target: InputData = None, |
| pred_type="slot", |
| multi_threshold=0.5, |
| ignore_index=-100, |
| return_list=True, |
| return_sentence_level=True, |
| use_multi=False, |
| use_crf=False, |
| CRF=None) -> List or Tensor: |
| """ decode output logits |
| |
| Args: |
| output (OutputData): output logits data |
| target (InputData, optional): input data with attention mask. Defaults to None. |
| pred_type (str, optional): prediction type in ["slot", "intent", "token-level-intent"]. Defaults to "slot". |
| multi_threshold (float, optional): multi intent decode threshold. Defaults to 0.5. |
| ignore_index (int, optional): align and pad token with ignore index. Defaults to -100. |
| return_list (bool, optional): if True return list else return torch Tensor. Defaults to True. |
| return_sentence_level (bool, optional): if True decode sentence level intent else decode token level intent. Defaults to True. |
| use_multi (bool, optional): whether to decode to multi intent. Defaults to False. |
| use_crf (bool, optional): whether to use crf. Defaults to False. |
| CRF (CRF, optional): CRF function. Defaults to None. |
| |
| Returns: |
| List or Tensor: decoded sequence ids |
| """ |
| if pred_type == "slot": |
| inputs = output.slot_ids |
| else: |
| inputs = output.intent_ids |
|
|
| if pred_type == "slot": |
| if not use_multi: |
| if use_crf: |
| res = CRF.decode(inputs, mask=target.attention_mask) |
| else: |
| res = torch.argmax(inputs, dim=-1) |
| else: |
| raise NotImplementedError("Multi-slot prediction is not supported.") |
| elif pred_type == "intent": |
| if not use_multi: |
| res = torch.argmax(inputs, dim=-1) |
| else: |
| res = (torch.sigmoid(inputs) > multi_threshold).nonzero() |
| if return_list: |
| res_index = res.detach().cpu().tolist() |
| res_list = [[] for _ in range(len(target.seq_lens))] |
| for item in res_index: |
| res_list[item[0]].append(item[1]) |
| return res_list |
| else: |
| return res |
| elif pred_type == "token-level-intent": |
| if not use_multi: |
| res = torch.argmax(inputs, dim=-1) |
| if not return_sentence_level: |
| return res |
| if return_list: |
| res = res.detach().cpu().tolist() |
| attention_mask = target.attention_mask |
| for i in range(attention_mask.shape[0]): |
| temp = [] |
| for j in range(attention_mask.shape[1]): |
| if attention_mask[i][j] == 1: |
| temp.append(res[i][j]) |
| else: |
| break |
| res[i] = temp |
| return [max(it, key=lambda v: it.count(v)) for it in res] |
| else: |
| seq_lens = target.seq_lens |
|
|
| if not return_sentence_level: |
| token_res = torch.cat([ |
| torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold |
| for i in range(len(seq_lens))], |
| dim=0) |
| return utils.unpack_sequence(token_res, seq_lens, padding_value=ignore_index) |
|
|
| intent_index_sum = torch.cat([ |
| torch.sum(torch.sigmoid(inputs[i, 0:seq_lens[i], :]) > multi_threshold, dim=0).unsqueeze(0) |
| for i in range(len(seq_lens))], |
| dim=0) |
|
|
| res = (intent_index_sum > torch.div(seq_lens, 2, rounding_mode='floor').unsqueeze(1)).nonzero() |
| if return_list: |
| res_index = res.detach().cpu().tolist() |
| res_list = [[] for _ in range(len(seq_lens))] |
| for item in res_index: |
| res_list[item[0]].append(item[1]) |
| return res_list |
| else: |
| return res |
| else: |
| raise NotImplementedError("Prediction mode except ['slot','intent','token-level-intent'] is not supported.") |
| if return_list: |
| res = res.detach().cpu().tolist() |
| return res |
|
|
|
|
| def compute_loss(pred: OutputData, |
| target: InputData, |
| criterion_type="slot", |
| use_crf=False, |
| ignore_index=-100, |
| loss_fn=None, |
| use_multi=False, |
| CRF=None): |
| """ compute loss |
| |
| Args: |
| pred (OutputData): output logits data |
| target (InputData): input golden data |
| criterion_type (str, optional): criterion type in ["slot", "intent", "token-level-intent"]. Defaults to "slot". |
| ignore_index (int, optional): compute loss with ignore index. Defaults to -100. |
| loss_fn (_type_, optional): loss function. Defaults to None. |
| use_crf (bool, optional): whether to use crf. Defaults to False. |
| CRF (CRF, optional): CRF function. Defaults to None. |
| |
| Returns: |
| Tensor: loss result |
| """ |
| if criterion_type == "slot": |
| if use_crf: |
| return -1 * CRF(pred.slot_ids, target.slot, target.get_slot_mask(ignore_index).byte()) |
| else: |
| pred_slot = utils.pack_sequence(pred.slot_ids, target.seq_lens) |
| target_slot = utils.pack_sequence(target.slot, target.seq_lens) |
| return loss_fn(pred_slot, target_slot) |
| elif criterion_type == "token-level-intent": |
| |
| intent_target = target.intent.unsqueeze(1) |
| if not use_multi: |
| intent_target = intent_target.repeat(1, pred.intent_ids.shape[1]) |
| else: |
| intent_target = intent_target.repeat(1, pred.intent_ids.shape[1], 1) |
| intent_pred = utils.pack_sequence(pred.intent_ids, target.seq_lens) |
| intent_target = utils.pack_sequence(intent_target, target.seq_lens) |
| return loss_fn(intent_pred, intent_target) |
| else: |
| return loss_fn(pred.intent_ids, target.intent) |
|
|