import torch import triton import triton.language as tl def get_autotuning_configs(q_tile_sizes: list): """Generate autotuning configurations optimized for H100.""" warps = [4, 8] stages = [2, 3] return [ triton.Config({'q_tile_size': t}, num_warps=w, num_stages=s) for t in q_tile_sizes for w in warps for s in stages ] @triton.autotune( configs=get_autotuning_configs([64, 128]), key=['seq_len', 'feature_dim'], ) @triton.jit def mosaic_attn_fwd_kernel( q_ptr, k_ptr, v_ptr, output_ptr, lse_ptr, block_indices_ptr, softmax_scale: tl.constexpr, seq_len: tl.constexpr, num_kv_heads: tl.constexpr, num_q_heads: tl.constexpr, q_heads_per_kv_head: tl.constexpr, feature_dim: tl.constexpr, kv_block_size: tl.constexpr, num_kv_blocks_per_q_block: tl.constexpr, q_tile_size: tl.constexpr, ): """ Sparse attention forward kernel: for each query tile (i.e. block chunk), for each query head, attend to a subset of key/value blocks. """ LOG2_E: tl.constexpr = 1.44269504089 q_tile_id = tl.program_id(0) q_head_id = tl.program_id(1) batch_kv_head_id = tl.program_id(2) batch_idx = batch_kv_head_id // num_kv_heads kv_head_idx = batch_kv_head_id % num_kv_heads q_head_idx = kv_head_idx * q_heads_per_kv_head + q_head_id batch_offset = batch_idx * seq_len q_tile_start = q_tile_id * q_tile_size num_blocks_in_seq = seq_len // kv_block_size tiles_per_block = kv_block_size // q_tile_size q_block_id = q_tile_id // tiles_per_block block_indices_offset = ( batch_idx * num_blocks_in_seq * num_kv_heads * num_kv_blocks_per_q_block + q_block_id * num_kv_heads * num_kv_blocks_per_q_block + kv_head_idx * num_kv_blocks_per_q_block ) q_base_ptr = q_ptr + batch_offset * num_q_heads * feature_dim + q_head_idx * feature_dim k_base_ptr = k_ptr + batch_offset * num_kv_heads * feature_dim + kv_head_idx * feature_dim v_base_ptr = v_ptr + batch_offset * num_kv_heads * feature_dim + kv_head_idx * feature_dim q_tile_ptr = tl.make_block_ptr( base=q_base_ptr, shape=(seq_len, feature_dim), strides=(num_q_heads * feature_dim, 1), offsets=(q_tile_start, 0), block_shape=(q_tile_size, feature_dim), order=(1, 0) ) output_tile_ptr = tl.make_block_ptr( base=output_ptr + batch_offset * num_q_heads * feature_dim + q_head_idx * feature_dim, shape=(seq_len, feature_dim), strides=(num_q_heads * feature_dim, 1), offsets=(q_tile_start, 0), block_shape=(q_tile_size, feature_dim), order=(1, 0) ) lse_base_ptr = lse_ptr + (batch_offset + q_tile_start) * num_q_heads + tl.arange(0, q_tile_size) * num_q_heads + q_head_idx output_accum = tl.zeros([q_tile_size, feature_dim], dtype=tl.float32) max_scores = tl.full([q_tile_size], float('-inf'), dtype=tl.float32) sum_exp_scores = tl.zeros([q_tile_size], dtype=tl.float32) q_tile = tl.load(q_tile_ptr) q_tile = (q_tile * softmax_scale * LOG2_E).to(tl.bfloat16) for i in range(num_kv_blocks_per_q_block): kv_block_start = kv_block_size * tl.load(block_indices_ptr + block_indices_offset + i).to(tl.int32) k_block_ptr = tl.make_block_ptr( base=k_base_ptr, shape=(feature_dim, seq_len), strides=(1, num_kv_heads * feature_dim), offsets=(0, kv_block_start), block_shape=(feature_dim, kv_block_size), order=(1, 0) ) v_block_ptr = tl.make_block_ptr( base=v_base_ptr, shape=(seq_len, feature_dim), strides=(num_kv_heads * feature_dim, 1), offsets=(kv_block_start, 0), block_shape=(kv_block_size, feature_dim), order=(1, 0) ) k_block = tl.load(k_block_ptr).to(tl.bfloat16) v_block = tl.load(v_block_ptr).to(tl.bfloat16) attention_scores = tl.dot(q_tile, k_block) new_max = tl.max(attention_scores, axis=1) old_max = max_scores max_scores = tl.maximum(max_scores, new_max) rescale = tl.exp2(old_max - max_scores) attention_probs = tl.exp2(attention_scores - max_scores[:, None]) sum_exp_scores = sum_exp_scores * rescale + tl.sum(attention_probs, axis=1) output_accum = output_accum * rescale[:, None] output_accum += tl.dot(attention_probs.to(tl.bfloat16), v_block) final_output = output_accum / sum_exp_scores[:, None] log_sum_exp = (max_scores + tl.log2(sum_exp_scores)) tl.store(output_tile_ptr, final_output.to(q_ptr.dtype.element_ty)) tl.store(lse_base_ptr, log_sum_exp.to(tl.float32)) def mosaic_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.LongTensor, block_size: int, softmax_scale: float, ): batch_size, seq_len, num_kv_heads, feature_dim = k.shape num_q_heads = q.shape[2] num_kv_blocks_per_q_block = block_indices.shape[-1] q_heads_per_kv_head = num_q_heads // num_kv_heads output = torch.empty(batch_size, seq_len, num_q_heads, feature_dim, dtype=v.dtype, device=q.device) lse = torch.empty(batch_size, seq_len, num_q_heads, dtype=torch.float32, device=q.device) grid = lambda META: ( triton.cdiv(seq_len, META['q_tile_size']), q_heads_per_kv_head, batch_size * num_kv_heads ) mosaic_attn_fwd_kernel[grid]( q_ptr = q, k_ptr = k, v_ptr = v, output_ptr = output, lse_ptr = lse, block_indices_ptr = block_indices, softmax_scale = softmax_scale, seq_len = seq_len, num_kv_heads = num_kv_heads, num_q_heads = num_q_heads, q_heads_per_kv_head = q_heads_per_kv_head, feature_dim = feature_dim, kv_block_size = block_size, num_kv_blocks_per_q_block = num_kv_blocks_per_q_block, ) return output, lse @triton.autotune( configs=get_autotuning_configs([64, 128]), key=['seq_len', 'feature_dim'], ) @triton.jit def mosaic_attn_bwd_q_kernel( q_ptr, k_ptr, v_ptr, lse_ptr, delta_ptr, grad_o_ptr, grad_q_ptr, block_indices_ptr, softmax_scale: tl.constexpr, seq_len: tl.constexpr, num_kv_heads: tl.constexpr, num_q_heads: tl.constexpr, q_heads_per_kv_head: tl.constexpr, feature_dim: tl.constexpr, kv_block_size: tl.constexpr, num_kv_blocks_per_q_block: tl.constexpr, q_tile_size: tl.constexpr, ): LOG2_E: tl.constexpr = 1.44269504089 LN_2: tl.constexpr = 0.69314718056 q_tile_id = tl.program_id(0) q_head_id = tl.program_id(1) batch_kv_head_id = tl.program_id(2) batch_idx = batch_kv_head_id // num_kv_heads kv_head_idx = batch_kv_head_id % num_kv_heads q_head_idx = kv_head_idx * q_heads_per_kv_head + q_head_id batch_offset = batch_idx * seq_len q_tile_start = q_tile_id * q_tile_size tiles_per_block = kv_block_size // q_tile_size q_block_id = q_tile_id // tiles_per_block num_q_blocks = seq_len // kv_block_size block_indices_offset = ( batch_idx * num_q_blocks * num_kv_heads * num_kv_blocks_per_q_block + q_block_id * num_kv_heads * num_kv_blocks_per_q_block + kv_head_idx * num_kv_blocks_per_q_block ) q_offsets = ( tl.arange(0, q_tile_size)[:, None] * num_q_heads * feature_dim + q_head_idx * feature_dim + tl.arange(0, feature_dim)[None, :] ) lse_offsets = tl.arange(0, q_tile_size) * num_q_heads + q_head_idx q_base_ptr = q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim grad_o_base_ptr = grad_o_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim delta_base_ptr = delta_ptr + (batch_offset + q_tile_start) * num_q_heads lse_base_ptr = lse_ptr + (batch_offset + q_tile_start) * num_q_heads grad_q_base_ptr = grad_q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim grad_q_accum = tl.zeros([q_tile_size, feature_dim], dtype=tl.float32) q_tile = tl.load(q_base_ptr + q_offsets) q_tile = (q_tile * softmax_scale * LOG2_E).to(tl.bfloat16) grad_o_tile = tl.load(grad_o_base_ptr + q_offsets).to(tl.bfloat16) delta_vals = tl.load(delta_base_ptr + lse_offsets) lse_vals = tl.load(lse_base_ptr + lse_offsets).to(tl.float32) for i in range(num_kv_blocks_per_q_block): kv_block_idx = tl.load(block_indices_ptr + block_indices_offset + i).to(tl.int32) k_block_ptr = tl.make_block_ptr( base=k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim, shape=(feature_dim, seq_len), strides=(1, num_kv_heads * feature_dim), offsets=(0, kv_block_idx * kv_block_size), block_shape=(feature_dim, kv_block_size), order=(0, 1) ) v_block_ptr = tl.make_block_ptr( base=v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim, shape=(feature_dim, seq_len), strides=(1, num_kv_heads * feature_dim), offsets=(0, kv_block_idx * kv_block_size), block_shape=(feature_dim, kv_block_size), order=(0, 1) ) k_block = tl.load(k_block_ptr).to(tl.bfloat16) v_block = tl.load(v_block_ptr).to(tl.bfloat16) attention_scores = tl.dot(q_tile, k_block) attention_probs = tl.exp2(attention_scores - lse_vals[:, None]) * LN_2 grad_times_v = tl.dot(grad_o_tile, v_block) grad_scores = attention_probs * (grad_times_v - delta_vals[:, None]) grad_q_accum += tl.dot(grad_scores.to(tl.bfloat16), tl.trans(k_block.to(tl.bfloat16))) grad_q_accum = grad_q_accum * softmax_scale * LOG2_E tl.store(grad_q_base_ptr + q_offsets, grad_q_accum.to(q_ptr.dtype.element_ty)) @torch.compile @torch.no_grad() def mosaic_block_mask( block_indices: torch.LongTensor, ): batch_size, num_blocks, num_heads, _ = block_indices.shape block_mask = torch.zeros( batch_size, num_blocks, num_heads, num_blocks, dtype=torch.bool, device=block_indices.device ) batch_idx = torch.arange(batch_size, device=block_indices.device)[:, None, None, None] q_block_idx = torch.arange(num_blocks, device=block_indices.device)[None, :, None, None] head_idx = torch.arange(num_heads, device=block_indices.device)[None, None, :, None] block_mask[batch_idx, q_block_idx, head_idx, block_indices] = True block_mask_transposed = block_mask.permute(0, 2, 3, 1).contiguous() return block_mask_transposed @triton.autotune( configs=get_autotuning_configs([16, 32]), key=['seq_len', 'feature_dim'], ) @triton.jit def mosaic_attn_bwd_kv_kernel( q_ptr, k_ptr, v_ptr, lse_ptr, delta_ptr, grad_o_ptr, grad_k_ptr, grad_v_ptr, block_mask_ptr, softmax_scale: tl.constexpr, seq_len: tl.constexpr, num_kv_heads: tl.constexpr, num_q_heads: tl.constexpr, q_heads_per_kv_head: tl.constexpr, feature_dim: tl.constexpr, kv_block_size: tl.constexpr, q_tile_size: tl.constexpr, ): LOG2_E: tl.constexpr = 1.44269504089 LN_2: tl.constexpr = 0.69314718056 kv_block_id = tl.program_id(0) batch_kv_head_id = tl.program_id(1) batch_idx = batch_kv_head_id // num_kv_heads kv_head_idx = batch_kv_head_id % num_kv_heads batch_offset = batch_idx * seq_len num_blocks_in_seq = seq_len // kv_block_size tiles_per_block = kv_block_size // q_tile_size fine_mask_start = ( batch_idx * num_kv_heads * num_blocks_in_seq * num_blocks_in_seq + kv_head_idx * num_blocks_in_seq * num_blocks_in_seq + kv_block_id * num_blocks_in_seq ) k_block_ptr = tl.make_block_ptr( k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim, (seq_len, feature_dim), (num_kv_heads * feature_dim, 1), (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0) ) v_block_ptr = tl.make_block_ptr( v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim, (seq_len, feature_dim), (num_kv_heads * feature_dim, 1), (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0) ) grad_k_ptr = tl.make_block_ptr( grad_k_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim, (seq_len, feature_dim), (num_kv_heads * feature_dim, 1), (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0) ) grad_v_ptr = tl.make_block_ptr( grad_v_ptr + (batch_offset * num_kv_heads + kv_head_idx) * feature_dim, (seq_len, feature_dim), (num_kv_heads * feature_dim, 1), (kv_block_id * kv_block_size, 0), (kv_block_size, feature_dim), (1, 0) ) k_block = tl.load(k_block_ptr).to(tl.bfloat16) v_block = tl.load(v_block_ptr).to(tl.bfloat16) grad_k_accum = tl.zeros([kv_block_size, feature_dim], dtype=tl.float32) grad_v_accum = tl.zeros([kv_block_size, feature_dim], dtype=tl.float32) for q_block_id in range(num_blocks_in_seq): is_connected = tl.load(block_mask_ptr + fine_mask_start + q_block_id) if is_connected: for tile_in_block in range(tiles_per_block): tile_idx = q_block_id * tiles_per_block + tile_in_block q_tile_start = tile_idx * q_tile_size q_tile_ptr = tl.make_block_ptr( base=q_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim, shape=(q_tile_size, num_q_heads, feature_dim), strides=(num_q_heads * feature_dim, feature_dim, 1), offsets=(0, kv_head_idx * q_heads_per_kv_head, 0), block_shape=(q_tile_size, q_heads_per_kv_head, feature_dim), order=(0, 1, 2), ) grad_o_tile_ptr = tl.make_block_ptr( base=grad_o_ptr + (batch_offset + q_tile_start) * num_q_heads * feature_dim, shape=(q_tile_size, num_q_heads, feature_dim), strides=(num_q_heads * feature_dim, feature_dim, 1), offsets=(0, kv_head_idx * q_heads_per_kv_head, 0), block_shape=(q_tile_size, q_heads_per_kv_head, feature_dim), order=(0, 1, 2), ) lse_tile_ptr = tl.make_block_ptr( base=lse_ptr + (batch_offset + q_tile_start) * num_q_heads, shape=(q_tile_size, num_q_heads), strides=(num_q_heads, 1), offsets=(0, kv_head_idx * q_heads_per_kv_head), block_shape=(q_tile_size, q_heads_per_kv_head), order=(1, 0), ) delta_tile_ptr = tl.make_block_ptr( base=delta_ptr + (batch_offset + q_tile_start) * num_q_heads, shape=(q_tile_size, num_q_heads), strides=(num_q_heads, 1), offsets=(0, kv_head_idx * q_heads_per_kv_head), block_shape=(q_tile_size, q_heads_per_kv_head), order=(1, 0), ) q_tile = tl.load(q_tile_ptr) * softmax_scale * LOG2_E q_tile = tl.reshape(q_tile, (q_tile_size * q_heads_per_kv_head, feature_dim)) q_tile = q_tile.to(tl.bfloat16) grad_o_block = tl.load(grad_o_tile_ptr) grad_o_block = tl.reshape(grad_o_block, (q_tile_size * q_heads_per_kv_head, feature_dim)) grad_o_block = grad_o_block.to(tl.bfloat16) lse_vals = tl.load(lse_tile_ptr) lse_vals = tl.reshape(lse_vals, (q_tile_size * q_heads_per_kv_head,)) delta_vals = tl.load(delta_tile_ptr) delta_vals = tl.reshape(delta_vals, (q_tile_size * q_heads_per_kv_head,)) attention_scores = tl.dot(k_block, tl.trans(q_tile)) attention_probs = tl.exp2(attention_scores - lse_vals[None, :]) grad_v_accum += tl.dot(attention_probs.to(tl.bfloat16), grad_o_block) grad_times_v = tl.dot(v_block, tl.trans(grad_o_block)) grad_scores = attention_probs * (grad_times_v - delta_vals[None, :]) * LN_2 grad_k_accum += tl.dot(grad_scores.to(tl.bfloat16), q_tile) tl.store(grad_k_ptr, grad_k_accum.to(grad_k_ptr.dtype.element_ty)) tl.store(grad_v_ptr, grad_v_accum.to(grad_v_ptr.dtype.element_ty)) def mosaic_attn_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, lse: torch.Tensor, grad_o: torch.Tensor, softmax_scale: float, block_indices: torch.LongTensor, block_size: int, ): batch_size, seq_len, num_kv_heads, feature_dim = k.shape num_q_heads = q.shape[2] num_kv_blocks_per_q_block = block_indices.shape[-1] q_heads_per_kv_head = num_q_heads // num_kv_heads num_blocks_in_seq = seq_len // block_size grad_q = torch.empty_like(q) grad_k = torch.empty_like(k) grad_v = torch.empty_like(v) block_mask = mosaic_block_mask(block_indices) delta = (output * grad_o).sum(dim=-1) grid_dq = lambda META: ( triton.cdiv(seq_len, META['q_tile_size']), q_heads_per_kv_head, batch_size * num_kv_heads ) mosaic_attn_bwd_q_kernel[grid_dq]( q_ptr=q, k_ptr=k, v_ptr=v, lse_ptr=lse, delta_ptr=delta, grad_o_ptr=grad_o, grad_q_ptr=grad_q, block_indices_ptr=block_indices, softmax_scale=softmax_scale, seq_len=seq_len, num_kv_heads=num_kv_heads, num_q_heads=num_q_heads, q_heads_per_kv_head=q_heads_per_kv_head, feature_dim=feature_dim, kv_block_size=block_size, num_kv_blocks_per_q_block=num_kv_blocks_per_q_block, ) grid_dkv = (num_blocks_in_seq, batch_size * num_kv_heads) mosaic_attn_bwd_kv_kernel[grid_dkv]( q_ptr=q, k_ptr=k, v_ptr=v, lse_ptr=lse, delta_ptr=delta, grad_o_ptr=grad_o, grad_k_ptr=grad_k, grad_v_ptr=grad_v, block_mask_ptr=block_mask, softmax_scale=softmax_scale, seq_len=seq_len, num_kv_heads=num_kv_heads, num_q_heads=num_q_heads, q_heads_per_kv_head=q_heads_per_kv_head, feature_dim=feature_dim, kv_block_size=block_size, ) return grad_q, grad_k, grad_v class MosaicAttnFunction(torch.autograd.Function): @staticmethod @torch.amp.custom_fwd(device_type='cuda') def forward( ctx: torch.autograd.function.FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.Tensor, block_size: int, softmax_scale: float ): q, k, v, block_indices = map(lambda x: x.contiguous(), (q, k, v, block_indices)) ctx.dtype = q.dtype output, lse = mosaic_attn_fwd( q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, softmax_scale=softmax_scale, ) ctx.save_for_backward(q, k, v, output, lse, block_indices) ctx.block_size = block_size ctx.softmax_scale = softmax_scale return output.to(q.dtype) @staticmethod @torch.amp.custom_bwd(device_type='cuda') def backward( ctx: torch.autograd.function.FunctionCtx, grad_o: torch.Tensor ): q, k, v, output, lse, block_indices = ctx.saved_tensors grad_o = grad_o.contiguous() grad_q, grad_k, grad_v = mosaic_attn_bwd( q=q, k=k, v=v, output=output, lse=lse, grad_o=grad_o, softmax_scale=ctx.softmax_scale, block_indices=block_indices, block_size=ctx.block_size, ) return grad_q, grad_k, grad_v, None, None, None def mosaic_sparse_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.LongTensor, block_size: int, softmax_scale: float = None, ): softmax_scale = q.shape[-1] ** -0.5 if softmax_scale is None else softmax_scale return MosaicAttnFunction.apply(q, k, v, block_indices, block_size, softmax_scale)