| from transformers import PreTrainedModel |
| from transformers import AutoModelForMaskedLM, AutoTokenizer |
| from pytorch_lightning.loggers import WandbLogger |
|
|
| from src.regression.PL import FullModelPL, EncoderPL, DecoderPL |
| from src.regression.HF.configs import FullModelConfigHF |
|
|
| from config import DEVICE |
|
|
|
|
| class FullModelHF(PreTrainedModel): |
| config_class = FullModelConfigHF |
|
|
| def __init__(self, config): |
|
|
| super().__init__(config) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_ckpt) |
| mlm_bert = AutoModelForMaskedLM.from_pretrained(config.bert_ckpt) |
| self.bert = mlm_bert.distilbert |
|
|
| encoder = EncoderPL(tokenizer=self.tokenizer, bert=self.bert).to(DEVICE) |
|
|
| wandb_logger = WandbLogger( |
| project="transformers", |
| entity="sanjin_juric_fot", |
| |
| |
| ) |
|
|
| artifact = wandb_logger.use_artifact(config.decoder_ckpt) |
| artifact_dir = artifact.download() |
| decoder = DecoderPL.load_from_checkpoint(artifact_dir + "/" + "model.ckpt").to(DEVICE) |
|
|
| self.model = FullModelPL( |
| encoder=encoder, |
| decoder=decoder, |
| layer_norm=config.layer_norm, |
| nontext_features=config.nontext_features, |
| ).to(DEVICE) |
|
|
| def forward(self, input): |
| return self.model._get_loss(input) |
|
|