| ''' |
| Author: Qiguang Chen |
| Date: 2023-01-11 10:39:26 |
| LastEditors: Qiguang Chen |
| LastEditTime: 2023-01-31 20:07:00 |
| Description: |
| |
| ''' |
| import random |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
|
|
| from model.decoder import decoder_utils |
|
|
| from torchcrf import CRF |
|
|
| from common.utils import HiddenData, OutputData, InputData, ClassifierOutputData, unpack_sequence, pack_sequence, \ |
| instantiate |
|
|
|
|
| class BaseClassifier(nn.Module): |
| """Base class for all classifier module |
| """ |
| def __init__(self, **config): |
| super().__init__() |
| self.config = config |
| if config.get("loss_fn"): |
| self.loss_fn = instantiate(config.get("loss_fn")) |
| else: |
| self.loss_fn = CrossEntropyLoss(ignore_index=self.config.get("ignore_index")) |
|
|
| def forward(self, *args, **kwargs): |
| raise NotImplementedError("No implemented classifier.") |
|
|
| def decode(self, output: OutputData, |
| target: InputData = None, |
| return_list=True, |
| return_sentence_level=None): |
| """decode output logits |
| |
| Args: |
| output (OutputData): output logits data |
| target (InputData, optional): input data with attention mask. Defaults to None. |
| return_list (bool, optional): if True return list else return torch Tensor.. Defaults to True. |
| return_sentence_level (_type_, optional): if True decode sentence level intent else decode token level intent. Defaults to None. |
| |
| Returns: |
| List or Tensor: decoded sequence ids |
| """ |
| if self.config.get("return_sentence_level") is not None and return_sentence_level is None: |
| return_sentence_level = self.config.get("return_sentence_level") |
| elif self.config.get("return_sentence_level") is None and return_sentence_level is None: |
| return_sentence_level = False |
| return decoder_utils.decode(output, target, |
| return_list=return_list, |
| return_sentence_level=return_sentence_level, |
| pred_type=self.config.get("mode"), |
| use_multi=self.config.get("use_multi"), |
| multi_threshold=self.config.get("multi_threshold")) |
|
|
| def compute_loss(self, pred: OutputData, target: InputData): |
| """compute loss |
| |
| Args: |
| pred (OutputData): output logits data |
| target (InputData): input golden data |
| |
| Returns: |
| Tensor: loss result |
| """ |
| _CRF = None |
| if self.config.get("use_crf"): |
| _CRF = self.CRF |
| return decoder_utils.compute_loss(pred, target, criterion_type=self.config["mode"], |
| use_crf=_CRF is not None, |
| ignore_index=self.config["ignore_index"], |
| use_multi=self.config.get("use_multi"), |
| loss_fn=self.loss_fn, |
| CRF=_CRF) |
|
|
|
|
| class LinearClassifier(BaseClassifier): |
| """ |
| Decoder structure based on Linear. |
| """ |
| def __init__(self, **config): |
| """Construction function for LinearClassifier |
| |
| Args: |
| config (dict): |
| input_dim (int): hidden state dim. |
| use_slot (bool): whether to classify slot label. |
| slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. |
| use_intent (bool): whether to classify intent label. |
| intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. |
| use_crf (bool): whether to use crf for slot. |
| """ |
| super().__init__(**config) |
| self.config = config |
| if config.get("use_slot"): |
| self.slot_classifier = nn.Linear(config["input_dim"], config["slot_label_num"]) |
| if self.config.get("use_crf"): |
| self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) |
| if config.get("use_intent"): |
| self.intent_classifier = nn.Linear(config["input_dim"], config["intent_label_num"]) |
|
|
| def forward(self, hidden: HiddenData): |
| if self.config.get("use_intent"): |
| return ClassifierOutputData(self.intent_classifier(hidden.get_intent_hidden_state())) |
| if self.config.get("use_slot"): |
| return ClassifierOutputData(self.slot_classifier(hidden.get_slot_hidden_state())) |
|
|
|
|
|
|
| class AutoregressiveLSTMClassifier(BaseClassifier): |
| """ |
| Decoder structure based on unidirectional LSTM. |
| """ |
|
|
| def __init__(self, **config): |
| """ Construction function for Decoder. |
| |
| Args: |
| config (dict): |
| input_dim (int): input dimension of Decoder. In fact, it's encoder hidden size. |
| use_slot (bool): whether to classify slot label. |
| slot_label_num (int, optional): the number of slot label. Enabled if use_slot is True. |
| use_intent (bool): whether to classify intent label. |
| intent_label_num (int, optional): the number of intent label. Enabled if use_intent is True. |
| use_crf (bool): whether to use crf for slot. |
| hidden_dim (int): hidden dimension of iterative LSTM. |
| embedding_dim (int): if it's not None, the input and output are relevant. |
| dropout_rate (float): dropout rate of network which is only useful for embedding. |
| """ |
|
|
| super(AutoregressiveLSTMClassifier, self).__init__(**config) |
| if config.get("use_slot") and config.get("use_crf"): |
| self.CRF = CRF(num_tags=config["slot_label_num"], batch_first=True) |
| self.input_dim = config["input_dim"] |
| self.hidden_dim = config["hidden_dim"] |
| if config.get("use_intent"): |
| self.output_dim = config["intent_label_num"] |
| if config.get("use_slot"): |
| self.output_dim = config["slot_label_num"] |
| self.dropout_rate = config["dropout_rate"] |
| self.embedding_dim = config.get("embedding_dim") |
| self.force_ratio = config.get("force_ratio") |
| self.config = config |
| self.ignore_index = config.get("ignore_index") if config.get("ignore_index") is not None else -100 |
| |
| |
| if self.embedding_dim is not None: |
| self.embedding_layer = nn.Embedding(self.output_dim, self.embedding_dim) |
| self.init_tensor = nn.Parameter( |
| torch.randn(1, self.embedding_dim), |
| requires_grad=True |
| ) |
|
|
| |
| if self.embedding_dim is not None: |
| lstm_input_dim = self.input_dim + self.embedding_dim |
| else: |
| lstm_input_dim = self.input_dim |
|
|
| |
| self.dropout_layer = nn.Dropout(self.dropout_rate) |
| self.lstm_layer = nn.LSTM( |
| input_size=lstm_input_dim, |
| hidden_size=self.hidden_dim, |
| batch_first=True, |
| bidirectional=self.config["bidirectional"], |
| dropout=self.dropout_rate, |
| num_layers=self.config["layer_num"] |
| ) |
| self.linear_layer = nn.Linear( |
| self.hidden_dim, |
| self.output_dim |
| ) |
| |
|
|
| def forward(self, hidden: HiddenData, internal_interaction=None, **interaction_args): |
| """ Forward process for decoder. |
| |
| :param internal_interaction: |
| :param hidden: |
| :return: is distribution of prediction labels. |
| """ |
| input_tensor = hidden.slot_hidden |
| seq_lens = hidden.inputs.attention_mask.sum(-1).detach().cpu().tolist() |
| output_tensor_list, sent_start_pos = [], 0 |
| input_tensor = pack_sequence(input_tensor, seq_lens) |
| forced_input = None |
| if self.training: |
| if random.random() < self.force_ratio: |
| if self.config["mode"]=="slot": |
|
|
| forced_slot = pack_sequence(hidden.inputs.slot, seq_lens) |
| temp_slot = [] |
| for index, x in enumerate(forced_slot): |
| if index == 0: |
| temp_slot.append(x.reshape(1)) |
| elif x == self.ignore_index: |
| temp_slot.append(temp_slot[-1]) |
| else: |
| temp_slot.append(x.reshape(1)) |
| forced_input = torch.cat(temp_slot, 0) |
| if self.config["mode"]=="token-level-intent": |
| forced_intent = hidden.inputs.intent.unsqueeze(1).repeat(1, hidden.inputs.slot.shape[1]) |
| forced_input = pack_sequence(forced_intent, seq_lens) |
| if self.embedding_dim is None or forced_input is not None: |
|
|
| for sent_i in range(0, len(seq_lens)): |
| sent_end_pos = sent_start_pos + seq_lens[sent_i] |
|
|
| |
| seg_hiddens = input_tensor[sent_start_pos: sent_end_pos, :] |
|
|
| if self.embedding_dim is not None and forced_input is not None: |
| if seq_lens[sent_i] > 1: |
| seg_forced_input = forced_input[sent_start_pos: sent_end_pos] |
|
|
| seg_forced_tensor = self.embedding_layer(seg_forced_input)[:-1] |
| seg_prev_tensor = torch.cat([self.init_tensor, seg_forced_tensor], dim=0) |
| else: |
| seg_prev_tensor = self.init_tensor |
|
|
| |
| combined_input = torch.cat([seg_hiddens, seg_prev_tensor], dim=1) |
| else: |
| combined_input = seg_hiddens |
| dropout_input = self.dropout_layer(combined_input) |
| lstm_out, _ = self.lstm_layer(dropout_input.view(1, seq_lens[sent_i], -1)) |
| if internal_interaction is not None: |
| interaction_args["sent_id"] = sent_i |
| lstm_out = internal_interaction(torch.transpose(lstm_out, 0, 1), **interaction_args)[:, 0] |
| linear_out = self.linear_layer(lstm_out.view(seq_lens[sent_i], -1)) |
|
|
| output_tensor_list.append(linear_out) |
| sent_start_pos = sent_end_pos |
| else: |
| for sent_i in range(0, len(seq_lens)): |
| prev_tensor = self.init_tensor |
|
|
| |
| |
| last_h, last_c = None, None |
|
|
| sent_end_pos = sent_start_pos + seq_lens[sent_i] |
| for word_i in range(sent_start_pos, sent_end_pos): |
| seg_input = input_tensor[[word_i], :] |
| combined_input = torch.cat([seg_input, prev_tensor], dim=1) |
| dropout_input = self.dropout_layer(combined_input).view(1, 1, -1) |
| if last_h is None and last_c is None: |
| lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input) |
| else: |
| lstm_out, (last_h, last_c) = self.lstm_layer(dropout_input, (last_h, last_c)) |
|
|
| if internal_interaction is not None: |
| interaction_args["sent_id"] = sent_i |
| lstm_out = internal_interaction(lstm_out, **interaction_args)[:, 0] |
|
|
| lstm_out = self.linear_layer(lstm_out.view(1, -1)) |
| output_tensor_list.append(lstm_out) |
|
|
| _, index = lstm_out.topk(1, dim=1) |
| prev_tensor = self.embedding_layer(index).view(1, -1) |
| sent_start_pos = sent_end_pos |
| seq_unpacked = unpack_sequence(torch.cat(output_tensor_list, dim=0), seq_lens) |
| |
| if self.config.get("use_multi"): |
| pred_output = ClassifierOutputData(seq_unpacked) |
| else: |
| pred_output = ClassifierOutputData(F.log_softmax(seq_unpacked, dim=-1)) |
| return pred_output |
|
|
|
|
| class MLPClassifier(BaseClassifier): |
| """ |
| Decoder structure based on MLP. |
| """ |
| def __init__(self, **config): |
| """ Construction function for Decoder. |
| |
| Args: |
| config (dict): |
| use_slot (bool): whether to classify slot label. |
| use_intent (bool): whether to classify intent label. |
| mlp (List): |
| |
| - _model_target_: torch.nn.Linear |
| |
| in_features (int): input feature dim |
| |
| out_features (int): output feature dim |
| |
| - _model_target_: torch.nn.LeakyReLU |
| |
| negative_slope: 0.2 |
| |
| - ... |
| """ |
| super(MLPClassifier, self).__init__(**config) |
| self.config = config |
| for i, x in enumerate(config["mlp"]): |
| if isinstance(x.get("in_features"), str): |
| config["mlp"][i]["in_features"] = self.config[x["in_features"][1:-1]] |
| if isinstance(x.get("out_features"), str): |
| config["mlp"][i]["out_features"] = self.config[x["out_features"][1:-1]] |
| mlp = [instantiate(x) for x in config["mlp"]] |
| self.seq = nn.Sequential(*mlp) |
|
|
|
|
| def forward(self, hidden: HiddenData): |
| if self.config.get("use_intent"): |
| res = self.seq(hidden.intent_hidden) |
| else: |
| res = self.seq(hidden.slot_hidden) |
| return ClassifierOutputData(res) |
|
|