Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import torchxrayvision as xrv | |
| class MedicalImageEncoder(nn.Module): | |
| """ | |
| SOTA Image Encoder sử dụng DenseNet-121 (TorchXRayVision) | |
| Pretrained trên 200K+ ảnh X-ray (CheXpert, NIH, v.v.) | |
| """ | |
| def __init__(self, pretrained=True): | |
| super(MedicalImageEncoder, self).__init__() | |
| if pretrained: | |
| self.model = xrv.models.DenseNet(weights="densenet121-res224-chex") | |
| else: | |
| self.model = xrv.models.DenseNet(weights=None) | |
| self.model.classifier = nn.Identity() # Bỏ lớp phân loại | |
| self.projector = nn.Linear(1024, 768) # Map về dimension của PhoBERT | |
| def forward(self, x): | |
| feat_map = self.model.features(x) # [B, 1024, 7, 7] | |
| feat_map = feat_map.flatten(2).transpose(1, 2) # [B, 49, 1024] | |
| return self.projector(feat_map) # [B, 49, 768] | |