| import os |
| from typing import Any, Dict, List, Optional, Union, Callable |
|
|
| import torch |
| from transformers import BertTokenizer |
| from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList |
| from transformers.generation.utils import GenerationConfig, GenerateOutput |
| from transformers.utils import ModelOutput |
|
|
|
|
| class TSGenerationMixin(GenerationMixin): |
| tokenizer = BertTokenizer.from_pretrained(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bert_config'), local_files_only=True) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| text_inputs=None, |
| text_input_ids: Optional[torch.Tensor] = None, |
| text_attention_mask: Optional[torch.Tensor] = None, |
| text_token_type_ids: Optional[torch.Tensor] = None, |
| vision_inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| logits_processor: Optional[LogitsProcessorList] = None, |
| stopping_criteria: Optional[StoppingCriteriaList] = None, |
| prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| synced_gpus: Optional[bool] = None, |
| assistant_model: Optional["PreTrainedModel"] = None, |
| streamer: Optional["BaseStreamer"] = None, |
| negative_prompt_ids: Optional[torch.Tensor] = None, |
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, |
| revin: Optional[bool] = True, |
| num_samples: Optional[int] = 1, |
| max_output_length: Optional[int] = 96, |
| inference_token_len: Optional[int] = None, |
| max_text_token_length: Optional[int] = 125, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.Tensor]: |
| if len(inputs.shape) != 2: |
| raise ValueError('Input shape must be: [batch_size, seq_len]') |
| if revin: |
| means = inputs.mean(dim=-1, keepdim=True) |
| stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5 |
| inputs = (inputs - means) / stdev |
| if text_inputs is not None: |
| tokenized_text = self._tokenize(text_inputs, max_length=max_text_token_length) |
| text_input_ids = tokenized_text['input_ids'].squeeze(0) |
| text_attention_mask = tokenized_text['attention_mask'].squeeze(0) |
| text_token_type_ids = tokenized_text.get('token_type_ids', torch.zeros_like(text_input_ids)).squeeze(0) |
|
|
| model_inputs = self.prepare_inputs_for_generation( |
| inputs, |
| text_input_ids=text_input_ids, |
| text_attention_mask=text_attention_mask, |
| text_token_type_ids=text_token_type_ids, |
| vision_inputs=vision_inputs, |
| generation_config=generation_config, |
| max_output_length=max_output_length, |
| inference_token_len=inference_token_len, |
| **kwargs |
| ) |
|
|
| outputs = self(**model_inputs, return_dict=True, revin=False, num_samples=num_samples) |
|
|
| predictions = outputs.logits |
|
|
| if revin: |
| stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1) |
| means = means.unsqueeze(1).repeat(1, num_samples, 1) |
| predictions = (predictions * stdev) + means |
|
|
| return predictions |
|
|
| def prepare_inputs_for_generation( |
| self, |
| inputs: torch.Tensor, |
| text_input_ids: Optional[torch.Tensor] = None, |
| text_attention_mask: Optional[torch.Tensor] = None, |
| text_token_type_ids: Optional[torch.Tensor] = None, |
| vision_inputs: Optional[torch.Tensor] = None, |
| generation_config: Optional[GenerationConfig] = None, |
| max_output_length: Optional[int] = None, |
| inference_token_len: Optional[int] = None, |
| **kwargs |
| ): |
| return { |
| "input_ids": inputs, |
| "text_input_ids": text_input_ids, |
| "text_attention_mask": text_attention_mask, |
| "text_token_type_ids": text_token_type_ids, |
| "vision_ids": vision_inputs, |
| "max_output_length": max_output_length, |
| "inference_token_len": inference_token_len, |
| **kwargs |
| } |
|
|
| def _tokenize(self, texts, max_length): |
| return self.tokenizer( |
| texts, |
| padding='max_length', |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt" |
| ) |
|
|
| def _update_model_kwargs_for_generation( |
| self, |
| outputs: ModelOutput, |
| model_kwargs: Dict[str, Any], |
| horizon_length: int = 1, |
| is_encoder_decoder: bool = False, |
| standardize_cache_format: bool = False, |
| ) -> Dict[str, Any]: |
| return model_kwargs |
|
|