Anirudh Balaraman commited on
Commit
16c0de3
·
1 Parent(s): 80a9c91

simplifiy finetuning

Browse files
Files changed (1) hide show
  1. run_cspca.py +63 -71
run_cspca.py CHANGED
@@ -27,81 +27,73 @@ def main_worker(args):
27
  model_dir = os.path.join(args.logdir, "models")
28
  os.makedirs(model_dir, exist_ok=True)
29
 
30
- metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
31
- for st in list(range(args.num_seeds)):
32
- set_determinism(seed=st)
33
-
34
- train_loader = get_dataloader(args, split="train")
35
- valid_loader = get_dataloader(args, split="test")
36
- cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
37
- for submodule in [
38
- cspca_model.backbone.net,
39
- cspca_model.backbone.myfc,
40
- cspca_model.backbone.transformer,
41
- ]:
42
- for param in submodule.parameters():
43
- param.requires_grad = False
44
-
45
- optimizer = torch.optim.AdamW(
46
- filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
47
- )
48
 
49
- old_loss = float("inf")
50
- old_auc = 0.0
51
- for epoch in range(args.epochs):
52
- train_loss, train_auc = train_epoch(
53
- cspca_model, train_loader, optimizer, epoch=epoch, args=args
54
- )
55
- logging.info(
56
- f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
57
- )
58
- val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
59
- logging.info(
60
- f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
61
- )
62
- val_metric["state"] = st
63
- if val_metric["loss"] < old_loss:
64
- old_loss = val_metric["loss"]
65
- old_auc = val_metric["auc"]
66
- sensitivity = val_metric["sensitivity"]
67
- specificity = val_metric["specificity"]
68
- if not metrics_dict["auc"] or val_metric["auc"] >= max(metrics_dict["auc"]):
69
- save_cspca_checkpoint(cspca_model, val_metric, model_dir)
70
-
71
- metrics_dict["auc"].append(old_auc)
72
- metrics_dict["sensitivity"].append(sensitivity)
73
- metrics_dict["specificity"].append(specificity)
74
- if cache_dir_path.exists() and cache_dir_path.is_dir():
75
- shutil.rmtree(cache_dir_path)
76
-
77
- get_metrics(metrics_dict)
78
-
79
- elif args.mode == "test":
80
  cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
81
- checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
82
- cspca_model.load_state_dict(checkpt["state_dict"])
83
- cspca_model = cspca_model.to(args.device)
84
- if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
85
- auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  logging.info(
87
- f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set."
88
  )
89
- else:
90
- logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
91
-
92
- metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
93
- for st in list(range(args.num_seeds)):
94
- set_determinism(seed=st)
95
- test_loader = get_dataloader(args, split="test")
96
- test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
97
- metrics_dict["auc"].append(test_metric["auc"])
98
- metrics_dict["sensitivity"].append(test_metric["sensitivity"])
99
- metrics_dict["specificity"].append(test_metric["specificity"])
100
-
101
- if cache_dir_path.exists() and cache_dir_path.is_dir():
102
- shutil.rmtree(cache_dir_path)
103
-
104
- get_metrics(metrics_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  def parse_args():
 
27
  model_dir = os.path.join(args.logdir, "models")
28
  os.makedirs(model_dir, exist_ok=True)
29
 
30
+ set_determinism(seed=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ train_loader = get_dataloader(args, split="train")
33
+ valid_loader = get_dataloader(args, split="test")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
35
+ for submodule in [
36
+ cspca_model.backbone.net,
37
+ cspca_model.backbone.myfc,
38
+ cspca_model.backbone.transformer,
39
+ ]:
40
+ for param in submodule.parameters():
41
+ param.requires_grad = False
42
+
43
+ optimizer = torch.optim.AdamW(
44
+ filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
45
+ )
46
+
47
+ old_loss = float("inf")
48
+ old_auc = 0.0
49
+ for epoch in range(args.epochs):
50
+ train_loss, train_auc = train_epoch(
51
+ cspca_model, train_loader, optimizer, epoch=epoch, args=args
52
+ )
53
+ logging.info(
54
+ f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
55
+ )
56
+ val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
57
  logging.info(
58
+ f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
59
  )
60
+ if val_metric["loss"] < old_loss:
61
+ old_loss = val_metric["loss"]
62
+ old_auc = val_metric["auc"]
63
+ sensitivity = val_metric["sensitivity"]
64
+ specificity = val_metric["specificity"]
65
+ save_cspca_checkpoint(cspca_model, val_metric, model_dir)
66
+
67
+ args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
68
+ if cache_dir_path.exists() and cache_dir_path.is_dir():
69
+ shutil.rmtree(cache_dir_path)
70
+
71
+
72
+ cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
73
+ checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
74
+ cspca_model.load_state_dict(checkpt["state_dict"])
75
+ cspca_model = cspca_model.to(args.device)
76
+ if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
77
+ auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
78
+ logging.info(
79
+ f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set."
80
+ )
81
+ else:
82
+ logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
83
+
84
+ metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
85
+ for st in list(range(args.num_seeds)):
86
+ set_determinism(seed=st)
87
+ test_loader = get_dataloader(args, split="test")
88
+ test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
89
+ metrics_dict["auc"].append(test_metric["auc"])
90
+ metrics_dict["sensitivity"].append(test_metric["sensitivity"])
91
+ metrics_dict["specificity"].append(test_metric["specificity"])
92
+
93
+ if cache_dir_path.exists() and cache_dir_path.is_dir():
94
+ shutil.rmtree(cache_dir_path)
95
+
96
+ get_metrics(metrics_dict)
97
 
98
 
99
  def parse_args():