| from torch import nn |
| import torch |
| from typing import Optional |
| import copy |
| import pandas as pd |
|
|
| """ |
| This module contains the implementation of the QA model. We define three different models and a dataset class. |
| The structure is based on the Hugging Face implementations. |
| https://huggingface.co/docs/transformers/model_doc/distilbert |
| """ |
|
|
| class SimpleQuestionDistilBERT(nn.Module): |
| """ |
| This class implements a simple version of the distilbert question answering model, following the implementation of Hugging Face, |
| https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/distilbert/modeling_distilbert.py#L805 |
| |
| It basically fine-tunes a given distilbert model. We only add one linear layer on top, which determines the start and end logits. |
| """ |
| def __init__(self, distilbert, dropout=0.1): |
| """ |
| Creates and initialises model |
| """ |
| super(SimpleQuestionDistilBERT, self).__init__() |
|
|
| self.distilbert = distilbert |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| self.classifier = nn.Linear(768, 2) |
|
|
| |
| def init_weights(m): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| m.bias.data.fill_(0.01) |
| self.classifier.apply(init_weights) |
|
|
|
|
| def forward(self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| start_positions: Optional[torch.Tensor] = None, |
| end_positions: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None): |
| """ |
| This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits. |
| """ |
| |
| distilbert_output = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| hidden_states = distilbert_output[0] |
| hidden_states = self.dropout(hidden_states) |
|
|
| |
| logits = self.classifier(hidden_states) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| |
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1) |
|
|
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| return {"loss": total_loss, |
| "start_logits": start_logits, |
| "end_logits": end_logits, |
| "hidden_states": distilbert_output.hidden_states, |
| "attentions": distilbert_output.attentions} |
|
|
|
|
| class QuestionDistilBERT(nn.Module): |
| """ |
| This class implements the distilbert question answering model. We fix all layers of the base model and only fine-tune the head. |
| The head consists of a transformer encoder with three layers and a classifier on top. |
| """ |
| def __init__(self, distilbert, dropout=0.1): |
| """ |
| Creates and initialises QuestionDIstilBERT instance |
| """ |
| super(QuestionDistilBERT, self).__init__() |
|
|
| |
| for param in distilbert.parameters(): |
| param.requires_grad = False |
|
|
| self.distilbert = distilbert |
| self.relu = nn.ReLU() |
|
|
| self.dropout = nn.Dropout(dropout) |
| self.te = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=768, nhead=12), num_layers=3) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Dropout(dropout), |
| nn.ReLU(), |
| nn.Linear(768, 512), |
| nn.Dropout(dropout), |
| nn.ReLU(), |
| nn.Linear(512, 256), |
| nn.Dropout(dropout), |
| nn.ReLU(), |
| nn.Linear(256, 128), |
| nn.Dropout(dropout), |
| nn.ReLU(), |
| nn.Linear(128, 64), |
| nn.Dropout(dropout), |
| nn.ReLU(), |
| nn.Linear(64, 2) |
| ) |
|
|
| |
| def init_weights(m): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| m.bias.data.fill_(0.01) |
|
|
| self.classifier.apply(init_weights) |
|
|
| def forward(self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| start_positions: Optional[torch.Tensor] = None, |
| end_positions: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None): |
| """ |
| This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits. |
| """ |
| |
| distilbert_output = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| hidden_states = distilbert_output[0] |
| hidden_states = self.dropout(hidden_states) |
| attn_output = self.te(hidden_states) |
|
|
| |
| logits = self.classifier(attn_output) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| |
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1) |
|
|
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| return {"loss": total_loss, |
| "start_logits": start_logits, |
| "end_logits": end_logits, |
| "hidden_states": distilbert_output.hidden_states, |
| "attentions": distilbert_output.attentions} |
|
|
|
|
| class ReuseQuestionDistilBERT(nn.Module): |
| """ |
| This class imports a model where all layers of the base distilbert model are fixed. |
| Instead of training a completely new head, we copy the last two layers of the base model and add a classifier on top. |
| """ |
| def __init__(self, distilbert, dropout=0.15): |
| """ |
| Creates and initialises QuestionDIstilBERT instance |
| """ |
| super(ReuseQuestionDistilBERT, self).__init__() |
| self.te = copy.deepcopy(list(list(distilbert.children())[1].children())[0][-2:]) |
| |
| for param in distilbert.parameters(): |
| param.requires_grad = False |
|
|
| self.distilbert = distilbert |
| self.relu = nn.ReLU() |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.classifier = nn.Linear(768, 2) |
|
|
| def init_weights(m): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| m.bias.data.fill_(0.01) |
| self.classifier.apply(init_weights) |
|
|
| def forward(self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| start_positions: Optional[torch.Tensor] = None, |
| end_positions: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None): |
| """ |
| This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits. |
| """ |
| |
| distilbert_output = self.distilbert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| |
| hidden_states = distilbert_output[0] |
| hidden_states = self.dropout(hidden_states) |
| for te in self.te: |
| hidden_states = te( |
| x=hidden_states, |
| attn_mask=attention_mask, |
| head_mask=head_mask, |
| output_attentions=output_attentions |
| )[0] |
| hidden_states = self.dropout(hidden_states) |
|
|
| |
| logits = self.classifier(hidden_states) |
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| |
| total_loss = None |
| if start_positions is not None and end_positions is not None: |
| if len(start_positions.size()) > 1: |
| start_positions = start_positions.squeeze(-1) |
| if len(end_positions.size()) > 1: |
| end_positions = end_positions.squeeze(-1) |
|
|
| |
| ignored_index = start_logits.size(1) |
| start_positions = start_positions.clamp(0, ignored_index) |
| end_positions = end_positions.clamp(0, ignored_index) |
|
|
| loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) |
| start_loss = loss_fct(start_logits, start_positions) |
| end_loss = loss_fct(end_logits, end_positions) |
| total_loss = (start_loss + end_loss) / 2 |
|
|
| return {"loss": total_loss, |
| "start_logits": start_logits, |
| "end_logits": end_logits, |
| "hidden_states": distilbert_output.hidden_states, |
| "attentions": distilbert_output.attentions} |
|
|
| class Dataset(torch.utils.data.Dataset): |
| """ |
| This class creates a dataset for the DistilBERT qa-model. |
| """ |
| def __init__(self, squad_paths, natural_question_paths, hotpotqa_paths, tokenizer): |
| """ |
| creates and initialises dataset object |
| """ |
| self.paths = [] |
| self.count = 0 |
| if squad_paths != None: |
| self.paths.extend(squad_paths[:len(squad_paths)-1]) |
| if natural_question_paths != None: |
| self.paths.extend(natural_question_paths[:len(natural_question_paths)-1]) |
| if hotpotqa_paths != None: |
| self.paths.extend(hotpotqa_paths[:len(hotpotqa_paths)-1]) |
| self.data = None |
| self.current_file = 0 |
| self.remaining = 0 |
| self.encodings = None |
| |
| self.tokenizer = tokenizer |
| |
| |
| def __len__(self): |
| """ |
| returns the length of the dataset |
| """ |
| return len(self.paths)*1000 |
| |
| def read_file(self, path): |
| """ |
| reads the file stored at path |
| """ |
| with open(path, 'r', encoding='utf-8') as f: |
| lines = f.read().split('\n') |
| return lines |
|
|
| def get_encodings(self): |
| """ |
| returns encoded strings for the model |
| """ |
| |
| questions = [q.strip() for q in self.data["question"]] |
| context = [q.lower() for q in self.data["context"]] |
| |
| |
| inputs = self.tokenizer( |
| questions, |
| context, |
| max_length=512, |
| truncation="only_second", |
| return_offsets_mapping=True, |
| padding="max_length", |
| ) |
|
|
| |
| offset_mapping = inputs.pop("offset_mapping") |
| |
| answers = self.data["answer"] |
| answer_start = self.data["answer_start"] |
|
|
| |
| start_positions = [] |
| end_positions = [] |
|
|
| |
| for i, offset in enumerate(offset_mapping): |
|
|
| answer = answers[i] |
| start_char = int(answer_start[i]) |
| end_char = start_char + len(answer) |
|
|
| sequence_ids = inputs.sequence_ids(i) |
|
|
| |
| idx = 0 |
| while sequence_ids[idx] != 1: |
| idx += 1 |
|
|
| context_start = idx |
| while sequence_ids[idx] == 1: |
| idx += 1 |
| context_end = idx - 1 |
| |
| |
| if offset[context_start][0] > end_char or offset[context_end][1] < start_char: |
| start_positions.append(0) |
| end_positions.append(0) |
| self.count += 1 |
| else: |
| |
| idx = context_start |
| while idx <= context_end and offset[idx][0] <= start_char: |
| idx += 1 |
|
|
| start_positions.append(idx - 1) |
| idx = context_end |
| while idx >= context_start and offset[idx][1] >= end_char: |
| idx -= 1 |
| end_positions.append(idx + 1) |
|
|
| |
| inputs["start_positions"] = start_positions |
| inputs["end_positions"] = end_positions |
| |
| return {'input_ids': torch.tensor(inputs['input_ids']), |
| 'attention_mask': torch.tensor(inputs['attention_mask']), |
| 'start_positions': torch.tensor(inputs['start_positions']), |
| 'end_positions': torch.tensor(inputs['end_positions'])} |
|
|
| def __getitem__(self, i): |
| """ |
| returns encoding of item i |
| """ |
| |
| |
| if self.remaining == 0: |
| self.data = self.read_file(self.paths[self.current_file]) |
| self.data = pd.DataFrame([line.split("\t") for line in self.data], |
| columns=["context", "question", "answer", "answer_start"]) |
| self.current_file += 1 |
| self.remaining = len(self.data) |
| self.encodings = self.get_encodings() |
| |
| if self.current_file == len(self.paths): |
| self.current_file = 0 |
| self.remaining -= 1 |
| return {key: tensor[i%1000] for key, tensor in self.encodings.items()} |
|
|
| def test_model(model, optim, test_ds_loader, device): |
| """ |
| This function is used to test the model's functionality, namely if params are not NaN and infinite, |
| not-frozen parameters have to change, frozen ones must not |
| :param model: pytorch model to evaluate |
| :param optim: optimizer |
| :param test_ds_loader: dataloader object |
| :param device: device, the model is on |
| :raises Exception if the model doesn't work as expected |
| """ |
| |
|
|
| |
| params = [np for np in model.named_parameters() if np[1].requires_grad] |
| initial_params = [(name, p.clone()) for (name, p) in params] |
|
|
| |
| params_frozen = [np for np in model.named_parameters() if not np[1].requires_grad] |
| initial_params_frozen = [(name, p.clone()) for (name, p) in params_frozen] |
|
|
| |
| optim.zero_grad() |
| batch = next(iter(test_ds_loader)) |
|
|
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| start_positions = batch['start_positions'].to(device) |
| end_positions = batch['end_positions'].to(device) |
|
|
| |
| outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, |
| end_positions=end_positions) |
| loss = outputs['loss'] |
| loss.backward() |
| optim.step() |
|
|
| |
| for (_, p0), (name, p1) in zip(initial_params, params): |
| |
| try: |
| assert not torch.equal(p0.to(device), p1.to(device)) |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='did not change!' |
| ) |
| ) |
| |
| try: |
| assert not torch.isnan(p1).byte().any() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is NaN!' |
| ) |
| ) |
| |
| try: |
| assert torch.isfinite(p1).byte().all() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is Inf!' |
| ) |
| ) |
|
|
| |
| for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen): |
| |
| try: |
| assert torch.equal(p0.to(device), p1.to(device)) |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='changed!' |
| ) |
| ) |
| |
| try: |
| assert not torch.isnan(p1).byte().any() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is NaN!' |
| ) |
| ) |
|
|
| |
| try: |
| assert torch.isfinite(p1).byte().all() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is Inf!' |
| ) |
| ) |
| print("Passed") |
|
|