v4.6 dev update. (#3315)

* v4.6 dev update.

* Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286)

* [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280)

* gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM

Adds CollectiveMma and CollectiveBuilder specializations for
MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM
(MoE expert dispatch) with tensor- and token-level FP8 scaling on
SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10).

New files:
- include/cutlass/gemm/collective/sm120_mma_array_tma.hpp
  CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized.
  Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules.
  Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B.
  Supports F8F6F4 MMA with TMA loads for both A and B operands.

- include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl
  CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized
  Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts
  from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized
  dispatch policy, produces correctly-typed CollectiveOp.

Modified files:
- collective_mma.hpp: include sm120_mma_array_tma.hpp
- collective_builder.hpp: include sm120_array_mma_builder.inl
- sm120_mma_builder.inl: remove ptr-array schedules from enable_if
  (they now route to sm120_array_mma_builder.inl) and drop the
  IsPtrArrayKernel static_assert that enforced the restriction

Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running
vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B
total / 4B active). Previously fell back to a non-CUTLASS Triton path;
with this patch, the SM120 CUTLASS grouped GEMM collective activates and
produces correct outputs. Short-sequence throughput improved ~7% vs the
fallback baseline (76.3 → 81.9 tok/s).

Closes #3263

Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>

* test: add SM120 ptr-array grouped GEMM unit tests

Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder
specializations introduced for MainloopSm120ArrayTmaWarpSpecialized,
covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and
KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across
e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs,
and two tile shapes.

Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new
cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per
reviewer request in PR #3280.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>

---------

Signed-off-by: Tyler Merritt <tgmerritt@gmail.com>
Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com>
Co-authored-by: Tyler <tgmerritt@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Junkai-Wu
2026-06-16 11:23:20 +08:00
committed by GitHub
parent 0ce648f53f
commit 39b352fa93
175 changed files with 18275 additions and 7581 deletions

View File

@@ -49,7 +49,7 @@ To run this example:
.. code-block:: bash
python examples/ampere/hstu_attention.py --batch_size 4 --seqlen_q 8192 --seqlen_kv 8192 --num_head 4 --head_dim 128 --m_block_size 128 --n_block_size 64 --is_causal --perf_test
python examples/cute/ampere/kernel/attention/hstu_attention.py --batch_size 4 --seqlen_q 8192 --seqlen_kv 8192 --num_head 4 --head_dim 128 --m_block_size 128 --n_block_size 64 --is_causal --perf_test
The above example tests the performance of HSTU attention with batch size 4, sequence length 8192, 4 attention heads, and head dimension 128. The m_block_size is 128, and n_block_size is 64. The causal masking is enabled.
@@ -268,6 +268,104 @@ class HSTUAttentionForwardAmpere(object):
stream=stream,
)
@cute.jit
def _copy_with_residue(
self,
copy_atom,
src_tile,
dst_tile,
coord_tile,
head_dim_pred,
has_outer_residue: cutlass.Constexpr,
has_inner_residue: cutlass.Constexpr,
outer_size,
block_end,
fill_zero_on_oob: cutlass.Constexpr = True,
is_known_boundary: cutlass.Constexpr = False,
):
# Copy a (CPY_Atom, M, K) tile with optional outer-axis (M) and head-dim (K) residue.
# `is_known_boundary=True` skips the runtime boundary check (caller knows the tile straddles outer_size).
# `fill_zero_on_oob=False` for stores; loads zero-fill out-of-bounds rows so SMEM contents are well-defined.
if cutlass.const_expr(not has_outer_residue and not has_inner_residue):
cute.copy(copy_atom, src_tile, dst_tile)
elif cutlass.const_expr(not has_outer_residue):
cute.copy(copy_atom, src_tile, dst_tile, pred=head_dim_pred)
else:
if cutlass.const_expr(is_known_boundary):
is_boundary = True
else:
is_boundary = cute.elem_less(outer_size, block_end)
if is_boundary:
for m in cutlass.range_constexpr(cute.size(dst_tile.shape[1])):
if cute.elem_less(coord_tile[0, m, 0][1], outer_size):
if cutlass.const_expr(has_inner_residue):
cute.copy(
copy_atom,
src_tile[None, m, None],
dst_tile[None, m, None],
pred=head_dim_pred[None, m, None],
)
else:
cute.copy(
copy_atom,
src_tile[None, m, None],
dst_tile[None, m, None],
)
elif cutlass.const_expr(fill_zero_on_oob):
dst_tile[None, m, None].fill(0)
else:
if cutlass.const_expr(has_inner_residue):
cute.copy(copy_atom, src_tile, dst_tile, pred=head_dim_pred)
else:
cute.copy(copy_atom, src_tile, dst_tile)
@cute.jit
def _copy_rab_tile(
self,
copy_atom,
src_tile,
dst_tile,
coord_tile,
has_q_residue: cutlass.Constexpr,
has_kv_residue: cutlass.Constexpr,
seqlen_q,
seqlen_kv,
q_block_end,
kv_block_end,
is_known_kv_interior: cutlass.Constexpr = False,
):
# Copy a (CPY_Atom, M, N) RAB tile with optional 2D residue.
# Coord entries are 4-tuples from the (B, H, Q, KV) identity tensor; index 2 is q-coord, index 3 is kv-coord.
# `is_known_kv_interior=True` skips the kv-side runtime check when the loaded tile is guaranteed inside seqlen_kv.
kv_check_active = has_kv_residue and not is_known_kv_interior
if cutlass.const_expr(not has_q_residue and not kv_check_active):
cute.copy(copy_atom, src_tile, dst_tile)
else:
# Pick the runtime predicate without mixing cute.Boolean with Python bool.
if cutlass.const_expr(has_q_residue and kv_check_active):
needs_per_elem = cute.elem_less(
seqlen_q, q_block_end
) or cute.elem_less(seqlen_kv, kv_block_end)
elif cutlass.const_expr(has_q_residue):
needs_per_elem = cute.elem_less(seqlen_q, q_block_end)
else:
needs_per_elem = cute.elem_less(seqlen_kv, kv_block_end)
if needs_per_elem:
for m in cutlass.range_constexpr(cute.size(dst_tile.shape[1])):
for n in cutlass.range_constexpr(cute.size(dst_tile.shape[2])):
if cute.elem_less(
coord_tile[0, m, n][2], seqlen_q
) and cute.elem_less(coord_tile[0, m, n][3], seqlen_kv):
cute.copy(
copy_atom,
src_tile[None, m, n],
dst_tile[None, m, n],
)
else:
dst_tile[None, m, n].fill(0)
else:
cute.copy(copy_atom, src_tile, dst_tile)
@cute.kernel
def kernel(
self,
@@ -395,7 +493,7 @@ class HSTUAttentionForwardAmpere(object):
tVsV = gmem_thr_copy_QKV.partition_D(sV)
# (CPY_Atom, CPY_M, CPY_N, n_block)
tRABgRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_S(gRAB)
tRabsRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_D(sRAB)
tRABsRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_D(sRAB)
# ///////////////////////////////////////////////////////////////////////////////
# Tile MMA compute thread partitions and allocate accumulators
@@ -446,115 +544,118 @@ class HSTUAttentionForwardAmpere(object):
tOsVt = smem_thr_copy_V.partition_S(sVt)
tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt)
tSsRAB = smem_thr_copy_RAB.partition_S(sRAB)
has_head_dim_residue = self._head_dim != self._head_dim_padded
has_q_residue = self._seqlen_q % self._m_block_size != 0
has_kv_residue = self._seqlen_kv % self._n_block_size != 0
need_q_predicates = has_head_dim_residue or has_q_residue
need_kv_predicates = has_head_dim_residue or has_kv_residue
need_rab_predicates = has_q_residue or has_kv_residue
# ///////////////////////////////////////////////////////////////////////////////
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
# of tile_shape
# ///////////////////////////////////////////////////////////////////////////////
# Construct identity layout for Q, KV and RAB
mcQ = cute.make_identity_tensor(mQ.layout.shape)
mcKV = cute.make_identity_tensor(mK.layout.shape)
mcRAB = cute.make_identity_tensor(mRAB.layout.shape)
cQ = cute.local_tile(
mcQ[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
cKV = cute.local_tile(
mcKV[batch_size, None, num_head, None],
(self._n_block_size, self._head_dim_padded),
(n_block, 0),
)
cRAB = cute.local_tile(
mcRAB[batch_size, num_head, None, None],
(self._m_block_size, self._n_block_size),
(m_block, None),
)
# Repeat the partitioning with identity layouts
tQcQ = gmem_thr_copy_QKV.partition_S(cQ)
tKVcKV = gmem_thr_copy_QKV.partition_S(cKV)
tRABcRAB = gmem_thr_copy_QKV.partition_S(cRAB)
tQpQ = cute.make_rmem_tensor(
cute.make_layout(
(
tQsQ.shape[0][1],
cute.size(tQsQ, mode=[1]),
cute.size(tQsQ, mode=[2]),
),
stride=(cute.size(tQsQ, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
tKVpKV = cute.make_rmem_tensor(
cute.make_layout(
(
tKsK.shape[0][1],
cute.size(tKsK, mode=[1]),
cute.size(tKsK, mode=[2]),
),
stride=(cute.size(tKsK, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
# Set predicates for head_dim bounds, seqlen_q/k/v bounds is processed at the first tile.
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
if cutlass.const_expr(need_q_predicates):
mcQ = cute.make_identity_tensor(mQ.layout.shape)
cQ = cute.local_tile(
mcQ[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
tQcQ = gmem_thr_copy_QKV.partition_S(cQ)
if cutlass.const_expr(has_head_dim_residue):
tQpQ = cute.make_rmem_tensor(
cute.make_layout(
(
tQsQ.shape[0][1],
cute.size(tQsQ, mode=[1]),
cute.size(tQsQ, mode=[2]),
),
stride=(cute.size(tQsQ, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
)
if cutlass.const_expr(need_kv_predicates):
mcKV = cute.make_identity_tensor(mK.layout.shape)
cKV = cute.local_tile(
mcKV[batch_size, None, num_head, None],
(self._n_block_size, self._head_dim_padded),
(n_block, 0),
)
tKVcKV = gmem_thr_copy_QKV.partition_S(cKV)
if cutlass.const_expr(has_head_dim_residue):
tKVpKV = cute.make_rmem_tensor(
cute.make_layout(
(
tKsK.shape[0][1],
cute.size(tKsK, mode=[1]),
cute.size(tKsK, mode=[2]),
),
stride=(cute.size(tKsK, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
)
if cutlass.const_expr(need_rab_predicates):
mcRAB = cute.make_identity_tensor(mRAB.layout.shape)
cRAB = cute.local_tile(
mcRAB[batch_size, num_head, None, None],
(self._m_block_size, self._n_block_size),
(m_block, None),
)
tRABcRAB = gmem_thr_copy_QKV.partition_S(cRAB)
# ///////////////////////////////////////////////////////////////////////////////
# Prefetch Prologue
# ///////////////////////////////////////////////////////////////////////////////
# Start async loads of the last mn-tile, where we take care of the mn residue
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
tQgQ[None, m, None],
tQsQ[None, m, None],
pred=tQpQ[None, m, None],
)
else:
# Clear the smem tiles to account for predicated off loads
tQsQ[None, m, None].fill(0)
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
tKgK[None, n, None, n_block],
tKsK[None, n, None],
pred=tKVpKV[None, n, None],
)
else:
# Clear the smem tiles to account for predicated off loads
tKsK[None, n, None].fill(0)
for m in cutlass.range_constexpr(cute.size(tRABcRAB.shape[1])):
for n in cutlass.range_constexpr(cute.size(tRABcRAB.shape[2])):
if cute.elem_less(
tRABcRAB[0, m, n, n_block][1], mRAB.layout.shape[2]
) and cute.elem_less(
tRABcRAB[0, m, n, n_block][2], mRAB.layout.shape[3]
):
cute.copy(
gmem_tiled_copy_QKV,
tRABgRAB[None, m, n, n_block],
tRabsRAB[None, m, n],
)
else:
# Clear the smem tiles to account for predicated off loads
tRabsRAB[None, m, n].fill(0)
self._copy_with_residue(
gmem_tiled_copy_QKV,
tQgQ[None, None, None],
tQsQ[None, None, None],
tQcQ if has_q_residue else None,
tQpQ if has_head_dim_residue else None,
has_q_residue,
has_head_dim_residue,
mQ.layout.shape[1],
(m_block + 1) * self._m_block_size,
)
# n_block is the last n-tile by construction, so any kv-residue is in this tile.
self._copy_with_residue(
gmem_tiled_copy_QKV,
tKgK[None, None, None, n_block],
tKsK[None, None, None],
tKVcKV if has_kv_residue else None,
tKVpKV if has_head_dim_residue else None,
has_kv_residue,
has_head_dim_residue,
mK.layout.shape[1],
(n_block + 1) * self._n_block_size,
is_known_boundary=has_kv_residue,
)
self._copy_rab_tile(
gmem_tiled_copy_QKV,
tRABgRAB[None, None, None, n_block],
tRABsRAB[None, None, None],
tRABcRAB[None, None, None, n_block] if need_rab_predicates else None,
has_q_residue,
has_kv_residue,
mRAB.layout.shape[2],
mRAB.layout.shape[3],
(m_block + 1) * self._m_block_size,
(n_block + 1) * self._n_block_size,
)
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
@@ -565,24 +666,17 @@ class HSTUAttentionForwardAmpere(object):
cute.arch.cp_async_wait_group(0)
self.cta_sync_barrier.arrive_and_wait()
if n_block_idx == n_block:
for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
if cute.elem_less(tKVcKV[0, n, 0][1], mV.layout.shape[1]):
cute.copy(
gmem_tiled_copy_QKV,
tVgV[None, n, None, n_block_idx],
tVsV[None, n, None],
pred=tKVpKV[None, n, None],
)
else:
tVsV[None, n, None].fill(0)
else:
cute.copy(
gmem_tiled_copy_QKV,
tVgV[None, None, None, n_block_idx],
tVsV[None, None, None],
pred=tKVpKV[None, None, None],
)
self._copy_with_residue(
gmem_tiled_copy_QKV,
tVgV[None, None, None, n_block_idx],
tVsV[None, None, None],
tKVcKV if has_kv_residue else None,
tKVpKV if has_head_dim_residue else None,
has_kv_residue,
has_head_dim_residue,
mV.layout.shape[1],
(n_block_idx + 1) * self._n_block_size,
)
cute.arch.cp_async_commit_group()
acc_shape_S = thr_mma.partition_shape_C(
@@ -643,25 +737,33 @@ class HSTUAttentionForwardAmpere(object):
self.cta_sync_barrier.arrive_and_wait()
if n_block_idx > 0:
cute.copy(
# tile (n_block_idx - 1) is always inside seqlen_kv, so only head_dim residue can apply
self._copy_with_residue(
gmem_tiled_copy_QKV,
tKgK[None, None, None, n_block_idx - 1],
tKsK[None, None, None],
pred=tKVpKV[None, None, None],
None,
tKVpKV[None, None, None] if has_head_dim_residue else None,
False,
has_head_dim_residue,
None,
None,
)
self._copy_rab_tile(
gmem_tiled_copy_QKV,
tRABgRAB[None, None, None, n_block_idx - 1],
tRABsRAB[None, None, None],
tRABcRAB[None, None, None, n_block_idx - 1]
if need_rab_predicates
else None,
has_q_residue,
has_kv_residue,
mRAB.layout.shape[2],
mRAB.layout.shape[3],
(m_block + 1) * self._m_block_size,
None,
is_known_kv_interior=True,
)
# m residue handling for RAB
for m in cutlass.range_constexpr(cute.size(tRABcRAB.shape[1])):
if cute.elem_less(
tRABcRAB[0, m, 0, n_block_idx - 1][1], mRAB.layout.shape[2]
):
cute.copy(
gmem_tiled_copy_QKV,
tRABgRAB[None, m, None, n_block_idx - 1],
tRabsRAB[None, m, None],
)
else:
tRabsRAB[None, m, None].fill(0)
cute.arch.cp_async_commit_group()
# ///////////////////////////////////////////////////////////////////////////////
@@ -695,26 +797,29 @@ class HSTUAttentionForwardAmpere(object):
t4 = t1 / t3
acc_S.store(t4)
mACC = cute.make_identity_tensor(
(mRAB.layout.shape[2], mRAB.layout.shape[3])
) # (seqlen_q, seqlen_kv)
cACC = cute.local_tile(
mACC[None, None],
(self._m_block_size, self._n_block_size),
(m_block, n_block_idx),
)
if cutlass.const_expr(self._is_causal):
mACC = cute.make_identity_tensor(
(mRAB.layout.shape[2], mRAB.layout.shape[3])
) # (seqlen_q, seqlen_kv)
cACC = cute.local_tile(
mACC[None, None],
(self._m_block_size, self._n_block_size),
(m_block, n_block_idx),
)
if self._is_causal and (n_block - n_block_idx) < cute.ceil_div(
self._m_block_size, self._n_block_size
):
tACCcACC = thr_mma.partition_C(cACC)
for i in cutlass.range_constexpr(cute.size(tACCcACC.shape[0])):
for j in cutlass.range_constexpr(cute.size(tACCcACC.shape[1])):
for k in cutlass.range_constexpr(cute.size(tACCcACC.shape[2])):
if cute.elem_less(
tACCcACC[i, j, k][0], tACCcACC[i, j, k][1]
if (n_block - n_block_idx) < cute.ceil_div(
self._m_block_size, self._n_block_size
):
tACCcACC = thr_mma.partition_C(cACC)
for i in cutlass.range_constexpr(cute.size(tACCcACC.shape[0])):
for j in cutlass.range_constexpr(cute.size(tACCcACC.shape[1])):
for k in cutlass.range_constexpr(
cute.size(tACCcACC.shape[2])
):
acc_S[i, j, k] = 0.0
if cute.elem_less(
tACCcACC[i, j, k][0], tACCcACC[i, j, k][1]
):
acc_S[i, j, k] = 0.0
rP = cute.make_rmem_tensor_like(acc_S, self._dtype)
rP.store(acc_S.load().to(self._dtype))
@@ -803,35 +908,39 @@ class HSTUAttentionForwardAmpere(object):
tOsO,
tOrO,
)
# predicate for O
mcO = cute.make_identity_tensor(mO.layout.shape)
cO = cute.local_tile(
mcO[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
tOcO = gmem_thr_copy_O.partition_D(cO)
tOpO = cute.make_rmem_tensor(
cute.make_layout(
(tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]),
stride=(tOgO.shape[2], 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
tOpO[rest_v, 0, rest_n] = cute.elem_less(
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
)
# copy acc O from rmem to gmem
for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])):
if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]):
cute.copy(
gmem_tiled_copy_O,
tOrO[None, rest_m, None],
tOgO[None, rest_m, None],
pred=tOpO[None, rest_m, None],
if cutlass.const_expr(need_q_predicates):
mcO = cute.make_identity_tensor(mO.layout.shape)
cO = cute.local_tile(
mcO[batch_size, None, num_head, None],
(self._m_block_size, self._head_dim_padded),
(m_block, 0),
)
tOcO = gmem_thr_copy_O.partition_D(cO)
if cutlass.const_expr(has_head_dim_residue):
tOpO = cute.make_rmem_tensor(
cute.make_layout(
(tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]),
stride=(tOgO.shape[2], 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
tOpO[rest_v, 0, rest_n] = cute.elem_less(
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
)
self._copy_with_residue(
gmem_tiled_copy_O,
tOrO[None, None, None],
tOgO[None, None, None],
tOcO if has_q_residue else None,
tOpO if has_head_dim_residue else None,
has_q_residue,
has_head_dim_residue,
mO.layout.shape[1],
(m_block + 1) * self._m_block_size,
fill_zero_on_oob=False,
)
def run_pytorch_hstu_test(

View File

@@ -27,7 +27,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import enum
import math
import os
import sys
@@ -42,12 +41,20 @@ import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.nvgpu.common import OperandMajorMode
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import Int32, Float32, Float8E4M3FN, Float16, BFloat16, Boolean
from cutlass.cute.typing import (
Int32,
Float32,
Float8E4M3FN,
Float16,
BFloat16,
Boolean,
)
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -74,7 +81,7 @@ To run this example:
.. code-block:: bash
python examples/blackwell/fmha_bwd.py \\
python examples/cute/blackwell/kernel/attention/fmha/fmha_bwd.py \\
--s_q_max 1024 --s_k_max 1024 \\
--h_q 8 --h_k 8 --d 128 --b 1 \\
--element_dtype float16 --acc_dtype float32 \\
@@ -193,6 +200,22 @@ class BlackwellFusedMultiHeadAttentionBackward:
barrier_id=5,
num_threads=self.num_reduce_warps * self.threads_per_warp,
)
self.load_reduce_tma_sync_barrier_0 = pipeline.NamedBarrier(
barrier_id=6,
num_threads=2 * self.threads_per_warp,
)
self.load_reduce_tma_sync_barrier_1 = pipeline.NamedBarrier(
barrier_id=7,
num_threads=2 * self.threads_per_warp,
)
self.load_reduce_tma_sync_barrier_2 = pipeline.NamedBarrier(
barrier_id=8,
num_threads=2 * self.threads_per_warp,
)
self.load_reduce_tma_sync_barrier_3 = pipeline.NamedBarrier(
barrier_id=9,
num_threads=2 * self.threads_per_warp,
)
self.tmem_dK_offset = 0
self.tmem_dV_offset = self.tmem_dK_offset + mma_tiler[2]
@@ -200,8 +223,8 @@ class BlackwellFusedMultiHeadAttentionBackward:
self.tmem_dP_offset = self.tmem_dQ_offset
self.tmem_S_offset = self.tmem_dQ_offset + max(mma_tiler[0], mma_tiler[2])
self.num_regs_reduce = 152
self.num_regs_compute = 128
self.num_regs_reduce = 144
self.num_regs_compute = 136
self.num_regs_mma = 96
self.num_regs_empty = 96
self.num_regs_load = 96
@@ -344,17 +367,17 @@ class BlackwellFusedMultiHeadAttentionBackward:
self.dO_major_mode = utils.LayoutEnum.from_tensor(dO).mma_major_mode()
self.dQ_layout = utils.LayoutEnum.from_tensor(dQ)
if cutlass.const_expr(self.Q_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.Q_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of q is not supported")
if cutlass.const_expr(self.dQ_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.dQ_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of dq is not supported")
if cutlass.const_expr(self.K_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.K_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of k is not supported")
if cutlass.const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.dK_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of dk is not supported")
if cutlass.const_expr(self.V_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.V_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of v is not supported")
if cutlass.const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K):
if cutlass.const_expr(self.dV_major_mode != OperandMajorMode.K):
raise RuntimeError("The layout of dv is not supported")
self._setup_attributes()
@@ -364,8 +387,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# compute S
KQ_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.element_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
self.element_dtype,
OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
cta_group,
self.KQ_mma_tiler[:2],
@@ -373,8 +397,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# compute dP
VdO_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.element_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
self.element_dtype,
OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
cta_group,
self.VdO_mma_tiler[:2],
@@ -382,8 +407,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# compute dV
PdO_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.element_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.MN,
self.element_dtype,
OperandMajorMode.K,
OperandMajorMode.MN,
self.acc_dtype,
cta_group,
self.PdO_mma_tiler[:2],
@@ -392,8 +418,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# compute dK
dSQ_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.element_dtype,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.MN,
self.element_dtype,
OperandMajorMode.K,
OperandMajorMode.MN,
self.acc_dtype,
cta_group,
self.dSQ_mma_tiler[:2],
@@ -401,8 +428,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# compute dQ
dSK_tiled_mma = sm100_utils.make_trivial_tiled_mma(
self.element_dtype,
tcgen05.OperandMajorMode.MN,
tcgen05.OperandMajorMode.MN,
self.element_dtype,
OperandMajorMode.MN,
OperandMajorMode.MN,
self.acc_dtype,
cta_group,
self.dSK_mma_tiler[:2],
@@ -483,7 +511,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
dQ_smem_layout_atom = sm100_utils.make_smem_layout_atom(
sm100_utils.get_smem_layout_atom_ab(
tcgen05.OperandMajorMode.K,
OperandMajorMode.K,
self.acc_dtype,
(self.tile_shape_Q, 32),
),
@@ -844,9 +872,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
LSE_smem_layout: cute.Layout,
sum_OdO_smem_layout: cute.Layout,
):
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
grid_dim_x, grid_dim_y, grid_dim_z = cute.arch.grid_dim()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
if warp_idx == self.load_warp_id:
@@ -1265,12 +1291,10 @@ class BlackwellFusedMultiHeadAttentionBackward:
# (load_mma_Q_pipeline, load_compute_LSE_pipeline, load_mma_dO_pipeline, load_compute_sum_OdO_pipeline)
pipeline_args: tuple,
):
tidx, tidy, tidz = cute.arch.thread_idx()
tidx = cute.arch.thread_idx()[0]
blk_coord_k, blk_coord_h_k, blk_coord_b = cute.arch.block_idx()
blk_coord_h_r = Int32(0)
blk_coord_h = (blk_coord_h_r, blk_coord_h_k)
seq_Q, seq_K, D, HB = problem_shape
H, B = HB
iter_index = iter_start
(
load_mma_Q_pipeline,
@@ -1362,6 +1386,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
load_compute_sum_OdO_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.load_compute_sum_OdO_stage
)
load_mma_Q_pipeline.producer_acquire(load_mma_Q_producer_state)
tma_barrier = load_mma_Q_pipeline.producer_get_barrier(
load_mma_Q_producer_state
@@ -1396,7 +1421,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp
atom_async_copy = cute.make_copy_atom(
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
cpasync.CopyG2SOp(cache_mode=cute.nvgpu.LoadCacheMode.ALWAYS),
self.acc_dtype,
num_bits_per_copy=self.acc_dtype.width,
)
@@ -1486,6 +1511,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
iter_count -= 1
iter_index += 1
load_reduce_tma_sync_phase = Int32(0)
while iter_count > 0:
if iter_index == iter_end:
iter_index = iter_start
@@ -1507,6 +1533,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
load_mma_Q_producer_state.advance()
self.load_reduce_tma_sync_arrive(load_reduce_tma_sync_phase)
load_reduce_tma_sync_phase += 1
load_compute_LSE_pipeline.producer_acquire(load_compute_LSE_producer_state)
# Load LSE
@@ -1589,6 +1618,30 @@ class BlackwellFusedMultiHeadAttentionBackward:
iter_count -= 1
iter_index += 1
@cute.jit
def load_reduce_tma_sync_arrive(self, phase: Int32):
phase_mod = phase % 4
if phase_mod == 0:
self.load_reduce_tma_sync_barrier_0.arrive()
elif phase_mod == 1:
self.load_reduce_tma_sync_barrier_1.arrive()
elif phase_mod == 2:
self.load_reduce_tma_sync_barrier_2.arrive()
else:
self.load_reduce_tma_sync_barrier_3.arrive()
@cute.jit
def load_reduce_tma_sync_wait(self, phase: Int32):
phase_mod = phase % 4
if phase_mod == 0:
self.load_reduce_tma_sync_barrier_0.arrive_and_wait()
elif phase_mod == 1:
self.load_reduce_tma_sync_barrier_1.arrive_and_wait()
elif phase_mod == 2:
self.load_reduce_tma_sync_barrier_2.arrive_and_wait()
else:
self.load_reduce_tma_sync_barrier_3.arrive_and_wait()
@cute.jit
def mma(
self,
@@ -1909,11 +1962,10 @@ class BlackwellFusedMultiHeadAttentionBackward:
# (mma_compute_S_pipeline, compute_mma_P_pipeline, load_compute_LSE_pipeline, load_compute_sum_OdO_pipeline, mma_compute_dP_pipeline, compute_mma_dS_pipeline, mma_compute_dKdV_pipeline)
pipeline_args: tuple,
):
tidx, tidy, tidz = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
tidx = cute.arch.thread_idx()[0]
Q, K, D, HB = problem_shape
blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord
Q, K = problem_shape[0], problem_shape[1]
blk_coord_k, blk_coord_batch = blk_coord[1], blk_coord[3]
iter_index = iter_start
(
mma_compute_S_pipeline,
@@ -2205,10 +2257,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# (mma_reduce_dQ_pipeline, reduce_tma_store_pipeline)
pipeline_args: tuple,
):
tidx, tidy, tidz = cute.arch.thread_idx()
tidx = cute.arch.thread_idx()[0]
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
Q, K, D, HB = problem_shape
blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord
blk_coord_batch = blk_coord[3]
blk_coord_h, blk_coord_b = blk_coord_batch
blk_coord_h_r, blk_coord_h_k = blk_coord_h
@@ -2241,7 +2292,6 @@ class BlackwellFusedMultiHeadAttentionBackward:
thr_t2r = tiled_t2r.get_slice(thread_idx)
tTR_cdQ = thr_t2r.partition_D(cdQ)
tTR_gdQ = thr_t2r.partition_D(gdQ)
tTR_sdQ = thr_t2r.partition_D(sdQ)
tTR_tdQ = thr_t2r.partition_S(tdQtdQ)
@@ -2253,6 +2303,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
cute.group_modes(gdQ, 0, 2),
)
load_reduce_tma_sync_phase = Int32(0)
while iter_count > 0:
mma_reduce_dQ_pipeline.consumer_wait(mma_reduce_dQ_consumer_state)
@@ -2286,6 +2337,10 @@ class BlackwellFusedMultiHeadAttentionBackward:
self.reduce_sync_barrier.arrive_and_wait()
if warp_idx == 0:
if iter_count > 1 and i == 0:
self.load_reduce_tma_sync_wait(load_reduce_tma_sync_phase)
load_reduce_tma_sync_phase += 1
cute.copy(
tma_atom_dQ_acc,
tdQsdQ[None, reduce_tma_store_producer_state.index],
@@ -2346,7 +2401,6 @@ class BlackwellFusedMultiHeadAttentionBackward:
input: cute.Tensor,
frg_cnt: Int32,
) -> cute.Tensor:
tidx, tidy, tidz = cute.arch.thread_idx()
output = cute.make_rmem_tensor(input.shape, self.element_dtype)
frg_tile = cute.size(input) // frg_cnt
t_frg = cute.logical_divide(input, cute.make_layout(frg_cnt))
@@ -2406,9 +2460,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
# (mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state)
pipeline_args: tuple,
):
tidx, tidy, tidz = cute.arch.thread_idx()
Q, K, D, HB = problem_shape
blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord
tidx = cute.arch.thread_idx()[0]
K, D, HB = problem_shape[1], problem_shape[2], problem_shape[3]
blk_coord_k, blk_coord_batch = blk_coord[1], blk_coord[3]
mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state = pipeline_args
load_op = cute.make_copy_atom(
@@ -2511,7 +2565,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
workspace: cute.Tensor,
acc_dtype: Type[cutlass.Numeric],
) -> Tuple[cute.Tensor, cute.Tensor, cute.Tensor]:
Q, D, HB = (
Q, D, _HB = (
problem_shape[0],
problem_shape[2],
problem_shape[3],
@@ -2524,7 +2578,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
acc_bytes = acc_dtype.width // 8
sum_OdO_bytes = cute.assume(B * H * Q * acc_bytes, divby=acc_bytes)
scaled_lse_bytes = cute.assume(B * H * Q * acc_bytes, divby=acc_bytes)
dQ_acc_bytes = cute.assume(B * H * Q * D * acc_bytes, divby=acc_bytes)
cute.assume(B * H * Q * D * acc_bytes, divby=acc_bytes)
sum_OdO_iter = workspace.iterator
scaled_lse_iter = sum_OdO_iter + sum_OdO_bytes
@@ -2899,7 +2953,9 @@ def run(
window_size_right,
bottom_right_align,
):
raise ValueError("sliding window doesn't support current setting")
raise testing.CantImplementError(
"sliding window doesn't support current setting"
)
# create sequence lengths for variable length inputs
cumulative_s_q = [0]
@@ -2985,10 +3041,10 @@ def run(
lse_ref = cutlass_torch.create_and_permute_torch_tensor(
(b, h_k, h_r, s_q),
cutlass.torch.dtype(acc_dtype),
cutlass_torch.dtype(acc_dtype),
permute_order=(3, 2, 1, 0),
init_type=cutlass.torch.TensorInitType.RANDOM,
init_config=cutlass.torch.RandomInitConfig(min_val=10, max_val=11),
init_type=cutlass_torch.TensorInitType.RANDOM,
init_config=cutlass_torch.RandomInitConfig(min_val=10, max_val=11),
)
lse_torch = lse_ref.cuda()
lse_tensor = from_dlpack(lse_torch, assumed_align=16)
@@ -2997,7 +3053,11 @@ def run(
mma_tiler = (*mma_tiler_mn, d)
fmha_bwd = BlackwellFusedMultiHeadAttentionBackward(
element_dtype, acc_dtype, mma_tiler, varlen, mask_type
element_dtype,
acc_dtype,
mma_tiler,
varlen,
mask_type,
)
workspace_size = BlackwellFusedMultiHeadAttentionBackward._get_workspace_size(
@@ -3112,11 +3172,11 @@ def run(
torch.cuda.synchronize()
print("Verifying results...")
q_ref = q_ref.cuda().to(cutlass.torch.dtype(element_dtype))
k_ref = k_ref.cuda().to(cutlass.torch.dtype(element_dtype))
v_ref = v_ref.cuda().to(cutlass.torch.dtype(element_dtype))
o_ref = o_ref.cuda().to(cutlass.torch.dtype(element_dtype))
do_ref = do_ref.cuda().to(cutlass.torch.dtype(element_dtype))
q_ref = q_ref.cuda().to(cutlass_torch.dtype(element_dtype))
k_ref = k_ref.cuda().to(cutlass_torch.dtype(element_dtype))
v_ref = v_ref.cuda().to(cutlass_torch.dtype(element_dtype))
o_ref = o_ref.cuda().to(cutlass_torch.dtype(element_dtype))
do_ref = do_ref.cuda().to(cutlass_torch.dtype(element_dtype))
dv = dv_torch.to(dtype=torch.float32)
dk = dk_torch.to(dtype=torch.float32)
dq = dq_torch.to(dtype=torch.float32)
@@ -3227,10 +3287,10 @@ def run(
lse_ref_new = cutlass_torch.create_and_permute_torch_tensor(
(b, h_k, h_r, s_q),
cutlass.torch.dtype(acc_dtype),
cutlass_torch.dtype(acc_dtype),
permute_order=(3, 2, 1, 0),
init_type=cutlass.torch.TensorInitType.RANDOM,
init_config=cutlass.torch.RandomInitConfig(min_val=10, max_val=11),
init_type=cutlass_torch.TensorInitType.RANDOM,
init_config=cutlass_torch.RandomInitConfig(min_val=10, max_val=11),
)
lse_torch_new = lse_ref_new.cuda()
lse_tensor_new = from_dlpack(lse_torch_new, assumed_align=16)

View File

@@ -27,54 +27,21 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import torch
import argparse
import numpy as np
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import cuda.bindings.driver as cuda
try:
from cuda.core import Device
except ImportError:
from cuda.core.experimental import Device
from cuda.pathfinder import load_nvidia_dynamic_lib
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
from cutlass import testing
from cutlass.cute.runtime import from_dlpack
from cutlass.cutlass_dsl import T
from cutlass._mlir.dialects import vector
try:
import nvshmem.core
except ImportError as exc:
raise ImportError(
"nvshmem4py is required but not installed. Please install it using:\n"
" For CUDA 12: pip install nvshmem4py-cu12\n"
" For CUDA 13: pip install nvshmem4py-cu13\n"
"Note: nvshmem4py version >= 0.1.3 is recommended."
) from None
try:
load_nvidia_dynamic_lib("nvshmem_host")
except RuntimeError as exc:
raise ImportError(
"nvshmem lib is required but not installed. Please install it using:\n"
" For CUDA 12: pip install nvidia-nvshmem-cu12\n"
" For CUDA 13: pip install nvidia-nvshmem-cu13\n"
) from None
"""
A Distributed One-Shot All-Reduce Example using CuTe DSL and fine-grained memory control. This is a mirrored version of the
existing tensorrt_llm kernel:
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
In Lamport terminology this is a classic flag-based busy-wait: every participant keeps polling the shared slot until the
flag changes from the sentinel (negative zero) to real data, which indicates that the Lamport-style logical ordering has
advanced and the payload is safe to consume.
This example kernel demonstrates a one-shot all-reduce operation using the CuTe DSL with fine-grained memory control.
It uses dedicated communication buffers for data exchange, and these buffers act as ping-pong buffers. During the
process, the kernel uses one buffer for communication and initializes the next buffer to all negative zeros.
@@ -90,8 +57,8 @@ The .SYS memory scope and .VOLATILE memory order are used to ensure that the dat
.. code-block:: bash
torchrun --nproc-per-node 8 examples/distributed/all_reduce_one_shot_lamport.py --M 8192 --N 8192
torchrun --nproc-per-node 8 examples/distributed/all_reduce_one_shot_lamport.py \
torchrun --nproc-per-node 8 examples/cute/blackwell/distributed/all_reduce_one_shot_lamport.py --M 8192 --N 8192
torchrun --nproc-per-node 8 examples/cute/blackwell/distributed/all_reduce_one_shot_lamport.py \
--M 8192 --N 8192 --benchmark --warmup_iterations 2 --iterations 10
"""
@@ -174,14 +141,14 @@ class AllReduceOneShotLamportKernel:
# assume all buffers have the same element type with input
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
cute.nvgpu.CopyG2ROp(),
buffers[0].element_type,
num_bits_per_copy=128,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
)
copy_atom_store = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
cute.nvgpu.CopyR2GOp(),
buffers[0].element_type,
num_bits_per_copy=128,
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
@@ -276,6 +243,10 @@ def run_all_reduce_one_shot(
skip_ref_check=False,
benchmark=True,
):
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if rank == 0:
@@ -284,8 +255,20 @@ def run_all_reduce_one_shot(
print(f"GPU count: {world_size}")
# init buffer tensors to be neg 0
local_buffer_tensor = nvshmem.core.tensor([PING_PONG_SIZE, world_size, M, N,], dtype=torch.float32).neg_()
buffer_tensor_list = [nvshmem.core.get_peer_tensor(local_buffer_tensor, rank).permute(2, 3, 1, 0) for rank in range(world_size)]
t = symm_mem.empty(
[
PING_PONG_SIZE,
world_size,
M,
N,
],
device="cuda",
).neg_()
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
buffer_tensor_list = [
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 3, 1, 0)
for rank in range(world_size)
]
signal = cutlass.Int32(0)
input_tensor = torch.randn([M, N], device=f"cuda:{rank}")
output_tensor = torch.zeros([M, N], device=f"cuda:{rank}")
@@ -319,33 +302,40 @@ def run_all_reduce_one_shot(
if rank == 0:
print("Results verified successfully!")
for t in buffer_tensor_list:
nvshmem.core.free_tensor(t)
if not benchmark:
return
free_func_and_tensor_pairs = []
def add_free_func_and_tensor(free_func, tensor):
free_func_and_tensor_pairs.append((free_func, tensor))
def generate_tensors():
local_buffer = nvshmem.core.tensor([PING_PONG_SIZE, world_size, M, N,], dtype=torch.float32).neg_()
buffer_tensor_list = [nvshmem.core.get_peer_tensor(local_buffer, rank).permute(2, 3, 1, 0) for rank in range(world_size)]
input_tensor = torch.randn([M, N], device=f"cuda:{rank}")
output_tensor = torch.zeros([M, N], device=f"cuda:{rank}")
t = symm_mem.empty(
[
PING_PONG_SIZE,
world_size,
M * N,
],
device="cuda",
).neg_()
hdl = symm_mem.rendezvous(t, group=dist.group.WORLD.group_name)
# get tensors from other devices from the symmetric memory
buffers = [
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
for rank in range(world_size)
]
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
ja = testing.JitArguments(
cutlass.Int32(0),
from_dlpack(input_tensor, assumed_align=32),
from_dlpack(output_tensor, assumed_align=32),
[from_dlpack(t, assumed_align=32) for t in buffer_tensor_list],
stream=stream
[from_dlpack(t, assumed_align=32) for t in buffers],
stream=stream,
)
for tensor in buffer_tensor_list:
add_free_func_and_tensor(nvshmem.core.free_tensor, tensor)
ja._hdl = (
hdl # in order to extend the life cycle of hdl for the kernel execution
)
ja._t = t # same reason
return ja
avg_time_us = testing.benchmark(
compiled_func,
workspace_generator=generate_tensors,
@@ -361,47 +351,39 @@ def run_all_reduce_one_shot(
f"Achieved memory throughput: {((world_size + 1) * output_tensor.numel() * 32 // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
)
for free_func, tensor in free_func_and_tensor_pairs:
free_func(tensor)
def torchrun_uid_init_bcast():
"""
Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
def run(
M,
N,
warmup_iterations=2,
iterations=10,
skip_ref_check=False,
benchmark=True,
):
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
It uses torch.distributed.broadcast on a NumPy array to handle the broadcasting
"""
# Set Torch device
local_rank = int(os.environ['LOCAL_RANK'])
globals()["torch"] = torch
globals()["dist"] = dist
globals()["symm_mem"] = symm_mem
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# nvshmem4py requires a cuda.core Device at init time
dev = Device(local_rank)
dev.set_current()
global stream
stream = dev.create_stream()
# Initialize torch.distributed process group
dist.init_process_group(
backend="cpu:gloo,cuda:nccl",
)
# Extract rank, nranks from process group
num_ranks = dist.get_world_size()
# Create an empty uniqueid for all ranks
uid = nvshmem.core.get_unique_id(empty=(local_rank != 0))
uid_bytes = uid._data.view(np.uint8).copy()
uid_tensor = torch.from_numpy(uid_bytes).cuda()
dist.broadcast(uid_tensor, src=0)
dist.barrier()
uid._data[:] = uid_tensor.cpu().numpy().view(uid._data.dtype)
nvshmem.core.init(device=dev, uid=uid, rank=local_rank, nranks=num_ranks, initializer_method="uid")
def torchrun_finalize():
nvshmem.core.finalize()
dist.destroy_process_group()
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
try:
run_all_reduce_one_shot(
M,
N,
warmup_iterations,
iterations,
skip_ref_check,
benchmark,
)
finally:
if dist.is_initialized():
dist.destroy_process_group()
def main():
@@ -414,13 +396,17 @@ def main():
parser.add_argument("--iterations", default=10, type=int)
parser.add_argument("--skip_ref_check", action="store_true")
parser.add_argument("--benchmark", action="store_true")
args = parser.parse_args()
torchrun_uid_init_bcast()
run_all_reduce_one_shot(args.M, args.N, args.warmup_iterations, args.iterations, args.skip_ref_check, args.benchmark)
torchrun_finalize()
run(
args.M,
args.N,
args.warmup_iterations,
args.iterations,
args.skip_ref_check,
args.benchmark,
)
return

View File

@@ -269,8 +269,8 @@ class GroupedGemmKernel:
c_bytes = c_bytes_per_stage * self.num_c_stage
self.num_sched_stages = 2
sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
sched_bytes = sched_work_tile_bytes_per_stage * self.num_sched_stages
self.sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
sched_bytes = self.sched_work_tile_bytes_per_stage * self.num_sched_stages
fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes
@@ -774,7 +774,11 @@ class GroupedGemmKernel:
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4]
sched_buf_align_bytes = self.sched_work_tile_bytes_per_stage
sched_buf: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4],
sched_buf_align_bytes,
]
sched_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_sched_stages * 2
]
@@ -2010,9 +2014,6 @@ if __name__ == "__main__":
compare_with_bmm=args.compare_with_bmm,
compare_with_sol=args.compare_with_sol,
)
if misc.no_torch_210:
misc.compare_with_bmm = True
print("Override to set --compare_with_bmm to avoid possible torch crash.")
tester = GroupedGemmTester(problem, impl, misc)
tester.run()

View File

@@ -359,6 +359,7 @@ class ScaledGroupedGemmKernel:
)
self.num_sched_stages = 2
self.sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
# ── SMEM layouts ──
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
@@ -1371,7 +1372,11 @@ class ScaledGroupedGemmKernel:
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_pipeline_stages * 2
]
sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4]
sched_buf_align_bytes = self.sched_work_tile_bytes_per_stage
sched_buf: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4],
sched_buf_align_bytes,
]
sched_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_sched_stages * 2
]

View File

@@ -1,12 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# This is the third tutorial GEMM. It further enhances the second tutorial by adding warp
# specialization for TMA, MMA, and epilogue warps.
@@ -32,7 +50,7 @@ The dynamic scheduler is more flexible than the static scheduler, as it can hand
To run this example:
.. code-block:: bash
python examples/blackwell/tutorial_gemm/fp16_gemm_3_1.py \
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_3_1.py \
--mnk 8192,8192,8192
Constraints for this example:
@@ -72,29 +90,32 @@ class SharedStorage:
tmem_holding_buffer: cutlass.Int32
# Only for CLC Dynamic Scheduler
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
clc_response_align_bytes = num_clc_response_bytes
clc_response: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 4],
clc_response_align_bytes,
]
@cute.kernel()
def kernel(
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
a_smem_layout: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
tma_a: cpasync.TmaInfo,
tma_b: cpasync.TmaInfo,
tma_c: cpasync.TmaInfo,
c_smem_layout_kind: cutlass.Constexpr,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
cta_layout_vmnk: cute.Layout,
tile_sched_params: Union[
utils.ClcDynamicPersistentTileSchedulerParams,
utils.PersistentTileSchedulerParams,
],
):
):
# Extract tma_tensor from TmaInfo
mA_mkl = tma_a.tma_tensor
mB_nkl = tma_b.tma_tensor
mC_mnl = tma_c.tma_tensor
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@@ -123,9 +144,9 @@ def kernel(
# Prefetch tma descriptor
if warp_idx == tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
cpasync.prefetch_descriptor(tma_a.atom)
cpasync.prefetch_descriptor(tma_b.atom)
cpasync.prefetch_descriptor(tma_c.atom)
# As many participants as the number of threads issuing the MMA in the same row and column
# Substract one to not count twice the same thread
@@ -167,8 +188,8 @@ def kernel(
)
num_tma_copy_bytes = (
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
) * cute.size(cta_layout_vmnk, mode=[0])
# Threads/warps participating in the mainloop pipeline
@@ -248,21 +269,21 @@ def kernel(
# Allocate SMEM
sA = smem.allocate_tensor(
element_type=io_dtype,
layout=a_smem_layout.outer,
layout=tma_a.smem_layout.outer,
byte_alignment=128,
swizzle=a_smem_layout.inner,
swizzle=tma_a.smem_layout.inner,
)
sB = smem.allocate_tensor(
element_type=io_dtype,
layout=b_smem_layout.outer,
layout=tma_b.smem_layout.outer,
byte_alignment=128,
swizzle=b_smem_layout.inner,
swizzle=tma_b.smem_layout.inner,
)
sC = smem.allocate_tensor(
element_type=io_dtype,
layout=epi_smem_layout_staged.outer,
layout=tma_c.smem_layout.outer,
byte_alignment=128,
swizzle=epi_smem_layout_staged.inner,
swizzle=tma_c.smem_layout.inner,
)
# Partition tensors for MMA and make fragments
@@ -301,7 +322,7 @@ def kernel(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK)
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
tma_a.atom,
cta_in_cluster_coord_vmnk[2],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
cute.group_modes(sA, 0, 3),
@@ -311,7 +332,7 @@ def kernel(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestN, RestK)
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
tma_b.atom,
cta_in_cluster_coord_vmnk[1],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
cute.group_modes(sB, 0, 3),
@@ -321,7 +342,7 @@ def kernel(
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
tma_c.atom,
0,
cute.make_layout(1),
cute.group_modes(sC, 0, 2),
@@ -379,14 +400,14 @@ def kernel(
# Issue TMA loads
cute.copy(
tma_atom_a,
tma_a.atom,
tAgA_slice[(None, k_tile_idx)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=tma_mcast_mask_a,
)
cute.copy(
tma_atom_b,
tma_b.atom,
tBgB_slice[(None, k_tile_idx)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
@@ -447,23 +468,14 @@ def kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()
@@ -496,10 +508,10 @@ def kernel(
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
# Initialize TMA store pipeline for epilogue
# Initialize TMA store pipeline for epilogue with 4 warps
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
size=128,
pipeline.Agent.Warp,
size=4,
)
epilogue_pipeline = pipeline.PipelineTmaStore.create(
num_stages=epi_stages,
@@ -594,7 +606,7 @@ def kernel(
# SMEM -> GMEM
if warp_idx == epilogue_warp_ids[0]:
cute.copy(
tma_atom_c,
tma_c.atom,
tCsC[(None, c_buffer)],
tCgC_grouped[(None, subtile_idx)],
)
@@ -678,8 +690,8 @@ def host_function(
mma_inst_shape_mnk,
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
tcgen05.OperandSource.SMEM,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
)
tiled_mma = cute.make_tiled_mma(op)
@@ -717,21 +729,18 @@ def host_function(
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
op,
a,
a_smem_layout_slice,
a_smem_layout,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
)
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
op,
b,
b_smem_layout_slice,
b_smem_layout,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
@@ -757,11 +766,10 @@ def host_function(
epi_stages,
)
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
c_tma_atom, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
epi_smem_layout_staged,
epi_tile,
)
@@ -779,16 +787,10 @@ def host_function(
kernel(
tiled_mma,
a_tma_atom,
a_tma_tensor,
b_tma_atom,
b_tma_tensor,
c_tma_atom,
c_tma_tensor,
a_smem_layout,
b_smem_layout,
tma_a,
tma_b,
tma_c,
c_smem_layout_kind,
epi_smem_layout_staged,
epi_tile,
cta_layout_vmnk,
tile_sched_params,

View File

@@ -1,12 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# This is the fourth tutorial GEMM (4). It extends fp16_gemm_3_1.py by adding TMA prefetch.
# TMA prefetch uses cute.prefetch() to bring data into L2 cache before TMA copy needs it,
@@ -55,7 +73,7 @@ CuTe DSL Blackwell SM100 kernels. Users can specify preferred and fallback clust
To run this example:
.. code-block:: bash
python examples/blackwell/tutorial_gemm/fp16_gemm_4.py \
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_4.py \
--mnk 8192,8192,8192
Constraints for this example:
@@ -96,22 +114,20 @@ class SharedStorage:
tmem_holding_buffer: cutlass.Int32
# Only for CLC Dynamic Scheduler
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
clc_response_align_bytes = num_clc_response_bytes
clc_response: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 4],
clc_response_align_bytes,
]
@cute.jit
def cluster_specific_kernel(
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
a_smem_layout: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
tma_a: cpasync.TmaInfo,
tma_b: cpasync.TmaInfo,
tma_c: cpasync.TmaInfo,
c_smem_layout_kind: cutlass.Constexpr,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
cta_layout_vmnk: cute.Layout,
cluster_shape_mnk: Tuple[int, int, int],
@@ -120,6 +136,11 @@ def cluster_specific_kernel(
utils.PersistentTileSchedulerParams,
],
):
# Extract tma_tensor from TmaInfo
mA_mkl = tma_a.tma_tensor
mB_nkl = tma_b.tma_tensor
mC_mnl = tma_c.tma_tensor
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@@ -148,9 +169,9 @@ def cluster_specific_kernel(
# Prefetch tma descriptor
if warp_idx == tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
cpasync.prefetch_descriptor(tma_a.atom)
cpasync.prefetch_descriptor(tma_b.atom)
cpasync.prefetch_descriptor(tma_c.atom)
# As many participants as the number of threads issuing the MMA in the same row and column
# Substract one to not count twice the same thread
@@ -192,8 +213,8 @@ def cluster_specific_kernel(
)
num_tma_copy_bytes = (
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
) * cute.size(cta_layout_vmnk, mode=[0])
# Threads/warps participating in the mainloop pipeline
@@ -273,21 +294,21 @@ def cluster_specific_kernel(
# Allocate SMEM
sA = smem.allocate_tensor(
element_type=io_dtype,
layout=a_smem_layout.outer,
layout=tma_a.smem_layout.outer,
byte_alignment=128,
swizzle=a_smem_layout.inner,
swizzle=tma_a.smem_layout.inner,
)
sB = smem.allocate_tensor(
element_type=io_dtype,
layout=b_smem_layout.outer,
layout=tma_b.smem_layout.outer,
byte_alignment=128,
swizzle=b_smem_layout.inner,
swizzle=tma_b.smem_layout.inner,
)
sC = smem.allocate_tensor(
element_type=io_dtype,
layout=epi_smem_layout_staged.outer,
layout=tma_c.smem_layout.outer,
byte_alignment=128,
swizzle=epi_smem_layout_staged.inner,
swizzle=tma_c.smem_layout.inner,
)
# Partition tensors for MMA and make fragments
@@ -326,7 +347,7 @@ def cluster_specific_kernel(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK)
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
tma_a.atom,
cta_in_cluster_coord_vmnk[2],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
cute.group_modes(sA, 0, 3),
@@ -336,7 +357,7 @@ def cluster_specific_kernel(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestN, RestK)
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
tma_b.atom,
cta_in_cluster_coord_vmnk[1],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
cute.group_modes(sB, 0, 3),
@@ -346,7 +367,7 @@ def cluster_specific_kernel(
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
tma_c.atom,
0,
cute.make_layout(1),
cute.group_modes(sC, 0, 2),
@@ -403,14 +424,14 @@ def cluster_specific_kernel(
# Issue TMA loads
cute.copy(
tma_atom_a,
tma_a.atom,
tAgA_slice[(None, k_tile_idx)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=tma_mcast_mask_a,
)
cute.copy(
tma_atom_b,
tma_b.atom,
tBgB_slice[(None, k_tile_idx)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
@@ -471,23 +492,14 @@ def cluster_specific_kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()
@@ -520,10 +532,10 @@ def cluster_specific_kernel(
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
# Initialize TMA store pipeline for epilogue
# Initialize TMA store pipeline for epilogue with 4 warps
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
size=128,
pipeline.Agent.Warp,
size=4,
)
epilogue_pipeline = pipeline.PipelineTmaStore.create(
num_stages=epi_stages,
@@ -618,7 +630,7 @@ def cluster_specific_kernel(
# SMEM -> GMEM
if warp_idx == epilogue_warp_ids[0]:
cute.copy(
tma_atom_c,
tma_c.atom,
tCsC[(None, c_buffer)],
tCgC_grouped[(None, subtile_idx)],
)
@@ -652,20 +664,12 @@ def cluster_specific_kernel(
@cute.kernel
def kernel(
tiled_mma: cute.TiledMma,
tma_atom_a_preferred: cute.CopyAtom,
mA_mkl_preferred: cute.Tensor,
tma_atom_b_preferred: cute.CopyAtom,
mB_nkl_preferred: cute.Tensor,
tma_atom_a_fallback: cute.CopyAtom,
mA_mkl_fallback: cute.Tensor,
tma_atom_b_fallback: cute.CopyAtom,
mB_nkl_fallback: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
a_smem_layout: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
tma_a_preferred: cpasync.TmaInfo,
tma_b_preferred: cpasync.TmaInfo,
tma_a_fallback: cpasync.TmaInfo,
tma_b_fallback: cpasync.TmaInfo,
tma_c: cpasync.TmaInfo,
c_smem_layout_kind: cutlass.Constexpr,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
preferred_cta_layout_vmnk: cute.Layout,
fallback_cta_layout_vmnk: cute.Layout,
@@ -687,19 +691,15 @@ def kernel(
)
# As for now, only support preferred cluster kernel via the mega-kernel approach
# mega-kernel approach has 2 mutually exclusive code branches, only one path runs per launch,
# specify `smem_merge_branch_allocs=True` at launch to enables shared memory reuse between two paths
if is_preferred_cluster:
cluster_specific_kernel(
tiled_mma,
tma_atom_a_preferred,
mA_mkl_preferred,
tma_atom_b_preferred,
mB_nkl_preferred,
tma_atom_c,
mC_mnl,
a_smem_layout,
b_smem_layout,
tma_a_preferred,
tma_b_preferred,
tma_c,
c_smem_layout_kind,
epi_smem_layout_staged,
epi_tile,
preferred_cta_layout_vmnk,
preferred_cluster_shape_mnk,
@@ -708,16 +708,10 @@ def kernel(
else:
cluster_specific_kernel(
tiled_mma,
tma_atom_a_fallback,
mA_mkl_fallback,
tma_atom_b_fallback,
mB_nkl_fallback,
tma_atom_c,
mC_mnl,
a_smem_layout,
b_smem_layout,
tma_a_fallback,
tma_b_fallback,
tma_c,
c_smem_layout_kind,
epi_smem_layout_staged,
epi_tile,
fallback_cta_layout_vmnk,
fallback_cluster_shape_mnk,
@@ -814,8 +808,8 @@ def host_function(
mma_inst_shape_mnk,
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
tcgen05.OperandSource.SMEM,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
)
tiled_mma = cute.make_tiled_mma(op)
@@ -860,37 +854,35 @@ def host_function(
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
tma_atom_a_fallback, a_tma_tensor_fallback = cute.nvgpu.make_tiled_tma_atom_A(
tma_a_fallback = cute.nvgpu.make_tiled_tma_atom_A(
op,
a,
a_smem_layout_slice,
a_smem_layout,
mma_tiler_mnk,
tiled_mma,
fallback_cta_layout_vmnk.shape,
)
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
tma_atom_b_fallback, b_tma_tensor_fallback = cute.nvgpu.make_tiled_tma_atom_B(
tma_b_fallback = cute.nvgpu.make_tiled_tma_atom_B(
op,
b,
b_smem_layout_slice,
b_smem_layout,
mma_tiler_mnk,
tiled_mma,
fallback_cta_layout_vmnk.shape,
)
tma_atom_a_preferred, a_tma_tensor_preferred = cute.nvgpu.make_tiled_tma_atom_A(
tma_a_preferred = cute.nvgpu.make_tiled_tma_atom_A(
op,
a,
a_smem_layout_slice,
a_smem_layout,
mma_tiler_mnk,
tiled_mma,
preferred_cta_layout_vmnk.shape,
)
tma_atom_b_preferred, b_tma_tensor_preferred = cute.nvgpu.make_tiled_tma_atom_B(
tma_b_preferred = cute.nvgpu.make_tiled_tma_atom_B(
op,
b,
b_smem_layout_slice,
b_smem_layout,
mma_tiler_mnk,
tiled_mma,
preferred_cta_layout_vmnk.shape,
@@ -915,12 +907,11 @@ def host_function(
epi_tile,
epi_stages,
)
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
epi_smem_layout_staged,
epi_tile,
)
@@ -945,20 +936,12 @@ def host_function(
kernel(
tiled_mma,
tma_atom_a_preferred,
a_tma_tensor_preferred,
tma_atom_b_preferred,
b_tma_tensor_preferred,
tma_atom_a_fallback,
a_tma_tensor_fallback,
tma_atom_b_fallback,
b_tma_tensor_fallback,
tma_atom_c,
c_tma_tensor,
a_smem_layout,
b_smem_layout,
tma_a_preferred,
tma_b_preferred,
tma_a_fallback,
tma_b_fallback,
tma_c,
c_smem_layout_kind,
epi_smem_layout_staged,
epi_tile,
preferred_cta_layout_vmnk,
fallback_cta_layout_vmnk,
@@ -969,6 +952,7 @@ def host_function(
block=[224, 1, 1] if use_clc_dynamic_scheduler else [192, 1, 1],
cluster=preferred_cluster_shape_mnk,
fallback_cluster=fallback_cluster_shape_mnk,
smem_merge_branch_allocs=True,
)
@@ -981,7 +965,7 @@ def run_dense_gemm(
import cutlass.torch as cutlass_torch
print("===================================================================")
print("Running Blackwell fp16 GEMM example 4 (with MIX cluster size support):")
print("Running Blackwell fp16 GEMM example 4 (with MIX cluster support):")
print(f" mnk: {mnk}")
print(f" tolerance: {tolerance}")
print(f" Preferred cluster shape: {preferred_cluster_shape_mnk}")

View File

@@ -1,12 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# This is the fifth tutorial GEMM (5). It extends fp16_gemm_3_1.py by adding TMA prefetch.
# TMA prefetch uses cute.prefetch() to bring data into L2 cache before TMA copy needs it,
@@ -44,7 +62,7 @@ Key differences from fp16_gemm_3_1.py:
To run this example:
.. code-block:: bash
python examples/blackwell/tutorial_gemm/fp16_gemm_5.py \
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_5.py \
--mnk 8192,8192,8192
Constraints for this example:
@@ -84,22 +102,20 @@ class SharedStorage:
tmem_holding_buffer: cutlass.Int32
# Only for CLC Dynamic Scheduler
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
clc_response_align_bytes = num_clc_response_bytes
clc_response: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 4],
clc_response_align_bytes,
]
@cute.kernel()
def kernel(
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
a_smem_layout: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
tma_a: cpasync.TmaInfo,
tma_b: cpasync.TmaInfo,
tma_c: cpasync.TmaInfo,
c_smem_layout_kind: cutlass.Constexpr,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
cta_layout_vmnk: cute.Layout,
tile_sched_params: Union[
@@ -107,6 +123,11 @@ def kernel(
utils.PersistentTileSchedulerParams,
],
):
# Extract tma_tensor from TmaInfo
mA_mkl = tma_a.tma_tensor
mB_nkl = tma_b.tma_tensor
mC_mnl = tma_c.tma_tensor
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@@ -135,9 +156,9 @@ def kernel(
# Prefetch tma descriptor
if warp_idx == tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
cpasync.prefetch_descriptor(tma_a.atom)
cpasync.prefetch_descriptor(tma_b.atom)
cpasync.prefetch_descriptor(tma_c.atom)
# As many participants as the number of threads issuing the MMA in the same row and column
# Substract one to not count twice the same thread
@@ -179,8 +200,8 @@ def kernel(
)
num_tma_copy_bytes = (
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
) * cute.size(cta_layout_vmnk, mode=[0])
# Threads/warps participating in the mainloop pipeline
@@ -260,21 +281,21 @@ def kernel(
# Allocate SMEM
sA = smem.allocate_tensor(
element_type=io_dtype,
layout=a_smem_layout.outer,
layout=tma_a.smem_layout.outer,
byte_alignment=128,
swizzle=a_smem_layout.inner,
swizzle=tma_a.smem_layout.inner,
)
sB = smem.allocate_tensor(
element_type=io_dtype,
layout=b_smem_layout.outer,
layout=tma_b.smem_layout.outer,
byte_alignment=128,
swizzle=b_smem_layout.inner,
swizzle=tma_b.smem_layout.inner,
)
sC = smem.allocate_tensor(
element_type=io_dtype,
layout=epi_smem_layout_staged.outer,
layout=tma_c.smem_layout.outer,
byte_alignment=128,
swizzle=epi_smem_layout_staged.inner,
swizzle=tma_c.smem_layout.inner,
)
# Partition tensors for MMA and make fragments
@@ -313,7 +334,7 @@ def kernel(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK)
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
tma_a.atom,
cta_in_cluster_coord_vmnk[2],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
cute.group_modes(sA, 0, 3),
@@ -323,7 +344,7 @@ def kernel(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestN, RestK)
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
tma_b.atom,
cta_in_cluster_coord_vmnk[1],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
cute.group_modes(sB, 0, 3),
@@ -333,7 +354,7 @@ def kernel(
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
tma_c.atom,
0,
cute.make_layout(1),
cute.group_modes(sC, 0, 2),
@@ -396,8 +417,8 @@ def kernel(
for pf_k_tile in cutlass.range(
cutlass.min(prefetch_dist, num_k_tiles), unroll=1
):
cute.prefetch(tma_atom_a, tAgA_slice[(None, pf_k_tile)])
cute.prefetch(tma_atom_b, tBgB_slice[(None, pf_k_tile)])
cute.prefetch(tma_a.atom, tAgA_slice[(None, pf_k_tile)])
cute.prefetch(tma_b.atom, tBgB_slice[(None, pf_k_tile)])
# =========================================================
# TMA Load Loop with Rolling Prefetch
@@ -408,14 +429,14 @@ def kernel(
# Issue TMA loads (use k_tile_idx like fp16_gemm_3_1.py)
cute.copy(
tma_atom_a,
tma_a.atom,
tAgA_slice[(None, k_tile_idx)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=tma_mcast_mask_a,
)
cute.copy(
tma_atom_b,
tma_b.atom,
tBgB_slice[(None, k_tile_idx)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
@@ -426,8 +447,8 @@ def kernel(
# This keeps the L2 primed as we progress through the K dimension
if k_tile_idx + prefetch_dist < num_k_tiles:
future_k_tile = k_tile_idx + prefetch_dist
cute.prefetch(tma_atom_a, tAgA_slice[(None, future_k_tile)])
cute.prefetch(tma_atom_b, tBgB_slice[(None, future_k_tile)])
cute.prefetch(tma_a.atom, tAgA_slice[(None, future_k_tile)])
cute.prefetch(tma_b.atom, tBgB_slice[(None, future_k_tile)])
# Advance to next tile
if cutlass.const_expr(use_clc_dynamic_scheduler):
@@ -483,24 +504,14 @@ def kernel(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()
@@ -533,10 +544,10 @@ def kernel(
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
# Initialize TMA store pipeline for epilogue
# Initialize TMA store pipeline for epilogue with 4 warps
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
size=128,
pipeline.Agent.Warp,
size=4,
)
epilogue_pipeline = pipeline.PipelineTmaStore.create(
num_stages=epi_stages,
@@ -631,7 +642,7 @@ def kernel(
# SMEM -> GMEM
if warp_idx == epilogue_warp_ids[0]:
cute.copy(
tma_atom_c,
tma_c.atom,
tCsC[(None, c_buffer)],
tCgC_grouped[(None, subtile_idx)],
)
@@ -715,8 +726,8 @@ def host_function(
mma_inst_shape_mnk,
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
tcgen05.OperandSource.SMEM,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
)
tiled_mma = cute.make_tiled_mma(op)
@@ -754,20 +765,18 @@ def host_function(
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
tma_atom_a, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
op,
a,
a_smem_layout_slice,
a_smem_layout,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
)
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
tma_atom_b, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
op,
b,
b_smem_layout_slice,
b_smem_layout,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
@@ -793,11 +802,10 @@ def host_function(
epi_stages,
)
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
epi_smem_layout_staged,
epi_tile,
)
@@ -815,16 +823,10 @@ def host_function(
kernel(
tiled_mma,
tma_atom_a,
a_tma_tensor,
tma_atom_b,
b_tma_tensor,
tma_atom_c,
c_tma_tensor,
a_smem_layout,
b_smem_layout,
tma_a,
tma_b,
tma_c,
c_smem_layout_kind,
epi_smem_layout_staged,
epi_tile,
cta_layout_vmnk,
tile_sched_params,

View File

@@ -1,12 +1,30 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# This is the sixth tutorial GEMM. It enables programmatic dependent launch (PDL) features.
@@ -57,7 +75,7 @@ For --mnk 256,8192,128, the speedup pdl v.s. no pdl can be up to 1.16x.
To run this example:
.. code-block:: bash
python examples/blackwell/tutorial_gemm/fp16_gemm_6.py \
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_6.py \
--mnk 256,8192,128
Constraints for this example:
@@ -100,7 +118,11 @@ class SharedStorage:
tmem_holding_buffer: cutlass.Int32
# Only for CLC Dynamic Scheduler
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
clc_response_align_bytes = num_clc_response_bytes
clc_response: cute.struct.Align[
cute.struct.MemRange[cutlass.Int32, 4],
clc_response_align_bytes,
]
@cute.kernel()
@@ -133,16 +155,10 @@ def dequantize(
@cute.kernel()
def gemm(
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
mA_mkl: cute.Tensor,
tma_atom_b: cute.CopyAtom,
mB_nkl: cute.Tensor,
tma_atom_c: cute.CopyAtom,
mC_mnl: cute.Tensor,
a_smem_layout: cute.ComposedLayout,
b_smem_layout: cute.ComposedLayout,
tma_a: cpasync.TmaInfo,
tma_b: cpasync.TmaInfo,
tma_c: cpasync.TmaInfo,
c_smem_layout_kind: cutlass.Constexpr,
epi_smem_layout_staged: cute.ComposedLayout,
epi_tile: cute.Tile,
cta_layout_vmnk: cute.Layout,
tile_sched_params: Union[
@@ -150,6 +166,11 @@ def gemm(
utils.PersistentTileSchedulerParams,
],
):
# Extract tma_tensor from TmaInfo
mA_mkl = tma_a.tma_tensor
mB_nkl = tma_b.tma_tensor
mC_mnl = tma_c.tma_tensor
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@@ -178,9 +199,9 @@ def gemm(
# Prefetch tma descriptor
if warp_idx == tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_c)
cpasync.prefetch_descriptor(tma_a.atom)
cpasync.prefetch_descriptor(tma_b.atom)
cpasync.prefetch_descriptor(tma_c.atom)
# As many participants as the number of threads issuing the MMA in the same row and column
# Substract one to not count twice the same thread
@@ -222,8 +243,8 @@ def gemm(
)
num_tma_copy_bytes = (
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
) * cute.size(cta_layout_vmnk, mode=[0])
# Threads/warps participating in the mainloop pipeline
@@ -303,21 +324,21 @@ def gemm(
# Allocate SMEM
sA = smem.allocate_tensor(
element_type=io_dtype,
layout=a_smem_layout.outer,
layout=tma_a.smem_layout.outer,
byte_alignment=128,
swizzle=a_smem_layout.inner,
swizzle=tma_a.smem_layout.inner,
)
sB = smem.allocate_tensor(
element_type=io_dtype,
layout=b_smem_layout.outer,
layout=tma_b.smem_layout.outer,
byte_alignment=128,
swizzle=b_smem_layout.inner,
swizzle=tma_b.smem_layout.inner,
)
sC = smem.allocate_tensor(
element_type=io_dtype,
layout=epi_smem_layout_staged.outer,
layout=tma_c.smem_layout.outer,
byte_alignment=128,
swizzle=epi_smem_layout_staged.inner,
swizzle=tma_c.smem_layout.inner,
)
# Partition tensors for MMA and make fragments
@@ -356,7 +377,7 @@ def gemm(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestM, RestK)
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_a,
tma_a.atom,
cta_in_cluster_coord_vmnk[2],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
cute.group_modes(sA, 0, 3),
@@ -366,7 +387,7 @@ def gemm(
# ((atom_v, rest_v), STAGE)
# ((atom_v, rest_v), RestN, RestK)
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_b,
tma_b.atom,
cta_in_cluster_coord_vmnk[1],
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
cute.group_modes(sB, 0, 3),
@@ -376,7 +397,7 @@ def gemm(
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
tma_atom_c,
tma_c.atom,
0,
cute.make_layout(1),
cute.group_modes(sC, 0, 2),
@@ -440,14 +461,14 @@ def gemm(
# Issue TMA loads
cute.copy(
tma_atom_a,
tma_a.atom,
tAgA_slice[(None, k_tile_idx)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
mcast_mask=tma_mcast_mask_a,
)
cute.copy(
tma_atom_b,
tma_b.atom,
tBgB_slice[(None, k_tile_idx)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
@@ -508,23 +529,14 @@ def gemm(
# (MMA, MMA_M, MMA_N)
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_tile_idx in range(num_k_tiles):
# Wait for TMA copies to complete
handle = ab_consumer.wait_and_advance()
# Execute one K-block worth of MMA instructions
num_k_blocks = cute.size(tCrA, mode=[2])
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
k_block_coord = (None, None, k_block_idx, handle.index)
cute.gemm(
tiled_mma,
tCtAcc,
tCrA[k_block_coord],
tCrB[k_block_coord],
tCtAcc,
)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
# Signal that the A/B buffers have been consumed and are ready for the next load
handle.release()
@@ -557,10 +569,10 @@ def gemm(
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
# Initialize TMA store pipeline for epilogue
# Initialize TMA store pipeline for epilogue with 4 warps
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
size=128,
pipeline.Agent.Warp,
size=4,
)
epilogue_pipeline = pipeline.PipelineTmaStore.create(
num_stages=epi_stages,
@@ -654,7 +666,7 @@ def gemm(
# SMEM -> GMEM
if warp_idx == epilogue_warp_ids[0]:
cute.copy(
tma_atom_c,
tma_c.atom,
tCsC[(None, c_buffer)],
tCgC_grouped[(None, subtile_idx)],
)
@@ -750,8 +762,8 @@ def host_function(
mma_inst_shape_mnk,
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
tcgen05.OperandSource.SMEM,
tcgen05.OperandMajorMode.K,
tcgen05.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
cute.nvgpu.OperandMajorMode.K,
)
tiled_mma = cute.make_tiled_mma(op)
@@ -789,20 +801,18 @@ def host_function(
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
tma_atom_a, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
op,
a,
a_smem_layout_slice,
a_smem_layout,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
)
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
tma_atom_b, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
op,
b_dequantized,
b_smem_layout_slice,
b_smem_layout,
mma_tiler_mnk,
tiled_mma,
cta_layout_vmnk.shape,
@@ -828,11 +838,10 @@ def host_function(
epi_stages,
)
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
c,
epi_smem_layout,
epi_smem_layout_staged,
epi_tile,
)
@@ -862,16 +871,10 @@ def host_function(
gemm(
tiled_mma,
tma_atom_a,
a_tma_tensor,
tma_atom_b,
b_tma_tensor,
tma_atom_c,
c_tma_tensor,
a_smem_layout,
b_smem_layout,
tma_a,
tma_b,
tma_c,
c_smem_layout_kind,
epi_smem_layout_staged,
epi_tile,
cta_layout_vmnk,
tile_sched_params,

View File

@@ -34,7 +34,7 @@ import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync
import cutlass.cute.testing as testing
from cutlass import testing
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.cute.runtime import from_dlpack
@@ -42,6 +42,11 @@ import cutlass.utils.hopper_helpers as sm90_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
import cutlass.utils.blackwell_helpers as sm120_utils
# SM120 block-scaled GEMM dispatch helpers (sibling utility module). Try the
# namespace-package import path first (used under pytest, where
# `python/examples/CuTeDSL/cute` is on sys.path); fall back to the bare local
# import for standalone-script invocation, where Python places only the
# script's own directory on sys.path[0].
try:
from blackwell_geforce.kernel.blockscaled_gemm.blockscaled_gemm_dispatch import (
FP4_SHIFT_BITS,
@@ -49,8 +54,8 @@ try:
make_sm120_blockscaled_mma_op,
validate_blockscaled_args,
)
except ImportError:
from blockscaled_gemm_dispatch import (
except ImportError: # pragma: no cover - exercised only via standalone-script invocation
from blockscaled_gemm_dispatch import ( # noqa: E402
FP4_SHIFT_BITS,
make_ldmatrix_atom,
make_sm120_blockscaled_mma_op,
@@ -705,9 +710,6 @@ class Sm120BlockScaledGemmKernel:
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
tCrSFA = sm120_utils.partition_fragment_SFA(sSFA[None, None, 0], thr_mma, tidx)
tCrSFB = sm120_utils.partition_fragment_SFB(sSFB[None, None, 0], thr_mma, tidx)
# Keep residual K modes nested to match the C++ SM120 block-scaled mainloop.
tCrSFA = cute.group_modes(tCrSFA, 2, cute.rank(tCrSFA))
tCrSFB = cute.group_modes(tCrSFB, 2, cute.rank(tCrSFB))
tCgC = thr_mma.partition_C(gC_mnl)
acc_shape = tCgC.shape[:3]
@@ -879,8 +881,6 @@ class Sm120BlockScaledGemmKernel:
tCsSFB_copy_view = thr_copy_ldmatrix_SFB.partition_S(sSFB)
tCrSFB_copy_view = thr_copy_ldmatrix_SFB.retile(tCrSFB)
epi_buffer = cutlass.Int32(0)
if warp_group_idx == 1:
tile_sched.advance_to_next_work()
mainloop_consumer_state = self.advance(
@@ -1051,7 +1051,6 @@ class Sm120BlockScaledGemmKernel:
)
if k_block_idx == num_k_blocks - 1:
cute.arch.fence_proxy("async.shared", space="cta")
mainloop_pipeline.consumer_release(mainloop_consumer_state)
mainloop_consumer_state.advance()
@@ -1210,8 +1209,7 @@ class Sm120BlockScaledGemmKernel:
tRS_rD_out.store(acc_vec.to(self.c_dtype))
# Register to shared memory
epi_buffer = epi_buffer + 1
epi_buffer = epi_buffer % cute.size(
epi_buffer = (epi_m * epi_rest_n + epi_n) % cute.size(
tRS_sD, mode=[3]
)
self.epilog_sync_barrier.arrive_and_wait()
@@ -1242,6 +1240,7 @@ class Sm120BlockScaledGemmKernel:
tile_sched.advance_to_next_work()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
tma_store_pipeline.producer_tail()
math_wg_order_state = math_wg_order_barrier.arrive(math_wg_order_state)
# End of for k_tile loop
# End of while loop
@@ -1886,24 +1885,6 @@ def run_bs(
cute.testing.convert(ref_f8, ref_tensor)
ref = ref_device.cpu()
torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02)
elif c_dtype is cutlass.Float4E2M1FN:
# Convert ref : f32 -> f4 -> f32
ref_f4_ = torch.empty(*(l, m, n), dtype=torch.uint8, device="cuda").permute(
1, 2, 0
)
ref_f4 = from_dlpack(ref_f4_, assumed_align=16).mark_layout_dynamic(
leading_dim=1
)
ref_f4.element_type = c_dtype
ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda()
ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic(
leading_dim=1
)
cute.testing.convert(ref_tensor, ref_f4)
cute.testing.convert(ref_f4, ref_tensor)
ref = ref_device.cpu()
torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02)
def generate_tensors():
a_tensor, _ = cutlass_torch.cute_tensor_like(
a_ref, a_dtype, is_dynamic_layout=True, assumed_align=16
@@ -1933,7 +1914,7 @@ def run_bs(
_, sfa_tensor, _ = create_scale_factor_tensor(l, m, k, sf_vec_size, sf_dtype)
_, sfb_tensor, _ = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype)
return cute.testing.JitArguments(
return cutlass.testing.JitArguments(
a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, stream
)

View File

@@ -13,7 +13,6 @@
"outputs": [],
"source": [
"import torch\n",
"from functools import partial\n",
"from typing import List\n",
"\n",
"import cutlass\n",
@@ -332,8 +331,8 @@
"\n",
" # Print results\n",
" # ------------\n",
" print(f\"Performance Metrics:\")\n",
" print(f\"-------------------\")\n",
" print(\"Performance Metrics:\")\n",
" print(\"-------------------\")\n",
" print(f\"Kernel execution time: {avg_time_us:.4f} us\")\n",
" print(f\"Memory throughput: {achieved_bandwidth:.2f} GB/s\")"
]
@@ -1082,7 +1081,7 @@
" ###############################################################################\n",
" # Compute predicate for out of boundary checks\n",
" ###############################################################################\n",
" frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)\n",
" frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)\n",
" print(f\"[DSL INFO] frgPred = {frgPred.type}\")\n",
"\n",
" for i in cutlass.range_constexpr(cute.size(frgPred)):\n",

View File

@@ -0,0 +1,106 @@
# DSL Feature Examples
This directory demonstrates **CuTe DSL capabilities** beyond kernel authoring itself:
exporting compiled kernels for deployment, integrating with ML frameworks, using
foreign function interfaces, and accessing low-level DSL features like inline PTX
and shared memory allocation.
---
## Directory Structure
```
dsl/
export/ Exporting kernels to C shared libraries
export_to_c.py Compile a kernel and export as .so/.dylib
load_in_python.py Load and call the exported library from Python
run_with_dynamic_loading.cpp C++ driver using dlopen
run_with_dynamic_loading.sh Build/run script for dynamic loading
run_with_static_linking.cpp C++ driver using static linking
run_with_static_linking.sh Build/run script for static linking
ffi/ Foreign function interface
jit_argument.py JIT compilation with argument passing
tensor.cpp C++ tensor interop implementation
CMakeLists.txt CMake build for FFI examples
jax/ JAX integration
cutlass_call_basic.py Basic CUTLASS kernel call from JAX
cutlass_call_export.py Export a CUTLASS kernel for JAX
cutlass_call_sharding.py Multi-device sharding with CUTLASS kernels
elementwise_apply_example.py Elementwise apply via JAX
tvm_ffi/ TVM FFI integration
jit_and_use_in_torch.py JIT compile and call from PyTorch
jit_and_use_in_jax.py JIT compile and call from JAX
aot_export.py Ahead-of-time export
aot_use_in_torch.py Use AOT-exported kernel in PyTorch
aot_use_in_jax.py Use AOT-exported kernel in JAX
aot_use_in_cpp_bundle.cpp Use AOT-exported kernel in C++
aot_use_in_cpp_bundle.sh Build/run script for C++ AOT usage
compile_with_fake_tensor.py Compile using fake tensors
compile_with_symint_arg.py Compile with symbolic integer arguments
ampere_gemm_with_fake_tensor.py Ampere GEMM with fake tensor compilation
error_reporting.py Error reporting and diagnostics
call_bypass_dlpack.py Calling kernels bypassing DLPack
call_from_jit.py Calling conventions from JIT-compiled code
cooperative_launch.py Cooperative kernel launch (multi-CTA)
dynamic_smem_size.py Dynamic shared memory allocation
inline_ptx.py Embedding inline PTX assembly
launch_completion_and_programmatic_events.py
Launch completion / programmatic events with cudaEvent_t and CUevent
pointer.py Pointer manipulation in DSL
print_latex.py LaTeX rendering of CuTe layouts
programmatic_dependent_launch.py Programmatic dependent launch (PDL)
smem_allocator.py Shared memory allocator usage
torch_fake_tensor.py PyTorch fake tensor integration
torch_fp4.py PyTorch FP4 tensor support
```
---
## Subdirectory Guides
### `export/` -- Kernel Export
Shows how to compile a CuTe DSL kernel into a standalone C shared library (`.so`)
that can be loaded and called from C++ or Python without any CuTe DSL dependency
at runtime. Includes complete examples for both dynamic loading (`dlopen`) and
static linking workflows.
### `ffi/` -- Foreign Function Interface
Demonstrates how to pass arguments between Python/CuTe DSL and C++ code using
the FFI layer. Useful for integrating CuTe DSL kernels into existing C++
applications.
### `jax/` -- JAX Integration
Shows how to call CuTe DSL kernels from JAX using `cutlass_call`, including
basic invocation, kernel export for JAX, multi-device sharding, and elementwise
application patterns.
### `tvm_ffi/` -- TVM FFI Integration
Comprehensive examples for using CuTe DSL kernels through TVM's foreign function
interface. Covers both JIT and AOT (ahead-of-time) compilation workflows, with
usage examples for PyTorch, JAX, and C++. Also demonstrates fake-tensor
compilation (no GPU required at compile time) and symbolic integer arguments.
---
## Top-Level Files
The top-level Python files demonstrate individual DSL features:
- **`call_bypass_dlpack.py`** / **`call_from_jit.py`** -- Kernel calling conventions
- **`inline_ptx.py`** -- Embedding inline PTX assembly in CuTe DSL kernels
- **`launch_completion_and_programmatic_events.py`** -- Examples of
``launch_completion_event`` and ``programmatic_event`` launch attributes,
using events created via ``torch.cuda.Event(enable_timing=False)`` and
presented as either ``cudaEvent_t`` (`cuda.bindings.runtime`) or ``CUevent`` (`cuda.bindings.driver`). The
stream is passed as a ``cudaStream_t`` (`cuda.bindings.runtime`)
- **`programmatic_dependent_launch.py`** -- Programmatic dependent launch for
chaining kernels with data dependencies
- **`cooperative_launch.py`** -- Cooperative launch for multi-CTA kernels
- **`dynamic_smem_size.py`** / **`smem_allocator.py`** -- Shared memory allocation
- **`torch_fake_tensor.py`** / **`torch_fp4.py`** -- PyTorch integration features
- **`pointer.py`** -- Pointer manipulation within DSL kernels
- **`print_latex.py`** -- Render CuTe layouts as LaTeX for visualization

View File

@@ -0,0 +1,367 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Launch Completion Events and Programmatic Events Example
=======================================================
This module demonstrates the two CUDA kernel-launch attributes that record a
``cudaEvent_t`` / ``CUevent`` as part of a launch:
1. ``cudaLaunchAttributeLaunchCompletionEvent``
The event is recorded when all blocks of the grid have begun executing
(best-effort, used for launch ordering). This attribute is
usable on **any** compute capability supported by CuTeDSL.
2. ``cudaLaunchAttributeProgrammaticEvent``
Part of the Programmatic Dependent Launch (PDL) model. The event is recorded
either:
* after **every block** in the grid has called
``cute.arch.griddepcontrol_launch_dependents()`` (or terminated) - this is
selected with ``trigger_at_block_start=0``. The kernel must call the trigger
itself.
* automatically at each block start - selected with ``trigger_at_block_start=1``.
The timing is similar to the launch-completion event, but the resulting
event remains part of the programmatic dependency model and is visible to
the next kernel's ``cute.arch.griddepcontrol_wait()``.
``programmatic_event`` requires sm_90 (Hopper) or newer.
The example demonstrates each attribute by launching a kernel with the attribute
attached, passing the stream as a ``cudaStream_t`` (runtime bindings) and the event
either as a ``cudaEvent_t`` (runtime) or ``CUevent`` (driver). The events
themselves are created from PyTorch with ``torch.cuda.Event(enable_timing=False)``
so they carry ``cudaEventDisableTiming``, which is required by both launch
attributes.
Usage::
python launch_completion_and_programmatic_events.py
"""
import argparse
from typing import Literal, Tuple, Union
import cuda.bindings.driver as cuda_driver
import cuda.bindings.runtime as cuda_runtime
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.runtime import from_dlpack
# =============================================================================
# Feature gate - queried via torch
# =============================================================================
def supports_programmatic_event() -> bool:
"""``programmatic_event`` is part of the programmatic dependency model and requires Hopper (sm_90+)."""
return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9
# =============================================================================
# Kernels
# =============================================================================
@cute.kernel
def simple_kernel(out: cute.Tensor, value: cutlass.Int32):
"""Write ``value`` into each element of ``out``.
Used to demonstrate ``launch_completion_event``: the event fires
automatically when all blocks have begun executing.
"""
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
if tidx < cute.size(out, [0]) and bidx < cute.size(out, [1]):
out[tidx, bidx] = value
@cute.kernel
def programmatic_trigger_kernel(
out: cute.Tensor,
value: cutlass.Int32,
trigger_at_block_start: cutlass.Constexpr[bool],
):
"""Write ``value`` into each element of ``out``, then trigger the
programmatic launch-completion signal.
Used to demonstrate ``programmatic_event``.
With ``trigger_at_block_start=False``, every block must execute
``cute.arch.griddepcontrol_launch_dependents()`` (from the DSL:
``cute.arch.griddepcontrol_launch_dependents()``) for the event to fire.
With ``trigger_at_block_start=True`` the event fires automatically at the
block start.
"""
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
cute.arch.griddepcontrol_wait()
if cutlass.const_expr(not trigger_at_block_start):
cute.arch.griddepcontrol_launch_dependents()
if tidx < cute.size(out, [0]) and bidx < cute.size(out, [1]):
out[tidx, bidx] = value
# =============================================================================
# JIT host functions - each exercises a single launch attribute
# =============================================================================
THREADS_PER_BLOCK = 128
@cute.jit
def launch_with_launch_completion_event(
out: cute.Tensor,
value: cutlass.Int32,
stream: cuda_runtime.cudaStream_t,
launch_completion_event: Union[cuda_runtime.cudaEvent_t, cuda_driver.CUevent],
):
"""Launch ``simple_kernel`` with ``launch_completion_event=launch_completion_event``."""
simple_kernel(out, value).launch(
grid=(cute.size(out, [1]), 1, 1),
block=(cute.size(out, [0]), 1, 1),
stream=stream,
launch_completion_event=launch_completion_event,
launch_completion_event_flags=0, # Optional flags
)
@cute.jit
def launch_with_programmatic_event(
out: cute.Tensor,
value: cutlass.Int32,
stream: cuda_runtime.cudaStream_t,
programmatic_event: Union[cuda_runtime.cudaEvent_t, cuda_driver.CUevent],
trigger_at_block_start: cutlass.Constexpr[int] = 0,
):
"""Launch ``programmatic_trigger_kernel`` with ``programmatic_event=programmatic_event``."""
programmatic_trigger_kernel(out, value, trigger_at_block_start == 1).launch(
grid=(cute.size(out, [1]), 1, 1),
block=(cute.size(out, [0]), 1, 1),
stream=stream,
programmatic_event=programmatic_event,
programmatic_event_flags=0, # Optional flags
programmatic_event_trigger_at_block_start=trigger_at_block_start, # Defaults to zero
use_pdl=True,
)
# =============================================================================
# Helpers for event creation and synchronization
# =============================================================================
def _make_event(
kind: Literal["runtime", "driver"], init_stream: torch.cuda.Stream
) -> Tuple[torch.cuda.Event, Union[cuda_runtime.cudaEvent_t, cuda_driver.CUevent]]:
"""Create a CUDA event from torch with timing disabled and wrap it as the
requested low-level API type.
Using ``torch.cuda.Event(enable_timing=False)`` guarantees the underlying
CUDA event is created with ``cudaEventDisableTiming``, which is required
for ``launch_completion_event`` and ``programmatic_event`` launch attributes.
PyTorch lazily initializes the underlying ``cudaEvent_t`` on the first
``record()`` call. We force initialization by recording the event.
Returns ``(torch_event, cuda_event)``. The torch event
must be kept alive for the lifetime of the wrapped cuda event. Torch will
destroy the event when ``torch_event`` is garbage-collected.
"""
torch_event = torch.cuda.Event(enable_timing=False)
torch_event.record(init_stream)
raw_event = int(torch_event.cuda_event)
if raw_event == 0:
raise RuntimeError("torch.cuda.Event was not created")
if kind == "runtime":
return torch_event, cuda_runtime.cudaEvent_t(raw_event)
elif kind == "driver":
return torch_event, cuda_driver.CUevent(raw_event)
else:
raise ValueError(
f"Unknown event kind: {kind!r}; expected 'runtime' or 'driver'"
)
# =============================================================================
# Run functions
# =============================================================================
def run_launch_completion_event_example(
n_elements_per_block: int = 128,
n_blocks: int = 48,
launch_completion_event_kind: Literal["runtime", "driver"] = "runtime",
) -> None:
"""Run the ``launch_completion_event`` demonstration.
Allocates an output tensor, launches ``simple_kernel`` with
``launch_completion_event`` attached, then blocks on the event using the
matching API and verifies the output.
"""
print(
f"\n[Launch Completion Event] Running launch_completion_event example "
f"(launch_completion_event_kind={launch_completion_event_kind!r}, "
f"n_elements_per_block={n_elements_per_block}, n_blocks={n_blocks})"
)
out = torch.full(
(n_elements_per_block, n_blocks), -1, dtype=torch.int32, device="cuda"
)
expected = torch.full(
(n_elements_per_block, n_blocks), 0, dtype=torch.int32, device="cuda"
)
torch_stream = torch.cuda.default_stream()
stream = cuda_runtime.cudaStream_t(torch_stream.cuda_stream)
torch_event, launch_completion_event = _make_event(
launch_completion_event_kind, torch_stream
)
out_tensor = from_dlpack(out)
launch_with_launch_completion_event(out_tensor, 0, stream, launch_completion_event)
torch_event.synchronize()
torch.testing.assert_close(out, expected)
print(
f"[Launch Completion Event] {launch_completion_event_kind} event fired and output verified."
)
def run_programmatic_event_example(
n_elements_per_block: int = 128,
n_blocks: int = 48,
programmatic_event_kind: Literal["runtime", "driver"] = "runtime",
trigger_at_block_start: int = 0,
) -> None:
"""Run the ``programmatic_event`` demonstration.
Allocates an output tensor, launches ``programmatic_trigger_kernel`` with
``programmatic_event`` attached, then blocks on the event using the matching
API and verifies the output.
"""
if not supports_programmatic_event():
raise RuntimeError("programmatic_event requires Hopper (sm_90) or newer")
print(
f"\n[Programmatic Event] Running programmatic_event example "
f"(programmatic_event_kind={programmatic_event_kind!r}, "
f"n_elements_per_block={n_elements_per_block}, n_blocks={n_blocks}, "
f"trigger_at_block_start={trigger_at_block_start})"
)
out = torch.full(
(n_elements_per_block, n_blocks), -1, dtype=torch.int32, device="cuda"
)
expected = torch.full(
(n_elements_per_block, n_blocks), 1, dtype=torch.int32, device="cuda"
)
torch_stream = torch.cuda.default_stream()
stream = cuda_runtime.cudaStream_t(torch_stream.cuda_stream)
out_tensor = from_dlpack(out)
graph = torch.cuda.CUDAGraph()
torch_event = None
with torch.cuda.graph(graph):
stream = cuda_runtime.cudaStream_t(torch.cuda.current_stream().cuda_stream)
torch_event, programmatic_event = _make_event(
programmatic_event_kind, torch_stream
)
launch_with_programmatic_event(
out_tensor,
0,
stream,
programmatic_event,
trigger_at_block_start=trigger_at_block_start,
)
# Overlaps with the prior launch after every CTA has been launched
launch_with_programmatic_event(
out_tensor,
1,
stream,
programmatic_event,
trigger_at_block_start=trigger_at_block_start,
)
graph.replay()
assert torch_event is not None, "torch.cuda.Event was not created"
torch.cuda.synchronize()
torch.testing.assert_close(out, expected)
print(
f"[Programmatic Event] {programmatic_event_kind} event fired and output verified."
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Demonstrate launch_completion_event and programmatic_event launch "
"attributes with both cudaEvent_t and CUevent."
)
)
parser.add_argument("--n-blocks", default=48, type=int)
parser.add_argument("--n-elements-per-block", default=128, type=int)
parser.add_argument("--use-driver-api", action="store_true")
parser.add_argument("--trigger-at-block-start", action="store_true")
args = parser.parse_args()
run_launch_completion_event_example(
n_elements_per_block=args.n_elements_per_block,
n_blocks=args.n_blocks,
launch_completion_event_kind="driver" if args.use_driver_api else "runtime",
)
if supports_programmatic_event():
run_programmatic_event_example(
n_elements_per_block=args.n_elements_per_block,
n_blocks=args.n_blocks,
programmatic_event_kind="driver" if args.use_driver_api else "runtime",
trigger_at_block_start=1 if args.trigger_at_block_start else 0,
)
else:
print("\nSkipping programmatic_event: requires Hopper (sm_90) or newer.")
print("\nPASS")