| import argparse
|
| import os
|
|
|
| import numpy as np
|
| import pandas as pd
|
| import torch
|
| from tqdm import tqdm
|
|
|
| from model.CPICANN import CPICANN
|
|
|
|
|
| def getAnnoMap():
|
| vs = pd.read_csv(args.anno_struc).values
|
| annos, elems = {}, {}
|
| for v in vs:
|
| annos[v[1]] = v
|
| elems[v[1]] = set(v[3].split(' '))
|
|
|
| return annos, elems
|
|
|
|
|
| def filter_by_elem(logits, elemMap, elem):
|
| for i, e in elemMap.items():
|
| if not e <= elem:
|
| logits[:, i] = -10 ** 9
|
|
|
| return logits
|
|
|
|
|
| def main():
|
| annoMap, elemMap = getAnnoMap()
|
|
|
| model = CPICANN(embed_dim=128, num_classes=args.num_classes)
|
|
|
| loaded = torch.load(args.load_path)
|
| model.load_state_dict(loaded['model'])
|
| model.to(args.device)
|
| model.eval()
|
| print('loaded model from {}'.format(args.load_path))
|
| print(model)
|
|
|
| if args.elem_filtration:
|
| print('elem_filtration activated!')
|
| else:
|
| print('elem_filtration deactivated!')
|
|
|
| lst = pd.read_csv(args.anno_val).values
|
|
|
| top10Hits = np.array([0] * 10, dtype=np.int32)
|
|
|
| dataLen = len(lst)
|
| pbar = tqdm(range(args.infTimes))
|
| for i in range(args.infTimes):
|
| while True:
|
| c1, c2 = np.random.randint(0, dataLen, 2)
|
| anno1, anno2 = lst[c1], lst[c2]
|
| if anno1[6] != anno2[6]:
|
| break
|
|
|
|
|
|
|
| data1 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c1][0]}.csv')).values
|
| data2 = pd.read_csv(os.path.join(args.data_dir, f'{lst[c2][0]}.csv')).values
|
|
|
| mixRate1 = np.random.randint(20, 81)
|
| mixRate2 = 100 - mixRate1
|
|
|
| data = mixRate1 * data1 + mixRate2 * data2
|
| elem = set(lst[c2][3].strip().split(' ')) | set(lst[c1][3].strip().split(' '))
|
|
|
| def runFile(v):
|
| min_i, scale = min(v), max(v) - min(v)
|
| v = (v - min_i) / scale * 100
|
|
|
| v = torch.tensor(v, dtype=torch.float32).reshape(1, 1, -1)
|
| v = v.to(args.device)
|
| with torch.no_grad():
|
| logits = model(v)
|
|
|
|
|
| if args.elem_filtration:
|
| logits = filter_by_elem(logits, elemMap, elem)
|
|
|
| _pred = torch.nn.functional.softmax(logits.squeeze(), dim=0)
|
| return _pred.topk(10)
|
|
|
| top10 = runFile(data)
|
|
|
| m = [0] * 10
|
| for no, (indice, rate) in enumerate(zip(top10.indices, top10.values)):
|
| pred = annoMap[top10.indices[no].item()]
|
|
|
| if pred[0] == int(anno1[0][:7]):
|
| m[no] = 1
|
| elif pred[0] == int(anno2[0][:7]):
|
| m[no] = 2
|
|
|
| if 1 in m[:2] and 2 in m[:2]:
|
| top10Hits[1:] += 1
|
| elif 1 in m[:3] and 2 in m[:3]:
|
| top10Hits[2:] += 1
|
| elif 1 in m[:4] and 2 in m[:4]:
|
| top10Hits[3:] += 1
|
| elif 1 in m[:5] and 2 in m[:5]:
|
| top10Hits[4:] += 1
|
| elif 1 in m[:6] and 2 in m[:6]:
|
| top10Hits[5:] += 1
|
| elif 1 in m[:7] and 2 in m[:7]:
|
| top10Hits[6:] += 1
|
| elif 1 in m[:8] and 2 in m[:8]:
|
| top10Hits[7:] += 1
|
| elif 1 in m[:9] and 2 in m[:9]:
|
| top10Hits[8:] += 1
|
| elif 1 in m[:10] and 2 in m[:10]:
|
| top10Hits[9:] += 1
|
|
|
| pbar.update(1)
|
| pbar.close()
|
|
|
| for i in range(1, 10):
|
| print('top{}Hits: {}%'.format(i + 1, round(top10Hits[i] / args.infTimes * 100, 2)))
|
|
|
|
|
| if __name__ == '__main__':
|
| parser = argparse.ArgumentParser()
|
|
|
| parser.add_argument('--device', default='cuda:0', type=str)
|
| parser.add_argument('--data_dir', default='data/val/', type=str)
|
| parser.add_argument('--infTimes', default=1000, type=int, help='number of mixed pattern to be inferenced')
|
| parser.add_argument('--load_path', default='pretrained/bi-phase_checkpoint_2000.pth', type=str,
|
| help='path to load pretrained single-phase identification model')
|
| parser.add_argument('--anno_struc', default='annotation/anno_struc.csv', type=str,
|
| help='path to annotation file for training data')
|
| parser.add_argument('--anno_val', default='annotation/anno_val.csv', type=str,
|
| help='path to annotation file for validation data')
|
| parser.add_argument('--num_classes', default=23073, type=int, metavar='N')
|
|
|
| parser.add_argument('--elem_filtration', default=False, type=bool)
|
|
|
| args = parser.parse_args()
|
|
|
| main()
|
| print('THE END')
|
|
|