Spaces:
Paused
Paused
File size: 1,064 Bytes
d63774a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | import torch.nn as nn
from transformers import AutoModel
class PhoBERTEncoder(nn.Module):
"""
Text Encoder sử dụng PhoBERT pretrained.
Hỗ trợ tiếng Việt tốt nhất cho Medical VQA.
"""
def __init__(self, model_name="vinai/phobert-base", freeze_layers=10):
super(PhoBERTEncoder, self).__init__()
self.bert = AutoModel.from_pretrained(model_name, use_safetensors=True)
# Đóng băng các lớp Transformer đầu tiên nếu cần
if freeze_layers > 0:
for param in self.bert.embeddings.parameters():
param.requires_grad = False
for layer in self.bert.encoder.layer[:freeze_layers]:
for param in layer.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
# Lấy [CLS] token đại diện cho toàn bộ câu hỏi
return outputs.last_hidden_state[:, 0, :]
|