add to mps device in _build_GOT_vision
Browse files- got_vision_b.py +4 -2
got_vision_b.py
CHANGED
|
@@ -448,6 +448,8 @@ def _build_GOT_vision(
|
|
| 448 |
image_size = 1024
|
| 449 |
vit_patch_size = 16
|
| 450 |
image_embedding_size = image_size // vit_patch_size
|
|
|
|
|
|
|
| 451 |
image_encoder=ImageEncoderViT(
|
| 452 |
depth=encoder_depth,
|
| 453 |
embed_dim=encoder_embed_dim,
|
|
@@ -461,8 +463,8 @@ def _build_GOT_vision(
|
|
| 461 |
global_attn_indexes=encoder_global_attn_indexes,
|
| 462 |
window_size=14,
|
| 463 |
out_chans=prompt_embed_dim,
|
| 464 |
-
)
|
| 465 |
-
|
| 466 |
|
| 467 |
return image_encoder
|
| 468 |
|
|
|
|
| 448 |
image_size = 1024
|
| 449 |
vit_patch_size = 16
|
| 450 |
image_embedding_size = image_size // vit_patch_size
|
| 451 |
+
device = torch.device('mps')
|
| 452 |
+
|
| 453 |
image_encoder=ImageEncoderViT(
|
| 454 |
depth=encoder_depth,
|
| 455 |
embed_dim=encoder_embed_dim,
|
|
|
|
| 463 |
global_attn_indexes=encoder_global_attn_indexes,
|
| 464 |
window_size=14,
|
| 465 |
out_chans=prompt_embed_dim,
|
| 466 |
+
).to(device)
|
| 467 |
+
|
| 468 |
|
| 469 |
return image_encoder
|
| 470 |
|