Spaces:
Sleeping
Sleeping
Anirudh Balaraman commited on
Commit ·
16c0de3
1
Parent(s): 80a9c91
simplifiy finetuning
Browse files- run_cspca.py +63 -71
run_cspca.py
CHANGED
|
@@ -27,81 +27,73 @@ def main_worker(args):
|
|
| 27 |
model_dir = os.path.join(args.logdir, "models")
|
| 28 |
os.makedirs(model_dir, exist_ok=True)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
for st in list(range(args.num_seeds)):
|
| 32 |
-
set_determinism(seed=st)
|
| 33 |
-
|
| 34 |
-
train_loader = get_dataloader(args, split="train")
|
| 35 |
-
valid_loader = get_dataloader(args, split="test")
|
| 36 |
-
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 37 |
-
for submodule in [
|
| 38 |
-
cspca_model.backbone.net,
|
| 39 |
-
cspca_model.backbone.myfc,
|
| 40 |
-
cspca_model.backbone.transformer,
|
| 41 |
-
]:
|
| 42 |
-
for param in submodule.parameters():
|
| 43 |
-
param.requires_grad = False
|
| 44 |
-
|
| 45 |
-
optimizer = torch.optim.AdamW(
|
| 46 |
-
filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
|
| 47 |
-
)
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
for epoch in range(args.epochs):
|
| 52 |
-
train_loss, train_auc = train_epoch(
|
| 53 |
-
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 54 |
-
)
|
| 55 |
-
logging.info(
|
| 56 |
-
f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
|
| 57 |
-
)
|
| 58 |
-
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 59 |
-
logging.info(
|
| 60 |
-
f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
|
| 61 |
-
)
|
| 62 |
-
val_metric["state"] = st
|
| 63 |
-
if val_metric["loss"] < old_loss:
|
| 64 |
-
old_loss = val_metric["loss"]
|
| 65 |
-
old_auc = val_metric["auc"]
|
| 66 |
-
sensitivity = val_metric["sensitivity"]
|
| 67 |
-
specificity = val_metric["specificity"]
|
| 68 |
-
if not metrics_dict["auc"] or val_metric["auc"] >= max(metrics_dict["auc"]):
|
| 69 |
-
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 70 |
-
|
| 71 |
-
metrics_dict["auc"].append(old_auc)
|
| 72 |
-
metrics_dict["sensitivity"].append(sensitivity)
|
| 73 |
-
metrics_dict["specificity"].append(specificity)
|
| 74 |
-
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 75 |
-
shutil.rmtree(cache_dir_path)
|
| 76 |
-
|
| 77 |
-
get_metrics(metrics_dict)
|
| 78 |
-
|
| 79 |
-
elif args.mode == "test":
|
| 80 |
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
logging.info(
|
| 87 |
-
f"
|
| 88 |
)
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
def parse_args():
|
|
|
|
| 27 |
model_dir = os.path.join(args.logdir, "models")
|
| 28 |
os.makedirs(model_dir, exist_ok=True)
|
| 29 |
|
| 30 |
+
set_determinism(seed=42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
train_loader = get_dataloader(args, split="train")
|
| 33 |
+
valid_loader = get_dataloader(args, split="test")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 35 |
+
for submodule in [
|
| 36 |
+
cspca_model.backbone.net,
|
| 37 |
+
cspca_model.backbone.myfc,
|
| 38 |
+
cspca_model.backbone.transformer,
|
| 39 |
+
]:
|
| 40 |
+
for param in submodule.parameters():
|
| 41 |
+
param.requires_grad = False
|
| 42 |
+
|
| 43 |
+
optimizer = torch.optim.AdamW(
|
| 44 |
+
filter(lambda p: p.requires_grad, cspca_model.parameters()), lr=args.optim_lr
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
old_loss = float("inf")
|
| 48 |
+
old_auc = 0.0
|
| 49 |
+
for epoch in range(args.epochs):
|
| 50 |
+
train_loss, train_auc = train_epoch(
|
| 51 |
+
cspca_model, train_loader, optimizer, epoch=epoch, args=args
|
| 52 |
+
)
|
| 53 |
+
logging.info(
|
| 54 |
+
f"STATE {st} EPOCH {epoch} TRAIN loss: {train_loss:.4f} AUC: {train_auc:.4f}"
|
| 55 |
+
)
|
| 56 |
+
val_metric = val_epoch(cspca_model, valid_loader, epoch=epoch, args=args)
|
| 57 |
logging.info(
|
| 58 |
+
f"STATE {st} EPOCH {epoch} VAL loss: {val_metric['loss']:.4f} AUC: {val_metric['auc']:.4f}"
|
| 59 |
)
|
| 60 |
+
if val_metric["loss"] < old_loss:
|
| 61 |
+
old_loss = val_metric["loss"]
|
| 62 |
+
old_auc = val_metric["auc"]
|
| 63 |
+
sensitivity = val_metric["sensitivity"]
|
| 64 |
+
specificity = val_metric["specificity"]
|
| 65 |
+
save_cspca_checkpoint(cspca_model, val_metric, model_dir)
|
| 66 |
+
|
| 67 |
+
args.checkpoint_cspca = os.path.join(model_dir, "cspca_model.pth")
|
| 68 |
+
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 69 |
+
shutil.rmtree(cache_dir_path)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
cspca_model = csPCa_Model(backbone=mil_model).to(args.device)
|
| 73 |
+
checkpt = torch.load(args.checkpoint_cspca, map_location="cpu")
|
| 74 |
+
cspca_model.load_state_dict(checkpt["state_dict"])
|
| 75 |
+
cspca_model = cspca_model.to(args.device)
|
| 76 |
+
if "auc" in checkpt and "sensitivity" in checkpt and "specificity" in checkpt:
|
| 77 |
+
auc, sens, spec = checkpt["auc"], checkpt["sensitivity"], checkpt["specificity"]
|
| 78 |
+
logging.info(
|
| 79 |
+
f"csPCa Model loaded from {args.checkpoint_cspca} with AUC: {auc}, Sensitivity: {sens}, Specificity: {spec} on the test set."
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
logging.info(f"csPCa Model loaded from {args.checkpoint_cspca}.")
|
| 83 |
+
|
| 84 |
+
metrics_dict = {"auc": [], "sensitivity": [], "specificity": []}
|
| 85 |
+
for st in list(range(args.num_seeds)):
|
| 86 |
+
set_determinism(seed=st)
|
| 87 |
+
test_loader = get_dataloader(args, split="test")
|
| 88 |
+
test_metric = val_epoch(cspca_model, test_loader, epoch=0, args=args)
|
| 89 |
+
metrics_dict["auc"].append(test_metric["auc"])
|
| 90 |
+
metrics_dict["sensitivity"].append(test_metric["sensitivity"])
|
| 91 |
+
metrics_dict["specificity"].append(test_metric["specificity"])
|
| 92 |
+
|
| 93 |
+
if cache_dir_path.exists() and cache_dir_path.is_dir():
|
| 94 |
+
shutil.rmtree(cache_dir_path)
|
| 95 |
+
|
| 96 |
+
get_metrics(metrics_dict)
|
| 97 |
|
| 98 |
|
| 99 |
def parse_args():
|