Upload folder using huggingface_hub
Browse files- modeling_MMRet_CLIP.py +14 -11
modeling_MMRet_CLIP.py
CHANGED
|
@@ -38,10 +38,10 @@ from transformers.utils import (
|
|
| 38 |
replace_return_docstrings,
|
| 39 |
)
|
| 40 |
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 41 |
-
|
| 42 |
|
| 43 |
if is_flash_attn_2_available():
|
| 44 |
-
from .
|
| 45 |
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
|
@@ -50,7 +50,7 @@ logger = logging.get_logger(__name__)
|
|
| 50 |
_CONFIG_FOR_DOC = "MMRet_CLIP"
|
| 51 |
|
| 52 |
# Image classification docstring
|
| 53 |
-
_IMAGE_CLASS_CHECKPOINT = "
|
| 54 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
| 55 |
|
| 56 |
|
|
@@ -1160,6 +1160,9 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
| 1160 |
# Initialize weights and apply final processing
|
| 1161 |
self.post_init()
|
| 1162 |
|
|
|
|
|
|
|
|
|
|
| 1163 |
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
| 1164 |
def get_text_features(
|
| 1165 |
self,
|
|
@@ -1258,18 +1261,18 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
| 1258 |
|
| 1259 |
|
| 1260 |
def encode_image(self, images):
|
| 1261 |
-
embeddings = self.
|
| 1262 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
| 1263 |
return embeddings
|
| 1264 |
|
| 1265 |
def encode_text(self, text):
|
| 1266 |
-
embeddings = self.
|
| 1267 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
| 1268 |
return embeddings
|
| 1269 |
|
| 1270 |
def encode_multimodal(self, images, text):
|
| 1271 |
-
text_embeddings = self.
|
| 1272 |
-
image_embeddings = self.
|
| 1273 |
|
| 1274 |
embeddings = text_embeddings + image_embeddings
|
| 1275 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
@@ -1278,7 +1281,7 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
| 1278 |
|
| 1279 |
def data_process(self, images=None, text=None):
|
| 1280 |
if images is None and text is not None:
|
| 1281 |
-
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.
|
| 1282 |
|
| 1283 |
return images, text, "text"
|
| 1284 |
elif images is not None and text is None:
|
|
@@ -1286,7 +1289,7 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
| 1286 |
images = Image.open(images).convert("RGB")
|
| 1287 |
elif isinstance(images, list):
|
| 1288 |
images = [Image.open(image).convert("RGB") for image in images]
|
| 1289 |
-
images = self.processor(images=images, return_tensors="pt").to(self.
|
| 1290 |
images = images["pixel_values"]
|
| 1291 |
return images, text, "images"
|
| 1292 |
elif images is not None and text is not None:
|
|
@@ -1296,9 +1299,9 @@ class CLIPModel(CLIPPreTrainedModel):
|
|
| 1296 |
elif isinstance(images, list):
|
| 1297 |
assert len(images) == len(text), "images and text must be lists of the same length when use list"
|
| 1298 |
images = [Image.open(image).convert("RGB") for image in images]
|
| 1299 |
-
images = self.processor(images=images, return_tensors="pt").to(self.
|
| 1300 |
images = images["pixel_values"]
|
| 1301 |
-
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.
|
| 1302 |
return images, text, "multimodal"
|
| 1303 |
else:
|
| 1304 |
raise ValueError("images and text cannot both be None")
|
|
|
|
| 38 |
replace_return_docstrings,
|
| 39 |
)
|
| 40 |
from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
| 41 |
+
from transformers import CLIPProcessor
|
| 42 |
|
| 43 |
if is_flash_attn_2_available():
|
| 44 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 45 |
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
|
|
|
| 50 |
_CONFIG_FOR_DOC = "MMRet_CLIP"
|
| 51 |
|
| 52 |
# Image classification docstring
|
| 53 |
+
_IMAGE_CLASS_CHECKPOINT = "JUNJIE99/MMRet-base"
|
| 54 |
_IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_0"
|
| 55 |
|
| 56 |
|
|
|
|
| 1160 |
# Initialize weights and apply final processing
|
| 1161 |
self.post_init()
|
| 1162 |
|
| 1163 |
+
def set_processor(self, model_name):
|
| 1164 |
+
self.processor = CLIPProcessor.from_pretrained(model_name)
|
| 1165 |
+
|
| 1166 |
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
| 1167 |
def get_text_features(
|
| 1168 |
self,
|
|
|
|
| 1261 |
|
| 1262 |
|
| 1263 |
def encode_image(self, images):
|
| 1264 |
+
embeddings = self.get_image_features(images)
|
| 1265 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
| 1266 |
return embeddings
|
| 1267 |
|
| 1268 |
def encode_text(self, text):
|
| 1269 |
+
embeddings = self.get_text_features(**text)
|
| 1270 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
| 1271 |
return embeddings
|
| 1272 |
|
| 1273 |
def encode_multimodal(self, images, text):
|
| 1274 |
+
text_embeddings = self.get_text_features(**text)
|
| 1275 |
+
image_embeddings = self.get_image_features(images)
|
| 1276 |
|
| 1277 |
embeddings = text_embeddings + image_embeddings
|
| 1278 |
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
|
|
|
|
| 1281 |
|
| 1282 |
def data_process(self, images=None, text=None):
|
| 1283 |
if images is None and text is not None:
|
| 1284 |
+
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
|
| 1285 |
|
| 1286 |
return images, text, "text"
|
| 1287 |
elif images is not None and text is None:
|
|
|
|
| 1289 |
images = Image.open(images).convert("RGB")
|
| 1290 |
elif isinstance(images, list):
|
| 1291 |
images = [Image.open(image).convert("RGB") for image in images]
|
| 1292 |
+
images = self.processor(images=images, return_tensors="pt").to(self.device)
|
| 1293 |
images = images["pixel_values"]
|
| 1294 |
return images, text, "images"
|
| 1295 |
elif images is not None and text is not None:
|
|
|
|
| 1299 |
elif isinstance(images, list):
|
| 1300 |
assert len(images) == len(text), "images and text must be lists of the same length when use list"
|
| 1301 |
images = [Image.open(image).convert("RGB") for image in images]
|
| 1302 |
+
images = self.processor(images=images, return_tensors="pt").to(self.device)
|
| 1303 |
images = images["pixel_values"]
|
| 1304 |
+
text = self.processor(text=text, return_tensors="pt", padding=True).to(self.device)
|
| 1305 |
return images, text, "multimodal"
|
| 1306 |
else:
|
| 1307 |
raise ValueError("images and text cannot both be None")
|