| import torch |
| import torch.nn as nn |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| import logging |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| prompt_string = "Generate test cases for a given Python function using its source code. The output should be a series of assert statements that verify the function's correctness. \n\nExample:\n\nInput:\n\ndef add(a, b):\n\treturn a + b\n \n\nOutput:\n\ndef test_add(a, b):\n\n\tassert add(1, 2) == 3\n\tassert add(-1, 1) == 0\n\tassert add(0, 0) == 0\n\tassert add(1.5, 2.5) == 4.0\n\nMy Input:\n\n" |
| max_length = 512 |
| params_inference = { |
| 'max_new_tokens': 512, |
| 'do_sample': True, |
| 'top_k': 100, |
| 'top_p': 0.85, |
| } |
| num_lines = 5 |
|
|
| gpt2_name="bigcode/gpt_bigcode-santacoder" |
| saved_model_path = '/Users/chervonikov_alexey/Desktop/result_gpt_bigcode_15_epochs/model_val_loss.pth' |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| class LargeCodeModelGPTBigCode(nn.Module): |
| '''Класс для большой языковой модели, которая обрабатывает входной код''' |
| def __init__(self, gpt2_name=gpt2_name, |
| prompt_string = prompt_string, |
| params_inference = params_inference, |
| max_length = max_length, |
| device = device, |
| saved_model_path = saved_model_path, |
| num_lines = num_lines, |
| flag_hugging_face = False, |
| flag_pretrained = False): |
| |
| ''' |
| Конструктор класса. Необходим для инициализации модели |
| |
| Параметры: |
| -gpt2_name: link модели на HuggingFace (default: "bigcode/gpt_bigcode-santacoder") |
| -prompt_string: дополнительная обертка для лучшего понимания моделью задачи (default: prompt_string) |
| -params_inference: параметры инференса (для использования self.gpt2.generate(**inputs, |
| **inference_params)) |
| -max_length: максимальное число токенов в последовательности (default: 512) |
| -device: устройство (default: cuda or cpu) |
| -saved_model_path: путь к затюненной модели |
| -num_lines: число линий (ввиду "неконечной" генерации модели) |
| -flag_hugging_face: флаг для использования HuggingFace (default: False) |
| -flag_pretrained: флаг для инициализации модели затюненными весами |
| |
| ''' |
| super(LargeCodeModelGPTBigCode, self).__init__() |
|
|
| self.new_special_tokens = ['<FUNC_TOKEN>', |
| '<INFO_TOKEN>', |
| '<CLS_TOKEN>', |
| '<AST_TOKEN>', |
| '<DESCRIPTION_TOKEN>', |
| '<COMMENTS_TOKEN>'] |
|
|
| self.special_tokens_dict = { |
| 'additional_special_tokens': self.new_special_tokens |
| } |
|
|
| self.tokenizerGPT = AutoTokenizer.from_pretrained(gpt2_name, padding_side='left') |
| self.tokenizerGPT.add_special_tokens({'pad_token': '<PAD>'}) |
| self.tokenizerGPT.add_special_tokens(self.special_tokens_dict) |
| self.gpt2 = AutoModelForCausalLM.from_pretrained(gpt2_name) |
| self.gpt2.resize_token_embeddings(len(self.tokenizerGPT)) |
| |
| self.max_length = max_length |
| self.inference_params = params_inference |
| self.additional_prompt = prompt_string |
| self.pretrained_path = saved_model_path |
| self.device = device |
| self.num_lines = num_lines |
|
|
| if flag_pretrained == True and flag_hugging_face == False: |
| if self.device == "cuda": |
| self.load_state_dict(torch.load(self.pretrained_path)) |
| else: |
| self.load_state_dict(torch.load(self.pretrained_path, |
| map_location=torch.device('cpu'))) |
|
|
| |
| def forward(self, input_ids, attention_mask, |
| response_ids): |
| |
| ''' |
| Forward call method |
| |
| Параметры: |
| -input_ids: входные токены |
| -attention_mask: маска внимания |
| -response_ids: метки |
| |
| Returns: |
| -результат forward call |
| |
| ''' |
| gpt2_outputs = self.gpt2(input_ids = input_ids, |
| attention_mask = attention_mask, |
| labels = response_ids) |
|
|
| return gpt2_outputs |
| |
| @staticmethod |
| def decode_sequence(tokens_ids, tokenizer): |
| ''' |
| Декодирование последовательности токенов |
| |
| Параметры: |
| -tokens_ids: последоавтельность токенов |
| -tokenizer: токенизатор |
| |
| Returns: |
| -code_decoded: Декодированная последовательность |
| ''' |
|
|
| code_decoded = tokenizer.decode(tokens_ids, skip_special_tokens = True) |
| return code_decoded |
| |
| @staticmethod |
| def remove_before_substring(text, substring="My Output:"): |
| ''' |
| Вспомогательная утилита, чтобы убрать все лишнее |
| |
| Параметры: |
| -text: строка |
| -substring: подстрока (default: "My Output:") |
| |
| Returns: |
| -text: обновленный текст |
| |
| ''' |
|
|
| index = text.find(substring) |
| if index != -1: |
| |
| return text[index:] |
| return text |
| |
| @staticmethod |
| def extract_text_between_markers(text, start_marker, end_marker): |
| ''' |
| Утилита для работы с входной строкой (получение чистой входной функции) |
| |
| Параметры: |
| -text: входная строка |
| -start_marker: стартовый маркер |
| -end_marker: конечный маркер |
| |
| Returns: |
| -Отфильтрованный текст |
| |
| ''' |
|
|
| start_index = text.find(start_marker) |
| end_index = text.find(end_marker) |
| if start_index == -1 or end_index == -1 or start_index >= end_index: |
| return None |
| return text[start_index + len(start_marker):end_index].strip() |
|
|
| def input_inference(self, code_text): |
| ''' |
| Инференс входной строки кода |
| |
| Параметры: |
| -code_text: строка с кодом |
| |
| Returns: |
| -dict: { |
| 'input_function': input_function, |
| 'generated_output': output_string |
| } |
| |
| Словарь в формате input_function + generated_output |
| ''' |
| model_input = self.additional_prompt + code_text + "\n\nMy Output:\n\n" |
|
|
| def encode_text(text, tokenizer = self.tokenizerGPT): |
| encoding = tokenizer.encode_plus( |
| text, |
| add_special_tokens=True, |
| max_length=max_length, |
| padding='max_length', |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors='pt', |
| ) |
| input_ids_code_text = encoding['input_ids'].flatten() |
| attention_mask_code_text = encoding['attention_mask'].flatten() |
| return input_ids_code_text, attention_mask_code_text |
|
|
| input_ids_focal_method, attention_mask_focal_method = encode_text(model_input) |
|
|
| self.eval() |
| with torch.no_grad(): |
|
|
| inputs = {'input_ids': input_ids_focal_method.unsqueeze(0).to(device), |
| 'attention_mask': attention_mask_focal_method.unsqueeze(0).to(device)} |
| |
| input_function = self.decode_sequence(tokens_ids = inputs['input_ids'][0], |
| tokenizer=self.tokenizerGPT) |
| |
| input_function = self.extract_text_between_markers(input_function, |
| 'My Input:', |
| 'My Output:'), |
|
|
| output = self.gpt2.generate(**inputs, |
| **self.inference_params) |
|
|
| output_string = self.decode_sequence(tokens_ids = output[0], |
| tokenizer=self.tokenizerGPT) |
| |
| output_string = self.remove_before_substring(output_string) \ |
| .replace('My Output:', "") \ |
| .strip() |
| output_string = "\n".join(output_string.splitlines()[:self.num_lines]) |
|
|
| return { |
| 'input_function': input_function, |
| 'generated_output': output_string |
| } |
|
|
| if __name__ == "__main__": |
| CodeModel = LargeCodeModelGPTBigCode(gpt2_name, flag_pretrained=True) |
| |