MIGC-diffusers / README.md
thisiswooyeol's picture
Create README.md
4201802 verified
|
raw
history blame
2.56 kB
metadata
license: creativeml-openrail-m
language:
  - en
base_model:
  - CompVis/stable-diffusion-v1-4
  - limuloo1999/MIGC
pipeline_tag: text-to-image

About file

Diffusers version of MIGC adapter state dict. The actual values are identical to the original checkpoint file MICG_SD14.ckpt Please see the details of MIGC in the MIGC repositiory.

How to use

Please use modified pipeline class in pipeline_stable_diffusion_migc.py file.

import random

import numpy as np
import safetensors.torch
import torch
from huggingface_hub import hf_hub_download

from pipeline_stable_diffusion_migc import StableDiffusionMIGCPipeline


DEVICE="cuda"
SEED=42

pipe = StableDiffusionMIGCPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(DEVICE)
adapter_path = hf_hub_download(repo_id="thisiswooyeol/MIGC-diffusers", filename="migc_adapter_weights.safetensors")

# Load MIGC adapter to UNet attn2 layers
state_dict = safetensors.torch.load_file(adapter_path)
for name, module in pipe.unet.named_modules():
    if hasattr(module, "migc"):
        print(f"Found MIGC in {name}")

        # Get the state dict with the incorrect keys
        state_dict_to_load = {k: v for k, v in state_dict.items() if k.startswith(name)}

        # Create a new state dict, removing the "attn2." prefix from each key
        new_state_dict = {k.replace(f"{name}.migc.", "", 1): v for k, v in state_dict_to_load.items()}

        # Load the corrected state dict
        module.migc.load_state_dict(new_state_dict)
        module.to(device=pipe.unet.device, dtype=pipe.unet.dtype)


# Sample inference !
prompt = "bestquality, detailed, 8k.a photo of a black potted plant and a yellow refrigerator and a brown surfboard"
phrases = [
  "a black potted plant",
  "a brown surfboard",
  "a yellow refrigerator",
]
bboxes = [
  [0.5717187499999999, 0.0, 0.8179531250000001, 0.29807511737089204],
  [0.85775, 0.058755868544600943, 0.9991875, 0.646525821596244],
  [0.6041562500000001, 0.284906103286385, 0.799046875, 0.9898591549295774],
]

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


seed_everything(SEED)

image = pipe(
    prompt=prompt,
    phrases=phrases,
    bboxes=bboxes,
    negative_prompt="worst quality, low quality, bad anatomy",
    generator=torch.Generator(DEVICE).manual_seed(SEED),
).images[0]
image.save("image.png")