{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "from PIL import Image\n", "import numpy as np\n", "import torch\n", "from diffusers import (\n", " StableDiffusionInstructPix2PixPipeline,\n", " UNet2DConditionModel,\n", " ControlNetModel,\n", " StableDiffusionControlNetPipeline,\n", " UniPCMultistepScheduler,\n", ")\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import numpy as np\n", "from PIL import Image\n", "from scipy.ndimage import binary_erosion, binary_dilation\n", "from skimage.morphology import disk\n", "\n", "\n", "def find_region(generated_image, erosion_dilation_radius=5):\n", " red_channel = generated_image[:, :, 0] \n", " green_channel = generated_image[:, :, 1] \n", " blue_channel = generated_image[:, :, 2] \n", " red_region = (red_channel > 100) & (green_channel < 80) & (blue_channel < 80)\n", " selem = disk(erosion_dilation_radius) \n", " mask = binary_erosion(red_region, structure=selem).astype(np.uint8) \n", " mask = binary_dilation(mask, structure=selem).astype(np.uint8)\n", " return mask\n", "\n", "def post_process(generated_image, organ, disease):\n", " generated_image = np.array(generated_image)\n", " mask = find_region(generated_image)\n", " color_map = {\n", " \"Atelectasis\": (255, 0, 0), \n", " \"Calcification\": (0, 255, 0), \n", " \"Cardiomegaly\": (0, 0, 255), \n", " \"Consolidation\": (255, 255, 0), \n", " \"Diffuse Nodule\": (255, 165, 0), \n", " \"Effusion\": (0, 255, 255), \n", " \"Emphysema\": (255, 0, 255), \n", " \"Fibrosis\": (128, 0, 128), \n", " \"Fracture\": (255, 192, 203), \n", " \"Mass\": (173, 255, 47), \n", " \"Nodule\": (0, 128, 255), \n", " \"Pleural Thickening\": (75, 0, 130), \n", " \"Pneumothorax\": (255, 105, 180) \n", " }\n", " organ_np = np.array(organ)\n", " color = color_map.get(disease, [0, 0, 0]) \n", " organ_np[mask == 1] = color\n", " return Image.fromarray(organ_np)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Load models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t2l_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained('CompVis/stable-diffusion-v1-4')\n", "t2l_pipeline.unet = UNet2DConditionModel.from_pretrained('/home/AURA/text2layout/instruct-pix2pix-model/checkpoint-7000/', subfolder=\"unet\")\n", "t2l_pipeline.safety_checker = None\n", "t2l_pipeline.to(device)\n", "\n", "controlnet = ControlNetModel.from_pretrained('/home/AURA/layout2image_multi/controlnet-model_multi/checkpoint-7000/controlnet', use_safetensors=True)\n", "l2i_pipeline = StableDiffusionControlNetPipeline.from_pretrained('/home/AURA/roentgen/', controlnet=controlnet)\n", "l2i_pipeline.scheduler = UniPCMultistepScheduler.from_config(l2i_pipeline .scheduler.config)\n", "l2i_pipeline.safety_checker = None\n", "l2i_pipeline.to(device)\n", "\n", "print('Loaded models')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "organ_mask_path = '/home/AURA/test/41130/organ_41130.png'\n", "organ = Image.open(organ_mask_path).convert(\"RGB\")\n", "organ.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt = \"mild Atelectasis on right upper lung\"\n", "disease = prompt.split()[1] # disease = \"Atelectasis\"\n", "generated_mask = t2l_pipeline(prompt, organ, num_inference_steps=20, guidance_scale=7).images[0]\n", "overlap = post_process(generated_mask, organ, disease)\n", "overlap.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "generated_image = l2i_pipeline(prompt, overlap, num_inference_steps=70, guidance_scale=8).images[0]\n", "generated_image.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }