| import tensorflow as tf |
| import numpy as np |
| import torch |
|
|
| class ModelAdapter(): |
| def __init__(self, model, num_classes=10): |
| """ |
| Please note that model should be tf.keras model without activation function 'softmax' |
| """ |
| self.num_classes = num_classes |
| self.tf_model = model |
| self.data_format = self.__check_channel_ordering() |
|
|
| def __tf_to_pt(self, tf_tensor): |
| """ Private function |
| Convert tf tensor to pt format |
| |
| Args: |
| tf_tensor: (tf_tensor) TF tensor |
| |
| Retruns: |
| pt_tensor: (pt_tensor) Pytorch tensor |
| """ |
|
|
| cpu_tensor = tf_tensor.numpy() |
| pt_tensor = torch.from_numpy(cpu_tensor).cuda() |
|
|
| return pt_tensor |
|
|
| def set_data_format(self, data_format): |
| """ |
| Set data_format manually |
| |
| Args: |
| data_format: A string, whose value should be either 'channels_last' or 'channels_first' |
| """ |
|
|
| if data_format != 'channels_last' or data_format != 'channels_first': |
| raise ValueError("data_format should be either 'channels_last' or 'channels_first'") |
|
|
| self.data_format = data_format |
|
|
|
|
| def __check_channel_ordering(self): |
| """ Private function |
| Determinate TF model's channel ordering based on model's information. |
| Default ordering is 'channels_last' in TF. |
| However, 'channels_first' is used in Pytorch. |
| |
| Returns: |
| data_format: A string, whose value should be either 'channels_last' or 'channels_first' |
| """ |
|
|
| data_format = None |
|
|
| |
| for L in self.tf_model.layers: |
| if isinstance(L, tf.keras.layers.Conv2D): |
| print("[INFO] set data_format = '{:s}'".format(L.data_format)) |
| data_format = L.data_format |
| break |
|
|
| |
| if data_format is None: |
| print("[WARNING] Can not find Conv2D layer") |
| input_shape = self.tf_model.input_shape |
|
|
| |
| if input_shape[3] == 3: |
| print("[INFO] Because detecting input_shape[3] == 3, set data_format = 'channels_last'") |
| data_format = 'channels_last' |
|
|
| |
| elif input_shape[3] == 1: |
| print("[INFO] Because detecting input_shape[3] == 1, set data_format = 'channels_last'") |
| data_format = 'channels_last' |
|
|
| |
| elif input_shape[1] == 3: |
| print("[INFO] Because detecting input_shape[1] == 3, set data_format = 'channels_first'") |
| data_format = 'channels_first' |
|
|
| |
| elif input_shape[1] == 1: |
| print("[INFO] Because detecting input_shape[1] == 1, set data_format = 'channels_first'") |
| data_format = 'channels_first' |
|
|
| else: |
| print("[ERROR] Unknow case") |
|
|
| return data_format |
|
|
|
|
| |
| def __get_logits(self, x_input): |
| """ Private function |
| Get model's pre-softmax output in inference mode |
| |
| Args: |
| x_input: (tf_tensor) Input data |
| |
| Returns: |
| logits: (tf_tensor) Logits |
| """ |
|
|
| return self.tf_model(x_input, training=False) |
|
|
|
|
| def __get_xent(self, logits, y_input): |
| """ Private function |
| Get cross entropy loss |
| |
| Args: |
| logits: (tf_tensor) Logits. |
| y_input: (tf_tensor) Label. |
| |
| Returns: |
| xent: (tf_tensor) Cross entropy |
| """ |
|
|
| return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_input) |
|
|
|
|
| def __get_dlr(self, logit, y_input): |
| """ Private function |
| Get DLR loss |
| |
| Args: |
| logit: (tf_tensor) Logits |
| y_input: (tf_tensor) Input label |
| |
| Returns: |
| loss: (tf_tensor) DLR loss |
| """ |
|
|
| |
| logit_sort = tf.sort(logit, axis=1) |
|
|
| |
| y_onehot = tf.one_hot(y_input , self.num_classes, dtype=tf.float32) |
| logit_y = tf.reduce_sum(y_onehot * logit, axis=1) |
|
|
| |
| logit_pred = tf.reduce_max(logit, axis=1) |
| cond = (logit_pred == logit_y) |
| z_i = tf.where(cond, logit_sort[:, -2], logit_sort[:, -1]) |
|
|
| |
| z_y = logit_y |
| z_p1 = logit_sort[:, -1] |
| z_p3 = logit_sort[:, -3] |
|
|
| loss = - (z_y - z_i) / (z_p1 - z_p3 + 1e-12) |
| return loss |
|
|
|
|
| def __get_dlr_target(self, logits, y_input, y_target): |
| """ Private function |
| Get targeted version of DLR loss |
| |
| Args: |
| logit: (tf_tensor) Logits |
| y_input: (tf_tensor) Input label |
| y_target: (tf_tensor) Input targeted label |
| |
| Returns: |
| loss: (tf_tensor) Targeted DLR loss |
| """ |
|
|
| x = logits |
| x_sort = tf.sort(x, axis=1) |
| y_onehot = tf.one_hot(y_input, self.num_classes) |
| y_target_onehot = tf.one_hot(y_target, self.num_classes) |
| loss = -(tf.reduce_sum(x * y_onehot, axis=1) - tf.reduce_sum(x * y_target_onehot, axis=1)) / (x_sort[:, -1] - .5 * x_sort[:, -3] - .5 * x_sort[:, -4] + 1e-12) |
|
|
| return loss |
|
|
|
|
| |
| @tf.function |
| @tf.autograph.experimental.do_not_convert |
| def __get_jacobian(self, x_input): |
| """ Private function |
| Get Jacoian |
| |
| Args: |
| x_input: (tf_tensor) Input data |
| |
| Returns: |
| jaconbian: (tf_tensor) Jacobian |
| """ |
|
|
| with tf.GradientTape(watch_accessed_variables=False) as g: |
| g.watch(x_input) |
| logits = self.__get_logits(x_input) |
|
|
| jacobian = g.batch_jacobian(logits, x_input) |
|
|
| return logits, jacobian |
|
|
|
|
| @tf.function |
| @tf.autograph.experimental.do_not_convert |
| def __get_grad_xent(self, x_input, y_input): |
| """ Private function |
| Get gradient of cross entropy |
| |
| Args: |
| x_input: (tf_tensor) Input data |
| y_input: (tf_tensor) Input label |
| |
| Returns: |
| logits: (tf_tensor) Logits |
| xent: (tf_tensor) Cross entropy |
| grad_xent: (tf_tensor) Gradient of cross entropy |
| """ |
|
|
| with tf.GradientTape(watch_accessed_variables=False) as g: |
| g.watch(x_input) |
| logits = self.__get_logits(x_input) |
| xent = self.__get_xent(logits, y_input) |
| |
| grad_xent = g.gradient(xent, x_input) |
|
|
| return logits, xent, grad_xent |
|
|
|
|
| @tf.function |
| @tf.autograph.experimental.do_not_convert |
| def __get_grad_diff_logits_target(self, x, la, la_target): |
| """ Private function |
| Get difference of logits and corrospopnding gradient |
| |
| Args: |
| x_input: (tf_tensor) Input data |
| la: (tf_tensor) Input label |
| la_target: (tf_tensor) Input targeted label |
| |
| Returns: |
| difflogits: (tf_tensor) Difference of logits |
| grad_diff: (tf_tensor) Gradient of difference of logits |
| """ |
|
|
| la_mask = tf.one_hot(la, self.num_classes) |
| la_target_mask = tf.one_hot(la_target, self.num_classes) |
|
|
| with tf.GradientTape(watch_accessed_variables=False) as g: |
| g.watch(x) |
| logits = self.__get_logits(x) |
| difflogits = tf.reduce_sum((la_target_mask - la_mask) * logits, axis=1) |
|
|
| grad_diff = g.gradient(difflogits, x) |
|
|
| return difflogits, grad_diff |
|
|
|
|
| @tf.function |
| @tf.autograph.experimental.do_not_convert |
| def __get_grad_dlr(self, x_input, y_input): |
| """ Private function |
| Get gradient of DLR loss |
| |
| Args: |
| x_input: (tf_tensor) Input data |
| y_input: (tf_tensor) Input label |
| |
| Returns: |
| logits: (tf_tensor) Logits |
| val_dlr: (tf_tensor) DLR loss |
| grad_dlr: (tf_tensor) Gradient of DLR loss |
| """ |
|
|
| with tf.GradientTape(watch_accessed_variables=False) as g: |
| g.watch(x_input) |
| logits = self.__get_logits(x_input) |
| val_dlr = self.__get_dlr(logits, y_input) |
|
|
| grad_dlr = g.gradient(val_dlr, x_input) |
| |
| return logits, val_dlr, grad_dlr |
|
|
|
|
| @tf.function |
| @tf.autograph.experimental.do_not_convert |
| def __get_grad_dlr_target(self, x_input, y_input, y_target): |
| """ Private function |
| Get gradient of targeted DLR loss |
| |
| Args: |
| x_input: (tf_tensor) Input data |
| y_input: (tf_tensor) Input label |
| y_target: (tf_tensor) Input targeted label |
| |
| Returns: |
| logits: (tf_tensor) Logits |
| val_dlr: (tf_tensor) Targeted DLR loss |
| grad_dlr: (tf_tensor) Gradient of targeted DLR loss |
| """ |
|
|
| with tf.GradientTape(watch_accessed_variables=False) as g: |
| g.watch(x_input) |
| logits = self.__get_logits(x_input) |
| dlr_target = self.__get_dlr_target(logits, y_input, y_target) |
|
|
| grad_target = g.gradient(dlr_target, x_input) |
|
|
| return logits, dlr_target, grad_target |
| |
|
|
| |
| def predict(self, x): |
| """ |
| Get model's pre-softmax output in inference mode |
| |
| Args: |
| x_input: (pytorch_tensor) Input data |
| |
| Returns: |
| y: (pytorch_tensor) Pre-softmax output |
| """ |
|
|
| |
| x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
| if self.data_format == 'channels_last': |
| x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
| |
| y = self.__get_logits(x2) |
|
|
| |
| y = self.__tf_to_pt(y) |
| |
| return y |
|
|
|
|
| def grad_logits(self, x): |
| """ |
| Get logits and gradient of logits |
| |
| Args: |
| x: (pytorch_tensor) Input data |
| |
| Returns: |
| logits: (pytorch_tensor) Logits |
| g2: (pytorch_tensor) Jacobian |
| """ |
|
|
| |
| x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
| if self.data_format == 'channels_last': |
| x2 = tf.transpose(x2, perm=[0,2,3,1]) |
| |
| |
| logits, g2 = self.__get_jacobian(x2) |
|
|
| |
| if self.data_format == 'channels_last': |
| g2 = tf.transpose(g2, perm=[0,1,4,2,3]) |
| logits = self.__tf_to_pt(logits) |
| g2 = self.__tf_to_pt(g2) |
|
|
| return logits, g2 |
|
|
|
|
| def get_logits_loss_grad_xent(self, x, y): |
| """ |
| Get gradient of cross entropy |
| |
| Args: |
| x: (pytorch_tensor) Input data |
| y: (pytorch_tensor) Input label |
| |
| Returns: |
| logits_val: (pytorch_tensor) Logits |
| loss_indiv_val: (pytorch_tensor) Cross entropy |
| grad_val: (pytorch_tensor) Gradient of cross entropy |
| """ |
|
|
| |
| x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
| y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
| if self.data_format == 'channels_last': |
| x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
| |
| logits_val, loss_indiv_val, grad_val = self.__get_grad_xent(x2, y2) |
|
|
| |
| if self.data_format == 'channels_last': |
| grad_val = tf.transpose(grad_val, perm=[0,3,1,2]) |
| logits_val = self.__tf_to_pt(logits_val) |
| loss_indiv_val = self.__tf_to_pt(loss_indiv_val) |
| grad_val = self.__tf_to_pt(grad_val) |
|
|
| return logits_val, loss_indiv_val, grad_val |
|
|
|
|
| def set_target_class(self, y, y_target): |
| pass |
| |
|
|
| def get_grad_diff_logits_target(self, x, y, y_target): |
| """ |
| Get difference of logits and corrospopnding gradient |
| |
| Args: |
| x: (pytorch_tensor) Input data |
| y: (pytorch_tensor) Input label |
| y_target: (pytorch_tensor) Input targeted label |
| |
| Returns: |
| difflogits: (pytorch_tensor) Difference of logits |
| g2: (pytorch_tensor) Gradient of difference of logits |
| """ |
|
|
| |
| la = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
| la_target = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32) |
| x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
| if self.data_format == 'channels_last': |
| x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
| |
| difflogits, g2 = self.__get_grad_diff_logits_target(x2, la, la_target) |
|
|
| |
| if self.data_format == 'channels_last': |
| g2 = tf.transpose(g2, perm=[0, 3, 1, 2]) |
| difflogits = self.__tf_to_pt(difflogits) |
| g2 = self.__tf_to_pt(g2) |
| |
| return difflogits, g2 |
|
|
|
|
| def get_logits_loss_grad_dlr(self, x, y): |
| """ |
| Get gradient of DLR loss |
| |
| Args: |
| x: (pytorch_tensor) Input data |
| y: (pytorch_tensor) Input label |
| |
| Returns: |
| logits_val: (pytorch_tensor) Logits |
| loss_indiv_val: (pytorch_tensor) DLR loss |
| grad_val: (pytorch_tensor) Gradient of DLR loss |
| """ |
|
|
| |
| x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
| y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
| if self.data_format == 'channels_last': |
| x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
| |
| logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr(x2, y2) |
|
|
| |
| if self.data_format == 'channels_last': |
| grad_val = tf.transpose(grad_val, perm=[0,3,1,2]) |
| logits_val = self.__tf_to_pt(logits_val) |
| loss_indiv_val = self.__tf_to_pt(loss_indiv_val) |
| grad_val = self.__tf_to_pt(grad_val) |
|
|
| return logits_val, loss_indiv_val, grad_val |
| |
| def get_logits_loss_grad_target(self, x, y, y_target): |
| """ |
| Get gradient of targeted DLR loss |
| |
| Args: |
| x: (pytorch_tensor) Input data |
| y: (pytorch_tensor) Input label |
| y_target: (pytorch_tensor) Input targeted label |
| |
| Returns: |
| logits_val: (pytorch_tensor) Logits |
| loss_indiv_val: (pytorch_tensor) Targeted DLR loss |
| grad_val: (pytorch_tensor) Gradient of targeted DLR loss |
| """ |
|
|
| |
| x2 = tf.convert_to_tensor(x.cpu().numpy(), dtype=tf.float32) |
| y2 = tf.convert_to_tensor(y.cpu().numpy(), dtype=tf.int32) |
| y_targ = tf.convert_to_tensor(y_target.cpu().numpy(), dtype=tf.int32) |
| if self.data_format == 'channels_last': |
| x2 = tf.transpose(x2, perm=[0,2,3,1]) |
|
|
| |
| logits_val, loss_indiv_val, grad_val = self.__get_grad_dlr_target(x2, y2, y_targ) |
|
|
| |
| if self.data_format == 'channels_last': |
| grad_val = tf.transpose(grad_val, perm=[0,3,1,2]) |
| logits_val = self.__tf_to_pt(logits_val) |
| loss_indiv_val = self.__tf_to_pt(loss_indiv_val) |
| grad_val = self.__tf_to_pt(grad_val) |
|
|
| return logits_val, loss_indiv_val, grad_val |
|
|