Files
cutlass/examples/python/CuTeDSL/notebooks/tour_to_sol_gemm.ipynb
Longsheng Du 08185b9c3e Update blackwell tutorial to be compatible with 4.5-dev version (#3130)
* Update blackwell tutorial to be compatible with 4.5-dev version

* update example for reverted changes

* add more example fix
2026-04-09 14:40:33 +08:00

1024 lines
42 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1503e37e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import cutlass\n",
"import cutlass.cute as cute\n",
"import cutlass.utils as utils\n",
"import cutlass.torch as cutlass_torch\n",
"import cutlass.pipeline as pipeline\n",
"from cutlass.cute.nvgpu import cpasync, tcgen05\n",
"import cutlass.utils.blackwell_helpers as sm100_utils\n",
"from cutlass.cute.runtime import from_dlpack"
]
},
{
"cell_type": "markdown",
"id": "ee111a48",
"metadata": {},
"source": [
"# Tour of SOL GEMM\n",
"\n",
"This notebook demonstrates how to reach SOL (Speed Of Light) GEMM (GEneral Matrix Multiplication) based on Blackwell (tcgen05) step by step.\n",
"\n",
"Before going through it, you need to get familiar with:\n",
"\n",
"- tensor.ipynb\n",
"- tensorssa.ipynb\n",
"- cute_layout_algebra.ipynb\n",
"- composed_layout.ipynb\n",
"- elementwise_add.ipynb\n",
"- async_pipeline.ipynb\n",
"\n",
"These ipynb files will give you a basic knowledge on how to write a kernel by using CuTeDSL.\n",
"\n",
"## Learning Objectives\n",
"\n",
"In this tutorial, you will learn writing an efficient gemm step by step:\n",
"- How to implement basic GEMM kernel using CuTeDSL\n",
"- How to subtile the acc\n",
"- How to apply multi-stage by using software pipelining\n",
"- How to vectorize the instructions for storing out\n",
"\n",
"## Understanding GEMM\n",
"\n",
"GEMM is one of the most important operations in linear algebra and deep learning. Given two 2D tensors A with shape $(M, K)$ and B with shape $(N, K)$, the GEMM operation $C = A \\times B$ is defined as:\n",
"\n",
"$\n",
" C_{i,j} = \\sum_{k=0}^{K-1} A_{i,k} * B_{j,k}\n",
"$\n",
"\n",
"The result is a 2D tensor C with shape $(M, N)$.\n",
"\n",
"where:\n",
"- $i \\in [0, M)$ represents the row index of $A$ and $C$\n",
"- $j \\in [0, N)$ represents the column index of $C$ and the row index of $B$\n",
"- $k \\in [0, K)$ repersents the column index of $A$ and $B$\n",
"- $A_{i,k}$, $B_{j,k}$, and $C_{i,j}$ are the elements at position $(i,k)$, $(j,k)$ and $(i,j)$ in tensors $A$, $B$, and $C$ respectively\n",
"\n",
"This operation has several important characteristics:\n",
"\n",
"1. **Parallelizable**: Each element can be computed independently. It helps take fully use of SMs in a GPU.\n",
"2. **Data Reusable**: $C_{i,:}$-s (The row $i$ of $C$) need the same data from $A_{i,:}$ while $C_{:,j}$-s (The column $j$ of $C$) need the same data from $B_{:,j}$. This data reuse pattern can help reduce the IO pressure\n",
"3. **Block-friendly**: A block of elements can be processed together. Each block is a sub-problem of the whole GEMM. It helps reduce the IO pressure for each SM. It gives possibility to accelerate the computation using MMA instructions.\n",
"4. **Bottleneck-flexible**: Unlike the elementwise_add, the bottleneck for GEMM is varied for different problem sizes. Let's calculate the compute/memory ratio for GEMM roughly: $ratio = \\frac{M * N * K}{M*K + N*K + M*N} = \\frac{1}{\\frac{1}{N} + \\frac{1}{M} + \\frac{1}{K}}$.\n",
"It's related to all M, N and K. To reach good enough perf, we need different strategies for different problem sizes accordingly.\n",
"\n",
"\n",
"## Naive GEMM\n",
"\n",
"Let's start with a naive implementation to establish a baseline before exploring optimizations.\n",
"\n",
"![figure1](./images/blocked_gemm.svg \"figure1: blocked gemm\")\n",
"\n",
"First of all, we need to set basic configurations.\n",
"\n",
"- io_dtype: The datatype for tensors $A$, $B$, and $C$. For the most cases, it's also the input datatype of mma instructions (there're some exceptions, e.g. TF32 datatype, input transformation, etc.).\n",
"\n",
"- acc_dtype: The datatype for the accumulation. Normally, set it as FP32 to avoid overflow. As C's datatype could be different from acc_dtype, the acc data needs to be converted to io_dtype before storing out.\n",
"\n",
"- mma_inst_shape_mnk: The shape of one tcgen05 mma instruction can deal with. See more details in [PTX Document 9.7.16.2.1. Matrix Shape](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape).\n",
"From beginning, we choose the biggest one as it's easy to reach SOL.\n",
"\n",
"- mma_tiler_mnk: The GEMM kernel is normally implemented as blocked GEMM (see figure 1). Mma tiler is the block shape that one CTA or two CTAs will process. Whether one or two is determined by the issue granularity of tcgen05. See more details in [PTX Document 9.7.16.5. Issue Granularity](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-issue-granularity).\n",
"From beginning, we choose one CTA to issue tcgen05 for simplicity.\n",
"\n",
"- threads_per_cta: The number of threads we need to use in one CTA. To take fully use of a SM (streaming multiprocessor), it's at least 128.\n",
"\n",
"- ab_stages: The number demonstrates how many blocks that TMA can load before each block's computation. It's usually limited by the smem capacity. For mma_tiler_mnk (128, 256, 64), we can set it as 4 at most.\n",
"\n",
"- acc_stage: As each CTA only computes one block of acc and stores out, the number is 1.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15fd8f73",
"metadata": {},
"outputs": [],
"source": [
"io_dtype = cutlass.Float16\n",
"acc_dtype = cutlass.Float32\n",
"mma_inst_shape_mnk = (128, 256, 16)\n",
"mma_tiler_mnk = (128, 256, 64)\n",
"threads_per_cta = 128\n",
"\n",
"# Pipeline stage configuration\n",
"ab_stages = 4\n",
"acc_stage = 1"
]
},
{
"cell_type": "markdown",
"id": "91ad74f2",
"metadata": {},
"source": [
"Then, let's define the problem sizes and initialize the input & output tensors, i.e. $A$, $B$, and $C$.\n",
"\n",
"We start with a typical computation bound case. i.e. 8kx8kx8k. It's also large enough for each dimension to avoid tile quantization issue."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f2bf3d6f",
"metadata": {},
"outputs": [],
"source": [
"m, n, k = 8192, 8192, 8192\n",
"\n",
"# Make K-major tensors (torch tensors are row-major)\n",
"def make_tensors(mn, k, dtype):\n",
" shape = (mn, k)\n",
" return (\n",
" torch.empty(*shape, dtype=torch.int32)\n",
" .random_(-2, 2)\n",
" .to(dtype=dtype, device=\"cuda\")\n",
" )\n",
"\n",
"a = make_tensors(m, k, cutlass_torch.dtype(io_dtype))\n",
"b = make_tensors(n, k, cutlass_torch.dtype(io_dtype))\n",
"c = make_tensors(m, n, cutlass_torch.dtype(io_dtype))\n",
"a_tensor = (\n",
" from_dlpack(a)\n",
" .mark_layout_dynamic(leading_dim=1)\n",
")\n",
"b_tensor = (\n",
" from_dlpack(b)\n",
" .mark_layout_dynamic(leading_dim=1)\n",
")\n",
"c_tensor = (\n",
" from_dlpack(c)\n",
" .mark_layout_dynamic(leading_dim=1)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "81bd00ba",
"metadata": {},
"source": [
"Before writing kernel, we need to configurate basic components in a GEMM operation.\n",
"\n",
"1. Tiled MMA. The tiled MMA helps calculate GEMM for one mma tile. We configurate it as tcgen05 MMA.\n",
"\n",
"2. Smem layous for A and B. They must match the post-partitioned (CTA-local) shapes expected by the MMA instructions.\n",
"sm100_utils provides functions to determine the post-partitioned shape.\n",
"These functions take the tiled MMA, and the mma tiler as inputs and returns a shape that is at least rank-3\n",
"where the first mode has the same shape as the MMA instruction, 2nd and 3rd mode expresses the number of time\n",
"MMA instr is repeated in M/N mode and K mode of MMA tile, respectively.\n",
"These SMEM layouts are swizzled to improve MMA performance.\n",
"\n",
"3. TMA descriptors for A & B. We've already know A, B tensors' info in both GMEM (global memory) & SMEM (shared memory). We take those to generate TMA descriptors & tme tensors. They helps load a block of A & B from GMEM to SMEM.\n",
"\n",
"## Host Code\n",
"\n",
"Host code constructs the components introduced above. Besides, we calculate the grid shape & launch the kernel with these as parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5bbbb368",
"metadata": {},
"outputs": [],
"source": [
"@cute.jit\n",
"def host_function(\n",
" a: cute.Tensor,\n",
" b: cute.Tensor,\n",
" c: cute.Tensor,\n",
" kernel: cutlass.Constexpr,\n",
"):\n",
" # Construct tiled MMA\n",
" op = tcgen05.MmaF16BF16Op(\n",
" io_dtype,\n",
" acc_dtype,\n",
" mma_inst_shape_mnk,\n",
" tcgen05.CtaGroup.ONE,\n",
" tcgen05.OperandSource.SMEM,\n",
" tcgen05.OperandMajorMode.K,\n",
" tcgen05.OperandMajorMode.K,\n",
" )\n",
" tiled_mma = cute.make_tiled_mma(op)\n",
"\n",
" # Construct SMEM layouts for A and B\n",
" a_smem_layout = sm100_utils.make_smem_layout_a(\n",
" tiled_mma,\n",
" mma_tiler_mnk,\n",
" a.element_type,\n",
" ab_stages,\n",
" )\n",
" b_smem_layout = sm100_utils.make_smem_layout_b(\n",
" tiled_mma,\n",
" mma_tiler_mnk,\n",
" b.element_type,\n",
" ab_stages,\n",
" )\n",
" a_smem_layout_one_stage = cute.select(a_smem_layout, mode=[0, 1, 2])\n",
" b_smem_layout_one_stage = cute.select(b_smem_layout, mode=[0, 1, 2])\n",
"\n",
" # Construct TMA load atoms\n",
" op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE)\n",
" a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(\n",
" op,\n",
" a,\n",
" a_smem_layout_one_stage,\n",
" mma_tiler_mnk,\n",
" tiled_mma,\n",
" )\n",
" b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(\n",
" op,\n",
" b,\n",
" b_smem_layout_one_stage,\n",
" mma_tiler_mnk,\n",
" tiled_mma,\n",
" )\n",
"\n",
" # Launch the kernel\n",
" grid_shape = cute.ceil_div((*c.layout.shape, 1), mma_tiler_mnk[:2])\n",
" kernel(\n",
" tiled_mma,\n",
" a_tma_atom,\n",
" a_tma_tensor,\n",
" b_tma_atom,\n",
" b_tma_tensor,\n",
" c,\n",
" a_smem_layout,\n",
" b_smem_layout,\n",
" ).launch(\n",
" grid=grid_shape,\n",
" block=(threads_per_cta, 1, 1),\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "18f4322d",
"metadata": {},
"source": [
"## Structure of the Kernel\n",
"\n",
"Let's breakdown how a GEMM kernel organize:\n",
"\n",
"1. **Prologue**: The phase before the first MMA instructions. It usually defines, fetches, allocates, partitions or calculates necessary components (listed below). What else, load multiple stages of data ahead of the first MMA to help hide GMEM latency.\n",
" - Indexing\n",
" * `block_idx` (bidx, bidy): Block index in the grid\n",
" * `mma_coord_mnk`: The location of which block the current MMA unit will calculate (see details in figure 1)\n",
" * `thread_idx` (tidx): Thread index within a block (0 to threads_per_cta - 1). We need this to slice the partition of tensor memory for each thread in a block (see details in [PTX Document 9.7.16.2.3.1 Memory Layout](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-memory-layout))\n",
" * `warp_idx`: As TMA & tcgen05.mma only needs one thread to issue, some code only needs to execute by warp 0\n",
" - Allocation\n",
" * `smem` (storage, sA, sB): Allocate necessary smem usage for pipelines, A/B smem tensors as input of tcgen05.mma\n",
" * `tmem`: Allocate necessary tmem usage for Acc\n",
" - Pipeline (see more details in async_pipeline.ipynb)\n",
" * `PipelineTmaUmma`: Tma & tcgen05.mma units are async. PipelineTmaUmma helps notify: 1. tcgen05.mma when TMA fills A/B buffers to full; 2. TMA when tcgen05.mma consumes A/B buffer to empty\n",
" * `PipelineUmmaAsync`: It helps threads when tcgen05.mma finish the accumulation and Acc is ready\n",
" * `Barrier initialization`: barrier initialization work is done inside the pipeline create functions\n",
" - Partition\n",
" * `local_tile`: Get the block of A/B/C GMEM tensors for current MMA unit acoording to `mma_coord_mnk`.\n",
" * `TMA`: Get the tensor view from each TMA instruction\n",
" * `MMA`: Get the tensor view from each tcgen05.mma instruction\n",
" - TMA descriptor prefetch\n",
" * `cpasync.prefetch_descriptor`: helps shorten the latency of access tma descrptor, i.e. tma_atom_a, tma_atom_b\n",
"\n",
"2. **Mainloop**: The phase that carries out the main computation of GEMM. It's usually organized as a loop to iterate blocks in K dim for accumulation. The loop body contains:\n",
" - `Data prefetch` with a fixed stride (ab_stage - 1) ahead of current K block\n",
" - `MMA computation` for current K block\n",
"\n",
"3. **Epilogue**: The phase after the MMA instructions finish the accumulation. It usually contains:\n",
" - `Partition`: Get the tensor views from epi tiler (acc subtile) & each tcgen05.ld instruction\n",
" - `Acc fetch`: Load data from tensor memory to register\n",
" - `Fusion & datatype conversion`: Fuse some operations on C (optional); Datatype conversion if output type is different from acc type\n",
" - `Relinquish tmem alloc permit`: Give permit for following launched kernels\n",
" - `Storing`: TMA or st.global to store out\n",
" - `TMEM deallocation`: Deallocate tmem for Acc buffer\n",
" \n",
" Usually, we subtile the acc buffer to save resources of registers & smem (if using TMA to store C). For our mma_tiler (128, 256), each thread needs 256 registers if no subtiling. Besides, better instruction-level parallelism for interleavely issuing tcgen05.ld, data conversion & st.global.\n",
"\n",
"``` python\n",
" for i in cutlass.range(cute.size(tDtC, mode=[2])):\n",
" cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)\n",
" tCrC.store(tCrAcc.load().to(io_dtype))\n",
" cute.autovec_copy(tCrC, tDgC[None, None, i])\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "358609c4",
"metadata": {},
"outputs": [],
"source": [
"@cute.struct\n",
"class SharedStorage:\n",
" ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2]\n",
" acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stage * 2]\n",
" tmem_holding_buf: cutlass.Int32\n",
"\n",
"\n",
"@cute.kernel\n",
"def kernel(\n",
" tiled_mma: cute.TiledMma,\n",
" tma_atom_a: cute.CopyAtom,\n",
" mA_mkl: cute.Tensor,\n",
" tma_atom_b: cute.CopyAtom,\n",
" mB_nkl: cute.Tensor,\n",
" mC_mnl: cute.Tensor,\n",
" a_smem_layout: cute.ComposedLayout,\n",
" b_smem_layout: cute.ComposedLayout,\n",
"):\n",
" #\n",
" # 1. Prepare args\n",
" #\n",
"\n",
" # Current thread/warp/block coordinates\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" warp_idx = cute.arch.warp_idx()\n",
" warp_idx = cute.arch.make_warp_uniform(warp_idx)\n",
" bidx, bidy, _ = cute.arch.block_idx()\n",
" mma_coord_mnk = (bidx, bidy, None)\n",
"\n",
" # Allocate SMEM\n",
" smem = cutlass.utils.SmemAllocator()\n",
" storage = smem.allocate(SharedStorage)\n",
" sA = smem.allocate_tensor(\n",
" element_type=io_dtype,\n",
" layout=a_smem_layout.outer,\n",
" byte_alignment=128,\n",
" swizzle=a_smem_layout.inner,\n",
" )\n",
" sB = smem.allocate_tensor(\n",
" element_type=io_dtype,\n",
" layout=b_smem_layout.outer,\n",
" byte_alignment=128,\n",
" swizzle=b_smem_layout.inner,\n",
" )\n",
"\n",
" # Allocate all TMEM columns\n",
" tmem_alloc_barrier = pipeline.NamedBarrier(\n",
" barrier_id=1,\n",
" num_threads=threads_per_cta,\n",
" )\n",
" tmem = utils.TmemAllocator(\n",
" storage.tmem_holding_buf.ptr,\n",
" barrier_for_retrieve=tmem_alloc_barrier,\n",
" )\n",
" num_tmem_cols = 512\n",
" tmem.allocate(num_tmem_cols)\n",
"\n",
" # Prefetch tma descriptor\n",
" if warp_idx == 0:\n",
" cpasync.prefetch_descriptor(tma_atom_a)\n",
" cpasync.prefetch_descriptor(tma_atom_b)\n",
"\n",
" # Pipeline configuration\n",
" num_tma_copy_bytes = cute.size_in_bytes(\n",
" io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2])\n",
" ) + cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))\n",
" ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(\n",
" num_stages=ab_stages,\n",
" producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),\n",
" consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),\n",
" tx_count=num_tma_copy_bytes,\n",
" barrier_storage=storage.ab_mbar_ptr.data_ptr(),\n",
" ).make_participants()\n",
" acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(\n",
" num_stages=acc_stage,\n",
" producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),\n",
" consumer_group=pipeline.CooperativeGroup(\n",
" pipeline.Agent.Thread, threads_per_cta\n",
" ),\n",
" barrier_storage=storage.acc_mbar_ptr.data_ptr(),\n",
" ).make_participants()\n",
"\n",
" # Partition tensors for MMA and make fragments\n",
" # (bM, bK, RestK)\n",
" gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))\n",
" # (bN, bK, RestK)\n",
" gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))\n",
" # (bM, bN)\n",
" gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))\n",
" thr_mma = tiled_mma.get_slice(0)\n",
" # (MMA, MMA_M, MMA_K)\n",
" tCgA = thr_mma.partition_A(gA)\n",
" # (MMA, MMA_N, MMA_K)\n",
" tCgB = thr_mma.partition_B(gB)\n",
" # (MMA, MMA_M, MMA_N)\n",
" tCgC = thr_mma.partition_C(gC)\n",
" # (MMA, MMA_M, MMA_K)\n",
" tCrA = tiled_mma.make_fragment_A(sA)\n",
" # (MMA, MMA_N, MMA_K)\n",
" tCrB = tiled_mma.make_fragment_B(sB)\n",
" # (MMA, MMA_M, MMA_N)\n",
" acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])\n",
" # (MMA, MMA_M, MMA_N)\n",
" tCtAcc = tiled_mma.make_fragment_C(acc_shape)\n",
" # Partition tensors for TMA; This requires the tensors partitioned for MMA\n",
" tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(\n",
" tma_atom_a,\n",
" 0,\n",
" cute.make_layout(1),\n",
" cute.group_modes(sA, 0, 3),\n",
" cute.group_modes(tCgA, 0, 3),\n",
" )\n",
" tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(\n",
" tma_atom_b,\n",
" 0,\n",
" cute.make_layout(1),\n",
" cute.group_modes(sB, 0, 3),\n",
" cute.group_modes(tCgB, 0, 3),\n",
" )\n",
"\n",
" # CTA-wide sync before retrieving the pointer to the start of the allocated TMEM\n",
" # Only warp 0 does the allocation so we need to sync before retrieving the TMEM start address\n",
" tmem.wait_for_alloc()\n",
" tmem_ptr = tmem.retrieve_ptr(acc_dtype)\n",
" # Swap the pointer in tCtAcc\n",
" tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)\n",
"\n",
" subtile_cnt = 4\n",
" # (EpiTile)\n",
" epi_tiler = (\n",
" (cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt),\n",
" )\n",
" # (EpiTile, NumTiles)\n",
" tCtAcc_epi = cute.zipped_divide(tCtAcc, epi_tiler)\n",
" # (EpiTile, NumTiles)\n",
" gC_epi = cute.zipped_divide(tCgC, epi_tiler)\n",
"\n",
" # Every thread loads 32x128 bits\n",
" tmem_atom = cute.make_copy_atom(\n",
" tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),\n",
" cutlass.Float32,\n",
" )\n",
" tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0])\n",
" tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)\n",
"\n",
" # (TmemCpy,NumTmemCpy,NumTiles)\n",
" tDtC = tmem_thr_copy.partition_S(tCtAcc_epi)\n",
" # (TmemCpy,NumTmemCpy,NumTiles)\n",
" tDgC = tmem_thr_copy.partition_D(gC_epi)\n",
"\n",
" # (TmemCpy,NumTmemCpy)\n",
" tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype)\n",
" # (TmemCpy,NumTmemCpy)\n",
" tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype)\n",
"\n",
" #\n",
" # 2. Main loop\n",
" #\n",
" num_k_tiles = cute.size(gA, mode=[2])\n",
" if warp_idx == 0:\n",
" # Wait for a empty accumulator buffer\n",
" acc_empty = acc_producer.acquire_and_advance()\n",
" for k_tile_idx in cutlass.range(num_k_tiles):\n",
" # Issue TMA loads\n",
" ab_empty = ab_producer.acquire_and_advance()\n",
" cute.copy(\n",
" tma_atom_a,\n",
" tAgA[(None, ab_empty.count)],\n",
" tAsA[(None, ab_empty.index)],\n",
" tma_bar_ptr=ab_empty.barrier,\n",
" )\n",
" cute.copy(\n",
" tma_atom_b,\n",
" tBgB[(None, ab_empty.count)],\n",
" tBsB[(None, ab_empty.index)],\n",
" tma_bar_ptr=ab_empty.barrier,\n",
" )\n",
"\n",
" # Execute one K-block worth of MMA instructions\n",
" ab_full = ab_consumer.wait_and_advance()\n",
" num_k_blocks = cute.size(tCrA, mode=[2])\n",
" for k_block_idx in cutlass.range_constexpr(num_k_blocks):\n",
" k_block_coord = (None, None, k_block_idx, ab_full.index)\n",
" cute.gemm(\n",
" tiled_mma,\n",
" tCtAcc,\n",
" tCrA[k_block_coord],\n",
" tCrB[k_block_coord],\n",
" tCtAcc,\n",
" )\n",
" tiled_mma.set(tcgen05.Field.ACCUMULATE, True)\n",
"\n",
" # Signal that the A/B buffers have been consumed and are ready for the next load\n",
" ab_full.release()\n",
"\n",
" # Signal that the accumulator is fully computed\n",
" acc_empty.commit()\n",
"\n",
" #\n",
" # 3. Epilogue\n",
" #\n",
"\n",
" # Release TMEM allocation lock\n",
" tmem.relinquish_alloc_permit()\n",
"\n",
" # Wait for the accumulator buffer to be full\n",
" acc_full = acc_consumer.wait_and_advance()\n",
"\n",
" # TMEM -> RMEM -> GEMM\n",
" # Sub-tiling for better instruction-level parallelism\n",
" for i in cutlass.range(cute.size(tDtC, mode=[2])):\n",
" cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)\n",
" tCrC.store(tCrAcc.load().to(io_dtype))\n",
" cute.autovec_copy(tCrC, tDgC[None, None, i])\n",
" acc_full.release()\n",
"\n",
" # Deallocate TMEM\n",
" pipeline.sync(barrier_id=1)\n",
" tmem.free(tmem_ptr)"
]
},
{
"cell_type": "markdown",
"id": "fe62727d",
"metadata": {},
"source": [
"## Performance Analysis and Benchmarking\n",
"\n",
"To understand and improve our kernel's performance, we need to measure its execution time and computation throughput. Let's analyze several key metrics:\n",
"\n",
"* **Execution Time**\n",
" - Measures raw kernel performance in microseconds\n",
" - Lower is better\n",
" - Affected by GPU clock speed, memory bandwidth, and kernel efficiency\n",
"* **Computation Throughput**\n",
" - Measures how fast we compute (TFlops)\n",
" - Higher is better\n",
" - Theoretical peak varies by GPU model\n",
" - For GEMM:\n",
" * M * N * K FMAs to finish GEMM\n",
" * 2 Float operations for each FMA\n",
" * Total = M * N * K * 2\n",
"\n",
"Below is our benchmarking utility that measures these metrics:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5833f3fc",
"metadata": {},
"outputs": [],
"source": [
"def benchmark(callable, a_tensor, b_tensor, c_tensor):\n",
" avg_time_us = cute.testing.benchmark(\n",
" callable,\n",
" kernel_arguments=cute.testing.JitArguments(a_tensor, b_tensor, c_tensor),\n",
" warmup_iterations=1,\n",
" iterations=2,\n",
" )\n",
"\n",
" # Calculate metrics\n",
"\n",
" # Calculate total float ops calculated:\n",
" # - M * N * K * 2 (FMA)\n",
" total_float_ops = m * n * k * 2\n",
"\n",
" # Calculate achieved TFlops\n",
" achieved_tflops = total_float_ops / (avg_time_us * 1000000) # TFlops\n",
"\n",
" # Print results\n",
" # ------------\n",
" print(f\"Performance Metrics:\")\n",
" print(f\"-------------------\")\n",
" print(f\"Kernel execution time: {avg_time_us:.4f} us\")\n",
" print(f\"Memory throughput: {achieved_tflops:.2f} tflops\")"
]
},
{
"cell_type": "markdown",
"id": "21f59ec5",
"metadata": {},
"source": [
"## Test the first version of GEMM\n",
"\n",
"You can run the following code to get the Tflops and verify the function is ok by using `torch.einsum` as a reference.\n",
"\n",
"You should be able to reach about **450** TFlops."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05e62b4b",
"metadata": {},
"outputs": [],
"source": [
"# Compile the kernel for the specific input types\n",
"naive_kernel = cute.compile(host_function, a_tensor, b_tensor, c_tensor, kernel)\n",
"\n",
"# Run the kernel\n",
"benchmark(naive_kernel, a_tensor, b_tensor, c_tensor)\n",
"\n",
"# Verify Results\n",
"# -------------\n",
"# Compare our kernel output with PyTorch's native implementation\n",
"# Compute reference result and verify\n",
"ref = (torch.einsum(\"mk,nk->mn\", a.to(torch.float32), b.to(torch.float32))).cpu()\n",
"torch.testing.assert_close(\n",
" c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=1e-1, rtol=1e-05\n",
")\n",
"print(\"Verification passed!\")"
]
},
{
"cell_type": "markdown",
"id": "b19916a9",
"metadata": {},
"source": [
"\n",
"## Enable software pipelining\n",
"\n",
"Like what we said before, usually we prefetch multiple stages (ab_stages - 2) of A/B tensors to hide latency of GMEM (see figure 2). The dark area demonstrates the issue of TMA/tcgen05.mma while the light area demonstrates the latency correspondingly.\n",
"It can use (ab\\_stages - 1) * time of one stage mma to hide GMEM latency.\n",
"\n",
"![figure2](./images/software_pipelining_ab_stages_minus_2.svg \"figure2: software_pipelining_with_prefetch_ab_stages_minus_2\")\n",
"\n",
"\n",
"\n",
"To enable this strategy, we:\n",
"1. write a loop to prefetch before the mainloop\n",
"2. A fixed stride ahead copy inside the mainloop.\n",
"\n",
"\n",
"``` python\n",
"# Prefetch ab_stages - 2 blocks of A/B\n",
"for stage in cutlass.range(ab_stages - 2):\n",
" ab_empty = ab_producer.acquire_and_advance()\n",
" cute.copy(...)\n",
"\n",
"for k_tile_idx in cutlass.range(num_k_tiles):\n",
" # Issue TMA loads\n",
" if k_tile_idx + ab_stages - 2 < num_k_tiles:\n",
" ab_empty = ab_producer.acquire_and_advance()\n",
" cute.copy(...)\n",
" # Execute one K-block worth of MMA instructions\n",
" ab_full = ab_consumer.wait_and_advance()\n",
" cute.gemm(...)\n",
" # Signal that the A/B buffers have been consumed and are ready for the next load\n",
" ab_full.release()\n",
"```\n",
"\n",
"For CuTeDSL, we have an attribute `prefetch_stages` for cutlass.range. It helps us write code like the general pattern but prefetch data like we write above.\n",
"\n",
"``` python\n",
"for k_tile_idx in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2):\n",
" # Issue TMA loads\n",
" ab_empty = ab_producer.acquire_and_advance()\n",
" cute.copy(...)\n",
" # Execute one K-block worth of MMA instructions\n",
" ab_full = ab_consumer.wait_and_advance()\n",
" cute.gemm(...)\n",
" # Signal that the A/B buffers have been consumed and are ready for the next load\n",
" ab_full.release()\n",
"```\n",
"\n",
"Figure 3 explains why we prefetch ab_stages - 2 instead of ab_stages - 1. For ab_stages - 1, Each TMA copy inside mainloop will be issued after the previous MMA finished. It will delay the issue of next MMA and cause bubbles between 2 blocks.\n",
"\n",
"![figure3](./images/software_pipelining_ab_stages_minus_1.svg \"figure3: software_pipelining_with_prefetch_ab_stages_minus_1\")\n",
"\n",
"Let's test the perf with prefetch enabled. You can reach about **880** TFlops."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c1abeca4",
"metadata": {},
"outputs": [],
"source": [
"@cute.kernel\n",
"def kernel_with_prefetch(\n",
" tiled_mma: cute.TiledMma,\n",
" tma_atom_a: cute.CopyAtom,\n",
" mA_mkl: cute.Tensor,\n",
" tma_atom_b: cute.CopyAtom,\n",
" mB_nkl: cute.Tensor,\n",
" mC_mnl: cute.Tensor,\n",
" a_smem_layout: cute.ComposedLayout,\n",
" b_smem_layout: cute.ComposedLayout,\n",
"):\n",
" #\n",
" # 1. Prepare args\n",
" #\n",
"\n",
" # Current thread/warp/block coordinates\n",
" tidx, _, _ = cute.arch.thread_idx()\n",
" warp_idx = cute.arch.warp_idx()\n",
" warp_idx = cute.arch.make_warp_uniform(warp_idx)\n",
" bidx, bidy, _ = cute.arch.block_idx()\n",
" mma_coord_mnk = (bidx, bidy, None)\n",
"\n",
" # Allocate SMEM\n",
" smem = cutlass.utils.SmemAllocator()\n",
" storage = smem.allocate(SharedStorage)\n",
" sA = smem.allocate_tensor(\n",
" element_type=io_dtype,\n",
" layout=a_smem_layout.outer,\n",
" byte_alignment=128,\n",
" swizzle=a_smem_layout.inner,\n",
" )\n",
" sB = smem.allocate_tensor(\n",
" element_type=io_dtype,\n",
" layout=b_smem_layout.outer,\n",
" byte_alignment=128,\n",
" swizzle=b_smem_layout.inner,\n",
" )\n",
"\n",
" # Allocate all TMEM columns\n",
" tmem_alloc_barrier = pipeline.NamedBarrier(\n",
" barrier_id=1,\n",
" num_threads=threads_per_cta,\n",
" )\n",
" tmem = utils.TmemAllocator(\n",
" storage.tmem_holding_buf.ptr,\n",
" barrier_for_retrieve=tmem_alloc_barrier,\n",
" )\n",
" num_tmem_cols = 512\n",
" tmem.allocate(num_tmem_cols)\n",
"\n",
" # Prefetch tma descriptor\n",
" if warp_idx == 0:\n",
" cpasync.prefetch_descriptor(tma_atom_a)\n",
" cpasync.prefetch_descriptor(tma_atom_b)\n",
"\n",
" # Pipeline configuration\n",
" num_tma_copy_bytes = cute.size_in_bytes(\n",
" io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2])\n",
" ) + cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))\n",
" ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(\n",
" num_stages=ab_stages,\n",
" producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),\n",
" consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),\n",
" tx_count=num_tma_copy_bytes,\n",
" barrier_storage=storage.ab_mbar_ptr.data_ptr(),\n",
" ).make_participants()\n",
" acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create(\n",
" num_stages=acc_stage,\n",
" producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),\n",
" consumer_group=pipeline.CooperativeGroup(\n",
" pipeline.Agent.Thread, threads_per_cta\n",
" ),\n",
" barrier_storage=storage.acc_mbar_ptr.data_ptr(),\n",
" ).make_participants()\n",
"\n",
" # Partition tensors for MMA and make fragments\n",
" # (bM, bK, RestK)\n",
" gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1))\n",
" # (bN, bK, RestK)\n",
" gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1))\n",
" # (bM, bN)\n",
" gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None))\n",
" thr_mma = tiled_mma.get_slice(0)\n",
" # (MMA, MMA_M, MMA_K)\n",
" tCgA = thr_mma.partition_A(gA)\n",
" # (MMA, MMA_N, MMA_K)\n",
" tCgB = thr_mma.partition_B(gB)\n",
" # (MMA, MMA_M, MMA_N)\n",
" tCgC = thr_mma.partition_C(gC)\n",
" # (MMA, MMA_M, MMA_K)\n",
" tCrA = tiled_mma.make_fragment_A(sA)\n",
" # (MMA, MMA_N, MMA_K)\n",
" tCrB = tiled_mma.make_fragment_B(sB)\n",
" # (MMA, MMA_M, MMA_N)\n",
" acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2])\n",
" # (MMA, MMA_M, MMA_N)\n",
" tCtAcc = tiled_mma.make_fragment_C(acc_shape)\n",
" # Partition tensors for TMA; This requires the tensors partitioned for MMA\n",
" tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(\n",
" tma_atom_a,\n",
" 0,\n",
" cute.make_layout(1),\n",
" cute.group_modes(sA, 0, 3),\n",
" cute.group_modes(tCgA, 0, 3),\n",
" )\n",
" tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(\n",
" tma_atom_b,\n",
" 0,\n",
" cute.make_layout(1),\n",
" cute.group_modes(sB, 0, 3),\n",
" cute.group_modes(tCgB, 0, 3),\n",
" )\n",
"\n",
" # CTA-wide sync before retrieving the pointer to the start of the allocated TMEM\n",
" # Only warp 0 does the allocation so we need to sync before retrieving the TMEM start address\n",
" tmem.wait_for_alloc()\n",
" tmem_ptr = tmem.retrieve_ptr(acc_dtype)\n",
" # Swap the pointer in tCtAcc\n",
" tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout)\n",
"\n",
" subtile_cnt = 4\n",
" # (EpiTile)\n",
" epi_tiler = (\n",
" (cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt),\n",
" )\n",
" # (EpiTile, NumTiles)\n",
" tCtAcc_epi = cute.zipped_divide(tCtAcc, epi_tiler)\n",
" # (EpiTile, NumTiles)\n",
" gC_epi = cute.zipped_divide(tCgC, epi_tiler)\n",
"\n",
" # Every thread loads 32x128 bits\n",
" tmem_atom = cute.make_copy_atom(\n",
" tcgen05.Ld32x32bOp(tcgen05.Repetition.x64),\n",
" cutlass.Float32,\n",
" )\n",
" tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0])\n",
" tmem_thr_copy = tmem_tiled_copy.get_slice(tidx)\n",
"\n",
" # (TmemCpy,NumTmemCpy,NumTiles)\n",
" tDtC = tmem_thr_copy.partition_S(tCtAcc_epi)\n",
" # (TmemCpy,NumTmemCpy,NumTiles)\n",
" tDgC = tmem_thr_copy.partition_D(gC_epi)\n",
"\n",
" # (TmemCpy,NumTmemCpy)\n",
" tCrAcc = cute.make_rmem_tensor(tDgC[None, None, 0].shape, acc_dtype)\n",
" # (TmemCpy,NumTmemCpy)\n",
" tCrC = cute.make_rmem_tensor(tDgC[None, None, 0].shape, io_dtype)\n",
"\n",
" #\n",
" # 2. Main loop\n",
" #\n",
" num_k_tiles = cute.size(gA, mode=[2])\n",
" if warp_idx == 0:\n",
" # Wait for a empty accumulator buffer\n",
" acc_empty = acc_producer.acquire_and_advance()\n",
" for k_tile_idx in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2):\n",
" # Issue TMA loads\n",
" ab_empty = ab_producer.acquire_and_advance()\n",
" cute.copy(\n",
" tma_atom_a,\n",
" tAgA[(None, ab_empty.count)],\n",
" tAsA[(None, ab_empty.index)],\n",
" tma_bar_ptr=ab_empty.barrier,\n",
" )\n",
" cute.copy(\n",
" tma_atom_b,\n",
" tBgB[(None, ab_empty.count)],\n",
" tBsB[(None, ab_empty.index)],\n",
" tma_bar_ptr=ab_empty.barrier,\n",
" )\n",
"\n",
" # Execute one K-block worth of MMA instructions\n",
" ab_full = ab_consumer.wait_and_advance()\n",
" num_k_blocks = cute.size(tCrA, mode=[2])\n",
" for k_block_idx in cutlass.range_constexpr(num_k_blocks):\n",
" k_block_coord = (None, None, k_block_idx, ab_full.index)\n",
" cute.gemm(\n",
" tiled_mma,\n",
" tCtAcc,\n",
" tCrA[k_block_coord],\n",
" tCrB[k_block_coord],\n",
" tCtAcc,\n",
" )\n",
" tiled_mma.set(tcgen05.Field.ACCUMULATE, True)\n",
"\n",
" # Signal that the A/B buffers have been consumed and are ready for the next load\n",
" ab_full.release()\n",
"\n",
" # Signal that the accumulator is fully computed\n",
" acc_empty.commit()\n",
"\n",
" #\n",
" # 3. Epilogue\n",
" #\n",
"\n",
" # Release TMEM allocation lock\n",
" tmem.relinquish_alloc_permit()\n",
"\n",
" # Wait for the accumulator buffer to be full\n",
" acc_full = acc_consumer.wait_and_advance()\n",
"\n",
" # TMEM -> RMEM -> GEMM\n",
" # Sub-tiling for better instruction-level parallelism\n",
" for i in cutlass.range(cute.size(tDtC, mode=[2])):\n",
" cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc)\n",
" tCrC.store(tCrAcc.load().to(io_dtype))\n",
" cute.autovec_copy(tCrC, tDgC[None, None, i])\n",
" acc_full.release()\n",
"\n",
" # Deallocate TMEM\n",
" pipeline.sync(barrier_id=1)\n",
" tmem.free(tmem_ptr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23e745b5",
"metadata": {},
"outputs": [],
"source": [
"# Compile the kernel for the specific input types\n",
"prefetch_kernel = cute.compile(host_function, a_tensor, b_tensor, c_tensor, kernel_with_prefetch)\n",
"\n",
"# Run the kernel\n",
"benchmark(prefetch_kernel, a_tensor, b_tensor, c_tensor)\n",
"\n",
"# Verify Results\n",
"# -------------\n",
"# Compare our kernel output with PyTorch's native implementation\n",
"# Compute reference result and verify\n",
"ref = (torch.einsum(\"mk,nk->mn\", a.to(torch.float32), b.to(torch.float32))).cpu()\n",
"torch.testing.assert_close(\n",
" c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=1e-1, rtol=1e-05\n",
")\n",
"print(\"Verification passed!\")"
]
},
{
"cell_type": "markdown",
"id": "8452a70b",
"metadata": {},
"source": [
"## Vectorized instructions for storing out\n",
"\n",
"If we use NCU to profile this kernel, a sharply drop of TensorCore utilizition for each wave switching. That's because of lots of st.global.b16.\n",
"CuTeDSL needs alignment & divisibility to choose vectorized instructions for cute.copy.\n",
"We need to set these attributes correctly from cute tensors.\n",
"\n",
"You can reach about **1400** Tflops after using vectorized instructions."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e713a4b2",
"metadata": {},
"outputs": [],
"source": [
"a_tensor_ = (\n",
" from_dlpack(a, assumed_align=32)\n",
" .mark_layout_dynamic(leading_dim=1)\n",
" .mark_compact_shape_dynamic(mode=1, divisibility=k)\n",
")\n",
"b_tensor_ = (\n",
" from_dlpack(b, assumed_align=32)\n",
" .mark_layout_dynamic(leading_dim=1)\n",
" .mark_compact_shape_dynamic(mode=1, divisibility=k)\n",
")\n",
"c_tensor_ = (\n",
" from_dlpack(c, assumed_align=32)\n",
" .mark_layout_dynamic(leading_dim=1)\n",
" .mark_compact_shape_dynamic(mode=1, divisibility=n)\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e3b5d8f",
"metadata": {},
"outputs": [],
"source": [
"# Compile the kernel for the specific input types\n",
"vectorized_kernel = cute.compile(host_function, a_tensor_, b_tensor_, c_tensor_, kernel_with_prefetch)\n",
"\n",
"# Run the kernel\n",
"benchmark(vectorized_kernel, a_tensor_, b_tensor_, c_tensor_)\n",
"\n",
"# Verify Results\n",
"# -------------\n",
"# Compare our kernel output with PyTorch's native implementation\n",
"# Compute reference result and verify\n",
"ref = (torch.einsum(\"mk,nk->mn\", a.to(torch.float32), b.to(torch.float32))).cpu()\n",
"torch.testing.assert_close(\n",
" c.cpu(), ref.to(cutlass_torch.dtype(io_dtype)), atol=1e-1, rtol=1e-05\n",
")\n",
"print(\"Verification passed!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}