Add tutorial fp16_gemm_1 (#2750)

* Add tutorial fp16_gemm_1

* refine

* refine

* refine

* revert changes in fp16_gemm_0.py
This commit is contained in:
Linfeng Zheng
2025-11-07 11:40:09 +08:00
committed by GitHub
parent d1ef0e87f2
commit 2252254ce2

View File

@@ -0,0 +1,527 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
# This is the second tutorial GEMM. It builds on the first tutorial by adding 2CTA MMA
# instructions with a 2x1 cluster.
import argparse
import torch
from typing import Tuple
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.torch as cutlass_torch
import cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.runtime import from_dlpack
"""
The second tutorial GEMM demonstrating a simple kernel implementation in CuTeDSL
With large tile sizes, it can as well achieve very high performance on 8kx8kx8k problem sizes.
Compared with fp16_gemm_0.py, this example adds 2CTA MMA & TMA multicast supports.
For fp16_gemm_0.py running at relative high SM frequency, the dram latency will be a potential performance issue.
This example can achieve better performance than fp16_gemm_0.py due to:
1. The 2CTA MMA can reduce B tensor smem size which allows larger ab stages to hide dram latency.
For both 1CTA & 2CTA, one stage of A tensor smem size is 128x64xsizeof(float16)=16KB
Situation for B is different.
For 1CTA, one stage of B tensor smem size is 256x64xsizeof(float16)=32KB,
while for 2CTA, we can only take half size, i.e. 16KB.
So, the maxmimum AB stage for 1CTA is 227 // (16 + 32) = 4, while for 2CTA is 227 // (16 + 16) = 7.
The latency hiding capability is 512 * (4 - 1) = 1.5K cycles for 1CTA, while 512 * (7 - 1) = 3K cycles for 2CTA.
2. The L2 traffic is reduced due to the TMA multicast.
For a (m, n) cluster shape, the L2 traffic for one tile is 16KB / n + 32KB / m.
16KB / 1 + 32KB / 2 = 24KB for 2x1 cluster shape
16KB / 4 + 32KB / 4 = 12KB for 4x4 cluster shape
If no TMA multicast enabled, the L2 traffic for one tile is less than 16KB + 32KB = 48KB, which depends on hardware optimization.
The first one can provide large latency hiding capability while the second one can reduce the data ready time.
These two factors should be considered for latency/memory throughput bound cases.
To run this example:
.. code-block:: bash
python examples/blackwell/tutorial_gemm/fp16_gemm_1.py \
--mnk 8192,8192,8192
Constraints for this example:
* The problem size of m and n must be divisible by the tile size m & n (256, 256)
"""
io_dtype = cutlass.Float16
acc_dtype = cutlass.Float32
cluster_shape_mnk = (2, 1, 1)
mma_inst_shape_mnk = (256, 256, 16)
mma_tiler_mnk = (256, 256, 64)
threads_per_cta = 128
# Pipeline stage configuration
ab_stages = 7
acc_stage = 1
@cute.struct
class SharedStorage:
# each stage has 2 kinds of barrier, i.e. empty & full
ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]
acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stage * 2]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
@cute.kernel()
def kernel(
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
mC_mnl: cute.Tensor,
a_smem_layout: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
cta_layout_vmnk: cute.Layout,
):
# Current thread/warp/block coordinates
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
bidx, bidy, _ = cute.arch.block_idx()
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster)
mma_coord_vmnk = (
bidx % cute.size(cta_layout_vmnk, mode=[0]),
bidx // cute.size(cta_layout_vmnk, mode=[0]),
bidy,
None,
)
mma_coord_mnk = mma_coord_vmnk[1:]
#
# 1. Prepare args
#
# Allocate SMEM
smem = cutlass.utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
sA = smem.allocate_tensor(
element_type=io_dtype,
layout=a_smem_layout.outer,
byte_alignment=128,
swizzle=a_smem_layout.inner,
)
sB = smem.allocate_tensor(
element_type=io_dtype,
layout=b_smem_layout.outer,
byte_alignment=128,
swizzle=b_smem_layout.inner,
)
# Prefetch tma descriptor
if warp_idx == 0:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
# Pipeline configuration
num_tma_copy_bytes = (
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
) * cute.size(cta_layout_vmnk, mode=[0])
num_mcast_ctas_a = cute.size(cta_layout_vmnk.shape[2])
num_mcast_ctas_b = cute.size(cta_layout_vmnk.shape[1])
num_tma_producer = num_mcast_ctas_a + num_mcast_ctas_b - 1
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
num_stages=ab_stages,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
),
tx_count=num_tma_copy_bytes,
barrier_storage=storage.ab_mbar_ptr.data_ptr(),
cta_layout_vmnk=cta_layout_vmnk,
).make_participants()
acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(
num_stages=acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
cute.size(cta_layout_vmnk, mode=[0]) * threads_per_cta,
),
barrier_storage=storage.acc_mbar_ptr.data_ptr(),
cta_layout_vmnk=cta_layout_vmnk,
).make_participants()
# Partition tensors for MMA and make fragments
# (bM, bK, RestK)
gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))
# (bN, bK, RestK)
gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))
# (bM, bN)
gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))
thr_mma = tiled_mma.get_slice(mma_coord_vmnk[0])
# (MMA, MMA_M, MMA_K)
tCgA = thr_mma.partition_A(gA)
# (MMA, MMA_N, MMA_K)
tCgB = thr_mma.partition_B(gB)
# (MMA, MMA_M, MMA_N)
tCgC = thr_mma.partition_C(gC)
# (MMA, MMA_M, MMA_K)
tCrA = tiled_mma.make_fragment_A(sA)
# (MMA, MMA_N, MMA_K)
tCrB = tiled_mma.make_fragment_B(sB)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])
# (MMA, MMA_M, MMA_N)
tCtAcc = tiled_mma.make_fragment_C(acc_shape)
# Partition tensors for TMA; This requires the tensors partitioned for MMA
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
cta_in_cluster_coord_vmnk[2],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
cute.group_modes(sA, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
cta_in_cluster_coord_vmnk[1],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
cute.group_modes(sB, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2
)
tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask(
cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1
)
# Allocate TMEM and swap the pointer in tCtAcc
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=threads_per_cta,
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf,
barrier_for_retrieve=tmem_alloc_barrier,
is_two_cta=cute.size(cta_layout_vmnk, mode=[0]) > 1,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr,
)
num_tmem_cols = 512
tmem.allocate(num_tmem_cols)
# CTA-wide sync before retrieving the pointer to the start of the allocated TMEM
# Only warp 0 does the allocation so we need to sync before retrieving the TMEM start address
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(acc_dtype)
# Swap the pointer in tCtAcc
tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)
subtile_cnt = 4
# (EpiTile)
epi_tiler = (
(cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt),
)
# (EpiTile, NumTiles)
tCtAcc_epi = cute.zipped_divide(tCtAcc, epi_tiler)
# (EpiTile, NumTiles)
gC_epi = cute.zipped_divide(tCgC, epi_tiler)
# Every thread loads 64 x fp32
tmem_atom = cute.make_copy_atom(
tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),
cutlass.Float32,
)
tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0])
tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)
# (TmemCpy,NumTmemCpy,NumTiles)
tDtC = tmem_thr_copy.partition_S(tCtAcc_epi)
# (TmemCpy,NumTmemCpy,NumTiles)
tDgC = tmem_thr_copy.partition_D(gC_epi)
# (TmemCpy,NumTmemCpy)
tCrAcc = cute.make_rmem_tensor_like(tDgC[None, None, 0], acc_dtype)
# (TmemCpy,NumTmemCpy)
tCrC = cute.make_rmem_tensor_like(tDgC[None, None, 0], io_dtype)
#
# 2. Main loop
#
is_leader_cta = mma_coord_vmnk[0] == 0
num_k_tiles = cute.size(gA, mode=[2])
if warp_idx == 0:
# Wait for a empty accumulator buffer
if is_leader_cta:
acc_producer.acquire_and_advance()
for _ in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2):
# Issue TMA loads
ab_empty = ab_producer.acquire_and_advance()
cute.copy(
tma_atom_a,
tAgA[(None, ab_empty.count)],
tAsA[(None, ab_empty.index)],
tma_bar_ptr=ab_empty.barrier,
mcast_mask=tma_mcast_mask_a,
)
cute.copy(
tma_atom_b,
tBgB[(None, ab_empty.count)],
tBsB[(None, ab_empty.index)],
tma_bar_ptr=ab_empty.barrier,
mcast_mask=tma_mcast_mask_b,
)
# Execute one K-block worth of MMA instructions
if is_leader_cta:
ab_full = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, ab_full.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
ab_full.release()
# Signal that the accumulator is fully computed
if is_leader_cta:
acc_producer.commit()
#
# 3. Epilogue
#
# Release TMEM allocation lock
tmem.relinquish_alloc_permit()
# Wait for the accumulator buffer to be full
acc_full = acc_consumer.wait_and_advance()
# TMEM -> RMEM -> GEMM
# Sub-tiling for better instruction-level parallelism
for i in cutlass.range(cute.size(tDtC, mode=[2])):
cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)
tCrC.store(tCrAcc.load().to(io_dtype))
cute.autovec_copy(tCrC, tDgC[None, None, i])
acc_full.release()
# Ensure used buffers are properly synchronized before producer exit.
# This could avoid the invalid dsmem access due to early leading CTA exit.
if warp_idx == 0:
ab_producer.tail()
if is_leader_cta:
acc_producer.tail()
# Deallocate TMEM
pipeline.sync(barrier_id=1)
tmem.free(tmem_ptr)
@cute.jit
def host_function(
a: cute.Tensor,
b: cute.Tensor,
c: cute.Tensor,
):
# Construct tiled MMA
op = tcgen05.MmaF16BF16Op(
io_dtype,
acc_dtype,
mma_inst_shape_mnk,
tcgen05.CtaGroup.TWO,
tcgen05.OperandSource.SMEM,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
)
tiled_mma = cute.make_tiled_mma(op)
# Construct SMEM layouts for A and B
a_smem_layout = utils.sm100.make_smem_layout_a(
tiled_mma,
mma_tiler_mnk,
a.element_type,
ab_stages,
)
b_smem_layout = utils.sm100.make_smem_layout_b(
tiled_mma,
mma_tiler_mnk,
b.element_type,
ab_stages,
)
a_smem_layout_one_stage = cute.select(a_smem_layout, mode=[0, 1, 2])
b_smem_layout_one_stage = cute.select(b_smem_layout, mode=[0, 1, 2])
# Construct the VMNK layout
cta_layout_mnk = cute.make_layout(cluster_shape_mnk)
cta_layout_vmnk = cute.tiled_divide(cta_layout_mnk, (tiled_mma.thr_id,))
# Construct TMA load atoms
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO)
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
op,
a,
a_smem_layout_one_stage,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
)
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
op,
b,
b_smem_layout_one_stage,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
)
grid_shape = cute.round_up(
cute.ceil_div(
(*c.layout.shape, 1), (mma_tiler_mnk[0] // 2, *mma_tiler_mnk[1:])
),
cluster_shape_mnk,
)
# Pretty prints kernel attributes useful for debugging
# print(f"a = {cute.pretty_str(a)}")
# print(f"b = {cute.pretty_str(b)}")
# print(f"c = {cute.pretty_str(c)}")
# print(f"tiled_mma = {cute.pretty_str(tiled_mma)}")
# print(f"a_smem_layout = {cute.pretty_str(a_smem_layout)}")
# print(f"b_smem_layout = {cute.pretty_str(b_smem_layout)}")
# print(f"cta_layout_mnk = {cute.pretty_str(cta_layout_mnk)}")
# print(f"cta_layout_vmnk = {cute.pretty_str(cta_layout_vmnk)}")
# print(f"a_tma_atom = {cute.pretty_str(a_tma_atom)}")
# print(f"b_tma_atom = {cute.pretty_str(b_tma_atom)}")
# print(f"a_tma_tensor = {cute.pretty_str(a_tma_tensor)}")
# print(f"b_tma_tensor = {cute.pretty_str(b_tma_tensor)}")
# cute.printf("grid_shape = {}", grid_shape)
# Launch the kernel
kernel(
tiled_mma,
a_tma_atom,
a_tma_tensor,
b_tma_atom,
b_tma_tensor,
c,
a_smem_layout,
b_smem_layout,
cta_layout_vmnk,
).launch(
grid=grid_shape,
block=[threads_per_cta, 1, 1],
cluster=cluster_shape_mnk,
)
def run_dense_gemm(
mnk: Tuple[int, int, int],
tolerance: float,
):
print("===================================================================")
print("Running Blackwell fp16 GEMM example 1 with:")
print(f" mnk: {mnk}")
print(f" tolerance: {tolerance}")
print("===================================================================")
print()
m, n, k = mnk
torch.manual_seed(1111)
# Make K-major tensors (torch tensors are row-major)
def make_tensors(mn, k, dtype):
shape = (mn, k)
return (
torch.empty(*shape, dtype=torch.int32)
.random_(-2, 2)
.to(device="cuda", dtype=dtype)
)
a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))
b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))
c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))
a_tensor = (
from_dlpack(a, assumed_align=32)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=k)
)
b_tensor = (
from_dlpack(b, assumed_align=32)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=k)
)
c_tensor = (
from_dlpack(c, assumed_align=32)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=n)
)
# Entry point to the host JIT function
host_function(
a_tensor,
b_tensor,
c_tensor,
no_cache=True,
)
# Compute reference result and verify
ref = (torch.einsum("mk,nk->mn", a.to(torch.float32), b.to(torch.float32))).cpu()
torch.testing.assert_close(
c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=tolerance, rtol=1e-05
)
if __name__ == "__main__":
def parse_comma_separated_ints(s: str):
try:
return [int(x.strip()) for x in s.split(",")]
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
if not torch.cuda.is_available():
raise RuntimeError("A GPU is required to run this example")
parser = argparse.ArgumentParser(description="Blackwell fp16 GEMM example 1")
parser.add_argument(
"--mnk",
type=parse_comma_separated_ints,
default=[8192, 8192, 8192],
help="MNK dimensions (comma-separated)",
)
parser.add_argument(
"--tolerance", type=float, default=1e-01, help="Tolerance for validation"
)
args = parser.parse_args()
if len(args.mnk) != 3:
parser.error("--mnk must contain exactly 3 values")
if args.mnk[0] % mma_tiler_mnk[0] != 0 or args.mnk[1] % mma_tiler_mnk[1] != 0:
parser.error("m n must be divisible by mma_tiler_mn")
run_dense_gemm(
args.mnk,
args.tolerance,
)
print("PASS")