| """A local gradio app that detects seizures with EEG using FHE.""" |
| from PIL import Image |
| import os |
| import shutil |
| import subprocess |
| import time |
| import gradio as gr |
| import numpy |
| import requests |
| import numpy as np |
| from itertools import chain |
|
|
| from common import ( |
| CLIENT_TMP_PATH, |
| SEIZURE_DETECTION_MODEL_PATH, |
| SERVER_TMP_PATH, |
| EXAMPLES, |
| INPUT_SHAPE, |
| KEYS_PATH, |
| REPO_DIR, |
| SERVER_URL, |
| ) |
| from client_server_interface import FHEClient |
| from concrete.ml.deployment import FHEModelClient |
| |
| subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR) |
| time.sleep(3) |
|
|
| def shorten_bytes_object(bytes_object, limit=500): |
| """Shorten the input bytes object to a given length. |
| |
| Encrypted data is too large for displaying it in the browser using Gradio. This function |
| provides a shorten representation of it. |
| |
| Args: |
| bytes_object (bytes): The input to shorten |
| limit (int): The length to consider. Default to 500. |
| |
| Returns: |
| str: Hexadecimal string shorten representation of the input byte object. |
| |
| """ |
| |
| shift = 100 |
| return bytes_object[shift : limit + shift].hex() |
|
|
| def get_client(user_id): |
| """Get the client API. |
| |
| Args: |
| user_id (int): The current user's ID. |
| |
| Returns: |
| FHEClient: The client API. |
| """ |
| return FHEClient( |
| key_dir=KEYS_PATH / f"seizure_detection_{user_id}" |
| ) |
|
|
| def get_client_file_path(name, user_id): |
| """Get the correct temporary file path for the client. |
| |
| Args: |
| name (str): The desired file name. |
| user_id (int): The current user's ID. |
| |
| Returns: |
| pathlib.Path: The file path. |
| """ |
| return CLIENT_TMP_PATH / f"{name}_seizure_detection_{user_id}" |
|
|
| def clean_temporary_files(n_keys=20): |
| """Clean keys and encrypted images. |
| |
| A maximum of n_keys keys and associated temporary files are allowed to be stored. Once this |
| limit is reached, the oldest files are deleted. |
| |
| Args: |
| n_keys (int): The maximum number of keys and associated files to be stored. Default to 20. |
| |
| """ |
| |
| key_dirs = sorted(KEYS_PATH.iterdir(), key=os.path.getmtime) |
|
|
| |
| user_ids = [] |
| if len(key_dirs) > n_keys: |
| n_keys_to_delete = len(key_dirs) - n_keys |
| for key_dir in key_dirs[:n_keys_to_delete]: |
| user_ids.append(key_dir.name) |
| shutil.rmtree(key_dir) |
|
|
| |
| client_files = CLIENT_TMP_PATH.iterdir() |
| server_files = SERVER_TMP_PATH.iterdir() |
|
|
| |
| for file in chain(client_files, server_files): |
| for user_id in user_ids: |
| if user_id in file.name: |
| file.unlink() |
|
|
| def keygen(): |
| """Generate the private key for seizure detection. |
| |
| Returns: |
| (user_id, True) (Tuple[int, bool]): The current user's ID and a boolean used for visual display. |
| |
| """ |
| |
| clean_temporary_files() |
|
|
| |
| user_id = np.random.randint(0, 2**32) |
| print(f"Your user ID is: {user_id}....") |
|
|
| client = FHEModelClient(path_dir=SEIZURE_DETECTION_MODEL_PATH, key_dir=KEYS_PATH / f"{user_id}") |
| client.load() |
|
|
| print("Super print ici") |
|
|
| |
| client.generate_private_and_evaluation_keys() |
|
|
| print("Super print ici 2") |
| |
| serialized_evaluation_keys = client.get_serialized_evaluation_keys() |
| assert isinstance(serialized_evaluation_keys, bytes) |
|
|
| print("Super print ici 3") |
|
|
| |
| evaluation_key_path = KEYS_PATH / f"{user_id}/evaluation_key" |
| with evaluation_key_path.open("wb") as f: |
| f.write(serialized_evaluation_keys) |
|
|
| print("Super print ici 4") |
|
|
| return (user_id, True) |
|
|
| def encrypt(user_id, input_image): |
| """Encrypt the given image for seizure detection. |
| |
| Args: |
| user_id (int): The current user's ID. |
| input_image (numpy.ndarray): The image to encrypt. |
| |
| Returns: |
| (input_image, encrypted_image_short) (Tuple[bytes]): The encrypted image and one of its |
| representation. |
| |
| """ |
| if user_id == "": |
| raise gr.Error("Please generate the private key first.") |
|
|
| if input_image is None: |
| raise gr.Error("Please choose an image first.") |
|
|
|
|
| import numpy as np |
| |
| if input_image.shape != (32, 32, 1): |
| input_image_pil = Image.fromarray(input_image) |
| input_image_pil = input_image_pil.resize((32, 32)) |
| input_image = np.array(input_image_pil) |
|
|
| |
| input_image = np.mean(input_image, axis=2).astype(np.float32) |
| input_image = input_image.reshape(1, 1, 32, 32) |
|
|
| |
| input_image = (input_image / 255.0 * 4095 - 2048).astype(np.int16) |
| input_image = np.clip(input_image, -2048, 2047) |
|
|
| print("Processing the image finished") |
| |
| client = get_client(user_id) |
|
|
| print("Client retrieved") |
|
|
| |
| encrypted_image = client.encrypt_serialize(input_image) |
|
|
| print("Encrypted image retrieved") |
|
|
| |
| |
| encrypted_image_path = get_client_file_path("encrypted_image", user_id) |
|
|
| print("Encrypted image path retrieved") |
|
|
| with encrypted_image_path.open("wb") as encrypted_image_file: |
| encrypted_image_file.write(encrypted_image) |
|
|
| print("Encrypted image file retrieved") |
|
|
| |
| encrypted_image_short = encrypted_image[:100] |
|
|
| return encrypted_image_short |
|
|
|
|
| def send_input(user_id): |
| """Send the encrypted input image as well as the evaluation key to the server. |
| |
| Args: |
| user_id (int): The current user's ID. |
| """ |
| |
| evaluation_key_path = get_client_file_path("evaluation_key", user_id) |
|
|
| if user_id == "" or not evaluation_key_path.is_file(): |
| raise gr.Error("Please generate the private key first.") |
|
|
| encrypted_input_path = get_client_file_path("encrypted_image", user_id) |
|
|
| if not encrypted_input_path.is_file(): |
| raise gr.Error("Please generate the private key and then encrypt an image first.") |
|
|
| |
| data = { |
| "user_id": user_id, |
| } |
|
|
| files = [ |
| ("files", open(encrypted_input_path, "rb")), |
| ("files", open(evaluation_key_path, "rb")), |
| ] |
|
|
| |
| url = SERVER_URL + "send_input" |
| with requests.post( |
| url=url, |
| data=data, |
| files=files, |
| ) as response: |
| return response.ok |
|
|
| def run_fhe(user_id): |
| """Apply the seizure detection model on the encrypted image previously sent using FHE. |
| |
| Args: |
| user_id (int): The current user's ID. |
| """ |
| data = { |
| "user_id": user_id, |
| } |
|
|
| |
| url = SERVER_URL + "run_fhe" |
| with requests.post( |
| url=url, |
| data=data, |
| ) as response: |
| if response.ok: |
| return response.json() |
| else: |
| raise gr.Error("Please wait for the input image to be sent to the server.") |
|
|
| def get_output(user_id): |
| """Retrieve the encrypted output (boolean). |
| |
| Args: |
| user_id (int): The current user's ID. |
| |
| Returns: |
| encrypted_output_short (bytes): A representation of the encrypted result. |
| |
| """ |
| data = { |
| "user_id": user_id, |
| } |
|
|
| |
| url = SERVER_URL + "get_output" |
| with requests.post( |
| url=url, |
| data=data, |
| ) as response: |
| if response.ok: |
| encrypted_output = response.content |
|
|
| |
| |
| encrypted_output_path = get_client_file_path("encrypted_output", user_id) |
|
|
| with encrypted_output_path.open("wb") as encrypted_output_file: |
| encrypted_output_file.write(encrypted_output) |
|
|
| |
| encrypted_output_short = shorten_bytes_object(encrypted_output) |
|
|
| return encrypted_output_short |
| else: |
| raise gr.Error("Please wait for the FHE execution to be completed.") |
|
|
| def decrypt_output(user_id): |
| """Decrypt the result. |
| |
| Args: |
| user_id (int): The current user's ID. |
| |
| Returns: |
| bool: The decrypted output (True if seizure detected, False otherwise) |
| |
| """ |
| if user_id == "": |
| raise gr.Error("Please generate the private key first.") |
|
|
| |
| encrypted_output_path = get_client_file_path("encrypted_output", user_id) |
|
|
| if not encrypted_output_path.is_file(): |
| raise gr.Error("Please run the FHE execution first.") |
|
|
| |
| with encrypted_output_path.open("rb") as encrypted_output_file: |
| encrypted_output = encrypted_output_file.read() |
|
|
| |
| client = get_client(user_id) |
|
|
| |
| decrypted_output = client.deserialize_decrypt_post_process(encrypted_output) |
|
|
| return "Seizure detected" if decrypted_output else "No seizure detected" |
|
|
| def resize_img(img, width=256, height=256): |
| """Resize the image.""" |
| if img.dtype != numpy.uint8: |
| img = img.astype(numpy.uint8) |
| img_pil = Image.fromarray(img) |
| |
| resized_img_pil = img_pil.resize((width, height)) |
| |
| return numpy.array(resized_img_pil) |
|
|
| demo = gr.Blocks() |
|
|
| print("Starting the demo...") |
| with demo: |
| gr.Markdown( |
| """ |
| <h1 align="center">Seizure Detection on Encrypted EEG Data Using Fully Homomorphic Encryption</h1> |
| """ |
| ) |
|
|
| gr.Markdown("## Client side") |
| gr.Markdown("### Step 1: Upload an EEG image. ") |
| gr.Markdown( |
| f"The image will automatically be resized to shape (32, 32). " |
| "The image here, however, is displayed in its original resolution." |
| ) |
| with gr.Row(): |
| input_image = gr.Image( |
| value=None, label="Upload an EEG image here.", height=256, |
| width=256, sources="upload", interactive=True, |
| ) |
|
|
| examples = gr.Examples( |
| examples=EXAMPLES, inputs=[input_image], examples_per_page=5, label="Examples to use." |
| ) |
|
|
| gr.Markdown("### Step 2: Generate the private key.") |
| keygen_button = gr.Button("Generate the private key.") |
|
|
| with gr.Row(): |
| keygen_checkbox = gr.Checkbox(label="Private key generated:", interactive=False) |
|
|
| user_id = gr.Textbox(label="", max_lines=2, interactive=False, visible=False) |
|
|
| gr.Markdown("### Step 3: Encrypt the image using FHE.") |
| encrypt_button = gr.Button("Encrypt the image using FHE.") |
|
|
| with gr.Row(): |
| encrypted_input = gr.Textbox( |
| label="Encrypted input representation:", max_lines=2, interactive=False |
| ) |
|
|
| gr.Markdown("## Server side") |
| gr.Markdown( |
| "The encrypted value is received by the server. The server can then compute the seizure " |
| "detection directly over encrypted values. Once the computation is finished, the server returns " |
| "the encrypted results to the client." |
| ) |
| gr.Markdown("### Step 4: Send the encrypted image to the server.") |
| send_input_button = gr.Button("Send the encrypted image to the server.") |
| send_input_checkbox = gr.Checkbox(label="Encrypted image sent.", interactive=False) |
|
|
| gr.Markdown("### Step 5: Run FHE execution.") |
| execute_fhe_button = gr.Button("Run FHE execution.") |
| fhe_status = gr.Textbox(label="FHE execution status:", max_lines=1, interactive=False) |
| fhe_execution_time = gr.Textbox( |
| label="Total FHE execution time (in seconds):", max_lines=1, interactive=False |
| ) |
| task_id = gr.Textbox(label="Task ID:", visible=False) |
|
|
| gr.Markdown("### Step 6: Check FHE execution status and receive the encrypted output from the server.") |
| check_status_button = gr.Button("Check FHE execution status") |
| get_output_button = gr.Button("Receive the encrypted output from the server.", interactive=False) |
|
|
| with gr.Row(): |
| encrypted_output = gr.Textbox( |
| label="Encrypted output representation:", |
| max_lines=2, |
| interactive=False |
| ) |
|
|
| gr.Markdown("## Client side") |
| gr.Markdown( |
| "The encrypted output is sent back to the client, who can finally decrypt it with the " |
| "private key. Only the client is aware of the original image and the detection result." |
| ) |
|
|
| gr.Markdown("### Step 7: Decrypt the output.") |
| decrypt_button = gr.Button("Decrypt the output") |
|
|
| with gr.Row(): |
| decrypted_output = gr.Textbox( |
| label="Seizure detection result:", |
| interactive=False |
| ) |
|
|
| |
| keygen_button.click( |
| keygen, |
| outputs=[user_id, keygen_checkbox], |
| ) |
|
|
| |
| encrypt_button.click( |
| encrypt, |
| inputs=[user_id, input_image], |
| outputs=[encrypted_input], |
| ) |
|
|
| |
| send_input_button.click( |
| send_input, inputs=[user_id], outputs=[send_input_checkbox] |
| ) |
|
|
| |
| execute_fhe_button.click(run_fhe, inputs=[user_id], outputs=[fhe_execution_time]) |
|
|
| |
| get_output_button.click( |
| get_output, |
| inputs=[user_id], |
| outputs=[encrypted_output] |
| ) |
|
|
| |
| decrypt_button.click( |
| decrypt_output, |
| inputs=[user_id], |
| outputs=[decrypted_output], |
| ) |
|
|
| gr.Markdown( |
| "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a " |
| "Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). " |
| "Try it yourself and don't forget to star on Github ⭐." |
| ) |
|
|
| demo.launch(share=False) |
|
|