# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import importlib.util from packaging import version import math from typing import Callable, Optional import torch from torch import nn # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata _xformers_available = importlib.util.find_spec("xformers") is not None try: _xformers_version = importlib_metadata.version("xformers") import torch if version.Version(torch.__version__) < version.Version("1.12"): raise ValueError("PyTorch should be >= 1.12") print(f"Successfully imported xformers version {_xformers_version}") except importlib_metadata.PackageNotFoundError: _xformers_available = False if _xformers_available: import xformers import xformers.ops else: xformers = None class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. Uses three q, k, v linear layers to compute attention. Parameters: channels (`int`): The number of channels in the input and output. num_head_channels (`int`, *optional*): The number of channels in each head. If None, then `num_heads` = 1. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore def __init__( self, channels: int, num_head_channels: Optional[int] = None, norm_num_groups: int = 32, rescale_output_factor: float = 1.0, eps: float = 1e-5, ): super().__init__() self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 self.num_head_size = num_head_channels self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) # define q,k,v as linear layers self.query = nn.Linear(channels, channels) self.key = nn.Linear(channels, channels) self.value = nn.Linear(channels, channels) self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, 1) self._use_memory_efficient_attention_xformers = False self._attention_op = None def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.num_heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): if use_memory_efficient_attention_xformers: if not _xformers_available: raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" " xformers" ), name="xformers", ) elif not torch.cuda.is_available(): raise ValueError( "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" " only available for GPU " ) else: try: # Make sure we can run the memory efficient attention _ = xformers.ops.memory_efficient_attention( torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), torch.randn((1, 2, 40), device="cuda"), ) except Exception as e: raise e self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self._attention_op = attention_op def forward(self, hidden_states): residual = hidden_states batch, channel, height, width = hidden_states.shape # norm hidden_states = self.group_norm(hidden_states) hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) # proj to q, k, v query_proj = self.query(hidden_states) key_proj = self.key(hidden_states) value_proj = self.value(hidden_states) scale = 1 / math.sqrt(self.channels / self.num_heads) query_proj = self.reshape_heads_to_batch_dim(query_proj) key_proj = self.reshape_heads_to_batch_dim(key_proj) value_proj = self.reshape_heads_to_batch_dim(value_proj) if self._use_memory_efficient_attention_xformers: # Memory efficient attention hidden_states = xformers.ops.memory_efficient_attention( query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op ) hidden_states = hidden_states.to(query_proj.dtype) else: attention_scores = torch.baddbmm( torch.empty( query_proj.shape[0], query_proj.shape[1], key_proj.shape[1], dtype=query_proj.dtype, device=query_proj.device, ), query_proj, key_proj.transpose(-1, -2), beta=0, alpha=scale, ) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) hidden_states = torch.bmm(attention_probs, value_proj) # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor return hidden_states