Medical-VQA / src /models /encoder.py
SpringWang08's picture
Deploy Medical VQA app
d63774a
raw
history blame
950 Bytes
import torch
import torch.nn as nn
import torchxrayvision as xrv
class MedicalImageEncoder(nn.Module):
"""
SOTA Image Encoder sử dụng DenseNet-121 (TorchXRayVision)
Pretrained trên 200K+ ảnh X-ray (CheXpert, NIH, v.v.)
"""
def __init__(self, pretrained=True):
super(MedicalImageEncoder, self).__init__()
if pretrained:
self.model = xrv.models.DenseNet(weights="densenet121-res224-chex")
else:
self.model = xrv.models.DenseNet(weights=None)
self.model.classifier = nn.Identity() # Bỏ lớp phân loại
self.projector = nn.Linear(1024, 768) # Map về dimension của PhoBERT
def forward(self, x):
feat_map = self.model.features(x) # [B, 1024, 7, 7]
feat_map = feat_map.flatten(2).transpose(1, 2) # [B, 49, 1024]
return self.projector(feat_map) # [B, 49, 768]