| import torch |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_greedy import GreedyConfig |
| from freegroup import tools |
|
|
| class GreedyModel(PreTrainedModel): |
| config_class = GreedyConfig |
| |
| def __init__(self, config: GreedyConfig): |
| super().__init__(config) |
| self.stub = torch.nn.parameter.Parameter(torch.tensor(0.)) |
|
|
| def _reduce_step(self, token, stack, reducables): |
| stack.append(token.item()) |
|
|
| for reducable in self.config.reciprocals + reducables: |
| n = len(reducable) |
| if len(stack) >= len(reducable): |
| if tools.occurs(stack[-n:], reducable * 2): |
| del stack[-n:] |
|
|
| return stack |
|
|
| def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| past = kwargs.pop('past', None) |
| return {'input_ids': input_ids, 'past': past} |
|
|
| def forward(self, input_ids = None, past = None, **kwargs): |
|
|
| assert (input_ids is not None), "Can't be None" |
| |
| batch_size, sequence_length = input_ids.shape |
|
|
| if past is None: |
| stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)] |
| hidden_states = None |
| else: |
| stacks, hidden_states = past |
|
|
| begin_idx = 0 if hidden_states is None else hidden_states.size(0) |
|
|
| for t in range(begin_idx, sequence_length): |
| last_hidden_states = torch.zeros((batch_size, self.config.vocab_size)) |
| |
| for batch_idx, word in enumerate(input_ids): |
| for stack, reducables in zip(stacks[batch_idx], self.config.reducables): |
| |
| self._reduce_step(word[t], stack, reducables) |
| if not stack: continue |
| |
| last = stack[-1] |
|
|
| for r in reducables: |
| if not last in r: |
| key = r[0] |
| last_hidden_states[batch_idx][r[0]] += 1 |
| if last in r: |
| pos = r.index(last) |
| key = r[(pos + 1) % len(r)] |
| last_hidden_states[batch_idx][key] += 1 |
| for r in self.config.reciprocals: |
| if last in r: |
| pos = r.index(last) |
| key = r[(pos + 1) % len(r)] |
| last_hidden_states[batch_idx][key] += 1 |
|
|
| for r in self.config.reciprocals: |
| if word[t] in r: |
| pos = r.index(word[t]) |
| key = r[(pos + 1) % len(r)] |
| last_hidden_states[batch_idx][key] = -torch.inf |
| |
| if all(map(lambda x: len(x) == 0, stacks[batch_idx])): |
| last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf |
|
|
| if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0) |
| else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0))) |
|
|
| return CausalLMOutputWithPast( |
| logits = hidden_states.permute(1, 0, 2), |
| past_key_values = (stacks, hidden_states) |
| ) |