| import random |
| import numpy as np |
| from torch import nn |
| import torch |
|
|
| from concrete.fhe.compilation.compiler import Compiler |
| from concrete.ml.common.utils import generate_proxy_function |
| from concrete.ml.torch.numpy_module import NumpyModule |
|
|
| from common import AVAILABLE_MATCHERS |
|
|
|
|
| class TorchRandomGuessing(nn.Module): |
| """Torch identity model.""" |
|
|
| def __init__(self, classes_=[0, 1]): |
| super().__init__() |
| self.classes_ = classes_ |
|
|
| def forward(self, x): |
| """Random guessing forward pass. |
| |
| Args: |
| x (torch.Tensor): concat of query and reference. |
| |
| Returns: |
| (torch.Tensor): . |
| """ |
| x = x.sum() |
| return torch.tensor([random.choice([0, 1])]) + x - x |
|
|
|
|
| class Matcher: |
| def __init__(self, matcher_name): |
| assert matcher_name in AVAILABLE_MATCHERS, ( |
| f"Unsupported image matcher. Expected one of {*AVAILABLE_MATCHERS,}, " |
| f"but got {matcher_name}", |
| ) |
| self.fhe_circuit = None |
| self.matcher_name = matcher_name |
|
|
| if self.matcher_name == "random guessing": |
| self.torch_model = TorchRandomGuessing() |
|
|
| def compile(self): |
|
|
| inputset = (np.array([10]), np.array([5])) |
|
|
| print("torch module > numpy module ...") |
| numpy_module = NumpyModule( |
| |
| self.torch_model, |
| |
| dummy_input=torch.from_numpy(inputset[0]), |
| ) |
|
|
| print("get proxy function ...") |
| |
| |
| |
| numpy_filter_proxy, parameters_mapping = generate_proxy_function( |
| numpy_module.numpy_forward, ["inputs"] |
| ) |
|
|
| print("Compile the filter and retrieve its FHE circuit ...") |
| compiler = Compiler( |
| numpy_filter_proxy, |
| { |
| parameters_mapping["inputs"]: "encrypted", |
| }, |
| ) |
| self.fhe_circuit = compiler.compile(inputset) |
| return self.fhe_circuit |
|
|
| def post_processing(self, output_result): |
| """Apply post-processing to the decrypted output result. |
| |
| Args: |
| output_result (np.ndarray): The decrypted result to post-process. |
| |
| Returns: |
| output_result (np.ndarray): The post-processed result. |
| """ |
| print(f"{output_result=}") |
|
|
| return "PASS" if output_result[0] == 1 else "FAIL" |
|
|
|
|
| |
| |
|
|