| from dataclasses import dataclass |
| from typing import Optional, List |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig |
| import regex as re |
| import torch |
| import torch.nn.functional as F |
|
|
| PROGRAM_SPECIAL_TOKEN="<extra_id_124>" |
| UTTERANCES_SPECIAL_TOKEN="<extra_id_123>" |
| GT_PROGRAM_SPECIAL_TOKEN="<extra_id_122>" |
|
|
| def consistent(rx, spec): |
| |
| for s, label in spec: |
| if not label in ['+', '-']: |
| return None |
| try: |
| if re.fullmatch(rx, s, timeout=1): |
| if label == '-': |
| return False |
| else: |
| if label == '+': |
| return False |
| except re.error: |
| return None |
| except TimeoutError: |
| return None |
|
|
| return True |
|
|
| def get_utterance_processing_functions(label_pos, idx, separator=' '): |
| if label_pos == "suffix": |
| if idx: |
| def utterances_to_string(spec): |
| return ''.join([f"<extra_id_{i}>{s}{label}" for i, (s, label) in enumerate(spec)]) |
| else: |
| def utterances_to_string(spec): |
| return separator.join([f"{s}{label}" for s, label in spec]) |
| else: |
| if idx: |
| def utterances_to_string(spec): |
| return ''.join([f"<extra_id_{i}>{label}{s}" for i, (s, label) in enumerate(spec)]) |
| else: |
| def utterances_to_string(spec): |
| return separator.join([f"{label}{s}" for s, label in spec]) |
| |
| if label_pos == "suffix": |
| if idx: |
| def string_to_utterances(string): |
| string = re.sub(r'<extra_id_\d+>', ' ', string) |
| return [(s[:-1], s[-1]) for s in string.split(' ') if len(s) > 0] |
| else: |
| def string_to_utterances(string): |
| return [(s[:-1], s[-1]) for s in string.split(separator) if len(s) > 0] |
| else: |
| if idx: |
| def string_to_utterances(string): |
| string = re.sub(r'<extra_id_\d+>', '', string) |
| return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] |
| else: |
| def string_to_utterances(string): |
| return [(s[1:], s[0]) for s in string.split(separator) if len(s) > 0] |
| |
| return utterances_to_string, string_to_utterances |
|
|
| def decode(c): |
| if c < 3: |
| return f"<{c}>" |
| elif c < 258: |
| return chr(c - 3) |
| else: |
| return f"<extra_id_{c - 259}>" |
| |
| def byt5_decode_batch(outputs, skip_special_tokens=True, skip_position_token=False): |
| skipped_tokens = outputs |
| if skip_special_tokens: |
| skipped_tokens = [ |
| [[t for t in x if t >= 3] for x in beam] |
| for beam in skipped_tokens |
| ] |
| |
| if skip_position_token: |
| skipped_tokens = [ |
| [[t for t in x if t <= 258] for x in beam] |
| for beam in skipped_tokens |
| ] |
|
|
| return [ |
| [''.join([decode(t) for t in x]) for x in beam] |
| for beam in skipped_tokens |
| ] |
|
|
| class Agent: |
| def __init__(self, |
| model_path: str, |
| gen_config: dict, |
| device: str = "cuda", |
| ): |
| self.device = device |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| self.gen_config = GenerationConfig(**gen_config) |
|
|
| @dataclass |
| class ListenerOutput: |
| programs: List[List[str]] |
| idx: Optional[List[List[int]]] = None |
| decoded: Optional[List[List[str]]] = None |
| decoded_scores: Optional[List[List[float]]] = None |
| pruned: Optional[List[List[str]]] = None |
|
|
|
|
| class Listener(Agent): |
| def __init__(self, |
| model_path, |
| gen_config, |
| device="cuda", |
| label_pos="suffix", |
| idx: bool=True, |
| program_special_token=PROGRAM_SPECIAL_TOKEN, |
| utterances_special_token=UTTERANCES_SPECIAL_TOKEN |
| ): |
| super().__init__( |
| model_path, |
| gen_config, |
| device=device |
| ) |
| self.label_pos = label_pos |
| self.idx = idx |
| self.program_special_token = program_special_token |
| self.utterances_special_token = utterances_special_token |
| self.utterances_to_string, self.string_to_utterances = ( |
| get_utterance_processing_functions( |
| label_pos, idx, separator=utterances_special_token |
| ) |
| ) |
| |
| def synthesize(self, context, return_scores=False, enforce_consistency=True): |
| |
| if isinstance(context[0], list): |
| context_str = list(map(self.utterances_to_string, context)) |
| else: |
| context_str = context |
|
|
| context_tokens = self.tokenizer( |
| [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c |
| for c in context_str], |
| return_tensors="pt", |
| padding=True |
| ).to(self.device) |
| |
| decoder_inputs = self.tokenizer( |
| [self.program_special_token for _ in context], return_tensors="pt", |
| add_special_tokens=False |
| ).to(self.device) |
|
|
| outputs = self.model.generate(**context_tokens, |
| decoder_input_ids=decoder_inputs.input_ids, |
| generation_config=self.gen_config, |
| return_dict_in_generate=True, |
| output_scores=True |
| ) |
|
|
| decoded_batch = byt5_decode_batch(outputs.sequences.reshape((len(context), -1, outputs.sequences.shape[-1])).tolist(), skip_position_token=True, skip_special_tokens=True) |
|
|
| consistent_programs = [] |
| idxs = [] |
| for decoded, ctx in zip(decoded_batch, context): |
| cp = [] |
| idx = [] |
| for i, p in enumerate(decoded): |
| if enforce_consistency: |
| if consistent(p, ctx): |
| cp.append(p) |
| idx.append(i) |
| else: |
| cp.append(p) |
| idx.append(i) |
| |
| consistent_programs.append(cp) |
| idxs.append(idx) |
| |
| logprobs = torch.stack(outputs.scores, dim=1).log_softmax(dim=-1) |
| gen_probs = torch.gather(logprobs, 2, outputs.sequences[:, 1:, None]).squeeze(-1) |
| gen_probs.masked_fill_(gen_probs.isinf(), 0) |
| scores = gen_probs.sum(-1) |
| n_decoded = scores.shape[0] |
| n_seq = n_decoded // len(context) |
| scores = scores.reshape((len(context), n_seq)) |
| scores_list = scores.tolist() |
|
|
| if return_scores: |
| return ListenerOutput( |
| consistent_programs, |
| idxs, |
| decoded_batch, |
| scores_list |
| ) |
| else: |
| return ListenerOutput(consistent_programs) |
|
|
| |
| def score_program(self, contexts, programs): |
| if isinstance(contexts[0], list): |
| context_str = list(map(self.utterances_to_string, contexts)) |
| else: |
| context_str = contexts |
|
|
| context_tokens = self.tokenizer( |
| [f"{self.utterances_special_token}{c}" if not c.startswith(self.utterances_special_token) else c |
| for c in context_str], |
| return_tensors="pt", |
| padding=True |
| ).to(self.device) |
|
|
| program_tokens = self.tokenizer([f"{self.program_special_token}{p}" for p in programs], return_tensors="pt").to(self.device) |
| outputs = self.model(input_ids=context_tokens.input_ids, decoder_input_ids=program_tokens.input_ids, return_dict=True) |
| |
| logprobs = torch.gather(F.log_softmax(outputs.logits, dim=-1), 2, program_tokens.input_ids[:, 1:, None]).squeeze(-1) |
| |
| logprobs.masked_fill_(program_tokens.input_ids[:, 1:] == 0, 0) |
|
|
| scores = logprobs.sum(-1) |
| |
| return scores.tolist() |