| import torch |
| import torch.nn as nn |
|
|
| from .utils import replicate_input |
|
|
|
|
| class Attack(object): |
| """ |
| Abstract base class for all attack classes. |
| Arguments: |
| predict (nn.Module): forward pass function. |
| loss_fn (nn.Module): loss function. |
| clip_min (float): mininum value per input dimension. |
| clip_max (float): maximum value per input dimension. |
| """ |
|
|
| def __init__(self, predict, loss_fn, clip_min, clip_max): |
| self.predict = predict |
| self.loss_fn = loss_fn |
| self.clip_min = clip_min |
| self.clip_max = clip_max |
|
|
| def perturb(self, x, **kwargs): |
| """ |
| Virtual method for generating the adversarial examples. |
| Arguments: |
| x (torch.Tensor): the model's input tensor. |
| **kwargs: optional parameters used by child classes. |
| Returns: |
| adversarial examples. |
| """ |
| error = "Sub-classes must implement perturb." |
| raise NotImplementedError(error) |
|
|
| def __call__(self, *args, **kwargs): |
| return self.perturb(*args, **kwargs) |
|
|
|
|
| class LabelMixin(object): |
| def _get_predicted_label(self, x): |
| """ |
| Compute predicted labels given x. Used to prevent label leaking during adversarial training. |
| Arguments: |
| x (torch.Tensor): the model's input tensor. |
| Returns: |
| torch.Tensor containing predicted labels. |
| """ |
| with torch.no_grad(): |
| outputs = self.predict(x) |
| _, y = torch.max(outputs, dim=1) |
| return y |
|
|
| def _verify_and_process_inputs(self, x, y): |
| if self.targeted: |
| assert y is not None |
|
|
| if not self.targeted: |
| if y is None: |
| y = self._get_predicted_label(x) |
|
|
| x = replicate_input(x) |
| y = replicate_input(y) |
| return x,y |