# Utility triton funcs @triton.jit def get_offset_for_next_block( loop_iter, col_indices, total_blocks, SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, BLOCKS_ARE_CONTIGUOUS: tl.constexpr ): if BLOCKS_ARE_CONTIGUOUS: return BLOCK cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK return offset @triton.jit def get_bounded_indices(indices, max_len=None): return indices % max_len if max_len is not None else indices @triton.jit def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): if IS_DIVISIBLE and SAFE_HEAD_DIM: return tl.load(block_ptr) elif IS_DIVISIBLE and not SAFE_HEAD_DIM: return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") elif not IS_DIVISIBLE and SAFE_HEAD_DIM: return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") else: return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") @triton.jit def load_checked_2d( ptr, offs_m, offs_n, stride_m, stride_n, IS_DIVISIBLE_M: tl.constexpr, IS_DIVISIBLE_N: tl.constexpr, M_LEN: tl.constexpr, N_LEN: tl.constexpr, ): # Calculate final pointer if strides are provided if stride_m is not None and stride_n is not None: ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n # Handle all masking cases if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_LEN), other=0.0) elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: return tl.load(ptr, mask=(offs_n[None, :] < N_LEN), other=0.0) elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) else: # Both divisible return tl.load(ptr)