| from config import DatasetName |
| import tensorflow as tf |
|
|
|
|
| class ACRLoss: |
| def acr_loss(self, x_pr, x_gt, phi, lambda_weight, ds_name): |
| low_map = tf.cast(tf.abs(x_pr - x_gt) <= 1.0, dtype=tf.float32) |
| high_map = tf.cast(tf.abs(x_pr - x_gt) > 1.0, dtype=tf.float32) |
|
|
| '''Big errors''' |
| ln_2 = tf.ones_like(x_pr, dtype=tf.float32) * tf.math.log(2.0) |
| C = tf.cast(tf.cast(phi, dtype=tf.double) * tf.cast(ln_2, dtype=tf.double) - 1.0, dtype=tf.float32) |
| loss_high = 100 * tf.reduce_mean(tf.math.multiply(high_map, (tf.square(x_pr - x_gt) + C))) |
|
|
| '''Small errors''' |
| power = tf.cast(2.0 - phi, tf.dtypes.float32) |
| ll = tf.pow(tf.abs(x_pr - x_gt), power) |
| loss_low = 100 * tf.reduce_mean(tf.math.multiply(low_map, (lambda_weight * tf.math.log(1.0 + ll)))) |
|
|
| loss_total = loss_low + loss_high |
|
|
| return loss_total, loss_low, loss_high |
|
|