File size: 2,561 Bytes
4201802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
---
license: creativeml-openrail-m
language:
- en
base_model:
- CompVis/stable-diffusion-v1-4
- limuloo1999/MIGC
pipeline_tag: text-to-image
---
# About file

<!-- Provide a quick summary of what the model is/does. -->

Diffusers version of MIGC adapter state dict. The actual values are identical to the original checkpoint file [MICG_SD14.ckpt](https://huggingface.co/limuloo1999/MIGC)
Please see the details of MIGC in the [MIGC repositiory](https://github.com/limuloo/MIGC).


# How to use

Please use modified pipeline class in `pipeline_stable_diffusion_migc.py` file.

```python
import random

import numpy as np
import safetensors.torch
import torch
from huggingface_hub import hf_hub_download

from pipeline_stable_diffusion_migc import StableDiffusionMIGCPipeline


DEVICE="cuda"
SEED=42

pipe = StableDiffusionMIGCPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(DEVICE)
adapter_path = hf_hub_download(repo_id="thisiswooyeol/MIGC-diffusers", filename="migc_adapter_weights.safetensors")

# Load MIGC adapter to UNet attn2 layers
state_dict = safetensors.torch.load_file(adapter_path)
for name, module in pipe.unet.named_modules():
    if hasattr(module, "migc"):
        print(f"Found MIGC in {name}")

        # Get the state dict with the incorrect keys
        state_dict_to_load = {k: v for k, v in state_dict.items() if k.startswith(name)}

        # Create a new state dict, removing the "attn2." prefix from each key
        new_state_dict = {k.replace(f"{name}.migc.", "", 1): v for k, v in state_dict_to_load.items()}

        # Load the corrected state dict
        module.migc.load_state_dict(new_state_dict)
        module.to(device=pipe.unet.device, dtype=pipe.unet.dtype)


# Sample inference !
prompt = "bestquality, detailed, 8k.a photo of a black potted plant and a yellow refrigerator and a brown surfboard"
phrases = [
  "a black potted plant",
  "a brown surfboard",
  "a yellow refrigerator",
]
bboxes = [
  [0.5717187499999999, 0.0, 0.8179531250000001, 0.29807511737089204],
  [0.85775, 0.058755868544600943, 0.9991875, 0.646525821596244],
  [0.6041562500000001, 0.284906103286385, 0.799046875, 0.9898591549295774],
]

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


seed_everything(SEED)

image = pipe(
    prompt=prompt,
    phrases=phrases,
    bboxes=bboxes,
    negative_prompt="worst quality, low quality, bad anatomy",
    generator=torch.Generator(DEVICE).manual_seed(SEED),
).images[0]
image.save("image.png")
```