mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-06-28 18:37:05 +00:00
v4.6 dev update. (#3315)
* v4.6 dev update. * Remove CUTLASS_HOST_DEVICE from CudaHostAdapater::memsetDevice (#3286) * [SM120] Add ptr-array TMA collective for tensor/token-scaled FP8 grouped GEMM (#3280) * gemm: add SM120 array TMA collective for tensor/token-scaled FP8 grouped GEMM Adds CollectiveMma and CollectiveBuilder specializations for MainloopSm120ArrayTmaWarpSpecialized, enabling ptr-array grouped GEMM (MoE expert dispatch) with tensor- and token-level FP8 scaling on SM_120/SM_121 consumer Blackwell (RTX 5090/5080/5070, DGX Spark GB10). New files: - include/cutlass/gemm/collective/sm120_mma_array_tma.hpp CollectiveMma specialization for MainloopSm120ArrayTmaWarpSpecialized. Handles both Cooperative (4x2 atom layout) and Pingpong (2x2) schedules. Grouped GEMM via pointer-array indirection through params.ptr_A / ptr_B. Supports F8F6F4 MMA with TMA loads for both A and B operands. - include/cutlass/gemm/collective/builders/sm120_array_mma_builder.inl CollectiveBuilder specialization for KernelPtrArrayTmaWarpSpecialized Cooperative/PingpongSm120<N> schedule tags. Computes tile/stage counts from smem capacity, routes to MainloopSm120ArrayTmaWarpSpecialized dispatch policy, produces correctly-typed CollectiveOp. Modified files: - collective_mma.hpp: include sm120_mma_array_tma.hpp - collective_builder.hpp: include sm120_array_mma_builder.inl - sm120_mma_builder.inl: remove ptr-array schedules from enable_if (they now route to sm120_array_mma_builder.inl) and drop the IsPtrArrayKernel static_assert that enforced the restriction Validated on real SM_121 hardware (DGX Spark, 128 GB LPDDR5X) running vLLM with RedHatAI/gemma-4-26B-A4B-it-FP8-Dynamic (Gemma 4 MoE, 26B total / 4B active). Previously fell back to a non-CUTLASS Triton path; with this patch, the SM120 CUTLASS grouped GEMM collective activates and produces correct outputs. Short-sequence throughput improved ~7% vs the fallback baseline (76.3 → 81.9 tok/s). Closes #3263 Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> * test: add SM120 ptr-array grouped GEMM unit tests Adds 6 device-level tests for the CollectiveMma/CollectiveBuilder specializations introduced for MainloopSm120ArrayTmaWarpSpecialized, covering both KernelPtrArrayTmaWarpSpecializedPingpongSm120<2> and KernelPtrArrayTmaWarpSpecializedCooperativeSm120<2> schedule tags across e4m3×e4m3 (symmetric), e4m3×e5m2 (mixed), float and bfloat16 outputs, and two tile shapes. Tests land in test/unit/gemm/device/sm120_tensorop_gemm/ under the new cutlass_test_unit_sm120_grouped_gemm_device_tensorop CMake target, per reviewer request in PR #3280. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> --------- Signed-off-by: Tyler Merritt <tgmerritt@gmail.com> Co-authored-by: Alex Georgiev <89279829+alexngUNC@users.noreply.github.com> Co-authored-by: Tyler <tgmerritt@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -49,7 +49,7 @@ To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/ampere/hstu_attention.py --batch_size 4 --seqlen_q 8192 --seqlen_kv 8192 --num_head 4 --head_dim 128 --m_block_size 128 --n_block_size 64 --is_causal --perf_test
|
||||
python examples/cute/ampere/kernel/attention/hstu_attention.py --batch_size 4 --seqlen_q 8192 --seqlen_kv 8192 --num_head 4 --head_dim 128 --m_block_size 128 --n_block_size 64 --is_causal --perf_test
|
||||
|
||||
The above example tests the performance of HSTU attention with batch size 4, sequence length 8192, 4 attention heads, and head dimension 128. The m_block_size is 128, and n_block_size is 64. The causal masking is enabled.
|
||||
|
||||
@@ -268,6 +268,104 @@ class HSTUAttentionForwardAmpere(object):
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def _copy_with_residue(
|
||||
self,
|
||||
copy_atom,
|
||||
src_tile,
|
||||
dst_tile,
|
||||
coord_tile,
|
||||
head_dim_pred,
|
||||
has_outer_residue: cutlass.Constexpr,
|
||||
has_inner_residue: cutlass.Constexpr,
|
||||
outer_size,
|
||||
block_end,
|
||||
fill_zero_on_oob: cutlass.Constexpr = True,
|
||||
is_known_boundary: cutlass.Constexpr = False,
|
||||
):
|
||||
# Copy a (CPY_Atom, M, K) tile with optional outer-axis (M) and head-dim (K) residue.
|
||||
# `is_known_boundary=True` skips the runtime boundary check (caller knows the tile straddles outer_size).
|
||||
# `fill_zero_on_oob=False` for stores; loads zero-fill out-of-bounds rows so SMEM contents are well-defined.
|
||||
if cutlass.const_expr(not has_outer_residue and not has_inner_residue):
|
||||
cute.copy(copy_atom, src_tile, dst_tile)
|
||||
elif cutlass.const_expr(not has_outer_residue):
|
||||
cute.copy(copy_atom, src_tile, dst_tile, pred=head_dim_pred)
|
||||
else:
|
||||
if cutlass.const_expr(is_known_boundary):
|
||||
is_boundary = True
|
||||
else:
|
||||
is_boundary = cute.elem_less(outer_size, block_end)
|
||||
if is_boundary:
|
||||
for m in cutlass.range_constexpr(cute.size(dst_tile.shape[1])):
|
||||
if cute.elem_less(coord_tile[0, m, 0][1], outer_size):
|
||||
if cutlass.const_expr(has_inner_residue):
|
||||
cute.copy(
|
||||
copy_atom,
|
||||
src_tile[None, m, None],
|
||||
dst_tile[None, m, None],
|
||||
pred=head_dim_pred[None, m, None],
|
||||
)
|
||||
else:
|
||||
cute.copy(
|
||||
copy_atom,
|
||||
src_tile[None, m, None],
|
||||
dst_tile[None, m, None],
|
||||
)
|
||||
elif cutlass.const_expr(fill_zero_on_oob):
|
||||
dst_tile[None, m, None].fill(0)
|
||||
else:
|
||||
if cutlass.const_expr(has_inner_residue):
|
||||
cute.copy(copy_atom, src_tile, dst_tile, pred=head_dim_pred)
|
||||
else:
|
||||
cute.copy(copy_atom, src_tile, dst_tile)
|
||||
|
||||
@cute.jit
|
||||
def _copy_rab_tile(
|
||||
self,
|
||||
copy_atom,
|
||||
src_tile,
|
||||
dst_tile,
|
||||
coord_tile,
|
||||
has_q_residue: cutlass.Constexpr,
|
||||
has_kv_residue: cutlass.Constexpr,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
q_block_end,
|
||||
kv_block_end,
|
||||
is_known_kv_interior: cutlass.Constexpr = False,
|
||||
):
|
||||
# Copy a (CPY_Atom, M, N) RAB tile with optional 2D residue.
|
||||
# Coord entries are 4-tuples from the (B, H, Q, KV) identity tensor; index 2 is q-coord, index 3 is kv-coord.
|
||||
# `is_known_kv_interior=True` skips the kv-side runtime check when the loaded tile is guaranteed inside seqlen_kv.
|
||||
kv_check_active = has_kv_residue and not is_known_kv_interior
|
||||
if cutlass.const_expr(not has_q_residue and not kv_check_active):
|
||||
cute.copy(copy_atom, src_tile, dst_tile)
|
||||
else:
|
||||
# Pick the runtime predicate without mixing cute.Boolean with Python bool.
|
||||
if cutlass.const_expr(has_q_residue and kv_check_active):
|
||||
needs_per_elem = cute.elem_less(
|
||||
seqlen_q, q_block_end
|
||||
) or cute.elem_less(seqlen_kv, kv_block_end)
|
||||
elif cutlass.const_expr(has_q_residue):
|
||||
needs_per_elem = cute.elem_less(seqlen_q, q_block_end)
|
||||
else:
|
||||
needs_per_elem = cute.elem_less(seqlen_kv, kv_block_end)
|
||||
if needs_per_elem:
|
||||
for m in cutlass.range_constexpr(cute.size(dst_tile.shape[1])):
|
||||
for n in cutlass.range_constexpr(cute.size(dst_tile.shape[2])):
|
||||
if cute.elem_less(
|
||||
coord_tile[0, m, n][2], seqlen_q
|
||||
) and cute.elem_less(coord_tile[0, m, n][3], seqlen_kv):
|
||||
cute.copy(
|
||||
copy_atom,
|
||||
src_tile[None, m, n],
|
||||
dst_tile[None, m, n],
|
||||
)
|
||||
else:
|
||||
dst_tile[None, m, n].fill(0)
|
||||
else:
|
||||
cute.copy(copy_atom, src_tile, dst_tile)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
@@ -395,7 +493,7 @@ class HSTUAttentionForwardAmpere(object):
|
||||
tVsV = gmem_thr_copy_QKV.partition_D(sV)
|
||||
# (CPY_Atom, CPY_M, CPY_N, n_block)
|
||||
tRABgRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_S(gRAB)
|
||||
tRabsRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_D(sRAB)
|
||||
tRABsRAB = gmem_tiled_copy_QKV.get_slice(tidx).partition_D(sRAB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Tile MMA compute thread partitions and allocate accumulators
|
||||
@@ -446,115 +544,118 @@ class HSTUAttentionForwardAmpere(object):
|
||||
tOsVt = smem_thr_copy_V.partition_S(sVt)
|
||||
tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt)
|
||||
tSsRAB = smem_thr_copy_RAB.partition_S(sRAB)
|
||||
has_head_dim_residue = self._head_dim != self._head_dim_padded
|
||||
has_q_residue = self._seqlen_q % self._m_block_size != 0
|
||||
has_kv_residue = self._seqlen_kv % self._n_block_size != 0
|
||||
need_q_predicates = has_head_dim_residue or has_q_residue
|
||||
need_kv_predicates = has_head_dim_residue or has_kv_residue
|
||||
need_rab_predicates = has_q_residue or has_kv_residue
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Predicate: Mark indices that need to copy when problem_shape isn't a multiple
|
||||
# of tile_shape
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Construct identity layout for Q, KV and RAB
|
||||
mcQ = cute.make_identity_tensor(mQ.layout.shape)
|
||||
mcKV = cute.make_identity_tensor(mK.layout.shape)
|
||||
mcRAB = cute.make_identity_tensor(mRAB.layout.shape)
|
||||
|
||||
cQ = cute.local_tile(
|
||||
mcQ[batch_size, None, num_head, None],
|
||||
(self._m_block_size, self._head_dim_padded),
|
||||
(m_block, 0),
|
||||
)
|
||||
cKV = cute.local_tile(
|
||||
mcKV[batch_size, None, num_head, None],
|
||||
(self._n_block_size, self._head_dim_padded),
|
||||
(n_block, 0),
|
||||
)
|
||||
cRAB = cute.local_tile(
|
||||
mcRAB[batch_size, num_head, None, None],
|
||||
(self._m_block_size, self._n_block_size),
|
||||
(m_block, None),
|
||||
)
|
||||
|
||||
# Repeat the partitioning with identity layouts
|
||||
tQcQ = gmem_thr_copy_QKV.partition_S(cQ)
|
||||
tKVcKV = gmem_thr_copy_QKV.partition_S(cKV)
|
||||
tRABcRAB = gmem_thr_copy_QKV.partition_S(cRAB)
|
||||
|
||||
tQpQ = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tQsQ.shape[0][1],
|
||||
cute.size(tQsQ, mode=[1]),
|
||||
cute.size(tQsQ, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tQsQ, mode=[2]), 0, 1),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
tKVpKV = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tKsK.shape[0][1],
|
||||
cute.size(tKsK, mode=[1]),
|
||||
cute.size(tKsK, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tKsK, mode=[2]), 0, 1),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
|
||||
# Set predicates for head_dim bounds, seqlen_q/k/v bounds is processed at the first tile.
|
||||
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
|
||||
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
|
||||
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
|
||||
if cutlass.const_expr(need_q_predicates):
|
||||
mcQ = cute.make_identity_tensor(mQ.layout.shape)
|
||||
cQ = cute.local_tile(
|
||||
mcQ[batch_size, None, num_head, None],
|
||||
(self._m_block_size, self._head_dim_padded),
|
||||
(m_block, 0),
|
||||
)
|
||||
tQcQ = gmem_thr_copy_QKV.partition_S(cQ)
|
||||
if cutlass.const_expr(has_head_dim_residue):
|
||||
tQpQ = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tQsQ.shape[0][1],
|
||||
cute.size(tQsQ, mode=[1]),
|
||||
cute.size(tQsQ, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tQsQ, mode=[2]), 0, 1),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
|
||||
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
|
||||
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
|
||||
for rest_v in cutlass.range_constexpr(tQpQ.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tQpQ.shape[2]):
|
||||
tQpQ[rest_v, 0, rest_k] = cute.elem_less(
|
||||
tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3]
|
||||
)
|
||||
|
||||
if cutlass.const_expr(need_kv_predicates):
|
||||
mcKV = cute.make_identity_tensor(mK.layout.shape)
|
||||
cKV = cute.local_tile(
|
||||
mcKV[batch_size, None, num_head, None],
|
||||
(self._n_block_size, self._head_dim_padded),
|
||||
(n_block, 0),
|
||||
)
|
||||
tKVcKV = gmem_thr_copy_QKV.partition_S(cKV)
|
||||
if cutlass.const_expr(has_head_dim_residue):
|
||||
tKVpKV = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(
|
||||
tKsK.shape[0][1],
|
||||
cute.size(tKsK, mode=[1]),
|
||||
cute.size(tKsK, mode=[2]),
|
||||
),
|
||||
stride=(cute.size(tKsK, mode=[2]), 0, 1),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]):
|
||||
for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]):
|
||||
tKVpKV[rest_v, 0, rest_k] = cute.elem_less(
|
||||
tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3]
|
||||
)
|
||||
|
||||
if cutlass.const_expr(need_rab_predicates):
|
||||
mcRAB = cute.make_identity_tensor(mRAB.layout.shape)
|
||||
cRAB = cute.local_tile(
|
||||
mcRAB[batch_size, num_head, None, None],
|
||||
(self._m_block_size, self._n_block_size),
|
||||
(m_block, None),
|
||||
)
|
||||
tRABcRAB = gmem_thr_copy_QKV.partition_S(cRAB)
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Prefetch Prologue
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
# Start async loads of the last mn-tile, where we take care of the mn residue
|
||||
for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])):
|
||||
if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
tQgQ[None, m, None],
|
||||
tQsQ[None, m, None],
|
||||
pred=tQpQ[None, m, None],
|
||||
)
|
||||
else:
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tQsQ[None, m, None].fill(0)
|
||||
|
||||
for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])):
|
||||
if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
tKgK[None, n, None, n_block],
|
||||
tKsK[None, n, None],
|
||||
pred=tKVpKV[None, n, None],
|
||||
)
|
||||
else:
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tKsK[None, n, None].fill(0)
|
||||
|
||||
for m in cutlass.range_constexpr(cute.size(tRABcRAB.shape[1])):
|
||||
for n in cutlass.range_constexpr(cute.size(tRABcRAB.shape[2])):
|
||||
if cute.elem_less(
|
||||
tRABcRAB[0, m, n, n_block][1], mRAB.layout.shape[2]
|
||||
) and cute.elem_less(
|
||||
tRABcRAB[0, m, n, n_block][2], mRAB.layout.shape[3]
|
||||
):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
tRABgRAB[None, m, n, n_block],
|
||||
tRabsRAB[None, m, n],
|
||||
)
|
||||
else:
|
||||
# Clear the smem tiles to account for predicated off loads
|
||||
tRabsRAB[None, m, n].fill(0)
|
||||
self._copy_with_residue(
|
||||
gmem_tiled_copy_QKV,
|
||||
tQgQ[None, None, None],
|
||||
tQsQ[None, None, None],
|
||||
tQcQ if has_q_residue else None,
|
||||
tQpQ if has_head_dim_residue else None,
|
||||
has_q_residue,
|
||||
has_head_dim_residue,
|
||||
mQ.layout.shape[1],
|
||||
(m_block + 1) * self._m_block_size,
|
||||
)
|
||||
# n_block is the last n-tile by construction, so any kv-residue is in this tile.
|
||||
self._copy_with_residue(
|
||||
gmem_tiled_copy_QKV,
|
||||
tKgK[None, None, None, n_block],
|
||||
tKsK[None, None, None],
|
||||
tKVcKV if has_kv_residue else None,
|
||||
tKVpKV if has_head_dim_residue else None,
|
||||
has_kv_residue,
|
||||
has_head_dim_residue,
|
||||
mK.layout.shape[1],
|
||||
(n_block + 1) * self._n_block_size,
|
||||
is_known_boundary=has_kv_residue,
|
||||
)
|
||||
self._copy_rab_tile(
|
||||
gmem_tiled_copy_QKV,
|
||||
tRABgRAB[None, None, None, n_block],
|
||||
tRABsRAB[None, None, None],
|
||||
tRABcRAB[None, None, None, n_block] if need_rab_predicates else None,
|
||||
has_q_residue,
|
||||
has_kv_residue,
|
||||
mRAB.layout.shape[2],
|
||||
mRAB.layout.shape[3],
|
||||
(m_block + 1) * self._m_block_size,
|
||||
(n_block + 1) * self._n_block_size,
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -565,24 +666,17 @@ class HSTUAttentionForwardAmpere(object):
|
||||
cute.arch.cp_async_wait_group(0)
|
||||
self.cta_sync_barrier.arrive_and_wait()
|
||||
|
||||
if n_block_idx == n_block:
|
||||
for n in cutlass.range_constexpr(cute.size(tVsV.shape[1])):
|
||||
if cute.elem_less(tKVcKV[0, n, 0][1], mV.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
tVgV[None, n, None, n_block_idx],
|
||||
tVsV[None, n, None],
|
||||
pred=tKVpKV[None, n, None],
|
||||
)
|
||||
else:
|
||||
tVsV[None, n, None].fill(0)
|
||||
else:
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
tVgV[None, None, None, n_block_idx],
|
||||
tVsV[None, None, None],
|
||||
pred=tKVpKV[None, None, None],
|
||||
)
|
||||
self._copy_with_residue(
|
||||
gmem_tiled_copy_QKV,
|
||||
tVgV[None, None, None, n_block_idx],
|
||||
tVsV[None, None, None],
|
||||
tKVcKV if has_kv_residue else None,
|
||||
tKVpKV if has_head_dim_residue else None,
|
||||
has_kv_residue,
|
||||
has_head_dim_residue,
|
||||
mV.layout.shape[1],
|
||||
(n_block_idx + 1) * self._n_block_size,
|
||||
)
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
acc_shape_S = thr_mma.partition_shape_C(
|
||||
@@ -643,25 +737,33 @@ class HSTUAttentionForwardAmpere(object):
|
||||
self.cta_sync_barrier.arrive_and_wait()
|
||||
|
||||
if n_block_idx > 0:
|
||||
cute.copy(
|
||||
# tile (n_block_idx - 1) is always inside seqlen_kv, so only head_dim residue can apply
|
||||
self._copy_with_residue(
|
||||
gmem_tiled_copy_QKV,
|
||||
tKgK[None, None, None, n_block_idx - 1],
|
||||
tKsK[None, None, None],
|
||||
pred=tKVpKV[None, None, None],
|
||||
None,
|
||||
tKVpKV[None, None, None] if has_head_dim_residue else None,
|
||||
False,
|
||||
has_head_dim_residue,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
self._copy_rab_tile(
|
||||
gmem_tiled_copy_QKV,
|
||||
tRABgRAB[None, None, None, n_block_idx - 1],
|
||||
tRABsRAB[None, None, None],
|
||||
tRABcRAB[None, None, None, n_block_idx - 1]
|
||||
if need_rab_predicates
|
||||
else None,
|
||||
has_q_residue,
|
||||
has_kv_residue,
|
||||
mRAB.layout.shape[2],
|
||||
mRAB.layout.shape[3],
|
||||
(m_block + 1) * self._m_block_size,
|
||||
None,
|
||||
is_known_kv_interior=True,
|
||||
)
|
||||
# m residue handling for RAB
|
||||
for m in cutlass.range_constexpr(cute.size(tRABcRAB.shape[1])):
|
||||
if cute.elem_less(
|
||||
tRABcRAB[0, m, 0, n_block_idx - 1][1], mRAB.layout.shape[2]
|
||||
):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_QKV,
|
||||
tRABgRAB[None, m, None, n_block_idx - 1],
|
||||
tRabsRAB[None, m, None],
|
||||
)
|
||||
else:
|
||||
tRabsRAB[None, m, None].fill(0)
|
||||
|
||||
cute.arch.cp_async_commit_group()
|
||||
|
||||
# ///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -695,26 +797,29 @@ class HSTUAttentionForwardAmpere(object):
|
||||
t4 = t1 / t3
|
||||
acc_S.store(t4)
|
||||
|
||||
mACC = cute.make_identity_tensor(
|
||||
(mRAB.layout.shape[2], mRAB.layout.shape[3])
|
||||
) # (seqlen_q, seqlen_kv)
|
||||
cACC = cute.local_tile(
|
||||
mACC[None, None],
|
||||
(self._m_block_size, self._n_block_size),
|
||||
(m_block, n_block_idx),
|
||||
)
|
||||
if cutlass.const_expr(self._is_causal):
|
||||
mACC = cute.make_identity_tensor(
|
||||
(mRAB.layout.shape[2], mRAB.layout.shape[3])
|
||||
) # (seqlen_q, seqlen_kv)
|
||||
cACC = cute.local_tile(
|
||||
mACC[None, None],
|
||||
(self._m_block_size, self._n_block_size),
|
||||
(m_block, n_block_idx),
|
||||
)
|
||||
|
||||
if self._is_causal and (n_block - n_block_idx) < cute.ceil_div(
|
||||
self._m_block_size, self._n_block_size
|
||||
):
|
||||
tACCcACC = thr_mma.partition_C(cACC)
|
||||
for i in cutlass.range_constexpr(cute.size(tACCcACC.shape[0])):
|
||||
for j in cutlass.range_constexpr(cute.size(tACCcACC.shape[1])):
|
||||
for k in cutlass.range_constexpr(cute.size(tACCcACC.shape[2])):
|
||||
if cute.elem_less(
|
||||
tACCcACC[i, j, k][0], tACCcACC[i, j, k][1]
|
||||
if (n_block - n_block_idx) < cute.ceil_div(
|
||||
self._m_block_size, self._n_block_size
|
||||
):
|
||||
tACCcACC = thr_mma.partition_C(cACC)
|
||||
for i in cutlass.range_constexpr(cute.size(tACCcACC.shape[0])):
|
||||
for j in cutlass.range_constexpr(cute.size(tACCcACC.shape[1])):
|
||||
for k in cutlass.range_constexpr(
|
||||
cute.size(tACCcACC.shape[2])
|
||||
):
|
||||
acc_S[i, j, k] = 0.0
|
||||
if cute.elem_less(
|
||||
tACCcACC[i, j, k][0], tACCcACC[i, j, k][1]
|
||||
):
|
||||
acc_S[i, j, k] = 0.0
|
||||
|
||||
rP = cute.make_rmem_tensor_like(acc_S, self._dtype)
|
||||
rP.store(acc_S.load().to(self._dtype))
|
||||
@@ -803,35 +908,39 @@ class HSTUAttentionForwardAmpere(object):
|
||||
tOsO,
|
||||
tOrO,
|
||||
)
|
||||
# predicate for O
|
||||
mcO = cute.make_identity_tensor(mO.layout.shape)
|
||||
cO = cute.local_tile(
|
||||
mcO[batch_size, None, num_head, None],
|
||||
(self._m_block_size, self._head_dim_padded),
|
||||
(m_block, 0),
|
||||
)
|
||||
tOcO = gmem_thr_copy_O.partition_D(cO)
|
||||
tOpO = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]),
|
||||
stride=(tOgO.shape[2], 0, 1),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
|
||||
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
|
||||
tOpO[rest_v, 0, rest_n] = cute.elem_less(
|
||||
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
|
||||
)
|
||||
# copy acc O from rmem to gmem
|
||||
for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])):
|
||||
if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]):
|
||||
cute.copy(
|
||||
gmem_tiled_copy_O,
|
||||
tOrO[None, rest_m, None],
|
||||
tOgO[None, rest_m, None],
|
||||
pred=tOpO[None, rest_m, None],
|
||||
if cutlass.const_expr(need_q_predicates):
|
||||
mcO = cute.make_identity_tensor(mO.layout.shape)
|
||||
cO = cute.local_tile(
|
||||
mcO[batch_size, None, num_head, None],
|
||||
(self._m_block_size, self._head_dim_padded),
|
||||
(m_block, 0),
|
||||
)
|
||||
tOcO = gmem_thr_copy_O.partition_D(cO)
|
||||
if cutlass.const_expr(has_head_dim_residue):
|
||||
tOpO = cute.make_rmem_tensor(
|
||||
cute.make_layout(
|
||||
(tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]),
|
||||
stride=(tOgO.shape[2], 0, 1),
|
||||
),
|
||||
cutlass.Boolean,
|
||||
)
|
||||
for rest_v in cutlass.range_constexpr(tOpO.shape[0]):
|
||||
for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])):
|
||||
tOpO[rest_v, 0, rest_n] = cute.elem_less(
|
||||
tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3]
|
||||
)
|
||||
self._copy_with_residue(
|
||||
gmem_tiled_copy_O,
|
||||
tOrO[None, None, None],
|
||||
tOgO[None, None, None],
|
||||
tOcO if has_q_residue else None,
|
||||
tOpO if has_head_dim_residue else None,
|
||||
has_q_residue,
|
||||
has_head_dim_residue,
|
||||
mO.layout.shape[1],
|
||||
(m_block + 1) * self._m_block_size,
|
||||
fill_zero_on_oob=False,
|
||||
)
|
||||
|
||||
|
||||
def run_pytorch_hstu_test(
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import enum
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
@@ -42,12 +41,20 @@ import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cute.nvgpu.common import OperandMajorMode
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import Int32, Float32, Float8E4M3FN, Float16, BFloat16, Boolean
|
||||
from cutlass.cute.typing import (
|
||||
Int32,
|
||||
Float32,
|
||||
Float8E4M3FN,
|
||||
Float16,
|
||||
BFloat16,
|
||||
Boolean,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -74,7 +81,7 @@ To run this example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python examples/blackwell/fmha_bwd.py \\
|
||||
python examples/cute/blackwell/kernel/attention/fmha/fmha_bwd.py \\
|
||||
--s_q_max 1024 --s_k_max 1024 \\
|
||||
--h_q 8 --h_k 8 --d 128 --b 1 \\
|
||||
--element_dtype float16 --acc_dtype float32 \\
|
||||
@@ -193,6 +200,22 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
barrier_id=5,
|
||||
num_threads=self.num_reduce_warps * self.threads_per_warp,
|
||||
)
|
||||
self.load_reduce_tma_sync_barrier_0 = pipeline.NamedBarrier(
|
||||
barrier_id=6,
|
||||
num_threads=2 * self.threads_per_warp,
|
||||
)
|
||||
self.load_reduce_tma_sync_barrier_1 = pipeline.NamedBarrier(
|
||||
barrier_id=7,
|
||||
num_threads=2 * self.threads_per_warp,
|
||||
)
|
||||
self.load_reduce_tma_sync_barrier_2 = pipeline.NamedBarrier(
|
||||
barrier_id=8,
|
||||
num_threads=2 * self.threads_per_warp,
|
||||
)
|
||||
self.load_reduce_tma_sync_barrier_3 = pipeline.NamedBarrier(
|
||||
barrier_id=9,
|
||||
num_threads=2 * self.threads_per_warp,
|
||||
)
|
||||
|
||||
self.tmem_dK_offset = 0
|
||||
self.tmem_dV_offset = self.tmem_dK_offset + mma_tiler[2]
|
||||
@@ -200,8 +223,8 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
self.tmem_dP_offset = self.tmem_dQ_offset
|
||||
self.tmem_S_offset = self.tmem_dQ_offset + max(mma_tiler[0], mma_tiler[2])
|
||||
|
||||
self.num_regs_reduce = 152
|
||||
self.num_regs_compute = 128
|
||||
self.num_regs_reduce = 144
|
||||
self.num_regs_compute = 136
|
||||
self.num_regs_mma = 96
|
||||
self.num_regs_empty = 96
|
||||
self.num_regs_load = 96
|
||||
@@ -344,17 +367,17 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
self.dO_major_mode = utils.LayoutEnum.from_tensor(dO).mma_major_mode()
|
||||
self.dQ_layout = utils.LayoutEnum.from_tensor(dQ)
|
||||
|
||||
if cutlass.const_expr(self.Q_major_mode != tcgen05.OperandMajorMode.K):
|
||||
if cutlass.const_expr(self.Q_major_mode != OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of q is not supported")
|
||||
if cutlass.const_expr(self.dQ_major_mode != tcgen05.OperandMajorMode.K):
|
||||
if cutlass.const_expr(self.dQ_major_mode != OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of dq is not supported")
|
||||
if cutlass.const_expr(self.K_major_mode != tcgen05.OperandMajorMode.K):
|
||||
if cutlass.const_expr(self.K_major_mode != OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of k is not supported")
|
||||
if cutlass.const_expr(self.dK_major_mode != tcgen05.OperandMajorMode.K):
|
||||
if cutlass.const_expr(self.dK_major_mode != OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of dk is not supported")
|
||||
if cutlass.const_expr(self.V_major_mode != tcgen05.OperandMajorMode.K):
|
||||
if cutlass.const_expr(self.V_major_mode != OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of v is not supported")
|
||||
if cutlass.const_expr(self.dV_major_mode != tcgen05.OperandMajorMode.K):
|
||||
if cutlass.const_expr(self.dV_major_mode != OperandMajorMode.K):
|
||||
raise RuntimeError("The layout of dv is not supported")
|
||||
|
||||
self._setup_attributes()
|
||||
@@ -364,8 +387,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# compute S
|
||||
KQ_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
||||
self.element_dtype,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
self.element_dtype,
|
||||
OperandMajorMode.K,
|
||||
OperandMajorMode.K,
|
||||
self.acc_dtype,
|
||||
cta_group,
|
||||
self.KQ_mma_tiler[:2],
|
||||
@@ -373,8 +397,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# compute dP
|
||||
VdO_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
||||
self.element_dtype,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
self.element_dtype,
|
||||
OperandMajorMode.K,
|
||||
OperandMajorMode.K,
|
||||
self.acc_dtype,
|
||||
cta_group,
|
||||
self.VdO_mma_tiler[:2],
|
||||
@@ -382,8 +407,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# compute dV
|
||||
PdO_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
||||
self.element_dtype,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.MN,
|
||||
self.element_dtype,
|
||||
OperandMajorMode.K,
|
||||
OperandMajorMode.MN,
|
||||
self.acc_dtype,
|
||||
cta_group,
|
||||
self.PdO_mma_tiler[:2],
|
||||
@@ -392,8 +418,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# compute dK
|
||||
dSQ_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
||||
self.element_dtype,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.MN,
|
||||
self.element_dtype,
|
||||
OperandMajorMode.K,
|
||||
OperandMajorMode.MN,
|
||||
self.acc_dtype,
|
||||
cta_group,
|
||||
self.dSQ_mma_tiler[:2],
|
||||
@@ -401,8 +428,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# compute dQ
|
||||
dSK_tiled_mma = sm100_utils.make_trivial_tiled_mma(
|
||||
self.element_dtype,
|
||||
tcgen05.OperandMajorMode.MN,
|
||||
tcgen05.OperandMajorMode.MN,
|
||||
self.element_dtype,
|
||||
OperandMajorMode.MN,
|
||||
OperandMajorMode.MN,
|
||||
self.acc_dtype,
|
||||
cta_group,
|
||||
self.dSK_mma_tiler[:2],
|
||||
@@ -483,7 +511,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
|
||||
dQ_smem_layout_atom = sm100_utils.make_smem_layout_atom(
|
||||
sm100_utils.get_smem_layout_atom_ab(
|
||||
tcgen05.OperandMajorMode.K,
|
||||
OperandMajorMode.K,
|
||||
self.acc_dtype,
|
||||
(self.tile_shape_Q, 32),
|
||||
),
|
||||
@@ -844,9 +872,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
LSE_smem_layout: cute.Layout,
|
||||
sum_OdO_smem_layout: cute.Layout,
|
||||
):
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
grid_dim_x, grid_dim_y, grid_dim_z = cute.arch.grid_dim()
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
|
||||
if warp_idx == self.load_warp_id:
|
||||
@@ -1265,12 +1291,10 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# (load_mma_Q_pipeline, load_compute_LSE_pipeline, load_mma_dO_pipeline, load_compute_sum_OdO_pipeline)
|
||||
pipeline_args: tuple,
|
||||
):
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
tidx = cute.arch.thread_idx()[0]
|
||||
blk_coord_k, blk_coord_h_k, blk_coord_b = cute.arch.block_idx()
|
||||
blk_coord_h_r = Int32(0)
|
||||
blk_coord_h = (blk_coord_h_r, blk_coord_h_k)
|
||||
seq_Q, seq_K, D, HB = problem_shape
|
||||
H, B = HB
|
||||
iter_index = iter_start
|
||||
(
|
||||
load_mma_Q_pipeline,
|
||||
@@ -1362,6 +1386,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
load_compute_sum_OdO_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.load_compute_sum_OdO_stage
|
||||
)
|
||||
|
||||
load_mma_Q_pipeline.producer_acquire(load_mma_Q_producer_state)
|
||||
tma_barrier = load_mma_Q_pipeline.producer_get_barrier(
|
||||
load_mma_Q_producer_state
|
||||
@@ -1396,7 +1421,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
|
||||
async_copy_num_elts = sLSE.shape[0] // self.threads_per_warp
|
||||
atom_async_copy = cute.make_copy_atom(
|
||||
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS),
|
||||
cpasync.CopyG2SOp(cache_mode=cute.nvgpu.LoadCacheMode.ALWAYS),
|
||||
self.acc_dtype,
|
||||
num_bits_per_copy=self.acc_dtype.width,
|
||||
)
|
||||
@@ -1486,6 +1511,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
iter_count -= 1
|
||||
iter_index += 1
|
||||
|
||||
load_reduce_tma_sync_phase = Int32(0)
|
||||
while iter_count > 0:
|
||||
if iter_index == iter_end:
|
||||
iter_index = iter_start
|
||||
@@ -1507,6 +1533,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
|
||||
load_mma_Q_producer_state.advance()
|
||||
|
||||
self.load_reduce_tma_sync_arrive(load_reduce_tma_sync_phase)
|
||||
load_reduce_tma_sync_phase += 1
|
||||
|
||||
load_compute_LSE_pipeline.producer_acquire(load_compute_LSE_producer_state)
|
||||
|
||||
# Load LSE
|
||||
@@ -1589,6 +1618,30 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
iter_count -= 1
|
||||
iter_index += 1
|
||||
|
||||
@cute.jit
|
||||
def load_reduce_tma_sync_arrive(self, phase: Int32):
|
||||
phase_mod = phase % 4
|
||||
if phase_mod == 0:
|
||||
self.load_reduce_tma_sync_barrier_0.arrive()
|
||||
elif phase_mod == 1:
|
||||
self.load_reduce_tma_sync_barrier_1.arrive()
|
||||
elif phase_mod == 2:
|
||||
self.load_reduce_tma_sync_barrier_2.arrive()
|
||||
else:
|
||||
self.load_reduce_tma_sync_barrier_3.arrive()
|
||||
|
||||
@cute.jit
|
||||
def load_reduce_tma_sync_wait(self, phase: Int32):
|
||||
phase_mod = phase % 4
|
||||
if phase_mod == 0:
|
||||
self.load_reduce_tma_sync_barrier_0.arrive_and_wait()
|
||||
elif phase_mod == 1:
|
||||
self.load_reduce_tma_sync_barrier_1.arrive_and_wait()
|
||||
elif phase_mod == 2:
|
||||
self.load_reduce_tma_sync_barrier_2.arrive_and_wait()
|
||||
else:
|
||||
self.load_reduce_tma_sync_barrier_3.arrive_and_wait()
|
||||
|
||||
@cute.jit
|
||||
def mma(
|
||||
self,
|
||||
@@ -1909,11 +1962,10 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# (mma_compute_S_pipeline, compute_mma_P_pipeline, load_compute_LSE_pipeline, load_compute_sum_OdO_pipeline, mma_compute_dP_pipeline, compute_mma_dS_pipeline, mma_compute_dKdV_pipeline)
|
||||
pipeline_args: tuple,
|
||||
):
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
tidx = cute.arch.thread_idx()[0]
|
||||
|
||||
Q, K, D, HB = problem_shape
|
||||
blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord
|
||||
Q, K = problem_shape[0], problem_shape[1]
|
||||
blk_coord_k, blk_coord_batch = blk_coord[1], blk_coord[3]
|
||||
iter_index = iter_start
|
||||
(
|
||||
mma_compute_S_pipeline,
|
||||
@@ -2205,10 +2257,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# (mma_reduce_dQ_pipeline, reduce_tma_store_pipeline)
|
||||
pipeline_args: tuple,
|
||||
):
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
tidx = cute.arch.thread_idx()[0]
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
Q, K, D, HB = problem_shape
|
||||
blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord
|
||||
blk_coord_batch = blk_coord[3]
|
||||
|
||||
blk_coord_h, blk_coord_b = blk_coord_batch
|
||||
blk_coord_h_r, blk_coord_h_k = blk_coord_h
|
||||
@@ -2241,7 +2292,6 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
thr_t2r = tiled_t2r.get_slice(thread_idx)
|
||||
|
||||
tTR_cdQ = thr_t2r.partition_D(cdQ)
|
||||
tTR_gdQ = thr_t2r.partition_D(gdQ)
|
||||
tTR_sdQ = thr_t2r.partition_D(sdQ)
|
||||
tTR_tdQ = thr_t2r.partition_S(tdQtdQ)
|
||||
|
||||
@@ -2253,6 +2303,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
cute.group_modes(gdQ, 0, 2),
|
||||
)
|
||||
|
||||
load_reduce_tma_sync_phase = Int32(0)
|
||||
while iter_count > 0:
|
||||
mma_reduce_dQ_pipeline.consumer_wait(mma_reduce_dQ_consumer_state)
|
||||
|
||||
@@ -2286,6 +2337,10 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
self.reduce_sync_barrier.arrive_and_wait()
|
||||
|
||||
if warp_idx == 0:
|
||||
if iter_count > 1 and i == 0:
|
||||
self.load_reduce_tma_sync_wait(load_reduce_tma_sync_phase)
|
||||
load_reduce_tma_sync_phase += 1
|
||||
|
||||
cute.copy(
|
||||
tma_atom_dQ_acc,
|
||||
tdQsdQ[None, reduce_tma_store_producer_state.index],
|
||||
@@ -2346,7 +2401,6 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
input: cute.Tensor,
|
||||
frg_cnt: Int32,
|
||||
) -> cute.Tensor:
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
output = cute.make_rmem_tensor(input.shape, self.element_dtype)
|
||||
frg_tile = cute.size(input) // frg_cnt
|
||||
t_frg = cute.logical_divide(input, cute.make_layout(frg_cnt))
|
||||
@@ -2406,9 +2460,9 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
# (mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state)
|
||||
pipeline_args: tuple,
|
||||
):
|
||||
tidx, tidy, tidz = cute.arch.thread_idx()
|
||||
Q, K, D, HB = problem_shape
|
||||
blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_batch = blk_coord
|
||||
tidx = cute.arch.thread_idx()[0]
|
||||
K, D, HB = problem_shape[1], problem_shape[2], problem_shape[3]
|
||||
blk_coord_k, blk_coord_batch = blk_coord[1], blk_coord[3]
|
||||
mma_compute_dKdV_pipeline, mma_compute_dKdV_consumer_state = pipeline_args
|
||||
|
||||
load_op = cute.make_copy_atom(
|
||||
@@ -2511,7 +2565,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
workspace: cute.Tensor,
|
||||
acc_dtype: Type[cutlass.Numeric],
|
||||
) -> Tuple[cute.Tensor, cute.Tensor, cute.Tensor]:
|
||||
Q, D, HB = (
|
||||
Q, D, _HB = (
|
||||
problem_shape[0],
|
||||
problem_shape[2],
|
||||
problem_shape[3],
|
||||
@@ -2524,7 +2578,7 @@ class BlackwellFusedMultiHeadAttentionBackward:
|
||||
acc_bytes = acc_dtype.width // 8
|
||||
sum_OdO_bytes = cute.assume(B * H * Q * acc_bytes, divby=acc_bytes)
|
||||
scaled_lse_bytes = cute.assume(B * H * Q * acc_bytes, divby=acc_bytes)
|
||||
dQ_acc_bytes = cute.assume(B * H * Q * D * acc_bytes, divby=acc_bytes)
|
||||
cute.assume(B * H * Q * D * acc_bytes, divby=acc_bytes)
|
||||
|
||||
sum_OdO_iter = workspace.iterator
|
||||
scaled_lse_iter = sum_OdO_iter + sum_OdO_bytes
|
||||
@@ -2899,7 +2953,9 @@ def run(
|
||||
window_size_right,
|
||||
bottom_right_align,
|
||||
):
|
||||
raise ValueError("sliding window doesn't support current setting")
|
||||
raise testing.CantImplementError(
|
||||
"sliding window doesn't support current setting"
|
||||
)
|
||||
|
||||
# create sequence lengths for variable length inputs
|
||||
cumulative_s_q = [0]
|
||||
@@ -2985,10 +3041,10 @@ def run(
|
||||
|
||||
lse_ref = cutlass_torch.create_and_permute_torch_tensor(
|
||||
(b, h_k, h_r, s_q),
|
||||
cutlass.torch.dtype(acc_dtype),
|
||||
cutlass_torch.dtype(acc_dtype),
|
||||
permute_order=(3, 2, 1, 0),
|
||||
init_type=cutlass.torch.TensorInitType.RANDOM,
|
||||
init_config=cutlass.torch.RandomInitConfig(min_val=10, max_val=11),
|
||||
init_type=cutlass_torch.TensorInitType.RANDOM,
|
||||
init_config=cutlass_torch.RandomInitConfig(min_val=10, max_val=11),
|
||||
)
|
||||
lse_torch = lse_ref.cuda()
|
||||
lse_tensor = from_dlpack(lse_torch, assumed_align=16)
|
||||
@@ -2997,7 +3053,11 @@ def run(
|
||||
mma_tiler = (*mma_tiler_mn, d)
|
||||
|
||||
fmha_bwd = BlackwellFusedMultiHeadAttentionBackward(
|
||||
element_dtype, acc_dtype, mma_tiler, varlen, mask_type
|
||||
element_dtype,
|
||||
acc_dtype,
|
||||
mma_tiler,
|
||||
varlen,
|
||||
mask_type,
|
||||
)
|
||||
|
||||
workspace_size = BlackwellFusedMultiHeadAttentionBackward._get_workspace_size(
|
||||
@@ -3112,11 +3172,11 @@ def run(
|
||||
torch.cuda.synchronize()
|
||||
print("Verifying results...")
|
||||
|
||||
q_ref = q_ref.cuda().to(cutlass.torch.dtype(element_dtype))
|
||||
k_ref = k_ref.cuda().to(cutlass.torch.dtype(element_dtype))
|
||||
v_ref = v_ref.cuda().to(cutlass.torch.dtype(element_dtype))
|
||||
o_ref = o_ref.cuda().to(cutlass.torch.dtype(element_dtype))
|
||||
do_ref = do_ref.cuda().to(cutlass.torch.dtype(element_dtype))
|
||||
q_ref = q_ref.cuda().to(cutlass_torch.dtype(element_dtype))
|
||||
k_ref = k_ref.cuda().to(cutlass_torch.dtype(element_dtype))
|
||||
v_ref = v_ref.cuda().to(cutlass_torch.dtype(element_dtype))
|
||||
o_ref = o_ref.cuda().to(cutlass_torch.dtype(element_dtype))
|
||||
do_ref = do_ref.cuda().to(cutlass_torch.dtype(element_dtype))
|
||||
dv = dv_torch.to(dtype=torch.float32)
|
||||
dk = dk_torch.to(dtype=torch.float32)
|
||||
dq = dq_torch.to(dtype=torch.float32)
|
||||
@@ -3227,10 +3287,10 @@ def run(
|
||||
|
||||
lse_ref_new = cutlass_torch.create_and_permute_torch_tensor(
|
||||
(b, h_k, h_r, s_q),
|
||||
cutlass.torch.dtype(acc_dtype),
|
||||
cutlass_torch.dtype(acc_dtype),
|
||||
permute_order=(3, 2, 1, 0),
|
||||
init_type=cutlass.torch.TensorInitType.RANDOM,
|
||||
init_config=cutlass.torch.RandomInitConfig(min_val=10, max_val=11),
|
||||
init_type=cutlass_torch.TensorInitType.RANDOM,
|
||||
init_config=cutlass_torch.RandomInitConfig(min_val=10, max_val=11),
|
||||
)
|
||||
lse_torch_new = lse_ref_new.cuda()
|
||||
lse_tensor_new = from_dlpack(lse_torch_new, assumed_align=16)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,54 +27,21 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
import cuda.bindings.driver as cuda
|
||||
try:
|
||||
from cuda.core import Device
|
||||
except ImportError:
|
||||
from cuda.core.experimental import Device
|
||||
from cuda.pathfinder import load_nvidia_dynamic_lib
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass import testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cutlass_dsl import T
|
||||
from cutlass._mlir.dialects import vector
|
||||
|
||||
try:
|
||||
import nvshmem.core
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem4py is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvshmem4py-cu12\n"
|
||||
" For CUDA 13: pip install nvshmem4py-cu13\n"
|
||||
"Note: nvshmem4py version >= 0.1.3 is recommended."
|
||||
) from None
|
||||
|
||||
try:
|
||||
load_nvidia_dynamic_lib("nvshmem_host")
|
||||
except RuntimeError as exc:
|
||||
raise ImportError(
|
||||
"nvshmem lib is required but not installed. Please install it using:\n"
|
||||
" For CUDA 12: pip install nvidia-nvshmem-cu12\n"
|
||||
" For CUDA 13: pip install nvidia-nvshmem-cu13\n"
|
||||
) from None
|
||||
|
||||
"""
|
||||
A Distributed One-Shot All-Reduce Example using CuTe DSL and fine-grained memory control. This is a mirrored version of the
|
||||
existing tensorrt_llm kernel:
|
||||
https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu
|
||||
|
||||
In Lamport terminology this is a classic flag-based busy-wait: every participant keeps polling the shared slot until the
|
||||
flag changes from the sentinel (negative zero) to real data, which indicates that the Lamport-style logical ordering has
|
||||
advanced and the payload is safe to consume.
|
||||
|
||||
This example kernel demonstrates a one-shot all-reduce operation using the CuTe DSL with fine-grained memory control.
|
||||
It uses dedicated communication buffers for data exchange, and these buffers act as ping-pong buffers. During the
|
||||
process, the kernel uses one buffer for communication and initializes the next buffer to all negative zeros.
|
||||
@@ -90,8 +57,8 @@ The .SYS memory scope and .VOLATILE memory order are used to ensure that the dat
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_one_shot_lamport.py --M 8192 --N 8192
|
||||
torchrun --nproc-per-node 8 examples/distributed/all_reduce_one_shot_lamport.py \
|
||||
torchrun --nproc-per-node 8 examples/cute/blackwell/distributed/all_reduce_one_shot_lamport.py --M 8192 --N 8192
|
||||
torchrun --nproc-per-node 8 examples/cute/blackwell/distributed/all_reduce_one_shot_lamport.py \
|
||||
--M 8192 --N 8192 --benchmark --warmup_iterations 2 --iterations 10
|
||||
"""
|
||||
|
||||
@@ -174,14 +141,14 @@ class AllReduceOneShotLamportKernel:
|
||||
|
||||
# assume all buffers have the same element type with input
|
||||
copy_atom_load = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
cute.nvgpu.CopyG2ROp(),
|
||||
buffers[0].element_type,
|
||||
num_bits_per_copy=128,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE,
|
||||
)
|
||||
copy_atom_store = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
cute.nvgpu.CopyR2GOp(),
|
||||
buffers[0].element_type,
|
||||
num_bits_per_copy=128,
|
||||
memory_scope=cute.nvgpu.common.MemoryScope.SYS,
|
||||
@@ -276,6 +243,10 @@ def run_all_reduce_one_shot(
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
):
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
if rank == 0:
|
||||
@@ -284,8 +255,20 @@ def run_all_reduce_one_shot(
|
||||
print(f"GPU count: {world_size}")
|
||||
|
||||
# init buffer tensors to be neg 0
|
||||
local_buffer_tensor = nvshmem.core.tensor([PING_PONG_SIZE, world_size, M, N,], dtype=torch.float32).neg_()
|
||||
buffer_tensor_list = [nvshmem.core.get_peer_tensor(local_buffer_tensor, rank).permute(2, 3, 1, 0) for rank in range(world_size)]
|
||||
t = symm_mem.empty(
|
||||
[
|
||||
PING_PONG_SIZE,
|
||||
world_size,
|
||||
M,
|
||||
N,
|
||||
],
|
||||
device="cuda",
|
||||
).neg_()
|
||||
hdl = symm_mem.rendezvous(t, dist.group.WORLD)
|
||||
buffer_tensor_list = [
|
||||
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 3, 1, 0)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
signal = cutlass.Int32(0)
|
||||
input_tensor = torch.randn([M, N], device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros([M, N], device=f"cuda:{rank}")
|
||||
@@ -319,33 +302,40 @@ def run_all_reduce_one_shot(
|
||||
if rank == 0:
|
||||
print("Results verified successfully!")
|
||||
|
||||
for t in buffer_tensor_list:
|
||||
nvshmem.core.free_tensor(t)
|
||||
|
||||
if not benchmark:
|
||||
return
|
||||
|
||||
free_func_and_tensor_pairs = []
|
||||
def add_free_func_and_tensor(free_func, tensor):
|
||||
free_func_and_tensor_pairs.append((free_func, tensor))
|
||||
|
||||
def generate_tensors():
|
||||
local_buffer = nvshmem.core.tensor([PING_PONG_SIZE, world_size, M, N,], dtype=torch.float32).neg_()
|
||||
buffer_tensor_list = [nvshmem.core.get_peer_tensor(local_buffer, rank).permute(2, 3, 1, 0) for rank in range(world_size)]
|
||||
input_tensor = torch.randn([M, N], device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros([M, N], device=f"cuda:{rank}")
|
||||
t = symm_mem.empty(
|
||||
[
|
||||
PING_PONG_SIZE,
|
||||
world_size,
|
||||
M * N,
|
||||
],
|
||||
device="cuda",
|
||||
).neg_()
|
||||
hdl = symm_mem.rendezvous(t, group=dist.group.WORLD.group_name)
|
||||
# get tensors from other devices from the symmetric memory
|
||||
buffers = [
|
||||
hdl.get_buffer(rank, t.shape, t.dtype).permute(2, 1, 0)
|
||||
for rank in range(world_size)
|
||||
]
|
||||
input_tensor = torch.randn(M * N, device=f"cuda:{rank}")
|
||||
output_tensor = torch.zeros(M * N, device=f"cuda:{rank}")
|
||||
|
||||
ja = testing.JitArguments(
|
||||
cutlass.Int32(0),
|
||||
from_dlpack(input_tensor, assumed_align=32),
|
||||
from_dlpack(output_tensor, assumed_align=32),
|
||||
[from_dlpack(t, assumed_align=32) for t in buffer_tensor_list],
|
||||
stream=stream
|
||||
[from_dlpack(t, assumed_align=32) for t in buffers],
|
||||
stream=stream,
|
||||
)
|
||||
for tensor in buffer_tensor_list:
|
||||
add_free_func_and_tensor(nvshmem.core.free_tensor, tensor)
|
||||
|
||||
ja._hdl = (
|
||||
hdl # in order to extend the life cycle of hdl for the kernel execution
|
||||
)
|
||||
ja._t = t # same reason
|
||||
return ja
|
||||
|
||||
avg_time_us = testing.benchmark(
|
||||
compiled_func,
|
||||
workspace_generator=generate_tensors,
|
||||
@@ -361,47 +351,39 @@ def run_all_reduce_one_shot(
|
||||
f"Achieved memory throughput: {((world_size + 1) * output_tensor.numel() * 32 // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s"
|
||||
)
|
||||
|
||||
for free_func, tensor in free_func_and_tensor_pairs:
|
||||
free_func(tensor)
|
||||
|
||||
def torchrun_uid_init_bcast():
|
||||
"""
|
||||
Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
|
||||
def run(
|
||||
M,
|
||||
N,
|
||||
warmup_iterations=2,
|
||||
iterations=10,
|
||||
skip_ref_check=False,
|
||||
benchmark=True,
|
||||
):
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed._symmetric_memory as symm_mem
|
||||
|
||||
It uses torch.distributed.broadcast on a NumPy array to handle the broadcasting
|
||||
"""
|
||||
# Set Torch device
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
globals()["torch"] = torch
|
||||
globals()["dist"] = dist
|
||||
globals()["symm_mem"] = symm_mem
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
# nvshmem4py requires a cuda.core Device at init time
|
||||
dev = Device(local_rank)
|
||||
dev.set_current()
|
||||
global stream
|
||||
stream = dev.create_stream()
|
||||
|
||||
# Initialize torch.distributed process group
|
||||
dist.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
)
|
||||
|
||||
# Extract rank, nranks from process group
|
||||
num_ranks = dist.get_world_size()
|
||||
|
||||
# Create an empty uniqueid for all ranks
|
||||
uid = nvshmem.core.get_unique_id(empty=(local_rank != 0))
|
||||
uid_bytes = uid._data.view(np.uint8).copy()
|
||||
uid_tensor = torch.from_numpy(uid_bytes).cuda()
|
||||
dist.broadcast(uid_tensor, src=0)
|
||||
dist.barrier()
|
||||
uid._data[:] = uid_tensor.cpu().numpy().view(uid._data.dtype)
|
||||
|
||||
nvshmem.core.init(device=dev, uid=uid, rank=local_rank, nranks=num_ranks, initializer_method="uid")
|
||||
|
||||
|
||||
def torchrun_finalize():
|
||||
nvshmem.core.finalize()
|
||||
dist.destroy_process_group()
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
try:
|
||||
run_all_reduce_one_shot(
|
||||
M,
|
||||
N,
|
||||
warmup_iterations,
|
||||
iterations,
|
||||
skip_ref_check,
|
||||
benchmark,
|
||||
)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def main():
|
||||
@@ -414,13 +396,17 @@ def main():
|
||||
parser.add_argument("--iterations", default=10, type=int)
|
||||
parser.add_argument("--skip_ref_check", action="store_true")
|
||||
parser.add_argument("--benchmark", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
torchrun_uid_init_bcast()
|
||||
|
||||
run_all_reduce_one_shot(args.M, args.N, args.warmup_iterations, args.iterations, args.skip_ref_check, args.benchmark)
|
||||
|
||||
torchrun_finalize()
|
||||
run(
|
||||
args.M,
|
||||
args.N,
|
||||
args.warmup_iterations,
|
||||
args.iterations,
|
||||
args.skip_ref_check,
|
||||
args.benchmark,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -269,8 +269,8 @@ class GroupedGemmKernel:
|
||||
c_bytes = c_bytes_per_stage * self.num_c_stage
|
||||
|
||||
self.num_sched_stages = 2
|
||||
sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
|
||||
sched_bytes = sched_work_tile_bytes_per_stage * self.num_sched_stages
|
||||
self.sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
|
||||
sched_bytes = self.sched_work_tile_bytes_per_stage * self.num_sched_stages
|
||||
|
||||
fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes
|
||||
|
||||
@@ -774,7 +774,11 @@ class GroupedGemmKernel:
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_acc_stage * 2
|
||||
]
|
||||
sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4]
|
||||
sched_buf_align_bytes = self.sched_work_tile_bytes_per_stage
|
||||
sched_buf: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4],
|
||||
sched_buf_align_bytes,
|
||||
]
|
||||
sched_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_sched_stages * 2
|
||||
]
|
||||
@@ -2010,9 +2014,6 @@ if __name__ == "__main__":
|
||||
compare_with_bmm=args.compare_with_bmm,
|
||||
compare_with_sol=args.compare_with_sol,
|
||||
)
|
||||
if misc.no_torch_210:
|
||||
misc.compare_with_bmm = True
|
||||
print("Override to set --compare_with_bmm to avoid possible torch crash.")
|
||||
|
||||
tester = GroupedGemmTester(problem, impl, misc)
|
||||
tester.run()
|
||||
|
||||
@@ -359,6 +359,7 @@ class ScaledGroupedGemmKernel:
|
||||
)
|
||||
|
||||
self.num_sched_stages = 2
|
||||
self.sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
|
||||
|
||||
# ── SMEM layouts ──
|
||||
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
||||
@@ -1371,7 +1372,11 @@ class ScaledGroupedGemmKernel:
|
||||
acc_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_acc_pipeline_stages * 2
|
||||
]
|
||||
sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4]
|
||||
sched_buf_align_bytes = self.sched_work_tile_bytes_per_stage
|
||||
sched_buf: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4],
|
||||
sched_buf_align_bytes,
|
||||
]
|
||||
sched_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_sched_stages * 2
|
||||
]
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# This is the third tutorial GEMM. It further enhances the second tutorial by adding warp
|
||||
# specialization for TMA, MMA, and epilogue warps.
|
||||
@@ -32,7 +50,7 @@ The dynamic scheduler is more flexible than the static scheduler, as it can hand
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_3_1.py \
|
||||
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_3_1.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
@@ -72,29 +90,32 @@ class SharedStorage:
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
# Only for CLC Dynamic Scheduler
|
||||
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
|
||||
clc_response_align_bytes = num_clc_response_bytes
|
||||
clc_response: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, 4],
|
||||
clc_response_align_bytes,
|
||||
]
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
tma_a: cpasync.TmaInfo,
|
||||
tma_b: cpasync.TmaInfo,
|
||||
tma_c: cpasync.TmaInfo,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: Union[
|
||||
utils.ClcDynamicPersistentTileSchedulerParams,
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
):
|
||||
):
|
||||
# Extract tma_tensor from TmaInfo
|
||||
mA_mkl = tma_a.tma_tensor
|
||||
mB_nkl = tma_b.tma_tensor
|
||||
mC_mnl = tma_c.tma_tensor
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
@@ -123,9 +144,9 @@ def kernel(
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
cpasync.prefetch_descriptor(tma_a.atom)
|
||||
cpasync.prefetch_descriptor(tma_b.atom)
|
||||
cpasync.prefetch_descriptor(tma_c.atom)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
@@ -167,8 +188,8 @@ def kernel(
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
@@ -248,21 +269,21 @@ def kernel(
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
layout=tma_a.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
swizzle=tma_a.smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
layout=tma_b.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
swizzle=tma_b.smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
layout=tma_c.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
swizzle=tma_c.smem_layout.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
@@ -301,7 +322,7 @@ def kernel(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
@@ -311,7 +332,7 @@ def kernel(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
@@ -321,7 +342,7 @@ def kernel(
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
@@ -379,14 +400,14 @@ def kernel(
|
||||
|
||||
# Issue TMA loads
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
@@ -447,23 +468,14 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
@@ -496,10 +508,10 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
# Initialize TMA store pipeline for epilogue with 4 warps
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
pipeline.Agent.Warp,
|
||||
size=4,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
@@ -594,7 +606,7 @@ def kernel(
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
@@ -678,8 +690,8 @@ def host_function(
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
@@ -717,21 +729,18 @@ def host_function(
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
a_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
b_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
@@ -757,11 +766,10 @@ def host_function(
|
||||
epi_stages,
|
||||
)
|
||||
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
c_tma_atom, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
@@ -779,16 +787,10 @@ def host_function(
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
a_tma_atom,
|
||||
a_tma_tensor,
|
||||
b_tma_atom,
|
||||
b_tma_tensor,
|
||||
c_tma_atom,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
tma_a,
|
||||
tma_b,
|
||||
tma_c,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
tile_sched_params,
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# This is the fourth tutorial GEMM (4). It extends fp16_gemm_3_1.py by adding TMA prefetch.
|
||||
# TMA prefetch uses cute.prefetch() to bring data into L2 cache before TMA copy needs it,
|
||||
@@ -55,7 +73,7 @@ CuTe DSL Blackwell SM100 kernels. Users can specify preferred and fallback clust
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_4.py \
|
||||
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_4.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
@@ -96,22 +114,20 @@ class SharedStorage:
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
# Only for CLC Dynamic Scheduler
|
||||
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
|
||||
clc_response_align_bytes = num_clc_response_bytes
|
||||
clc_response: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, 4],
|
||||
clc_response_align_bytes,
|
||||
]
|
||||
|
||||
|
||||
@cute.jit
|
||||
def cluster_specific_kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
tma_a: cpasync.TmaInfo,
|
||||
tma_b: cpasync.TmaInfo,
|
||||
tma_c: cpasync.TmaInfo,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
cluster_shape_mnk: Tuple[int, int, int],
|
||||
@@ -120,6 +136,11 @@ def cluster_specific_kernel(
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
):
|
||||
# Extract tma_tensor from TmaInfo
|
||||
mA_mkl = tma_a.tma_tensor
|
||||
mB_nkl = tma_b.tma_tensor
|
||||
mC_mnl = tma_c.tma_tensor
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
@@ -148,9 +169,9 @@ def cluster_specific_kernel(
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
cpasync.prefetch_descriptor(tma_a.atom)
|
||||
cpasync.prefetch_descriptor(tma_b.atom)
|
||||
cpasync.prefetch_descriptor(tma_c.atom)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
@@ -192,8 +213,8 @@ def cluster_specific_kernel(
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
@@ -273,21 +294,21 @@ def cluster_specific_kernel(
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
layout=tma_a.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
swizzle=tma_a.smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
layout=tma_b.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
swizzle=tma_b.smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
layout=tma_c.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
swizzle=tma_c.smem_layout.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
@@ -326,7 +347,7 @@ def cluster_specific_kernel(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
@@ -336,7 +357,7 @@ def cluster_specific_kernel(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
@@ -346,7 +367,7 @@ def cluster_specific_kernel(
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
@@ -403,14 +424,14 @@ def cluster_specific_kernel(
|
||||
|
||||
# Issue TMA loads
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
@@ -471,23 +492,14 @@ def cluster_specific_kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
@@ -520,10 +532,10 @@ def cluster_specific_kernel(
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
# Initialize TMA store pipeline for epilogue with 4 warps
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
pipeline.Agent.Warp,
|
||||
size=4,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
@@ -618,7 +630,7 @@ def cluster_specific_kernel(
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
@@ -652,20 +664,12 @@ def cluster_specific_kernel(
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a_preferred: cute.CopyAtom,
|
||||
mA_mkl_preferred: cute.Tensor,
|
||||
tma_atom_b_preferred: cute.CopyAtom,
|
||||
mB_nkl_preferred: cute.Tensor,
|
||||
tma_atom_a_fallback: cute.CopyAtom,
|
||||
mA_mkl_fallback: cute.Tensor,
|
||||
tma_atom_b_fallback: cute.CopyAtom,
|
||||
mB_nkl_fallback: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
tma_a_preferred: cpasync.TmaInfo,
|
||||
tma_b_preferred: cpasync.TmaInfo,
|
||||
tma_a_fallback: cpasync.TmaInfo,
|
||||
tma_b_fallback: cpasync.TmaInfo,
|
||||
tma_c: cpasync.TmaInfo,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
preferred_cta_layout_vmnk: cute.Layout,
|
||||
fallback_cta_layout_vmnk: cute.Layout,
|
||||
@@ -687,19 +691,15 @@ def kernel(
|
||||
)
|
||||
|
||||
# As for now, only support preferred cluster kernel via the mega-kernel approach
|
||||
# mega-kernel approach has 2 mutually exclusive code branches, only one path runs per launch,
|
||||
# specify `smem_merge_branch_allocs=True` at launch to enables shared memory reuse between two paths
|
||||
if is_preferred_cluster:
|
||||
cluster_specific_kernel(
|
||||
tiled_mma,
|
||||
tma_atom_a_preferred,
|
||||
mA_mkl_preferred,
|
||||
tma_atom_b_preferred,
|
||||
mB_nkl_preferred,
|
||||
tma_atom_c,
|
||||
mC_mnl,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
tma_a_preferred,
|
||||
tma_b_preferred,
|
||||
tma_c,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
preferred_cta_layout_vmnk,
|
||||
preferred_cluster_shape_mnk,
|
||||
@@ -708,16 +708,10 @@ def kernel(
|
||||
else:
|
||||
cluster_specific_kernel(
|
||||
tiled_mma,
|
||||
tma_atom_a_fallback,
|
||||
mA_mkl_fallback,
|
||||
tma_atom_b_fallback,
|
||||
mB_nkl_fallback,
|
||||
tma_atom_c,
|
||||
mC_mnl,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
tma_a_fallback,
|
||||
tma_b_fallback,
|
||||
tma_c,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
fallback_cta_layout_vmnk,
|
||||
fallback_cluster_shape_mnk,
|
||||
@@ -814,8 +808,8 @@ def host_function(
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
@@ -860,37 +854,35 @@ def host_function(
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
tma_atom_a_fallback, a_tma_tensor_fallback = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
tma_a_fallback = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
a_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
fallback_cta_layout_vmnk.shape,
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
tma_atom_b_fallback, b_tma_tensor_fallback = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
tma_b_fallback = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
b_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
fallback_cta_layout_vmnk.shape,
|
||||
)
|
||||
|
||||
tma_atom_a_preferred, a_tma_tensor_preferred = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
tma_a_preferred = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
a_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
preferred_cta_layout_vmnk.shape,
|
||||
)
|
||||
tma_atom_b_preferred, b_tma_tensor_preferred = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
tma_b_preferred = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
b_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
preferred_cta_layout_vmnk.shape,
|
||||
@@ -915,12 +907,11 @@ def host_function(
|
||||
epi_tile,
|
||||
epi_stages,
|
||||
)
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
|
||||
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
@@ -945,20 +936,12 @@ def host_function(
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
tma_atom_a_preferred,
|
||||
a_tma_tensor_preferred,
|
||||
tma_atom_b_preferred,
|
||||
b_tma_tensor_preferred,
|
||||
tma_atom_a_fallback,
|
||||
a_tma_tensor_fallback,
|
||||
tma_atom_b_fallback,
|
||||
b_tma_tensor_fallback,
|
||||
tma_atom_c,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
tma_a_preferred,
|
||||
tma_b_preferred,
|
||||
tma_a_fallback,
|
||||
tma_b_fallback,
|
||||
tma_c,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
preferred_cta_layout_vmnk,
|
||||
fallback_cta_layout_vmnk,
|
||||
@@ -969,6 +952,7 @@ def host_function(
|
||||
block=[224, 1, 1] if use_clc_dynamic_scheduler else [192, 1, 1],
|
||||
cluster=preferred_cluster_shape_mnk,
|
||||
fallback_cluster=fallback_cluster_shape_mnk,
|
||||
smem_merge_branch_allocs=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -981,7 +965,7 @@ def run_dense_gemm(
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
print("===================================================================")
|
||||
print("Running Blackwell fp16 GEMM example 4 (with MIX cluster size support):")
|
||||
print("Running Blackwell fp16 GEMM example 4 (with MIX cluster support):")
|
||||
print(f" mnk: {mnk}")
|
||||
print(f" tolerance: {tolerance}")
|
||||
print(f" Preferred cluster shape: {preferred_cluster_shape_mnk}")
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# This is the fifth tutorial GEMM (5). It extends fp16_gemm_3_1.py by adding TMA prefetch.
|
||||
# TMA prefetch uses cute.prefetch() to bring data into L2 cache before TMA copy needs it,
|
||||
@@ -44,7 +62,7 @@ Key differences from fp16_gemm_3_1.py:
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_5.py \
|
||||
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_5.py \
|
||||
--mnk 8192,8192,8192
|
||||
|
||||
Constraints for this example:
|
||||
@@ -84,22 +102,20 @@ class SharedStorage:
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
# Only for CLC Dynamic Scheduler
|
||||
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
|
||||
clc_response_align_bytes = num_clc_response_bytes
|
||||
clc_response: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, 4],
|
||||
clc_response_align_bytes,
|
||||
]
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
def kernel(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
tma_a: cpasync.TmaInfo,
|
||||
tma_b: cpasync.TmaInfo,
|
||||
tma_c: cpasync.TmaInfo,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: Union[
|
||||
@@ -107,6 +123,11 @@ def kernel(
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
):
|
||||
# Extract tma_tensor from TmaInfo
|
||||
mA_mkl = tma_a.tma_tensor
|
||||
mB_nkl = tma_b.tma_tensor
|
||||
mC_mnl = tma_c.tma_tensor
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
@@ -135,9 +156,9 @@ def kernel(
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
cpasync.prefetch_descriptor(tma_a.atom)
|
||||
cpasync.prefetch_descriptor(tma_b.atom)
|
||||
cpasync.prefetch_descriptor(tma_c.atom)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
@@ -179,8 +200,8 @@ def kernel(
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
@@ -260,21 +281,21 @@ def kernel(
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
layout=tma_a.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
swizzle=tma_a.smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
layout=tma_b.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
swizzle=tma_b.smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
layout=tma_c.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
swizzle=tma_c.smem_layout.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
@@ -313,7 +334,7 @@ def kernel(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
@@ -323,7 +344,7 @@ def kernel(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
@@ -333,7 +354,7 @@ def kernel(
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
@@ -396,8 +417,8 @@ def kernel(
|
||||
for pf_k_tile in cutlass.range(
|
||||
cutlass.min(prefetch_dist, num_k_tiles), unroll=1
|
||||
):
|
||||
cute.prefetch(tma_atom_a, tAgA_slice[(None, pf_k_tile)])
|
||||
cute.prefetch(tma_atom_b, tBgB_slice[(None, pf_k_tile)])
|
||||
cute.prefetch(tma_a.atom, tAgA_slice[(None, pf_k_tile)])
|
||||
cute.prefetch(tma_b.atom, tBgB_slice[(None, pf_k_tile)])
|
||||
|
||||
# =========================================================
|
||||
# TMA Load Loop with Rolling Prefetch
|
||||
@@ -408,14 +429,14 @@ def kernel(
|
||||
|
||||
# Issue TMA loads (use k_tile_idx like fp16_gemm_3_1.py)
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
@@ -426,8 +447,8 @@ def kernel(
|
||||
# This keeps the L2 primed as we progress through the K dimension
|
||||
if k_tile_idx + prefetch_dist < num_k_tiles:
|
||||
future_k_tile = k_tile_idx + prefetch_dist
|
||||
cute.prefetch(tma_atom_a, tAgA_slice[(None, future_k_tile)])
|
||||
cute.prefetch(tma_atom_b, tBgB_slice[(None, future_k_tile)])
|
||||
cute.prefetch(tma_a.atom, tAgA_slice[(None, future_k_tile)])
|
||||
cute.prefetch(tma_b.atom, tBgB_slice[(None, future_k_tile)])
|
||||
|
||||
# Advance to next tile
|
||||
if cutlass.const_expr(use_clc_dynamic_scheduler):
|
||||
@@ -483,24 +504,14 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
@@ -533,10 +544,10 @@ def kernel(
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
# Initialize TMA store pipeline for epilogue with 4 warps
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
pipeline.Agent.Warp,
|
||||
size=4,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
@@ -631,7 +642,7 @@ def kernel(
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
@@ -715,8 +726,8 @@ def host_function(
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
@@ -754,20 +765,18 @@ def host_function(
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
tma_atom_a, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
a_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
tma_atom_b, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b,
|
||||
b_smem_layout_slice,
|
||||
b_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
@@ -793,11 +802,10 @@ def host_function(
|
||||
epi_stages,
|
||||
)
|
||||
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
@@ -815,16 +823,10 @@ def host_function(
|
||||
|
||||
kernel(
|
||||
tiled_mma,
|
||||
tma_atom_a,
|
||||
a_tma_tensor,
|
||||
tma_atom_b,
|
||||
b_tma_tensor,
|
||||
tma_atom_c,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
tma_a,
|
||||
tma_b,
|
||||
tma_c,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
tile_sched_params,
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
# This is the sixth tutorial GEMM. It enables programmatic dependent launch (PDL) features.
|
||||
|
||||
@@ -57,7 +75,7 @@ For --mnk 256,8192,128, the speedup pdl v.s. no pdl can be up to 1.16x.
|
||||
|
||||
To run this example:
|
||||
.. code-block:: bash
|
||||
python examples/blackwell/tutorial_gemm/fp16_gemm_6.py \
|
||||
python examples/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_6.py \
|
||||
--mnk 256,8192,128
|
||||
|
||||
Constraints for this example:
|
||||
@@ -100,7 +118,11 @@ class SharedStorage:
|
||||
tmem_holding_buffer: cutlass.Int32
|
||||
# Only for CLC Dynamic Scheduler
|
||||
clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
clc_response: cute.struct.MemRange[cutlass.Int32, 4]
|
||||
clc_response_align_bytes = num_clc_response_bytes
|
||||
clc_response: cute.struct.Align[
|
||||
cute.struct.MemRange[cutlass.Int32, 4],
|
||||
clc_response_align_bytes,
|
||||
]
|
||||
|
||||
|
||||
@cute.kernel()
|
||||
@@ -133,16 +155,10 @@ def dequantize(
|
||||
@cute.kernel()
|
||||
def gemm(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tma_atom_a: cute.CopyAtom,
|
||||
mA_mkl: cute.Tensor,
|
||||
tma_atom_b: cute.CopyAtom,
|
||||
mB_nkl: cute.Tensor,
|
||||
tma_atom_c: cute.CopyAtom,
|
||||
mC_mnl: cute.Tensor,
|
||||
a_smem_layout: cute.ComposedLayout,
|
||||
b_smem_layout: cute.ComposedLayout,
|
||||
tma_a: cpasync.TmaInfo,
|
||||
tma_b: cpasync.TmaInfo,
|
||||
tma_c: cpasync.TmaInfo,
|
||||
c_smem_layout_kind: cutlass.Constexpr,
|
||||
epi_smem_layout_staged: cute.ComposedLayout,
|
||||
epi_tile: cute.Tile,
|
||||
cta_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: Union[
|
||||
@@ -150,6 +166,11 @@ def gemm(
|
||||
utils.PersistentTileSchedulerParams,
|
||||
],
|
||||
):
|
||||
# Extract tma_tensor from TmaInfo
|
||||
mA_mkl = tma_a.tma_tensor
|
||||
mB_nkl = tma_b.tma_tensor
|
||||
mC_mnl = tma_c.tma_tensor
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
@@ -178,9 +199,9 @@ def gemm(
|
||||
|
||||
# Prefetch tma descriptor
|
||||
if warp_idx == tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_c)
|
||||
cpasync.prefetch_descriptor(tma_a.atom)
|
||||
cpasync.prefetch_descriptor(tma_b.atom)
|
||||
cpasync.prefetch_descriptor(tma_c.atom)
|
||||
|
||||
# As many participants as the number of threads issuing the MMA in the same row and column
|
||||
# Substract one to not count twice the same thread
|
||||
@@ -222,8 +243,8 @@ def gemm(
|
||||
)
|
||||
|
||||
num_tma_copy_bytes = (
|
||||
cute.size_in_bytes(io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2]))
|
||||
cute.size_in_bytes(io_dtype, cute.select(tma_a.smem_layout, mode=[0, 1, 2]))
|
||||
+ cute.size_in_bytes(io_dtype, cute.select(tma_b.smem_layout, mode=[0, 1, 2]))
|
||||
) * cute.size(cta_layout_vmnk, mode=[0])
|
||||
|
||||
# Threads/warps participating in the mainloop pipeline
|
||||
@@ -303,21 +324,21 @@ def gemm(
|
||||
# Allocate SMEM
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=a_smem_layout.outer,
|
||||
layout=tma_a.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout.inner,
|
||||
swizzle=tma_a.smem_layout.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=b_smem_layout.outer,
|
||||
layout=tma_b.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout.inner,
|
||||
swizzle=tma_b.smem_layout.inner,
|
||||
)
|
||||
sC = smem.allocate_tensor(
|
||||
element_type=io_dtype,
|
||||
layout=epi_smem_layout_staged.outer,
|
||||
layout=tma_c.smem_layout.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=epi_smem_layout_staged.inner,
|
||||
swizzle=tma_c.smem_layout.inner,
|
||||
)
|
||||
|
||||
# Partition tensors for MMA and make fragments
|
||||
@@ -356,7 +377,7 @@ def gemm(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestM, RestK)
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
cta_in_cluster_coord_vmnk[2],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[2])),
|
||||
cute.group_modes(sA, 0, 3),
|
||||
@@ -366,7 +387,7 @@ def gemm(
|
||||
# ((atom_v, rest_v), STAGE)
|
||||
# ((atom_v, rest_v), RestN, RestK)
|
||||
tBsB, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
cta_in_cluster_coord_vmnk[1],
|
||||
cute.make_layout(cute.size(cta_layout_vmnk, mode=[1])),
|
||||
cute.group_modes(sB, 0, 3),
|
||||
@@ -376,7 +397,7 @@ def gemm(
|
||||
gC_epi = cute.flat_divide(tCgC[((None, None), 0, 0, None, None)], epi_tile)
|
||||
|
||||
tCsC, tCgC_tma = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
@@ -440,14 +461,14 @@ def gemm(
|
||||
|
||||
# Issue TMA loads
|
||||
cute.copy(
|
||||
tma_atom_a,
|
||||
tma_a.atom,
|
||||
tAgA_slice[(None, k_tile_idx)],
|
||||
tAsA[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
mcast_mask=tma_mcast_mask_a,
|
||||
)
|
||||
cute.copy(
|
||||
tma_atom_b,
|
||||
tma_b.atom,
|
||||
tBgB_slice[(None, k_tile_idx)],
|
||||
tBsB[(None, handle.index)],
|
||||
tma_bar_ptr=handle.barrier,
|
||||
@@ -508,23 +529,14 @@ def gemm(
|
||||
# (MMA, MMA_M, MMA_N)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, acc_empty.index)]
|
||||
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for k_tile_idx in range(num_k_tiles):
|
||||
# Wait for TMA copies to complete
|
||||
handle = ab_consumer.wait_and_advance()
|
||||
|
||||
# Execute one K-block worth of MMA instructions
|
||||
num_k_blocks = cute.size(tCrA, mode=[2])
|
||||
for k_block_idx in cutlass.range_constexpr(num_k_blocks):
|
||||
k_block_coord = (None, None, k_block_idx, handle.index)
|
||||
cute.gemm(
|
||||
tiled_mma,
|
||||
tCtAcc,
|
||||
tCrA[k_block_coord],
|
||||
tCrB[k_block_coord],
|
||||
tCtAcc,
|
||||
)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile_idx != 0)
|
||||
tile_crd = (None, None, None, handle.index)
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc)
|
||||
|
||||
# Signal that the A/B buffers have been consumed and are ready for the next load
|
||||
handle.release()
|
||||
@@ -557,10 +569,10 @@ def gemm(
|
||||
# (MMA, MMA_M, MMA_N, STAGE)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# Initialize TMA store pipeline for epilogue
|
||||
# Initialize TMA store pipeline for epilogue with 4 warps
|
||||
epilogue_pipeline_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread,
|
||||
size=128,
|
||||
pipeline.Agent.Warp,
|
||||
size=4,
|
||||
)
|
||||
epilogue_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=epi_stages,
|
||||
@@ -654,7 +666,7 @@ def gemm(
|
||||
# SMEM -> GMEM
|
||||
if warp_idx == epilogue_warp_ids[0]:
|
||||
cute.copy(
|
||||
tma_atom_c,
|
||||
tma_c.atom,
|
||||
tCsC[(None, c_buffer)],
|
||||
tCgC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
@@ -750,8 +762,8 @@ def host_function(
|
||||
mma_inst_shape_mnk,
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE,
|
||||
tcgen05.OperandSource.SMEM,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
tcgen05.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
)
|
||||
tiled_mma = cute.make_tiled_mma(op)
|
||||
|
||||
@@ -789,20 +801,18 @@ def host_function(
|
||||
op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp(
|
||||
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
)
|
||||
a_smem_layout_slice = cute.slice_(a_smem_layout, (None, None, None, 0))
|
||||
tma_atom_a, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
tma_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
op,
|
||||
a,
|
||||
a_smem_layout_slice,
|
||||
a_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
)
|
||||
b_smem_layout_slice = cute.slice_(b_smem_layout, (None, None, None, 0))
|
||||
tma_atom_b, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
tma_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
op,
|
||||
b_dequantized,
|
||||
b_smem_layout_slice,
|
||||
b_smem_layout,
|
||||
mma_tiler_mnk,
|
||||
tiled_mma,
|
||||
cta_layout_vmnk.shape,
|
||||
@@ -828,11 +838,10 @@ def host_function(
|
||||
epi_stages,
|
||||
)
|
||||
|
||||
epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, c_tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
tma_c = cute.nvgpu.cpasync.make_tiled_tma_atom(
|
||||
cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(),
|
||||
c,
|
||||
epi_smem_layout,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
)
|
||||
|
||||
@@ -862,16 +871,10 @@ def host_function(
|
||||
|
||||
gemm(
|
||||
tiled_mma,
|
||||
tma_atom_a,
|
||||
a_tma_tensor,
|
||||
tma_atom_b,
|
||||
b_tma_tensor,
|
||||
tma_atom_c,
|
||||
c_tma_tensor,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
tma_a,
|
||||
tma_b,
|
||||
tma_c,
|
||||
c_smem_layout_kind,
|
||||
epi_smem_layout_staged,
|
||||
epi_tile,
|
||||
cta_layout_vmnk,
|
||||
tile_sched_params,
|
||||
|
||||
@@ -34,7 +34,7 @@ import cuda.bindings.driver as cuda
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
import cutlass.cute.testing as testing
|
||||
from cutlass import testing
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
@@ -42,6 +42,11 @@ import cutlass.utils.hopper_helpers as sm90_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
import cutlass.utils.blackwell_helpers as sm120_utils
|
||||
|
||||
# SM120 block-scaled GEMM dispatch helpers (sibling utility module). Try the
|
||||
# namespace-package import path first (used under pytest, where
|
||||
# `python/examples/CuTeDSL/cute` is on sys.path); fall back to the bare local
|
||||
# import for standalone-script invocation, where Python places only the
|
||||
# script's own directory on sys.path[0].
|
||||
try:
|
||||
from blackwell_geforce.kernel.blockscaled_gemm.blockscaled_gemm_dispatch import (
|
||||
FP4_SHIFT_BITS,
|
||||
@@ -49,8 +54,8 @@ try:
|
||||
make_sm120_blockscaled_mma_op,
|
||||
validate_blockscaled_args,
|
||||
)
|
||||
except ImportError:
|
||||
from blockscaled_gemm_dispatch import (
|
||||
except ImportError: # pragma: no cover - exercised only via standalone-script invocation
|
||||
from blockscaled_gemm_dispatch import ( # noqa: E402
|
||||
FP4_SHIFT_BITS,
|
||||
make_ldmatrix_atom,
|
||||
make_sm120_blockscaled_mma_op,
|
||||
@@ -705,9 +710,6 @@ class Sm120BlockScaledGemmKernel:
|
||||
tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0])
|
||||
tCrSFA = sm120_utils.partition_fragment_SFA(sSFA[None, None, 0], thr_mma, tidx)
|
||||
tCrSFB = sm120_utils.partition_fragment_SFB(sSFB[None, None, 0], thr_mma, tidx)
|
||||
# Keep residual K modes nested to match the C++ SM120 block-scaled mainloop.
|
||||
tCrSFA = cute.group_modes(tCrSFA, 2, cute.rank(tCrSFA))
|
||||
tCrSFB = cute.group_modes(tCrSFB, 2, cute.rank(tCrSFB))
|
||||
|
||||
tCgC = thr_mma.partition_C(gC_mnl)
|
||||
acc_shape = tCgC.shape[:3]
|
||||
@@ -879,8 +881,6 @@ class Sm120BlockScaledGemmKernel:
|
||||
tCsSFB_copy_view = thr_copy_ldmatrix_SFB.partition_S(sSFB)
|
||||
tCrSFB_copy_view = thr_copy_ldmatrix_SFB.retile(tCrSFB)
|
||||
|
||||
epi_buffer = cutlass.Int32(0)
|
||||
|
||||
if warp_group_idx == 1:
|
||||
tile_sched.advance_to_next_work()
|
||||
mainloop_consumer_state = self.advance(
|
||||
@@ -1051,7 +1051,6 @@ class Sm120BlockScaledGemmKernel:
|
||||
)
|
||||
|
||||
if k_block_idx == num_k_blocks - 1:
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
mainloop_pipeline.consumer_release(mainloop_consumer_state)
|
||||
mainloop_consumer_state.advance()
|
||||
|
||||
@@ -1210,8 +1209,7 @@ class Sm120BlockScaledGemmKernel:
|
||||
tRS_rD_out.store(acc_vec.to(self.c_dtype))
|
||||
|
||||
# Register to shared memory
|
||||
epi_buffer = epi_buffer + 1
|
||||
epi_buffer = epi_buffer % cute.size(
|
||||
epi_buffer = (epi_m * epi_rest_n + epi_n) % cute.size(
|
||||
tRS_sD, mode=[3]
|
||||
)
|
||||
self.epilog_sync_barrier.arrive_and_wait()
|
||||
@@ -1242,6 +1240,7 @@ class Sm120BlockScaledGemmKernel:
|
||||
tile_sched.advance_to_next_work()
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
tma_store_pipeline.producer_tail()
|
||||
math_wg_order_state = math_wg_order_barrier.arrive(math_wg_order_state)
|
||||
# End of for k_tile loop
|
||||
# End of while loop
|
||||
@@ -1886,24 +1885,6 @@ def run_bs(
|
||||
cute.testing.convert(ref_f8, ref_tensor)
|
||||
ref = ref_device.cpu()
|
||||
torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02)
|
||||
elif c_dtype is cutlass.Float4E2M1FN:
|
||||
# Convert ref : f32 -> f4 -> f32
|
||||
ref_f4_ = torch.empty(*(l, m, n), dtype=torch.uint8, device="cuda").permute(
|
||||
1, 2, 0
|
||||
)
|
||||
ref_f4 = from_dlpack(ref_f4_, assumed_align=16).mark_layout_dynamic(
|
||||
leading_dim=1
|
||||
)
|
||||
ref_f4.element_type = c_dtype
|
||||
ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda()
|
||||
ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic(
|
||||
leading_dim=1
|
||||
)
|
||||
cute.testing.convert(ref_tensor, ref_f4)
|
||||
cute.testing.convert(ref_f4, ref_tensor)
|
||||
ref = ref_device.cpu()
|
||||
torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02)
|
||||
|
||||
def generate_tensors():
|
||||
a_tensor, _ = cutlass_torch.cute_tensor_like(
|
||||
a_ref, a_dtype, is_dynamic_layout=True, assumed_align=16
|
||||
@@ -1933,7 +1914,7 @@ def run_bs(
|
||||
|
||||
_, sfa_tensor, _ = create_scale_factor_tensor(l, m, k, sf_vec_size, sf_dtype)
|
||||
_, sfb_tensor, _ = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype)
|
||||
return cute.testing.JitArguments(
|
||||
return cutlass.testing.JitArguments(
|
||||
a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, stream
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from functools import partial\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"import cutlass\n",
|
||||
@@ -332,8 +331,8 @@
|
||||
"\n",
|
||||
" # Print results\n",
|
||||
" # ------------\n",
|
||||
" print(f\"Performance Metrics:\")\n",
|
||||
" print(f\"-------------------\")\n",
|
||||
" print(\"Performance Metrics:\")\n",
|
||||
" print(\"-------------------\")\n",
|
||||
" print(f\"Kernel execution time: {avg_time_us:.4f} us\")\n",
|
||||
" print(f\"Memory throughput: {achieved_bandwidth:.2f} GB/s\")"
|
||||
]
|
||||
@@ -1082,7 +1081,7 @@
|
||||
" ###############################################################################\n",
|
||||
" # Compute predicate for out of boundary checks\n",
|
||||
" ###############################################################################\n",
|
||||
" frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean)\n",
|
||||
" frgPred = cute.make_rmem_tensor(thrCrd.shape, cutlass.Boolean)\n",
|
||||
" print(f\"[DSL INFO] frgPred = {frgPred.type}\")\n",
|
||||
"\n",
|
||||
" for i in cutlass.range_constexpr(cute.size(frgPred)):\n",
|
||||
|
||||
106
examples/python/CuTeDSL/dsl_tutorials/README.md
Normal file
106
examples/python/CuTeDSL/dsl_tutorials/README.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# DSL Feature Examples
|
||||
|
||||
This directory demonstrates **CuTe DSL capabilities** beyond kernel authoring itself:
|
||||
exporting compiled kernels for deployment, integrating with ML frameworks, using
|
||||
foreign function interfaces, and accessing low-level DSL features like inline PTX
|
||||
and shared memory allocation.
|
||||
|
||||
---
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
dsl/
|
||||
export/ Exporting kernels to C shared libraries
|
||||
export_to_c.py Compile a kernel and export as .so/.dylib
|
||||
load_in_python.py Load and call the exported library from Python
|
||||
run_with_dynamic_loading.cpp C++ driver using dlopen
|
||||
run_with_dynamic_loading.sh Build/run script for dynamic loading
|
||||
run_with_static_linking.cpp C++ driver using static linking
|
||||
run_with_static_linking.sh Build/run script for static linking
|
||||
ffi/ Foreign function interface
|
||||
jit_argument.py JIT compilation with argument passing
|
||||
tensor.cpp C++ tensor interop implementation
|
||||
CMakeLists.txt CMake build for FFI examples
|
||||
jax/ JAX integration
|
||||
cutlass_call_basic.py Basic CUTLASS kernel call from JAX
|
||||
cutlass_call_export.py Export a CUTLASS kernel for JAX
|
||||
cutlass_call_sharding.py Multi-device sharding with CUTLASS kernels
|
||||
elementwise_apply_example.py Elementwise apply via JAX
|
||||
tvm_ffi/ TVM FFI integration
|
||||
jit_and_use_in_torch.py JIT compile and call from PyTorch
|
||||
jit_and_use_in_jax.py JIT compile and call from JAX
|
||||
aot_export.py Ahead-of-time export
|
||||
aot_use_in_torch.py Use AOT-exported kernel in PyTorch
|
||||
aot_use_in_jax.py Use AOT-exported kernel in JAX
|
||||
aot_use_in_cpp_bundle.cpp Use AOT-exported kernel in C++
|
||||
aot_use_in_cpp_bundle.sh Build/run script for C++ AOT usage
|
||||
compile_with_fake_tensor.py Compile using fake tensors
|
||||
compile_with_symint_arg.py Compile with symbolic integer arguments
|
||||
ampere_gemm_with_fake_tensor.py Ampere GEMM with fake tensor compilation
|
||||
error_reporting.py Error reporting and diagnostics
|
||||
call_bypass_dlpack.py Calling kernels bypassing DLPack
|
||||
call_from_jit.py Calling conventions from JIT-compiled code
|
||||
cooperative_launch.py Cooperative kernel launch (multi-CTA)
|
||||
dynamic_smem_size.py Dynamic shared memory allocation
|
||||
inline_ptx.py Embedding inline PTX assembly
|
||||
launch_completion_and_programmatic_events.py
|
||||
Launch completion / programmatic events with cudaEvent_t and CUevent
|
||||
pointer.py Pointer manipulation in DSL
|
||||
print_latex.py LaTeX rendering of CuTe layouts
|
||||
programmatic_dependent_launch.py Programmatic dependent launch (PDL)
|
||||
smem_allocator.py Shared memory allocator usage
|
||||
torch_fake_tensor.py PyTorch fake tensor integration
|
||||
torch_fp4.py PyTorch FP4 tensor support
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Subdirectory Guides
|
||||
|
||||
### `export/` -- Kernel Export
|
||||
|
||||
Shows how to compile a CuTe DSL kernel into a standalone C shared library (`.so`)
|
||||
that can be loaded and called from C++ or Python without any CuTe DSL dependency
|
||||
at runtime. Includes complete examples for both dynamic loading (`dlopen`) and
|
||||
static linking workflows.
|
||||
|
||||
### `ffi/` -- Foreign Function Interface
|
||||
|
||||
Demonstrates how to pass arguments between Python/CuTe DSL and C++ code using
|
||||
the FFI layer. Useful for integrating CuTe DSL kernels into existing C++
|
||||
applications.
|
||||
|
||||
### `jax/` -- JAX Integration
|
||||
|
||||
Shows how to call CuTe DSL kernels from JAX using `cutlass_call`, including
|
||||
basic invocation, kernel export for JAX, multi-device sharding, and elementwise
|
||||
application patterns.
|
||||
|
||||
### `tvm_ffi/` -- TVM FFI Integration
|
||||
|
||||
Comprehensive examples for using CuTe DSL kernels through TVM's foreign function
|
||||
interface. Covers both JIT and AOT (ahead-of-time) compilation workflows, with
|
||||
usage examples for PyTorch, JAX, and C++. Also demonstrates fake-tensor
|
||||
compilation (no GPU required at compile time) and symbolic integer arguments.
|
||||
|
||||
---
|
||||
|
||||
## Top-Level Files
|
||||
|
||||
The top-level Python files demonstrate individual DSL features:
|
||||
|
||||
- **`call_bypass_dlpack.py`** / **`call_from_jit.py`** -- Kernel calling conventions
|
||||
- **`inline_ptx.py`** -- Embedding inline PTX assembly in CuTe DSL kernels
|
||||
- **`launch_completion_and_programmatic_events.py`** -- Examples of
|
||||
``launch_completion_event`` and ``programmatic_event`` launch attributes,
|
||||
using events created via ``torch.cuda.Event(enable_timing=False)`` and
|
||||
presented as either ``cudaEvent_t`` (`cuda.bindings.runtime`) or ``CUevent`` (`cuda.bindings.driver`). The
|
||||
stream is passed as a ``cudaStream_t`` (`cuda.bindings.runtime`)
|
||||
- **`programmatic_dependent_launch.py`** -- Programmatic dependent launch for
|
||||
chaining kernels with data dependencies
|
||||
- **`cooperative_launch.py`** -- Cooperative launch for multi-CTA kernels
|
||||
- **`dynamic_smem_size.py`** / **`smem_allocator.py`** -- Shared memory allocation
|
||||
- **`torch_fake_tensor.py`** / **`torch_fp4.py`** -- PyTorch integration features
|
||||
- **`pointer.py`** -- Pointer manipulation within DSL kernels
|
||||
- **`print_latex.py`** -- Render CuTe layouts as LaTeX for visualization
|
||||
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Launch Completion Events and Programmatic Events Example
|
||||
=======================================================
|
||||
|
||||
This module demonstrates the two CUDA kernel-launch attributes that record a
|
||||
``cudaEvent_t`` / ``CUevent`` as part of a launch:
|
||||
|
||||
1. ``cudaLaunchAttributeLaunchCompletionEvent``
|
||||
The event is recorded when all blocks of the grid have begun executing
|
||||
(best-effort, used for launch ordering). This attribute is
|
||||
usable on **any** compute capability supported by CuTeDSL.
|
||||
|
||||
2. ``cudaLaunchAttributeProgrammaticEvent``
|
||||
Part of the Programmatic Dependent Launch (PDL) model. The event is recorded
|
||||
either:
|
||||
|
||||
* after **every block** in the grid has called
|
||||
``cute.arch.griddepcontrol_launch_dependents()`` (or terminated) - this is
|
||||
selected with ``trigger_at_block_start=0``. The kernel must call the trigger
|
||||
itself.
|
||||
* automatically at each block start - selected with ``trigger_at_block_start=1``.
|
||||
The timing is similar to the launch-completion event, but the resulting
|
||||
event remains part of the programmatic dependency model and is visible to
|
||||
the next kernel's ``cute.arch.griddepcontrol_wait()``.
|
||||
|
||||
``programmatic_event`` requires sm_90 (Hopper) or newer.
|
||||
|
||||
The example demonstrates each attribute by launching a kernel with the attribute
|
||||
attached, passing the stream as a ``cudaStream_t`` (runtime bindings) and the event
|
||||
either as a ``cudaEvent_t`` (runtime) or ``CUevent`` (driver). The events
|
||||
themselves are created from PyTorch with ``torch.cuda.Event(enable_timing=False)``
|
||||
so they carry ``cudaEventDisableTiming``, which is required by both launch
|
||||
attributes.
|
||||
|
||||
Usage::
|
||||
|
||||
python launch_completion_and_programmatic_events.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import cuda.bindings.driver as cuda_driver
|
||||
import cuda.bindings.runtime as cuda_runtime
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Feature gate - queried via torch
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def supports_programmatic_event() -> bool:
|
||||
"""``programmatic_event`` is part of the programmatic dependency model and requires Hopper (sm_90+)."""
|
||||
return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kernels
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def simple_kernel(out: cute.Tensor, value: cutlass.Int32):
|
||||
"""Write ``value`` into each element of ``out``.
|
||||
|
||||
Used to demonstrate ``launch_completion_event``: the event fires
|
||||
automatically when all blocks have begun executing.
|
||||
"""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
if tidx < cute.size(out, [0]) and bidx < cute.size(out, [1]):
|
||||
out[tidx, bidx] = value
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def programmatic_trigger_kernel(
|
||||
out: cute.Tensor,
|
||||
value: cutlass.Int32,
|
||||
trigger_at_block_start: cutlass.Constexpr[bool],
|
||||
):
|
||||
"""Write ``value`` into each element of ``out``, then trigger the
|
||||
programmatic launch-completion signal.
|
||||
|
||||
Used to demonstrate ``programmatic_event``.
|
||||
|
||||
With ``trigger_at_block_start=False``, every block must execute
|
||||
``cute.arch.griddepcontrol_launch_dependents()`` (from the DSL:
|
||||
``cute.arch.griddepcontrol_launch_dependents()``) for the event to fire.
|
||||
With ``trigger_at_block_start=True`` the event fires automatically at the
|
||||
block start.
|
||||
"""
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
cute.arch.griddepcontrol_wait()
|
||||
|
||||
if cutlass.const_expr(not trigger_at_block_start):
|
||||
cute.arch.griddepcontrol_launch_dependents()
|
||||
|
||||
if tidx < cute.size(out, [0]) and bidx < cute.size(out, [1]):
|
||||
out[tidx, bidx] = value
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JIT host functions - each exercises a single launch attribute
|
||||
# =============================================================================
|
||||
|
||||
|
||||
THREADS_PER_BLOCK = 128
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_with_launch_completion_event(
|
||||
out: cute.Tensor,
|
||||
value: cutlass.Int32,
|
||||
stream: cuda_runtime.cudaStream_t,
|
||||
launch_completion_event: Union[cuda_runtime.cudaEvent_t, cuda_driver.CUevent],
|
||||
):
|
||||
"""Launch ``simple_kernel`` with ``launch_completion_event=launch_completion_event``."""
|
||||
simple_kernel(out, value).launch(
|
||||
grid=(cute.size(out, [1]), 1, 1),
|
||||
block=(cute.size(out, [0]), 1, 1),
|
||||
stream=stream,
|
||||
launch_completion_event=launch_completion_event,
|
||||
launch_completion_event_flags=0, # Optional flags
|
||||
)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_with_programmatic_event(
|
||||
out: cute.Tensor,
|
||||
value: cutlass.Int32,
|
||||
stream: cuda_runtime.cudaStream_t,
|
||||
programmatic_event: Union[cuda_runtime.cudaEvent_t, cuda_driver.CUevent],
|
||||
trigger_at_block_start: cutlass.Constexpr[int] = 0,
|
||||
):
|
||||
"""Launch ``programmatic_trigger_kernel`` with ``programmatic_event=programmatic_event``."""
|
||||
programmatic_trigger_kernel(out, value, trigger_at_block_start == 1).launch(
|
||||
grid=(cute.size(out, [1]), 1, 1),
|
||||
block=(cute.size(out, [0]), 1, 1),
|
||||
stream=stream,
|
||||
programmatic_event=programmatic_event,
|
||||
programmatic_event_flags=0, # Optional flags
|
||||
programmatic_event_trigger_at_block_start=trigger_at_block_start, # Defaults to zero
|
||||
use_pdl=True,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helpers for event creation and synchronization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _make_event(
|
||||
kind: Literal["runtime", "driver"], init_stream: torch.cuda.Stream
|
||||
) -> Tuple[torch.cuda.Event, Union[cuda_runtime.cudaEvent_t, cuda_driver.CUevent]]:
|
||||
"""Create a CUDA event from torch with timing disabled and wrap it as the
|
||||
requested low-level API type.
|
||||
|
||||
Using ``torch.cuda.Event(enable_timing=False)`` guarantees the underlying
|
||||
CUDA event is created with ``cudaEventDisableTiming``, which is required
|
||||
for ``launch_completion_event`` and ``programmatic_event`` launch attributes.
|
||||
|
||||
PyTorch lazily initializes the underlying ``cudaEvent_t`` on the first
|
||||
``record()`` call. We force initialization by recording the event.
|
||||
|
||||
Returns ``(torch_event, cuda_event)``. The torch event
|
||||
must be kept alive for the lifetime of the wrapped cuda event. Torch will
|
||||
destroy the event when ``torch_event`` is garbage-collected.
|
||||
"""
|
||||
torch_event = torch.cuda.Event(enable_timing=False)
|
||||
torch_event.record(init_stream)
|
||||
raw_event = int(torch_event.cuda_event)
|
||||
if raw_event == 0:
|
||||
raise RuntimeError("torch.cuda.Event was not created")
|
||||
if kind == "runtime":
|
||||
return torch_event, cuda_runtime.cudaEvent_t(raw_event)
|
||||
elif kind == "driver":
|
||||
return torch_event, cuda_driver.CUevent(raw_event)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown event kind: {kind!r}; expected 'runtime' or 'driver'"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Run functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def run_launch_completion_event_example(
|
||||
n_elements_per_block: int = 128,
|
||||
n_blocks: int = 48,
|
||||
launch_completion_event_kind: Literal["runtime", "driver"] = "runtime",
|
||||
) -> None:
|
||||
"""Run the ``launch_completion_event`` demonstration.
|
||||
|
||||
Allocates an output tensor, launches ``simple_kernel`` with
|
||||
``launch_completion_event`` attached, then blocks on the event using the
|
||||
matching API and verifies the output.
|
||||
"""
|
||||
print(
|
||||
f"\n[Launch Completion Event] Running launch_completion_event example "
|
||||
f"(launch_completion_event_kind={launch_completion_event_kind!r}, "
|
||||
f"n_elements_per_block={n_elements_per_block}, n_blocks={n_blocks})"
|
||||
)
|
||||
|
||||
out = torch.full(
|
||||
(n_elements_per_block, n_blocks), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
expected = torch.full(
|
||||
(n_elements_per_block, n_blocks), 0, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
torch_stream = torch.cuda.default_stream()
|
||||
stream = cuda_runtime.cudaStream_t(torch_stream.cuda_stream)
|
||||
|
||||
torch_event, launch_completion_event = _make_event(
|
||||
launch_completion_event_kind, torch_stream
|
||||
)
|
||||
|
||||
out_tensor = from_dlpack(out)
|
||||
|
||||
launch_with_launch_completion_event(out_tensor, 0, stream, launch_completion_event)
|
||||
|
||||
torch_event.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, expected)
|
||||
print(
|
||||
f"[Launch Completion Event] {launch_completion_event_kind} event fired and output verified."
|
||||
)
|
||||
|
||||
|
||||
def run_programmatic_event_example(
|
||||
n_elements_per_block: int = 128,
|
||||
n_blocks: int = 48,
|
||||
programmatic_event_kind: Literal["runtime", "driver"] = "runtime",
|
||||
trigger_at_block_start: int = 0,
|
||||
) -> None:
|
||||
"""Run the ``programmatic_event`` demonstration.
|
||||
|
||||
Allocates an output tensor, launches ``programmatic_trigger_kernel`` with
|
||||
``programmatic_event`` attached, then blocks on the event using the matching
|
||||
API and verifies the output.
|
||||
"""
|
||||
if not supports_programmatic_event():
|
||||
raise RuntimeError("programmatic_event requires Hopper (sm_90) or newer")
|
||||
|
||||
print(
|
||||
f"\n[Programmatic Event] Running programmatic_event example "
|
||||
f"(programmatic_event_kind={programmatic_event_kind!r}, "
|
||||
f"n_elements_per_block={n_elements_per_block}, n_blocks={n_blocks}, "
|
||||
f"trigger_at_block_start={trigger_at_block_start})"
|
||||
)
|
||||
|
||||
out = torch.full(
|
||||
(n_elements_per_block, n_blocks), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
expected = torch.full(
|
||||
(n_elements_per_block, n_blocks), 1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
torch_stream = torch.cuda.default_stream()
|
||||
stream = cuda_runtime.cudaStream_t(torch_stream.cuda_stream)
|
||||
|
||||
out_tensor = from_dlpack(out)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
torch_event = None
|
||||
|
||||
with torch.cuda.graph(graph):
|
||||
stream = cuda_runtime.cudaStream_t(torch.cuda.current_stream().cuda_stream)
|
||||
torch_event, programmatic_event = _make_event(
|
||||
programmatic_event_kind, torch_stream
|
||||
)
|
||||
launch_with_programmatic_event(
|
||||
out_tensor,
|
||||
0,
|
||||
stream,
|
||||
programmatic_event,
|
||||
trigger_at_block_start=trigger_at_block_start,
|
||||
)
|
||||
|
||||
# Overlaps with the prior launch after every CTA has been launched
|
||||
launch_with_programmatic_event(
|
||||
out_tensor,
|
||||
1,
|
||||
stream,
|
||||
programmatic_event,
|
||||
trigger_at_block_start=trigger_at_block_start,
|
||||
)
|
||||
|
||||
graph.replay()
|
||||
assert torch_event is not None, "torch.cuda.Event was not created"
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, expected)
|
||||
print(
|
||||
f"[Programmatic Event] {programmatic_event_kind} event fired and output verified."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Demonstrate launch_completion_event and programmatic_event launch "
|
||||
"attributes with both cudaEvent_t and CUevent."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--n-blocks", default=48, type=int)
|
||||
parser.add_argument("--n-elements-per-block", default=128, type=int)
|
||||
parser.add_argument("--use-driver-api", action="store_true")
|
||||
parser.add_argument("--trigger-at-block-start", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
run_launch_completion_event_example(
|
||||
n_elements_per_block=args.n_elements_per_block,
|
||||
n_blocks=args.n_blocks,
|
||||
launch_completion_event_kind="driver" if args.use_driver_api else "runtime",
|
||||
)
|
||||
|
||||
if supports_programmatic_event():
|
||||
run_programmatic_event_example(
|
||||
n_elements_per_block=args.n_elements_per_block,
|
||||
n_blocks=args.n_blocks,
|
||||
programmatic_event_kind="driver" if args.use_driver_api else "runtime",
|
||||
trigger_at_block_start=1 if args.trigger_at_block_start else 0,
|
||||
)
|
||||
else:
|
||||
print("\nSkipping programmatic_event: requires Hopper (sm_90) or newer.")
|
||||
|
||||
print("\nPASS")
|
||||
Reference in New Issue
Block a user