File size: 21,654 Bytes
4700ca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.


# Implementation of 2D Rotary Position Embeddings (RoPE).

# This module provides a clean implementation of 2D Rotary Position Embeddings,
# which extends the original RoPE concept to handle 2D spatial positions.

# Inspired by:
#         https://github.com/meta-llama/codellama/blob/main/llama/model.py
#         https://github.com/naver-ai/rope-vit


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple

from typing import List, Optional, Tuple, Union


class PositionGetter:
    """Generates and caches 2D spatial positions for patches in a grid.

    This class efficiently manages the generation of spatial coordinates for patches
    in a 2D grid, caching results to avoid redundant computations.

    Attributes:
        position_cache: Dictionary storing precomputed position tensors for different
            grid dimensions.
    """

    def __init__(self):
        """Initializes the position generator with an empty cache."""
        self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}

    def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
        """Generates spatial positions for a batch of patches.

        Args:
            batch_size: Number of samples in the batch.
            height: Height of the grid in patches.
            width: Width of the grid in patches.
            device: Target device for the position tensor.

        Returns:
            Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
            for each position in the grid, repeated for each batch item.
        """
        if (height, width) not in self.position_cache:
            y_coords = torch.arange(height, device=device)
            x_coords = torch.arange(width, device=device)
            positions = torch.cartesian_prod(y_coords, x_coords)
            self.position_cache[height, width] = positions

        cached_positions = self.position_cache[height, width]
        return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()


class RotaryPositionEmbedding2D(nn.Module):
    """2D Rotary Position Embedding implementation.

    This module applies rotary position embeddings to input tokens based on their
    2D spatial positions. It handles the position-dependent rotation of features
    separately for vertical and horizontal dimensions.

    Args:
        frequency: Base frequency for the position embeddings. Default: 100.0
        scaling_factor: Scaling factor for frequency computation. Default: 1.0

    Attributes:
        base_frequency: Base frequency for computing position embeddings.
        scaling_factor: Factor to scale the computed frequencies.
        frequency_cache: Cache for storing precomputed frequency components.
    """

    def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
        """Initializes the 2D RoPE module."""
        super().__init__()
        self.base_frequency = frequency
        self.scaling_factor = scaling_factor
        self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}

    def _compute_frequency_components(
        self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Computes frequency components for rotary embeddings.

        Args:
            dim: Feature dimension (must be even).
            seq_len: Maximum sequence length.
            device: Target device for computations.
            dtype: Data type for the computed tensors.

        Returns:
            Tuple of (cosine, sine) tensors for frequency components.
        """
        cache_key = (dim, seq_len, device, dtype)
        if cache_key not in self.frequency_cache:
            # Compute frequency bands
            exponents = torch.arange(0, dim, 2, device=device).float() / dim
            inv_freq = 1.0 / (self.base_frequency**exponents)

            # Generate position-dependent frequencies
            positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
            angles = torch.einsum("i,j->ij", positions, inv_freq)

            # Compute and cache frequency components
            angles = angles.to(dtype)
            angles = torch.cat((angles, angles), dim=-1)
            cos_components = angles.cos().to(dtype)
            sin_components = angles.sin().to(dtype)
            self.frequency_cache[cache_key] = (cos_components, sin_components)

        return self.frequency_cache[cache_key]

    @staticmethod
    def _rotate_features(x: torch.Tensor) -> torch.Tensor:
        """Performs feature rotation by splitting and recombining feature dimensions.

        Args:
            x: Input tensor to rotate.

        Returns:
            Rotated feature tensor.
        """
        feature_dim = x.shape[-1]
        x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def _apply_1d_rope(
        self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
    ) -> torch.Tensor:
        """Applies 1D rotary position embeddings along one dimension.

        Args:
            tokens: Input token features.
            positions: Position indices.
            cos_comp: Cosine components for rotation.
            sin_comp: Sine components for rotation.

        Returns:
            Tokens with applied rotary position embeddings.
        """
        # Embed positions with frequency components
        cos = F.embedding(positions, cos_comp)[:, None, :, :]
        sin = F.embedding(positions, sin_comp)[:, None, :, :]

        # Apply rotation
        return (tokens * cos) + (self._rotate_features(tokens) * sin)

    def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
        """Applies 2D rotary position embeddings to input tokens.

        Args:
            tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
                   The feature dimension (dim) must be divisible by 4.
            positions: Position tensor of shape (batch_size, n_tokens, 2) containing
                      the y and x coordinates for each token.

        Returns:
            Tensor of same shape as input with applied 2D rotary position embeddings.

        Raises:
            AssertionError: If input dimensions are invalid or positions are malformed.
        """
        # Validate inputs
        assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
        assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"

        # Compute feature dimension for each spatial direction
        feature_dim = tokens.size(-1) // 2

        # Get frequency components
        max_position = int(positions.max()) + 1
        cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)

        # Split features for vertical and horizontal processing
        vertical_features, horizontal_features = tokens.chunk(2, dim=-1)

        # Apply RoPE separately for each dimension
        vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
        horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)

        # Combine processed features
        return torch.cat((vertical_features, horizontal_features), dim=-1)
    


def get_1d_rotary_pos_embed(
    dim: int,
    pos: Union[np.ndarray, int],
    theta: float = 10000.0,
    use_real=False,
    linear_factor=1.0,
    ntk_factor=1.0,
    repeat_interleave_real=True,
    freqs_dtype=torch.float32,  #  torch.float32, torch.float64 (flux)
):
    """
    计算1D旋转位置编码(RoPE)的频率张量。
    
    RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
    公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
    
    Args:
        dim: 特征维度,必须是偶数(因为要成对处理)
        pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
        theta: 基础频率,控制位置编码的周期性(默认10000)
        use_real: 是否返回实数形式(cos和sin分开)还是复数形式
        linear_factor: 线性缩放因子,用于上下文扩展
        ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
        repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
        freqs_dtype: 频率张量的数据类型
        
    Returns:
        复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
        实数形式:两个 [S, D] 的张量(cos和sin)
    """
    # 确保维度是偶数(RoPE需要成对处理维度)
    assert dim % 2 == 0

    # 将位置转换为torch张量
    if isinstance(pos, int):
        pos = torch.arange(pos)  # 生成 [0, 1, 2, ..., pos-1]
    if isinstance(pos, np.ndarray):
        pos = torch.from_numpy(pos)  # [S]

    # 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
    theta = theta * ntk_factor
    
    # 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
    # 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
    # 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
    freqs = (
        1.0
        / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
        / linear_factor
    )  # [D/2],每个频率对应一个维度对
    
    # 步骤2:计算位置-频率矩阵
    # 使用外积:pos[m] * freqs[i] = m * θ_i
    # 结果:每个位置m和每个频率i的组合
    freqs = torch.outer(pos, freqs)  # [S, D/2]
    
    # 步骤3:根据返回格式转换
    if use_real and repeat_interleave_real:
        # 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
        # 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float()  # [S, D]
        return freqs_cos, freqs_sin
    elif use_real:
        # 方式2:拼接重复(用于stable audio, allegro等模型)
        # 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
        freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float()  # [S, D]
        freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float()  # [S, D]
        return freqs_cos, freqs_sin
    else:
        # 方式3:复数形式(用于lumina等模型)
        # 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
        # torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64: [S, D/2]
        return freqs_cis


class WanRotaryPosEmbed(nn.Module):
    """
    3D旋转位置编码(3D RoPE)模块
    
    核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
    每个维度(t, h, w)独立使用RoPE,然后拼接起来。
    
    公式:
    对于3D位置 (f, h, w)(帧、高度、宽度):
    - 帧维度使用 dim_f 个特征维度
    - 高度维度使用 dim_h 个特征维度  
    - 宽度维度使用 dim_w 个特征维度
    其中 dim_f + dim_h + dim_w = attention_head_dim
    """
    def __init__(
        self,
        attention_head_dim: int,
        patch_size: Tuple[int, int, int],
        max_seq_len: int = 1024,
        theta: float = 10000.0,
        fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
    ):
        super().__init__()

        self.attention_head_dim = attention_head_dim  # 注意力头的总维度
        self.patch_size = patch_size  # patch大小 (patch_f, patch_h, patch_w)
        self.max_seq_len = max_seq_len  # 最大序列长度(用于预计算频率)

        # 步骤1:分配维度给三个空间维度
        if fhw_dim is not None:
            # 如果指定了维度分配,使用指定的
            assert attention_head_dim == sum(
                fhw_dim
            ), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
            t_dim, h_dim, w_dim = fhw_dim
        else:
            # 否则自动分配:h和w各占1/3,t占剩余
            # 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
            h_dim = w_dim = 2 * (attention_head_dim // 6)
            t_dim = attention_head_dim - h_dim - w_dim
        
        # 保存维度分配以便在forward中使用
        self.fhw_dim = (t_dim, h_dim, w_dim)

        # 步骤2:为每个维度预计算频率
        # 分别计算时间、高度、宽度三个维度的RoPE频率
        freqs = []
        for dim in [t_dim, h_dim, w_dim]:
            # 每个维度独立调用1D RoPE
            # 返回复数形式的频率: [max_seq_len, dim//2]
            freq = get_1d_rotary_pos_embed(
                dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
            )
            freqs.append(freq)
        # 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
        self.freqs = torch.cat(freqs, dim=1)

    def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
        """
        前向传播:为3D输入(视频帧+patch)生成旋转位置编码
        
        参数:
        - ppf (int): 帧数(patches per frame),当f_end为None时使用
        - pph (int): 每帧的patch高度数量
        - ppw (int): 每帧的patch宽度数量  
        - patch_start_idx (int): 每帧的特殊token数量(在patches之前)
        - device: 计算设备(CPU/GPU)
        - f_start (int): 起始帧索引(用于causal模式),默认为0
        - f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
        
        返回:
        - freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
        
        Token排列顺序:
        [frame0_special_token_0, ..., frame0_special_token_N,
         frame0_patch_0, ..., frame0_patch_M,
         frame1_special_token_0, ..., frame1_special_token_N,
         frame1_patch_0, ..., frame1_patch_M,
         ...]
        
        模式:
        - 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
        - Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
        """

        # 步骤1:将预计算的频率移到目标设备,并分割成三个维度
        self.freqs = self.freqs.to(device)
        # 获取实际的维度分配
        if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
            t_dim, h_dim, w_dim = self.fhw_dim
        else:
            # 自动分配的情况
            h_dim = w_dim = 2 * (self.attention_head_dim // 6)
            t_dim = self.attention_head_dim - h_dim - w_dim
        
        # 使用正确的split sizes(每个维度的一半)
        freqs = self.freqs.split_with_sizes(
            [
                t_dim // 2,  # 时间维度
                h_dim // 2,  # 高度维度
                w_dim // 2,  # 宽度维度
            ],
            dim=1,
        )
        
        # 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
        if f_end is not None:
            ppf = f_end - f_start
            frame_slice = slice(f_start, f_end)
        else:
            # 非causal模式:使用从0开始的ppf个帧
            frame_slice = slice(0, ppf)
        
        # 步骤2:处理特殊token(如果存在)
        ## For other tokens
        if patch_start_idx > 0:
            # 2.1 为特殊token生成位置编码
            # 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
            # camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
            # Shape: (ppf, patch_start_idx, dim)
            freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1)  # (ppf, patch_start_idx, dim_f) 帧维度变化
            freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1)  # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
            freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1)  # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
            freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1)  # (ppf, patch_start_idx, dim) 拼接三维
            freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1)  # (ppf, patch_start_idx, dim)

            # 2.2 为图像patch生成位置编码
            # Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
            # 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
            # Shape: (ppf, pph, ppw, dim)
            freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)  # (ppf, pph, ppw, dim_f) 帧维度
            freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1)  # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
            freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)  # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
            freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1)  # (ppf, pph, ppw, dim) 拼接三维
            freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1)  # (ppf, pph * ppw, dim) 展平空间维度
            
            # 步骤3:按照正确的顺序组合特殊token和patches
            # 每帧内部顺序:[特殊tokens, patches]
            # Concatenate special tokens and patches for each frame along the second dimension
            # Shape: (ppf, patch_start_idx + pph * ppw, dim)
            freqs = torch.cat([freqs_special, freqs_patches], dim=1)  # (ppf, patch_start_idx + pph * ppw, dim)
            
            # 步骤4:展平为最终形状并添加batch和head维度
            # Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
            freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
            freqs = freqs.unsqueeze(0).unsqueeze(0)  # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
            return freqs
        
        # 如果没有特殊token(patch_start_idx == 0),只处理图像patches
        # 所有patches位于 (f, 0:pph, 0:ppw)
        freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)  # (ppf, pph, ppw, dim_f) 帧维度
        freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1)  # (ppf, pph, ppw, dim_h) 高度从0开始
        freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)  # (ppf, pph, ppw, dim_w) 宽度从0开始
        freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)  # (1, 1, ppf * pph * ppw, dim)
        return freqs
    
def apply_rotary_emb(x, freqs):
    """
    应用旋转位置编码到输入特征
    
    核心思想:使用复数乘法实现特征旋转,保持相对位置信息
    
    数学原理:
    对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
    (x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
                        = (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
    
    这等价于旋转矩阵:
    [cos(θ)  -sin(θ)] [x1]
    [sin(θ)   cos(θ)] [x2]
    
    参数:
    - x: 输入特征 [batch, heads, seq_len, head_dim]
    - freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
    
    返回:
    - x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
    
    实现步骤:
    1. 将x的每两个连续特征看作一个复数 (real, imag)
    2. 与预计算的复数频率 e^(iθ) 相乘
    3. 转回实数表示
    """
    # 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
    # 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
    x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
    
    # 步骤2:转换为复数表示 [b, h, seq, 32]
    # 每个元素是 real + imag*i
    x_complex = torch.view_as_complex(x_reshaped)
    
    # 步骤3:复数乘法实现旋转
    # x_complex * freqs 相当于将每对特征旋转θ角度
    # freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
    x_rotated = x_complex * freqs
    
    # 步骤4:转回实数表示 [b, h, seq, 32, 2]
    x_real = torch.view_as_real(x_rotated)
    
    # 步骤5:展平最后两维 [b, h, seq, 64]
    x_out = x_real.flatten(3)
    
    # 步骤6:转回原始数据类型
    return x_out.to(x.dtype)