From ecb32fe2314184cb3e3ed6d26d9524c69c0d1a21 Mon Sep 17 00:00:00 2001 From: Zheng Linfeng Date: Tue, 24 Mar 2026 00:12:01 -0700 Subject: [PATCH] [CLI] Fix tutorial issues --- .../blackwell/tutorial_gemm/fp16_gemm_0.py | 15 +++++++++++--- .../blackwell/tutorial_gemm/fp16_gemm_1.py | 14 ++++++++++--- .../blackwell/tutorial_gemm/fp16_gemm_2.py | 16 +++++++++++---- .../blackwell/tutorial_gemm/fp16_gemm_3.py | 20 ++++++++++++++----- .../blackwell/tutorial_gemm/fp16_gemm_3_1.py | 15 +++++++++++--- .../blackwell/tutorial_gemm/fp16_gemm_4.py | 15 +++++++++++--- .../blackwell/tutorial_gemm/fp16_gemm_5.py | 16 ++++++++++++--- .../blackwell/tutorial_gemm/fp16_gemm_6.py | 15 +++++++++++--- 8 files changed, 99 insertions(+), 27 deletions(-) diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py index b706a3d98..d953be1ed 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py @@ -235,9 +235,17 @@ def kernel( ab_full = ab_consumer.wait_and_advance() # tCtAcc += tCrA * tCrB - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, ab_full.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, ab_full.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load ab_full.release() @@ -445,3 +453,4 @@ if __name__ == "__main__": run_dense_gemm(args.mnk, args.tolerance) print("PASS") + diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py index 45f1b9a40..f4abf4f04 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py @@ -294,9 +294,17 @@ def kernel( if is_leader_cta: ab_full = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0) - tile_crd = (None, None, None, ab_full.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, ab_full.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) ab_full.release() # Signal that the accumulator is fully computed diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py index 4cc79b74e..d2e1ccff1 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py @@ -49,7 +49,7 @@ To store the results from registers to global memory using TMA actually requires 1). Write the tile from registers to shared memory 2). Write the tile from shared memory to global memory Here we continue to use epiolgue subtiles, one reason is that it reduces the shared memory usage in the epilogue, -and another reason is that it can hide the STS latency, that is the STS of the next subtile can be overlapped with the TMA store of the current subtile. +and another reason is that it can hide the st.shared latency, that is the st.shared of the next subtile can be overlapped with the TMA store of the current subtile. For large mma tile size, the mainloop performance between Non-WS and WS version could be similar if there are enough ab_stages to hide the dram latency. The performance gain of WS version mainly comes from the prologue and epilogue in this case. @@ -338,9 +338,17 @@ def kernel( handle = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, handle.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load handle.release() diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py index f51236b5e..9f831e961 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py @@ -362,14 +362,23 @@ def kernel( # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)] + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) for k_tile_idx in range(num_k_tiles): # Wait for TMA copies to complete handle = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, handle.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load handle.release() @@ -699,8 +708,9 @@ def run_dense_gemm( def make_tensors(mn, k, dtype): shape = (mn, k) return ( - torch.empty(*shape, dtype=torch.int32) - .random_(-2, 2) + # torch.empty(*shape, dtype=torch.int32) + # .random_(-2, 2) + torch.ones(*shape, dtype=torch.int32) .to(device="cuda", dtype=dtype) ) diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py index b18ab1792..553758cf7 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py @@ -447,14 +447,23 @@ def kernel( # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)] + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) for k_tile_idx in range(num_k_tiles): # Wait for TMA copies to complete handle = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, handle.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load handle.release() diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py index c8b999f68..b5ad6d3b1 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py @@ -471,14 +471,23 @@ def cluster_specific_kernel( # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)] + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) for k_tile_idx in range(num_k_tiles): # Wait for TMA copies to complete handle = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, handle.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load handle.release() diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py index cf185fef0..a483120af 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py @@ -483,14 +483,24 @@ def kernel( # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)] + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + for k_tile_idx in range(num_k_tiles): # Wait for TMA copies to complete handle = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, handle.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load handle.release() diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py index 2eae1f34c..371880ac6 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py +++ b/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py @@ -508,14 +508,23 @@ def gemm( # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)] + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) for k_tile_idx in range(num_k_tiles): # Wait for TMA copies to complete handle = ab_consumer.wait_and_advance() # Execute one K-block worth of MMA instructions - tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0) - tile_crd = (None, None, None, handle.index) - cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Signal that the A/B buffers have been consumed and are ready for the next load handle.release()