Image-Text-to-Text
Describe Anything
Safetensors
English
llava_llama
custom_code
DAM-3B-Self-Contained / llava_llama.py
richardaecn's picture
Upload llava_llama.py
19415aa verified
# This file has a copy of the dam library. Please refer to the dam library for documents and licenses. Licenses in this file are carried over from files in the dam library.
import base64
import dataclasses
import logging
import math
import os
import os.path as osp
import re
import string
import tempfile
import warnings
from abc import ABC
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum, auto
from io import BytesIO
from shutil import copyfile
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from accelerate.hooks import add_hook_to_module
from huggingface_hub import repo_exists, snapshot_download
from huggingface_hub.utils import HFValidationError
from PIL import Image
from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
StoppingCriteria,
TextIteratorStreamer,
)
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.convert_slow_tokenizer import import_protobuf
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import (
convert_to_rgb,
get_channel_dimension_axis,
get_resize_output_image_size,
normalize,
pad,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, CausalLMOutputWithPast
from transformers.modeling_utils import ContextManagers, PreTrainedModel, no_init_weights
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import (
AddedToken,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from transformers.utils import (
ModelOutput,
TensorType,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_tf_available,
is_torch_available,
is_torchvision_available,
is_vision_available,
logging,
replace_return_docstrings,
requires_backends,
)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/haotian-liu/LLaVA/
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
MASK_TOKEN_INDEX = -300
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
class LlavaConfig(PretrainedConfig):
model_type = "llava"
def __init__(
self,
llm_cfg=None,
vision_tower_cfg=None,
mm_projector_cfg=None,
mask_encoder_cfg=None,
context_provider_cfg=None,
architectures=None,
resume_path=None,
hidden_size=None,
mm_hidden_size=None,
image_aspect_ratio=None,
num_video_frames=None,
mm_vision_select_layer=None,
mm_vision_select_feature=None,
mm_use_im_start_end=False,
mm_use_im_patch_token=True,
mm_projector_lr=None,
vision_resolution=None,
interpolate_mode=None,
s2=None,
s2_scales=None,
s2_max_split_size=None,
**kwargs,
):
super().__init__()
self.architectures = architectures
self.llm_cfg = llm_cfg
self.vision_tower_cfg = vision_tower_cfg
self.mm_projector_cfg = mm_projector_cfg
self.mask_encoder_cfg = mask_encoder_cfg
self.context_provider_cfg = context_provider_cfg
self.resume_path = resume_path
self.hidden_size = hidden_size
self.mm_hidden_size = mm_hidden_size
self.image_aspect_ratio = image_aspect_ratio
self.num_video_frames = num_video_frames
self.mm_vision_select_layer = mm_vision_select_layer
self.mm_vision_select_feature = mm_vision_select_feature
self.mm_use_im_start_end = mm_use_im_start_end
self.mm_use_im_start_end = mm_use_im_start_end
self.mm_use_im_patch_token = mm_use_im_patch_token
self.mm_projector_lr = mm_projector_lr
self.vision_resolution = vision_resolution
self.interpolate_mode = interpolate_mode
self.s2 = s2
self.s2_scales = s2_scales
self.s2_max_split_size = s2_max_split_size
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/haotian-liu/LLaVA/
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
MISTRAL = auto()
LLAMA_3 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("<image>", "").strip()
if "mmtag" in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "<image>\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.LLAMA_3:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message = message[0]
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif (
self.sep_style == SeparatorStyle.LLAMA_2
or self.sep_style == SeparatorStyle.MISTRAL
):
if self.sep_style == SeparatorStyle.LLAMA_2:
def wrap_sys(msg):
return f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
else:
def wrap_sys(msg):
return f"{msg}" + ("\n" if msg else "")
def wrap_inst(msg):
return f"[INST] {msg} [/INST]"
ret = ""
if self.sep_style == SeparatorStyle.MISTRAL:
ret += "<s>"
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
if self.sep_style == SeparatorStyle.LLAMA_2:
ret += " " + message + " " + self.sep2
else:
ret += message + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(
pil_img.mode, (width, width), background_color
)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(
pil_img.mode, (height, height), background_color
)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(
f"Invalid image_process_mode: {image_process_mode}"
)
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>", "").strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version,
)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [
[x, y[0] if type(y) is tuple else y] for x, y in self.messages
],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
(
"Human",
"What are the key differences between renewable and non-renewable energy sources?",
),
(
"Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
),
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
# kentang-mit@: This conversation template is designed for SFT on VFLAN.
conv_vicuna_v1_nosys = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="v1_nosys",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mistral = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="mistral",
messages=(),
offset=0,
sep_style=SeparatorStyle.MISTRAL,
sep="",
sep2="</s>",
)
conv_llava_llama_2 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mpt = Conversation(
system="""<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_llava_plain = Conversation(
system="",
roles=("", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_llava_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("Human", "Assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_llava_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llava_v1_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
version="v1_mmtag",
)
hermes_2 = Conversation(
system="<|im_start|>system\nAnswer the questions.",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
messages=(),
offset=0,
version="hermes-2",
)
# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
llama_3_chat = Conversation(
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=(
"<|start_header_id|>user<|end_header_id|>\n\n",
"<|start_header_id|>system<|end_header_id|>\n\n",
),
version="llama_v3",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="<|end_of_text|>",
)
default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"hermes-2": hermes_2,
"llama_3": llama_3_chat,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"vicuna_v1_nosys": conv_vicuna_v1_nosys,
"llama_2": conv_llama_2,
"mistral": conv_mistral,
"plain": conv_llava_plain,
"v0_plain": conv_llava_plain,
"llava_v0": conv_llava_v0,
"v0_mmtag": conv_llava_v0_mmtag,
"llava_v1": conv_llava_v1,
"v1_mmtag": conv_llava_v1_mmtag,
"llava_llama_2": conv_llava_llama_2,
"mpt": conv_mpt,
}
# if __name__ == "__main__":
# print(default_conversation.get_prompt())
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None):
import cv2
if fps is None or frame_count is None:
# if one of fps or frame_count is None, still recompute
fps = vidcap.get(cv2.CAP_PROP_FPS)
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
if fps == 0 or frame_count == 0:
print("Video file not found. return empty images.")
return [
Image.new("RGB", (720, 720)),
] * num_frames
frame_count / fps
frame_interval = frame_count // num_frames
if frame_interval == 0 and frame_count <= 1:
print("frame_interval is equal to 0. return empty image.")
return [
Image.new("RGB", (720, 720)),
] * num_frames
# print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
images = []
count = 0
success = True
frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int)
while success:
# print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
if frame_count >= num_frames:
success, frame = vidcap.read()
if count in frame_indices:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
im_pil = Image.fromarray(img)
images.append(im_pil)
if len(images) >= num_frames:
return images
count += 1
else:
# Left padding frames if the video is not long enough
success, frame = vidcap.read()
if success:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
im_pil = Image.fromarray(img)
images.append(im_pil)
count += 1
elif count >= 1:
width, height = images[-1].size
images = [Image.new("RGB", (width, height))] * (
num_frames - len(images)
) + images
print("padding frames:", (num_frames - len(images)))
return images
else:
break
raise ValueError("Did not find enough frames in the video. return empty image.")
def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None):
"""
Extract frames from a video using OpenCV.
Args:
vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
frames (int): Number of frames to extract from the video.
Returns:
list: List of PIL Images extracted from the video.
Raises:
NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
"""
import cv2
if isinstance(vpath_or_bytesio, str):
vidcap = cv2.VideoCapture(vpath_or_bytesio)
return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
elif isinstance(vpath_or_bytesio, (BytesIO,)):
# assuming mp4
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
temp_video.write(vpath_or_bytesio.read())
temp_video_name = temp_video.name
vidcap = cv2.VideoCapture(temp_video_name)
return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count)
else:
raise NotImplementedError(type(vpath_or_bytesio))
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
"""
Expand the given PIL image to a square shape by adding padding.
Parameters:
- pil_img: The PIL image to be expanded.
- background_color: The color of the padding to be added.
Returns:
- The expanded PIL image.
If the image is already square, it is returned as is.
If the image is wider than it is tall, padding is added to the top and bottom.
If the image is taller than it is wide, padding is added to the left and right.
"""
width, height = pil_img.size
if pil_img.mode == "L":
background_color = background_color[0]
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_image(image_file, data_args, image_folder, pil_preprocess_fn=None):
processor = data_args.image_processor
if isinstance(image_file, str):
if image_folder is not None:
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
else:
# image is stored in bytearray
image = image_file.convert("RGB")
info = None
if pil_preprocess_fn is not None:
image = pil_preprocess_fn(image)
if isinstance(image, tuple):
image, info = image
if data_args.image_aspect_ratio == "resize":
if hasattr(data_args.image_processor, "crop_size"):
# CLIP vision tower
crop_size = data_args.image_processor.crop_size
else:
# SIGLIP vision tower
assert hasattr(data_args.image_processor, "size")
crop_size = data_args.image_processor.size
image = image.resize((crop_size["height"], crop_size["width"]))
if data_args.image_aspect_ratio == "pad":
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
else:
# Using default behavior of the vision encoder
# For CLIP, default is central crop
# For Radio, default is central crop
# For Siglip, default is resize
# For InternVIT, default is resize
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
if info is not None:
return image, info
return image
def process_images(images, image_processor, model_cfg):
model_cfg.image_processor = image_processor
new_images = [process_image(image, model_cfg, None) for image in images]
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
# Note that newer VILA codebase adds an lstrip option that defaults to False, and the functionality is the same by default
def tokenizer_image_token(
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def is_gemma_tokenizer(tokenizer):
return "gemma" in tokenizer.__class__.__name__.lower()
def get_model_name_from_path(model_path):
if not model_path:
return "describe_anything_model"
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if (
len(cur_keyword_ids) > 1
and cur_keyword_ids[0] == tokenizer.bos_token_id
):
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [
keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids
]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(
output_ids[:, -offset:], skip_special_tokens=True
)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/haotian-liu/LLaVA/
def get_model_config(config):
# `mask_encoder_cfg` and `context_provider_cfg` are optional
default_keys = [
"llm_cfg",
"vision_tower_cfg",
"mm_projector_cfg",
"mask_encoder_cfg",
"context_provider_cfg",
]
if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2:
root_path = config._name_or_path
else:
root_path = config.resume_path
# download from huggingface
if root_path is not None and not osp.exists(root_path):
try:
valid_hf_repo = repo_exists(root_path)
except HFValidationError:
valid_hf_repo = False
if valid_hf_repo:
root_path = snapshot_download(root_path)
return_list = []
for key in default_keys:
cfg = getattr(config, key, None)
if isinstance(cfg, dict):
try:
return_list.append(os.path.join(root_path, key[:-4]))
except:
raise ValueError(f"Cannot find resume path in config for {key}!")
elif isinstance(cfg, PretrainedConfig):
return_list.append(os.path.join(root_path, key[:-4]))
elif isinstance(cfg, str):
return_list.append(cfg)
elif cfg is None:
# We still return even if the cfg is None or does not exist
return_list.append(cfg)
return return_list
def is_mm_model(model_path):
"""
Check if the model at the given path is a visual language model.
Args:
model_path (str): The path to the model.
Returns:
bool: True if the model is an MM model, False otherwise.
"""
config = AutoConfig.from_pretrained(model_path)
architectures = config.architectures
for architecture in architectures:
if "llava" in architecture.lower():
return True
return False
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if "llava" in config and "llava" not in cfg.model_type:
assert cfg.model_type == "llama"
print(
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
)
print(
"You must upgrade the checkpoint to the new code base (this can be done automatically)."
)
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = "LlavaLlamaForCausalLM"
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO decide whether should we use metaclass
class LlavaMetaModel(ABC):
def init_vlm(self, config: PreTrainedModel = None, *args, **kwargs):
# TODO(ligeng): figure out how from_config and from_pretrained works in HF implementation.
if (
hasattr(self, "llm")
or hasattr(self, "vision_tower")
or hasattr(self, "mm_projector")
):
# already initialized, skipped
return
model_dtype = getattr(config, "model_dtype", "torch.float16")
if not hasattr(config, "model_dtype"):
warnings.warn(
"model_dtype not found in config, defaulting to torch.float16."
)
config.model_dtype = model_dtype
# print("init_vlm(): config", config); input("DEBUG init_vlm")
cfgs = get_model_config(config)
# Only the first three are required. Others are optional.
(
llm_cfg,
vision_tower_cfg,
mm_projector_cfg,
mask_encoder_cfg,
context_provider_cfg,
) = cfgs
if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None:
raise ValueError(
"`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config."
)
# print("init_vlm():", cfgs); input("DEBUG init_vlm")
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG init_vlm")
self.llm, self.tokenizer = build_llm_and_tokenizer(
llm_cfg, config, *args, **kwargs
)
self.vision_tower = build_vision_tower(vision_tower_cfg, config)
self.mm_projector = build_mm_projector(mm_projector_cfg, config)
self.context_provider = (
build_context_provider(context_provider_cfg, config)
if context_provider_cfg is not None
else None
)
self.post_config()
self.is_loaded = True
assert (
self.llm is not None
or self.vision_tower is not None
or self.mm_projector is not None
), "At least one of the components must be instantiated."
@classmethod
def load_from_config(cls, model_path_or_config, *args, **kwargs):
pass
# FIXME we will use this function to load model in the future
@classmethod
def load_pretrained(cls, model_path_or_config, *args, **kwargs):
config = kwargs.pop("config", None)
if config is None:
if isinstance(model_path_or_config, str):
config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True)
elif isinstance(model_path_or_config, LlavaConfig):
config = model_path_or_config
else:
raise NotImplementedError(
f"wrong type, {type(model_path_or_config)} \
{isinstance(model_path_or_config, LlavaConfig)}"
)
model_dtype = getattr(config, "model_dtype", "torch.float16")
if not hasattr(config, "model_dtype"):
warnings.warn(
"model_dtype not found in config, defaulting to torch.float16."
)
config.model_dtype = model_dtype
cfgs = get_model_config(config)
# Only the first three are required. Others are optional.
(
llm_cfg,
vision_tower_cfg,
mm_projector_cfg,
mask_encoder_cfg,
context_provider_cfg,
) = cfgs
if llm_cfg is None or vision_tower_cfg is None or mm_projector_cfg is None:
raise ValueError(
"`llm_cfg` `mm_projector_cfg` `vision_tower_cfg` not found in the config."
)
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained")
with ContextManagers(
[
no_init_weights(),
]
):
vlm = cls(config, *args, **kwargs)
# print(llm_cfg, vision_tower_cfg, mm_projector_cfg); input("DEBUG load_pretrained finish")
if (
hasattr(vlm, "llm")
or hasattr(vlm, "vision_tower")
or hasattr(vlm, "mm_projector")
):
if vlm.is_loaded:
return vlm
vlm.llm, vlm.tokenizer = build_llm_and_tokenizer(
llm_cfg, config, *args, **kwargs
)
vlm.vision_tower = build_vision_tower(vision_tower_cfg, config)
vlm.mm_projector = build_mm_projector(mm_projector_cfg, config)
if mask_encoder_cfg is not None:
raise NotImplementedError("Mask encoder is not supported.")
vlm.context_provider = (
build_context_provider(context_provider_cfg, config)
if context_provider_cfg is not None
else None
)
self.post_config()
self.is_loaded = True
# FIXME(ligeng, yunhao): llm should never be none here.
assert (
vlm.llm is not None
or vlm.vision_tower is not None
or vlm.mm_projector is not None
), "At least one of the components must be instantiated."
return vlm
# FIXME we will use this function to save the model in the future
def save_pretrained(self, output_dir, state_dict=None):
if state_dict is None:
# other wise fetch from deepspeed
# state_dict = accelerator.get_state_dict(is_deepspeed_enabled)
state_dict = self.state_dict()
if getattr(self, "tokenizer", None):
self.tokenizer.save_pretrained(osp.join(output_dir, "llm"))
if self.get_llm():
print(f"saving llm to {osp.join(output_dir, 'llm')}")
self.llm.config._name_or_path = osp.join(output_dir, "llm")
llm_state_dict = OrderedDict(
{k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}
)
self.llm.save_pretrained(
os.path.join(output_dir, "llm"), state_dict=llm_state_dict
)
self.config.llm_cfg = self.llm.config
if (
self.get_vision_tower()
and "radio" not in self.get_vision_tower().__class__.__name__.lower()
):
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}")
self.vision_tower.config._name_or_path = osp.join(
output_dir, "vision_tower"
)
vision_tower_state_dict = OrderedDict(
{
k.split("vision_tower.vision_tower.")[-1]: v
for k, v in state_dict.items()
if "vision_tower" in k
}
)
self.vision_tower.vision_tower.save_pretrained(
os.path.join(output_dir, "vision_tower"),
state_dict=vision_tower_state_dict,
)
self.vision_tower.image_processor.save_pretrained(
os.path.join(output_dir, "vision_tower")
)
self.config.vision_tower_cfg = self.vision_tower.config
if hasattr(self.config.vision_tower_cfg, "auto_map"):
delattr(self.config.vision_tower_cfg, "auto_map")
if self.get_mm_projector():
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}")
self.mm_projector.config._name_or_path = osp.join(
output_dir, "mm_projector"
)
mm_projector_state_dict = OrderedDict(
{
k.split("mm_projector.")[-1]: v
for k, v in state_dict.items()
if "mm_projector" in k
}
)
self.mm_projector.save_pretrained(
os.path.join(output_dir, "mm_projector"),
state_dict=mm_projector_state_dict,
)
self.config.mm_projector_cfg = self.mm_projector.config
if self.get_context_provider():
print(
f"saving context_provider to {osp.join(output_dir, 'context_provider')}"
)
self.context_provider.config._name_or_path = osp.join(
output_dir, "context_provider"
)
context_provider_state_dict = OrderedDict(
{
k.split("context_provider.")[-1]: v
for k, v in state_dict.items()
if "context_provider" in k
}
)
self.context_provider.save_pretrained(
os.path.join(output_dir, "context_provider"),
state_dict=context_provider_state_dict,
)
self.config.context_provider_cfg = self.context_provider.config
# update and save top-level config
self.config._name_or_path = output_dir
self.config.architectures = [self.__class__.__name__]
self.config.save_pretrained(output_dir)
def get_llm(self):
llm = getattr(self, "llm", None)
if type(llm) is list:
llm = llm[0]
return llm
def get_lm_head(self):
lm_head = getattr(self.get_llm(), "lm_head", None)
return lm_head
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_mm_projector(self):
mm_projector = getattr(self, "mm_projector", None)
if type(mm_projector) is list:
mm_projector = mm_projector[0]
return mm_projector
def get_context_provider(self):
context_provider = getattr(self, "context_provider", None)
return context_provider
def post_config(self):
self.training = self.get_llm().training
# configuration
if getattr(self.config, "llm_cfg", None) is None:
self.config.llm_cfg = self.llm.config
if getattr(self.config, "vision_tower_cfg", None) is None:
self.config.vision_tower_cfg = self.vision_tower.config
if getattr(self.config, "mm_projector_cfg", None) is None:
self.config.mm_projector_cfg = self.mm_projector.config
if (
getattr(self.config, "context_provider_cfg", None) is None
and self.context_provider is not None
):
self.config.context_provider_cfg = self.context_provider.config
def freezed_module_patch(self):
"""
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules.
"""
if self.training:
if self.get_llm() and not getattr(
self.config, "tune_language_model", False
):
logging.warning(
"Caution: Your LLM is currently in training mode, ensuring accurate gradient computation. Please be vigilant, particularly regarding BatchNorm and Dropout operations."
)
if self.get_vision_tower() and not getattr(
self.config, "tune_vision_tower", False
):
self.get_vision_tower().eval()
if self.get_mm_projector() and not getattr(
self.config, "tune_mm_projector", False
):
self.get_mm_projector().eval()
if self.get_context_provider() and not getattr(
self.config, "tune_context_provider", False
):
self.get_context_provider().eval()
def encode_images(self, images):
image_features = self.get_vision_tower()(images)
image_features = self.get_mm_projector()(image_features)
return image_features
def encode_images_with_context(self, images):
context_provider = self.get_context_provider()
# If the channels completely match, they are cimage (image with context).
cimage_mask = torch.any(
(images[:, :4, ...] != images[:, 4:, ...]).flatten(start_dim=1), dim=1
)
if context_provider.treat_image_as_cimage:
# If the context provider treats the image as cimage, then all images are cimage.
cimage_mask[:] = True
if context_provider.context_image_as_queries:
# Swap the crop image and full image since the model uses the full image as queries by default
images = torch.cat((images[:, 4:, ...], images[:, :4, ...]), dim=1)
# Process the first 4 channels for all images: for image it's the image, for cimage it's the full image
vision_tower = self.get_vision_tower()
# Encode context images (full images)
image_features = vision_tower(images[:, :4, ...]).to(self.device)
# Each cimage has 8 channels (full and crop concatenated)
cimage_concatenated = images[cimage_mask]
cimage_full_features = image_features[cimage_mask]
if context_provider.context_provider_type == "cross_attn_end_to_all":
cimage_features = self.context_provider(
cimage_full_features=cimage_full_features,
cimage_concatenated=cimage_concatenated,
vision_tower=vision_tower,
).to(self.device)
elif context_provider.context_provider_type == "concat":
# Full features of cimages are computed but not used.
cimage_features = self.context_provider(
cimage_concatenated=cimage_concatenated, vision_tower=vision_tower
).to(self.device)
else:
raise NotImplementedError(
f"Context provider type {context_provider.context_provider_type} not implemented."
)
# Put cimage_features into image_features
image_features[cimage_mask] = cimage_features
# Project to the llm space
image_features = self.get_mm_projector()(image_features)
return image_features
# @yunhao: is there a better way to handle function call and attributes for llm?
# support beam search
def _temporary_reorder_cache(self, past_key_values, sorted_idx):
return self.get_llm()._temporary_reorder_cache(past_key_values, sorted_idx)
def get_input_embeddings(self):
return self.get_llm().get_input_embeddings()
def get_output_embeddings(self):
return self.get_llm().get_output_embeddings()
def resize_token_embeddings(self, embed_size):
self.get_llm().resize_token_embeddings(embed_size)
class LlavaMetaForCausalLM(ABC):
"""This class is originally implemented by the LLaVA team and
modified by Haotian Tang and Jason Lu based on Ji Lin's implementation
to support multiple images and input packing."""
# TODO move the forward function here if there is no need to override it
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
if (
past_key_values is not None
and vision_tower is not None
and images is not None
and input_ids.shape[1] == 1
):
target_shape = past_key_values[-1][-1].shape[-2] + 1
attention_mask = torch.cat(
(
attention_mask,
torch.ones(
(
attention_mask.shape[0],
target_shape - attention_mask.shape[1],
),
dtype=attention_mask.dtype,
device=attention_mask.device,
),
),
dim=1,
)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
# handle different image dtypes for packing
if type(images) is list:
images = torch.cat(images, dim=0)
elif images.ndim == 5: # batch_size x seq_len x image_channels
images = images.flatten(0, 1)
if getattr(self, "context_provider", None):
image_features = self.encode_images_with_context(images)
else:
# Since we slice it with index below, turning it into a list splits things by the first index which does not result in data copy or degrade performance.
# Example dimension: [1, 196, 2560]
assert (
images.shape[1] <= 4
), "images have more than 4 channels, but context provider is not included"
image_features = self.encode_images(images).to(self.device)
# Note (kentang-mit@): image start / end is not implemented here to support pretraining.
if getattr(self.config, "turn_mm_projector", False) and getattr(
self.config, "mm_use_im_start_end", False
):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask
input_ids_copy = input_ids.clone()
# kentang-mit@: Otherwise tokenizer out of bounds. Embeddings of image tokens will not be used.
input_ids_copy[input_ids_copy == IMAGE_TOKEN_INDEX] = 0
input_embeds = self.llm.model.embed_tokens(input_ids_copy)
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
input_embeds_1 = [
cur_input_embeds[cur_attention_mask]
for cur_input_embeds, cur_attention_mask in zip(
input_embeds, attention_mask
)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
# print("BEFORE BATCH LOOP:", len(input_ids), input_ids[0].shape, input_ids[0].device, [(x == IMAGE_TOKEN_INDEX).sum() for x in input_ids])
# kentang-mit@: If some part of the model is executed in the loop, the the loop length needs to be a constant.
for batch_idx, cur_input_ids in enumerate(input_ids):
cur_input_ids = input_ids[batch_idx]
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[0]
# cur_input_embeds_1 = self.get_llm().embed_tokens(cur_input_ids)
cur_input_embeds_1 = input_embeds_1[batch_idx]
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
# kenang-mit@: we do not have placeholdr image for text-only data now.
# cur_image_idx += 1
continue
cur_input_embeds = input_embeds_1[batch_idx]
image_token_indices = (
[-1]
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
cur_input_embeds_no_im = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(
cur_input_ids[
image_token_indices[i] + 1 : image_token_indices[i + 1]
]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
cur_input_embeds_no_im.append(
cur_input_embeds[
image_token_indices[i] + 1 : image_token_indices[i + 1]
]
)
[x.shape[0] for x in cur_labels_noim]
# cur_input_embeds = self.get_llm().embed_tokens(torch.cat(cur_input_ids_noim))
# cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(
self.llm.config, "tokenizer_model_max_length", None
)
if tokenizer_model_max_length is not None:
if any(len(x) > tokenizer_model_max_length for x in new_input_embeds):
warnings.warn("Inputs truncated!")
new_input_embeds = [
x[:tokenizer_model_max_length] for x in new_input_embeds
]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels)
):
cur_len = cur_new_embed.shape[0]
if getattr(self.llm.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)
def repack_multimodal_data(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
):
# kentang-mit@: reorder and repack (reduce computation overhead)
# requires transformers replacement.
new_inputs_embeds = []
new_position_ids = []
new_labels = []
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
sorted_seqlens_in_batch, sorted_idx = torch.sort(
seqlens_in_batch, descending=True
)
# print(sorted_seqlens_in_batch)
max_seqlen = inputs_embeds.shape[1]
cur_inputs_embeds = []
cur_position_ids = []
cur_labels = []
cur_batch_len = 0
# print(sorted_seqlens_in_batch.device, len(sorted_seqlens_in_batch), max_seqlen)
for i in range(len(sorted_seqlens_in_batch)):
cur_seqlen = sorted_seqlens_in_batch[i].item()
if cur_seqlen + cur_batch_len <= max_seqlen:
cur_batch_len += cur_seqlen
# each item: num_tokens x num_channels
# remove padding on-the-fly
cur_inputs_embeds.append(
inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]
)
# each item: num_tokens
cur_position_ids.append(
torch.arange(
cur_inputs_embeds[-1].shape[0],
device=cur_inputs_embeds[-1].device,
)
)
# each item: num_tokens
# remove padding on-the-fly
cur_labels.append(labels[sorted_idx[i]][attention_mask[sorted_idx[i]]])
else:
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0))
new_position_ids.append(torch.cat(cur_position_ids, 0))
new_labels.append(torch.cat(cur_labels, 0))
# The current batch is too long. We will start a new batch.
cur_batch_len = cur_seqlen
cur_inputs_embeds = [
inputs_embeds[sorted_idx[i]][attention_mask[sorted_idx[i]]]
]
cur_position_ids = [
torch.arange(
cur_inputs_embeds[-1].shape[0],
device=cur_inputs_embeds[-1].device,
)
]
cur_labels = [labels[sorted_idx[i]][attention_mask[sorted_idx[i]]]]
if len(cur_inputs_embeds):
new_inputs_embeds.append(torch.cat(cur_inputs_embeds, 0))
new_position_ids.append(torch.cat(cur_position_ids, 0))
new_labels.append(torch.cat(cur_labels, 0))
# print(new_position_ids[0].device, [x.shape for x in new_inputs_embeds], [x.shape for x in new_labels], [x.shape for x in new_position_ids])
# assert 0
new_inputs_embeds = torch.nn.utils.rnn.pad_sequence(
new_inputs_embeds, batch_first=True, padding_value=self.llm.pad_token_id
)
new_position_ids = torch.nn.utils.rnn.pad_sequence(
new_position_ids, batch_first=True, padding_value=-1
)
new_labels = torch.nn.utils.rnn.pad_sequence(
new_labels, batch_first=True, padding_value=IGNORE_INDEX
)
# yunhao: it's currently a workaround to avoid errors for seq_len < 100
new_attention_mask = new_position_ids.ne(-1)
# sanity check
assert new_attention_mask.sum() == attention_mask.sum()
# print(new_inputs_embeds.shape, (new_attention_mask.sum(1)))
# print(sorted_seqlens_in_batch.device, sorted_seqlens_in_batch, new_attention_mask.sum(1))
# return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
return (
None,
new_position_ids,
new_attention_mask,
past_key_values,
new_inputs_embeds,
new_labels,
sorted_seqlens_in_batch,
)
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
# TODO yunhao: handle cases for <im_st> <im_end>
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
)
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
-num_new_tokens:
]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif model_args.mm_use_im_patch_token:
if model_args.mm_projector:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
# This file is modified from https://github.com/haotian-liu/LLaVA/
import torch # noqa
def build_mm_projector(
model_type_or_path: str, config: PretrainedConfig
) -> PreTrainedModel:
if model_type_or_path is None:
return None
# load from pretrained model
if config.resume_path:
assert os.path.exists(
model_type_or_path
), f"Resume mm projector path {model_type_or_path} does not exist!"
return MultimodalProjector.from_pretrained(
model_type_or_path, config, torch_dtype=eval(config.model_dtype)
)
# build from scratch
else:
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path)
mm_projector = MultimodalProjector(mm_projector_cfg, config).to(
eval(config.model_dtype)
)
return mm_projector
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
class DownSampleBlock(nn.Module):
def forward(self, x):
vit_embeds = x
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.flat_square(vit_embeds)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
return vit_embeds
def flat_square(self, x):
n, w, h, c = x.size()
if w % 2 == 1:
x = torch.concat(
[x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1
).contiguous()
n, w, h, c = x.size()
if h % 2 == 1:
x = torch.concat(
[x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2
).contiguous()
n, w, h, c = x.size()
x = x.view(n, w, int(h / 2), int(c * 2))
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
return x
class MultimodalProjectorConfig(PretrainedConfig):
model_type = "v2l_projector"
def __init__(self, mm_projector_type: str = None, **kwargs):
super().__init__()
self.mm_projector_type = mm_projector_type
class MultimodalProjector(PreTrainedModel):
config_class = MultimodalProjectorConfig
def __init__(
self, mm_projector_cfg: MultimodalProjectorConfig, config: PretrainedConfig
):
super().__init__(mm_projector_cfg)
mm_projector_type = mm_projector_cfg.mm_projector_type
if mm_projector_type == "identity":
self.layers = IdentityMap()
elif mm_projector_type == "linear":
self.layers = nn.Linear(config.mm_hidden_size, config.hidden_size)
elif mm_projector_type == "mlp_downsample":
self.layers = nn.Sequential(
DownSampleBlock(),
nn.LayerNorm(config.mm_hidden_size * 4),
nn.Linear(config.mm_hidden_size * 4, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size),
)
else:
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", mm_projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
self.layers = nn.Sequential(*modules)
else:
raise ValueError(f"Unknown projector type: {mm_projector_type}")
def forward(self, x, *args, **kwargs):
return self.layers(x)
AutoConfig.register("v2l_projector", MultimodalProjectorConfig)
# This file is modified from https://github.com/haotian-liu/LLaVA/
AutoModel.register(MultimodalProjectorConfig, MultimodalProjector)
import torch # noqa
def build_vision_tower(
model_name_or_path: str, config: PretrainedConfig
) -> PreTrainedModel:
# skip vision tower instantiation
if model_name_or_path is None:
return None
vision_tower_arch = None
if config.resume_path and "radio" not in model_name_or_path:
assert os.path.exists(
model_name_or_path
), f"Resume vision tower path {model_name_or_path} does not exist!"
vision_tower_cfg = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=True
)
vision_tower_arch = vision_tower_cfg.architectures[0].lower()
vision_tower_name = (
vision_tower_arch if vision_tower_arch is not None else model_name_or_path
)
if "siglip" in vision_tower_name:
vision_tower = SiglipVisionTower(model_name_or_path, config)
else:
raise ValueError(f"Unknown vision tower: {model_name_or_path}")
config.mm_hidden_size = vision_tower.config.hidden_size
return vision_tower
def build_context_provider(
model_type_or_path: str, config: PretrainedConfig
) -> PreTrainedModel:
if model_type_or_path is None:
return None
# load from pretrained model
if config.resume_path:
assert os.path.exists(
model_type_or_path
), f"Resume context provider path {model_type_or_path} does not exist!"
return ContextProvider.from_pretrained(
model_type_or_path, config, torch_dtype=eval(config.model_dtype)
)
# build from scratch
else:
mm_projector_cfg = ContextProviderConfig(model_type_or_path)
mm_projector = ContextProvider(mm_projector_cfg, config).to(
eval(config.model_dtype)
)
return mm_projector
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# import deepspeed
class ContextProviderConfig(PretrainedConfig):
model_type = "context_provider"
def __init__(
self,
context_provider_type: str = None,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
num_mask_channels=0,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
zero_init_output=True,
residual_dropout=0.0,
context_image_as_queries=False,
context_provider_layer_indices=None,
masked_cross_attn=False,
crop_position_single_embedding=False,
trainable_crop_position_embedding=True,
crop_embedding_mode="add",
treat_image_as_cimage=False,
**kwargs,
):
super().__init__(**kwargs)
self.context_provider_type = context_provider_type
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.num_mask_channels = num_mask_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.zero_init_output = zero_init_output
self.residual_dropout = residual_dropout
self.context_image_as_queries = context_image_as_queries
# cross_attn_end_to_all
# the `num_hidden_layers` should be the same as the one in the vision tower
self.num_hidden_layers = num_hidden_layers
self.context_provider_layer_indices = context_provider_layer_indices
self.masked_cross_attn = masked_cross_attn
# If enabled, crop_position_embedding (delta to full pos) will be updated during training.
self.trainable_crop_position_embedding = trainable_crop_position_embedding
# If enabled, crop_position_embedding (delta to full pos) will be a single embedding for all positions.
self.crop_position_single_embedding = crop_position_single_embedding
# add: delta. replace: do not add the original positional embedding
self.crop_embedding_mode = crop_embedding_mode
# If True, the input image will be treated as a cimage (with mask as full 1s)
self.treat_image_as_cimage = treat_image_as_cimage
# Context Provider
class ContextProviderCrossAttention(nn.Module):
"""Multi-headed cross-attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
batch_size, kv_len, _ = encoder_hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, kv_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, kv_len, self.num_heads, self.head_dim
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# Visualizations (-inf are shown as white)
# import matplotlib.pyplot as plt
# plt.imshow(attention_mask[0, 0, 0].view(27, 27).detach().cpu().numpy())
# plt.title("Attention mask")
# plt.colorbar()
# plt.show()
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
# Visualizations: show the attention weights of the first head, with the first query
# import matplotlib.pyplot as plt
# plt.imshow(attn_weights[0, 0, 0].view(27, 27).detach().cpu().numpy())
# plt.title("Attention weights")
# plt.colorbar()
# plt.show()
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class ContextProviderMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
def get_token_mask_bias(mask, patch_size):
# Note: mask should be (0, 1)
with torch.no_grad():
# Add a channel dimension and perform conv
# mask_tokens_after_conv: (B, 1, H, W), example dimension: [1, 1, 27, 27]
mask_tokens_after_conv = F.conv2d(
input=mask[:, None],
weight=torch.ones(
(1, 1, patch_size, patch_size), device=mask.device, dtype=mask.dtype
),
bias=None,
stride=(patch_size, patch_size),
padding="valid",
)
token_mask_bias = torch.zeros_like(mask_tokens_after_conv)
token_mask_bias.masked_fill_(mask_tokens_after_conv < 1e-5, float("-inf"))
token_mask_bias = token_mask_bias.flatten(1)
# Flattened dimension: (1, 729)
return token_mask_bias
def attn_mask_from_cimage_concatenated(cimage_concatenated, patch_size):
# Use the mask from input image (4th channel)
mask_normalized = cimage_concatenated[:, 3]
mask_unnormalized = (mask_normalized + 1) / 2
# (1, 729)
token_mask_bias = get_token_mask_bias(mask_unnormalized, patch_size=patch_size)
# attn_mask: (B, 1, Q, KV)
# print("Token positions:", token_mask.nonzero())
# Obtain token mask in the bias format: in mask 0, out of mask -inf
q_kv = token_mask_bias.shape[-1]
attn_mask_bias = token_mask_bias[:, None, None, :].repeat(1, 1, q_kv, 1)
# Visualizations
# print(f"token_mask_bias shape: {token_mask_bias.shape}, attn_mask_bias shape: {attn_mask_bias.shape}")
# import matplotlib.pyplot as plt
# plt.imshow(attn_mask_bias[0, 0, 0].view(27, 27).detach().cpu().numpy())
# plt.title("Attention mask (outside)")
# plt.show()
return attn_mask_bias
# From SiglipEncoderLayer. We would like to modify this to cross-attention.
class CrossAttnEncoderLayer(nn.Module):
def __init__(self, config: ContextProviderConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.cross_attn = ContextProviderCrossAttention(config)
self.residual_dropout = nn.Dropout(config.residual_dropout)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = ContextProviderMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
if config.zero_init_output:
# TODO: alternatively, we could parameterize with an MLP
# These factors are initialized with 0 (so only residual passes through)
if config.context_provider_type != "cross_attn_at_the_end":
self.register_parameter("attn_factor", nn.Parameter(torch.zeros((1,))))
self.register_parameter("mlp_factor", nn.Parameter(torch.zeros((1,))))
else:
# Use scalar tensor for compatibility
self.register_parameter(
"attn_factor", nn.Parameter(torch.zeros((1,)).view(()))
)
self.register_parameter(
"mlp_factor", nn.Parameter(torch.zeros((1,)).view(()))
)
else:
self.attn_factor = 1.0
self.mlp_factor = 1.0
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.cross_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
# Dropping the residual: let the model leverage more on the context
hidden_states = (
self.residual_dropout(residual) + self.attn_factor * hidden_states
)
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.mlp_factor * hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class CrossAttnContextProviderEndToAll(nn.Module):
def __init__(self, config: ContextProviderConfig):
super().__init__()
self.layers = nn.ModuleList(
[
CrossAttnEncoderLayer(config)
for i in enumerate(range(config.num_hidden_layers))
if config.context_provider_layer_indices is None
or i in config.context_provider_layer_indices
]
)
self.patch_size = config.patch_size
self.masked_cross_attn = config.masked_cross_attn
def forward(self, context_image_features, cimage_concatenated, vision_tower):
# Use the mask from input image (4th channel)
if self.masked_cross_attn:
attn_mask = attn_mask_from_cimage_concatenated(
cimage_concatenated, patch_size=self.patch_size
)
else:
attn_mask = None
detail_raw_image = cimage_concatenated[:, 4:, ...]
# NOTE: when using context image as queries, the context image was swapped with the detail image before passing into the context provider
outputs = vision_tower(
detail_raw_image,
context_provider_layers=self.layers,
contexts=context_image_features,
cross_attention_mask=attn_mask,
)
return outputs
class ContextProvider(PreTrainedModel):
config_class = ContextProviderConfig
def __init__(
self, context_provider_cfg: ContextProviderConfig, config: PretrainedConfig
):
super().__init__(context_provider_cfg)
self.context_image_as_queries = context_provider_cfg.context_image_as_queries
self.context_provider_type = context_provider_type = (
context_provider_cfg.context_provider_type
)
self.treat_image_as_cimage = context_provider_cfg.treat_image_as_cimage
if self.context_image_as_queries:
assert (
not context_provider_cfg.masked_cross_attn
), "Masked cross-attention not implemented when using context image as queries."
assert (
"concat" not in context_provider_type
), "Concat not implemented when using context image as queries."
if context_provider_type == "cross_attn_end_to_all":
# Information flow: end of context features -> all detail features
self.context_provider_module = CrossAttnContextProviderEndToAll(
context_provider_cfg
)
else:
raise ValueError(f"Unknown context provider type: {context_provider_type}")
def forward(
self,
cimage_full_features=None,
cimage_crop_features=None,
cimage_concatenated=None,
vision_tower=None,
):
if self.context_provider_type == "cross_attn_end_to_all":
assert (
cimage_full_features.shape[0] == cimage_concatenated.shape[0]
), f"shape mismatches: {cimage_full_features.shape[0]} != {cimage_concatenated.shape[0]}"
return self.context_provider_module(
context_image_features=cimage_full_features,
cimage_concatenated=cimage_concatenated,
vision_tower=vision_tower,
)
else:
raise ValueError(f"Unknown context provider type: {context_provider_type}")
AutoConfig.register("context_provider", ContextProviderConfig)
AutoModel.register(ContextProviderConfig, ContextProvider)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for RADIO."""
if is_torch_available():
import torch
if is_torchvision_available():
pass
if is_tf_available():
pass
logger = logging.get_logger(__name__)
def rank_print(s):
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
print(f"[Rank {rank}] {s}")
class ImageProcessor(BaseImageProcessor):
r"""
Constructs an image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"longest_edge": 1024}`):
Size of the output image after resizing. If "longest_edge" is specified, resizes the longest edge of the image to match
`size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
to that size, possibly changing the aspect ratio. Can be overridden by the `size` parameter in the
`preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the
`preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_DEFAULT_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to the specified `pad_size`. Can be overridden by the `do_pad` parameter in the
`preprocess` method.
pad_size (`dict`, *optional*, defaults to `{"height": 1024, "width": 1024}`):
Size of the output image after padding. Can be overridden by the `pad_size` parameter in the `preprocess`
method.
pad_value (`float` or `Iterable[float]`, *optional*, defaults to `0.`):
Value of padded pixels.
pad_multiple (`int`, *optional*, defaults to `None`):
Pad to a multiple of specified number.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: bool = True,
pad_size: int = None,
pad_multiple: int = None,
pad_value: Optional[Union[float, List[float]]] = 0.0,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"longest_edge": 1024}
size = (
get_size_dict(max_size=size, default_to_square=False)
if not isinstance(size, dict)
else size
)
if pad_size is not None and pad_multiple is not None:
raise ValueError(
"pad_size and pad_multiple should not be set at the same time."
)
pad_size = (
pad_size
if pad_size is not None
else {"height": 1024, "width": 1024} if pad_multiple is not None else None
)
if do_pad:
pad_size = get_size_dict(pad_size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = (
image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
)
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
self.do_pad = do_pad
self.pad_multiple = pad_multiple
self.pad_size = pad_size
self.pad_value = tuple(pad_value) if isinstance(pad_value, list) else pad_value
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [
"images",
"segmentation_maps",
"do_resize",
"size",
"resample",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_pad",
"pad_size",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
def pad_image(
self,
image: np.ndarray,
pad_size: Dict[str, int],
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Pad an image to `(pad_size["height"], pad_size["width"])` to the right and bottom.
Args:
image (`np.ndarray`):
Image to pad.
pad_size (`Dict[str, int]`):
Size of the output image after padding.
data_format (`str` or `ChannelDimension`, *optional*):
The data format of the image. Can be either "channels_first" or "channels_last". If `None`, the
`data_format` of the `image` will be used.
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
output_height, output_width = pad_size["height"], pad_size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
pad_width = output_width - input_width
pad_height = output_height - input_height
padded_image = pad(
image,
((0, pad_height), (0, pad_width)),
data_format=data_format,
input_data_format=input_data_format,
constant_values=self.pad_value,
**kwargs,
)
return padded_image
def _get_preprocess_shape(self, old_shape: Tuple[int, int], longest_edge: int):
"""
Compute the output size given input size and target long side length.
"""
oldh, oldw = old_shape
scale = longest_edge * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
newh = int(newh + 0.5)
neww = int(neww + 0.5)
return (newh, neww)
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"longest_edge": int}` or `{"width": int, "height": int}` specifying the size
of the output image. If "longest_edge" is specified, resizes the longest edge of the image to match
`size["longest_edge"]` while maintaining the aspect ratio. If "width" and "height" are specified, resizes the image
to that size, possibly changing the aspect ratio.
resample:
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "longest_edge" not in size:
if "width" not in size or "height" not in size:
raise ValueError(
f"The `size` dictionary must contain the key `longest_edge`, or `width` and `height`. Got {size.keys()}"
)
input_size = get_image_size(image, channel_dim=input_data_format)
if "longest_edge" in size:
output_height, output_width = self._get_preprocess_shape(
input_size, size["longest_edge"]
)
else:
output_height, output_width = size["height"], size["width"]
return resize(
image,
size=(output_height, output_width),
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def _preprocess(
self,
image: ImageInput,
do_resize: bool,
do_rescale: bool,
do_normalize: bool,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
rescale_factor: Optional[float] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[Dict[str, int]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
):
if do_resize:
image = self.resize(
image=image,
size=size,
resample=resample,
input_data_format=input_data_format,
)
reshaped_input_size = get_image_size(image, channel_dim=input_data_format)
if do_rescale:
image = self.rescale(
image=image, scale=rescale_factor, input_data_format=input_data_format
)
if do_normalize:
image = self.normalize(
image=image,
mean=image_mean,
std=image_std,
input_data_format=input_data_format,
)
if do_pad:
if self.pad_multiple:
h, w = get_image_size(image, channel_dim=input_data_format)
pad_size = {
"height": math.ceil(h / self.pad_multiple) * self.pad_multiple,
"width": math.ceil(w / self.pad_multiple) * self.pad_multiple,
}
image = self.pad_image(
image=image, pad_size=pad_size, input_data_format=input_data_format
)
return image, reshaped_input_size
def _preprocess_image(
self,
image: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[Dict[str, int]] = None,
do_convert_rgb: Optional[bool] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
# image = to_numpy_array(image)
# import time
# if int(time.time()*1000) % 10 == 0:
# # create an PIL image of size 1x1
# image = PIL.Image.new('RGB', (1, 1))
if isinstance(image, Image.Image):
# PIL always uses Channels Last.
input_data_format = ChannelDimension.LAST
# PIL RGBA images are converted to RGB
# mode_before = image.mode
if do_convert_rgb:
image = convert_to_rgb(image)
# All transformations expect numpy arrays.
image = to_numpy_array(image)
# if isinstance(image_, np.ndarray):
# rank_print(f"preprocess image type={type(image_)} shape={image_.shape} array shape={image.shape}")
# elif isinstance(image_, Image.Image):
# rank_print(f"preprocessimage type={type(image_)} size={image_.size} mode={image_.mode} array shape={image.shape}")
# else:
# rank_print(f"preprocess unknown image type={type(image_)} array shape={image.shape}")
if len(image.shape) == 2:
h, w = image.shape
ret = np.empty((h, w, 3), dtype=np.uint8)
ret[:, :, 0] = image
ret[:, :, 1] = image
ret[:, :, 2] = image
image = ret
rank_print(f"preprocess new image shape={image.shape}")
elif len(image.shape) == 3 and image.shape[-1] == 1:
ret = np.empty((h, w, 3), dtype=np.uint8)
ret[:, :, 0] = image[:, :, 0]
ret[:, :, 1] = image[:, :, 0]
ret[:, :, 2] = image[:, :, 0]
image = ret
rank_print(f"preprocess new image shape={image.shape}")
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)
original_size = get_image_size(image, channel_dim=input_data_format)
image, reshaped_input_size = self._preprocess(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
pad_size=pad_size,
input_data_format=input_data_format,
)
if data_format is not None:
image = to_channel_dimension_format(
image, data_format, input_channel_dim=input_data_format
)
# rank_print(f"preprocess original_size={original_size} reshaped_input_size={reshaped_input_size} image shape={image.shape} type={type(image)}")
# if image is a single channel convert to rgb
if do_convert_rgb and image.shape[0] == 1:
c, h, w = image.shape
ret = np.empty((3, h, w), dtype=np.uint8)
ret[0, :, :] = image[0, :, :]
ret[1, :, :] = image[0, :, :]
ret[2, :, :] = image[0, :, :]
image = ret
rank_print(f"preprocess final: {image.shape}")
return image, original_size, reshaped_input_size
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: Optional["PILImageResampling"] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[Union[int, float]] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_pad: Optional[bool] = None,
pad_size: Optional[Dict[str, int]] = None,
do_convert_rgb: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The longest edge of the image is resized to
`size["longest_edge"]` whilst preserving the aspect ratio.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image pixel values by rescaling factor.
rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to apply to the image pixel values.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_pad (`bool`, *optional*, defaults to `self.do_pad`):
Whether to pad the image.
pad_size (`Dict[str, int]`, *optional*, defaults to `self.pad_size`):
Controls the size of the padding applied to the image. The image is padded to `pad_size["height"]` and
`pad_size["width"]` if `do_pad` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = (
get_size_dict(max_size=size, default_to_square=False)
if not isinstance(size, dict)
else size
)
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_pad = do_pad if do_pad is not None else self.do_pad
pad_size = pad_size if pad_size is not None else self.pad_size
if do_pad:
pad_size = get_size_dict(pad_size, default_to_square=True)
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
images, original_sizes, reshaped_input_sizes = zip(
*(
self._preprocess_image(
image=img,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
pad_size=pad_size,
do_convert_rgb=do_convert_rgb,
data_format=data_format,
input_data_format=input_data_format,
)
for img in images
)
)
data = {
"pixel_values": images,
"original_sizes": original_sizes,
"reshaped_input_sizes": reshaped_input_sizes,
}
return BatchFeature(data=data, tensor_type=return_tensors)
# This file is modified from https://github.com/haotian-liu/LLaVA/
class VisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = getattr(args, "mm_vision_select_layer", -2)
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
self.cfg_only = None
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def _maybe_resize_pos_embeds(
self,
model: PreTrainedModel,
image_processor: BaseImageProcessor,
resolution: int = -1,
interpolate_mode: str = "linear",
):
if resolution in [model.config.image_size, -1]:
return
print(
f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..."
)
embeddings = model.vision_model.embeddings
patch_size = embeddings.patch_size
num_new_tokens = int((resolution // patch_size) ** 2)
old_embeddings = embeddings.position_embedding
match interpolate_mode:
case "linear":
# Step 1: Calculate the corresponding patch ID (pid) in the current resolution (M patches) based on the target resolution (N patches). Formula: pid = pid / N * M
# Step 2: Obtain new embeddings by interpolating between the embeddings of the two nearest calculated patch IDs. Formula: new_embeds = (pid - floor(pid)) * embeds[ceil(pid)] + (ceil(pid) - pid) * embeds[floor(pid)]
# At inference time, we assume deepspeed zero3 is not enabled.
# import deepspeed
# with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
new_embeddings = nn.Embedding(
num_new_tokens,
old_embedding_dim,
dtype=old_embeddings.weight.dtype,
device=old_embeddings.weight.device,
)
mapped_indices = (
torch.arange(num_new_tokens).to(old_embeddings.weight.device)
/ (num_new_tokens - 1)
* (old_num_tokens - 1)
)
floor_indices = torch.clamp(
mapped_indices.floor().long(), min=0, max=old_num_tokens - 1
)
ceil_indices = torch.clamp(
mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1
)
# At inference time, we assume deepspeed zero3 is not enabled.
# params = [old_embeddings.weight, new_embeddings.weight]
# with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
# interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[
# ceil_indices, :
# ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :]
interpolated_embeds = (mapped_indices - floor_indices)[
:, None
] * old_embeddings.weight.data[ceil_indices, :] + (
ceil_indices - mapped_indices
)[
:, None
] * old_embeddings.weight.data[
floor_indices, :
]
new_embeddings.weight.data = interpolated_embeds
case _:
raise NotImplementedError
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
new_embeddings.requires_grad_(old_embeddings.weight.requires_grad)
# update vision encoder's configurations
model.config.image_size = resolution
if hasattr(image_processor, "crop_size"):
# CLIP vision tower
image_processor.crop_size = resolution
else:
# SIGLIP vision tower
assert hasattr(image_processor, "size")
image_processor.size = {"height": resolution, "width": resolution}
# TODO define a '_reinitialize' method for VisionTower
embeddings.position_embedding = new_embeddings
embeddings.image_size = resolution
embeddings.num_patches = embeddings.num_positions = num_new_tokens
embeddings.position_ids = (
torch.arange(embeddings.num_positions)
.expand((1, -1))
.to(old_embeddings.weight.device)
)
def forward(self, images, **kwargs):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
**kwargs,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
**kwargs,
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Siglip model configuration"""
logger = logging.get_logger(__name__)
SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json",
}
class SiglipTextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
the `inputs_ids` passed when calling [`SiglipModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
max_position_embeddings (`int`, *optional*, defaults to 64):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
pad_token_id (`int`, *optional*, defaults to 1):
The id of the padding token in the vocabulary.
bos_token_id (`int`, *optional*, defaults to 49406):
The id of the beginning-of-sequence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 49407):
The id of the end-of-sequence token in the vocabulary.
Example:
```python
>>> from transformers import SiglipTextConfig, SiglipTextModel
>>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipTextConfig()
>>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipTextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_text_model"
def __init__(
self,
vocab_size=32000,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=64,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
# This differs from `CLIPTokenizer`'s default and from openai/siglip
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
pad_token_id=1,
bos_token_id=49406,
eos_token_id=49407,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.attention_dropout = attention_dropout
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
# cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
# get the text config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["text_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class SiglipVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_mask_channels (`int`, *optional*, defaults to 0):
Number of mask channels in the input images.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_mask_channels=0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.num_mask_channels = num_mask_channels
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> "PretrainedConfig":
# cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs
)
# get the vision config dict if we are loading from SiglipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if (
"model_type" in config_dict
and hasattr(cls, "model_type")
and config_dict["model_type"] != cls.model_type
):
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class SiglipConfig(PretrainedConfig):
r"""
[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`SiglipTextConfig`].
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
kwargs (*optional*):
Dictionary of keyword arguments.
Example:
```python
>>> from transformers import SiglipConfig, SiglipModel
>>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipConfig()
>>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
>>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
>>> from transformers import SiglipTextConfig, SiglipVisionConfig
>>> # Initializing a SiglipText and SiglipVision configuration
>>> config_text = SiglipTextConfig()
>>> config_vision = SiglipVisionConfig()
>>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
```"""
model_type = "siglip"
def __init__(self, text_config=None, vision_config=None, **kwargs):
super().__init__(**kwargs)
if text_config is None:
text_config = {}
logger.info(
"`text_config` is `None`. Initializing the `SiglipTextConfig` with default values."
)
if vision_config is None:
vision_config = {}
logger.info(
"`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values."
)
self.text_config = SiglipTextConfig(**text_config)
self.vision_config = SiglipVisionConfig(**vision_config)
self.initializer_factor = 1.0
@classmethod
def from_text_vision_configs(
cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs
):
r"""
Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
model configuration.
Returns:
[`SiglipConfig`]: An instance of a configuration object
"""
return cls(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
**kwargs,
)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for SigLIP."""
logger = logging.get_logger(__name__)
def is_scaled_image(image: np.ndarray) -> bool:
"""
Checks to see whether the pixel values have already been rescaled to [0, 1].
"""
if image.dtype == np.uint8:
return False
# It's possible the image has pixel values in [0, 255] but is of floating type
return np.min(image) >= 0 and np.max(image) <= 1
if is_vision_available():
import PIL
class SiglipImageProcessor(BaseImageProcessor):
r"""
Constructs a SigLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
`do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 384}
size = get_size_dict(size, default_to_square=False)
image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
# size = get_size_dict(size, default_to_square=False)
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError(
"Size must contain either 'shortest_edge' or 'height' and 'width'."
)
output_size = get_resize_output_image_size(
image, size=size, default_to_square=default_to_square
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
# do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
# crop_size = crop_size if crop_size is not None else self.crop_size
# crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = (
rescale_factor if rescale_factor is not None else self.rescale_factor
)
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = (
do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None:
raise ValueError("Size must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError(
"Image mean and std must be specified if do_normalize is True."
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
# if input_data_format is None:
# # We assume that all images have the same channel dimension format.
# input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample)
for image in images
]
if do_rescale:
images = [rescale(image=image, scale=rescale_factor) for image in images]
if do_normalize:
output_images = []
for image in images:
if get_channel_dimension_axis(image) == 0:
image = image.transpose((1, 2, 0))
if image.shape[-1] == 1:
image = np.dstack((image, image, image))
output_images.append(image)
images = output_images
# for image in images:
# # print("image shape", image.shape)
# channel_axis = get_channel_dimension_axis(image)
# num_channels = image.shape[channel_axis]
# if num_channels != len(image_mean):
# print("image_mean", image_mean)
# print("channel_axis", channel_axis)
# print("num_channels", num_channels)
# print("image.shape", image.shape)
# raise ValueError(
# f"Number of channels in the image ({num_channels}) does not match the length of image mean "
# f"({len(image_mean)})."
# )
images = [
normalize(image=image, mean=image_mean, std=image_std)
for image in images
]
images = [to_channel_dimension_format(image, data_format) for image in images]
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
# coding=utf-8
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Siglip model."""
# from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
logger = logging.get_logger(__name__)
# _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
# SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
# "google/siglip-base-patch16-224",
# # See all SigLIP models at https://huggingface.co/models?filter=siglip
# ]
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor,
mean: float = 0.0,
std: float = 1.0,
a: float = -2.0,
b: float = 2.0,
) -> torch.Tensor:
"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \\leq \text{mean} \\leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
class SiglipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
class SiglipTextModelOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class SiglipOutput(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
text_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output(`BaseModelOutputWithPooling`):
The output of the [`SiglipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: torch.FloatTensor = None
logits_per_text: torch.FloatTensor = None
text_embeds: torch.FloatTensor = None
image_embeds: torch.FloatTensor = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
(
self[k]
if k not in ["text_model_output", "vision_model_output"]
else getattr(self, k).to_tuple()
)
for k in self.keys()
)
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
print(f"Number of mask channels: {config.num_mask_channels}")
if config.num_mask_channels:
# Mask should have the same output shape to be added.
# Currently we have bias in this embedding (so that mask vs no mask are different).
self.mask_patch_embedding = nn.Conv2d(
in_channels=config.num_mask_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.mask_patch_embedding.use_zero_init = True
else:
self.mask_patch_embedding = None
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False,
)
def forward(
self,
pixel_values: torch.FloatTensor,
additional_position_embedding: Optional[torch.Tensor] = None,
additional_embedding_mode: Optional[str] = None,
) -> torch.Tensor:
if self.mask_patch_embedding is None:
patch_embeds = self.patch_embedding(
pixel_values
) # shape = [*, width, grid, grid]
else:
# Comment this out if you want to encode both images without mask channel and with mask channel.
# However, if different samples in the batch have different number of channels, this is not applicable.
# assert pixel_values.size(1) == 4, f"Input does not have a mask channel, shape: {pixel_values.shape}"
patch_embeds = self.patch_embedding(
pixel_values[:, :3, ...]
) # shape = [*, width, grid, grid]
if pixel_values.size(1) == 4:
patch_embeds = patch_embeds + self.mask_patch_embedding(
pixel_values[:, 3:4, ...]
)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
if additional_position_embedding is not None:
if additional_embedding_mode == "add":
embeddings = embeddings + self.position_embedding(self.position_ids)
embeddings = embeddings + additional_position_embedding
elif additional_embedding_mode == "replace":
# The original positional embedding is not used (multiplied by zero to ensure all parameters are used to be safe)
embeddings = (
embeddings + self.position_embedding(self.position_ids) * 0.0
)
embeddings = embeddings + additional_position_embedding
else:
raise ValueError(
f"additional_embedding_mode should be either 'add' or 'replace', got {additional_embedding_mode}"
)
else:
# Without additional position embedding
embeddings = embeddings + self.position_embedding(self.position_ids)
# print("No additional position embedding")
return embeddings
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
class SiglipTextEmbeddings(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = (
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SiglipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SiglipConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, SiglipVisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, SiglipConfig)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, SiglipAttention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, SiglipMLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, SiglipModel):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, nn.Conv2d) and getattr(module, "use_zero_init", False):
param_list = [module.weight]
if module.bias is not None:
param_list += [module.bias]
# This is used in mask patch embedding
#
# with deepspeed.zero.GatheredParameters(param_list, modifier_rank=0):
# for param in param_list:
# nn.init.zeros_(param)
for param in param_list:
nn.init.zeros_(param)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
SIGLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
SIGLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
class SiglipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SiglipEncoderLayer`].
Args:
config: SiglipConfig
"""
def __init__(self, config: SiglipConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
context_provider_layers: Optional[nn.ModuleList] = None,
contexts: Optional[List[torch.Tensor]] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
context_provider_layers (nn.ModuleList): ModuleList of context provider layers.
contexts: List of torch.Tensor for context (for KV in cross-attention).
cross_attention_mask (`torch.Tensor` of shape `(batch_size, q_sequence_length, kv_sequence_length)`, *optional*): mask for cross-attention.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for layer_index, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if context_provider_layers:
# Right now contexts is passed as the encoder_hidden_states (the output hidden_states of the context ViT).
context_provider_layer = context_provider_layers[layer_index]
if context_provider_layer is not None:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
context_provider_layer.__call__,
hidden_states,
contexts,
cross_attention_mask,
output_attentions,
)
else:
layer_outputs = context_provider_layer(
hidden_states,
contexts,
cross_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
class SiglipTextTransformer(nn.Module):
def __init__(self, config: SiglipTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipTextEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, embed_dim)
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
# if attention_mask is not None:
# # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
# attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# Assuming "sticky" EOS tokenization, last token is always EOS.
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@add_start_docstrings(
"""The text model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipTextModel(SiglipPreTrainedModel):
config_class = SiglipTextConfig
_no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
def __init__(self, config: SiglipTextConfig):
super().__init__(config)
self.text_model = SiglipTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoTokenizer, SiglipTextModel
>>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = SiglipMultiheadAttentionPoolingHead(config)
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig
)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(last_hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, batch_first=True
)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@add_start_docstrings(
"""The vision model from SigLIP without any head or projection on top.""",
SIGLIP_START_DOCSTRING,
)
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig
)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SiglipVisionModel
>>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)
@add_start_docstrings(SIGLIP_START_DOCSTRING)
class SiglipModel(SiglipPreTrainedModel):
config_class = SiglipConfig
def __init__(self, config: SiglipConfig):
super().__init__(config)
if not isinstance(config.text_config, SiglipTextConfig):
raise ValueError(
"config.text_config is expected to be of type SiglipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, SiglipVisionConfig):
raise ValueError(
"config.vision_config is expected to be of type SiglipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
self.text_model = SiglipTextTransformer(text_config)
self.vision_model = SiglipVisionTransformer(vision_config)
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`SiglipTextModel`].
Examples:
```python
>>> from transformers import AutoTokenizer, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
>>> # important: make sure to set padding="max_length" as that's how the model was trained
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... text_features = model.get_text_features(**inputs)
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`SiglipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... image_features = model.get_image_features(**inputs)
```"""
# Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = vision_outputs[1]
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SiglipOutput]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, AutoModel
>>> import torch
>>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
>>> inputs = processor(text=texts, images=image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image
>>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
>>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
31.9% that image 0 is 'a photo of 2 cats'
```"""
# Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
text_embeds = text_outputs[1]
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_text = (
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp()
+ self.logit_bias
)
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
raise NotImplementedError("SigLIP loss to be implemented")
if not return_dict:
output = (
logits_per_image,
logits_per_text,
text_embeds,
image_embeds,
text_outputs,
vision_outputs,
)
return ((loss,) + output) if loss is not None else output
return SiglipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image/Text processor class for SigLIP.
"""
class SiglipProcessor(ProcessorMixin):
r"""
Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
[`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
[`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
Args:
image_processor ([`SiglipImageProcessor`]):
The image processor is a required input.
tokenizer ([`SiglipTokenizer`]):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = "SiglipTokenizer"
def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = "max_length",
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` argument to
SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `max_length`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if text is None and images is None:
raise ValueError(
"You have to specify either text or images. Both cannot be none."
)
if text is not None:
encoding = self.tokenizer(
text,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors)
if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
else:
return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for SigLIP model."""
if TYPE_CHECKING:
from transformers.tokenization_utils_base import TextInput
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/spiece.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"google/siglip-base-patch16-224": 256,
}
SPIECE_UNDERLINE = "▁"
class SiglipTokenizer(PreTrainedTokenizer):
"""
Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
contains the vocabulary necessary to instantiate a tokenizer.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (`str`, *optional*, defaults to `"</s>"`):
The token used for padding, for example when batching sequences of different lengths.
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
to set:
- `enable_sampling`: Enable subword regularization.
- `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
- `nbest_size = {0,1}`: No sampling is performed.
- `nbest_size > 1`: samples from the nbest_size results.
- `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
using forward-filtering-and-backward-sampling algorithm.
- `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
BPE-dropout.
model_max_length (`int`, *optional*, defaults to 64):
The maximum length (in number of tokens) for model inputs.
do_lower_case (`bool`, *optional*, defaults to `True`):
Whether or not to lowercase the input when tokenizing.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file,
eos_token="</s>",
unk_token="<unk>",
pad_token="</s>",
additional_special_tokens=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
model_max_length=64,
do_lower_case=True,
**kwargs,
) -> None:
requires_backends(self, "protobuf")
pad_token = (
AddedToken(
pad_token, rstrip=True, lstrip=True, normalized=False, special=True
)
if isinstance(pad_token, str)
else pad_token
)
unk_token = (
AddedToken(
unk_token, rstrip=True, lstrip=True, normalized=False, special=True
)
if isinstance(unk_token, str)
else unk_token
)
eos_token = (
AddedToken(
eos_token, rstrip=True, lstrip=True, normalized=False, special=True
)
if isinstance(eos_token, str)
else eos_token
)
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.do_lower_case = do_lower_case
self.vocab_file = vocab_file
self.sp_model = self.get_spm_processor()
self.vocab_file = vocab_file
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
additional_special_tokens=additional_special_tokens,
sp_model_kwargs=self.sp_model_kwargs,
model_max_length=model_max_length,
do_lower_case=do_lower_case,
**kwargs,
)
def get_spm_processor(self):
tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
with open(self.vocab_file, "rb") as f:
sp_model = f.read()
model_pb2 = import_protobuf()
model = model_pb2.ModelProto.FromString(sp_model)
normalizer_spec = model_pb2.NormalizerSpec()
normalizer_spec.add_dummy_prefix = False
model.normalizer_spec.MergeFrom(normalizer_spec)
sp_model = model.SerializeToString()
tokenizer.LoadFromSerializedProto(sp_model)
return tokenizer
@property
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
def vocab_size(self):
return self.sp_model.get_piece_size()
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
def get_vocab(self):
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
# normal case: some special tokens
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
warnings.warn(
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
" eos tokens being added."
)
return token_ids
else:
return token_ids + [self.eos_token_id]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
def create_token_type_ids_from_sequences(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
eos = [self.eos_token_id]
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A sequence has the following format:
- single sequence: `X </s>`
- pair of sequences: `A </s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
if token_ids_1 is None:
return token_ids_0
else:
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
return token_ids_0 + token_ids_1
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
def __setstate__(self, d):
self.__dict__ = d
# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
def remove_punctuation(self, text: str) -> str:
return text.translate(str.maketrans("", "", string.punctuation))
# source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
"""Returns canonicalized `text` (puncuation removed).
Args:
text (`str`):
String to be canonicalized.
keep_punctuation_exact_string (`str`, *optional*):
If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
(but will still remove '{' and '}' that appear separately).
"""
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(
self.remove_punctuation(part)
for part in text.split(keep_punctuation_exact_string)
)
else:
text = self.remove_punctuation(text)
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def tokenize(
self, text: "TextInput", add_special_tokens=False, **kwargs
) -> List[str]:
"""
Converts a string to a list of tokens.
"""
tokens = super().tokenize(
SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs
)
if (
len(tokens) > 1
and tokens[0] == SPIECE_UNDERLINE
and tokens[1] in self.all_special_tokens
):
tokens = tokens[1:]
return tokens
@property
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
def unk_token_length(self):
return len(self.sp_model.encode(str(self.unk_token)))
def _tokenize(self, text, **kwargs):
"""
Returns a tokenized string.
We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
SPIECE_UNDERLINE.
For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
`self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
"""
text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
tokens = self.sp_model.encode(text, out_type=str)
# 1. Encode string + prefix ex: "<unk> Hey"
tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
# 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
return (
tokens[self.unk_token_length :]
if len(tokens) >= self.unk_token_length
else tokens
)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
# since we manually add the prefix space, we have to remove it
tokens[0] = tokens[0].lstrip(SPIECE_UNDERLINE)
out_string = ""
prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip()
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "")
+ VOCAB_FILES_NAMES["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file
) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
import torch # noqa
class SiglipVisionTower(VisionTower):
def __init__(
self, model_name_or_path: str, config: PretrainedConfig, state_dict=None
):
super().__init__(model_name_or_path, config)
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path)
self.vision_tower = SiglipVisionModel.from_pretrained(
# TODO(ligeng): why pass config here leading to errors?
model_name_or_path,
torch_dtype=eval(config.model_dtype),
state_dict=state_dict,
)
self.is_loaded = True
AutoConfig.register("siglip_vision_model", SiglipVisionConfig, exist_ok=True)
AutoModel.register(SiglipVisionConfig, SiglipVisionModel, exist_ok=True)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modified from https://github.com/haotian-liu/LLaVA/
class LlavaLlamaConfig(LlavaConfig):
model_type = "llava_llama"
# FIXME we will follow the convention to add a new class for CausalLM in the future
class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
config_class = LlavaLlamaConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
tokenizer_image_token = staticmethod(tokenizer_image_token)
def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
super().__init__(config)
self.dam_model = None
self.pretrained_model_name_or_path = None
self.init_vlm(config=config, *args, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: bool = None,
torch_dtype: Optional[Union[str, torch.dtype]] = torch.float16,
init_dam: bool = False,
# conv_mode and prompt_mode are only used by `init_dam` in `from_pretrained` if `init_dam` is set to True
conv_mode: str = "v1",
prompt_mode: str = "full+focal_crop",
**kwargs,
):
if torch_dtype:
config.model_dtype = str(torch_dtype)
if hasattr(cls, "load_pretrained"):
obj = cls.load_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
else:
obj = super(LlavaLlamaModel).from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
obj.pretrained_model_name_or_path = pretrained_model_name_or_path
# `init_dam` is used to initialize a `DescribeAnythingModel` object in a `LlavaLlamaModel` in DAM. If you initialize `DescribeAnythingModel` on your own outside, then you don't have to use this option.
# This is very useful if you use `from_pretrained` with remote code execution and don't want to put implementation for `DescribeAnythingModel` class in your codebase.
if init_dam:
obj.init_dam(conv_mode, prompt_mode)
return obj
def forward(
self,
input_ids: torch.LongTensor = None,
images: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
self.freezed_module_patch()
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, position_ids, attention_mask, past_key_values, labels, images
)
# Note (kentang-mit@): we have a unit test for this function.
if self.training:
(
_,
new_position_ids,
new_attention_mask,
_,
new_inputs_embeds,
new_labels,
sorted_seqlens_in_batch,
) = self.repack_multimodal_data(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels,
)
new_input_ids = None
past_key_values = None
else:
new_attention_mask = attention_mask
new_position_ids = position_ids
new_inputs_embeds = inputs_embeds
new_labels = labels
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
new_input_ids = input_ids
outputs = self.llm.forward(
input_ids=new_input_ids,
attention_mask=new_attention_mask,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=new_inputs_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
seqlens_in_batch=sorted_seqlens_in_batch,
)
return outputs
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.FloatTensor] = None,
images: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generation_kwargs,
):
if images is not None:
(
_,
_,
attention_mask,
_,
inputs_embeds,
_,
) = self.prepare_inputs_labels_for_multimodal(
input_ids, None, attention_mask, None, None, images
)
else:
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = inputs_embeds.to(self.dtype)
outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**generation_kwargs,
)
return outputs
def init_dam(self, conv_mode, prompt_mode):
model_name = get_model_name_from_path(self.pretrained_model_name_or_path)
self.dam_model = DescribeAnythingModel(
model_path=dict(
model=self, tokenizer=self.tokenizer, model_name=model_name
),
conv_mode=conv_mode,
prompt_mode=prompt_mode,
)
return self.dam_model
@property
def dam(self):
if self.dam_model is None:
self.init_dam()
return self.dam_model
AutoConfig.register("llava_llama", LlavaLlamaConfig)
AutoModel.register(LlavaLlamaConfig, LlavaLlamaModel)
import torch # noqa
def has_tokenizer(path):
if (
osp.exists(osp.join(path, "special_tokens_map.json"))
and osp.exists(osp.join(path, "tokenizer_config.json"))
and (
osp.exists(osp.join(path, "tokenizer.model"))
or osp.exists(osp.join(path, "tokenizer.json"))
)
):
# print("[has_tokenizer]", path, True)
return True
from huggingface_hub import HfApi, file_exists
from huggingface_hub.utils import HFValidationError
api = HfApi()
try:
valid_hf_repo = api.repo_exists(path)
except HFValidationError:
valid_hf_repo = False
if (
valid_hf_repo
and file_exists(path, "special_tokens_map.json")
and file_exists(path, "tokenizer_config.json")
and (
file_exists(path, "tokenizer.model") or file_exists(path, "tokenizer.json")
)
):
# print("[has_tokenizer]", path, True)
return True
# print("[has_tokenizer]", path, False)
return False
def context_length_extension(config):
orig_ctx_len = getattr(config, "max_position_embeddings", None)
model_max_length = getattr(config, "model_max_length", None)
if orig_ctx_len and model_max_length > orig_ctx_len:
print(f"Scaling RoPE from {orig_ctx_len} to {model_max_length}")
scaling_factor = float(math.ceil(model_max_length / orig_ctx_len))
config.rope_scaling = {"type": "linear", "factor": scaling_factor}
return config
def build_llm_and_tokenizer(
model_name_or_path: str,
config: PretrainedConfig,
# config_cls: PretrainedConfig = None,
# llm_cls: PreTrainedModel = None,
attn_implementation=None,
model_max_length=None,
*args,
**kwargs,
) -> PreTrainedModel:
# if config_cls is None:
# config_cls = AutoConfig
# if llm_cls is None:
# llm_cls = AutoModelForCausalLM
# config_cls = AutoConfig
# llm_cls = AutoModelForCausalLM
# extra configuration for llm
# print("build_llm_and_tokenizer():", model_name_or_path); input("DEBUG")
llm_cfg = AutoConfig.from_pretrained(model_name_or_path)
llm_cfg._attn_implementation = attn_implementation
llm_cfg.model_max_length = model_max_length
if model_max_length is not None:
context_length_extension(llm_cfg)
llm = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
config=llm_cfg,
torch_dtype=eval(config.model_dtype),
*args,
**kwargs,
)
llm_path = model_name_or_path
if not has_tokenizer(llm_path):
warnings.warn(
"tokenizer found in VLM root folder. Move to ./{VILA}/llm in the future."
)
llm_path = osp.join(llm_path, "llm")
# TODO(ligeng): use LLM class to judge to better compability.
if "mpt" in model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
llm_path,
model_max_length=llm_cfg.model_max_length,
padding_side="right",
)
elif "yi" in model_name_or_path.lower():
tokenizer = AutoTokenizer.from_pretrained(
llm_path,
model_max_length=llm_cfg.model_max_length,
padding_side="right",
use_fast=False,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
llm_path,
model_max_length=llm_cfg.model_max_length,
padding_side="right",
use_fast=False,
legacy=False,
)
# TODO(ligeng): is this necessary for llava?
config.hidden_size = llm.config.hidden_size
return llm, tokenizer
# This file is modified from https://github.com/haotian-liu/LLaVA/ and https://github.com/NVlabs/VILA/
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: we may move LlavaConfig to configuration_llava.py
# from model.configuration_llava import LlavaConfig
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def load_pretrained_model(
model_path,
model_name,
model_base=None,
load_8bit=False,
load_4bit=False,
device_map="auto",
device="cuda",
**kwargs,
):
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs["device_map"] = {"": device}
if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
else:
kwargs["torch_dtype"] = torch.float16
config = AutoConfig.from_pretrained(model_path)
config.resume_path = model_path
prepare_config_for_eval(config, kwargs)
model = LlavaLlamaModel(config=config, low_cpu_mem_usage=True, **kwargs)
tokenizer = model.tokenizer
model.eval()
# mm_use_im_start_end = getattr(
# model.config, "mm_use_im_start_end", False)
# mm_use_im_patch_token = getattr(
# model.config, "mm_use_im_patch_token", True)
# if mm_use_im_patch_token:
# tokenizer.add_tokens(
# [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
# if mm_use_im_start_end:
# tokenizer.add_tokens(
# [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
# )
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
vision_tower.to(device=device, dtype=torch.float16)
mm_projector = model.get_mm_projector()
mm_projector.to(device=device, dtype=torch.float16)
context_provider = model.get_context_provider()
if context_provider is not None:
context_provider.to(device=device, dtype=torch.float16)
image_processor = vision_tower.image_processor
if hasattr(model.llm.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
def parse_model_name_or_path(config: PretrainedConfig, model_name="llm", suffix="_cfg"):
target_model = f"{model_name}{suffix}"
target_cfg = getattr(config, target_model, None)
if isinstance(target_cfg, str):
return target_cfg
elif isinstance(target_cfg, dict):
return target_cfg["architectures"][0]
else:
raise ValueError(f"Invalid {target_model} configuration!")
def prepare_config_for_eval(config: PretrainedConfig, kwargs: dict):
try:
# compatible with deprecated config convention
if getattr(config, "vision_tower_cfg", None) is None:
config.vision_tower_cfg = config.mm_vision_tower
except AttributeError:
raise ValueError(
f"Invalid configuration! Cannot find vision_tower in config:\n{config}"
)
config.model_dtype = kwargs.pop("torch_dtype").__str__()
# siglip does not support device_map = "auto"
vision_tower_name = parse_model_name_or_path(config, "vision_tower")
if "siglip" in vision_tower_name.lower():
kwargs["device_map"] = "cuda"
class DescribeAnythingModel(nn.Module):
def __init__(self, model_path, conv_mode, prompt_mode, **kwargs):
super().__init__()
self.model_path = model_path
self.conv_mode = conv_mode
self.prompt_mode = prompt_mode
if isinstance(model_path, str):
self.tokenizer, self.model, _, _ = load_pretrained_model(
model_path, None, None, **kwargs
)
self.model_name = get_model_name_from_path(model_path)
else:
# model_path is actually a dict with model, tokenizer, and (optionally) model_name
self.model = model_path["model"]
self.tokenizer = model_path["tokenizer"]
self.model_name = model_path.get("model_name", None)
image_processor = self.model.vision_tower.image_processor
self.model.config.image_processor = image_processor
def get_prompt(self, qs):
if DEFAULT_IMAGE_TOKEN not in qs:
raise ValueError("no <image> tag found in input.")
conv = conv_templates[self.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt, conv
@staticmethod
def mask_to_box(mask_np):
mask_coords = np.argwhere(mask_np)
y0, x0 = mask_coords.min(axis=0)
y1, x1 = mask_coords.max(axis=0) + 1
h = y1 - y0
w = x1 - x0
return x0, y0, w, h
@classmethod
def crop_image(cls, pil_img, mask_np, crop_mode, min_box_w=48, min_box_h=48):
if crop_mode == "full":
# no crop
info = dict(mask_np=mask_np)
return pil_img, info
if crop_mode == "crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert (
img_np.shape[:2] == mask_np.shape
), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
cropped_mask_np = mask_np[y0 : y0 + h, x0 : x0 + w]
cropped_img_np = img_np[y0 : y0 + h, x0 : x0 + w]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "context_crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert (
img_np.shape[:2] == mask_np.shape
), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
img_h, img_w = img_np.shape[:2]
cropped_mask_np = mask_np[
max(y0 - h, 0) : min(y0 + 2 * h, img_h),
max(x0 - w, 0) : min(x0 + 2 * w, img_w),
]
cropped_img_np = img_np[
max(y0 - h, 0) : min(y0 + 2 * h, img_h),
max(x0 - w, 0) : min(x0 + 2 * w, img_w),
]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "focal_crop":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert (
img_np.shape[:2] == mask_np.shape
), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
img_h, img_w = img_np.shape[:2]
xc, yc = x0 + w / 2, y0 + h / 2
# focal_crop: need to have at least min_box_w and min_box_h pixels, otherwise resizing to (384, 384) leads to artifacts that may be OOD
w, h = max(w, min_box_w), max(h, min_box_h)
x0, y0 = int(xc - w / 2), int(yc - h / 2)
cropped_mask_np = mask_np[
max(y0 - h, 0) : min(y0 + 2 * h, img_h),
max(x0 - w, 0) : min(x0 + 2 * w, img_w),
]
cropped_img_np = img_np[
max(y0 - h, 0) : min(y0 + 2 * h, img_h),
max(x0 - w, 0) : min(x0 + 2 * w, img_w),
]
cropped_pil_img = Image.fromarray(cropped_img_np)
elif crop_mode == "crop_mask":
# crop image and mask
x0, y0, w, h = cls.mask_to_box(mask_np)
img_np = np.asarray(pil_img)
assert (
img_np.shape[:2] == mask_np.shape
), f"image shape mismatches with mask shape: {img_np.shape}, {mask_np.shape}"
cropped_mask_np = mask_np[y0 : y0 + h, x0 : x0 + w]
cropped_img_np = img_np[y0 : y0 + h, x0 : x0 + w]
# Mask the image
cropped_img_np = cropped_img_np * cropped_mask_np[..., None]
cropped_pil_img = Image.fromarray(cropped_img_np)
else:
raise ValueError(f"Unsupported crop_mode: {crop_mode}")
info = dict(mask_np=cropped_mask_np)
return cropped_pil_img, info
def get_description(
self,
image_pil,
mask_pil,
query,
streaming=False,
temperature=0.2,
top_p=0.5,
num_beams=1,
max_new_tokens=512,
**kwargs,
):
# kwargs is passed to generation_kwargs: https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig
prompt, conv = self.get_prompt(query)
if not isinstance(image_pil, (list, tuple)):
assert not isinstance(
mask_pil, (list, tuple)
), "image_pil and mask_pil must be both list or tuple or not list or tuple."
image_pils = [image_pil]
mask_pils = [mask_pil]
else:
image_pils = image_pil
mask_pils = mask_pil
description = self.get_description_from_prompt(
image_pils,
mask_pils,
prompt,
conv,
streaming=streaming,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
**kwargs,
)
return description
def get_image_tensor(self, image_pil, mask_pil, crop_mode, crop_mode2):
# the pil has True/False (if the value is non-zero, then we treat it as True)
mask_np = (np.asarray(mask_pil) > 0).astype(np.uint8)
images_tensor, image_info = process_image(
image_pil,
self.model.config,
None,
pil_preprocess_fn=lambda pil_img: self.crop_image(
image_pil, mask_np=mask_np, crop_mode=crop_mode
),
)
images_tensor = images_tensor[None].to(self.model.device, dtype=torch.float16)
mask_np = image_info["mask_np"]
mask_pil = Image.fromarray(mask_np * 255)
masks_tensor = process_image(mask_pil, self.model.config, None)
masks_tensor = masks_tensor[None].to(self.model.device, dtype=torch.float16)
images_tensor = torch.cat((images_tensor, masks_tensor[:, :1, ...]), dim=1)
if crop_mode2 is not None:
images_tensor2, image_info2 = process_image(
image_pil,
self.model.config,
None,
pil_preprocess_fn=lambda pil_img: self.crop_image(
pil_img, mask_np=mask_np, crop_mode=crop_mode2
),
)
images_tensor2 = images_tensor2[None].to(
self.model.device, dtype=torch.float16
)
mask_np2 = image_info2["mask_np"]
mask_pil2 = Image.fromarray(mask_np2 * 255)
masks_tensor2 = process_image(mask_pil2, self.model.config, None)
masks_tensor2 = masks_tensor2[None].to(
self.model.device, dtype=torch.float16
)
images_tensor2 = torch.cat(
(images_tensor2, masks_tensor2[:, :1, ...]), dim=1
)
else:
images_tensor2 = None
return (
torch.cat((images_tensor, images_tensor2), dim=1)
if images_tensor2 is not None
else images_tensor
)
def get_description_from_prompt(
self,
image_pils,
mask_pils,
prompt,
conv,
streaming=False,
temperature=0.2,
top_p=0.5,
num_beams=1,
max_new_tokens=512,
**kwargs,
):
if streaming:
return self.get_description_from_prompt_iterator(
image_pils,
mask_pils,
prompt,
conv,
streaming=True,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
**kwargs,
)
else:
# If streaming is False, there will be only one output
output = self.get_description_from_prompt_iterator(
image_pils,
mask_pils,
prompt,
conv,
streaming=False,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
**kwargs,
)
return next(output)
def get_description_from_prompt_iterator(
self,
image_pils,
mask_pils,
prompt,
conv,
streaming=False,
temperature=0.2,
top_p=0.5,
num_beams=1,
max_new_tokens=512,
**kwargs,
):
crop_mode, crop_mode2 = self.prompt_mode.split("+")
assert (
crop_mode == "full"
), "Current prompt only supports first crop as full (non-cropped). If you need other specifications, please update the prompt."
assert len(image_pils) == len(
mask_pils
), f"image_pils and mask_pils must have the same length. Got {len(image_pils)} and {len(mask_pils)}."
image_tensors = [
self.get_image_tensor(
image_pil, mask_pil, crop_mode=crop_mode, crop_mode2=crop_mode2
)
for image_pil, mask_pil in zip(image_pils, mask_pils)
]
input_ids = (
tokenizer_image_token(
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.cuda()
)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(
keywords, self.tokenizer, input_ids
)
streamer = (
TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True
)
if streaming
else None
)
generation_kwargs = dict(
input_ids=input_ids,
images=image_tensors,
do_sample=True if temperature > 0 else False,
use_cache=True,
stopping_criteria=[stopping_criteria],
streamer=streamer,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
**kwargs,
)
if streaming:
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if stop_str in generated_text:
generated_text = generated_text[: generated_text.find(stop_str)]
break
yield new_text
thread.join()
else:
with torch.inference_mode():
output_ids = self.model.generate(**generation_kwargs)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[
0
]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
yield outputs