Anirudh Balaraman commited on
Commit
5fa0689
·
1 Parent(s): 1927437

update mem requirement

Browse files
Files changed (1) hide show
  1. run_inference.py +32 -2
run_inference.py CHANGED
@@ -16,6 +16,32 @@ from src.preprocessing.histogram_match import histmatch
16
  from src.preprocessing.prostate_mask import get_segmask
17
  from src.preprocessing.register_and_crop import register_files
18
  from src.utils import get_parent_image, get_patch_coordinate, setup_logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def parse_args():
@@ -70,20 +96,24 @@ if __name__ == "__main__":
70
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
 
72
  logging.info("Loading PIRADS model")
73
- pirads_model = MILModel3D(num_classes=args.num_classes, mil_mode=args.mil_mode)
 
74
  pirads_checkpoint = torch.load(
75
  os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
76
  )
77
  pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
78
  pirads_model.to(args.device)
 
79
  logging.info("Loading csPCa model")
 
 
80
  cspca_model = CSPCAModel(backbone=pirads_model).to(args.device)
81
  checkpt = torch.load(
82
  os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
83
  )
84
  cspca_model.load_state_dict(checkpt["state_dict"])
85
  cspca_model = cspca_model.to(args.device)
86
-
87
  transform = data_transform(args)
88
  files = os.listdir(args.t2_dir)
89
  args.data_list = []
 
16
  from src.preprocessing.prostate_mask import get_segmask
17
  from src.preprocessing.register_and_crop import register_files
18
  from src.utils import get_parent_image, get_patch_coordinate, setup_logging
19
+ import streamlit as st
20
+
21
+ @st.cache_resource # <--- This decorator is the magic!
22
+ def load_pirads_model(num_classes, mil_mode, project_dir, device):
23
+ # Move the model initialization inside here
24
+ model = MILModel3D(num_classes=num_classes, mil_mode=mil_mode)
25
+ checkpoint = torch.load(
26
+ os.path.join(project_dir, "models", "pirads.pt"), map_location="cpu"
27
+ )
28
+ model.load_state_dict(checkpoint["state_dict"])
29
+ model.to(device)
30
+
31
+ model.eval() # Set to evaluation mode
32
+ return model
33
+ @st.cache_resource
34
+ def load_cspca_model(pirads_model, project_dir, device):
35
+ # Move the model initialization inside here
36
+ model = CSPCAModel(backbone=pirads_model).to(device)
37
+ checkpt = torch.load(
38
+ os.path.join(project_dir, "models", "cspca_model.pth"), map_location="cpu"
39
+ )
40
+ model.load_state_dict(checkpt["state_dict"])
41
+ model = model.to(device)
42
+
43
+ model.eval() # Set to evaluation mode
44
+ return model
45
 
46
 
47
  def parse_args():
 
96
  args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
 
98
  logging.info("Loading PIRADS model")
99
+ pirads_model = load_pirads_model(args.num_classes, args.mil_mode, args.project_dir, args.device)
100
+ '''
101
  pirads_checkpoint = torch.load(
102
  os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
103
  )
104
  pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
105
  pirads_model.to(args.device)
106
+ '''
107
  logging.info("Loading csPCa model")
108
+ cspca_model = load_cspca_model(pirads_model, args.project_dir, args.device)
109
+ '''
110
  cspca_model = CSPCAModel(backbone=pirads_model).to(args.device)
111
  checkpt = torch.load(
112
  os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
113
  )
114
  cspca_model.load_state_dict(checkpt["state_dict"])
115
  cspca_model = cspca_model.to(args.device)
116
+ '''
117
  transform = data_transform(args)
118
  files = os.listdir(args.t2_dir)
119
  args.data_list = []