# Common Imports @triton.jit def forward_block_mn( {{gen_argdefs()}}, q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, # accumulated values acc, l_i, m_i, # Offsets off_z, off_h, offs_m, offs_n, # Offsets needed for TMA loads kv_start, kv_offset, MATMUL_PRECISION, RCP_LN2, # Strides for K and V stride_kk, stride_kn, stride_vn, stride_vk, IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, ): # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through {{gen_defines() | indent_except_first(1)}} # -- load k -- # NB reversed order to since K is transposed kv_base_offset = kv_start + kv_offset {%- if USE_TMA %} k = tl.load_tensor_descriptor( desc_k, [kv_base_offset, 0], ) {%- else %} # Load K as [BLOCK_N, QK_HEAD_DIM_ROUNDED] then transpose to [QK_HEAD_DIM_ROUNDED, BLOCK_N] offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) offs_n_load = kv_base_offset + tl.arange(0, BLOCK_N) k = load_checked_2d(K, offs_n_load, offs_k, stride_kn, stride_kk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) {%- endif %} k = tl.trans(k) # -- compute qk --- qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. if not PRESCALE_QK: qk *= SM_SCALE # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, # which is larger than the actual number of elements. To avoid access memory out of bound, # we need to mask out the elements that are out of Q_LEN & KV_LEN. m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) {{ modification( subgraph_number=0, output_name="post_mod_scores", score="qk", b="off_z", h="off_h", m="m", n="n", out="qk" ) | indent_except_first(1) }} if CHECK_BLOCK_BOUNDARY: # Mask out the elements that are out of the KV_LEN for non divisible seqlen. post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) if not IS_FULL_BLOCKS: {{ modification( subgraph_number=1, output_name="mask_mod_output", score="qk", b="off_z", h="off_h", m="m", n="n", ) | indent_except_first(2) }} if CHECK_BLOCK_BOUNDARY: mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) # apply mask for partially unmasked blocks post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) if not PRESCALE_QK: post_mod_scores *= RCP_LN2 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -- compute scaling constant --- m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) if not ROWS_GUARANTEED_SAFE: masked_out_rows = (m_ij == float("-inf")) m_ij_masked = tl.where(masked_out_rows, 0, m_ij) else: m_ij_masked = m_ij alpha = tl.math.exp2(m_i - m_ij_masked) p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) # NB: l_i update is pulled up here since it's a bit faster # NB: For headdim=256, it's faster to move it back down to after m_i = # m_ij l_i = l_i * alpha + tl.sum(p, 1) # # -- scale and update acc -- acc = acc * alpha[:, None] {%- if USE_TMA %} v = tl.load_tensor_descriptor( desc_v, [kv_base_offset, 0], ) {%- else %} # Calculate offsets for V loading - reuse kv_base_offset from K loading offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) v = load_checked_2d(V, offs_n_load, offs_v, stride_vn, stride_vk, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) {%- endif %} acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) # -- update m_i m_i = m_ij return acc, l_i, m_i @triton.jit def forward_inner( {{gen_argdefs()}}, q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, # accumulated values acc, l_i, m_i, # Offsets used as inputs to score_mod & mask_mod # of size [BLOCK_M, BLOCK_N] or scalar. off_z, off_h, offs_m, offs_n, # Offsets needed for TMA loads kv_start, # blocksparse data kv_indices, kv_num_blocks, # start kv and end kv block block_n_start, block_n_end, MATMUL_PRECISION, # Strides for K and V stride_kk, stride_kn, stride_vn, stride_vk, IS_FULL_BLOCKS, ): # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through {{gen_defines() | indent_except_first(1)}} SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) RCP_LN2: tl.constexpr = 1.44269504 if PRESCALE_QK: q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) kv_offset = 0 # loop over k, v and update accumulator until block_n_end for start_n in range(block_n_start, block_n_end): # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. if IS_DIVISIBLE: acc, l_i, m_i = forward_block_mn( {{gen_argdefs()}}, q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, # accumulated values acc, l_i, m_i, # Offsets off_z, off_h, offs_m, offs_n, # Offsets needed for TMA loads kv_start, kv_offset, MATMUL_PRECISION, RCP_LN2, # Strides for K and V stride_kk, stride_kn, stride_vn, stride_vk, IS_FULL_BLOCKS, ) else: # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, # it's on par or slightly faster than only applying to the last block in fwd. # However, we choose different strategy for bwd, where we only apply mod & mask # to the last block because it's faster a lot. acc, l_i, m_i = forward_block_mn( {{gen_argdefs()}}, q, K, V, desc_k, desc_v, Q_LEN, KV_LEN, # accumulated values acc, l_i, m_i, # Offsets off_z, off_h, offs_m, offs_n, # Offsets needed for TMA loads kv_start, kv_offset, MATMUL_PRECISION, RCP_LN2, # Strides for K and V stride_kk, stride_kn, stride_vn, stride_vk, IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, ) offset = get_offset_for_next_block( start_n, kv_indices, kv_num_blocks, SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS ) offs_n = offs_n + offset kv_offset += offset return acc, l_i, m_i