| from transformers import PretrainedConfig, PreTrainedTokenizerBase |
| from freegroup import tools |
|
|
| class GreedyConfig(PretrainedConfig): |
|
|
| @classmethod |
| def from_tokenizer(cls, freegroup_dimension, tokenizer: PreTrainedTokenizerBase, **kwargs): |
| config = cls( |
| vocab_size = len(tokenizer), |
| eos_token_id = tokenizer.eos_token_id, |
| pad_token_id = tokenizer.pad_token_id, |
| **kwargs |
| ) |
| config._from_tokenizer(freegroup_dimension, tokenizer) |
| return config |
|
|
| def _from_tokenizer(self, freegroup_dimension, tokenizer): |
|
|
| freegroup_generators = list(range(1, freegroup_dimension + 1)) |
|
|
| self.reciprocals = [] |
| for x in freegroup_generators: |
| a, b = tokenizer.convert_tokens_to_ids([str(x), str(-x)]) |
| self.reciprocals.append([a, b]) |
|
|
| self.reducables = [[] for _ in range(freegroup_dimension + 1)] |
| for reducable, closure_generator in zip(self.reducables, [[x] for x in freegroup_generators] + [freegroup_generators[::]]): |
| reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, closure_generator)))) |
| reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, tools.reciprocal(closure_generator))))) |
|
|
|
|
| def __init__(self, **kwargs): |
| |
| self.reciprocals = kwargs.pop('reciprocals', None) |
|
|
| |
| self.reducables = kwargs.pop('reducables', None) |
|
|
| super().__init__(**kwargs) |
|
|
|
|
|
|