| --- |
| license: creativeml-openrail-m |
| language: |
| - en |
| base_model: |
| - CompVis/stable-diffusion-v1-4 |
| - limuloo1999/MIGC |
| pipeline_tag: text-to-image |
| --- |
| # About file |
|
|
| <!-- Provide a quick summary of what the model is/does. --> |
|
|
| Diffusers version of MIGC adapter state dict. The actual values are identical to the original checkpoint file [MICG_SD14.ckpt](https://huggingface.co/limuloo1999/MIGC) |
| Please see the details of MIGC in the [MIGC repositiory](https://github.com/limuloo/MIGC). |
|
|
|
|
| # How to use |
|
|
| Please use modified pipeline class in `pipeline_stable_diffusion_migc.py` file. |
|
|
| ```python |
| import random |
| |
| import numpy as np |
| import safetensors.torch |
| import torch |
| from huggingface_hub import hf_hub_download |
| |
| from pipeline_stable_diffusion_migc import StableDiffusionMIGCPipeline |
| |
| |
| DEVICE="cuda" |
| SEED=42 |
| |
| pipe = StableDiffusionMIGCPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(DEVICE) |
| adapter_path = hf_hub_download(repo_id="thisiswooyeol/MIGC-diffusers", filename="migc_adapter_weights.safetensors") |
| |
| # Load MIGC adapter to UNet attn2 layers |
| state_dict = safetensors.torch.load_file(adapter_path) |
| for name, module in pipe.unet.named_modules(): |
| if hasattr(module, "migc"): |
| print(f"Found MIGC in {name}") |
| |
| # Get the state dict with the incorrect keys |
| state_dict_to_load = {k: v for k, v in state_dict.items() if k.startswith(name)} |
| |
| # Create a new state dict, removing the "attn2." prefix from each key |
| new_state_dict = {k.replace(f"{name}.migc.", "", 1): v for k, v in state_dict_to_load.items()} |
| |
| # Load the corrected state dict |
| module.migc.load_state_dict(new_state_dict) |
| module.to(device=pipe.unet.device, dtype=pipe.unet.dtype) |
| |
| |
| # Sample inference ! |
| prompt = "bestquality, detailed, 8k.a photo of a black potted plant and a yellow refrigerator and a brown surfboard" |
| phrases = [ |
| "a black potted plant", |
| "a brown surfboard", |
| "a yellow refrigerator", |
| ] |
| bboxes = [ |
| [0.5717187499999999, 0.0, 0.8179531250000001, 0.29807511737089204], |
| [0.85775, 0.058755868544600943, 0.9991875, 0.646525821596244], |
| [0.6041562500000001, 0.284906103286385, 0.799046875, 0.9898591549295774], |
| ] |
| |
| def seed_everything(seed): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| |
| |
| seed_everything(SEED) |
| |
| image = pipe( |
| prompt=prompt, |
| phrases=phrases, |
| bboxes=bboxes, |
| negative_prompt="worst quality, low quality, bad anatomy", |
| generator=torch.Generator(DEVICE).manual_seed(SEED), |
| ).images[0] |
| image.save("image.png") |
| ``` |
|
|