File size: 950 Bytes
d63774a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
import torchxrayvision as xrv

class MedicalImageEncoder(nn.Module):
    """
    SOTA Image Encoder sử dụng DenseNet-121 (TorchXRayVision)
    Pretrained trên 200K+ ảnh X-ray (CheXpert, NIH, v.v.)
    """
    def __init__(self, pretrained=True):
        super(MedicalImageEncoder, self).__init__()
        if pretrained:
            self.model = xrv.models.DenseNet(weights="densenet121-res224-chex")
        else:
            self.model = xrv.models.DenseNet(weights=None)
            
        self.model.classifier = nn.Identity() # Bỏ lớp phân loại 
        self.projector = nn.Linear(1024, 768) # Map về dimension của PhoBERT
        
    def forward(self, x):
        feat_map = self.model.features(x)                    # [B, 1024, 7, 7]
        feat_map = feat_map.flatten(2).transpose(1, 2)       # [B, 49, 1024]
        return self.projector(feat_map)                      # [B, 49, 768]