| import numpy as np |
| import torch |
| import torch.optim |
| import os |
| import random |
|
|
| from methods import backbone |
| from methods.backbone_multiblock import model_dict |
| from data.datamgr import SimpleDataManager, SetDataManager |
| from methods.StyleAdv_RN_GNN import StyleAdvGNN |
|
|
| from options import parse_args, get_resume_file, load_warmup_state |
| from test_function_fwt_benchmark import test_bestmodel |
| from test_function_bscdfsl_benchmark import test_bestmodel_bscdfsl |
|
|
|
|
| def record_test_result(params): |
| acc_file_path = "tmp2.txt" |
| acc_file = open(acc_file_path, "w") |
| epoch_id = -1 |
| print( |
| "epoch", |
| epoch_id, |
| "miniImagenet:", |
| "cub:", |
| "cars:", |
| "places:", |
| "plantae:", |
| file=acc_file, |
| ) |
| name = params.name |
| n_shot = params.n_shot |
| method = params.method |
| test_bestmodel(acc_file, name, method, "miniImagenet", n_shot, epoch_id) |
| |
| |
| |
| |
|
|
| acc_file.close() |
| return |
|
|
|
|
| def record_test_result_bscdfsl(params): |
| print("hhhhhhh testing for bscdfsl") |
| acc_file_path = "tmp_bscdfsl2.txt" |
| acc_file = open(acc_file_path, "w") |
| epoch_id = -1 |
| print( |
| "epoch", epoch_id, "ChestX:", "ISIC:", "EuroSAT:", "CropDisease", file=acc_file |
| ) |
| name = params.name |
| n_shot = params.n_shot |
| method = params.method |
| |
| |
| test_bestmodel_bscdfsl(acc_file, name, method, "EuroSAT", n_shot, epoch_id) |
| |
|
|
| acc_file.close() |
| return |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| seed = 0 |
| print("set seed = %d" % seed) |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| params = parse_args("train") |
|
|
| |
| |
| |
| record_test_result_bscdfsl(params) |
|
|