mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
[CLI] Fix tutorial issues
This commit is contained in:
@@ -235,9 +235,17 @@ def kernel(
|
||||
ab_full = ab_consumer.wait_and_advance()
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, ab_full.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, ab_full.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
ab_full.release()
|
||||
@@ -445,3 +453,4 @@ if __name__ == "__main__":
|
||||
|
||||
run_dense_gemm(args.mnk, args.tolerance)
|
||||
print("PASS")
|
||||
|
||||
|
||||
@@ -294,9 +294,17 @@ def kernel(
|
||||
if is_leader_cta:
|
||||
ab_full = ab_consumer.wait_and_advance()
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0)
|
||||
tile_crd = (None, None, None, ab_full.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, ab_full.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
ab_full.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
|
||||
@@ -49,7 +49,7 @@ To store the results from registers to global memory using TMA actually requires
|
||||
1). Write the tile from registers to shared memory
|
||||
2). Write the tile from shared memory to global memory
|
||||
Here we continue to use epiolgue subtiles, one reason is that it reduces the shared memory usage in the epilogue,
|
||||
and another reason is that it can hide the STS latency, that is the STS of the next subtile can be overlapped with the TMA store of the current subtile.
|
||||
and another reason is that it can hide the st.shared latency, that is the st.shared of the next subtile can be overlapped with the TMA store of the current subtile.
|
||||
|
||||
For large mma tile size, the mainloop performance between Non-WS and WS version could be similar if there are enough ab_stages to hide the dram latency.
|
||||
The performance gain of WS version mainly comes from the prologue and epilogue in this case.
|
||||
@@ -338,9 +338,17 @@ def kernel(
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
@@ -362,14 +362,23 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
@@ -699,8 +708,9 @@ def run_dense_gemm(
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
# torch.empty(*shape, dtype=torch.int32)
|
||||
# .random_(-2, 2)
|
||||
torch.ones(*shape, dtype=torch.int32)
|
||||
.to(device="cuda", dtype=dtype)
|
||||
)
|
||||
|
||||
|
||||
@@ -447,14 +447,23 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
@@ -471,14 +471,23 @@ def cluster_specific_kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
@@ -483,14 +483,24 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
@@ -508,14 +508,23 @@ def gemm(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
Reference in New Issue
Block a user