gpu for cuda
Browse files- handler.py +3 -4
handler.py
CHANGED
|
@@ -20,9 +20,7 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
| 20 |
if device.type != 'cuda':
|
| 21 |
raise ValueError("need to run on GPU")
|
| 22 |
# set mixed precision dtype
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
dtype = torch.float32
|
| 26 |
|
| 27 |
# controlnet mapping for controlnet id and control hinter
|
| 28 |
CONTROLNET_MAPPING = {
|
|
@@ -87,7 +85,8 @@ class EndpointHandler():
|
|
| 87 |
|
| 88 |
|
| 89 |
# Define Generator with seed
|
| 90 |
-
self.generator = torch.Generator(device="cpu").manual_seed(3)
|
|
|
|
| 91 |
|
| 92 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 93 |
"""
|
|
|
|
| 20 |
if device.type != 'cuda':
|
| 21 |
raise ValueError("need to run on GPU")
|
| 22 |
# set mixed precision dtype
|
| 23 |
+
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# controlnet mapping for controlnet id and control hinter
|
| 26 |
CONTROLNET_MAPPING = {
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
# Define Generator with seed
|
| 88 |
+
# self.generator = torch.Generator(device="cpu").manual_seed(3)
|
| 89 |
+
self.generator = torch.Generator(device="cuda").manual_seed(3)
|
| 90 |
|
| 91 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 92 |
"""
|