| import torch |
| from torch.nn.attention.flex_attention import _mask_mod_signature |
|
|
| def causal_mask( |
| batch_size: int, |
| num_heads: int, |
| q_idx: torch.Tensor, |
| kv_idx: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Returns a boolean tensor indicating which positions in the attention matrix |
| are valid for causal (autoregressive) attention. By default, it's True for |
| positions (i, j) where i >= j. |
| |
| Args: |
| batch_size (int): Batch size (unused here). |
| num_heads (int): Number of heads (unused here). |
| q_idx (torch.Tensor): Tensor indexing the query positions. |
| kv_idx (torch.Tensor): Tensor indexing the key/value positions. |
| |
| Returns: |
| torch.Tensor: A boolean tensor where True indicates that the query at |
| position i can attend to the key at position j, respecting i >= j. |
| """ |
| return q_idx >= kv_idx |
|
|
|
|
| def generate_sliding_window_mask(window_size: int, causal: bool = True) -> _mask_mod_signature: |
| """ |
| Creates a sliding window mask function. |
| |
| If `causal=True`, each query token at position i can attend only to tokens j |
| in [i - window_size, i]. |
| If `causal=False`, each query token i can attend to any token j in |
| [i - window_size, i + window_size], i.e. a symmetric window of size `window_size`. |
| |
| Args: |
| window_size (int): The maximum distance from i that i can attend to. |
| causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True. |
| |
| Returns: |
| _mask_mod_signature: A callable mask function that takes |
| (batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor |
| indicating allowed attention connections. |
| """ |
| def sliding_window_mask( |
| batch_size: int, |
| num_heads: int, |
| q_idx: torch.Tensor, |
| kv_idx: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| If causal is True: |
| within_window = (q_idx - kv_idx) <= window_size, and q_idx >= kv_idx. |
| If causal is False: |
| within_window = abs(q_idx - kv_idx) <= window_size. |
| """ |
| if causal: |
| |
| distance = q_idx - kv_idx |
| within_window = (distance >= 0) & (distance <= window_size) |
| else: |
| |
| distance = (q_idx - kv_idx).abs() |
| within_window = distance <= window_size |
|
|
| return within_window |
|
|
| name_ext = "causal" if causal else "noncausal" |
| sliding_window_mask.__name__ = f"sliding_window_{window_size}_{name_ext}" |
| return sliding_window_mask |
|
|
|
|
| def generate_dilated_sliding_window_mask( |
| window_size: int, |
| dilation: int = 2, |
| causal: bool = True |
| ) -> _mask_mod_signature: |
| """ |
| Creates a dilated sliding window mask function. |
| |
| If `causal=True`, each query token i can attend tokens j in [i - window_size, i] |
| such that (i - j) % dilation == 0. |
| If `causal=False`, each query token i can attend tokens j in [i - window_size, |
| i + window_size] for which |i - j| % dilation == 0. |
| |
| Args: |
| window_size (int): The maximum distance from i to j (backwards if causal=True, |
| otherwise symmetric around i). |
| dilation (int): The stride for skipping positions. |
| causal (bool): Whether to enforce causal ordering (i >= j). Defaults to True. |
| |
| Returns: |
| _mask_mod_signature: A callable mask function that takes |
| (batch_size, num_heads, q_idx, kv_idx) and returns a boolean tensor |
| indicating allowed attention connections. |
| """ |
| def dilated_sliding_window_mask( |
| batch_size: int, |
| num_heads: int, |
| q_idx: torch.Tensor, |
| kv_idx: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| If causal is True: |
| distance = q_idx - kv_idx |
| 0 <= distance <= window_size and distance % dilation == 0. |
| If causal is False: |
| distance = (q_idx - kv_idx).abs() |
| distance <= window_size and distance % dilation == 0. |
| """ |
| if causal: |
| distance = q_idx - kv_idx |
| within_window = (distance >= 0) & (distance <= window_size) |
| else: |
| distance = (q_idx - kv_idx).abs() |
| within_window = distance <= window_size |
|
|
| meets_dilation = (distance % dilation) == 0 |
| return within_window & meets_dilation |
|
|
| mode_str = "causal" if causal else "noncausal" |
| dilated_sliding_window_mask.__name__ = ( |
| f"dilated_sliding_window_{window_size}_dilation_{dilation}_{mode_str}" |
| ) |
| return dilated_sliding_window_mask |
|
|
|
|
| def main(): |
| """ |
| Demonstrates usage of each mask by printing attention grids. We include a few |
| basic checks to ensure the masks behave as expected. We show both the causal |
| and non-causal versions for the sliding window and dilated masks. |
| """ |
| B, H = 1, 1 |
| Q_LEN, KV_LEN = 8, 8 |
|
|
| |
| q_idx = torch.arange(Q_LEN).unsqueeze(-1).expand(Q_LEN, KV_LEN) |
| kv_idx = torch.arange(KV_LEN).unsqueeze(0).expand(Q_LEN, KV_LEN) |
|
|
| print("= Causal Mask =") |
| c_mask = causal_mask(B, H, q_idx, kv_idx) |
| print(c_mask.int(), "\n") |
|
|
| print("= Sliding Window (window_size=2, causal=True) =") |
| sw_causal_fn = generate_sliding_window_mask(window_size=2, causal=True) |
| sw_causal = sw_causal_fn(B, H, q_idx, kv_idx) |
| print(sw_causal.int(), "\n") |
|
|
| print("= Sliding Window (window_size=2, causal=False) =") |
| sw_noncausal_fn = generate_sliding_window_mask(window_size=2, causal=False) |
| sw_noncausal = sw_noncausal_fn(B, H, q_idx, kv_idx) |
| print(sw_noncausal.int(), "\n") |
|
|
| print("= Dilated Sliding Window (window_size=4, dilation=2, causal=True) =") |
| ds_causal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=True) |
| ds_causal = ds_causal_fn(B, H, q_idx, kv_idx) |
| print(ds_causal.int(), "\n") |
|
|
| print("= Dilated Sliding Window (window_size=4, dilation=2, causal=False) =") |
| ds_noncausal_fn = generate_dilated_sliding_window_mask(window_size=4, dilation=2, causal=False) |
| ds_noncausal = ds_noncausal_fn(B, H, q_idx, kv_idx) |
| print(ds_noncausal.int(), "\n") |
|
|
| |
| |
| assert torch.all(c_mask == (q_idx >= kv_idx)), "Causal mask mismatch!" |
| |
| i = 5 |
| row_sw = sw_causal[i] |
| allowed_js = torch.where(row_sw)[0] |
| if len(allowed_js) > 0: |
| |
| assert (i - allowed_js.min()) <= 2, "Window mismatch for sliding_window_mask(causal=True)." |
|
|
| |
| i = 6 |
| row_ds = ds_causal[i] |
| allowed_js = torch.where(row_ds)[0] |
| for j in allowed_js: |
| diff = i - j |
| assert diff % 2 == 0, f"Dilation mismatch: got diff={diff}." |
|
|
| print("All checks passed.") |
|
|
| if __name__ == "__main__": |
| main() |