| """ |
| Deprecated network.py module. This file only exists to support backwards-compatibility |
| with old pickle files. See lib/__init__.py for more information. |
| """ |
|
|
| from __future__ import print_function |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.autograd import Variable |
| from torch.nn.parameter import Parameter |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| def choose(matrix, idxs): |
| if isinstance(idxs, Variable): |
| idxs = idxs.data |
| assert(matrix.ndimension() == 2) |
| unrolled_idxs = idxs + \ |
| torch.arange(0, matrix.size(0)).type_as(idxs) * matrix.size(1) |
| return matrix.view(matrix.nelement())[unrolled_idxs] |
|
|
|
|
| class Network(nn.Module): |
| """ |
| Todo: |
| - Beam search |
| - check if this is right? attend during P->FC rather than during softmax->P? |
| - allow length 0 inputs/targets |
| - give n_examples as input to FC |
| - Initialise new weights randomly, rather than as zeroes |
| """ |
|
|
| def __init__( |
| self, |
| input_vocabulary, |
| target_vocabulary, |
| hidden_size=512, |
| embedding_size=128, |
| cell_type="LSTM"): |
| """ |
| :param list input_vocabulary: list of possible inputs |
| :param list target_vocabulary: list of possible targets |
| """ |
| super(Network, self).__init__() |
| self.h_input_encoder_size = hidden_size |
| self.h_output_encoder_size = hidden_size |
| self.h_decoder_size = hidden_size |
| self.embedding_size = embedding_size |
| self.input_vocabulary = input_vocabulary |
| self.target_vocabulary = target_vocabulary |
| |
| self.v_input = len(input_vocabulary) |
| |
| self.v_target = len(target_vocabulary) |
|
|
| self.cell_type = cell_type |
| if cell_type == 'GRU': |
| self.input_encoder_cell = nn.GRUCell( |
| input_size=self.v_input + 1, |
| hidden_size=self.h_input_encoder_size, |
| bias=True) |
| self.input_encoder_init = Parameter( |
| torch.rand(1, self.h_input_encoder_size)) |
| self.output_encoder_cell = nn.GRUCell( |
| input_size=self.v_input + |
| 1 + |
| self.h_input_encoder_size, |
| hidden_size=self.h_output_encoder_size, |
| bias=True) |
| self.decoder_cell = nn.GRUCell( |
| input_size=self.v_target + 1, |
| hidden_size=self.h_decoder_size, |
| bias=True) |
| if cell_type == 'LSTM': |
| self.input_encoder_cell = nn.LSTMCell( |
| input_size=self.v_input + 1, |
| hidden_size=self.h_input_encoder_size, |
| bias=True) |
| self.input_encoder_init = nn.ParameterList([Parameter(torch.rand( |
| 1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size))]) |
| self.output_encoder_cell = nn.LSTMCell( |
| input_size=self.v_input + |
| 1 + |
| self.h_input_encoder_size, |
| hidden_size=self.h_output_encoder_size, |
| bias=True) |
| self.output_encoder_init_c = Parameter( |
| torch.rand(1, self.h_output_encoder_size)) |
| self.decoder_cell = nn.LSTMCell( |
| input_size=self.v_target + 1, |
| hidden_size=self.h_decoder_size, |
| bias=True) |
| self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size)) |
|
|
| self.W = nn.Linear( |
| self.h_output_encoder_size + |
| self.h_decoder_size, |
| self.embedding_size) |
| self.V = nn.Linear(self.embedding_size, self.v_target + 1) |
| self.input_A = nn.Bilinear( |
| self.h_input_encoder_size, |
| self.h_output_encoder_size, |
| 1, |
| bias=False) |
| self.output_A = nn.Bilinear( |
| self.h_output_encoder_size, |
| self.h_decoder_size, |
| 1, |
| bias=False) |
| self.input_EOS = torch.zeros(1, self.v_input + 1) |
| self.input_EOS[:, -1] = 1 |
| self.input_EOS = Parameter(self.input_EOS) |
| self.output_EOS = torch.zeros(1, self.v_input + 1) |
| self.output_EOS[:, -1] = 1 |
| self.output_EOS = Parameter(self.output_EOS) |
| self.target_EOS = torch.zeros(1, self.v_target + 1) |
| self.target_EOS[:, -1] = 1 |
| self.target_EOS = Parameter(self.target_EOS) |
|
|
| def __getstate__(self): |
| if hasattr(self, 'opt'): |
| return dict([(k, v) for k, v in self.__dict__.items( |
| ) if k is not 'opt'] + [('optstate', self.opt.state_dict())]) |
| |
| |
| else: |
| return self.__dict__ |
|
|
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| |
| if isinstance(self.input_encoder_init, tuple): |
| self.input_encoder_init = nn.ParameterList( |
| list(self.input_encoder_init)) |
|
|
| def clear_optimiser(self): |
| if hasattr(self, 'opt'): |
| del self.opt |
| if hasattr(self, 'optstate'): |
| del self.optstate |
|
|
| def get_optimiser(self): |
| self.opt = torch.optim.Adam(self.parameters(), lr=0.001) |
| if hasattr(self, 'optstate'): |
| self.opt.load_state_dict(self.optstate) |
|
|
| def optimiser_step(self, inputs, outputs, target): |
| if not hasattr(self, 'opt'): |
| self.get_optimiser() |
| score = self.score(inputs, outputs, target, autograd=True).mean() |
| (-score).backward() |
| self.opt.step() |
| self.opt.zero_grad() |
| return score.data[0] |
|
|
| def set_target_vocabulary(self, target_vocabulary): |
| if target_vocabulary == self.target_vocabulary: |
| return |
|
|
| V_weight = [] |
| V_bias = [] |
| decoder_ih = [] |
|
|
| for i in range(len(target_vocabulary)): |
| if target_vocabulary[i] in self.target_vocabulary: |
| j = self.target_vocabulary.index(target_vocabulary[i]) |
| V_weight.append(self.V.weight.data[j:j + 1]) |
| V_bias.append(self.V.bias.data[j:j + 1]) |
| decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1]) |
| else: |
| V_weight.append(torch.zeros(1, self.V.weight.size(1))) |
| V_bias.append(torch.ones(1) * -10) |
| decoder_ih.append( |
| torch.zeros( |
| self.decoder_cell.weight_ih.data.size(0), 1)) |
|
|
| V_weight.append(self.V.weight.data[-1:]) |
| V_bias.append(self.V.bias.data[-1:]) |
| decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:]) |
|
|
| self.target_vocabulary = target_vocabulary |
| self.v_target = len(target_vocabulary) |
| self.target_EOS.data = torch.zeros(1, self.v_target + 1) |
| self.target_EOS.data[:, -1] = 1 |
|
|
| self.V.weight.data = torch.cat(V_weight, dim=0) |
| self.V.bias.data = torch.cat(V_bias, dim=0) |
| self.V.out_features = self.V.bias.data.size(0) |
|
|
| self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1) |
| self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1) |
|
|
| self.clear_optimiser() |
|
|
| def input_encoder_get_init(self, batch_size): |
| if self.cell_type == "GRU": |
| return self.input_encoder_init.repeat(batch_size, 1) |
| if self.cell_type == "LSTM": |
| return tuple(x.repeat(batch_size, 1) |
| for x in self.input_encoder_init) |
|
|
| def output_encoder_get_init(self, input_encoder_h): |
| if self.cell_type == "GRU": |
| return input_encoder_h |
| if self.cell_type == "LSTM": |
| return ( |
| input_encoder_h, |
| self.output_encoder_init_c.repeat( |
| input_encoder_h.size(0), |
| 1)) |
|
|
| def decoder_get_init(self, output_encoder_h): |
| if self.cell_type == "GRU": |
| return output_encoder_h |
| if self.cell_type == "LSTM": |
| return ( |
| output_encoder_h, |
| self.decoder_init_c.repeat( |
| output_encoder_h.size(0), |
| 1)) |
|
|
| def cell_get_h(self, cell_state): |
| if self.cell_type == "GRU": |
| return cell_state |
| if self.cell_type == "LSTM": |
| return cell_state[0] |
|
|
| def score(self, inputs, outputs, target, autograd=False): |
| inputs = self.inputsToTensors(inputs) |
| outputs = self.inputsToTensors(outputs) |
| target = self.targetToTensor(target) |
| target, score = self.run(inputs, outputs, target=target, mode="score") |
| |
| if autograd: |
| return score |
| else: |
| return score.data |
|
|
| def sample(self, inputs, outputs): |
| inputs = self.inputsToTensors(inputs) |
| outputs = self.inputsToTensors(outputs) |
| target, score = self.run(inputs, outputs, mode="sample") |
| target = self.tensorToOutput(target) |
| return target |
|
|
| def sampleAndScore(self, inputs, outputs, nRepeats=None): |
| inputs = self.inputsToTensors(inputs) |
| outputs = self.inputsToTensors(outputs) |
| if nRepeats is None: |
| target, score = self.run(inputs, outputs, mode="sample") |
| target = self.tensorToOutput(target) |
| return target, score.data |
| else: |
| target = [] |
| score = [] |
| for i in range(nRepeats): |
| |
| t, s = self.run(inputs, outputs, mode="sample") |
| t = self.tensorToOutput(t) |
| target.extend(t) |
| score.extend(list(s.data)) |
| return target, score |
|
|
| def run(self, inputs, outputs, target=None, mode="sample"): |
| """ |
| :param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input) |
| :param List[LongTensor] inputs: n_examples * (max_length_input * batch_size) |
| :param List[LongTensor] target: max_length_target * batch_size |
| """ |
| assert((mode == "score" and target is not None) or mode == "sample") |
|
|
| n_examples = len(inputs) |
| max_length_input = [inputs[j].size(0) for j in range(n_examples)] |
| max_length_output = [outputs[j].size(0) for j in range(n_examples)] |
| max_length_target = target.size(0) if target is not None else 10 |
| batch_size = inputs[0].size(1) |
|
|
| score = Variable(torch.zeros(batch_size)) |
| inputs_scatter = [Variable(torch.zeros(max_length_input[j], batch_size, self.v_input + 1).scatter_( |
| 2, inputs[j][:, :, None], 1)) for j in range(n_examples)] |
| outputs_scatter = [Variable(torch.zeros(max_length_output[j], batch_size, self.v_input + 1).scatter_( |
| 2, outputs[j][:, :, None], 1)) for j in range(n_examples)] |
| if target is not None: |
| target_scatter = Variable(torch.zeros(max_length_target, |
| batch_size, |
| self.v_target + 1).scatter_(2, |
| target[:, |
| :, |
| None], |
| 1)) |
|
|
| |
|
|
| |
| input_H = [] |
| input_embeddings = [] |
| |
| input_attention_mask = [] |
| for j in range(n_examples): |
| active = torch.Tensor(max_length_input[j], batch_size).byte() |
| active[0, :] = 1 |
| state = self.input_encoder_get_init(batch_size) |
| hs = [] |
| for i in range(max_length_input[j]): |
| state = self.input_encoder_cell( |
| inputs_scatter[j][i, :, :], state) |
| if i + 1 < max_length_input[j]: |
| active[i + 1, :] = active[i, :] * \ |
| (inputs[j][i, :] != self.v_input) |
| h = self.cell_get_h(state) |
| hs.append(h[None, :, :]) |
| input_H.append(torch.cat(hs, 0)) |
| embedding_idx = active.sum(0).long() - 1 |
| embedding = input_H[j].gather(0, Variable( |
| embedding_idx[None, :, None].repeat(1, 1, self.h_input_encoder_size)))[0] |
| input_embeddings.append(embedding) |
| input_attention_mask.append(Variable(active.float().log())) |
|
|
| |
|
|
| def input_attend(j, h_out): |
| """ |
| 'general' attention from https://arxiv.org/pdf/1508.04025.pdf |
| :param j: Index of example |
| :param h_out: batch_size * h_output_encoder_size |
| """ |
| scores = self.input_A( |
| input_H[j].view( |
| max_length_input[j] * batch_size, |
| self.h_input_encoder_size), |
| h_out.view( |
| batch_size, |
| self.h_output_encoder_size).repeat( |
| max_length_input[j], |
| 1)).view( |
| max_length_input[j], |
| batch_size) + input_attention_mask[j] |
| c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0) |
| return c |
|
|
| |
| output_H = [] |
| output_embeddings = [] |
| |
| output_attention_mask = [] |
| for j in range(n_examples): |
| active = torch.Tensor(max_length_output[j], batch_size).byte() |
| active[0, :] = 1 |
| state = self.output_encoder_get_init(input_embeddings[j]) |
| hs = [] |
| h = self.cell_get_h(state) |
| for i in range(max_length_output[j]): |
| state = self.output_encoder_cell(torch.cat( |
| [outputs_scatter[j][i, :, :], input_attend(j, h)], 1), state) |
| if i + 1 < max_length_output[j]: |
| active[i + 1, :] = active[i, :] * \ |
| (outputs[j][i, :] != self.v_input) |
| h = self.cell_get_h(state) |
| hs.append(h[None, :, :]) |
| output_H.append(torch.cat(hs, 0)) |
| embedding_idx = active.sum(0).long() - 1 |
| embedding = output_H[j].gather(0, Variable( |
| embedding_idx[None, :, None].repeat(1, 1, self.h_output_encoder_size)))[0] |
| output_embeddings.append(embedding) |
| output_attention_mask.append(Variable(active.float().log())) |
|
|
| |
|
|
| def output_attend(j, h_dec): |
| """ |
| 'general' attention from https://arxiv.org/pdf/1508.04025.pdf |
| :param j: Index of example |
| :param h_dec: batch_size * h_decoder_size |
| """ |
| scores = self.output_A( |
| output_H[j].view( |
| max_length_output[j] * batch_size, |
| self.h_output_encoder_size), |
| h_dec.view( |
| batch_size, |
| self.h_decoder_size).repeat( |
| max_length_output[j], |
| 1)).view( |
| max_length_output[j], |
| batch_size) + output_attention_mask[j] |
| c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0) |
| return c |
|
|
| |
| target = target if mode == "score" else torch.zeros( |
| max_length_target, batch_size).long() |
| decoder_states = [ |
| self.decoder_get_init( |
| output_embeddings[j]) for j in range(n_examples)] |
| active = torch.ones(batch_size).byte() |
| for i in range(max_length_target): |
| FC = [] |
| for j in range(n_examples): |
| h = self.cell_get_h(decoder_states[j]) |
| p_aug = torch.cat([h, output_attend(j, h)], 1) |
| FC.append(F.tanh(self.W(p_aug)[None, :, :])) |
| |
| m = torch.max(torch.cat(FC, 0), 0)[0] |
| logsoftmax = F.log_softmax(self.V(m), dim=1) |
| if mode == "sample": |
| target[i, :] = torch.multinomial( |
| logsoftmax.data.exp(), 1)[:, 0] |
| score = score + \ |
| choose(logsoftmax, target[i, :]) * Variable(active.float()) |
| active *= (target[i, :] != self.v_target) |
| for j in range(n_examples): |
| if mode == "score": |
| target_char_scatter = target_scatter[i, :, :] |
| elif mode == "sample": |
| target_char_scatter = Variable(torch.zeros( |
| batch_size, self.v_target + 1).scatter_(1, target[i, :, None], 1)) |
| decoder_states[j] = self.decoder_cell( |
| target_char_scatter, decoder_states[j]) |
| return target, score |
|
|
| def inputsToTensors(self, inputss): |
| """ |
| :param inputss: size = nBatch * nExamples |
| """ |
| tensors = [] |
| for j in range(len(inputss[0])): |
| inputs = [x[j] for x in inputss] |
| maxlen = max(len(s) for s in inputs) |
| t = torch.ones( |
| 1 if maxlen == 0 else maxlen + 1, |
| len(inputs)).long() * self.v_input |
| for i in range(len(inputs)): |
| s = inputs[i] |
| if len(s) > 0: |
| t[:len(s), i] = torch.LongTensor( |
| [self.input_vocabulary.index(x) for x in s]) |
| tensors.append(t) |
| return tensors |
|
|
| def targetToTensor(self, targets): |
| """ |
| :param targets: |
| """ |
| maxlen = max(len(s) for s in targets) |
| t = torch.ones( |
| 1 if maxlen == 0 else maxlen + 1, |
| len(targets)).long() * self.v_target |
| for i in range(len(targets)): |
| s = targets[i] |
| if len(s) > 0: |
| t[:len(s), i] = torch.LongTensor( |
| [self.target_vocabulary.index(x) for x in s]) |
| return t |
|
|
| def tensorToOutput(self, tensor): |
| """ |
| :param tensor: max_length * batch_size |
| """ |
| out = [] |
| for i in range(tensor.size(1)): |
| l = tensor[:, i].tolist() |
| if l[0] == self.v_target: |
| out.append([]) |
| elif self.v_target in l: |
| final = tensor[:, i].tolist().index(self.v_target) |
| out.append([self.target_vocabulary[x] |
| for x in tensor[:final, i]]) |
| else: |
| out.append([self.target_vocabulary[x] for x in tensor[:, i]]) |
| return out |
|
|