cfgpp commited on
Commit
2885278
·
verified ·
1 Parent(s): 20cd6c0

Upload 2 files

Browse files
Files changed (2) hide show
  1. model_utils.py +41 -0
  2. preprocessing.py +11 -0
model_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/model_utils.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.models import densenet121, DenseNet121_Weights
5
+
6
+ # Disease labels
7
+ DISEASE_LIST = [
8
+ 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
9
+ 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass',
10
+ 'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax'
11
+ ]
12
+
13
+ # Load trained CheXNet model
14
+ class CheXNet(nn.Module):
15
+ def __init__(self, num_classes=14):
16
+ super().__init__()
17
+ base_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
18
+ self.features = base_model.features
19
+ self.classifier = nn.Linear(base_model.classifier.in_features, num_classes)
20
+
21
+ def forward(self, x):
22
+ x = self.features(x)
23
+ x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
24
+ x = torch.flatten(x, 1)
25
+ return self.classifier(x)
26
+
27
+ def load_model(url, device):
28
+ model_path = "dannynet.pth"
29
+ torch.hub.download_url_to_file(url,model_path)
30
+ model = torch.load(model_path, map_location = device)
31
+ model.eval()
32
+ return model
33
+
34
+ def predict(model, img_tensor, device):
35
+ with torch.no_grad():
36
+ output = model(img_tensor.unsqueeze(0).to(device))
37
+ probs = torch.sigmoid(output[0]).cpu().numpy()
38
+ return dict(zip(DISEASE_LIST, probs))
39
+
40
+
41
+
preprocessing.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/preprocessing.py
2
+ from torchvision import transforms
3
+
4
+ def preprocess_image(img):
5
+ transform = transforms.Compose([
6
+ transforms.Resize(256),
7
+ transforms.CenterCrop(224),
8
+ transforms.ToTensor(),
9
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
10
+ ])
11
+ return transform(img)