Spaces:
Running
Running
File size: 21,654 Bytes
4700ca8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 | # Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Implementation of 2D Rotary Position Embeddings (RoPE).
# This module provides a clean implementation of 2D Rotary Position Embeddings,
# which extends the original RoPE concept to handle 2D spatial positions.
# Inspired by:
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
# https://github.com/naver-ai/rope-vit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import List, Optional, Tuple, Union
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid.
This class efficiently manages the generation of spatial coordinates for patches
in a 2D grid, caching results to avoid redundant computations.
Attributes:
position_cache: Dictionary storing precomputed position tensors for different
grid dimensions.
"""
def __init__(self):
"""Initializes the position generator with an empty cache."""
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Generates spatial positions for a batch of patches.
Args:
batch_size: Number of samples in the batch.
height: Height of the grid in patches.
width: Width of the grid in patches.
device: Target device for the position tensor.
Returns:
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
for each position in the grid, repeated for each batch item.
"""
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
"""2D Rotary Position Embedding implementation.
This module applies rotary position embeddings to input tokens based on their
2D spatial positions. It handles the position-dependent rotation of features
separately for vertical and horizontal dimensions.
Args:
frequency: Base frequency for the position embeddings. Default: 100.0
scaling_factor: Scaling factor for frequency computation. Default: 1.0
Attributes:
base_frequency: Base frequency for computing position embeddings.
scaling_factor: Factor to scale the computed frequencies.
frequency_cache: Cache for storing precomputed frequency components.
"""
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
"""Initializes the 2D RoPE module."""
super().__init__()
self.base_frequency = frequency
self.scaling_factor = scaling_factor
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
def _compute_frequency_components(
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes frequency components for rotary embeddings.
Args:
dim: Feature dimension (must be even).
seq_len: Maximum sequence length.
device: Target device for computations.
dtype: Data type for the computed tensors.
Returns:
Tuple of (cosine, sine) tensors for frequency components.
"""
cache_key = (dim, seq_len, device, dtype)
if cache_key not in self.frequency_cache:
# Compute frequency bands
exponents = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency**exponents)
# Generate position-dependent frequencies
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
# Compute and cache frequency components
angles = angles.to(dtype)
angles = torch.cat((angles, angles), dim=-1)
cos_components = angles.cos().to(dtype)
sin_components = angles.sin().to(dtype)
self.frequency_cache[cache_key] = (cos_components, sin_components)
return self.frequency_cache[cache_key]
@staticmethod
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
"""Performs feature rotation by splitting and recombining feature dimensions.
Args:
x: Input tensor to rotate.
Returns:
Rotated feature tensor.
"""
feature_dim = x.shape[-1]
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d_rope(
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
) -> torch.Tensor:
"""Applies 1D rotary position embeddings along one dimension.
Args:
tokens: Input token features.
positions: Position indices.
cos_comp: Cosine components for rotation.
sin_comp: Sine components for rotation.
Returns:
Tokens with applied rotary position embeddings.
"""
# Embed positions with frequency components
cos = F.embedding(positions, cos_comp)[:, None, :, :]
sin = F.embedding(positions, sin_comp)[:, None, :, :]
# Apply rotation
return (tokens * cos) + (self._rotate_features(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Applies 2D rotary position embeddings to input tokens.
Args:
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
The feature dimension (dim) must be divisible by 4.
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
the y and x coordinates for each token.
Returns:
Tensor of same shape as input with applied 2D rotary position embeddings.
Raises:
AssertionError: If input dimensions are invalid or positions are malformed.
"""
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
计算1D旋转位置编码(RoPE)的频率张量。
RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
Args:
dim: 特征维度,必须是偶数(因为要成对处理)
pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
theta: 基础频率,控制位置编码的周期性(默认10000)
use_real: 是否返回实数形式(cos和sin分开)还是复数形式
linear_factor: 线性缩放因子,用于上下文扩展
ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
freqs_dtype: 频率张量的数据类型
Returns:
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
实数形式:两个 [S, D] 的张量(cos和sin)
"""
# 确保维度是偶数(RoPE需要成对处理维度)
assert dim % 2 == 0
# 将位置转换为torch张量
if isinstance(pos, int):
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # [S]
# 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
theta = theta * ntk_factor
# 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
# 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2],每个频率对应一个维度对
# 步骤2:计算位置-频率矩阵
# 使用外积:pos[m] * freqs[i] = m * θ_i
# 结果:每个位置m和每个频率i的组合
freqs = torch.outer(pos, freqs) # [S, D/2]
# 步骤3:根据返回格式转换
if use_real and repeat_interleave_real:
# 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
# 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# 方式2:拼接重复(用于stable audio, allegro等模型)
# 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# 方式3:复数形式(用于lumina等模型)
# 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
# torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
return freqs_cis
class WanRotaryPosEmbed(nn.Module):
"""
3D旋转位置编码(3D RoPE)模块
核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
每个维度(t, h, w)独立使用RoPE,然后拼接起来。
公式:
对于3D位置 (f, h, w)(帧、高度、宽度):
- 帧维度使用 dim_f 个特征维度
- 高度维度使用 dim_h 个特征维度
- 宽度维度使用 dim_w 个特征维度
其中 dim_f + dim_h + dim_w = attention_head_dim
"""
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int = 1024,
theta: float = 10000.0,
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
):
super().__init__()
self.attention_head_dim = attention_head_dim # 注意力头的总维度
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
# 步骤1:分配维度给三个空间维度
if fhw_dim is not None:
# 如果指定了维度分配,使用指定的
assert attention_head_dim == sum(
fhw_dim
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
t_dim, h_dim, w_dim = fhw_dim
else:
# 否则自动分配:h和w各占1/3,t占剩余
# 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
# 保存维度分配以便在forward中使用
self.fhw_dim = (t_dim, h_dim, w_dim)
# 步骤2:为每个维度预计算频率
# 分别计算时间、高度、宽度三个维度的RoPE频率
freqs = []
for dim in [t_dim, h_dim, w_dim]:
# 每个维度独立调用1D RoPE
# 返回复数形式的频率: [max_seq_len, dim//2]
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
)
freqs.append(freq)
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
self.freqs = torch.cat(freqs, dim=1)
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
"""
前向传播:为3D输入(视频帧+patch)生成旋转位置编码
参数:
- ppf (int): 帧数(patches per frame),当f_end为None时使用
- pph (int): 每帧的patch高度数量
- ppw (int): 每帧的patch宽度数量
- patch_start_idx (int): 每帧的特殊token数量(在patches之前)
- device: 计算设备(CPU/GPU)
- f_start (int): 起始帧索引(用于causal模式),默认为0
- f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
返回:
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
Token排列顺序:
[frame0_special_token_0, ..., frame0_special_token_N,
frame0_patch_0, ..., frame0_patch_M,
frame1_special_token_0, ..., frame1_special_token_N,
frame1_patch_0, ..., frame1_patch_M,
...]
模式:
- 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
- Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
"""
# 步骤1:将预计算的频率移到目标设备,并分割成三个维度
self.freqs = self.freqs.to(device)
# 获取实际的维度分配
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
t_dim, h_dim, w_dim = self.fhw_dim
else:
# 自动分配的情况
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
t_dim = self.attention_head_dim - h_dim - w_dim
# 使用正确的split sizes(每个维度的一半)
freqs = self.freqs.split_with_sizes(
[
t_dim // 2, # 时间维度
h_dim // 2, # 高度维度
w_dim // 2, # 宽度维度
],
dim=1,
)
# 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
if f_end is not None:
ppf = f_end - f_start
frame_slice = slice(f_start, f_end)
else:
# 非causal模式:使用从0开始的ppf个帧
frame_slice = slice(0, ppf)
# 步骤2:处理特殊token(如果存在)
## For other tokens
if patch_start_idx > 0:
# 2.1 为特殊token生成位置编码
# 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
# Shape: (ppf, patch_start_idx, dim)
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
# 2.2 为图像patch生成位置编码
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
# Shape: (ppf, pph, ppw, dim)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
# 步骤3:按照正确的顺序组合特殊token和patches
# 每帧内部顺序:[特殊tokens, patches]
# Concatenate special tokens and patches for each frame along the second dimension
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
# 步骤4:展平为最终形状并添加batch和head维度
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
return freqs
# 如果没有特殊token(patch_start_idx == 0),只处理图像patches
# 所有patches位于 (f, 0:pph, 0:ppw)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
return freqs
def apply_rotary_emb(x, freqs):
"""
应用旋转位置编码到输入特征
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
数学原理:
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
这等价于旋转矩阵:
[cos(θ) -sin(θ)] [x1]
[sin(θ) cos(θ)] [x2]
参数:
- x: 输入特征 [batch, heads, seq_len, head_dim]
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
返回:
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
实现步骤:
1. 将x的每两个连续特征看作一个复数 (real, imag)
2. 与预计算的复数频率 e^(iθ) 相乘
3. 转回实数表示
"""
# 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
# 步骤2:转换为复数表示 [b, h, seq, 32]
# 每个元素是 real + imag*i
x_complex = torch.view_as_complex(x_reshaped)
# 步骤3:复数乘法实现旋转
# x_complex * freqs 相当于将每对特征旋转θ角度
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
x_rotated = x_complex * freqs
# 步骤4:转回实数表示 [b, h, seq, 32, 2]
x_real = torch.view_as_real(x_rotated)
# 步骤5:展平最后两维 [b, h, seq, 64]
x_out = x_real.flatten(3)
# 步骤6:转回原始数据类型
return x_out.to(x.dtype)
|