| import torch |
|
|
| class Dataset(torch.utils.data.Dataset): |
| """ |
| This class loads and preprocesses the given text data |
| """ |
| def __init__(self, paths, tokenizer): |
| """ |
| This function initialises the object. It takes the given paths and tokeniser. |
| """ |
| |
| self.paths = paths[:len(paths)-1] |
| self.tokenizer = tokenizer |
| self.data = self.read_file(self.paths[0]) |
| self.current_file = 1 |
| self.remaining = len(self.data) |
| self.encodings = self.get_encodings(self.data) |
|
|
| def __len__(self): |
| """ |
| returns the lenght of the ds |
| """ |
| return 10000*len(self.paths) |
| |
| def read_file(self, path): |
| """ |
| reads a given file |
| """ |
| with open(path, 'r', encoding='utf-8') as f: |
| lines = f.read().split('\n') |
| return lines |
|
|
| def get_encodings(self, lines_all): |
| """ |
| Creates encodings for a given text input |
| """ |
| |
| batch = self.tokenizer(lines_all, max_length=512, padding='max_length', truncation=True) |
|
|
| |
| labels = torch.tensor(batch['input_ids']) |
| |
| mask = torch.tensor(batch['attention_mask']) |
|
|
| |
| input_ids = labels.detach().clone() |
| rand = torch.rand(input_ids.shape) |
|
|
| |
| mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 2) * (input_ids != 3) |
| |
| input_ids[mask_arr] = 4 |
| |
| return {'input_ids':input_ids, 'attention_mask':mask, 'labels':labels} |
|
|
| def __getitem__(self, i): |
| """ |
| returns item i |
| Note: do not use shuffling for this dataset |
| """ |
| |
| if self.remaining == 0: |
| self.data = self.read_file(self.paths[self.current_file]) |
| self.current_file += 1 |
| self.remaining = len(self.data) |
| self.encodings = self.get_encodings(self.data) |
| |
| |
| if self.current_file == len(self.paths): |
| self.current_file = 0 |
| |
| self.remaining -= 1 |
| return {key: tensor[i%10000] for key, tensor in self.encodings.items()} |
|
|
| def test_model(model, optim, test_ds_loader, device): |
| """ |
| This function tests whether the parameters of the model that are frozen change, the ones that are not frozen do change, |
| and whether any parameters become NaN or Inf |
| :param model: model to be tested |
| :param optim: optimiser used for training |
| :param test_ds_loader: dataset to perform the forward pass on |
| :param device: current device |
| :raises Exception: if any of the above conditions are not met |
| """ |
| |
|
|
| |
| params = [ np for np in model.named_parameters() if np[1].requires_grad ] |
| initial_params = [ (name, p.clone()) for (name, p) in params ] |
|
|
| params_frozen = [ np for np in model.named_parameters() if not np[1].requires_grad ] |
| initial_params_frozen = [ (name, p.clone()) for (name, p) in params_frozen ] |
|
|
| optim.zero_grad() |
|
|
| |
| batch = next(iter(test_ds_loader)) |
|
|
| input_ids = batch['input_ids'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| labels = batch['labels'].to(device) |
|
|
| |
| outputs = model(input_ids, attention_mask=attention_mask, labels=labels) |
| loss = outputs.loss |
| loss.backward() |
| optim.step() |
|
|
| |
| for (_, p0), (name, p1) in zip(initial_params, params): |
| |
| try: |
| assert not torch.equal(p0.to(device), p1.to(device)) |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='did not change!' |
| ) |
| ) |
| |
| try: |
| assert not torch.isnan(p1).byte().any() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is NaN!' |
| ) |
| ) |
| |
| try: |
| assert torch.isfinite(p1).byte().all() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is Inf!' |
| ) |
| ) |
| |
| |
| for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen): |
| |
| try: |
| assert torch.equal(p0.to(device), p1.to(device)) |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='changed!' |
| ) |
| ) |
| |
| try: |
| assert not torch.isnan(p1).byte().any() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is NaN!' |
| ) |
| ) |
| |
| |
| try: |
| assert torch.isfinite(p1).byte().all() |
| except AssertionError: |
| raise Exception( |
| "{var_name} {msg}".format( |
| var_name=name, |
| msg='is Inf!' |
| ) |
| ) |
| print("Passed") |