| import torch |
| import torch.nn as nn |
|
|
| from .base import Attack, LabelMixin |
| from .utils import ctx_noparamgrad_and_eval |
| from .utils import batch_multiply |
| from .utils import clamp ,normalize_by_pnorm |
| from utils.distributed import DistributedMetric |
| from tqdm import tqdm |
| from torchpack import distributed as dist |
| from utils import accuracy |
| from typing import Dict |
| class FGSMAttack(Attack, LabelMixin): |
| """ |
| One step fast gradient sign method (Goodfellow et al, 2014). |
| Arguments: |
| predict (nn.Module): forward pass function. |
| loss_fn (nn.Module): loss function. |
| eps (float): attack step size. |
| clip_min (float): mininum value per input dimension. |
| clip_max (float): maximum value per input dimension. |
| targeted (bool): indicate if this is a targeted attack. |
| """ |
|
|
| def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): |
| super(FGSMAttack, self).__init__(predict, loss_fn, clip_min, clip_max) |
|
|
| self.eps = eps |
| self.targeted = targeted |
| if self.loss_fn is None: |
| self.loss_fn = nn.CrossEntropyLoss(reduction="sum") |
|
|
| def perturb(self, x, y=None): |
| """ |
| Given examples (x, y), returns their adversarial counterparts with an attack length of eps. |
| Arguments: |
| x (torch.Tensor): input tensor. |
| y (torch.Tensor): label tensor. |
| - if None and self.targeted=False, compute y as predicted labels. |
| - if self.targeted=True, then y must be the targeted labels. |
| Returns: |
| torch.Tensor containing perturbed inputs. |
| torch.Tensor containing the perturbation. |
| """ |
|
|
| x, y = self._verify_and_process_inputs(x, y) |
| |
| xadv = x.requires_grad_() |
| outputs = self.predict(xadv) |
|
|
| loss = self.loss_fn(outputs, y) |
| if self.targeted: |
| loss = -loss |
| loss.backward() |
| grad_sign = xadv.grad.detach().sign() |
|
|
| xadv = xadv + batch_multiply(self.eps, grad_sign) |
| xadv = clamp(xadv, self.clip_min, self.clip_max) |
| radv = xadv - x |
| return xadv.detach(), radv.detach() |
|
|
|
|
| LinfFastGradientAttack = FGSMAttack |
|
|
|
|
| class FGMAttack(Attack, LabelMixin): |
| """ |
| One step fast gradient method. Perturbs the input with gradient (not gradient sign) of the loss wrt the input. |
| Arguments: |
| predict (nn.Module): forward pass function. |
| loss_fn (nn.Module): loss function. |
| eps (float): attack step size. |
| clip_min (float): mininum value per input dimension. |
| clip_max (float): maximum value per input dimension. |
| targeted (bool): indicate if this is a targeted attack. |
| """ |
|
|
| def __init__(self, predict, loss_fn=None, eps=0.3, clip_min=0., clip_max=1., targeted=False): |
| super(FGMAttack, self).__init__( |
| predict, loss_fn, clip_min, clip_max) |
|
|
| self.eps = eps |
| self.targeted = targeted |
| if self.loss_fn is None: |
| self.loss_fn = nn.CrossEntropyLoss(reduction="sum") |
|
|
| def perturb(self, x, y=None): |
| """ |
| Given examples (x, y), returns their adversarial counterparts with an attack length of eps. |
| Arguments: |
| x (torch.Tensor): input tensor. |
| y (torch.Tensor): label tensor. |
| - if None and self.targeted=False, compute y as predicted labels. |
| - if self.targeted=True, then y must be the targeted labels. |
| Returns: |
| torch.Tensor containing perturbed inputs. |
| torch.Tensor containing the perturbation. |
| """ |
|
|
| x, y = self._verify_and_process_inputs(x, y) |
| xadv = x.requires_grad_() |
| outputs = self.predict(xadv) |
|
|
| loss = self.loss_fn(outputs, y) |
| if self.targeted: |
| loss = -loss |
| loss.backward() |
| grad = normalize_by_pnorm(xadv.grad) |
| xadv = xadv + batch_multiply(self.eps, grad) |
| xadv = clamp(xadv, self.clip_min, self.clip_max) |
| radv = xadv - x |
|
|
| return xadv.detach(), radv.detach() |
| def eval_fgsm(self,data_loader_dict: Dict)-> Dict: |
|
|
| test_criterion = nn.CrossEntropyLoss().cuda() |
| val_loss = DistributedMetric() |
| val_top1 = DistributedMetric() |
| val_top5 = DistributedMetric() |
| val_advloss = DistributedMetric() |
| val_advtop1 = DistributedMetric() |
| val_advtop5 = DistributedMetric() |
| self.predict.eval() |
| with tqdm( |
| total=len(data_loader_dict["val"]), |
| desc="Eval", |
| disable=not dist.is_master(), |
| ) as t: |
| for images, labels in data_loader_dict["val"]: |
| images, labels = images.cuda(), labels.cuda() |
| |
| output = self.predict(images) |
| loss = test_criterion(output, labels) |
| val_loss.update(loss, images.shape[0]) |
| acc1, acc5 = accuracy(output, labels, topk=(1, 5)) |
| val_top5.update(acc5[0], images.shape[0]) |
| val_top1.update(acc1[0], images.shape[0]) |
| with ctx_noparamgrad_and_eval(self.predict): |
| images_adv,_ = self.perturb(images, labels) |
| output_adv = self.predict(images_adv) |
| loss_adv = test_criterion(output_adv,labels) |
| val_advloss.update(loss_adv, images.shape[0]) |
| acc1_adv, acc5_adv = accuracy(output_adv, labels, topk=(1, 5)) |
| val_advtop1.update(acc1_adv[0], images.shape[0]) |
| val_advtop5.update(acc5_adv[0], images.shape[0]) |
| t.set_postfix( |
| { |
| "loss": val_loss.avg.item(), |
| "top1": val_top1.avg.item(), |
| "top5": val_top5.avg.item(), |
| "adv_loss": val_advloss.avg.item(), |
| "adv_top1": val_advtop1.avg.item(), |
| "adv_top5": val_advtop5.avg.item(), |
| "#samples": val_top1.count.item(), |
| "batch_size": images.shape[0], |
| "img_size": images.shape[2], |
| } |
| ) |
| t.update() |
|
|
| val_results = { |
| "val_top1": val_top1.avg.item(), |
| "val_top5": val_top5.avg.item(), |
| "val_loss": val_loss.avg.item(), |
| "val_advtop1": val_advtop1.avg.item(), |
| "val_advtop5": val_advtop5.avg.item(), |
| "val_advloss": val_advloss.avg.item(), |
| } |
| return val_results |
|
|
| L2FastGradientAttack = FGMAttack |
|
|