| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import datasets |
| import os |
|
|
| from tokenizers import Tokenizer |
| from tokenizers.models import WordLevel |
| from tokenizers.pre_tokenizers import WhitespaceSplit |
| from tokenizers.processors import TemplateProcessing |
| from tokenizers.trainers import WordLevelTrainer |
| from tokenizers.decoders import WordPiece |
|
|
| from transformers import PreTrainedTokenizerFast |
| from transformers import BertConfig, BertForMaskedLM, BertModel, BertForPreTraining |
| from transformers import ( |
| AutoModelForMaskedLM, |
| AutoTokenizer, |
| DataCollatorForLanguageModeling, |
| EarlyStoppingCallback, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| os.environ["WANDB_DISABLED"] = "true" |
|
|
| NUM_TRAIN_EPOCHS = 100 |
|
|
| go_uni = datasets.load_dataset("damlab/uniprot")["train"].filter( |
| lambda x: x["go"] is not None |
| ) |
|
|
|
|
| tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"),) |
| tokenizer.pre_tokenizer = WhitespaceSplit() |
|
|
| trainer = WordLevelTrainer( |
| special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "[BOS]", "[EOS]"] |
| ) |
| tokenizer.train_from_iterator(go_uni["go"], trainer=trainer) |
|
|
| cls_token_id = tokenizer.token_to_id("[CLS]") |
| sep_token_id = tokenizer.token_to_id("[SEP]") |
| print(cls_token_id, sep_token_id) |
|
|
| tokenizer.post_processor = TemplateProcessing( |
| single=f"[CLS]:0 $A:0 [SEP]:0", |
| pair=f"[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1", |
| special_tokens=[("[CLS]", cls_token_id), ("[SEP]", sep_token_id)], |
| ) |
|
|
| tokenizer.decoder = WordPiece(prefix="##") |
|
|
| wrapped_tokenizer = PreTrainedTokenizerFast( |
| tokenizer_object=tokenizer, |
| |
| unk_token="[UNK]", |
| pad_token="[PAD]", |
| cls_token="[CLS]", |
| sep_token="[SEP]", |
| mask_token="[MASK]", |
| ) |
|
|
| wrapped_tokenizer.save_pretrained("./") |
|
|
|
|
| def tkn_func(examples): |
| return wrapped_tokenizer(examples["go"], max_length=256, truncation=True) |
|
|
|
|
| tokenized_dataset = go_uni.map( |
| tkn_func, batched=True, remove_columns=go_uni.column_names |
| ) |
| split_dataset = tokenized_dataset.train_test_split(seed=1234) |
|
|
|
|
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=wrapped_tokenizer, mlm_probability=0.15, pad_to_multiple_of=8, |
| ) |
|
|
| training_args = TrainingArguments( |
| "trainer", |
| evaluation_strategy="steps", |
| load_best_model_at_end=False, |
| save_strategy="no", |
| logging_first_step=True, |
| logging_steps=10, |
| eval_steps=10, |
| num_train_epochs=NUM_TRAIN_EPOCHS, |
| warmup_steps=10, |
| weight_decay=0.01, |
| per_device_train_batch_size=24, |
| per_device_eval_batch_size=24, |
| gradient_accumulation_steps=96, |
| lr_scheduler_type="cosine_with_restarts", |
| ) |
|
|
|
|
| encoder_bert = BertConfig( |
| vocab_size=tokenizer.get_vocab_size(), |
| hidden_size=1024, |
| num_hidden_layers=12, |
| num_attention_heads=32, |
| intermediate_size=3072, |
| hidden_act="gelu", |
| hidden_dropout_prob=0.1, |
| attention_probs_dropout_prob=0.1, |
| max_position_embeddings=256, |
| type_vocab_size=2, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| pad_token_id=0, |
| position_embedding_type="absolute", |
| ) |
|
|
|
|
| def model_init(): |
| return BertForMaskedLM(encoder_bert) |
|
|
|
|
| trainer = Trainer( |
| model_init=model_init, |
| args=training_args, |
| train_dataset=split_dataset["train"], |
| eval_dataset=split_dataset["test"], |
| data_collator=data_collator, |
| ) |
|
|
| results = trainer.train() |
| trainer.save_model("./") |
|
|