Update geneformer/mtl/model.py

#585
by MajorasMeow - opened

change in line 36:
old:
self.bert = BertModel(self.config)
new:
self.bert = BertModel.from_pretrained(pretrained_path)

Reason:
BertModel() is initializing random weights instead of using weights from pretrained model. This results in training from scratch instead of finetuning the pretrained model.
BertModel.from_pretrained() initializes the pretrained weights.

Tested by training with MTLClassifier both ways and also via:
pretrained_path = "/Geneformer/Geneformer-V2-104M"
ref = BertModel.from_pretrained(pretrained_path)
mtl = GeneformerMultiTask(pretrained_path, num_labels_list=[2])

print("mtl layer0 q mean:", mtl.bert.encoder.layer[0].attention.self.query.weight.abs().mean().item())
print("ref layer0 q mean:", ref.encoder.layer[0].attention.self.query.weight.abs().mean().item())

-># should result in identical float

Thank you so much for finding this and for your contribution in resolving it!

ctheodoris changed pull request status to merged

Happy to support :)

Sign up or log in to comment