[CLI] Fix tutorial issues

This commit is contained in:
Zheng Linfeng
2026-03-24 00:12:01 -07:00
parent 982748aa73
commit ecb32fe231
8 changed files with 99 additions and 27 deletions

View File

@@ -235,9 +235,17 @@ def kernel(
ab_full = ab_consumer.wait_and_advance()
# tCtAcc += tCrA * tCrB
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, ab_full.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, ab_full.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
ab_full.release()
@@ -445,3 +453,4 @@ if __name__ == "__main__":
run_dense_gemm(args.mnk, args.tolerance)
print("PASS")

View File

@@ -294,9 +294,17 @@ def kernel(
if is_leader_cta:
ab_full = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0)
tile_crd = (None, None, None, ab_full.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, ab_full.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
ab_full.release()
# Signal that the accumulator is fully computed

View File

@@ -49,7 +49,7 @@ To store the results from registers to global memory using TMA actually requires
1). Write the tile from registers to shared memory
2). Write the tile from shared memory to global memory
Here we continue to use epiolgue subtiles, one reason is that it reduces the shared memory usage in the epilogue,
and another reason is that it can hide the STS latency, that is the STS of the next subtile can be overlapped with the TMA store of the current subtile.
and another reason is that it can hide the st.shared latency, that is the st.shared of the next subtile can be overlapped with the TMA store of the current subtile.
For large mma tile size, the mainloop performance between Non-WS and WS version could be similar if there are enough ab_stages to hide the dram latency.
The performance gain of WS version mainly comes from the prologue and epilogue in this case.
@@ -338,9 +338,17 @@ def kernel(
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()

View File

@@ -362,14 +362,23 @@ def kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()
@@ -699,8 +708,9 @@ def run_dense_gemm(
def make_tensors(mn, k, dtype):
shape = (mn, k)
return (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
# torch.empty(*shape, dtype=torch.int32)
# .random_(-2, 2)
torch.ones(*shape, dtype=torch.int32)
.to(device="cuda", dtype=dtype)
)

View File

@@ -447,14 +447,23 @@ def kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()

View File

@@ -471,14 +471,23 @@ def cluster_specific_kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()

View File

@@ -483,14 +483,24 @@ def kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()

View File

@@ -508,14 +508,23 @@ def gemm(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()