{{def_kernel("A", "B")}} M = {{size("A", 0)}} N = {{size("B", 1)}} K = {{size("A", 1)}} if M * N == 0: # early exit due to zero-size input(s) return stride_am = {{stride("A", 0)}} stride_ak = {{stride("A", 1)}} stride_bk = {{stride("B", 0)}} stride_bn = {{stride("B", 1)}} # based on triton.ops.matmul pid = tl.program_id(0).to(INDEX_DTYPE) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: offs_a_m = rm % M if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: offs_b_n = rn % N offs_k = tl.arange(0, BLOCK_K) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k_idx in range(0, tl.cdiv(K, BLOCK_K)): {% if not EVEN_K %} a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) {% endif %} a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) idx_m = offs_a_m[:, None] idx_n = a_k_idx_vals {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8, index_shape=("BLOCK_M", "BLOCK_K"))}} idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8, index_shape=("BLOCK_K", "BLOCK_N"))}} {% if USE_FAST_ACCUM %} acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) {% else %} acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) {% endif %} # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) idx_m = rm[:, None] idx_n = rn[None, :] mask = (idx_m < M) & (idx_n < N) # inductor generates a suffix {{store_output(("idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}}