mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-19 22:38:56 +00:00
[CLI] add cutedsl fp16 gemm tutorial from 2 to 6 (#3106)
* [CLI] add fp16 gemm tutorial from 2 to 6 * [CLI] refine comments
This commit is contained in:
@@ -8,10 +8,13 @@
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Type, Callable
|
||||
from functools import partial, lru_cache
|
||||
|
||||
import cutlass
|
||||
from cutlass import Numeric
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
@@ -66,6 +69,7 @@ def kernel(
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
):
|
||||
|
||||
# Current thread/warp/block coordinates
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
@@ -139,15 +143,15 @@ def kernel(
|
||||
# (bM, bN)
|
||||
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
|
||||
thr_mma = tiled_mma.get_slice(0)
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
# (MMA, MMA_M, MMA_K, RestK)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
# (MMA, MMA_N, MMA_K, RestK)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
@@ -195,14 +199,14 @@ def kernel(
|
||||
tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)
|
||||
|
||||
# (TmemCpy,NumTmemCpy,NumTiles)
|
||||
tDtC = tmem_thr_copy.partition_S(tCtAcc_epi)
|
||||
tCtC = tmem_thr_copy.partition_S(tCtAcc_epi)
|
||||
# (TmemCpy,NumTmemCpy,NumTiles)
|
||||
tDgC = tmem_thr_copy.partition_D(gC_epi)
|
||||
tCgC = tmem_thr_copy.partition_D(gC_epi)
|
||||
|
||||
# (TmemCpy,NumTmemCpy)
|
||||
tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype)
|
||||
tCrAcc = cute.make_rmem_tensor(tCgC[None, None, 0].shape, acc_dtype)
|
||||
# (TmemCpy,NumTmemCpy)
|
||||
tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype)
|
||||
tCrC = cute.make_rmem_tensor(tCgC[None, None, 0].shape, io_dtype)
|
||||
|
||||
#
|
||||
# 2. Main loop
|
||||
@@ -229,17 +233,11 @@ def kernel(
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
ab_full = ab_consumer.wait_and_advance()
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, ab_full.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
|
||||
# tCtAcc += tCrA * tCrB
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, ab_full.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
ab_full.release()
|
||||
@@ -259,10 +257,10 @@ def kernel(
|
||||
|
||||
# TMEM -> RMEM -> GEMM
|
||||
# Sub-tiling for better instruction-level parallelism
|
||||
for i in cutlass.range(cute.size(tDtC, mode=[2])):
|
||||
cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)
|
||||
for i in cutlass.range(cute.size(tCtC, mode=[2])):
|
||||
cute.copy(tmem_tiled_copy, tCtC[None, None, i], tCrAcc)
|
||||
tCrC.store(tCrAcc.load().to(io_dtype))
|
||||
cute.autovec_copy(tCrC, tDgC[None, None, i])
|
||||
cute.autovec_copy(tCrC, tCgC[None, None, i])
|
||||
acc_full.release()
|
||||
|
||||
# Deallocate TMEM
|
||||
@@ -344,10 +342,44 @@ def host_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor):
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def prepare_run(
|
||||
callable: Callable,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_dtype: Type[Numeric],
|
||||
b_dtype: Type[Numeric],
|
||||
c_dtype: Type[Numeric],
|
||||
) -> tuple[Callable, tuple]:
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
a, b, c = cutlass_torch.prepare_tensors_for_gemm(
|
||||
(m, n, k), a_dtype, b_dtype, c_dtype
|
||||
)
|
||||
a_ = (
|
||||
from_dlpack(a, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=k)
|
||||
)
|
||||
b_ = (
|
||||
from_dlpack(b, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=k)
|
||||
)
|
||||
c_ = (
|
||||
from_dlpack(c, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=n)
|
||||
)
|
||||
compiled_fn = cute.compile(callable, a_, b_, c_, options="--generate-line-info")
|
||||
return partial(compiled_fn, a_, b_, c_), (a, b, c)
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
) -> None:
|
||||
global torch, cutlass_torch
|
||||
import torch
|
||||
import cutlass.torch as cutlass_torch
|
||||
@@ -362,48 +394,23 @@ def run_dense_gemm(
|
||||
m, n, k = mnk
|
||||
torch.manual_seed(1111)
|
||||
|
||||
# Make K-major tensors (torch tensors are row-major)
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(dtype=dtype, device="cuda")
|
||||
)
|
||||
|
||||
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
|
||||
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
|
||||
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
|
||||
a_tensor = (
|
||||
from_dlpack(a, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=k)
|
||||
run_fn, (a, b, c) = prepare_run(
|
||||
host_function, m, n, k, io_dtype, io_dtype, io_dtype
|
||||
)
|
||||
b_tensor = (
|
||||
from_dlpack(b, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=k)
|
||||
)
|
||||
c_tensor = (
|
||||
from_dlpack(c, assumed_align=32)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=n)
|
||||
)
|
||||
|
||||
# Entry point to the host JIT function
|
||||
host_function(a_tensor, b_tensor, c_tensor, no_cache=True)
|
||||
run_fn()
|
||||
|
||||
# Compute reference result and verify
|
||||
ref = (torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32))).cpu()
|
||||
ref = torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32))
|
||||
|
||||
torch.testing.assert_close(
|
||||
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
c, ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
def parse_comma_separated_ints(s: str) -> list[int]:
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",")]
|
||||
except ValueError:
|
||||
@@ -428,14 +435,13 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if len(args.mnk) != 3:
|
||||
parser.error("--mnk must contain exactly 3 values")
|
||||
if args.mnk[0] % mma_tiler_mnk[0] != 0 or args.mnk[1] % mma_tiler_mnk[1] != 0:
|
||||
parser.error("m n must be divisible by mma_tiler_mn")
|
||||
|
||||
run_dense_gemm(
|
||||
args.mnk,
|
||||
args.tolerance,
|
||||
)
|
||||
run_dense_gemm(args.mnk, args.tolerance)
|
||||
print("PASS")
|
||||
|
||||
@@ -65,7 +65,8 @@ Constraints for this example:
|
||||
|
||||
io_dtype = cutlass.Float16
|
||||
acc_dtype = cutlass.Float32
|
||||
cluster_shape_mnk = (2, 1, 1)
|
||||
use_2cta_instrs = True
|
||||
cluster_shape_mnk = (2, 1, 1) if use_2cta_instrs else (1, 1, 1)
|
||||
mma_inst_shape_mnk = (256, 256, 16)
|
||||
mma_tiler_mnk = (256, 256, 64)
|
||||
threads_per_cta = 128
|
||||
@@ -79,7 +80,7 @@ acc_stage = 1
|
||||
class SharedStorage:
|
||||
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stage * 2]
|
||||
tmem_dealloc_mbar_ptr: cutlass.Int64
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buf: cutlass.Int32
|
||||
|
||||
|
||||
@@ -95,6 +96,7 @@ def kernel(
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
):
|
||||
|
||||
# Current thread/warp/block coordinates
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
@@ -172,15 +174,15 @@ def kernel(
|
||||
# (bM, bN)
|
||||
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
|
||||
thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0])
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
# (MMA, MMA_M, MMA_K, RestK)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
# (MMA, MMA_N, MMA_K, RestK)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
# (MMA, MMA_M, MMA_K)
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K)
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
@@ -218,7 +220,7 @@ def kernel(
|
||||
storage.tmem_holding_buf,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
is_two_cta=cute.size(cta_layout_vmnk, mode=[0]) > 1,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
)
|
||||
num_tmem_cols = 512
|
||||
tmem.allocate(num_tmem_cols)
|
||||
@@ -230,7 +232,7 @@ def kernel(
|
||||
# Swap the pointer in tCtAcc
|
||||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)
|
||||
|
||||
subtile_cnt = 4
|
||||
subtile_cnt = 1 if mma_tiler_mnk[0] == 64 else 4
|
||||
# (EpiTile)
|
||||
epi_tiler = (
|
||||
(cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt),
|
||||
@@ -242,21 +244,24 @@ def kernel(
|
||||
|
||||
# Every thread loads 64 x fp32
|
||||
tmem_atom = cute.make_copy_atom(
|
||||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),
|
||||
tcgen05.Ld16x256bOp(tcgen05.Repetition.x8)
|
||||
if mma_tiler_mnk[0] == 64
|
||||
else tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),
|
||||
cutlass.Float32,
|
||||
)
|
||||
|
||||
tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0])
|
||||
tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)
|
||||
|
||||
# (TmemCpy,NumTmemCpy,NumTiles)
|
||||
tDtC = tmem_thr_copy.partition_S(tCtAcc_epi)
|
||||
tCtC = tmem_thr_copy.partition_S(tCtAcc_epi)
|
||||
# (TmemCpy,NumTmemCpy,NumTiles)
|
||||
tDgC = tmem_thr_copy.partition_D(gC_epi)
|
||||
tCgC = tmem_thr_copy.partition_D(gC_epi)
|
||||
|
||||
# (TmemCpy,NumTmemCpy)
|
||||
tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype)
|
||||
tCrAcc = cute.make_rmem_tensor(tCgC[None, None, 0].shape, acc_dtype)
|
||||
# (TmemCpy,NumTmemCpy)
|
||||
tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype)
|
||||
tCrC = cute.make_rmem_tensor(tCgC[None, None, 0].shape, io_dtype)
|
||||
|
||||
#
|
||||
# 2. Main loop
|
||||
@@ -266,8 +271,8 @@ def kernel(
|
||||
if warp_idx == 0:
|
||||
# Wait for a empty accumulator buffer
|
||||
if is_leader_cta:
|
||||
acc_producer.acquire_and_advance()
|
||||
for _ in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2):
|
||||
acc_producer.acquire()
|
||||
for k_tile in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2):
|
||||
# Issue TMA loads
|
||||
ab_empty = ab_producer.acquire_and_advance()
|
||||
cute.copy(
|
||||
@@ -289,22 +294,15 @@ def kernel(
|
||||
if is_leader_cta:
|
||||
ab_full = ab_consumer.wait_and_advance()
|
||||
# Execute one K-block worth of MMA instructions
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, ab_full.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0)
|
||||
tile_crd = (None, None, None, ab_full.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
ab_full.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
if is_leader_cta:
|
||||
acc_producer.commit()
|
||||
acc_producer.advance()
|
||||
|
||||
#
|
||||
# 3. Epilogue
|
||||
@@ -315,12 +313,13 @@ def kernel(
|
||||
|
||||
# Wait for the accumulator buffer to be full
|
||||
acc_full = acc_consumer.wait_and_advance()
|
||||
|
||||
# TMEM -> RMEM -> GEMM
|
||||
# Sub-tiling for better instruction-level parallelism
|
||||
for i in cutlass.range(cute.size(tDtC, mode=[2])):
|
||||
cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)
|
||||
for i in cutlass.range(cute.size(tCtC, mode=[2])):
|
||||
cute.copy(tmem_tiled_copy, tCtC[None, None, i], tCrAcc)
|
||||
tCrC.store(tCrAcc.load().to(io_dtype))
|
||||
cute.autovec_copy(tCrC, tDgC[None, None, i])
|
||||
cute.autovec_copy(tCrC, tCgC[None, None, i])
|
||||
acc_full.release()
|
||||
|
||||
# Ensure used buffers are properly synchronized before producer exit.
|
||||
@@ -346,7 +345,7 @@ def host_function(
|
||||
io_dtype,
|
||||
acc_dtype,
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
@@ -374,14 +373,16 @@ def host_function(
|
||||
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
|
||||
|
||||
# Construct TMA load atoms
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_one_stage,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape, # take the layout and extract the shape internally
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
@@ -394,7 +395,8 @@ def host_function(
|
||||
|
||||
grid_shape = cute.round_up(
|
||||
cute.ceil_div(
|
||||
(*c.layout.shape, 1), (mma_tiler_mnk[0] // 2, *mma_tiler_mnk[1:])
|
||||
(*c.layout.shape, 1),
|
||||
(mma_tiler_mnk[0] // (2 if use_2cta_instrs else 1), *mma_tiler_mnk[1:]),
|
||||
),
|
||||
cluster_shape_mnk,
|
||||
)
|
||||
|
||||
679
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py
Normal file
679
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# This is the third tutorial GEMM. It further enhances the second tutorial by adding warp
|
||||
# specialization for TMA, MMA, and epilogue warps.
|
||||
|
||||
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
"""
|
||||
The third tutorial GEMM demonstrates a simple kernel implementation in CuTeDSL.
|
||||
|
||||
It further enhances fp16_gemm_1.py by adding warp specialization for TMA, MMA, and epilogue warps.
|
||||
In the epilogue warp, we use TMA store instead of regular copy to store the result from registers to global memory.
|
||||
|
||||
This example can achieve better performance than fp16_gemm_1.py due to:
|
||||
1. We use warp specialization(WS) to overlap the memory loads and MMA computations.
|
||||
|
||||
Core concept of WS is to specialize warps with different tasks (e.g., DMA, MMA, epilogue),
|
||||
therefore, different warps in a CTA must communicate with each other.
|
||||
|
||||
Warp specialization's benefit comes from task parallelism between warps in the CTA.
|
||||
For example, the DMA warps proceed to start loading A/B tensors for the next K-block as soon as they finish loading the current K-block.
|
||||
While the MMA warps are computing the result of the current K-block. So the dram latency is hidden.
|
||||
The dram latency can also be hidden by prefetch in non-WS version,
|
||||
but WS version has better instruction level parallelism as different types of instructions are issued in different warps.
|
||||
For example, in non-WS version, tmem allocation and TMA loads are both issued in the same warp, TMA loads only issue after tmem allocation is finished.
|
||||
But in WS version, tmem allocation and TMA loads are issued in different warps, tmem allocation can be overlapped with TMA loads.
|
||||
|
||||
2. We use TMA store instead of regular copy to store the results from registers to global memory.
|
||||
|
||||
To store the results from registers to global memory using TMA actually requires two steps:
|
||||
1). Write the tile from registers to shared memory
|
||||
2). Write the tile from shared memory to global memory
|
||||
Here we continue to use epiolgue subtiles, one reason is that it reduces the shared memory usage in the epilogue,
|
||||
and another reason is that it can hide the STS latency, that is the STS of the next subtile can be overlapped with the TMA store of the current subtile.
|
||||
|
||||
For large mma tile size, the mainloop performance between Non-WS and WS version could be similar if there are enough ab_stages to hide the dram latency.
|
||||
The performance gain of WS version mainly comes from the prologue and epilogue in this case.
|
||||
That means, if k-dimension is small, then the performance of WS version will be obviously better than non-WS version.
|
||||
For small mma tile size, we may also see better mainloop performance for WS version.
|
||||
This is because there are ALU instructions (preparation work for MMA) for each MMA instruction, and ALU proportion is higher for small mma tile size.
|
||||
In Non-WS version, warp 0 will issue the ALU operations for both TMA and MMA instruction, while in WS version, they are issued in different warps,
|
||||
so less ALU instructions are issued in MMA warp, and mma instructions can be issued more efficiently.
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_2.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
* The problem size of m and n must be divisible by the tile size m & n (256, 256)
|
||||
"""
|
||||
|
||||
io_dtype = cutlass.Float16
|
||||
acc_dtype = cutlass.Float32
|
||||
use_2cta_instrs = True
|
||||
cluster_shape_mnk = (2, 1, 1) if use_2cta_instrs else (1, 1, 1)
|
||||
mma_inst_shape_mnk = (256, 256, 16)
|
||||
mma_tiler_mnk = (256, 256, 64)
|
||||
threads_in_epilogue = 128 # epilogue threads per cta
|
||||
|
||||
# Pipeline stage configuration
|
||||
ab_stages = 6
|
||||
epi_stages = 2
|
||||
acc_stages = 1
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stages * 2]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
):
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
|
||||
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
|
||||
mma_coord_vmnk = (
|
||||
bidx % cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidx // cute.size(cta_layout_vmnk, mode=[0]),
|
||||
bidy,
|
||||
None,
|
||||
)
|
||||
mma_coord_mnk = mma_coord_vmnk[1:]
|
||||
is_leader_cta = mma_coord_vmnk[0] == 0
|
||||
|
||||
epilogue_warp_ids = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
mma_warp_id = 4
|
||||
tma_warp_id = 5
|
||||
|
||||
#
|
||||
# 1. Prepare args
|
||||
#
|
||||
|
||||
# Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
)
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
num_mcast_participants = (
|
||||
cute.size(cta_layout_vmnk, mode=[1]) + cute.size(cta_layout_vmnk, mode=[2]) - 1
|
||||
)
|
||||
|
||||
# Mcast mask initialization
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
# (bM, bK, RestK)
|
||||
gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))
|
||||
# (bN, bK, RestK)
|
||||
gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))
|
||||
# (bM, bN)
|
||||
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
|
||||
|
||||
thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0])
|
||||
# (MMA, MMA_M, MMA_K, RestK)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K, RestK)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape)
|
||||
|
||||
# Barrier 1 for epilogue synchronization
|
||||
epilogue_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=1,
|
||||
num_threads=threads_in_epilogue,
|
||||
)
|
||||
|
||||
# Only MMA warp and epilogue warps participate in TMEM allocation synchronization
|
||||
# TMA warp does NOT participate
|
||||
tmem_alloc_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=2,
|
||||
num_threads=32
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True if use_2cta_instrs else False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
)
|
||||
|
||||
# Partition tensors for TMA; This requires the tensors partitioned for MMA
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
cute.group_modes(tCgA, 0, 3),
|
||||
)
|
||||
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
cute.group_modes(tCgB, 0, 3),
|
||||
)
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
||||
tCgC_epi = cute.flat_divide(tCgC[((None, None), 0, 0)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, size=num_mcast_participants
|
||||
)
|
||||
|
||||
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
|
||||
num_stages=ab_stages,
|
||||
producer_group=mainloop_pipeline_producer_group,
|
||||
consumer_group=mainloop_pipeline_consumer_group,
|
||||
tx_count=num_tma_copy_bytes,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Threads/warps participating in the accumulator pipeline
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=cute.size(cta_layout_vmnk, mode=[0]) * len(epilogue_warp_ids),
|
||||
)
|
||||
|
||||
acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
||||
num_stages=acc_stages,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
consumer_group=acc_pipeline_consumer_group,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
#
|
||||
# Main loop
|
||||
#
|
||||
|
||||
num_k_tiles = cute.size(gA, mode=[2])
|
||||
|
||||
# TMA warp
|
||||
if warp_idx == tma_warp_id:
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for A/B buffers to be empty before loading into them
|
||||
handle = ab_producer.acquire_and_advance()
|
||||
|
||||
# Issue TMA loads
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tAgA[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tBgB[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_b,
|
||||
)
|
||||
|
||||
# This mbarrier_wait is preventing threadblocks within a set of dependent threadblocks within the cluster
|
||||
# (dependent in the context of the TMA/MMA synchronization pattern) to exit early making
|
||||
# a late tcgen05 commit_arrive illegal
|
||||
ab_producer.tail()
|
||||
|
||||
# MMA warp
|
||||
elif warp_idx == mma_warp_id:
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Wait for an empty accumulator buffer
|
||||
if is_leader_cta:
|
||||
acc_empty = acc_producer.acquire_and_advance()
|
||||
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
acc_empty.commit()
|
||||
|
||||
# Epilogue warps
|
||||
elif warp_idx < mma_warp_id:
|
||||
# Allocate TMEM (only epilogue warp 0 actually allocates)
|
||||
num_tmem_cols = 512
|
||||
tmem.allocate(num_tmem_cols)
|
||||
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
producer_group=epilogue_pipeline_producer_group,
|
||||
)
|
||||
|
||||
# Wait for the accumulator buffer to be full
|
||||
acc_consumer.wait_and_advance()
|
||||
|
||||
copy_atom_t2r = cute.make_copy_atom(
|
||||
tcgen05.Ld16x256bOp(tcgen05.Repetition.x8)
|
||||
if mma_tiler_mnk[0] == 64
|
||||
else tcgen05.Ld32x32bOp(tcgen05.Repetition.x32),
|
||||
cutlass.Float32,
|
||||
)
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
||||
tCtAcc_epi = cute.flat_divide(
|
||||
tCtAcc[((None, None), 0, 0)],
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
# Tiled copy for TMEM -> RMEM load
|
||||
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
||||
copy_atom_t2r, tCtAcc_epi[(None, None, 0, 0)]
|
||||
)
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc_epi)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_gC = thr_copy_t2r.partition_D(tCgC_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_rmem_tensor(
|
||||
tTR_gC[(None, None, None, 0, 0)].shape, cutlass.Float32
|
||||
)
|
||||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
|
||||
# Copy atom and tiled copy for RMEM -> SMEM load
|
||||
copy_atom_r2s = cutlass.utils.blackwell_helpers.get_smem_store_op(
|
||||
c_smem_layout_kind, cutlass.Float32, cutlass.Float32, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
|
||||
tRS_rC = cute.make_rmem_tensor(tRS_rAcc.shape, io_dtype)
|
||||
|
||||
tCgC_grouped = cute.group_modes(tCgC_tma, 1, cute.rank(tCgC_tma))
|
||||
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
|
||||
# Epilogue tiling loop
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
# TMEM -> RMEM
|
||||
tTR_tAcc_slice = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_slice, tTR_rAcc)
|
||||
|
||||
# RMEM -> SMEM
|
||||
c_buffer = subtile_idx % epi_stages
|
||||
tRS_sC_slice = tRS_sC[(None, None, None, c_buffer)]
|
||||
|
||||
# type conversion
|
||||
tRS_rC.store(tRS_rAcc.load().to(io_dtype))
|
||||
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC_slice)
|
||||
|
||||
# Memory fence and barrier to ensure shared memory stores are visible to TMA stores
|
||||
cute.arch.fence_view_async_shared()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
|
||||
epilogue_pipeline.producer_commit()
|
||||
epilogue_pipeline.producer_acquire()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
|
||||
epilogue_pipeline.producer_tail()
|
||||
|
||||
# Dealloc the tensor memory buffer
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def host_function(
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
):
|
||||
#
|
||||
# Construct tiled MMA
|
||||
#
|
||||
|
||||
op = tcgen05.MmaF16BF16Op(
|
||||
io_dtype,
|
||||
acc_dtype,
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
#
|
||||
# Construct SMEM layouts for A and B
|
||||
#
|
||||
|
||||
a_smem_layout = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
a.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
b_smem_layout = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
b.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
|
||||
# c_smem_layout_kind is an enum for row/column major, not a CuTe layout
|
||||
c_smem_layout_kind = utils.LayoutEnum.from_tensor(c)
|
||||
|
||||
#
|
||||
# Construct the VMNK layout
|
||||
#
|
||||
|
||||
cta_layout_mnk = cute.make_layout(cluster_shape_mnk)
|
||||
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
|
||||
|
||||
#
|
||||
# Construct TMA load atoms
|
||||
#
|
||||
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape, # take the layout and extract the shape internally
|
||||
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
cta_tile_shape_mnk = (
|
||||
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id),
|
||||
mma_tiler_mnk[1],
|
||||
mma_tiler_mnk[2],
|
||||
)
|
||||
|
||||
epi_tile = utils.compute_epilogue_tile_shape(
|
||||
cta_tile_shape_mnk,
|
||||
use_2cta_instrs,
|
||||
c_smem_layout_kind,
|
||||
io_dtype,
|
||||
)
|
||||
|
||||
epi_smem_layout_staged = cutlass.utils.blackwell_helpers.make_smem_layout_epi(
|
||||
io_dtype,
|
||||
c_smem_layout_kind,
|
||||
epi_tile,
|
||||
epi_stages,
|
||||
)
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
|
||||
c_tma_atom, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
#
|
||||
# Launch the kernel
|
||||
#
|
||||
|
||||
grid_shape = cute.round_up(
|
||||
(
|
||||
cute.ceil_div(
|
||||
c.layout.shape[0], mma_tiler_mnk[0] // (2 if use_2cta_instrs else 1)
|
||||
),
|
||||
cute.ceil_div(c.layout.shape[1], mma_tiler_mnk[1]),
|
||||
1,
|
||||
),
|
||||
cluster_shape_mnk,
|
||||
)
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
a_tma_atom,
|
||||
a_tma_tensor,
|
||||
b_tma_atom,
|
||||
b_tma_tensor,
|
||||
c_tma_atom,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=[192, 1, 1],
|
||||
cluster=cluster_shape_mnk,
|
||||
)
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
global torch, cutlass_torch
|
||||
import torch
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
print("===================================================================")
|
||||
print("Running Blackwell fp16 GEMM example 2 with:")
|
||||
print(f" mnk: {mnk}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
m, n, k = mnk
|
||||
torch.manual_seed(1111)
|
||||
|
||||
# Make K-major tensors (torch tensors are row-major)
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(device="cuda", dtype=dtype)
|
||||
)
|
||||
|
||||
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
|
||||
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
|
||||
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
|
||||
a_memref = from_dlpack(a).mark_layout_dynamic()
|
||||
b_memref = from_dlpack(b).mark_layout_dynamic()
|
||||
c_memref = from_dlpack(c).mark_layout_dynamic()
|
||||
|
||||
# Entry point to the host JIT function
|
||||
host_function(
|
||||
a_memref,
|
||||
b_memref,
|
||||
c_memref,
|
||||
no_cache=True,
|
||||
)
|
||||
|
||||
# Compute reference result and verify
|
||||
ref = (torch.einsum("mk,nk->mn", a, b)).cpu()
|
||||
torch.testing.assert_close(
|
||||
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> list[int]:
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",")]
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
from cuda.bindings import driver as cu_driver
|
||||
|
||||
cu_driver.cuInit(0)
|
||||
err, device_count = cu_driver.cuDeviceGetCount()
|
||||
if err != cu_driver.CUresult.CUDA_SUCCESS or device_count < 1:
|
||||
raise RuntimeError("A GPU is required to run this example")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Blackwell fp16 GEMM example 2")
|
||||
parser.add_argument(
|
||||
"--mnk",
|
||||
type=parse_comma_separated_ints,
|
||||
default=(8192, 8192, 8192),
|
||||
help="MNK dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.mnk) != 3:
|
||||
parser.error("--mnk must contain exactly 3 values")
|
||||
|
||||
run_dense_gemm(
|
||||
args.mnk,
|
||||
args.tolerance,
|
||||
)
|
||||
print("PASS")
|
||||
769
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py
Normal file
769
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py
Normal file
@@ -0,0 +1,769 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# This is the third tutorial GEMM. It further enhances the second tutorial by adding warp
|
||||
# specialization for TMA, MMA, and epilogue warps.
|
||||
|
||||
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
|
||||
"""
|
||||
The third tutorial GEMM demonstrates a simple kernel implementation in CuTeDSL.
|
||||
|
||||
Compared to fp16_gemm_2.py, this kernel uses a static persistent tile scheduler (StaticPersistentTileScheduler).
|
||||
The static scheduler simplifies work distribution by assigning tiles to CTAs in a fixed, deterministic order,
|
||||
suitable for well-partitioned workloads. With static scheduling,
|
||||
the persistent clusters can stay on the GPU throughout kernel execution and process multiple tiles, hiding prologue and epilogue costs.
|
||||
Notes that the static scheduler is susceptible to workload imbalance if the resources of some SMs are unavailable,
|
||||
which is why we add a dynamic scheduler in the next example (fp16_gemm_3_1.py).
|
||||
|
||||
Therefore, for larger problem sizes the performance will be more advantageous, since a larger problem size leads to more tiles,
|
||||
and the prologue/epilogue can be hidden among different tiles.
|
||||
This is especially true when the main loop is relatively short and the prologue/epilogue accounts for a large proportion of the work,
|
||||
in which case the performance gains become even more significant.
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_3.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
* The problem size of m and n must be divisible by the tile size m & n (256, 256)
|
||||
"""
|
||||
|
||||
io_dtype = cutlass.Float16
|
||||
acc_dtype = cutlass.Float32
|
||||
use_2cta_instrs = True
|
||||
cluster_shape_mnk = (2, 1, 1) if use_2cta_instrs else (1, 1, 1)
|
||||
mma_inst_shape_mnk = (256, 256, 16)
|
||||
mma_tiler_mnk = (256, 256, 64)
|
||||
threads_in_epilogue = 128 # epilogue threads per cta
|
||||
|
||||
# Pipeline stage configuration
|
||||
ab_stages = 6
|
||||
epi_stages = 2
|
||||
acc_stages = 2
|
||||
|
||||
# Scheduler
|
||||
scheduler_type = utils.StaticPersistentTileScheduler
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stages * 2]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: utils.PersistentTileSchedulerParams,
|
||||
):
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
|
||||
mma_tile_coord_v = bidx % cute.size(cta_layout_vmnk, mode=[0])
|
||||
is_leader_cta = mma_tile_coord_v == 0
|
||||
|
||||
epilogue_warp_ids = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
mma_warp_id = 4
|
||||
tma_warp_id = 5
|
||||
|
||||
epilog_sync_bar_id = 1
|
||||
tmem_alloc_sync_bar_id = 2
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
num_mcast_participants = (
|
||||
cute.size(cta_layout_vmnk, mode=[1]) + cute.size(cta_layout_vmnk, mode=[2]) - 1
|
||||
)
|
||||
|
||||
# Mcast mask initialization
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
# Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
|
||||
# Barrier 1 for epilogue synchronization
|
||||
epilogue_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=epilog_sync_bar_id,
|
||||
num_threads=threads_in_epilogue,
|
||||
)
|
||||
|
||||
# Only MMA warp and epilogue warps participate in TMEM allocation synchronization
|
||||
# TMA warp does NOT participate
|
||||
tmem_alloc_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=tmem_alloc_sync_bar_id,
|
||||
num_threads=32
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, size=num_mcast_participants
|
||||
)
|
||||
|
||||
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
|
||||
num_stages=ab_stages,
|
||||
producer_group=mainloop_pipeline_producer_group,
|
||||
consumer_group=mainloop_pipeline_consumer_group,
|
||||
tx_count=num_tma_copy_bytes,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Threads/warps participating in the accumulator pipeline
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=cute.size(cta_layout_vmnk, mode=[0]) * len(epilogue_warp_ids),
|
||||
)
|
||||
|
||||
acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
||||
num_stages=acc_stages,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
consumer_group=acc_pipeline_consumer_group,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
pipeline_init_arrive(cluster_shape_mn=cluster_shape_mnk, is_relaxed=True)
|
||||
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
# (bM, bK, RestM, RestK)
|
||||
gA = cute.local_tile(
|
||||
mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None)
|
||||
)
|
||||
# (bN, bK, RestN, RestK)
|
||||
gB = cute.local_tile(
|
||||
mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None)
|
||||
)
|
||||
# (bM, bN, RestM, RestN)
|
||||
gC = cute.local_tile(
|
||||
mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None)
|
||||
)
|
||||
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||||
# (MMA, MMA_M, MMA_K, RestM, RestK)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K, RestN, RestK)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N, RestM, RestN)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages))
|
||||
|
||||
# Partition tensors for TMA; This requires the tensors partitioned for MMA
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
cute.group_modes(tCgA, 0, 3),
|
||||
)
|
||||
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
cute.group_modes(tCgB, 0, 3),
|
||||
)
|
||||
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(gC_epi, 0, 2),
|
||||
)
|
||||
|
||||
# Cluster wait before starting work
|
||||
pipeline_init_wait(cluster_shape_mn=cluster_shape_mnk)
|
||||
|
||||
tile_sched = scheduler_type.create(
|
||||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
#
|
||||
# Main loop
|
||||
#
|
||||
|
||||
num_k_tiles = cute.size(gA, mode=[3])
|
||||
|
||||
# TMA warp
|
||||
if warp_idx == tma_warp_id:
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
cur_tile_coord[1],
|
||||
cur_tile_coord[2],
|
||||
)
|
||||
|
||||
# Slice to per mma tile index
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None)]
|
||||
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None)]
|
||||
|
||||
# Tma load loop
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for A/B buffers to be empty before loading into them
|
||||
handle = ab_producer.acquire_and_advance()
|
||||
|
||||
# Issue TMA loads
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_b,
|
||||
)
|
||||
|
||||
# Advance to next k_tile
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# This mbarrier_wait is preventing threadblocks within a set of dependent threadblocks within the cluster
|
||||
# (dependent in the context of the TMA/MMA synchronization pattern) to exit early making
|
||||
# a late tcgen05 commit_arrive illegal
|
||||
ab_producer.tail()
|
||||
|
||||
# MMA warp
|
||||
elif warp_idx == mma_warp_id:
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
if is_leader_cta:
|
||||
# Wait for accumulator buffer empty
|
||||
acc_empty = acc_producer.acquire_and_advance()
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
acc_empty.commit()
|
||||
|
||||
# Advance to next tile
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# Wait for accumulator buffer empty
|
||||
acc_producer.tail()
|
||||
|
||||
# Epilogue warps
|
||||
elif warp_idx < mma_warp_id:
|
||||
# Allocate TMEM (only epilogue warp 0 actually allocates)
|
||||
num_tmem_cols = 512
|
||||
tmem.allocate(num_tmem_cols)
|
||||
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
producer_group=epilogue_pipeline_producer_group,
|
||||
)
|
||||
|
||||
copy_atom_t2r = cute.make_copy_atom(
|
||||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x32, tcgen05.Pack.NONE),
|
||||
cutlass.Float32,
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
cur_tile_coord[1],
|
||||
cur_tile_coord[2],
|
||||
)
|
||||
|
||||
# Wait for accumulator buffer full
|
||||
acc_full = acc_consumer.wait_and_advance()
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_full.index)]
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
||||
tCtAcc_epi = cute.flat_divide(
|
||||
tCtAcc[((None, None), 0, 0)], # why 0,0 ?
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
mma_tile_coord_mn = cute.slice_(mma_tile_coord_mnl, (None, None, 0))
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN)
|
||||
tCgC_epi = cute.flat_divide(
|
||||
tCgC[((None, None), 0, 0, *mma_tile_coord_mn)], epi_tile
|
||||
)
|
||||
|
||||
tCgC_tma_cur_tile = tCgC_tma[(None, None, None, *mma_tile_coord_mn)]
|
||||
|
||||
# Tiled copy for TMEM -> RMEM load
|
||||
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
||||
copy_atom_t2r, tCtAcc_epi[(None, None, 0, 0)]
|
||||
)
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc_epi)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_gC = thr_copy_t2r.partition_D(tCgC_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_rmem_tensor(
|
||||
tTR_gC[(None, None, None, 0, 0)].shape, cutlass.Float32
|
||||
)
|
||||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
|
||||
# Copy atom and tiled copy for RMEM -> SMEM load
|
||||
copy_atom_r2s = cutlass.utils.blackwell_helpers.get_smem_store_op(
|
||||
c_smem_layout_kind, cutlass.Float32, cutlass.Float32, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
|
||||
tRS_rC = cute.make_rmem_tensor(tRS_rAcc.shape, io_dtype)
|
||||
tCgC_grouped = cute.group_modes(
|
||||
tCgC_tma_cur_tile, 1, cute.rank(tCgC_tma_cur_tile)
|
||||
)
|
||||
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
|
||||
# Epilogue tiling loop
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
# TMEM -> RMEM
|
||||
tTR_tAcc_slice = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_slice, tTR_rAcc)
|
||||
|
||||
# RMEM -> SMEM
|
||||
c_buffer = subtile_idx % epi_stages
|
||||
tRS_sC_slice = tRS_sC[(None, None, None, c_buffer)]
|
||||
|
||||
# type conversion
|
||||
tRS_rC.store(tRS_rAcc.load().to(io_dtype))
|
||||
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC_slice)
|
||||
|
||||
# Memory fence and barrier to ensure shared memory stores are visible to TMA stores
|
||||
cute.arch.fence_view_async_shared()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
|
||||
epilogue_pipeline.producer_commit()
|
||||
epilogue_pipeline.producer_acquire()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
|
||||
# Async arrive accumulator buffer empty
|
||||
with cute.arch.elect_one():
|
||||
acc_full.release()
|
||||
|
||||
# Advance to next tile
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# Wait for C store complete
|
||||
epilogue_pipeline.producer_tail()
|
||||
|
||||
# Dealloc the tensor memory buffer
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def compute_grid(
|
||||
c: cute.Tensor,
|
||||
mma_tiler_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mnk: Tuple[int, int, int],
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
) -> Tuple[
|
||||
utils.PersistentTileSchedulerParams,
|
||||
Tuple[int, int, int],
|
||||
]:
|
||||
c_shape = cute.slice_(mma_tiler_mnk, (None, None, 0))
|
||||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||||
num_ctas_mn = gc[(0, (None, None))].shape
|
||||
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(*num_ctas_mn, 1), cluster_shape_mnk
|
||||
)
|
||||
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
||||
tile_sched_params, max_active_clusters
|
||||
)
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
@cute.jit
|
||||
def host_function(
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
):
|
||||
#
|
||||
# Construct tiled MMA
|
||||
#
|
||||
|
||||
op = tcgen05.MmaF16BF16Op(
|
||||
io_dtype,
|
||||
acc_dtype,
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
#
|
||||
# Construct SMEM layouts for A and B
|
||||
#
|
||||
|
||||
a_smem_layout = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
a.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
b_smem_layout = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
b.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
|
||||
# c_smem_layout_kind is an enum for row/column major, not a CuTe layout
|
||||
c_smem_layout_kind = utils.LayoutEnum.from_tensor(c)
|
||||
|
||||
#
|
||||
# Construct the VMNK layout
|
||||
#
|
||||
|
||||
cta_layout_mnk = cute.make_layout(cluster_shape_mnk)
|
||||
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
|
||||
|
||||
#
|
||||
# Construct TMA load atoms
|
||||
#
|
||||
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
cta_tile_shape_mnk = (
|
||||
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id),
|
||||
mma_tiler_mnk[1],
|
||||
mma_tiler_mnk[2],
|
||||
)
|
||||
|
||||
epi_tile = utils.compute_epilogue_tile_shape(
|
||||
cta_tile_shape_mnk,
|
||||
use_2cta_instrs,
|
||||
c_smem_layout_kind,
|
||||
io_dtype,
|
||||
)
|
||||
|
||||
epi_smem_layout_staged = cutlass.utils.blackwell_helpers.make_smem_layout_epi(
|
||||
io_dtype,
|
||||
c_smem_layout_kind,
|
||||
epi_tile,
|
||||
epi_stages,
|
||||
)
|
||||
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
c_tma_atom, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
#
|
||||
# Launch the kernel
|
||||
#
|
||||
|
||||
tile_sched_params, grid_shape = compute_grid(
|
||||
c,
|
||||
cta_tile_shape_mnk,
|
||||
cluster_shape_mnk,
|
||||
max_active_clusters,
|
||||
)
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
a_tma_atom,
|
||||
a_tma_tensor,
|
||||
b_tma_atom,
|
||||
b_tma_tensor,
|
||||
c_tma_atom,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
tile_sched_params,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=[192, 1, 1],
|
||||
cluster=cluster_shape_mnk,
|
||||
)
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
global torch, cutlass_torch
|
||||
import torch
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
print("===================================================================")
|
||||
print("Running Blackwell fp16 GEMM example 3 with:")
|
||||
print(f" mnk: {mnk}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
m, n, k = mnk
|
||||
torch.manual_seed(1111)
|
||||
|
||||
# Make K-major tensors (torch tensors are row-major)
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(device="cuda", dtype=dtype)
|
||||
)
|
||||
|
||||
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
|
||||
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
|
||||
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
|
||||
a_memref = from_dlpack(a).mark_layout_dynamic()
|
||||
b_memref = from_dlpack(b).mark_layout_dynamic()
|
||||
c_memref = from_dlpack(c).mark_layout_dynamic()
|
||||
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||||
cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
||||
)
|
||||
|
||||
# Entry point to the host JIT function
|
||||
host_function(
|
||||
a_memref,
|
||||
b_memref,
|
||||
c_memref,
|
||||
max_active_clusters,
|
||||
no_cache=True,
|
||||
)
|
||||
|
||||
# Compute reference result and verify
|
||||
ref = (torch.einsum("mk,nk->mn", a, b)).cpu()
|
||||
torch.testing.assert_close(
|
||||
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",")]
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
from cuda.bindings import driver as cu_driver
|
||||
|
||||
cu_driver.cuInit(0)
|
||||
err, device_count = cu_driver.cuDeviceGetCount()
|
||||
if err != cu_driver.CUresult.CUDA_SUCCESS or device_count < 1:
|
||||
raise RuntimeError("A GPU is required to run this example")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Blackwell fp16 GEMM example 3")
|
||||
parser.add_argument(
|
||||
"--mnk",
|
||||
type=parse_comma_separated_ints,
|
||||
default=(8192, 8192, 8192),
|
||||
help="MNK dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.mnk) != 3:
|
||||
parser.error("--mnk must contain exactly 3 values")
|
||||
|
||||
run_dense_gemm(
|
||||
args.mnk,
|
||||
args.tolerance,
|
||||
)
|
||||
print("PASS")
|
||||
882
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py
Normal file
882
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py
Normal file
@@ -0,0 +1,882 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# This is the third tutorial GEMM. It further enhances the second tutorial by adding warp
|
||||
# specialization for TMA, MMA, and epilogue warps.
|
||||
|
||||
|
||||
import argparse
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
|
||||
"""
|
||||
The third tutorial GEMM demonstrates a simple kernel implementation in CuTeDSL.
|
||||
|
||||
Compared to fp16_gemm_3.py, this kernel uses a dynamic persistent tile scheduler (ClcDynamicPersistentTileScheduler).
|
||||
The dynamic scheduler is more flexible than the static scheduler, as it can handle workload imbalance better.
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_3_1.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
* The problem size of m and n must be divisible by the tile size m & n (256, 256)
|
||||
"""
|
||||
|
||||
io_dtype = cutlass.Float16
|
||||
acc_dtype = cutlass.Float32
|
||||
use_2cta_instrs = True
|
||||
cluster_shape_mnk = (2, 1, 1) if use_2cta_instrs else (1, 1, 1)
|
||||
mma_inst_shape_mnk = (256, 256, 16)
|
||||
mma_tiler_mnk = (256, 256, 64)
|
||||
threads_in_epilogue = 128 # epilogue threads per cta
|
||||
|
||||
# Pipeline stage configuration
|
||||
ab_stages = 6
|
||||
epi_stages = 2
|
||||
acc_stages = 2
|
||||
num_clc_stage = 1
|
||||
|
||||
# Scheduler
|
||||
use_clc_dynamic_scheduler = True
|
||||
scheduler_type = (
|
||||
utils.ClcDynamicPersistentTileScheduler
|
||||
if use_clc_dynamic_scheduler
|
||||
else utils.StaticPersistentTileScheduler
|
||||
)
|
||||
# Response size is 4B * 4 elements
|
||||
num_clc_response_bytes = 16
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stages * 2]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
# Only for CLC Dynamic Scheduler
|
||||
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: Union[
|
||||
utils.ClcDynamicPersistentTileSchedulerParams,
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
):
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
|
||||
mma_tile_coord_v = bidx % cute.size(cta_layout_vmnk, mode=[0])
|
||||
is_leader_cta = mma_tile_coord_v == 0
|
||||
|
||||
epilogue_warp_ids = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
mma_warp_id = 4
|
||||
tma_warp_id = 5
|
||||
# sched_warp_id only for dynamic scheduler
|
||||
sched_warp_id = 6
|
||||
|
||||
epilog_sync_bar_id = 1
|
||||
tmem_alloc_sync_bar_id = 2
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
num_mcast_participants = (
|
||||
cute.size(cta_layout_vmnk, mode=[1]) + cute.size(cta_layout_vmnk, mode=[2]) - 1
|
||||
)
|
||||
|
||||
# Mcast mask initialization
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
# Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
|
||||
# Barrier 1 for epilogue synchronization
|
||||
epilogue_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=epilog_sync_bar_id,
|
||||
num_threads=threads_in_epilogue,
|
||||
)
|
||||
|
||||
# Only MMA warp and epilogue warps participate in TMEM allocation synchronization
|
||||
# TMA warp does NOT participate
|
||||
tmem_alloc_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=tmem_alloc_sync_bar_id,
|
||||
num_threads=32
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, size=num_mcast_participants
|
||||
)
|
||||
|
||||
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
|
||||
num_stages=ab_stages,
|
||||
producer_group=mainloop_pipeline_producer_group,
|
||||
consumer_group=mainloop_pipeline_consumer_group,
|
||||
tx_count=num_tma_copy_bytes,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Threads/warps participating in the accumulator pipeline
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=cute.size(cta_layout_vmnk, mode=[0]) * len(epilogue_warp_ids),
|
||||
)
|
||||
|
||||
acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
||||
num_stages=acc_stages,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
consumer_group=acc_pipeline_consumer_group,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Initialize clc_pipeline (barrier) and states
|
||||
# ONLY for CLC Dynamic Scheduler
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
cluster_size = cute.size(cluster_shape_mnk)
|
||||
num_clc_consumer_threads = 32 * len(
|
||||
(
|
||||
sched_warp_id,
|
||||
*(
|
||||
cluster_size
|
||||
* (
|
||||
mma_warp_id,
|
||||
tma_warp_id,
|
||||
*epilogue_warp_ids,
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
clc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, num_clc_consumer_threads
|
||||
)
|
||||
clc_pipeline = pipeline.PipelineClcFetchAsync.create(
|
||||
barrier_storage=storage.clc_mbar_ptr.data_ptr(),
|
||||
num_stages=num_clc_stage,
|
||||
producer_group=clc_pipeline_producer_group,
|
||||
consumer_group=clc_pipeline_consumer_group,
|
||||
tx_count=num_clc_response_bytes,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
defer_sync=True,
|
||||
)
|
||||
# Initial clc response pointer
|
||||
clc_response_ptr = storage.clc_response.data_ptr()
|
||||
|
||||
clc_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, num_clc_stage
|
||||
)
|
||||
else:
|
||||
clc_pipeline = None
|
||||
clc_response_ptr = None
|
||||
clc_consumer_state = None
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
pipeline_init_arrive(cluster_shape_mn=cluster_shape_mnk, is_relaxed=True)
|
||||
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
# (bM, bK, RestM, RestK)
|
||||
gA = cute.local_tile(
|
||||
mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None)
|
||||
)
|
||||
# (bN, bK, RestN, RestK)
|
||||
gB = cute.local_tile(
|
||||
mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None)
|
||||
)
|
||||
# (bM, bN, RestM, RestN)
|
||||
gC = cute.local_tile(
|
||||
mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None)
|
||||
)
|
||||
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||||
# (MMA, MMA_M, MMA_K, RestM, RestK)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K, RestN, RestK)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N, RestM, RestN)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages))
|
||||
|
||||
# Partition tensors for TMA; This requires the tensors partitioned for MMA
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
cute.group_modes(tCgA, 0, 3),
|
||||
)
|
||||
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
cute.group_modes(tCgB, 0, 3),
|
||||
)
|
||||
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(gC_epi, 0, 2),
|
||||
)
|
||||
|
||||
# Cluster wait before starting work
|
||||
pipeline_init_wait(cluster_shape_mn=cluster_shape_mnk)
|
||||
|
||||
# Construct the scheduler
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
tile_sched = scheduler_type.create(
|
||||
tile_sched_params,
|
||||
cute.arch.block_idx(),
|
||||
cute.arch.grid_dim(),
|
||||
clc_response_ptr,
|
||||
)
|
||||
else:
|
||||
tile_sched = scheduler_type.create(
|
||||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
#
|
||||
# Main loop
|
||||
#
|
||||
|
||||
num_k_tiles = cute.size(gA, mode=[3])
|
||||
|
||||
# TMA warp
|
||||
if warp_idx == tma_warp_id:
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
cur_tile_coord[1],
|
||||
cur_tile_coord[2],
|
||||
)
|
||||
|
||||
# Slice to per mma tile index
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None)]
|
||||
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None)]
|
||||
|
||||
# Tma load loop
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for A/B buffers to be empty before loading into them
|
||||
handle = ab_producer.acquire_and_advance()
|
||||
|
||||
# Issue TMA loads
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_b,
|
||||
)
|
||||
|
||||
# Advance to next k_tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
else:
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# This mbarrier_wait is preventing threadblocks within a set of dependent threadblocks within the cluster
|
||||
# (dependent in the context of the TMA/MMA synchronization pattern) to exit early making
|
||||
# a late tcgen05 commit_arrive illegal
|
||||
ab_producer.tail()
|
||||
|
||||
# Sched warp (only for dynamic scheduler)
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
is_first_cta_in_cluster = cta_rank_in_cluster == 0
|
||||
|
||||
if warp_idx == sched_warp_id and is_first_cta_in_cluster:
|
||||
# Persistent tile scheduling loop
|
||||
clc_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.ProducerConsumer, num_clc_stage
|
||||
)
|
||||
while work_tile.is_valid_tile:
|
||||
# Advance to next tile
|
||||
clc_pipeline.producer_acquire(clc_producer_state)
|
||||
mbarrier_addr = clc_pipeline.producer_get_barrier(clc_producer_state)
|
||||
tile_sched.advance_to_next_work(mbarrier_addr)
|
||||
clc_producer_state.advance()
|
||||
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
clc_pipeline.producer_tail(clc_producer_state)
|
||||
|
||||
# MMA warp
|
||||
if warp_idx == mma_warp_id:
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
if is_leader_cta:
|
||||
# Wait for accumulator buffer empty
|
||||
acc_empty = acc_producer.acquire_and_advance()
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
acc_empty.commit()
|
||||
|
||||
# Advance to next tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
else:
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# Wait for accumulator buffer empty
|
||||
acc_producer.tail()
|
||||
|
||||
# Epilogue warps
|
||||
if warp_idx < mma_warp_id:
|
||||
# Allocate TMEM (only epilogue warp 0 actually allocates)
|
||||
num_tmem_cols = 512
|
||||
tmem.allocate(num_tmem_cols)
|
||||
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
producer_group=epilogue_pipeline_producer_group,
|
||||
)
|
||||
|
||||
copy_atom_t2r = cute.make_copy_atom(
|
||||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x32, tcgen05.Pack.NONE),
|
||||
cutlass.Float32,
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
cur_tile_coord[1],
|
||||
cur_tile_coord[2],
|
||||
)
|
||||
|
||||
# Wait for accumulator buffer full
|
||||
acc_full = acc_consumer.wait_and_advance()
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_full.index)]
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
||||
tCtAcc_epi = cute.flat_divide(
|
||||
tCtAcc[((None, None), 0, 0)], # why 0,0 ?
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
mma_tile_coord_mn = cute.slice_(mma_tile_coord_mnl, (None, None, 0))
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN)
|
||||
tCgC_epi = cute.flat_divide(
|
||||
tCgC[((None, None), 0, 0, *mma_tile_coord_mn)], epi_tile
|
||||
)
|
||||
|
||||
tCgC_tma_cur_tile = tCgC_tma[(None, None, None, *mma_tile_coord_mn)]
|
||||
|
||||
# Tiled copy for TMEM -> RMEM load
|
||||
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
||||
copy_atom_t2r, tCtAcc_epi[(None, None, 0, 0)]
|
||||
)
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc_epi)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_gC = thr_copy_t2r.partition_D(tCgC_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_rmem_tensor(
|
||||
tTR_gC[(None, None, None, 0, 0)].shape, cutlass.Float32
|
||||
)
|
||||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
|
||||
# Copy atom and tiled copy for RMEM -> SMEM load
|
||||
copy_atom_r2s = cutlass.utils.blackwell_helpers.get_smem_store_op(
|
||||
c_smem_layout_kind, cutlass.Float32, cutlass.Float32, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
|
||||
tRS_rC = cute.make_rmem_tensor(tRS_rAcc.shape, io_dtype)
|
||||
tCgC_grouped = cute.group_modes(
|
||||
tCgC_tma_cur_tile, 1, cute.rank(tCgC_tma_cur_tile)
|
||||
)
|
||||
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
|
||||
# Epilogue tiling loop
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
# TMEM -> RMEM
|
||||
tTR_tAcc_slice = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_slice, tTR_rAcc)
|
||||
|
||||
# RMEM -> SMEM
|
||||
c_buffer = subtile_idx % epi_stages
|
||||
tRS_sC_slice = tRS_sC[(None, None, None, c_buffer)]
|
||||
|
||||
# type conversion
|
||||
tRS_rC.store(tRS_rAcc.load().to(io_dtype))
|
||||
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC_slice)
|
||||
|
||||
# Memory fence and barrier to ensure shared memory stores are visible to TMA stores
|
||||
cute.arch.fence_view_async_shared()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
|
||||
epilogue_pipeline.producer_commit()
|
||||
epilogue_pipeline.producer_acquire()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
|
||||
# Async arrive accumulator buffer empty
|
||||
with cute.arch.elect_one():
|
||||
acc_full.release()
|
||||
|
||||
# Advance to next tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
else:
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# Wait for C store complete
|
||||
epilogue_pipeline.producer_tail()
|
||||
|
||||
# Dealloc the tensor memory buffer
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def compute_grid(
|
||||
c: cute.Tensor,
|
||||
mma_tiler_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mnk: Tuple[int, int, int],
|
||||
scheduler_type: Union[
|
||||
utils.StaticPersistentTileScheduler, utils.ClcDynamicPersistentTileScheduler
|
||||
],
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
) -> Tuple[
|
||||
Union[
|
||||
utils.ClcDynamicPersistentTileSchedulerParams,
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
Tuple[int, int, int],
|
||||
]:
|
||||
c_shape = cute.slice_(mma_tiler_mnk, (None, None, 0))
|
||||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||||
num_ctas_mn = gc[(0, (None, None))].shape
|
||||
|
||||
if cutlass.const_expr(
|
||||
issubclass(scheduler_type, utils.ClcDynamicPersistentTileScheduler)
|
||||
):
|
||||
tile_sched_params = utils.ClcDynamicPersistentTileSchedulerParams(
|
||||
(*num_ctas_mn, 1), cluster_shape_mnk
|
||||
)
|
||||
grid = utils.ClcDynamicPersistentTileScheduler.get_grid_shape(tile_sched_params)
|
||||
else:
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(*num_ctas_mn, 1), cluster_shape_mnk
|
||||
)
|
||||
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
||||
tile_sched_params, max_active_clusters
|
||||
)
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
@cute.jit
|
||||
def host_function(
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
):
|
||||
#
|
||||
# Construct tiled MMA
|
||||
#
|
||||
|
||||
op = tcgen05.MmaF16BF16Op(
|
||||
io_dtype,
|
||||
acc_dtype,
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
#
|
||||
# Construct SMEM layouts for A and B
|
||||
#
|
||||
|
||||
a_smem_layout = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
a.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
b_smem_layout = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
b.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
|
||||
# c_smem_layout_kind is an enum for row/column major, not a CuTe layout
|
||||
c_smem_layout_kind = utils.LayoutEnum.from_tensor(c)
|
||||
|
||||
#
|
||||
# Construct the VMNK layout
|
||||
#
|
||||
|
||||
cta_layout_mnk = cute.make_layout(cluster_shape_mnk)
|
||||
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
|
||||
|
||||
#
|
||||
# Construct TMA load atoms
|
||||
#
|
||||
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
cta_tile_shape_mnk = (
|
||||
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id),
|
||||
mma_tiler_mnk[1],
|
||||
mma_tiler_mnk[2],
|
||||
)
|
||||
|
||||
epi_tile = utils.compute_epilogue_tile_shape(
|
||||
cta_tile_shape_mnk,
|
||||
use_2cta_instrs,
|
||||
c_smem_layout_kind,
|
||||
io_dtype,
|
||||
)
|
||||
|
||||
epi_smem_layout_staged = cutlass.utils.blackwell_helpers.make_smem_layout_epi(
|
||||
io_dtype,
|
||||
c_smem_layout_kind,
|
||||
epi_tile,
|
||||
epi_stages,
|
||||
)
|
||||
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
c_tma_atom, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
#
|
||||
# Launch the kernel
|
||||
#
|
||||
|
||||
tile_sched_params, grid_shape = compute_grid(
|
||||
c,
|
||||
cta_tile_shape_mnk,
|
||||
cluster_shape_mnk,
|
||||
scheduler_type,
|
||||
max_active_clusters,
|
||||
)
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
a_tma_atom,
|
||||
a_tma_tensor,
|
||||
b_tma_atom,
|
||||
b_tma_tensor,
|
||||
c_tma_atom,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
tile_sched_params,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=[224, 1, 1] if use_clc_dynamic_scheduler else [192, 1, 1],
|
||||
cluster=cluster_shape_mnk,
|
||||
)
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
global torch, cutlass_torch
|
||||
import torch
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
print("===================================================================")
|
||||
print("Running Blackwell fp16 GEMM example 3_1 with:")
|
||||
print(f" mnk: {mnk}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
m, n, k = mnk
|
||||
torch.manual_seed(1111)
|
||||
|
||||
# Make K-major tensors (torch tensors are row-major)
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(device="cuda", dtype=dtype)
|
||||
)
|
||||
|
||||
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
|
||||
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
|
||||
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
|
||||
a_memref = from_dlpack(a).mark_layout_dynamic()
|
||||
b_memref = from_dlpack(b).mark_layout_dynamic()
|
||||
c_memref = from_dlpack(c).mark_layout_dynamic()
|
||||
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||||
cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
||||
)
|
||||
|
||||
# Entry point to the host JIT function
|
||||
host_function(
|
||||
a_memref,
|
||||
b_memref,
|
||||
c_memref,
|
||||
max_active_clusters,
|
||||
no_cache=True,
|
||||
)
|
||||
|
||||
# Compute reference result and verify
|
||||
ref = (torch.einsum("mk,nk->mn", a, b)).cpu()
|
||||
torch.testing.assert_close(
|
||||
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",")]
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
from cuda.bindings import driver as cu_driver
|
||||
|
||||
cu_driver.cuInit(0)
|
||||
err, device_count = cu_driver.cuDeviceGetCount()
|
||||
if err != cu_driver.CUresult.CUDA_SUCCESS or device_count < 1:
|
||||
raise RuntimeError("A GPU is required to run this example")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Blackwell fp16 GEMM example 3_1")
|
||||
parser.add_argument(
|
||||
"--mnk",
|
||||
type=parse_comma_separated_ints,
|
||||
default=(8192, 8192, 8192),
|
||||
help="MNK dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.mnk) != 3:
|
||||
parser.error("--mnk must contain exactly 3 values")
|
||||
|
||||
run_dense_gemm(
|
||||
args.mnk,
|
||||
args.tolerance,
|
||||
)
|
||||
print("PASS")
|
||||
1065
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py
Normal file
1065
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py
Normal file
File diff suppressed because it is too large
Load Diff
919
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py
Normal file
919
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py
Normal file
@@ -0,0 +1,919 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# This is the fifth tutorial GEMM (5). It extends fp16_gemm_3_1.py by adding TMA prefetch.
|
||||
# TMA prefetch uses cute.prefetch() to bring data into L2 cache before TMA copy needs it,
|
||||
# helping to hide DRAM latency for memory-bound workloads.
|
||||
|
||||
|
||||
import argparse
|
||||
from typing import Tuple, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
|
||||
"""
|
||||
The fifth tutorial GEMM (5) demonstrates TMA prefetch optimization in CuTeDSL.
|
||||
|
||||
TMA Prefetch uses cute.prefetch() to bring data from DRAM into L2 cache ahead of time.
|
||||
This helps hide DRAM latency, which is particularly beneficial for memory-bound workloads.
|
||||
|
||||
TMA Prefetch consists of two phases:
|
||||
1. Initial Phase: Before the TMA load loop starts, prefetch the first `prefetch_dist`
|
||||
k-tiles into L2 cache. This primes the cache before any TMA copies begin.
|
||||
2. Rolling Phase: During each iteration of the TMA load loop, after issuing TMA copy
|
||||
for the current k-tile, prefetch the k-tile that is `prefetch_dist` ahead.
|
||||
|
||||
Key differences from fp16_gemm_3_1.py:
|
||||
1. Added cute.prefetch() calls to bring data into L2 cache before TMA copy
|
||||
2. Initial prefetch loop before the main TMA load loop
|
||||
3. Rolling prefetch inside the TMA load loop to keep L2 primed
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_5.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
* The problem size of m and n must be divisible by the tile size m & n (256, 256)
|
||||
"""
|
||||
|
||||
io_dtype = cutlass.Float16
|
||||
acc_dtype = cutlass.Float32
|
||||
use_2cta_instrs = True
|
||||
cluster_shape_mnk = (2, 2, 1) if use_2cta_instrs else (1, 1, 1)
|
||||
mma_inst_shape_mnk = (256, 64, 16)
|
||||
mma_tiler_mnk = (256, 64, 64)
|
||||
threads_in_epilogue = 128 # epilogue threads per cta
|
||||
|
||||
# Pipeline stage configuration
|
||||
ab_stages = 10
|
||||
epi_stages = 2
|
||||
acc_stages = 2
|
||||
num_clc_stage = 1
|
||||
|
||||
# Scheduler
|
||||
use_clc_dynamic_scheduler = True
|
||||
scheduler_type = (
|
||||
utils.ClcDynamicPersistentTileScheduler
|
||||
if use_clc_dynamic_scheduler
|
||||
else utils.StaticPersistentTileScheduler
|
||||
)
|
||||
# Response size is 4B * 4 elements
|
||||
num_clc_response_bytes = 16
|
||||
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
|
||||
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stages * 2]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
# Only for CLC Dynamic Scheduler
|
||||
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: Union[
|
||||
utils.ClcDynamicPersistentTileSchedulerParams,
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
):
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
|
||||
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
|
||||
|
||||
mma_tile_coord_v = bidx % cute.size(cta_layout_vmnk, mode=[0])
|
||||
is_leader_cta = mma_tile_coord_v == 0
|
||||
|
||||
epilogue_warp_ids = (
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
)
|
||||
mma_warp_id = 4
|
||||
tma_warp_id = 5
|
||||
# sched_warp_id only for dynamic scheduler
|
||||
sched_warp_id = 6
|
||||
|
||||
epilog_sync_bar_id = 1
|
||||
tmem_alloc_sync_bar_id = 2
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
num_mcast_participants = (
|
||||
cute.size(cta_layout_vmnk, mode=[1]) + cute.size(cta_layout_vmnk, mode=[2]) - 1
|
||||
)
|
||||
|
||||
# Mcast mask initialization
|
||||
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
|
||||
)
|
||||
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
|
||||
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
|
||||
)
|
||||
|
||||
# Allocate SMEM
|
||||
smem = cutlass.utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
|
||||
# Barrier 1 for epilogue synchronization
|
||||
epilogue_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=epilog_sync_bar_id,
|
||||
num_threads=threads_in_epilogue,
|
||||
)
|
||||
|
||||
# Only MMA warp and epilogue warps participate in TMEM allocation synchronization
|
||||
# TMA warp does NOT participate
|
||||
tmem_alloc_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=tmem_alloc_sync_bar_id,
|
||||
num_threads=32
|
||||
* len((mma_warp_id, *epilogue_warp_ids)), # 5 warps = 160 threads
|
||||
)
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding_buffer,
|
||||
barrier_for_retrieve=tmem_alloc_barrier,
|
||||
allocator_warp_id=epilogue_warp_ids[0],
|
||||
is_two_cta=True,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar,
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, size=num_mcast_participants
|
||||
)
|
||||
|
||||
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
|
||||
num_stages=ab_stages,
|
||||
producer_group=mainloop_pipeline_producer_group,
|
||||
consumer_group=mainloop_pipeline_consumer_group,
|
||||
tx_count=num_tma_copy_bytes,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Threads/warps participating in the accumulator pipeline
|
||||
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=cute.size(cta_layout_vmnk, mode=[0]) * len(epilogue_warp_ids),
|
||||
)
|
||||
|
||||
acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
|
||||
num_stages=acc_stages,
|
||||
producer_group=acc_pipeline_producer_group,
|
||||
consumer_group=acc_pipeline_consumer_group,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
).make_participants()
|
||||
|
||||
# Initialize clc_pipeline (barrier) and states
|
||||
# ONLY for CLC Dynamic Scheduler
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
cluster_size = cute.size(cluster_shape_mnk)
|
||||
num_clc_consumer_threads = 32 * len(
|
||||
(
|
||||
sched_warp_id,
|
||||
*(
|
||||
cluster_size
|
||||
* (
|
||||
mma_warp_id,
|
||||
tma_warp_id,
|
||||
*epilogue_warp_ids,
|
||||
)
|
||||
),
|
||||
)
|
||||
)
|
||||
clc_pipeline_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, num_clc_consumer_threads
|
||||
)
|
||||
clc_pipeline = pipeline.PipelineClcFetchAsync.create(
|
||||
barrier_storage=storage.clc_mbar_ptr.data_ptr(),
|
||||
num_stages=num_clc_stage,
|
||||
producer_group=clc_pipeline_producer_group,
|
||||
consumer_group=clc_pipeline_consumer_group,
|
||||
tx_count=num_clc_response_bytes,
|
||||
cta_layout_vmnk=cta_layout_vmnk,
|
||||
defer_sync=True,
|
||||
)
|
||||
# Initial clc response pointer
|
||||
clc_response_ptr = storage.clc_response.data_ptr()
|
||||
|
||||
clc_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, num_clc_stage
|
||||
)
|
||||
else:
|
||||
clc_pipeline = None
|
||||
clc_response_ptr = None
|
||||
clc_consumer_state = None
|
||||
|
||||
# Cluster arrive after barrier init
|
||||
pipeline_init_arrive(cluster_shape_mn=cluster_shape_mnk, is_relaxed=True)
|
||||
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
# (bM, bK, RestM, RestK)
|
||||
gA = cute.local_tile(
|
||||
mA_mkl, cute.slice_(mma_tiler_mnk, (None, 0, None)), (None, None)
|
||||
)
|
||||
# (bN, bK, RestN, RestK)
|
||||
gB = cute.local_tile(
|
||||
mB_nkl, cute.slice_(mma_tiler_mnk, (0, None, None)), (None, None)
|
||||
)
|
||||
# (bM, bN, RestM, RestN)
|
||||
gC = cute.local_tile(
|
||||
mC_mnl, cute.slice_(mma_tiler_mnk, (None, None, 0)), (None, None)
|
||||
)
|
||||
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
|
||||
# (MMA, MMA_M, MMA_K, RestM, RestK)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
# (MMA, MMA_N, MMA_K, RestN, RestK)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
# (MMA, MMA_M, MMA_N, RestM, RestN)
|
||||
tCgC = thr_mma.partition_C(gC)
|
||||
|
||||
# (MMA, MMA_M, MMA_K, STAGE)
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
# (MMA, MMA_N, MMA_K, STAGE)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages))
|
||||
|
||||
# Partition tensors for TMA; This requires the tensors partitioned for MMA
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
cute.group_modes(tCgA, 0, 3),
|
||||
)
|
||||
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
cute.group_modes(tCgB, 0, 3),
|
||||
)
|
||||
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(gC_epi, 0, 2),
|
||||
)
|
||||
|
||||
# Cluster wait before starting work
|
||||
pipeline_init_wait(cluster_shape_mn=cluster_shape_mnk)
|
||||
|
||||
# Construct the scheduler
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
tile_sched = scheduler_type.create(
|
||||
tile_sched_params,
|
||||
cute.arch.block_idx(),
|
||||
cute.arch.grid_dim(),
|
||||
clc_response_ptr,
|
||||
)
|
||||
else:
|
||||
tile_sched = scheduler_type.create(
|
||||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
#
|
||||
# Main loop
|
||||
#
|
||||
|
||||
num_k_tiles = cute.size(gA, mode=[3])
|
||||
|
||||
# Prefetch distance: how many k-tiles ahead to prefetch into L2 cache
|
||||
# This helps hide DRAM latency by bringing data to L2 before TMA copy needs it
|
||||
prefetch_dist = ab_stages
|
||||
|
||||
# TMA warp with prefetch
|
||||
if warp_idx == tma_warp_id:
|
||||
#
|
||||
# Persistent tile scheduling loop
|
||||
#
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
cur_tile_coord[1],
|
||||
cur_tile_coord[2],
|
||||
)
|
||||
|
||||
# Slice to per mma tile index
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None)]
|
||||
|
||||
# ((atom_v, rest_v), RestK)
|
||||
tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None)]
|
||||
|
||||
# =========================================================
|
||||
# TMA Prefetch - Initial Phase
|
||||
# =========================================================
|
||||
# Prefetch the first `prefetch_dist` k-tiles into L2 cache
|
||||
# This primes the cache before TMA copies start
|
||||
for pf_k_tile in cutlass.range(
|
||||
cutlass.min(prefetch_dist, num_k_tiles), unroll=1
|
||||
):
|
||||
cute.prefetch(tma_atom_a, tAgA_slice[(None, pf_k_tile)])
|
||||
cute.prefetch(tma_atom_b, tBgB_slice[(None, pf_k_tile)])
|
||||
|
||||
# =========================================================
|
||||
# TMA Load Loop with Rolling Prefetch
|
||||
# =========================================================
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for A/B buffers to be empty before loading into them
|
||||
handle = ab_producer.acquire_and_advance()
|
||||
|
||||
# Issue TMA loads (use k_tile_idx like fp16_gemm_3_1.py)
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_b,
|
||||
)
|
||||
|
||||
# Rolling prefetch: prefetch future k-tiles into L2 cache
|
||||
# This keeps the L2 primed as we progress through the K dimension
|
||||
if k_tile_idx + prefetch_dist < num_k_tiles:
|
||||
future_k_tile = k_tile_idx + prefetch_dist
|
||||
cute.prefetch(tma_atom_a, tAgA_slice[(None, future_k_tile)])
|
||||
cute.prefetch(tma_atom_b, tBgB_slice[(None, future_k_tile)])
|
||||
|
||||
# Advance to next tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
else:
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# This mbarrier_wait is preventing threadblocks within a set of dependent threadblocks within the cluster
|
||||
# (dependent in the context of the TMA/MMA synchronization pattern) to exit early making
|
||||
# a late tcgen05 commit_arrive illegal
|
||||
ab_producer.tail()
|
||||
|
||||
# Sched warp (only for dynamic scheduler)
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
is_first_cta_in_cluster = cta_rank_in_cluster == 0
|
||||
|
||||
if warp_idx == sched_warp_id and is_first_cta_in_cluster:
|
||||
# Persistent tile scheduling loop
|
||||
clc_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.ProducerConsumer, num_clc_stage
|
||||
)
|
||||
while work_tile.is_valid_tile:
|
||||
# Advance to next tile
|
||||
clc_pipeline.producer_acquire(clc_producer_state)
|
||||
mbarrier_addr = clc_pipeline.producer_get_barrier(clc_producer_state)
|
||||
tile_sched.advance_to_next_work(mbarrier_addr)
|
||||
clc_producer_state.advance()
|
||||
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
clc_pipeline.producer_tail(clc_producer_state)
|
||||
|
||||
# MMA warp
|
||||
if warp_idx == mma_warp_id:
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
if is_leader_cta:
|
||||
# Wait for accumulator buffer empty
|
||||
acc_empty = acc_producer.acquire_and_advance()
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
|
||||
# Signal that the accumulator is fully computed
|
||||
acc_empty.commit()
|
||||
|
||||
# Advance to next tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
else:
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# Wait for accumulator buffer empty
|
||||
acc_producer.tail()
|
||||
|
||||
# Epilogue warps
|
||||
if warp_idx < mma_warp_id:
|
||||
# Allocate TMEM (only epilogue warp 0 actually allocates)
|
||||
num_tmem_cols = 512
|
||||
tmem.allocate(num_tmem_cols)
|
||||
|
||||
# Wait for TMEM allocation and retrieve pointer
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
producer_group=epilogue_pipeline_producer_group,
|
||||
)
|
||||
|
||||
copy_atom_t2r = cute.make_copy_atom(
|
||||
tcgen05.Ld32x32bOp(tcgen05.Repetition.x32, tcgen05.Pack.NONE),
|
||||
cutlass.Float32,
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
mma_tile_coord_mnl = (
|
||||
cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
cur_tile_coord[1],
|
||||
cur_tile_coord[2],
|
||||
)
|
||||
|
||||
# Wait for accumulator buffer full
|
||||
acc_full = acc_consumer.wait_and_advance()
|
||||
|
||||
# Set tensor memory buffer for current tile
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_full.index)]
|
||||
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N)
|
||||
tCtAcc_epi = cute.flat_divide(
|
||||
tCtAcc[((None, None), 0, 0)], # why 0,0 ?
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
mma_tile_coord_mn = cute.slice_(mma_tile_coord_mnl, (None, None, 0))
|
||||
# (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN)
|
||||
tCgC_epi = cute.flat_divide(
|
||||
tCgC[((None, None), 0, 0, *mma_tile_coord_mn)], epi_tile
|
||||
)
|
||||
|
||||
tCgC_tma_cur_tile = tCgC_tma[(None, None, None, *mma_tile_coord_mn)]
|
||||
|
||||
# Tiled copy for TMEM -> RMEM load
|
||||
tiled_copy_t2r = tcgen05.make_tmem_copy(
|
||||
copy_atom_t2r, tCtAcc_epi[(None, None, 0, 0)]
|
||||
)
|
||||
thr_copy_t2r = tiled_copy_t2r.get_slice(tidx)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_tAcc = thr_copy_t2r.partition_S(tCtAcc_epi)
|
||||
# (T2R, T2R_M, T2R_N, EPI_M, EPI_N)
|
||||
tTR_gC = thr_copy_t2r.partition_D(tCgC_epi)
|
||||
# (T2R, T2R_M, T2R_N)
|
||||
tTR_rAcc = cute.make_rmem_tensor(
|
||||
tTR_gC[(None, None, None, 0, 0)].shape, cutlass.Float32
|
||||
)
|
||||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
|
||||
# Copy atom and tiled copy for RMEM -> SMEM load
|
||||
copy_atom_r2s = cutlass.utils.blackwell_helpers.get_smem_store_op(
|
||||
c_smem_layout_kind, cutlass.Float32, cutlass.Float32, tiled_copy_t2r
|
||||
)
|
||||
tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r)
|
||||
# (R2S, R2S_M, R2S_N, PIPE_D)
|
||||
thr_copy_r2s = tiled_copy_r2s.get_slice(tidx)
|
||||
tRS_sC = thr_copy_r2s.partition_D(sC)
|
||||
|
||||
tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc)
|
||||
tRS_rC = cute.make_rmem_tensor(tRS_rAcc.shape, io_dtype)
|
||||
tCgC_grouped = cute.group_modes(
|
||||
tCgC_tma_cur_tile, 1, cute.rank(tCgC_tma_cur_tile)
|
||||
)
|
||||
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
|
||||
# Epilogue tiling loop
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
# TMEM -> RMEM
|
||||
tTR_tAcc_slice = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_slice, tTR_rAcc)
|
||||
|
||||
# RMEM -> SMEM
|
||||
c_buffer = subtile_idx % epi_stages
|
||||
tRS_sC_slice = tRS_sC[(None, None, None, c_buffer)]
|
||||
|
||||
# type conversion
|
||||
tRS_rC.store(tRS_rAcc.load().to(io_dtype))
|
||||
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC_slice)
|
||||
|
||||
# Memory fence and barrier to ensure shared memory stores are visible to TMA stores
|
||||
cute.arch.fence_view_async_shared()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
|
||||
epilogue_pipeline.producer_commit()
|
||||
epilogue_pipeline.producer_acquire()
|
||||
epilogue_sync_barrier.arrive_and_wait()
|
||||
|
||||
# Async arrive accumulator buffer empty
|
||||
with cute.arch.elect_one():
|
||||
acc_full.release()
|
||||
|
||||
# Advance to next tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
clc_pipeline.consumer_wait(clc_consumer_state)
|
||||
work_tile = tile_sched.get_current_work()
|
||||
clc_pipeline.consumer_release(clc_consumer_state)
|
||||
clc_consumer_state.advance()
|
||||
else:
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
# Wait for C store complete
|
||||
epilogue_pipeline.producer_tail()
|
||||
|
||||
# Dealloc the tensor memory buffer
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def compute_grid(
|
||||
c: cute.Tensor,
|
||||
mma_tiler_mnk: Tuple[int, int, int],
|
||||
cluster_shape_mnk: Tuple[int, int, int],
|
||||
scheduler_type: Union[
|
||||
utils.StaticPersistentTileScheduler, utils.ClcDynamicPersistentTileScheduler
|
||||
],
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
) -> Tuple[
|
||||
Union[
|
||||
utils.ClcDynamicPersistentTileSchedulerParams,
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
Tuple[int, int, int],
|
||||
]:
|
||||
c_shape = cute.slice_(mma_tiler_mnk, (None, None, 0))
|
||||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||||
num_ctas_mn = gc[(0, (None, None))].shape
|
||||
|
||||
if cutlass.const_expr(
|
||||
issubclass(scheduler_type, utils.ClcDynamicPersistentTileScheduler)
|
||||
):
|
||||
tile_sched_params = utils.ClcDynamicPersistentTileSchedulerParams(
|
||||
(*num_ctas_mn, 1), cluster_shape_mnk
|
||||
)
|
||||
grid = utils.ClcDynamicPersistentTileScheduler.get_grid_shape(tile_sched_params)
|
||||
else:
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(*num_ctas_mn, 1), cluster_shape_mnk
|
||||
)
|
||||
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
||||
tile_sched_params, max_active_clusters
|
||||
)
|
||||
return tile_sched_params, grid
|
||||
|
||||
|
||||
@cute.jit
|
||||
def host_function(
|
||||
a: cute.Tensor,
|
||||
b: cute.Tensor,
|
||||
c: cute.Tensor,
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
):
|
||||
#
|
||||
# Construct tiled MMA
|
||||
#
|
||||
|
||||
op = tcgen05.MmaF16BF16Op(
|
||||
io_dtype,
|
||||
acc_dtype,
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
#
|
||||
# Construct SMEM layouts for A and B
|
||||
#
|
||||
|
||||
a_smem_layout = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
a.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
b_smem_layout = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma,
|
||||
mma_tiler_mnk,
|
||||
b.element_type,
|
||||
ab_stages,
|
||||
)
|
||||
|
||||
# c_smem_layout_kind is an enum for row/column major, not a CuTe layout
|
||||
c_smem_layout_kind = utils.LayoutEnum.from_tensor(c)
|
||||
|
||||
#
|
||||
# Construct the VMNK layout
|
||||
#
|
||||
|
||||
cta_layout_mnk = cute.make_layout(cluster_shape_mnk)
|
||||
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
|
||||
|
||||
#
|
||||
# Construct TMA load atoms
|
||||
#
|
||||
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
tma_atom_a, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
tma_atom_b, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
cta_tile_shape_mnk = (
|
||||
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id),
|
||||
mma_tiler_mnk[1],
|
||||
mma_tiler_mnk[2],
|
||||
)
|
||||
|
||||
epi_tile = utils.compute_epilogue_tile_shape(
|
||||
cta_tile_shape_mnk,
|
||||
use_2cta_instrs,
|
||||
c_smem_layout_kind,
|
||||
io_dtype,
|
||||
)
|
||||
|
||||
epi_smem_layout_staged = cutlass.utils.blackwell_helpers.make_smem_layout_epi(
|
||||
io_dtype,
|
||||
c_smem_layout_kind,
|
||||
epi_tile,
|
||||
epi_stages,
|
||||
)
|
||||
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
#
|
||||
# Launch the kernel
|
||||
#
|
||||
|
||||
tile_sched_params, grid_shape = compute_grid(
|
||||
c,
|
||||
cta_tile_shape_mnk,
|
||||
cluster_shape_mnk,
|
||||
scheduler_type,
|
||||
max_active_clusters,
|
||||
)
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
tma_atom_a,
|
||||
a_tma_tensor,
|
||||
tma_atom_b,
|
||||
b_tma_tensor,
|
||||
tma_atom_c,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
tile_sched_params,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=[224, 1, 1] if use_clc_dynamic_scheduler else [192, 1, 1],
|
||||
cluster=cluster_shape_mnk,
|
||||
)
|
||||
|
||||
|
||||
def run_dense_gemm(
|
||||
mnk: Tuple[int, int, int],
|
||||
tolerance: float,
|
||||
):
|
||||
global torch, cutlass_torch
|
||||
import torch
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
print("===================================================================")
|
||||
print("Running Blackwell fp16 GEMM example 5 (with TMA prefetch):")
|
||||
print(f" mnk: {mnk}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print("===================================================================")
|
||||
print()
|
||||
|
||||
m, n, k = mnk
|
||||
torch.manual_seed(1111)
|
||||
|
||||
# Make K-major tensors (torch tensors are row-major)
|
||||
def make_tensors(mn, k, dtype):
|
||||
shape = (mn, k)
|
||||
return (
|
||||
torch.empty(*shape, dtype=torch.int32)
|
||||
.random_(-2, 2)
|
||||
.to(device="cuda", dtype=dtype)
|
||||
)
|
||||
|
||||
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
|
||||
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
|
||||
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
|
||||
a_memref = from_dlpack(a).mark_layout_dynamic()
|
||||
b_memref = from_dlpack(b).mark_layout_dynamic()
|
||||
c_memref = from_dlpack(c).mark_layout_dynamic()
|
||||
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(
|
||||
cluster_shape_mnk[0] * cluster_shape_mnk[1]
|
||||
)
|
||||
|
||||
# Entry point to the host JIT function
|
||||
host_function(
|
||||
a_memref,
|
||||
b_memref,
|
||||
c_memref,
|
||||
max_active_clusters,
|
||||
no_cache=True,
|
||||
)
|
||||
|
||||
# Compute reference result and verify
|
||||
ref = (torch.einsum("mk,nk->mn", a, b)).cpu()
|
||||
torch.testing.assert_close(
|
||||
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
try:
|
||||
return [int(x.strip()) for x in s.split(",")]
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
from cuda.bindings import driver as cu_driver
|
||||
|
||||
cu_driver.cuInit(0)
|
||||
err, device_count = cu_driver.cuDeviceGetCount()
|
||||
if err != cu_driver.CUresult.CUDA_SUCCESS or device_count < 1:
|
||||
raise RuntimeError("A GPU is required to run this example")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Blackwell fp16 GEMM example 5 (with TMA prefetch)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mnk",
|
||||
type=parse_comma_separated_ints,
|
||||
default=(8192, 8192, 8192),
|
||||
help="MNK dimensions (comma-separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.mnk) != 3:
|
||||
parser.error("--mnk must contain exactly 3 values")
|
||||
|
||||
run_dense_gemm(
|
||||
args.mnk,
|
||||
args.tolerance,
|
||||
)
|
||||
print("PASS")
|
||||
1002
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py
Normal file
1002
examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user