CGSCORE / examples /graph /test_median_gcn.py
Yaning1001's picture
Add files using upload-large-folder tool
c91d7b1 verified
import torch
import numpy as np
from deeprobust.graph.defense import MedianGCN
from deeprobust.graph.targeted_attack import SGAttack
from deeprobust.graph.utils import *
from deeprobust.graph.data import Dataset, Dpr2Pyg
from deeprobust.graph.defense import SGC
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
parser.add_argument('--dataset', type=str, default='cora', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset')
parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate')
args = parser.parse_args()
args.cuda = torch.cuda.is_available()
print('cuda: %s' % args.cuda)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
data = Dataset(root='/tmp/', name=args.dataset)
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
idx_unlabeled = np.union1d(idx_val, idx_test)
# Setup Surrogate model
surrogate = SGC(nfeat=features.shape[1],
nclass=labels.max().item() + 1, K=2,
lr=0.01, device=device).to(device)
pyg_data = Dpr2Pyg(data)
surrogate.fit(pyg_data, verbose=False) # train with earlystopping
surrogate.test()
# Setup Attack Model
target_node = 0
assert target_node in idx_unlabeled
model = SGAttack(surrogate, attack_structure=True, device=device)
model = model.to(device)
def main():
degrees = adj.sum(0).A1
# How many perturbations to perform. Default: Degree of the node
n_perturbations = int(degrees[target_node])
# direct attack
model.attack(features, adj, labels, target_node, n_perturbations)
# # indirect attack/ influencer attack
# model.attack(features, adj, labels, target_node, n_perturbations, direct=False, n_influencers=5)
modified_adj = model.modified_adj
modified_features = model.modified_features
print(model.structure_perturbations)
test(adj, modified_adj, features, target_node)
def test(adj, modified_adj, features, target_node):
'''Evasion test on MedianGCN '''
gcn = MedianGCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
gcn = gcn.to(device)
pyg_data.update_edge_index(adj)
gcn.fit(pyg_data)
print('=== testing MedianGCN on original(clean) graph ===')
gcn.eval()
output = gcn.predict()
probs = torch.exp(output[[target_node]])[0]
print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
acc_test = accuracy(output[idx_test], labels[idx_test])
print("Overall test set results:",
"accuracy= {:.4f}".format(acc_test.item()))
print('=== testing MedianGCN on perturbed graph ===')
gcn.eval()
pyg_data.update_edge_index(modified_adj)
output = gcn.predict()
probs = torch.exp(output[[target_node]])[0]
print('Target node probs: {}'.format(probs.detach().cpu().numpy()))
acc_test = accuracy(output[idx_test], labels[idx_test])
print("Overall test set results:",
"accuracy= {:.4f}".format(acc_test.item()))
return acc_test.item()
def select_nodes(target_gcn=None):
'''
selecting nodes as reported in Nettack paper:
(i) the 10 nodes with highest margin of classification, i.e. they are clearly correctly classified,
(ii) the 10 nodes with lowest margin (but still correctly classified) and
(iii) 20 more nodes randomly
'''
if target_gcn is None:
target_gcn = MedianGCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
target_gcn = target_gcn.to(device)
# since `update_edge_index` is inplace operation
# so it is necessary to restore the initial clean graph
pyg_data.update_edge_index(adj)
target_gcn.fit(pyg_data)
target_gcn.eval()
output = target_gcn.predict()
margin_dict = {}
for idx in idx_test:
margin = classification_margin(output[idx], labels[idx])
if margin < 0: # only keep the nodes correctly classified
continue
margin_dict[idx] = margin
sorted_margins = sorted(margin_dict.items(), key=lambda x: x[1], reverse=True)
high = [x for x, y in sorted_margins[: 10]]
low = [x for x, y in sorted_margins[-10:]]
other = [x for x, y in sorted_margins[10: -10]]
other = np.random.choice(other, 20, replace=False).tolist()
return high + low + other
def single_test(adj, features, target_node, gcn=None):
if gcn is None:
# test on MedianGCN (poisoning attack)
gcn = MedianGCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
gcn = gcn.to(device)
pyg_data.update_edge_index(adj)
gcn.fit(pyg_data)
gcn.eval()
output = gcn.predict()
else:
pyg_data.update_edge_index(adj)
# test on MedianGCN (evasion attack)
output = gcn.predict(pyg_data)
probs = torch.exp(output[[target_node]])
# acc_test = accuracy(output[[target_node]], labels[target_node])
acc_test = (output.argmax(1)[target_node] == labels[target_node])
return acc_test.item()
def multi_test_evasion():
# test on 40 nodes on evasion attack
target_gcn = MedianGCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device=device)
target_gcn = target_gcn.to(device)
pyg_data.update_edge_index(adj)
target_gcn.fit(pyg_data)
cnt = 0
degrees = adj.sum(0).A1
node_list = select_nodes(target_gcn)
num = len(node_list)
print('=== [Evasion] Attacking %s nodes respectively ===' % num)
for target_node in tqdm(node_list):
n_perturbations = int(degrees[target_node])
model = SGAttack(surrogate, attack_structure=True, device=device)
model = model.to(device)
model.attack(features, adj, labels, target_node, n_perturbations, verbose=False)
modified_adj = model.modified_adj
modified_features = model.modified_features
acc = single_test(modified_adj, modified_features, target_node, gcn=target_gcn)
if acc == 0:
cnt += 1
print('misclassification rate : %s' % (cnt / num))
if __name__ == '__main__':
# MedianGCN is mainly designed for evasion attack
# so we do not consider poisoning attack setting
# Actually, MedianGCN still suffers from strong poisoning attack
main()
multi_test_evasion()