| ''' |
| Author: Qiguang Chen |
| Date: 2023-01-11 10:39:26 |
| LastEditors: Qiguang Chen |
| LastEditTime: 2023-01-31 18:22:36 |
| Description: |
| |
| ''' |
| from torch import nn |
|
|
| from common.utils import HiddenData, OutputData, InputData |
|
|
|
|
| class BaseDecoder(nn.Module): |
| """Base class for all decoder module. |
| |
| Notice: t is often only necessary to change this module and its sub-modules |
| """ |
| def __init__(self, intent_classifier=None, slot_classifier=None, interaction=None): |
| super().__init__() |
| self.intent_classifier = intent_classifier |
| self.slot_classifier = slot_classifier |
| self.interaction = interaction |
|
|
| def forward(self, hidden: HiddenData): |
| """forward |
| |
| Args: |
| hidden (HiddenData): encoded data |
| |
| Returns: |
| OutputData: prediction logits |
| """ |
| if self.interaction is not None: |
| hidden = self.interaction(hidden) |
| intent = None |
| slot = None |
| if self.intent_classifier is not None: |
| intent = self.intent_classifier(hidden) |
| if self.slot_classifier is not None: |
| slot = self.slot_classifier(hidden) |
| return OutputData(intent, slot) |
|
|
| def decode(self, output: OutputData, target: InputData = None): |
| """decode output logits |
| |
| Args: |
| output (OutputData): output logits data |
| target (InputData, optional): input data with attention mask. Defaults to None. |
| |
| Returns: |
| List: decoded sequence ids |
| """ |
| intent, slot = None, None |
| if self.intent_classifier is not None: |
| intent = self.intent_classifier.decode(output, target) |
| if self.slot_classifier is not None: |
| slot = self.slot_classifier.decode(output, target) |
| return OutputData(intent, slot) |
|
|
| def compute_loss(self, pred: OutputData, target: InputData, compute_intent_loss=True, compute_slot_loss=True): |
| """compute loss. |
| Notice: can set intent and slot loss weight by adding 'weight' config item in corresponding classifier configuration. |
| |
| Args: |
| pred (OutputData): output logits data |
| target (InputData): input golden data |
| compute_intent_loss (bool, optional): whether to compute intent loss. Defaults to True. |
| compute_slot_loss (bool, optional): whether to compute intent loss. Defaults to True. |
| |
| Returns: |
| Tensor: loss result |
| """ |
| loss = 0 |
| intent_loss = None |
| slot_loss = None |
| if self.intent_classifier is not None: |
| intent_loss = self.intent_classifier.compute_loss(pred, target) if compute_intent_loss else None |
| intent_weight = self.intent_classifier.config.get("weight") |
| intent_weight = intent_weight if intent_weight is not None else 1. |
| loss += intent_loss * intent_weight |
| if self.slot_classifier is not None: |
| slot_loss = self.slot_classifier.compute_loss(pred, target) if compute_slot_loss else None |
| slot_weight = self.slot_classifier.config.get("weight") |
| slot_weight = slot_weight if slot_weight is not None else 1. |
| loss += slot_loss * slot_weight |
| return loss, intent_loss, slot_loss |
|
|
|
|
| class StackPropagationDecoder(BaseDecoder): |
|
|
| def forward(self, hidden: HiddenData): |
| |
| pred_intent = self.intent_classifier(hidden) |
| |
| |
| hidden = self.interaction(pred_intent, hidden) |
| pred_slot = self.slot_classifier(hidden) |
| return OutputData(pred_intent, pred_slot) |
|
|
| class DCANetDecoder(BaseDecoder): |
|
|
| def forward(self, hidden: HiddenData): |
| if self.interaction is not None: |
| hidden = self.interaction(hidden, intent_emb=self.intent_classifier, slot_emb=self.slot_classifier) |
| return OutputData(self.intent_classifier(hidden), self.slot_classifier(hidden)) |
|
|
|
|