| import gradio as gr |
| import h5py |
| import mrcfile |
| import numpy as np |
| from PIL import Image |
| from omegaconf import DictConfig |
| import torch |
| from pathlib import Path |
| from torchvision.transforms import functional as F |
| import torchvision.transforms.v2 as v2 |
| import spaces |
|
|
|
|
| from draco.configuration import CfgNode |
| from draco.model import ( |
| build_model, |
| load_pretrained |
| ) |
|
|
| example_files = { |
| "EMPIAR-10078": "example/empiar-10078-00-000093-full_patch_aligned.h5", |
| "EMPIAR-10154": "example/empiar-10154-00-000130-full_patch_aligned.h5", |
| "EMPIAR-10185": "example/empiar-10185-00-000032-full_patch_aligned.h5", |
| "EMPIAR-10200": "example/empiar-10200-00-000139-full_patch_aligned.h5", |
| "EMPIAR-10216": "example/empiar-10216-00-000036-full_patch_aligned.h5" |
| } |
|
|
| class DRACODenoiser(object): |
| def __init__(self, |
| cfg: DictConfig, |
| ckpt_path: Path, |
| ) -> None: |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| self.transform = self.build_transform() |
| self.model = build_model(cfg).to(self.device).eval() |
| self.model = load_pretrained(self.model, ckpt_path, self.device) |
| self.patch_size = cfg.MODEL.PATCH_SIZE |
|
|
| def patchify(self, image: torch.Tensor) -> torch.Tensor: |
| B, C, H, W = image.shape |
| P = self.patch_size |
| if H % P != 0 or W % P != 0: |
| image = torch.nn.functional.pad(image, (0, (P - W % P) % P, 0, (P - H % P) % P), mode='constant', value=0) |
|
|
| patches = image.unfold(2, P, P).unfold(3, P, P) |
| patches = patches.permute(0, 2, 3, 4, 5, 1) |
| patches = patches.reshape(B, -1, P * P * C) |
| return patches |
|
|
| def unpatchify(self, patches: torch.Tensor, H: int, W: int) -> torch.Tensor: |
| B = patches.shape[0] |
| P = self.patch_size |
|
|
| images = patches.reshape(B, (H + P - 1) // P, (W + P - 1) // P, P, P, -1) |
| images = images.permute(0, 5, 1, 3, 2, 4) |
| images = images.reshape(B, -1, (H + P - 1) // P * P, (W + P - 1) // P * P) |
| images = images[..., :H, :W] |
| return images |
|
|
| @classmethod |
| def build_transform(cls) -> v2.Compose: |
| return v2.Compose([ |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale=True) |
| ]) |
|
|
| @spaces.GPU |
| def inference(self, image: Image.Image) -> None: |
| W, H = image.size |
|
|
| x = self.transform(image).unsqueeze(0).to(self.device) |
| y = self.model(x) |
|
|
| x = self.patchify(x).detach().cpu().numpy() |
| denoised = self.unpatchify(y, H, W).squeeze(0).permute(1, 2, 0).detach().cpu().numpy() |
|
|
| return denoised |
|
|
| |
| cfg = CfgNode.load_yaml_with_base(Path("draco.yaml")) |
| CfgNode.merge_with_dotlist(cfg, []) |
| ckpt_path = Path("denoise.ckpt") |
| denoiser = DRACODenoiser(cfg, ckpt_path) |
|
|
| def Auto_contrast(image, t_mean=150.0/255.0, t_sd=40.0/255.0) -> np.ndarray: |
|
|
| image = (image - image.min()) / (image.max() - image.min()) |
| mean = image.mean() |
| std = image.std() |
|
|
| f = std / t_sd |
|
|
| black = mean - t_mean * f |
| white = mean + (1 - t_mean) * f |
|
|
| new_image = np.clip(image, black, white) |
| new_image = (new_image - black) / (white - black) |
| return new_image |
|
|
|
|
| def load_data(file_path) -> np.ndarray: |
| if file_path.endswith('.h5'): |
| with h5py.File(file_path, "r") as f: |
| full_micrograph = f["micrograph"] if "micrograph" in f else f["data"] |
| full_mean = full_micrograph.attrs["mean"] if "mean" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).mean() |
| full_std = full_micrograph.attrs["std"] if "std" in full_micrograph.attrs else full_micrograph[:].astype(np.float32).std() |
| data = full_micrograph[:].astype(np.float32) |
| elif file_path.endswith('.mrc'): |
| with mrcfile.open(file_path, "r") as f: |
| data = f.data[:].astype(np.float32) |
| full_mean = data.mean() |
| full_std = data.std() |
| else: |
| raise ValueError("Unsupported file format. Please upload a .mrc or .h5 file.") |
| data = (data - full_mean) / full_std |
| return data |
|
|
| def display_crop(data, x_offset, y_offset, auto_contrast) -> Image: |
|
|
| if data is None: |
| return None |
| |
| crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024] |
| original_image_normalized = Auto_contrast(crop) if auto_contrast else (crop - crop.min()) / (crop.max() - crop.min()) |
| input_image = Image.fromarray((original_image_normalized * 255).astype(np.uint8)) |
| |
| return input_image |
|
|
| @spaces.GPU |
| def process_and_denoise(data, x_offset, y_offset, auto_contrast) -> Image: |
|
|
| if data is None: |
| return None |
| |
| crop = data[y_offset:y_offset + 1024, x_offset:x_offset + 1024] |
| denoised_data = denoiser.inference(Image.fromarray(crop)) |
| |
| denoised_data = denoised_data.squeeze() |
| denoised_image_normalized = Auto_contrast(denoised_data) if auto_contrast else (denoised_data - denoised_data.min()) / (denoised_data.max() - denoised_data.min()) |
| denoised_image = Image.fromarray((denoised_image_normalized * 255).astype(np.uint8)) |
| |
| return denoised_image |
|
|
| def clear_images() -> tuple: |
| return None, None, None, gr.update(value=0,maximum=1024), gr.update(value=0,maximum=1024) |
|
|
| with gr.Blocks(css=""" |
| .custom-size { |
| width: 735px; |
| height: 127px; |
| } |
| """) as demo: |
|
|
| gr.Markdown( |
| ''' |
| <div style="text-align: center;"> |
| <h1>Draco Denoising Demo 🙉</h1> |
| <p style="font-size:16px;">Upload a raw micrograph or select a example to visualize the original and denoised results</p> |
| <p style="font-size:16px;">Our denoising model supports a bin-1 micrograph (ends with .mrc or .h5). To achieve the optimal performance, the input should be <strong>motion corrected</strong> before passing to model.</p> |
| </div> |
| ''' |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
| example_selector = gr.Radio(label="Choose an example Raw Micrograph File", choices=list(example_files.keys())) |
| file_input = gr.File(label="Or upload a Micrograph File in .h5 or .mrc format") |
| |
| with gr.Column(): |
| auto_contrast = gr.Checkbox(label="Enable Auto Contrast", value=False, elem_classes=["custom-size"]) |
| x_slider = gr.Slider(0, 1024, step=10, label="X Offset", elem_classes=["custom-size"]) |
| y_slider = gr.Slider(0, 1024, step=10, label="Y Offset", elem_classes=["custom-size"]) |
|
|
| with gr.Row(): |
| denoise_button = gr.Button("Denoise") |
| clear_button = gr.Button("Clear") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| original_image = gr.Image(type="pil", label="Original Image") |
| with gr.Column(): |
| denoised_image = gr.Image(type="pil", label="Denoised Image") |
| |
| active_data = gr.State() |
|
|
| def load_image_and_update_sliders(file_path) -> tuple: |
| data = load_data(file_path) |
| h, w = data.shape[:2] |
| original_image = display_crop(data, 0, 0, auto_contrast) |
| return data, original_image, None, gr.update(value=0,maximum=w-1024), gr.update(value=0,maximum=h-1024) |
|
|
| |
| example_selector.change( |
| lambda choice:load_image_and_update_sliders(example_files[choice]), |
| inputs=example_selector, |
| outputs=[active_data, original_image, denoised_image, x_slider, y_slider] |
| ) |
| |
| file_input.clear( |
| clear_images, |
| inputs=None, |
| outputs=[original_image, denoised_image, active_data, x_slider, y_slider] |
| ) |
|
|
| file_input.change( |
| lambda file: load_image_and_update_sliders(file.name) if file else (None, None, None, gr.update(maximum=1024), gr.update(maximum=1024)), |
| inputs=file_input, |
| outputs=[active_data, original_image, denoised_image, x_slider, y_slider] |
| ) |
|
|
| x_slider.change( |
| display_crop, |
| inputs=[active_data, x_slider, y_slider, auto_contrast], |
| outputs=original_image |
| ) |
|
|
| y_slider.change( |
| display_crop, |
| inputs=[active_data, x_slider, y_slider, auto_contrast], |
| outputs=original_image |
| ) |
|
|
| denoise_button.click( |
| process_and_denoise, |
| inputs=[active_data, x_slider, y_slider, auto_contrast], |
| outputs=denoised_image |
| ) |
| |
| clear_button.click(clear_images, inputs=None, outputs=[original_image, denoised_image, active_data, x_slider, y_slider]) |
|
|
| demo.launch() |
|
|
|
|