pwaldron commited on
Commit
aca26e9
·
verified ·
1 Parent(s): decee8f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +129 -86
handler.py CHANGED
@@ -1,87 +1,130 @@
1
- from typing import Dict, List, Any
2
- import torch
3
- import base64
4
- from PIL import Image
5
- from io import BytesIO
6
- from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, AutoencoderKL
7
- from controlnet_aux.pidi import PidiNetDetector
8
-
9
- # set device
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
- if device.type != 'cuda':
13
- raise ValueError("need to run on GPU")
14
-
15
- class EndpointHandler():
16
- def __init__(self, path=""):
17
- # Preload all the elements you are going to need at inference.
18
- # pseudo:
19
- # self.model= load_model(path)
20
-
21
- adapter = T2IAdapter.from_pretrained(
22
- "Adapter/t2iadapter",
23
- subfolder="sketch_sdxl_1.0",
24
- torch_dtype=torch.float16,
25
- adapter_type="full_adapter_xl"
26
- )
27
-
28
- vae = AutoencoderKL.from_pretrained(
29
- "madebyollin/sdxl-vae-fp16-fix",
30
- torch_dtype=torch.float16,
31
- use_safetensors=True
32
- )
33
-
34
- self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
35
- "stabilityai/stable-diffusion-xl-base-1.0",
36
- adapter=adapter,
37
- vae=vae,
38
- torch_dtype=torch.float16,
39
- variant="fp16"
40
- ).to("cuda")
41
- self.pipeline.enable_sequential_cpu_offload()
42
-
43
- self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
44
-
45
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
46
- """
47
- data args:
48
- inputs (:obj: `str` | `PIL.Image` | `np.array`)
49
- kwargs
50
- Return:
51
- A :obj:`list` | `dict`: will be serialized and returned
52
- """
53
-
54
- # pseudo
55
- # self.model(input)
56
-
57
- # get inputs
58
- inputs = data.pop("inputs", "")
59
- encoded_image = data.pop("image", None)
60
-
61
- # Decode image and convert to black and white sketch
62
- decoded_image = self.decode_base64_image(encoded_image).convert('RGB')
63
- sketch_image = self.pidinet(
64
- decoded_image,
65
- detect_resolution=1024,
66
- image_resolution=1024,
67
- apply_filter=True
68
- ).convert('L')
69
-
70
- # sketch_image.save("./output1.png")
71
-
72
- output_image = self.pipeline(
73
- prompt=inputs,
74
- negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
75
- image=sketch_image,
76
- guidance_scale=7.5,
77
- ).images[0]
78
-
79
- # output_image.save("./output2.png")
80
- return output_image
81
-
82
- # helper to decode input image
83
- def decode_base64_image(self, image_string):
84
- base64_image = base64.b64decode(image_string)
85
- buffer = BytesIO(base64_image)
86
- image = Image.open(buffer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  return image
 
1
+ from typing import Dict, List, Any
2
+ import torch
3
+ import base64
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from diffusers import T2IAdapter, StableDiffusionXLAdapterPipeline, AutoencoderKL, DPMSolverMultistepScheduler
7
+ from controlnet_aux.pidi import PidiNetDetector
8
+
9
+ # set device
10
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+
12
+ if device.type != 'cuda':
13
+ raise ValueError("need to run on GPU")
14
+
15
+ class EndpointHandler():
16
+ # Preload all the elements you are going to need at inference.
17
+ def __init__(self, path=""):
18
+
19
+ # load the T2I adapter
20
+ adapter = T2IAdapter.from_pretrained(
21
+ "Adapter/t2iadapter",
22
+ subfolder="sketch_sdxl_1.0",
23
+ torch_dtype=torch.float16,
24
+ adapter_type="full_adapter_xl",
25
+ use_safetensors=True,
26
+ )
27
+
28
+ # load variational autoencoder (VAE)
29
+ vae = AutoencoderKL.from_pretrained(
30
+ "madebyollin/sdxl-vae-fp16-fix",
31
+ torch_dtype=torch.float16,
32
+ use_safetensors=True,
33
+ )
34
+
35
+ # load the scheduler
36
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(
37
+ "stabilityai/stable-diffusion-xl-base-1.0",
38
+ subfolder="scheduler",
39
+ use_lu_lambdas=True,
40
+ euler_at_final=True,
41
+ )
42
+
43
+ # instantiate HF pipeline to combine all the components
44
+ self.pipeline = StableDiffusionXLAdapterPipeline.from_pretrained(
45
+ "stabilityai/stable-diffusion-xl-base-1.0",
46
+ adapter=adapter,
47
+ vae=vae,
48
+ scheduler=scheduler,
49
+ torch_dtype=torch.float16,
50
+ variant="fp16",
51
+ use_safetensors=True,
52
+ ).to("cuda")
53
+
54
+ # instantiate HF refiner to improve output image
55
+ self.refiner = StableDiffusionXLAdapterPipeline.from_pretrained(
56
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
57
+ text_encoder_2=self.pipeline.text_encoder_2,
58
+ adapter=adapter,
59
+ vae=vae,
60
+ torch_dtype=torch.float16,
61
+ variant="fp16",
62
+ use_safetensors=True,
63
+ ).to("cuda")
64
+
65
+ self.pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
66
+
67
+ self.pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
68
+ self.refiner.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
69
+ self.pidinet.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
70
+ self.pipeline.enable_sequential_cpu_offload()
71
+ self.refiner.enable_model_cpu_offload()
72
+ self.pidinet.enable_model_cpu_offload()
73
+
74
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
75
+ """
76
+ data args:
77
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
78
+ kwargs
79
+ Return:
80
+ A :obj:`list` | `dict`: will be serialized and returned
81
+ """
82
+
83
+ # pseudo
84
+ # self.model(input)
85
+
86
+ # get inputs
87
+ inputs = data.pop("inputs", "")
88
+ encoded_image = data.pop("image", None)
89
+
90
+ # Decode image and convert to black and white sketch
91
+ decoded_image = self.decode_base64_image(encoded_image).convert('RGB')
92
+ sketch_image = self.pidinet(
93
+ decoded_image,
94
+ detect_resolution=1024,
95
+ image_resolution=1024,
96
+ apply_filter=True
97
+ ).convert('L')
98
+
99
+ # sketch_image.save("./output1.png")
100
+
101
+ num_inference_steps = 25
102
+ high_noise_frac = 0.7
103
+ base_image = self.pipeline(
104
+ prompt=inputs,
105
+ negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
106
+ image=sketch_image,
107
+ num_inference_steps=num_inference_steps,
108
+ denoising_end=high_noise_frac,
109
+ guidance_scale=7.5,
110
+ output_type="latent",
111
+ ).images
112
+
113
+ output_image = self.refiner(
114
+ prompt=inputs,
115
+ negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
116
+ image=base_image,
117
+ num_inference_steps=num_inference_steps,
118
+ denoising_start=high_noise_frac,
119
+ guidance_scale=7.5,
120
+ ).images[0]
121
+
122
+ # output_image.save("./output2.png")
123
+ return output_image
124
+
125
+ # helper to decode input image
126
+ def decode_base64_image(self, image_string):
127
+ base64_image = base64.b64decode(image_string)
128
+ buffer = BytesIO(base64_image)
129
+ image = Image.open(buffer)
130
  return image