| import numpy as np |
| import json |
|
|
| import torch |
| from torchvision import transforms |
|
|
| import util_functions.torch_utils as torch_utils |
| import util_functions.image_utils as image_utils |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| torch.manual_seed(0) |
| np.random.seed(0) |
|
|
| print('Building backbone and normalization layer...') |
| backbone = torch_utils.build_backbone(path='models/dino_r50.pth') |
| normlayer = torch_utils.load_normalization_layer(path='models/out2048.pth') |
| model = torch_utils.NormLayerWrapper(backbone, normlayer) |
|
|
| print('Building the hypercone...') |
| FPR = 1e-6 |
| angle = 1.462771101178447 |
| rho = 1 + np.tan(angle)**2 |
| carrier = torch.randn(1, 2048) |
| carrier /= torch.norm(carrier, dim=1, keepdim=True) |
|
|
| default_transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| def encode(image, epochs=10, psnr=44, lambda_w=1, lambda_i=1): |
| img_orig = default_transform(image).to(device, non_blocking=True).unsqueeze(0) |
| img = img_orig.clone().to(device, non_blocking=True) |
| img.requires_grad = True |
| optimizer = torch.optim.Adam([img], lr=1e-2) |
|
|
| for iteration in range(epochs): |
| print(f'iteration: {iteration}') |
| x = image_utils.ssim_attenuation(img, img_orig) |
| x = image_utils.psnr_clip(x, img_orig, psnr) |
|
|
| ft = model(x) |
|
|
| dot_product = (ft @ carrier.T) |
| norm = torch.norm(ft, dim=-1, keepdim=True) |
| cosines = torch.abs(dot_product/norm) |
| log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1])) |
| loss_R = -(rho * dot_product**2 - norm**2) |
|
|
| loss_l2_img = torch.norm(x - img_orig)**2 |
| loss = lambda_w*loss_R + lambda_i*loss_l2_img |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| logs = { |
| "keyword": "img_optim", |
| "iteration": iteration, |
| "loss": loss.item(), |
| "loss_R": loss_R.item(), |
| "loss_l2_img": loss_l2_img.item(), |
| "log10_pvalue": log10_pvalue.item(), |
| } |
| print("__log__:%s" % json.dumps(logs)) |
|
|
| img = image_utils.ssim_attenuation(img, img_orig) |
| img = image_utils.psnr_clip(img, img_orig, psnr) |
| img = image_utils.round_pixel(img) |
| img = img.squeeze(0).detach().cpu() |
| img = transforms.ToPILImage()(image_utils.unnormalize_img(img).squeeze(0)) |
|
|
| return img |
|
|
| def decode(image): |
| img = default_transform(image).to(device, non_blocking=True).unsqueeze(0) |
| ft = model(img) |
|
|
| dot_product = (ft @ carrier.T) |
| norm = torch.norm(ft, dim=-1, keepdim=True) |
| cosines = torch.abs(dot_product/norm) |
| log10_pvalue = np.log10(torch_utils.cosine_pvalue(cosines.item(), ft.shape[-1])) |
| loss_R = -(rho * dot_product**2 - norm**2) |
|
|
| text_marked = "marked" if loss_R < 0 else "unmarked" |
| return f'Image is {text_marked}' |
|
|