v4.5 tag update (#3202)

* Python DSL examples reorganization.

* v4.5 tag update.
This commit is contained in:
Junkai-Wu
2026-05-06 08:55:27 +08:00
committed by GitHub
parent f74fea9ce3
commit cb37157db5
351 changed files with 36688 additions and 8117 deletions

View File

@@ -2,18 +2,59 @@
# CUTLASS 4.x
## [4.5.0](https://github.com/NVIDIA/cutlass/tree/main) (2026-03-27)
## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01)
### CuTe DSL
* New features
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
- MXF8F6F4 mixed precision supoort
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
- Block Scaled MMA for SM120 now works on Spark
- EFC broadcast semantics support
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
- cute.copy now supports user specified loop unrolling
* Bug fixing and improvements
- Improved source code correlation for profiling/debugging
- Fixed an aarch64 segfault issue with tvm-ffi
- Re-organization for CuTe DSL examples/tutorials for better discoverability
* More examples of authorizing peak-performance kernels
- MOE examles
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
- Compared to torch_210_cu13, very few problem has worse perf in B200.
- mxfp8_2dx3d: avg 1.29 speedup;
- mxfp8_2dx2d: avg 1.41 speedup;
- nvfp4_2dx3d: avg 1.11 speedup;
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
* API changes
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
### CUTLASS C++
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
- Enables launching GEMM on stream with partial SM allocation.
* Fix some kernel issues:
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
- Fix CUTLASS clang build issues
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
* Fix some profiler issues:
- Add missing reference kernels for blockwise GEMM profiler
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!

View File

@@ -30,5 +30,5 @@ Certain files within this repository are subject to separate licensing terms:
- The files located in the `python/CuTeDSL` directory are licensed under the
NVIDIA End User License Agreement (EULA). Please refer to
https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
for the full terms.

View File

@@ -3,7 +3,7 @@
# CUTLASS 4.5.0
_CUTLASS 4.5.0 - March 2026_
_CUTLASS 4.5.0 - May 2026_
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
and related computations at all levels and scales within CUDA. It incorporates strategies for
@@ -45,16 +45,57 @@ To get started quickly - please refer :
# What's New in CUTLASS 4.5
### CuTe DSL
## CuTe DSL
* New features
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
- MXF8F6F4 mixed precision supoort
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
- Block Scaled MMA for SM120 now works on Spark
- EFC broadcast semantics support
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
- cute.copy now supports user specified loop unrolling
* Bug fixing and improvements
- Improved source code correlation for profiling/debugging
- Fixed an aarch64 segfault issue with tvm-ffi
- Re-organization for CuTe DSL examples/tutorials for better discoverability
### CUTLASS C++
* More examples of authorizing peak-performance kernels
- MOE examles
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
- Compared to torch_210_cu13, very few problem has worse perf in B200.
- mxfp8_2dx3d: avg 1.29 speedup;
- mxfp8_2dx2d: avg 1.41 speedup;
- nvfp4_2dx3d: avg 1.11 speedup;
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
* API changes
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
## CUTLASS C++
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
- Enables launching GEMM on stream with partial SM allocation.
* Fix some kernel issues:
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
- Fix CUTLASS clang build issues
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
* Fix some profiler issues:
- Add missing reference kernels for blockwise GEMM profiler
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!

View File

@@ -48,7 +48,7 @@ foreach(FUSION_CONV_EXAMPLE
fused_two_convs_f16_sm80_shmem
fused_two_convs_s8_sm75_rf
fused_two_convs_s8_sm75_shmem
fused_two_convs_s8_sm80_rf
# fused_two_convs_s8_sm80_rf # disabled: fails to build
fused_two_convs_s8_sm80_shmem
)

View File

@@ -195,7 +195,7 @@ public:
}
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
implementable &= TileScheduler::can_implement(args.scheduler);
implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info);
return implementable;
}

View File

@@ -216,7 +216,7 @@ struct Options {
float alpha, beta;
int iterations;
int m, n, k;
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
int cluster_m, cluster_n;
using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
DecompositionMode decomposition_mode;
@@ -240,10 +240,8 @@ struct Options {
m(256), n(256), k(16384),
alpha(1.f), beta(0.f),
iterations(10),
preferred_cluster_m(4),
preferred_cluster_n(4),
fallback_cluster_m(2),
fallback_cluster_n(1),
cluster_m(2),
cluster_n(1),
decomposition_mode(DecompositionMode::Heuristic),
reduction_mode(ReductionMode::Deterministic),
splits(1)
@@ -265,10 +263,8 @@ struct Options {
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("splits", splits, 1);
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
cmd.get_cmd_line_argument("cluster_m", cluster_m, 2);
cmd.get_cmd_line_argument("cluster_n", cluster_n, 1);
// Parse decompsition mode
std::string decomp_mode;
@@ -303,10 +299,8 @@ struct Options {
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n"
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
<< " --cluster_m=<str> Sets the M extent of the cluster shape\n"
<< " --cluster_n=<str> Sets the N extent of the cluster shape\n"
<< " --decomposition=<str> Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n"
<< " --reduction=<str> Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
@@ -424,8 +418,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
arguments.hw_info.cluster_shape = dim3(options.cluster_m, options.cluster_n, 1);
arguments.hw_info.cluster_shape_fallback = dim3(options.cluster_m, options.cluster_n, 1);
arguments.scheduler.splits = options.splits;
arguments.scheduler.decomposition_mode = options.decomposition_mode;
@@ -498,8 +492,7 @@ int run(Options &options) {
std::cout << "Stream-K GEMM with"
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n"
<< " Cluster = (" << options.cluster_m << ", " << options.cluster_n << ", 1)\n"
<< " Decomposition_mode=" << options.decomposition_mode_str()
<< " Split_count=" << options.splits
<< " Reduction_mode=" << options.reduction_mode_str()

View File

@@ -536,7 +536,11 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x;
#else
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
#endif
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
@@ -573,6 +577,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
}
ElementQK old_row_max = row_max;
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
auto pos = tTMEM_LOADcS(0);
if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) {
float curr_max = tiled_tmem_load.get_max();
row_max = ::fmax(row_max, curr_max);
}
else
#endif
{
// compute rowmax
float row_max_0 = row_max;

View File

@@ -540,7 +540,11 @@ struct Sm100FmhaGenMainloopWarpspecialized {
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x;
#else
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
#endif
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
@@ -577,6 +581,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
}
ElementQK old_row_max = row_max;
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
auto pos = tTMEM_LOADcS(0);
if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) {
float curr_max = tiled_tmem_load.get_max();
row_max = ::fmax(row_max, curr_max);
}
else
#endif
{
// compute rowmax
float row_max_0 = row_max;

View File

@@ -213,12 +213,15 @@ auto make_iterator(T* ptr) {
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleRunner {
template <
// Type of kernel schedule to generate
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100;
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
// Type of epilogue schedule to generate
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
static constexpr bool FuseQuantization = false;
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
class ClusterShapeMNK = Shape<_1, _1, _1>,
bool FuseQuantization = false
>
struct ExampleRunner {
using LayoutATag = cutlass::layout::RowMajor;
using LayoutBTag = cutlass::layout::ColumnMajor;
@@ -238,10 +241,8 @@ struct ExampleRunner {
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
static constexpr int TileM = cute::is_base_of_v<cutlass::gemm::KernelSchedule2Sm, MainloopScheduleType> ? 256 : 128;
using MmaTileMNK = Shape<Int<TileM>,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
@@ -712,10 +713,34 @@ int main(int argc, char const **args) {
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl;
ExampleRunner runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load, 1SM, 2x2 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_2, _2, _1>
> runner_mixed_tma_cpasync_1sm_2x2;
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x1 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_2, _1, _1>
> runner_mixed_tma_cpasync_2sm_2x1;
runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info);
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_2, _4, _1>
> runner_mixed_tma_cpasync_2sm_2x4;
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
#endif
return 0;

View File

@@ -220,6 +220,7 @@ using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVecto
template <
// Type of kernel schedule to generate
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
class ClusterShapeMNK = Shape<_1, _1, _1>,
// Type of epilogue schedule to generate
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
bool FuseQuantization = false
@@ -246,8 +247,8 @@ struct ExampleRunner {
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
static constexpr int TileM = cute::is_base_of_v<cutlass::gemm::KernelSchedule2Sm, MainloopScheduleType> ? 256 : 128;
using MmaTileMNK = Shape<Int<TileM>,_64,_64>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
static constexpr int AlignmentA = 32;
static constexpr int AlignmentB = 32;
@@ -485,6 +486,24 @@ struct ExampleRunner {
},
hw_info
};
}
else if constexpr (std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100>){
return typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{ // Mainloop arguments
block_A.device_data(),
block_B.device_data(),
block_SFA.device_data(),
block_SFB.device_data()
},
{ // Epilogue arguments
{},
block_C.device_data(), stride_C,
block_D.device_data(), stride_D
},
hw_info
};
}
else {
return typename Gemm::Arguments {
@@ -654,10 +673,80 @@ int main(int argc, char const **args) {
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100> runner_tma;
runner_tma.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 1x1 cluster:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100> runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x1 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
Shape<_4, _1, _1>
> runner_mixed_tma_cpasync_1sm_4x1;
runner_mixed_tma_cpasync_1sm_4x1.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 1x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
Shape<_1, _4, _1>
> runner_mixed_tma_cpasync_1sm_1x4;
runner_mixed_tma_cpasync_1sm_1x4.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x2 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
Shape<_2, _2, _1>
> runner_mixed_tma_cpasync_1sm_2x2;
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x2 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
Shape<_4, _2, _1>
> runner_mixed_tma_cpasync_1sm_4x2;
runner_mixed_tma_cpasync_1sm_4x2.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
Shape<_2, _4, _1>
> runner_mixed_tma_cpasync_1sm_2x4;
runner_mixed_tma_cpasync_1sm_2x4.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
Shape<_4, _4, _1>
> runner_mixed_tma_cpasync_1sm_4x4;
runner_mixed_tma_cpasync_1sm_4x4.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x1 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
Shape<_4, _1, _1>
> runner_mixed_tma_cpasync_2sm_4x1;
runner_mixed_tma_cpasync_2sm_4x1.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 2x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
Shape<_2, _4, _1>
> runner_mixed_tma_cpasync_2sm_2x4;
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x2 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
Shape<_4, _2, _1>
> runner_mixed_tma_cpasync_2sm_4x2;
runner_mixed_tma_cpasync_2sm_4x2.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
Shape<_4, _4, _1>
> runner_mixed_tma_cpasync_2sm_4x4;
runner_mixed_tma_cpasync_2sm_4x4.run(options, hw_info);
#endif
return 0;

View File

@@ -70,7 +70,6 @@
#include "helper.h"
using namespace cute;
///////////////////////////////////////////////////////////////////////////////////////////////////
@@ -169,12 +168,14 @@ bool initialize_block(
///////////////////////////////////////////////////////////////////////////////////////////////////
struct ExampleRunner {
template<
// Type of kernel schedule to generate
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100;
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
// Type of epilogue schedule to generate
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
class ClusterShapeMNK = Shape<_1, _1, _1>
>
struct ExampleRunner {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
@@ -189,7 +190,6 @@ struct ExampleRunner {
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
// 16B alignment lets us use TMA
@@ -219,7 +219,7 @@ struct ExampleRunner {
MainloopScheduleType
>::CollectiveOp;
using ProblemShape = cutlass::gemm::MoEProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using ProblemShape = typename cutlass::gemm::MoEProblemShape<Shape<int,int,int>>; // <M,N,K> per group
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
@@ -272,10 +272,10 @@ struct ExampleRunner {
auto [M, N, K] = problem;
printf("group [%d] : M = %d, N = %d, K = %d\n", i, M, N, K);
cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, Gemm::LayoutA(maxK));
cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, Gemm::LayoutB(maxK));
cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutC(maxM));
cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutD(maxM));
cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, typename Gemm::LayoutA(maxK));
cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, typename Gemm::LayoutB(maxK));
cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, typename Gemm::LayoutC(maxM));
cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, typename Gemm::LayoutD(maxM));
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
@@ -561,10 +561,62 @@ int main(int argc, char const **args) {
hw_info.device_id = 0;
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl;
ExampleRunner runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load and 1x1 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_1, _1, _1>
> runner_mixed_tma_cpasync_1sm;
runner_mixed_tma_cpasync_1sm.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x2 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_2, _2, _1>
> runner_mixed_tma_cpasync_1sm_2x2;
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load and 4x4 cluster:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_4, _4, _1>
> runner_mixed_tma_cpasync_1sm_4x4;
runner_mixed_tma_cpasync_1sm_4x4.run(options, hw_info);
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 2x1:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_2, _1, _1>
> runner_mixed_tma_cpasync_2sm_2x1;
runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info);
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 4x1:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_4, _1, _1>
> runner_mixed_tma_cpasync_2sm_4x1;
runner_mixed_tma_cpasync_2sm_4x1.run(options, hw_info);
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 2x4:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_2, _4, _1>
> runner_mixed_tma_cpasync_2sm_2x4;
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 4x4:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
cutlass::epilogue::collective::EpilogueScheduleAuto,
Shape<_4, _4, _1>
> runner_mixed_tma_cpasync_2sm_4x4;
runner_mixed_tma_cpasync_2sm_4x4.run(options, hw_info);
#endif
return 0;

View File

@@ -171,6 +171,8 @@ bool initialize_block(
template <
// Type of kernel schedule to generate
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
// Static cluster shape
class ClusterShapeMNK = Shape<_1, _1, _1>,
// Type of epilogue schedule to generate
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto
>
@@ -189,7 +191,6 @@ struct ExampleRunner {
using ElementCompute = float;
using ElementScalar = float;
using ClusterShapeMNK = Shape<_1,_1,_1>;
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
// 16B alignment lets us use TMA
@@ -334,6 +335,16 @@ struct ExampleRunner {
hw_info
};
}
else if constexpr (std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100>){
return typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A.get(), block_B.get()},
{{}, // epilogue.thread
block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
}
else {
return typename Gemm::Arguments {
cutlass::gemm::GemmUniversalMode::kGemm,
@@ -490,13 +501,33 @@ int main(int argc, char const **args) {
runner_tma.run(options, hw_info);
std::cout << "Running kernel with CPASYNC load:" << std::endl;
ExampleRunner runner_cpasync;
ExampleRunner<cutlass::gemm::KernelWarpSpecialized1SmSm100> runner_cpasync;
runner_cpasync.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
std::cout << "Running kernel with mixed TMA+CPASYNC load, cluster shape 1x1:" << std::endl;
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100> runner_mixed_tma_cpasync;
runner_mixed_tma_cpasync.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 2x2:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
Shape<_2, _2, _1>
> runner_mixed_tma_cpasync_2x2;
runner_mixed_tma_cpasync_2x2.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 4x4, 1SM instructions:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
Shape<_4, _4, _1>
> runner_mixed_tma_cpasync_4x4;
runner_mixed_tma_cpasync_4x4.run(options, hw_info);
std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 4x4, 2SM instructions:" << std::endl;
ExampleRunner<
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
Shape<_4, _4, _1>
> runner_mixed_tma_cpasync_4x4_2sm;
runner_mixed_tma_cpasync_4x4_2sm.run(options, hw_info);
#endif
return 0;

View File

@@ -48,7 +48,7 @@ from cutlass.cute.typing import Int32, Int64, Float32
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, ".."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from helpers import fmha_helpers as fmha_utils

View File

@@ -51,7 +51,7 @@ from cutlass.cute.typing import Int32, Float32, Float8E4M3FN, Float16, BFloat16,
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, ".."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from helpers import fmha_helpers as fmha_utils

View File

@@ -45,14 +45,14 @@ from cutlass.cute.runtime import from_dlpack
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from blackwell.mamba2_ssd.mamba2_ssd_reference import (
from cute.blackwell.kernel.attention.mamba2_ssd.mamba2_ssd_reference import (
ssd_reference_fp32_all,
ssd_reference_lowprecision_intermediates,
analyze_relative_diffs,
)
from blackwell.mamba2_ssd.mamba2_ssd_tile_scheduler import (
from cute.blackwell.kernel.attention.mamba2_ssd.mamba2_ssd_tile_scheduler import (
Mamba2SSDTileSchedulerParams,
Mamba2SSDTileScheduler,
)

View File

@@ -27,14 +27,11 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import enum
import math
import time
from typing import Type, Tuple
from functools import partial
import torch
import torch.nn.functional as F
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
@@ -42,7 +39,7 @@ import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
import cutlass.cute.nvgpu.tcgen05 as tcgen05
from cutlass.cute.nvgpu import tcgen05, OperandMajorMode
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.torch as cutlass_torch
@@ -51,9 +48,6 @@ import cutlass.cute.testing as testing
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import *
from cutlass._mlir.dialects import llvm
from cutlass.cute.arch.nvvm_wrappers import mapa
# Kernel invariants
mma_modes = (0, 1, 2)
mma_dice = (None, None, None) # (MMA, #MMA_M, #MMA_K)
@@ -65,19 +59,22 @@ warpgroup_threads = 128
# Math helpers
log2_e = math.log2(math.e) # change exponential base
use_tensor_ssa_math = False # experimental
fadd2 = partial(cute.arch.add_packed_f32x2, ftz=False, rnd="rn")
fmul2 = partial(cute.arch.mul_packed_f32x2, ftz=False, rnd="rn")
ffma2 = partial(cute.arch.fma_packed_f32x2, ftz=False, rnd="rn")
fadd2 = cute.arch.add_packed_f32x2
fmul2 = cute.arch.mul_packed_f32x2
ffma2 = cute.arch.fma_packed_f32x2
exp2 = partial(cute.math.exp2, fastmath=True)
warp_fmax = partial(cute.arch.warp_redux_sync, kind="fmax", nan=True)
smem_fmax = partial(cute.arch.atomic_fmax, sem="relaxed", scope="cta")
gmem_fmax = partial(cute.arch.atomic_fmax, sem="relaxed", scope="gpu")
class MixedInputFusedMultiHeadAttentionDecode:
def __init__(
self,
headdim,
block_scaledim, # headdim per scale factor; scale factor shape is (batches, heads_k, seqlen, headdim / block_scaledim)
grouped_head_tile, # GQA packing tile size, can be less than group size
convert_warpgroups = 1, # Multiple warpgroups striding on convert stages
block_scaledim, # headdim per scale factor; scale factor shape is (batches, heads_k, seqlen, headdim / block_scaledim)
grouped_head_tile, # GQA packing tile size, can be less than group size
convert_warpgroups=1, # Multiple warpgroups striding on convert stages
):
self.headdim = headdim
self.grouped_head_tile = grouped_head_tile
@@ -93,7 +90,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
self.softmax_warpgroup_id = warpgroup_id
warpgroup_id += 1
self.cvt_warpgroup_ids = tuple(range(warpgroup_id, warpgroup_id+convert_warpgroups))
self.cvt_warpgroup_ids = tuple(
range(warpgroup_id, warpgroup_id + convert_warpgroups)
)
warpgroup_id += convert_warpgroups
# Why 2 MMA+TMA warps when not MMA bound?
@@ -114,12 +113,15 @@ class MixedInputFusedMultiHeadAttentionDecode:
max_regs_per_wg_thread = 64 * 1024 // warpgroup_threads # 64K regs per SM
self.mma_tma_regs = 72
self.cvt_regs = 112
self.softmax_regs = (max_regs_per_wg_thread
- self.mma_tma_regs
- self.cvt_regs * convert_warpgroups)
self.softmax_regs = (
max_regs_per_wg_thread
- self.mma_tma_regs
- self.cvt_regs * convert_warpgroups
)
self.softmax_regs = max(128, min(256, self.softmax_regs))
assert (self.mma_tma_regs + self.softmax_regs +
self.cvt_regs * convert_warpgroups) <= max_regs_per_wg_thread or not self.use_reg_reconfig
assert (
self.mma_tma_regs + self.softmax_regs + self.cvt_regs * convert_warpgroups
) <= max_regs_per_wg_thread or not self.use_reg_reconfig
self.bs_stages = 2
self.sp_stages = 2
@@ -140,25 +142,33 @@ class MixedInputFusedMultiHeadAttentionDecode:
raise ValueError("use Float8E4M3FN instead of Float8E4M3")
if d % 64 != 0:
raise ValueError(f"headdim({d}) must be multiple of 64")
raise testing.CantImplementError(f"headdim({d}) must be multiple of 64")
if h_q % h_k != 0:
raise ValueError(f"heads_q({h_q}) must be a multiple of heads_k({h_k})")
raise testing.CantImplementError(
f"heads_q({h_q}) must be a multiple of heads_k({h_k})"
)
align_scale_bits = 128 # TMA requirement
if self.scaledim * q_dtype.width < align_scale_bits:
align_seq = align_scale_bits // (self.scaledim * q_dtype.width)
if s_k % align_seq != 0:
raise ValueError(f"seqlen({s_k}) must be a multiple of {align_seq}")
raise testing.CantImplementError(
f"seqlen({s_k}) must be a multiple of {align_seq}"
)
if kv_dtype.width < 8 and d % 128 != 0: # TMA requirement
raise ValueError(f"headdim({d}) must be multiple of 128 for {kv_dtype} KV")
raise testing.CantImplementError(
f"headdim({d}) must be multiple of 128 for {kv_dtype} KV"
)
@cute.jit
def __call__(
self,
problem_shape: Tuple[Int32, Int32, Int32, Int32, Int32], # b, h_q, h_k, s_k, d
kv_splits: Int32, # threadblocks per sequence
problem_shape: Tuple[
cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32
], # b, h_q, h_k, s_k, d
kv_splits: cutlass.Int32, # threadblocks per sequence
q_iter: cute.Pointer,
k_iter: cute.Pointer,
v_iter: cute.Pointer,
@@ -170,8 +180,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
o_partial_iter: cute.Pointer, # partial O per kv split
m_partial_iter: cute.Pointer, # partial colmax_s per kv split
l_partial_iter: cute.Pointer, # partial colsum_p per kv split
scale_qs: Float32,
scale_o: Float32,
scale_qs: cutlass.Float32,
scale_o: cutlass.Float32,
stream: cuda.CUstream,
):
##############################
@@ -179,7 +189,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
##############################
mma_dtype = q_iter.dtype
acc_dtype = o_partial_iter.dtype
assert acc_dtype is Float32 # don't support other acc types for now
assert acc_dtype is cutlass.Float32 # don't support other acc types for now
# Block tile sets the granularity at which threadblocks consume work
blk_tile_s = 128
@@ -197,8 +207,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
# GEMM1: (S_K, H_R, D, (H_K, B))
tiled_mma_kq = sm100_utils.make_trivial_tiled_mma(
mma_dtype,
tcgen05.OperandMajorMode.K, # K
tcgen05.OperandMajorMode.K, # Q
mma_dtype,
OperandMajorMode.K, # K
OperandMajorMode.K, # Q
acc_dtype,
tcgen05.CtaGroup.ONE,
mma_tile_mnk[:2],
@@ -208,8 +219,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
# GEMM2: (D, H_R, S_K, (H_K, B))
tiled_mma_vp = sm100_utils.make_trivial_tiled_mma( #
mma_dtype,
tcgen05.OperandMajorMode.K, # V
tcgen05.OperandMajorMode.MN, # P
mma_dtype,
OperandMajorMode.K, # V
OperandMajorMode.MN, # P
acc_dtype,
tcgen05.CtaGroup.ONE,
mma_tile_mnk[:2],
@@ -523,8 +535,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
mM: cute.Tensor,
mM_partial: cute.Tensor,
mL_partial: cute.Tensor,
scale_qs: Float32,
scale_qs_log2_e: Float32,
scale_qs: cutlass.Float32,
scale_qs_log2_e: cutlass.Float32,
):
# Read special registers
kv_splits, tiles_hr, tiles_hb = cute.arch.grid_dim()
@@ -586,7 +598,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
##############################
# Tmem Allocation
##############################
tmem_ptr_smem_ptr = smem.allocate_array(Int32)
tmem_ptr_smem_ptr = smem.allocate_array(cutlass.Int32)
if warp_idx == init_warp and not exit_early:
cute.arch.alloc_tmem(self.tmem_alloc_cols, tmem_ptr_smem_ptr)
init_warp += 1
@@ -595,26 +607,32 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Pipeline Allocation + Init
##############################
# Allocate Mbarriers
q_pipeline_ptr = smem.allocate_array(Int64, self.q_stages * 2)
kv_pipeline_ptr = smem.allocate_array(Int64, self.kv_stages * 2)
bs_pipeline_ptr = smem.allocate_array(Int64, self.bs_stages * 2)
cvt_pipeline_ptr = smem.allocate_array(Int64, self.cvt_stages * 2)
s_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2)
p_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2)
o_pipeline_ptr = smem.allocate_array(Int64, self.o_stages * 2)
q_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.q_stages * 2)
kv_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.kv_stages * 2)
bs_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.bs_stages * 2)
cvt_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.cvt_stages * 2)
s_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.sp_stages * 2)
p_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.sp_stages * 2)
o_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.o_stages * 2)
# Declare named barriers
softmax_nbar_id = 1
mma_kq_nbar_id = 2
mma_vp_nbar_id = 3
softmax_nbar = pipeline.NamedBarrier(
barrier_id=1, num_threads=warpgroup_threads
)
mma_kq_nbar = pipeline.NamedBarrier(barrier_id=2, num_threads=64)
mma_vp_nbar = pipeline.NamedBarrier(barrier_id=3, num_threads=64)
# Alias thread cooperatives
elect_one_cooperative = pipeline.CooperativeGroup(pipeline.Agent.Thread)
warpgroup_cooperative = pipeline.CooperativeGroup(pipeline.Agent.Thread, warpgroup_threads)
warpgroup_cooperative = pipeline.CooperativeGroup(
pipeline.Agent.Thread, warpgroup_threads
)
mma_group = elect_one_cooperative
tma_group = elect_one_cooperative
cvt_group = warpgroup_cooperative
cvt_groups = pipeline.CooperativeGroup(pipeline.Agent.Thread, warpgroup_threads * self.convert_warpgroups)
cvt_groups = pipeline.CooperativeGroup(
pipeline.Agent.Thread, warpgroup_threads * self.convert_warpgroups
)
softmax_group = warpgroup_cooperative
# Initialize pipelines
@@ -642,7 +660,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
num_stages=self.bs_stages,
producer_group=tma_group,
consumer_group=cvt_groups,
tx_count=cute.size_in_bytes(mma_dtype, cute.select(smem_layout_bs, mode=[0,1])),
tx_count=cute.size_in_bytes(
mma_dtype, cute.select(smem_layout_bs, mode=[0, 1])
),
barrier_storage=bs_pipeline_ptr,
tidx=mcast_coord,
cta_layout_vmnk=mcast_layout,
@@ -772,7 +792,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
tCtO = thrblk_mma_vp.make_fragment_C(tCsO.shape)
# Tmem tensor allocation
tmem_ptr = cute.arch.retrieve_tmem_ptr(Int32, 16, tmem_ptr_smem_ptr)
tmem_ptr = cute.arch.retrieve_tmem_ptr(cutlass.Int32, 16, tmem_ptr_smem_ptr)
tmem_offset = 0
tAtK_cvt = cute.make_tensor(
@@ -802,7 +822,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Exit early
##############################
if exit_early:
noop = None # early return not supported
noop = None # early return not supported # noqa: F841
##############################
# TMA KV Dispatch
@@ -991,9 +1011,11 @@ class MixedInputFusedMultiHeadAttentionDecode:
cute.arch.setmaxregister_decrease(self.cvt_regs)
# Intermediate convert type
cvt_type = Float32
if cutlass.const_expr(mma_dtype is cutlass.BFloat16 and
k_dtype in (cutlass.Int4, cutlass.Int8)):
cvt_type = cutlass.Float32
if cutlass.const_expr(
mma_dtype is cutlass.BFloat16
and k_dtype in (cutlass.Int4, cutlass.Int8)
):
cvt_type = mma_dtype
# Initialize for multiple warpgroups if necessary
@@ -1015,12 +1037,13 @@ class MixedInputFusedMultiHeadAttentionDecode:
smem_load_atom_k = cute.make_copy_atom(
cute.nvgpu.warp.LdMatrix8x16x8bOp(
num_matrices=4,
unpack_bits=(k_dtype.width if k_dtype.width < 8 else None)),
unpack_bits=(k_dtype.width if k_dtype.width < 8 else None),
),
kv_smem_dtype,
)
tmem_store_k = tcgen05.make_tmem_copy(
tmem_store_atom_k, tAtK_cvt[mma_dice+(0,)]
tmem_store_atom_k, tAtK_cvt[mma_dice + (0,)]
)
thr_store_k = tmem_store_k.get_slice(warpgroup_tidx)
tKrK_cvt_shape = thr_store_k.partition_S(tAtK_cvt).shape[:-1]
@@ -1043,7 +1066,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
smem_load_atom_v = cute.make_copy_atom(smem_load_op_v, kv_smem_dtype)
tmem_store_v = tcgen05.make_tmem_copy(
tmem_store_atom_v, tAtV_cvt[mma_dice+(0,)]
tmem_store_atom_v, tAtV_cvt[mma_dice + (0,)]
)
thr_store_v = tmem_store_v.get_slice(warpgroup_tidx)
tVrV_cvt_shape = thr_store_v.partition_S(tAtV_cvt).shape[:-1]
@@ -1113,7 +1136,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
bs_handle.release()
# Convert and scale K
for dk in cutlass.range(tiles_dk // self.convert_warpgroups, unroll=2):
for dk in cutlass.range(
tiles_dk // self.convert_warpgroups, unroll=2
):
tKrK = cute.make_rmem_tensor(tKrK_shape, kv_smem_dtype)
tKrK_cvt = cute.make_rmem_tensor(tKrK_cvt_shape, mma_dtype)
@@ -1139,7 +1164,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
cvt_handle = cvt_producer.acquire_and_advance()
cute.copy(
thr_store_k, tKrK_cvt, tKtK_cvt[cpy_dice + (cvt_handle.index,)]
thr_store_k,
tKrK_cvt,
tKtK_cvt[cpy_dice + (cvt_handle.index,)],
)
cute.arch.fence_view_async_tmem_store()
cvt_handle.commit()
@@ -1160,7 +1187,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
bs_handle.release()
# Convert and scale V
for dmsk in cutlass.range(tiles_dm * tiles_sk // self.convert_warpgroups, unroll=2):
for dmsk in cutlass.range(
tiles_dm * tiles_sk // self.convert_warpgroups, unroll=2
):
tVrV = cute.make_rmem_tensor(tVrV_shape, kv_smem_dtype)
tVrV_cvt = cute.make_rmem_tensor(tVrV_cvt_shape, mma_dtype)
@@ -1186,7 +1215,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
cvt_handle = cvt_producer.acquire_and_advance()
cute.copy(
thr_store_v, tVrV_cvt, tVtV_cvt[cpy_dice + (cvt_handle.index,)]
thr_store_v,
tVrV_cvt,
tVtV_cvt[cpy_dice + (cvt_handle.index,)],
)
cute.arch.fence_view_async_tmem_store()
cvt_handle.commit()
@@ -1223,9 +1254,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
k_handle = cvt_consumer.wait_and_advance(k_token)
# Signal BMM2 to start
if is_last_iter:
cute.arch.barrier_arrive(
barrier_id=mma_kq_nbar_id, number_of_threads=64
)
mma_kq_nbar.arrive()
for mma_k in cutlass.range_constexpr(tAtK_cvt.shape[2]):
cute.gemm(
tiled_mma_kq,
@@ -1245,7 +1274,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
if s > 0:
for _ in cutlass.range_constexpr(tiles_dm * tiles_sk):
cvt_consumer.advance()
cute.arch.barrier(barrier_id=mma_vp_nbar_id, number_of_threads=64)
mma_vp_nbar.arrive_and_wait()
s_token = s_producer.try_acquire()
##############################
@@ -1263,7 +1292,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Advance and wait for BMM1
for _ in cutlass.range_constexpr(tiles_dk):
cvt_consumer.advance()
cute.arch.barrier(barrier_id=mma_kq_nbar_id, number_of_threads=64)
mma_kq_nbar.arrive_and_wait()
# Sequence loop
p_token = False
@@ -1273,7 +1302,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
if s < iters_s - 1:
for _ in cutlass.range_constexpr(tiles_dk):
cvt_consumer.advance()
cute.arch.barrier(barrier_id=mma_kq_nbar_id, number_of_threads=64)
mma_kq_nbar.arrive_and_wait()
p_token = p_consumer.try_wait()
# BMM2
@@ -1286,9 +1315,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
v_handle = cvt_consumer.wait_and_advance(v_token)
# Signal BMM1 to start
if is_last_iter:
cute.arch.barrier_arrive(
barrier_id=mma_vp_nbar_id, number_of_threads=64
)
mma_vp_nbar.arrive()
for mma_k in cutlass.range_constexpr(tAtV_cvt.shape[2]):
cute.gemm(
tiled_mma_vp,
@@ -1325,7 +1352,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
tmem_load_atom_s = cute.make_copy_atom(
tcgen05.Ld32x32bOp(tmem_op_repeat), acc_dtype
)
tmem_load_s = tcgen05.make_tmem_copy(tmem_load_atom_s, tCtS[mma_dice + (0,)])
tmem_load_s = tcgen05.make_tmem_copy(
tmem_load_atom_s, tCtS[mma_dice + (0,)]
)
thr_load_s = tmem_load_s.get_slice(warpgroup_tidx)
tmem_store_atom_o = cute.make_copy_atom(
@@ -1356,7 +1385,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Partition colmax and initialize in RF
tSsM = thr_load_s.partition_D(tCsM) # (CPY, #CPY_MMA, #CPY_M, #CPY_N)
tSrM_prev = cute.make_rmem_tensor_like(tSsM)
tSrM_prev.fill(-Float32.inf)
tSrM_prev.fill(-cutlass.Float32.inf)
# Partition colsum and initialize in RF
# Each thread maintains a local colsum in RF, smem reduction happens after loop
@@ -1364,7 +1393,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
tCsL
) # (CPY, #CPY_MMA, #CPY_M, #CPY_N, WARPS)
tSrL = cute.make_rmem_tensor_like(tSsL[cpy_dice + (0,)])
tSrL.fill(Float32(0))
tSrL.fill(cutlass.Float32(0))
assert warp_threads >= cute.size(tSsM)
@@ -1385,17 +1414,15 @@ class MixedInputFusedMultiHeadAttentionDecode:
)
# Initialize O
tSrO.fill(Float32(0))
tSrO.fill(cutlass.Float32(0))
cute.copy(thr_store_o, tSrO, tStO)
# Initialize colsum and colmax in smem and wait
if warpgroup_widx == 0 and lane_store_max:
tSsM[lane_idx] = -Float32.inf
tSsM[lane_idx] = -cutlass.Float32.inf
if warpgroup_widx == 1 and lane_store_max:
tSsL[lane_idx] = Float32(0)
cute.arch.barrier(
barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads
)
tSsL[lane_idx] = cutlass.Float32(0)
softmax_nbar.arrive_and_wait()
#
# Sequence loop
@@ -1410,20 +1437,18 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Reduce colmax in warp RF
tSrM = cute.make_rmem_tensor_like(tSsM)
tSrM_lane = Float32(0) # Avoid dynamic register indexing
tSrM_lane = cutlass.Float32(0) # Avoid dynamic register indexing
for i in cutlass.range_constexpr(cute.size(tSrS)):
tSrM[i] = cute.arch.warp_redux_sync(tSrS[i], kind="fmax", nan=True)
tSrM[i] = warp_fmax(tSrS[i])
if i == lane_idx:
tSrM_lane = tSrM[i]
# Reduce colmax in smem
if lane_store_max:
self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
# Wait for colmax then load
cute.arch.barrier(
barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads
)
softmax_nbar.arrive_and_wait()
cute.autovec_copy(tSsM, tSrM)
# Compute online softmax
@@ -1504,7 +1529,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
#
# Reduce colsum in warp RF
tSrL_lane = Float32(0.0)
tSrL_lane = cutlass.Float32(0.0)
for i in cutlass.range_constexpr(cute.size(tSrL)):
tSrL[i] = cute.arch.warp_reduction_sum(tSrL[i])
if i == lane_idx:
@@ -1515,16 +1540,12 @@ class MixedInputFusedMultiHeadAttentionDecode:
tSsL[cpy_dice + (warpgroup_widx,)][lane_idx] = tSrL_lane
# Wait for colsum
cute.arch.barrier(
barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads
)
softmax_nbar.arrive_and_wait()
if warpgroup_widx == 0 and lane_store_max and inbound_hr:
# Load colsum and colmax
sL_lane_wg = sL[0, lane_idx, None]
sL_lane = (
sL_lane_wg[0] + sL_lane_wg[1] + sL_lane_wg[2] + sL_lane_wg[3]
)
sL_lane = sL_lane_wg[0] + sL_lane_wg[1] + sL_lane_wg[2] + sL_lane_wg[3]
sM_lane = sM[0, lane_idx]
# Scale colmax
@@ -1533,7 +1554,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
# Store colsum and colmax
gL_partial[lane_idx] = sL_lane
gM_partial[lane_idx] = sM_lane
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
o_handle = o_consumer.wait_and_advance()
cute.copy(thr_load_s, tStO, tSrO)
@@ -1563,16 +1584,16 @@ class MixedInputFusedMultiHeadAttentionDecode:
o_partial: cute.Tensor,
m_partial: cute.Tensor,
l_partial: cute.Tensor,
scale_o: Float32,
scale_o: cutlass.Float32,
):
d_blk_idx, coord_h, coord_b = cute.arch.block_idx()
d_per_blk, _, _ = cute.arch.block_dim()
d_idx, _, _ = cute.arch.thread_idx()
coord_d = d_blk_idx * d_per_blk + d_idx
o_dhb = Float32(0)
o_dhb = cutlass.Float32(0)
m_hb = m[coord_h, coord_b]
l_hb = Float32(0)
l_hb = cutlass.Float32(0)
o_partial_dhb = o_partial[coord_d, coord_h, coord_b, None]
m_partial_hb = m_partial[coord_h, coord_b, None]
@@ -1591,44 +1612,6 @@ class MixedInputFusedMultiHeadAttentionDecode:
return
@staticmethod
@cute.jit
def smem_fmax(ptr: Pointer, val: Float32):
# https://stackoverflow.com/a/72461459
# Works with canonical NaN which warp_redux_sync(kind="fmax") should return
llvm.inline_asm(
None,
[ptr.llvm_ptr, val.ir_value()],
"""{\n\t
.reg .pred p;\n\t
setp.lt.s32 p, $1, 0x0;
@p red.relaxed.shared::cta.min.u32 [$0], $1;\n\t
@!p red.relaxed.shared::cta.max.s32 [$0], $1;\n\t
}\n\t""",
"r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
@staticmethod
@cute.jit
def gmem_fmax(ptr: Pointer, val: Float32):
llvm.inline_asm(
None,
[ptr.llvm_ptr, val.ir_value()],
"""{\n\t
.reg .pred p;\n\t
setp.lt.s32 p, $1, 0x0;
@p red.relaxed.gpu.global.min.u32 [$0], $1;\n\t
@!p red.relaxed.gpu.global.max.s32 [$0], $1;\n\t
}\n\t""",
"l,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
def run(
batches: int = 1,
@@ -1638,10 +1621,10 @@ def run(
headdim: int = 512,
block_scaledim: int = 512,
kv_splits: int = 0,
q_dtype: Type[cutlass.Numeric] = BFloat16,
kv_dtype: Type[cutlass.Numeric] = Int8,
o_dtype: Type[cutlass.Numeric] = BFloat16,
acc_dtype: Type[cutlass.Numeric] = Float32,
q_dtype: Type[cutlass.Numeric] = cutlass.BFloat16,
kv_dtype: Type[cutlass.Numeric] = cutlass.Int8,
o_dtype: Type[cutlass.Numeric] = cutlass.BFloat16,
acc_dtype: Type[cutlass.Numeric] = cutlass.Float32,
tolerance: float = 0.1,
scale_q: float = 1.0,
scale_o: float = 1.0,
@@ -1652,7 +1635,7 @@ def run(
use_cold_l2: bool = False,
**kwargs,
):
print(f"Running Blackwell SM100 Mixed Input FMHA Decode test with:")
print("Running Blackwell SM100 Mixed Input FMHA Decode test with:")
print(f"\tbatches: {batches}, seqlen: {seqlen}")
print(f"\theads_q: {heads_q}, heads_k: {heads_k}")
print(f"\theaddim: {headdim}, block_scaledim: {block_scaledim}")
@@ -1709,9 +1692,7 @@ def run(
problem_shape = (batches, heads_q, heads_k, seqlen_k, headdim)
fmha.can_implement(
problem_shape, kv_splits, q_dtype, kv_dtype, o_dtype, acc_dtype
)
fmha.can_implement(problem_shape, kv_splits, q_dtype, kv_dtype, o_dtype, acc_dtype)
#
# Allocate Tensors
@@ -1719,24 +1700,24 @@ def run(
torch.manual_seed(1111)
def create_tensor(shape, dtype, init=True):
init_type = cutlass.torch.TensorInitType.RANDOM
init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2)
init_type = cutlass_torch.TensorInitType.RANDOM
init_config = cutlass_torch.RandomInitConfig(min_val=-2, max_val=2)
if init is False or init is None:
init_type = cutlass.torch.TensorInitType.SKIP
init_type = cutlass_torch.TensorInitType.SKIP
init_config = None
elif isinstance(init, int) or isinstance(init, float):
init_type = cutlass.torch.TensorInitType.SCALAR
init_config = cutlass.torch.ScalarInitConfig(value=init)
init_type = cutlass_torch.TensorInitType.SCALAR
init_config = cutlass_torch.ScalarInitConfig(value=init)
elif isinstance(init, tuple) or isinstance(init, list):
if len(init) == 2:
init_type = cutlass.torch.TensorInitType.RANDOM
init_config = cutlass.torch.RandomInitConfig(
init_type = cutlass_torch.TensorInitType.RANDOM
init_config = cutlass_torch.RandomInitConfig(
min_val=init[0], max_val=init[1]
)
if len(init) == 3:
init_type = cutlass.torch.TensorInitType.GAUSSIAN
init_config = cutlass.torch.RandomInitConfig(
init_type = cutlass_torch.TensorInitType.GAUSSIAN
init_config = cutlass_torch.RandomInitConfig(
mean=init[0], std=init[1], scale=init[2]
)
@@ -1809,7 +1790,7 @@ def run(
scale_qs,
scale_o,
current_stream,
options=f"--opt-level 2",
options="--opt-level 2",
)
print("Finished Compiling")
@@ -1980,7 +1961,7 @@ if __name__ == "__main__":
def parse_comma_separated_ints(s: str):
try:
return tuple(int(x.strip()) for x in s.split(","))
return tuple(cutlass.int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
@@ -2052,7 +2033,7 @@ if __name__ == "__main__":
parser.add_argument(
"--acc_dtype",
type=cutlass.dtype,
default=Float32,
default=cutlass.Float32,
help="accumulator/reduction data type",
)

View File

@@ -47,10 +47,10 @@ from cutlass.cute.typing import Int32, Int64, Float32
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from helpers import fmha_helpers as fmha_utils
from blackwell.mixed_input_fmha import prefill_helpers as prefill_utils
from cute.blackwell.kernel.attention.mixed_input_fmha import prefill_helpers as prefill_utils
class MixedInputFusedMultiHeadAttentionPrefillD256:

View File

@@ -49,10 +49,10 @@ from cutlass.cute.typing import Int32, Int64, Float32
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from helpers import fmha_helpers as fmha_utils
from blackwell.mixed_input_fmha import prefill_helpers as prefill_utils
from cute.blackwell.kernel.attention.mixed_input_fmha import prefill_helpers as prefill_utils
class MixedInputFusedMultiHeadAttentionPrefillD512:

View File

@@ -51,9 +51,9 @@ from cutlass.cutlass_dsl import BaseDSL
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from blackwell.mla.mla_helpers import (
from cute.blackwell.kernel.attention.mla.mla_helpers import (
ceil_div,
MAX_SPLITS,
LOG2_E,

View File

@@ -51,9 +51,9 @@ from cutlass.cutlass_dsl import BaseDSL
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
from blackwell.mla.mla_helpers import (
from cute.blackwell.kernel.attention.mla.mla_helpers import (
ceil_div,
MAX_SPLITS,
LOG2_E,

View File

@@ -47,9 +47,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
from blackwell.mixed_input_gemm.mixed_input_host_utils import (
from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import (
create_tensors_for_contiguous_grouped_mixed_input_gemm as create_tensors,
run_contiguous_grouped_ref_and_compare as run_ref_and_compare,
)

View File

@@ -48,9 +48,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
from blackwell.mixed_input_gemm.mixed_input_host_utils import (
from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import (
create_tensors_for_contiguous_grouped_mixed_input_gemm as create_tensors,
run_contiguous_grouped_ref_and_compare as run_ref_and_compare,
)

View File

@@ -47,9 +47,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../.."))
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
from blackwell.mixed_input_gemm.mixed_input_host_utils import (
from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import (
create_tensors_for_batched_mixed_input_gemm as create_tensors,
run_batched_mixed_input_ref_and_compare as run_ref_and_compare,
)

View File

@@ -0,0 +1,695 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
MoE Persistent Tile Scheduler
A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations.
This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo
(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile.
Scenarios:
- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate)
- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden)
Key design principle:
- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic)
- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension
- TMA descriptor management is handled by OnlineTensormapDescCreator
- The kernel orchestrates all three components
"""
from typing import List, Tuple, Literal
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import (
Boolean,
Int32,
Integer,
extract_mlir_values,
new_from_mlir_values,
const_expr,
dsl_user_op,
)
from cutlass._mlir import ir
# =============================================================================
# Work Tile Info
# =============================================================================
class MoEWorkTileInfo:
"""
Work tile information for MoE scheduler.
Contains CTA-level tile information for executor warps:
- expert_idx: Which expert (-1 means invalid/done)
- tile_m_idx: CTA tile index along GEMM M dimension
- tile_n_idx: CTA tile index along GEMM N dimension
- k_tile_cnt: Number of CTA tiles along K dimension
Note: These are CTA-level indices, not cluster-level.
tile_l_idx is always 0 for MoE, executor can hardcode it.
For 2Dx3D (Forward):
M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed)
For 2Dx2D (Backward):
M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic)
"""
def __init__(
self,
expert_idx: Int32, # -1 means invalid tile
tile_m_idx: Int32,
tile_n_idx: Int32,
k_tile_cnt: Int32,
):
self.expert_idx = expert_idx
self.tile_m_idx = tile_m_idx
self.tile_n_idx = tile_n_idx
self.k_tile_cnt = k_tile_cnt
@property
def is_valid_tile(self) -> Boolean:
"""Check if this is a valid work tile (expert_idx >= 0)."""
return self.expert_idx >= Int32(0)
def __extract_mlir_values__(self) -> List[ir.Value]:
values = extract_mlir_values(self.expert_idx)
values.extend(extract_mlir_values(self.tile_m_idx))
values.extend(extract_mlir_values(self.tile_n_idx))
values.extend(extract_mlir_values(self.k_tile_cnt))
return values
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo":
assert len(values) == 4
return MoEWorkTileInfo(
expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]),
tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]),
tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]),
k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]),
)
def to_rmem_tensor(self):
"""Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy."""
rmem = cute.make_rmem_tensor((4,), Int32)
rmem[0] = self.expert_idx
rmem[1] = self.tile_m_idx
rmem[2] = self.tile_n_idx
rmem[3] = self.k_tile_cnt
return rmem
@staticmethod
def from_rmem_tensor(rmem) -> "MoEWorkTileInfo":
"""Unpack work tile info from an rmem tensor of shape (4,)."""
return MoEWorkTileInfo(
expert_idx=rmem[0], # type: ignore[arg-type]
tile_m_idx=rmem[1], # type: ignore[arg-type]
tile_n_idx=rmem[2], # type: ignore[arg-type]
k_tile_cnt=rmem[3], # type: ignore[arg-type]
)
# =============================================================================
# Scheduler Parameters
# =============================================================================
class MoEStaticSchedulerParams:
"""
Parameters for MoE tile scheduler.
Uses unified semantics for both scenarios:
- expert_shape: (expert_cnt, intermediate, hidden)
For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert
For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert
Tile hierarchy:
- cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k)
- cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n)
- cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape
This class is used both on host (for grid shape calculation) and on device
(stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk,
cluster_shape_mn) are NOT serialized to MLIR values.
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden)
cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k)
cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n)
):
self.scenario = scenario
e, i, h = expert_shape
self.expert_cnt = e if isinstance(e, Int32) else Int32(e)
self.intermediate = i if isinstance(i, Int32) else Int32(i)
self.hidden = h if isinstance(h, Int32) else Int32(h)
self.cta_tile_shape_mnk = cta_tile_shape_mnk
self.cluster_shape_mn = cluster_shape_mn
@property
def cluster_tile_m(self) -> int:
"""Cluster tile size along M = cta_tile_m * cluster_m."""
return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0]
@property
def cluster_tile_n(self) -> int:
"""Cluster tile size along N = cta_tile_n * cluster_n."""
return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1]
@property
def cta_tile_k(self) -> int:
"""CTA tile size along K (same as cluster since cluster_k = 1)."""
return self.cta_tile_shape_mnk[2]
def __extract_mlir_values__(self) -> List[ir.Value]:
"""Only serialize runtime values, not codegen-time constants."""
values = []
values.extend(extract_mlir_values(self.expert_cnt))
values.extend(extract_mlir_values(self.intermediate))
values.extend(extract_mlir_values(self.hidden))
return values
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams":
assert len(values) == 3
return MoEStaticSchedulerParams(
scenario=self.scenario,
expert_shape=(
new_from_mlir_values(self.expert_cnt, [values[0]]),
new_from_mlir_values(self.intermediate, [values[1]]),
new_from_mlir_values(self.hidden, [values[2]]),
),
cta_tile_shape_mnk=self.cta_tile_shape_mnk,
cluster_shape_mn=self.cluster_shape_mn,
)
@staticmethod
def get_grid_shape(
params: "MoEStaticSchedulerParams",
max_active_clusters: int,
) -> Tuple[int, int, int]:
"""
Compute grid shape for kernel launch.
Since host doesn't know token distribution across experts,
we launch max_active_clusters and let device-side scheduler
determine which tiles are valid.
"""
return (
params.cluster_shape_mn[0],
params.cluster_shape_mn[1],
max_active_clusters,
)
# =============================================================================
# Scheduler (Device-side)
# =============================================================================
class MoEStaticPersistentTileScheduler:
"""
Persistent tile scheduler specialized for MoE grouped GEMM.
This scheduler is ONLY responsible for tile iteration. It does NOT know
about tensor types, TMA descriptors, or domain conversion. Those concerns
are handled by MoESchedExtension and OnlineTensormapDescCreator respectively.
Architecture:
- Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info
- Executor warps: Read work_tile_info from smem, use MoESchedExtension for
domain conversion and TMA desc selection
The scheduler handles:
- 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden)
- 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis)
Usage (Scheduler warp):
scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim)
work_tile_info = scheduler.initial_work_tile_info()
# Broadcast work_tile_info to smem...
while work_tile_info.is_valid_tile:
# ... do work ...
work_tile_info = scheduler.advance_to_next_work()
# Broadcast work_tile_info to smem...
Usage (Executor warps - via MoESchedExtension):
# Read work_tile_info from smem...
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info)
"""
def __init__(
self,
# Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes)
params: MoEStaticSchedulerParams,
# Runtime tensor for scheduling
offs: cute.Tensor, # (experts,) cumsum of token counts
# Scheduling state
num_persistent_clusters: Int32,
current_work_linear_idx: Int32,
cta_id_in_cluster: cute.Coord,
# Expert tracking state (for O(1) advance within same expert)
current_expert_idx: Int32,
expert_tile_start: Int32, # cumsum of tiles before current expert
expert_tile_end: Int32, # cumsum of tiles including current expert
):
self.params = params
self.offs = offs
self.num_persistent_clusters = num_persistent_clusters
self._current_work_linear_idx = current_work_linear_idx
self.cta_id_in_cluster = cta_id_in_cluster
# Expert tracking
self.current_expert_idx = current_expert_idx
self.expert_tile_start = expert_tile_start
self.expert_tile_end = expert_tile_end
# =========================================================================
# Convenience accessors for params
# =========================================================================
@property
def scenario(self) -> Literal["2Dx3D", "2Dx2D"]:
return self.params.scenario
@property
def expert_cnt(self) -> Int32:
return self.params.expert_cnt
@property
def intermediate(self) -> Int32:
return self.params.intermediate
@property
def hidden(self) -> Int32:
return self.params.hidden
@property
def cta_tile_shape_mnk(self) -> Tuple[int, int, int]:
return self.params.cta_tile_shape_mnk
@property
def cluster_shape_mn(self) -> Tuple[int, int]:
return self.params.cluster_shape_mn
@property
def cluster_tile_m(self) -> int:
return self.params.cluster_tile_m
@property
def cluster_tile_n(self) -> int:
return self.params.cluster_tile_n
@property
def cta_tile_k(self) -> int:
return self.params.cta_tile_k
# =========================================================================
# MLIR value serialization (for SSA value passing in device code)
# =========================================================================
def __extract_mlir_values__(self) -> List[ir.Value]:
values = []
# Params (only runtime values are extracted)
values.extend(extract_mlir_values(self.params))
# Runtime tensor for scheduling
values.extend(extract_mlir_values(self.offs))
# Scheduling state
values.extend(extract_mlir_values(self.num_persistent_clusters))
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self.cta_id_in_cluster))
# Expert tracking state
values.extend(extract_mlir_values(self.current_expert_idx))
values.extend(extract_mlir_values(self.expert_tile_start))
values.extend(extract_mlir_values(self.expert_tile_end))
return values
def __new_from_mlir_values__(
self, values: List[ir.Value]
) -> "MoEStaticPersistentTileScheduler":
idx = 0
# Params (3 values: expert_cnt, intermediate, hidden)
new_params = new_from_mlir_values(self.params, values[idx:idx + 3])
idx += 3
# Runtime tensor for scheduling (variable size)
offs_len = len(extract_mlir_values(self.offs))
new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len])
idx += offs_len
# Scheduling state
new_num_persistent_clusters = new_from_mlir_values(
self.num_persistent_clusters, [values[idx]]
)
idx += 1
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[idx]]
)
idx += 1
# cta_id_in_cluster (3 values for Coord)
new_cta_id_in_cluster = new_from_mlir_values(
self.cta_id_in_cluster, values[idx:idx + 3]
)
idx += 3
# Expert tracking state
new_current_expert_idx = new_from_mlir_values(
self.current_expert_idx, [values[idx]]
)
idx += 1
new_expert_tile_start = new_from_mlir_values(
self.expert_tile_start, [values[idx]]
)
idx += 1
new_expert_tile_end = new_from_mlir_values(
self.expert_tile_end, [values[idx]]
)
idx += 1
return MoEStaticPersistentTileScheduler(
params=new_params,
offs=new_offs,
num_persistent_clusters=new_num_persistent_clusters,
current_work_linear_idx=new_current_work_linear_idx,
cta_id_in_cluster=new_cta_id_in_cluster,
current_expert_idx=new_current_expert_idx,
expert_tile_start=new_expert_tile_start,
expert_tile_end=new_expert_tile_end,
)
# =========================================================================
# Factory method
# =========================================================================
@staticmethod
@dsl_user_op
def create(
params: MoEStaticSchedulerParams,
offs: cute.Tensor,
block_idx: Tuple[Integer, Integer, Integer],
grid_dim: Tuple[Integer, Integer, Integer],
*,
loc=None,
ip=None,
) -> "MoEStaticPersistentTileScheduler":
"""
Create a MoE persistent tile scheduler.
:param params: Scheduler parameters (from host)
:param offs: Cumsum tensor of token counts per expert, shape (experts,)
:param block_idx: CUDA block index
:param grid_dim: CUDA grid dimensions
"""
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
params.cluster_shape_mn, loc=loc, ip=ip
)
bidx, bidy, bidz = block_idx
current_work_linear_idx = Int32(bidz)
cta_id_in_cluster = (
Int32(bidx % params.cluster_shape_mn[0]),
Int32(bidy % params.cluster_shape_mn[1]),
Int32(0),
)
# Initialize expert tracking to "before expert 0"
# The first call to _get_work_tile_for_linear_idx will advance to the correct expert
current_expert_idx = Int32(0)
expert_tile_start = Int32(0)
expert_tile_end = Int32(0) # Will be computed on first access
return MoEStaticPersistentTileScheduler(
params=params,
offs=offs,
num_persistent_clusters=num_persistent_clusters,
current_work_linear_idx=current_work_linear_idx,
cta_id_in_cluster=cta_id_in_cluster,
current_expert_idx=current_expert_idx,
expert_tile_start=expert_tile_start,
expert_tile_end=expert_tile_end,
)
# =========================================================================
# Tile iteration methods
# =========================================================================
@dsl_user_op
@cute.jit
def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
"""Get the initial work tile info."""
return self._get_work_tile_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
@cute.jit
def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
"""Advance to the next work tile and return its info."""
self._current_work_linear_idx += self.num_persistent_clusters
return self._get_work_tile_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
@cute.jit
def _get_work_tile_for_linear_idx(
self,
cluster_linear_idx: Int32,
*,
loc=None,
ip=None
) -> MoEWorkTileInfo:
"""
Convert a linear cluster index to MoEWorkTileInfo.
Uses cached expert tracking state for O(1) fast path when staying
within the same expert. Advances expert state when needed.
Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range.
"""
# Ensure expert tracking is initialized and up-to-date
self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip)
# Check if valid (still within expert range after advancing)
is_valid = self.current_expert_idx < self.expert_cnt
work_tile_info = MoEWorkTileInfo(
expert_idx=Int32(-1),
tile_m_idx=Int32(0),
tile_n_idx=Int32(0),
k_tile_cnt=Int32(0),
)
if is_valid:
# Compute local cluster tile indices within current expert
local_idx = cluster_linear_idx - self.expert_tile_start
cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx(
local_idx, self.current_expert_idx, loc=loc, ip=ip
)
# Convert cluster tile indices to CTA tile indices
# cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster
cta_tile_m_idx = (
cluster_tile_m_idx * self.cluster_shape_mn[0]
+ self.cta_id_in_cluster[0] # type: ignore[index]
)
cta_tile_n_idx = (
cluster_tile_n_idx * self.cluster_shape_mn[1]
+ self.cta_id_in_cluster[1] # type: ignore[index]
)
# Compute k_tile_cnt
k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip)
work_tile_info = MoEWorkTileInfo(
expert_idx=self.current_expert_idx,
tile_m_idx=cta_tile_m_idx,
tile_n_idx=cta_tile_n_idx,
k_tile_cnt=k_tile_cnt,
)
return work_tile_info
@dsl_user_op
@cute.jit
def _advance_expert_to_contain(
self,
cluster_linear_idx: Int32,
*,
loc=None,
ip=None,
) -> None:
"""
Advance expert tracking state until current expert contains cluster_linear_idx,
or we run out of experts.
Fast path: If already in correct expert, no work needed.
"""
# Initialize expert_tile_end if this is the first call (expert_tile_end == 0)
if self.expert_tile_end == Int32(0):
tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip)
self.expert_tile_end = tiles_for_expert_0
# Advance until cluster_linear_idx < expert_tile_end or no more experts
while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt:
self.current_expert_idx = self.current_expert_idx + 1
self.expert_tile_start = self.expert_tile_end
if self.current_expert_idx < self.expert_cnt:
tiles_for_expert = self._compute_tiles_for_expert(
self.current_expert_idx, loc=loc, ip=ip
)
self.expert_tile_end = self.expert_tile_end + tiles_for_expert
@dsl_user_op
@cute.jit
def _compute_tiles_for_expert(
self,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Int32:
"""Compute total cluster tiles for a given expert."""
if const_expr(self.scenario == "2Dx2D"):
# Fixed M=hidden, N=intermediate
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
return cluster_tile_m_cnt * cluster_tile_n_cnt
else: # 2Dx3D
# Variable M (tokens), fixed N
tokens_i = self.offs[expert_idx]
if expert_idx > 0:
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
cluster_tile_m_cnt = (
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
) // self.cluster_tile_m
cluster_tile_n_cnt = (
self.intermediate + self.cluster_tile_n - 1
) // self.cluster_tile_n
return cluster_tile_m_cnt * cluster_tile_n_cnt
@dsl_user_op
@cute.jit
def _decompose_local_idx(
self,
local_idx: Int32,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""
Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx).
Uses "short side first" strategy: the shorter dimension changes faster.
This maximizes overlap between adjacent clusters for better L2 cache utilization.
For example, if m_cnt=2, n_cnt=8:
- N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx
- Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ...
"""
# Get tile counts for M and N
cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts(
expert_idx, loc=loc, ip=ip
)
cluster_tile_m_idx = -1
cluster_tile_n_idx = -1
# Short side first: shorter dimension changes faster
# If m_cnt <= n_cnt: m is shorter, m changes faster
# local_idx = n_idx * m_cnt + m_idx
# If n_cnt < m_cnt: n is shorter, n changes faster
# local_idx = m_idx * n_cnt + n_idx
if cluster_tile_m_cnt <= cluster_tile_n_cnt:
# M is shorter or equal, M changes faster
cluster_tile_m_idx = local_idx % cluster_tile_m_cnt
cluster_tile_n_idx = local_idx // cluster_tile_m_cnt
else:
# N is shorter, N changes faster
cluster_tile_n_idx = local_idx % cluster_tile_n_cnt
cluster_tile_m_idx = local_idx // cluster_tile_n_cnt
return (cluster_tile_m_idx, cluster_tile_n_idx)
@dsl_user_op
@cute.jit
def _get_cluster_tile_counts(
self,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert."""
if const_expr(self.scenario == "2Dx2D"):
# Fixed M=hidden, N=intermediate
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
else: # 2Dx3D
# Variable M (tokens), fixed N
tokens_i = self.offs[expert_idx]
if expert_idx > 0:
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
cluster_tile_m_cnt = (
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
) // self.cluster_tile_m
cluster_tile_n_cnt = (
self.intermediate + self.cluster_tile_n - 1
) // self.cluster_tile_n
return (cluster_tile_m_cnt, cluster_tile_n_cnt)
@dsl_user_op
@cute.jit
def _compute_k_tile_cnt(
self,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Int32:
"""
Compute the number of K tiles for this expert.
2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k)
2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k)
"""
if const_expr(self.scenario == "2Dx3D"):
# K is hidden (fixed)
return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k
else: # 2Dx2D
# K is tokens_i (variable per expert)
tokens_i = self.offs[expert_idx]
if expert_idx > cutlass.Int32(0):
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator]

View File

@@ -0,0 +1,443 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
MoE Scheduler Extension.
Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level
domain conversion and TMA descriptor selection. This is the "glue" layer between:
- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt)
- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace
- Kernel: orchestrates everything
Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own
MoESchedExtension subclass with kernel-specific domain conversion logic.
Key design principles:
- Unified interface: get_gmem_tensor() for all tensor types
- Free implementation: no role-based templates, each subclass writes its own logic
- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc.
are available as tools but not mandatory
Architecture:
Scheduler ──(produces)──> MoEWorkTileInfo
expert_idx, tile_m, tile_n, k_cnt
v
Extension ──(uses)──> OnlineTensormapDescCreator
│ │
│ get_gmem_tensor() │ get_desc_ptr()
│ prefetch_for_expert() │ construct_and_write()
│ │
└── internal calls ───────┘
Kernel (caller): the only place that knows all three exist
"""
from abc import ABC, abstractmethod
from typing import Literal, Tuple, Union
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Pointer
from cutlass.cutlass_dsl import Int32
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
from blackwell.kernel.moe.moe_utils import (
OnlineTensormapDescCreator,
tensormap_ptr_for_copy,
compute_expert_token_range,
rewrite_tensor_shape,
prefetch_tma_descriptor,
)
from blackwell.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
@dataclass(frozen=True)
class MoESchedExtension(ABC):
"""
Abstract base class for MoE scheduler extensions.
Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA
descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.)
provides its own subclass with kernel-specific logic.
The extension:
- Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval
- Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors
- Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors
Subclasses are free to add any additional attributes in __init__ (scenario,
codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic
per tensor_name. No role-based templates or rigid patterns are imposed.
Usage in kernel (caller):
ext = ConcreteSchedExtension(tensormap_ctor, scenario=...)
while work_tile_info.is_valid_tile:
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
# Use real_a, desc_a in cute.copy ...
"""
def __init__(self, tensormap_ctor: OnlineTensormapDescCreator):
super().__init__()
self.tensormap_ctor = tensormap_ctor
@abstractmethod
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]],
work_tile_info: MoEWorkTileInfo,
) -> Tuple[cute.Tensor, "Pointer | None"]:
"""
Convert an MoE-view tensor to the real per-expert tensor for the
current work tile, and return the appropriate TMA descriptor pointer.
The MoE-view tensor uses "fake" GEMM domain dimensions that span all
experts (e.g., fake_m = tokens_sum). This method slices/offsets it
to the current expert's actual region.
:param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa")
:param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain
:param offs: Either a single cumsum tensor (experts,), or a tuple of
(offs_token, offs_padded) where offs_padded provides
padded offsets for scale-factor domain conversion.
:param work_tile_info: Current work tile from the scheduler
:return: (real_tensor, tma_desc_ptr_or_none)
- real_tensor: domain-offset and shape-rewritten tensor for this expert
- tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy),
or None if the caller should use the global TMA descriptor
"""
...
@abstractmethod
def prefetch_for_expert(self, expert_idx: Int32) -> None:
"""
Prefetch expert-wise TMA descriptors for the given expert.
Called when the scheduler advances to a new expert, allowing the TMA
descriptor cache to be warmed up before the descriptors are needed.
:param expert_idx: Index of the expert whose descriptors to prefetch
"""
...
# =============================================================================
# Grouped MM Extension
# =============================================================================
class GroupedMmSchedExtension(MoESchedExtension):
"""
MoE scheduler extension for grouped_mm: handles tensors a, b, c.
Domain conversion logic per scenario:
2Dx3D:
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc
2Dx2D:
A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc
B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
tensormap_ctor: OnlineTensormapDescCreator,
):
super().__init__(tensormap_ctor)
self.scenario = scenario
@cute.jit
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: cute.Tensor,
work_tile_info: MoEWorkTileInfo,
):
expert_idx = work_tile_info.expert_idx
token_offset, tokens_i = compute_expert_token_range(offs, expert_idx)
shape = gmem_tensor_in_moe_view.shape
c1 = cutlass.Int32(1)
if cutlass.const_expr(self.scenario == "2Dx3D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (fake_m, k, 1) -> offset fake_m, global desc
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, k, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "c"):
# C: (fake_m, n, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(tokens_i, shape[1], c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(self.scenario == "2Dx2D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (m, fake_k, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, fake_k, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "c"):
# C: (m, n, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
raise ValueError("Invalid scenario or GEMM tensor name.")
@cute.jit
def prefetch_for_expert(self, expert_idx: Int32) -> None:
if cutlass.const_expr(self.scenario == "2Dx3D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
elif cutlass.const_expr(self.scenario == "2Dx2D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
else:
raise ValueError("Invalid scenario.")
# =============================================================================
# Scaled Grouped MM Extension (block-scaled MoE)
# =============================================================================
class ScaledGroupedMmSchedExtension(MoESchedExtension):
"""
MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb.
Extends GroupedMmSchedExtension with scale-factor tensor support.
SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert
via tile_atom_to_shape_SF.
The offs parameter is always a tuple (offs_token, offs_padded):
- offs_token: cumsum offsets in data (activation) domain
- offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity)
sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size.
Domain conversion logic per scenario:
2Dx3D:
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
C: (fake_m, n, 1) -> rewrite shape, expert-wise desc
SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset,
atom-tile, global desc
SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx,
atom-tile, global desc
2Dx2D:
A: (m, fake_k, 1) -> rewrite shape, expert-wise desc
B: (n, fake_k, 1) -> rewrite shape, expert-wise desc
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
atom-tile, expert-wise desc
SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
atom-tile, expert-wise desc
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
tensormap_ctor: OnlineTensormapDescCreator,
):
super().__init__(tensormap_ctor)
self.scenario = scenario
@cute.jit
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: Tuple[cute.Tensor, cute.Tensor],
work_tile_info: MoEWorkTileInfo,
):
# Unpack the offs tuple
offs_token, offs_padded = offs
expert_idx = work_tile_info.expert_idx
token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx)
padded_offset, padded_size_i = compute_expert_token_range(
offs_padded, expert_idx
)
shape = gmem_tensor_in_moe_view.shape
stride = gmem_tensor_in_moe_view.stride
c1 = cutlass.Int32(1)
sf_vec_size = self.tensormap_ctor.sf_vec_size
if cutlass.const_expr(self.scenario == "2Dx3D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (fake_m, k, 1) -> offset fake_m, global desc
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, k, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "c"):
# C: (fake_m, n, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(tokens_i, shape[1], c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "sfa"):
# SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc
real = cute.domain_offset(
(padded_offset, 0, 0), gmem_tensor_in_moe_view
)
per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
)
return (real, None)
elif cutlass.const_expr(tensor_name == "sfb"):
# SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
)
return (real, None)
elif cutlass.const_expr(self.scenario == "2Dx2D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (m, fake_k, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, fake_k, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "c"):
# C: (m, n, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "sfa"):
# SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "sfb"):
# SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)
)
return (real, desc)
raise ValueError("Invalid scenario or tensor name.")
@cute.jit
def prefetch_for_expert(self, expert_idx: Int32) -> None:
if cutlass.const_expr(self.scenario == "2Dx3D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
elif cutlass.const_expr(self.scenario == "2Dx2D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx))
else:
raise ValueError("Invalid scenario.")

View File

@@ -0,0 +1,910 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Online TMA Descriptor Construction Utilities.
Provides utilities for dynamically creating TMA descriptors at kernel runtime
based on runtime-provided information (problem sizes, pointers, etc.).
Key components:
- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods)
- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors
- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM
- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM
- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.)
- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type
- compute_expert_token_range: Compute per-expert token offset and count from offs
- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility
"""
from abc import ABC, abstractmethod
from typing import Optional, Literal, Tuple, Union
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import AddressSpace, Pointer
from cutlass.cute.nvgpu import cpasync
from cutlass.cutlass_dsl import dsl_user_op, Int32
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
from cutlass._mlir.dialects import cute as _cute_ir
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
TensormapDescBytes = 128
# =============================================================================
# Pointer Utilities
# =============================================================================
@dsl_user_op
@cute.jit
def spin_wait(
ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None
) -> None:
"""
Generic spin-wait.
Example usage:
# Wait until counter >= total_blocks
spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100)
# Wait until flag == 1
spin_wait(flag_ptr, lambda x: x == 1)
"""
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
while not condition(current):
# Load with L1 cache bypass (ld.global.cg)
if cutlass.const_expr(fail_sleep_cycles > 0):
cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip)
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
@dsl_user_op
def gmem_ptr_to_generic(
gmem_ptr: Pointer,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> Pointer:
if gmem_ptr.memspace != AddressSpace.gmem:
raise ValueError(
f"gmem_ptr_to_generic requires pointer in gmem address space, "
f"got {gmem_ptr.memspace}"
)
# Get LLVM pointer and cast to generic address space
llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip)
generic_llvm_ptr = llvm.addrspacecast(
llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip
)
# Create a new cute.Pointer with generic address space, preserving alignment
return cute.make_ptr(
gmem_ptr.dtype,
generic_llvm_ptr,
AddressSpace.generic,
assumed_align=gmem_ptr.alignment,
loc=loc,
ip=ip,
)
@dsl_user_op
def generic_ptr_to_gmem(
generic_ptr: Pointer,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> Pointer:
if generic_ptr.memspace != AddressSpace.generic:
raise ValueError(
f"generic_ptr_to_gmem requires pointer in generic address space, "
f"got {generic_ptr.memspace}"
)
# Get LLVM pointer and cast to gmem address space
llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip)
gmem_llvm_ptr = llvm.addrspacecast(
llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip
)
# Create a new cute.Pointer with gmem address space, preserving alignment
return cute.make_ptr(
generic_ptr.dtype,
gmem_llvm_ptr,
AddressSpace.gmem,
assumed_align=generic_ptr.alignment,
loc=loc,
ip=ip,
)
@dsl_user_op
def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None:
"""
Prefetch a TMA descriptor from global memory.
This function prefetches the TMA descriptor pointed to by tma_desc_ptr
into the TMA descriptor cache. The pointer must be in generic or global
address space. If a gmem pointer is passed, it will be automatically
converted to generic address space.
:param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory
:type tma_desc_ptr: Pointer
:raises ValueError: If pointer is not in generic or global address space
"""
if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic):
raise ValueError(
f"prefetch_tma_descriptor requires pointer in gmem or generic address space, "
f"got {tma_desc_ptr.memspace}"
)
# Convert gmem pointer to generic if needed
if tma_desc_ptr.memspace == AddressSpace.gmem:
tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip)
# Convert cute.Pointer to LLVM pointer for prefetch
llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip)
from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch
nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip)
def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer:
"""Offset a pointer by a given number of bytes."""
element_offset = byte_offset * 8 // ptr.dtype.width
return ptr + element_offset
@dsl_user_op
def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer:
"""
Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy.
cute.copy requires the tma_desc_ptr to be in generic address space and
recast to TmaDescriptorTiledType. This utility performs both conversions.
:param raw_ptr: Raw pointer to TMA descriptor in gmem
:type raw_ptr: Pointer
:return: Pointer compatible with cute.copy's tma_desc_ptr parameter
:rtype: Pointer
"""
generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip)
tma_desc_ptr_ty = _cute_ir.PtrType.get(
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
generic_ptr.memspace,
generic_ptr.alignment,
)
return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value)
# =============================================================================
# MoE Utilities
# =============================================================================
@dsl_user_op
@cute.jit
def compute_expert_token_range(
offs: cute.Tensor,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""
Compute token offset and count for a given expert from the cumsum offs tensor.
:param offs: Cumulative sum tensor of token counts per expert, shape (experts,)
:param expert_idx: Index of the expert
:return: (token_offset, tokens_i) where token_offset is the start position
and tokens_i is the number of tokens for this expert
"""
token_offset = Int32(0)
if expert_idx > Int32(0):
token_offset = offs[expert_idx - 1] # type: ignore[assignment]
tokens_i = offs[expert_idx] - token_offset
return token_offset, tokens_i
@dsl_user_op
def rewrite_tensor_shape(
tensor: cute.Tensor,
new_shape: Tuple,
*,
loc=None,
ip=None,
) -> cute.Tensor:
"""
Rewrite tensor shape while keeping the same stride and iterator.
This is primarily for debug friendliness - shows the actual expert's shape
instead of the fake global shape. No runtime overhead as it becomes
dead code in non-debug builds.
:param tensor: Source tensor whose stride and iterator to preserve
:param new_shape: New shape to apply
:return: New tensor with the given shape but original stride and iterator
"""
new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip)
return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip)
# =============================================================================
# TMA Descriptor Workspace Helper
# =============================================================================
class TensormapWorkspace:
"""
Helper for linear workspace layout of TMA descriptors.
Manages address calculation for a workspace buffer containing TMA descriptors
organized as: for each executor (e.g., expert or group), a fixed set of
named descriptor slots.
Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...]
Example:
# 2Dx3D MoE: only C is expert-wise
workspace = TensormapWorkspace(workspace_ptr, ["c"])
# 2Dx2D MoE: A and B are expert-wise
workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
# General grouped GEMM: all three tensors
workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"])
"""
def __init__(self, workspace_ptr: Pointer, slot_names: list):
"""
:param workspace_ptr: Pointer to the beginning of the workspace buffer
:param slot_names: Ordered list of tensor names, defining the slot layout
per executor. e.g., ["a", "b", "c"]
"""
self.workspace_ptr = workspace_ptr
self._name_to_slot = {name: i for i, name in enumerate(slot_names)}
self._slots_per_executor = len(slot_names)
@cute.jit
def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
"""
Get the workspace pointer for a specific TMA descriptor.
:param tensor_name: Name of the tensor (must be one of the slot_names)
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
:return: Aligned pointer to the TMA descriptor in workspace
"""
if cutlass.const_expr(tensor_name not in self._name_to_slot):
raise ValueError(
f"Invalid tensor_name '{tensor_name}', "
f"expected one of {list(self._name_to_slot.keys())}"
)
slot = self._name_to_slot[tensor_name]
byte_offset = (
executor_idx * self._slots_per_executor + slot
) * TensormapDescBytes
return ptr_offset_bytes(self.workspace_ptr, byte_offset).align(
TensormapDescBytes
)
@staticmethod
def size_bytes(num_slots: int, num_executors: int) -> int:
"""
Calculate workspace size in bytes.
:param num_slots: Number of descriptor slots per executor
:param num_executors: Total number of executors (e.g., expert_cnt or group_cnt)
:return: Total workspace size in bytes
"""
return num_slots * num_executors * TensormapDescBytes
# =============================================================================
# Online TMA Descriptor Creator (Abstract Base Class)
# =============================================================================
@dataclass(frozen=True)
class OnlineTensormapDescCreator(ABC):
"""
Abstract base class for building TMA descriptors online (at kernel runtime).
Subclasses store all needed parameters (both codegen-time configs and runtime
values) as explicit instance attributes in __init__. No dict-based APIs.
Subclasses must implement exactly 2 abstract methods:
- construct_and_write: Build TMA descriptor(s) for one executor and write to workspace
- get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace
To convert the raw pointer for use with cute.copy, callers should use the
standalone tensormap_ptr_for_copy() utility.
"""
@abstractmethod
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
"""
Build TMA descriptor(s) for one executor and write to workspace.
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx).
Semantics may vary by subclass when ``dependency`` is provided.
:param dependency: Optional pipeline consumer for inter-warp-group
synchronization. When provided, the subclass decides when to wait
(via ``dependency.wait_and_advance()``) and release. The subclass
also decides how to interpret ``executor_idx`` in this mode.
"""
...
@abstractmethod
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
"""
Get the raw gmem pointer to a specific TMA descriptor in workspace.
:param tensor_name: Name identifying which tensor's descriptor
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
:return: Raw pointer (gmem) to the TMA descriptor
"""
...
# =============================================================================
# MoE Grouped GEMM Tensormap Constructor
# =============================================================================
class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
"""
Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only).
Non-expert-wise descriptors are passed directly at kernel launch.
This class only handles:
- 2Dx3D: C descriptors (expert-wise, to avoid write conflicts)
- 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis)
All parameters are stored as explicit instance attributes (no dicts).
Workspace layout:
- 2Dx3D: [C_0, C_1, ..., C_{n-1}]
- 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}]
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
# Codegen-time configs
a_dtype,
b_dtype,
c_dtype,
a_smem_layout,
b_smem_layout,
epi_smem_layout,
a_tma_op,
b_tma_op,
c_tma_op,
tiled_mma,
mma_tiler,
cluster_layout_vmnk_shape,
epi_tile,
# Runtime params
a_tensor: cute.Tensor, # fake GEMM domain A
b_tensor: cute.Tensor, # fake GEMM domain B
c_tensor: cute.Tensor, # fake GEMM domain C
offs: cute.Tensor, # (experts,) cumsum
workspace_ptr: Pointer,
) -> None:
super().__init__()
self.scenario = scenario
# Codegen-time configs
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.a_smem_layout = a_smem_layout
self.b_smem_layout = b_smem_layout
self.epi_smem_layout = epi_smem_layout
self.a_tma_op = a_tma_op
self.b_tma_op = b_tma_op
self.c_tma_op = c_tma_op
self.tiled_mma = tiled_mma
self.mma_tiler = mma_tiler
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
self.epi_tile = epi_tile
# Runtime params
self.a_tensor = a_tensor
self.b_tensor = b_tensor
self.c_tensor = c_tensor
self.offs = offs
# Workspace with scenario-specific slot layout
if scenario == "2Dx3D":
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
else:
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
@staticmethod
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
"""Calculate workspace size in bytes for tensormap descriptors."""
if scenario == "2Dx3D":
return TensormapWorkspace.size_bytes(1, expert_cnt) # only C
else:
return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B
@cute.jit
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
return self.workspace.get_ptr(tensor_name, executor_idx)
@cute.jit
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
"""
Create expert-wise tensormap descriptors for the given expert.
- 2Dx3D: Creates C descriptor for this expert
- 2Dx2D: Creates A and B descriptors for this expert
"""
if cutlass.const_expr(self.scenario == "2Dx3D"):
self._construct_c_desc_2dx3d(executor_idx)
else: # 2Dx2D
self._construct_ab_descs_2dx2d(executor_idx)
@cute.jit
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
"""
2Dx3D: Create expert-wise C descriptor.
C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1)
Slice fake_m -> (tokens_i, intermediate, 1) per expert.
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c_ptr = self.c_tensor.iterator
c_stride = self.c_tensor.stride
intermediate = self.c_tensor.shape[1] # type: ignore[index]
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index]
c_layout_i = cute.make_layout(
(tokens_i, intermediate, c1),
stride=(c_stride[0], c_stride[1], c0), # type: ignore[index]
)
c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i)
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
self.c_tma_op,
c_tensor_i,
self.epi_smem_layout,
self.epi_tile,
)
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
@cute.jit
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
"""
2Dx2D: Create expert-wise A and B descriptors.
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
# A tensor: (m, fake_k, 1) -> (m, tokens_i, 1)
a_ptr = self.a_tensor.iterator
a_stride = self.a_tensor.stride
a_m = self.a_tensor.shape[0] # type: ignore[index]
a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index]
a_layout_i = cute.make_layout(
(a_m, tokens_i, c1),
stride=(a_stride[0], a_stride[1], c0), # type: ignore[index]
)
a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i)
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.a_tma_op,
a_tensor_i,
self.a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
# B tensor: (n, fake_k, 1) -> (n, tokens_i, 1)
b_ptr = self.b_tensor.iterator
b_stride = self.b_tensor.stride
b_n = self.b_tensor.shape[0] # type: ignore[index]
b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index]
b_layout_i = cute.make_layout(
(b_n, tokens_i, c1),
stride=(b_stride[0], b_stride[1], c0), # type: ignore[index]
)
b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i)
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.b_tma_op,
b_tensor_i,
self.b_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
# =============================================================================
# MoE Scaled Grouped GEMM Tensormap Constructor
# =============================================================================
class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
"""
Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled).
.. py:attribute:: ChunkSize
:value: 128
Number of experts processed per chunk in the desc_init_kernel.
Must match the warp-group width (4 warps × 32 threads).
Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support.
Expert-wise descriptors only — non-expert-wise descriptors are passed
directly at kernel launch.
Workspace layout:
- 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert)
- 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert)
:param scenario: "2Dx3D" or "2Dx2D"
:param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4)
:param a_dtype: Data type for tensor A
:param b_dtype: Data type for tensor B
:param c_dtype: Data type for tensor C
:param sf_dtype: Data type for scale factors (SFA/SFB)
:param a_smem_layout: SMEM layout for A TMA
:param b_smem_layout: SMEM layout for B TMA
:param epi_smem_layout: SMEM layout for epilogue (C) TMA
:param sfa_smem_layout: SMEM layout for SFA TMA
:param sfb_smem_layout: SMEM layout for SFB TMA
:param a_tma_op: TMA operation for A
:param b_tma_op: TMA operation for B
:param c_tma_op: TMA operation for C (S2G store or reduce)
:param sfa_tma_op: TMA operation for SFA
:param sfb_tma_op: TMA operation for SFB
:param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction
:param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication)
:param mma_tiler: MMA tiler shape (M, N, K)
:param mma_tiler_sfb: MMA tiler shape for SFB
:param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast
:param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast
:param epi_tile: Epilogue tile shape
:param a_tensor: Fake GEMM domain A tensor
:param b_tensor: Fake GEMM domain B tensor
:param c_tensor: Fake GEMM domain C tensor
:param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout)
:param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout)
:param offs: (experts,) cumsum offsets in data domain
:param offs_padded: (experts,) cumsum offsets in padded scale domain
:param workspace_ptr: Pointer to workspace for TMA descriptors
:param expert_cnt: Total number of experts
"""
ChunkSize = 128
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
sf_vec_size: int,
# Codegen-time configs: dtypes
a_dtype,
b_dtype,
c_dtype,
sf_dtype,
# Codegen-time configs: SMEM layouts
a_smem_layout,
b_smem_layout,
epi_smem_layout,
sfa_smem_layout,
sfb_smem_layout,
# Codegen-time configs: TMA ops
a_tma_op,
b_tma_op,
c_tma_op,
sfa_tma_op,
sfb_tma_op,
# Codegen-time configs: MMA / cluster / tile
tiled_mma,
tiled_mma_sfb,
mma_tiler,
mma_tiler_sfb,
cluster_layout_vmnk_shape,
cluster_layout_sfb_vmnk_shape,
epi_tile,
# Runtime params
a_tensor: cute.Tensor,
b_tensor: cute.Tensor,
c_tensor: cute.Tensor,
sfa_tensor: cute.Tensor,
sfb_tensor: cute.Tensor,
offs: cute.Tensor,
offs_padded: cute.Tensor,
workspace_ptr: Pointer,
expert_cnt: Optional[Union[Int32, int]] = None,
) -> None:
super().__init__()
self.scenario = scenario
self.sf_vec_size = sf_vec_size
# Dtypes
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.sf_dtype = sf_dtype
# SMEM layouts
self.a_smem_layout = a_smem_layout
self.b_smem_layout = b_smem_layout
self.epi_smem_layout = epi_smem_layout
self.sfa_smem_layout = sfa_smem_layout
self.sfb_smem_layout = sfb_smem_layout
# TMA ops
self.a_tma_op = a_tma_op
self.b_tma_op = b_tma_op
self.c_tma_op = c_tma_op
self.sfa_tma_op = sfa_tma_op
self.sfb_tma_op = sfb_tma_op
# MMA / cluster / tile
self.tiled_mma = tiled_mma
self.tiled_mma_sfb = tiled_mma_sfb
self.mma_tiler = mma_tiler
self.mma_tiler_sfb = mma_tiler_sfb
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape
self.epi_tile = epi_tile
# Runtime params
self.a_tensor = a_tensor
self.b_tensor = b_tensor
self.c_tensor = c_tensor
self.sfa_tensor = sfa_tensor
self.sfb_tensor = sfb_tensor
self.offs = offs
self.offs_padded = offs_padded
self.expert_cnt = expert_cnt
# Workspace with scenario-specific slot layout
if scenario == "2Dx3D":
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
else:
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"])
@staticmethod
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
"""Calculate workspace size in bytes for tensormap descriptors."""
if scenario == "2Dx3D":
return TensormapWorkspace.size_bytes(1, expert_cnt) # C only
else:
return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB
@cute.jit
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
return self.workspace.get_ptr(tensor_name, executor_idx)
@cute.jit
def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None:
"""
Create expert-wise tensormap descriptors for all experts.
``lane_in_group`` is the thread's position within its warp group
(0..ChunkSize-1). The method loops internally over all experts in
chunks of ``ChunkSize``, with two-phase pipeline synchronization
per chunk.
Per-chunk execution:
1. Phase 1: Build descriptors that do NOT depend on ``offs_padded``
(A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum.
2. Barrier: ``consumer.wait_and_advance()`` — all threads participate.
3. Phase 2: Build descriptors that depend on ``offs_padded``
(SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer.
4. Release: ``handle.release()`` — all threads participate.
:param lane_in_group: Thread's position within the warp group (0..127).
:param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the
consumer for mbarrier sync, and the SMEM tensor of shape
``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``.
"""
consumer, smem_offs_padded = dependency
assert self.expert_cnt is not None
num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize
chunk_idx = cutlass.Int32(0)
while chunk_idx < num_chunks:
expert_idx = chunk_idx * self.ChunkSize + lane_in_group
in_bounds = expert_idx < self.expert_cnt
# Phase 1: non-dependent descriptors
if in_bounds:
if cutlass.const_expr(self.scenario == "2Dx2D"):
self._construct_ab_descs_2dx2d(expert_idx)
else:
self._construct_c_desc_2dx3d(expert_idx)
# All threads participate in barrier (fixed arrive count)
handle = consumer.wait_and_advance()
# Phase 2: dependent descriptors (read padded offsets from SMEM)
if in_bounds:
if cutlass.const_expr(self.scenario == "2Dx2D"):
# smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]]
# padded_offset = smem[lane] (prev expert's cumulative)
# padded_end = smem[lane + 1] (this expert's cumulative)
padded_offset = smem_offs_padded[lane_in_group]
padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset
self._construct_sf_descs_2dx2d_direct(
expert_idx, padded_offset, padded_size_i
)
# All threads release (fixed arrive count)
handle.release()
chunk_idx += 1
# -----------------------------------------------------------------
# 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor)
# -----------------------------------------------------------------
@cute.jit
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
"""
2Dx3D: Create expert-wise C descriptor.
C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert.
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor)
c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index]
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
self.c_tma_op,
c_i,
self.epi_smem_layout,
self.epi_tile,
)
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
# -----------------------------------------------------------------
# 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor)
# -----------------------------------------------------------------
@cute.jit
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
"""
2Dx2D: Create expert-wise A and B descriptors.
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
# A: (m, fake_k, 1) -> domain_offset + rewrite shape
a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor)
a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.a_tma_op,
a_i,
self.a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
# B: (n, fake_k, 1) -> domain_offset + rewrite shape
b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor)
b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.b_tma_op,
b_i,
self.b_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
# -----------------------------------------------------------------
# 2Dx2D: SFA, SFB descriptors (new for block-scaled)
# -----------------------------------------------------------------
@cute.jit
def _construct_sf_descs_2dx2d_direct(
self,
expert_idx: Int32,
padded_offset: Int32,
padded_size_i: Int32,
) -> None:
"""
2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed
padded offset and size.
This variant allows the caller to supply padded offsets from SMEM
(in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM.
"""
c1 = cutlass.Int32(1)
a_chunks_to_move = (
padded_offset
// self.sf_vec_size
* cute.size(self.sfa_tensor, mode=[0])
// 128
)
a_elems_to_move = (
cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size
)
b_chunks_to_move = (
padded_offset
// self.sf_vec_size
* cute.size(self.sfb_tensor, mode=[0])
// 128
)
b_elems_to_move = (
cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size
)
per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size)
sfa_i = cute.make_tensor(
self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i
)
tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.sfa_tma_op,
sfa_i,
self.sfa_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
internal_type=cutlass.Uint64,
)
cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx))
per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size)
sfb_i = cute.make_tensor(
self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i
)
tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.sfb_tma_op,
sfb_i,
self.sfb_smem_layout,
self.mma_tiler_sfb,
self.tiled_mma_sfb,
self.cluster_layout_sfb_vmnk_shape,
internal_type=cutlass.Uint64,
)
cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -47,11 +47,11 @@ from cutlass.cute.runtime import make_ptr
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
examples_dir = os.path.join(current_dir, "..", "..")
examples_dir = os.path.join(current_dir, "..", "..", "..", "..")
if examples_dir not in sys.path:
sys.path.insert(0, examples_dir)
from blackwell.tutorial_gemm.utils import create_parser, run
from cute.blackwell.tutorial.tutorial_gemm.utils import create_parser, run
mma_tiler_mn = (128, 256)
mma_inst_shape_k = 64

View File

@@ -49,11 +49,11 @@ from cutlass.cute.runtime import from_dlpack, make_ptr
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
examples_dir = os.path.join(current_dir, "..", "..")
examples_dir = os.path.join(current_dir, "..", "..", "..", "..")
if examples_dir not in sys.path:
sys.path.insert(0, examples_dir)
from blackwell.tutorial_gemm.utils import create_parser, run
from cute.blackwell.tutorial.tutorial_gemm.utils import create_parser, run
mma_tiler_mn = (256, 256)
mma_inst_shape_k = 64

View File

@@ -0,0 +1,385 @@
# CUTLASS Tutorial Examples for Blackwell TMA
## TMA V0: Understanding tma_partition - The Foundation of TMA Operations
This example demonstrates the fundamental building blocks of Tensor Memory Accelerator (TMA) operations on NVIDIA Blackwell (SM100) architecture using CuTe DSL. It focuses on understanding the `tma_partition` interface, which is essential for all TMA operations.
### Key Concepts
* **TMA Load (Global → Shared)**: Asynchronous bulk data transfer from Global Memory to Shared Memory using TMA hardware
* **TMA Store (Shared → Global)**: Asynchronous bulk data transfer from Shared Memory back to Global Memory
* **tma_partition**: Core interface that prepares tensors for TMA operations by partitioning them according to TMA atom layout requirements
* **group_modes**: Tensor mode grouping to define the TMA atom shape - crucial for proper tma_partition usage
* **mbarrier Synchronization**: Hardware barriers (`mbarrier_init`, `mbarrier_arrive`, `mbarrier_wait`) for synchronizing asynchronous TMA operations
### Kernel Architecture
The `Sm100SimpleCopyKernel` performs a simple tile-based copy operation to illustrate TMA fundamentals:
#### Configuration
* **tile_shape**: Fixed at (128, 128) - Tile dimensions (M, N)
* **cluster_shape_mn**: Fixed at (1, 1) - Single CTA execution (no cluster parallelism)
* **Shared Memory**: Single buffer sized to hold one tile (tile_m × tile_n elements)
* **Synchronization**: Single mbarrier for TMA load completion
#### Key Components
1. **TMA Descriptor Creation**: Creates TMA atoms (`tma_atom_src`, `tma_atom_dst`) that encapsulate TMA hardware instructions
2. **Shared Memory Layout**: Row-major layout `(tile_m, tile_n):(tile_n, 1)` for simplicity
3. **Barrier Management**: Single barrier coordinates TMA load completion before processing
### Execution Flow
1. **Initialization**:
* Allocate shared memory buffer for one tile
* Initialize mbarrier with `elect_one()` to ensure proper synchronization semantics
* Set barrier to expect TMA transaction bytes (`tile_m × tile_n × element_size`)
* Synchronize all threads after barrier initialization
2. **Tensor Preparation**:
```
Tile global tensors into (tile_m, tile_n) blocks
Apply group_modes to combine tile dimensions into Mode 0 (TMA atom)
Example: gSrc_tiled with shape (128, 128, 4, 2)
After group_modes(_, 0, 2): ((128, 128), 4, 2)
└───┬────┘ └─┬─┘
Mode 0 Rest modes
```
3. **TMA Partition**:
```python
tAsA, tAgA = tma_partition(
tma_atom_src, # TMA operation atom
cta_id=0, # CTA ID within cluster
cta_layout, # Cluster layout
group_modes(smem_tensor, 0, 2), # SMEM view (Mode 0 = atom)
group_modes(gSrc_tiled, 0, 2) # Global view (Mode 0 = atom)
)
# Returns:
# tAsA: SMEM view with TMA internal layout
# tAgA: Global view with shape ((TMA_Layout), grid_m, grid_n)
```
4. **Tile Selection**:
```python
# Select specific tile for this CTA
tAgA_cta = tAgA[(None, bidx, bidy)]
# None: keep entire TMA atom (Mode 0)
# bidx, bidy: index into grid dimensions
```
5. **TMA Load** (Global → Shared):
```python
cute.copy(tma_atom_src, tAgA_cta, tAsA, tma_bar_ptr=barrier_ptr)
# Arrive on barrier (producer signals completion)
# Wait on barrier (all threads wait for TMA completion)
```
6. **TMA Store** (Shared → Global):
```python
cute.copy(tma_atom_dst, tAsA, tBgB_cta)
# Synchronous store completes before kernel exit
```
### Understanding tma_partition
The `tma_partition` function is the key to TMA operations. The tutorial includes comprehensive inline diagrams showing:
* **Input Tensor Shapes**: How tensors are organized before partitioning
* **group_modes Effect**: How mode grouping creates the required atom structure
* **Partition Output**: The structure of partitioned tensors for TMA operations
* **Indexing Pattern**: How to select CTA-specific tiles from partitioned views
See lines ~155-255 in `tma_v0.py` for detailed visual explanations.
### Configuration Parameters
* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N)
* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution
* `threads_per_cta`: 32 - Single warp (all threads participate in barriers)
* `buffer_align_bytes`: 1024 - Shared memory alignment
### Usage
Run the copy kernel with custom matrix dimensions:
```bash
# Basic usage (default: 512×128 matrix)
python tma_v0.py
# Custom dimensions
python tma_v0.py --M 1024 --N 2048
# With custom benchmark iterations
python tma_v0.py --M 4096 --N 4096 --num_warmup 10 --num_iters 50
```
#### Example Code
```python
from tma_v0 import run_tma_copy
# Run copy on 1024×2048 matrix
run_tma_copy(M=1024, N=2048)
# Output: Performance metrics and verification result
```
### Performance Considerations
* **Single-stage operation**: No pipelining - simple sequential load → store pattern
* **Educational focus**: Designed for understanding TMA fundamentals, not peak performance
* **Barrier synchronization**: Demonstrates proper mbarrier usage patterns
* **Foundation for V1/V2**: Concepts learned here are essential for understanding multi-stage pipelines and warp specialization in V1 and V2
## TMA V1: Matrix Transpose with Producer-Consumer Pattern
This example demonstrates a TMA-based matrix transpose using producer-consumer synchronization with mbarriers on NVIDIA Blackwell (SM100) architecture.
### Key Concepts
* **Producer-Consumer Pattern**: Different warps coordinate through mbarrier synchronization
* **TMA Operations**: Asynchronous bulk data transfer between Global and Shared Memory
* **Shared Memory Swizzle**: Optimized layouts to avoid bank conflicts during transpose
* **Warp Specialization**: Dedicated warps for loading, transposing, and storing
* **mbarrier Synchronization**: Hardware barriers coordinate asynchronous operations
### Kernel Architecture
The `Sm100MatrixTransposeKernel` performs a tiled matrix transpose (M×N → N×M) with the following design:
#### Warp Roles
1. **TMA Load Warp** (Warp 4, Producer): Issues TMA load operations from Global Memory to Shared Memory buffer `sA`
2. **Transpose Warps** (Warps 0-3, 4 warps, Consumer/Producer):
* Wait for `sA` to be filled by TMA load (consumer of load_mbar)
* Transpose data from `sA` → Registers → `sB`
* Signal completion to TMA Store warp (producer for store_mbar)
3. **TMA Store Warp** (Warp 5, Consumer):
* Wait for `sB` to be ready (consumer of store_mbar)
* Issue TMA store operations from `sB` to Global Memory
#### Synchronization Barriers
The kernel uses two mbarrier instances for producer-consumer coordination:
1. **load_mbar_ptr**: Synchronizes TMA Load → Transpose Warps
* Producer: TMA Load Warp (arrives after TMA completes)
* Consumer: Transpose Warps (wait before reading `sA`)
* Expected arrivals: 1 (from TMA load warp)
* Expected transactions: `tile_m × tile_n × element_size` bytes
2. **store_mbar_ptr**: Synchronizes Transpose Warps → TMA Store
* Producer: Transpose Warps (each warp arrives after writing to `sB`)
* Consumer: TMA Store Warp (waits before issuing TMA store)
* Expected arrivals: 4 (one from each transpose warp)
#### Execution Flow
1. **Initialization**:
* All warps participate in barrier initialization (thread 0 initializes)
* Allocate two shared memory buffers: `sA` (row-major) and `sB` (column-major)
* Create TMA descriptors for source and transposed destination
* Initialize `load_mbar` with expected count of 1 and transaction bytes
* Initialize `store_mbar` with expected count of 4 (number of transpose warps)
2. **TMA Load Warp** (Producer for Load Pipeline):
```
partition source tensor by tile shape
issue TMA load: Global[block_tile] → sA
arrive on load_mbar to signal completion
```
3. **Transpose Warps** (Consumer for Load, Producer for Store):
```
wait on load_mbar for TMA load to complete
partition sA for reading (each thread handles subset)
copy data: sA → Registers
partition sB for writing
copy data: Registers → sB (transpose happens via layout)
fence to ensure smem writes are visible
synchronize with trans_sync_barrier
[elect one thread] arrive on store_mbar to signal completion
```
4. **TMA Store Warp** (Consumer for Store Pipeline):
```
wait on store_mbar for transpose to complete
partition destination tensor by tile shape
issue TMA store: sB → Global[block_tile]
```
### Key Features
* **Simple Producer-Consumer Model**: Clear separation of concerns with dedicated warps
* **Efficient Synchronization**: Hardware mbarriers minimize synchronization overhead
* **Memory Layout Optimization**: Swizzled layouts prevent bank conflicts
* **Transposition via Layouts**: Transpose is achieved through different memory layouts for `sA` and `sB`
### Configuration Parameters
* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N)
* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution
* Warp count: 6 warps (1 TMA Load + 4 Transpose + 1 TMA Store)
### Usage
Run the transpose kernel with custom matrix dimensions:
```bash
# Basic usage (128×128 matrix)
python tma_v1.py
# Custom dimensions
python tma_v1.py --M 1024 --N 2048
# With custom benchmark iterations
python tma_v1.py --M 4096 --N 4096 --num_warmup 10 --num_iters 50
```
#### Example Code
```python
from tma_v1 import run_transpose
# Run transpose on 1024×2048 matrix
run_transpose(M=1024, N=2048)
# Output: Performance metrics and verification result
```
### Performance Considerations
* **Single-stage pipeline**: Simpler than multi-stage but with potential for idle warps
* **Warp specialization**: Clear roles minimize synchronization complexity
* **Good for learning**: Demonstrates fundamental TMA and mbarrier concepts
## TMA V2: Transpose with Multi-Stage Pipeline
This example demonstrates a TMA implementation with multi-stage pipelining for efficient matrix transpose operations on NVIDIA Blackwell (SM100) architecture.
### Key Concepts
* **Multi-Stage Pipeline**: Multiple buffers enable overlapping TMA loads, computation (transpose), and TMA stores to hide memory latency
* **Pipeline Abstractions**: `PipelineTmaAsync` and `PipelineTmaStore` provide producer-consumer synchronization
* **Persistent Tile Scheduler**: Efficient work distribution across CTAs for dynamic load balancing
* **Shared Memory Swizzle**: Optimized shared memory layouts to avoid bank conflicts
* **Warp Specialization**: Different warps handle loading and transposing; first transpose warp also handles storing
### Kernel Architecture
The `Sm100MatrixTransposeKernelV2` performs a tiled matrix transpose (M×N → N×M) with the following design:
#### Warp Roles
1. **TMA Load Warp** (Producer): Loads tiles from Global Memory to Shared Memory buffer `sA` using TMA load operations
2. **Transpose Warps** (4 warps, Consumer/Producer):
* Wait for data in `sA` from load pipeline
* Transpose data from `sA` → Registers → `sB`
* Synchronize via named barrier to ensure all transpose warps complete
* First transpose warp (`trans_warp_id[0]`) issues TMA store operations from `sB` to Global Memory
#### Pipeline Stages
The kernel uses two separate pipelines for maximum parallelism:
1. **Load Pipeline**: `TMA Load Warp` (producer) → `Transpose Warps` (consumer)
* Multi-stage buffer `sA` (automatically computed based on available SMEM)
* Enables prefetching multiple tiles while processing current tile
2. **Store Pipeline**: `Transpose Warps` (producer) → `First Transpose Warp` (consumer, issues TMA store)
* Multi-stage buffer `sB` (automatically computed based on available SMEM)
* First transpose warp handles TMA store after all transpose warps finish writing to `sB`
#### Stage Calculation
The kernel automatically computes the optimal number of pipeline stages using `_compute_stages()`:
* Calculates bytes needed per stage for `sA` (tile_m × tile_n) and `sB` (tile_m × tile_n)
* Reserves space for pipeline barriers and metadata (~1KB)
* Divides remaining shared memory by bytes per stage
* Clamps to 2-8 stages (2 for double buffering minimum, 8 for diminishing returns)
#### Execution Flow
1. **Initialization**:
* Allocate multi-stage shared memory buffers (`sA_staged`, `sB_staged`)
* Create TMA descriptors for source and transposed destination
* Initialize load and store pipelines with barrier synchronization
* Use persistent tile scheduler to distribute work tiles across CTAs
2. **TMA Load Warp** (Producer for Load Pipeline):
```
for each tile assigned by scheduler:
acquire next available stage in load pipeline
issue TMA load: Global[tile] → sA[stage]
advance to next stage
```
3. **Transpose Warps** (Consumer for Load, Producer for Store):
```
for each tile assigned by scheduler:
wait for load pipeline to fill current stage
copy data: sA[load_stage] → Registers
release load pipeline stage
transpose and write: Registers → sB[store_stage]
fence to ensure smem writes are visible
synchronize all transpose warps with barrier
[first transpose warp only] issue TMA store: sB[stage] → Global[tile]
[first transpose warp only] commit to store pipeline
[first transpose warp only] acquire next store pipeline stage
synchronize all transpose warps with barrier
```
4. **Pipeline Teardown**:
* Producer/consumer tail operations ensure all in-flight operations complete
### Key Features
* **Automatic Stage Optimization**: Kernel calculates optimal number of stages based on tile size, data type, and available shared memory
* **Persistent Tile Scheduler**: Efficient work distribution across CTAs for dynamic load balancing
* **Memory Layout Optimization**: Uses appropriate swizzling for row-major input and column-major transposed output
* **Efficient Synchronization**: Named barriers for intra-CTA coordination, mbarriers for pipeline stages
### Configuration Parameters
* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N)
* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution
* Number of pipeline stages: Automatically computed based on available SMEM
### Usage
Run the transpose kernel with custom matrix dimensions:
```bash
# Basic usage (128×128 matrix)
python tma_v2.py
# Custom dimensions
python tma_v2.py --M 1024 --N 2048
python tma_v2.py --M 1024 --N 2048 --num_warmup 10 --num_iters 50
```
#### Example Code
```python
from tma_v2 import run_transpose
# Run transpose on 1024×2048 matrix
run_transpose(M=1024, N=2048)
# Output: "TransposeSuccess!" if verification passes
```
### Performance Considerations
* **Multi-stage pipelining** hides memory latency by overlapping loads, computation, and stores
* **Persistent scheduling** provides better load balancing for irregular matrix sizes
* **Warp specialization** maximizes throughput: TMA load warp handles all loads, transpose warps handle computation, and first transpose warp handles stores
## TMA V3: TMA With MMA (Tensor Cores)
TBD

View File

@@ -0,0 +1,409 @@
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from typing import Type, Union
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync
from cutlass.cute.runtime import from_dlpack
import torch
"""
TMA V0: Understanding tma_partition - The Foundation of TMA Operations
This tutorial demonstrates TMA (Tensor Memory Accelerator) operations through
a simple copy kernel. It focuses on understanding the fundamental tma_partition
interface, which is the key to all TMA operations.
What This Tutorial Covers:
1. TMA Load (Global Memory -> Shared Memory)
2. TMA Store (Shared Memory -> Global Memory)
3. Barrier synchronization for TMA (elect_one, mbarrier_init, mbarrier_arrive, mbarrier_wait)
4. Detailed explanation of tma_partition with visual diagrams
Key Learning Points:
- tma_partition: How it transforms tensors for TMA operations
- group_modes: Why and how to group tensor modes to define TMA atom shape
- Indexing: How to select specific tiles from partitioned tensors
- Data flow: Complete visualization from input tensors to TMA copy
Visual Diagrams:
See line ~155 for comprehensive diagrams showing:
- Input tensor shapes and transformations
- group_modes effect on tensor layouts
- tma_partition output structure
- Complete data flow from global/shared memory to TMA copy
- Indexing pattern for CTA-specific tiles
Example Usage:
```bash
python cutlass_ir/compiler/python/examples/cute/blackwell/tutorial/tutorial_tma/tma_v0.py
```
"""
class Sm100SimpleCopyKernel:
def __init__(self):
"""
Initializes the configuration for a Blackwell TMA copy kernel.
"""
self.tile_shape = (128, 128)
self.tile_m, self.tile_n = self.tile_shape
self.cluster_shape_mn = (1, 1)
self.threads_per_cta = 32
self.buffer_align_bytes = 1024
@cute.jit
def __call__(self, src: cute.Tensor, dst: cute.Tensor):
if cutlass.const_expr(src.element_type != dst.element_type):
raise TypeError("Source and destination element types must match")
self.dtype: Type[cutlass.Numeric] = src.element_type
# layout for each cta: (tile_m, tile_n):(tile_n, 1)
smem_layout = cute.make_layout(
(self.tile_m, self.tile_n), stride=(self.tile_n, 1)
)
@cute.struct
class SharedStorage:
barrier_storage: cute.struct.MemRange[cutlass.Int64, 1]
smem_data: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout)],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout)
# cta_tiler: the per-CTA tile extents (M, N) used by TMA.
# Note: smem_layout may include swizzle or composed layout,
# so we use product_each(...) to take the product along each logical dimension and get
# the final (tile_m, tile_n) extents expected by TMA.
# In this simple example, smem_layout.shape == (tile_m, tile_n), so product_each(...) is
# just (tile_m, tile_n).
cta_tiler = cute.product_each(smem_layout.shape)
tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileG2SOp(), src, smem_layout, cta_tiler
)
tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), dst, smem_layout, cta_tiler
)
# Grid shape is now (M/TileM, N/TileN)
grid_shape = cute.ceil_div((*src.layout.shape, 1), self.tile_shape)
self.kernel(
tma_atom_src, tma_tensor_src, tma_atom_dst, tma_tensor_dst, smem_layout
).launch(
grid=grid_shape,
block=(self.threads_per_cta, 1, 1),
cluster=(*self.cluster_shape_mn, 1),
)
@cute.kernel
def kernel(
self,
tma_atom_src: cute.CopyAtom,
tma_tensor_src: cute.Tensor,
tma_atom_dst: cute.CopyAtom,
tma_tensor_dst: cute.Tensor,
smem_layout: Union[cute.Layout, cute.ComposedLayout],
):
bidx, bidy, _ = cute.arch.block_idx()
# Allocate Shared Memory
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
# Initialize barrier for TMA synchronization
barrier_ptr = storage.barrier_storage.data_ptr()
# Initialize the barrier: elect_one ensures only one thread executes this
# Note: We must use elect_one() instead of "if tid == 0" because:
# - elect_one() provides proper synchronization semantics
# - It ensures all threads are aware that exactly one thread is executing
# - It prevents race conditions and provides memory ordering guarantees
with cute.arch.elect_one():
cute.arch.mbarrier_init(barrier_ptr, 1)
cute.arch.mbarrier_expect_tx(barrier_ptr, self.num_tma_load_bytes)
# Fence ensures init/expect_tx are visible before proceeding
cute.arch.mbarrier_init_fence()
cute.arch.barrier()
# Tile the (M, N) tensor: ((TileM, TileN), M/TileM, N/TileN)
gSrc_tiled = cute.local_tile(
tma_tensor_src, (self.tile_m, self.tile_n), (None, None)
)
gDst_tiled = cute.local_tile(
tma_tensor_dst, (self.tile_m, self.tile_n), (None, None)
)
smem_tensor = storage.smem_data.get_tensor(smem_layout)
# ======================================================================
# TMA Partition: Tensor Preparation for TMA Operations
# ======================================================================
#
# tma_partition prepares tensors for TMA copy by partitioning them
# according to the TMA atom's internal layout requirements.
#
# Signature:
# tma_partition(atom, cta_id, cta_layout, smem_tensor, gmem_tensor)
# -> (smem_view, gmem_view)
#
# Key Requirement: Mode 0 of both tensors must represent the TMA atom
#
# Example: M=512, N=128, TileM=128, TileN=64
#
# Input Tensors:
# gSrc_tiled: (128, 64, 4, 2) # 4 separate modes
# └──┬──┘ └──┬──┘
# Tile Grid
#
# smem_tensor: (128, 64) # 2 separate modes
# └──┬──┘
# Tile
#
# Apply group_modes(tensor, 0, 2) to group first 2 modes:
#
# group_modes(gSrc_tiled, 0, 2) => ((128, 64), 4, 2)
# └───┬───┘
# Mode 0 = Atom
#
# group_modes(smem_tensor, 0, 2) => ((128, 64),)
# └───┬───┘
# Mode 0 = Atom
#
# After tma_partition:
#
# tAsA: SMEM view with TMA internal layout
# Shape: ((TMA_Layout),)
# - TMA_Layout: Swizzled/banked layout for efficient SMEM access
#
# tAgA: Global view preserving rest modes
# Shape: ((TMA_Layout), 4, 2)
# └─────┬─────┘ └──┬──┘
# TMA atom Rest modes (grid)
#
# Usage Pattern:
# 1. Group modes to define atom: group_modes(tensor, 0, 2)
# 2. Call tma_partition: tAsA, tAgA = tma_partition(...)
# 3. Select tile for CTA: tAgA_cta = tAgA[(None, bidx, bidy)]
# - None: keep entire atom
# - bidx, bidy: index into rest modes
# 4. Issue TMA copy: cute.copy(atom, tAgA_cta, tAsA)
#
# Visual Data Flow:
#
# Global Memory (512x128) Shared Memory (128x64)
# ┌────────────────────┐ ┌──────────────┐
# │ ┌───┬───┐ │ │ │
# │ │0,0│0,1│ │ │ smem_tensor │
# │ ├───┼───┤ │ │ (128, 64) │
# │ │1,0│1,1│ 4x2 │ │ │
# │ ├───┼───┤ tiles │ └──────────────┘
# │ │2,0│2,1│ │ │
# │ ├───┼───┤ │ │ group_modes
# │ │3,0│3,1│ │ ↓
# │ └───┴───┘ │ ((128, 64),)
# └────────────────────┘ │
# │ │
# │ gSrc_tiled │
# │ (128, 64, 4, 2) │
# ↓ │
# group_modes(_, 0, 2) │
# ↓ │
# ((128, 64), 4, 2) │
# │ │
# └────────┬────────────────────────┘
# ↓
# tma_partition
# ↓
# ┌──────────────┴──────────────┐
# │ │
# tAgA tAsA
# ((TMA_Layout), 4, 2) ((TMA_Layout),)
# │
# ↓ tAgA[(None, bidx, bidy)]
# tAgA_cta
# ((TMA_Layout),)
#
# ======================================================================
# TMA Load partition
# Here we only use 1x1 cluster, so cta_id is 0 and cta_layout is (1).
# More details about how to set cta_coord and cta_layout can be found in the tma_v4.py
# Note: Smem and gemm should have the same size (atom element size) in the first rank
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
tma_atom_src,
0,
cute.make_layout(1),
cute.group_modes(smem_tensor, 0, 2),
cute.group_modes(gSrc_tiled, 0, 2),
)
# TMA Store partition
# Same process as TMA Load, but for destination tensor
# Partitions gDst_tiled and smem_tensor according to TMA Store atom
_, tBgB = cute.nvgpu.cpasync.tma_partition(
tma_atom_dst,
0,
cute.make_layout(1),
cute.group_modes(smem_tensor, 0, 2),
cute.group_modes(gDst_tiled, 0, 2),
)
# Select specific tile for this CTA from partitioned global views
# Input: tAgA with shape ((TMA_Layout), 4, 2)
# Output: tAgA_cta with shape ((TMA_Layout),)
# The (None, bidx, bidy) indexing:
# - None: keeps the entire TMA atom layout (mode 0)
# - bidx: selects from rest mode 1 (M dimension grid)
# - bidy: selects from rest mode 2 (N dimension grid)
tAgA_cta = tAgA[(None, bidx, bidy)]
tBgB_cta = tBgB[(None, bidx, bidy)]
# ---------- TMA Load: Global -> Shared ----------
cute.copy(
tma_atom_src,
tAgA_cta, # Source (TMA Tensor View)
tAsA, # Dest (SMEM Tensor View)
tma_bar_ptr=barrier_ptr,
)
# Signal arrival on the barrier after TMA is issued
with cute.arch.elect_one():
cute.arch.mbarrier_arrive(barrier_ptr)
# Wait for TMA to complete
cute.arch.mbarrier_wait(barrier_ptr, 0)
# ---------- TMA Store: Shared -> Global ----------
cute.copy(
tma_atom_dst,
tAsA, # Source (SMEM Tensor View)
tBgB_cta, # Dest (Global Tensor View)
)
def run_tma_copy(M, N, num_warmup=5, num_iters=20):
"""
Run TMA copy kernel with performance measurement.
Args:
M: Matrix dimension M
N: Matrix dimension N
num_warmup: Number of warmup iterations
num_iters: Number of timing iterations
"""
# Create tensors with shape (M, N)
a = torch.randn((M, N), dtype=torch.float16, device="cuda")
b = torch.zeros((M, N), dtype=torch.float16, device="cuda")
# Notice: We declare N-dimension as the leading dimension should be divisible by 16
a_cute = (
from_dlpack(a, assumed_align=16)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=16)
)
b_cute = (
from_dlpack(b, assumed_align=16)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=16)
)
copy_kernel = Sm100SimpleCopyKernel()
compiled_kernel = cute.compile(copy_kernel, a_cute, b_cute)
# Warmup runs
for _ in range(num_warmup):
compiled_kernel(a_cute, b_cute)
torch.cuda.synchronize()
# Timed runs
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_iters):
compiled_kernel(a_cute, b_cute)
end_event.record()
torch.cuda.synchronize()
# Calculate performance metrics
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_time_ms / num_iters
# Calculate throughput
# For copy: read M*N elements + write M*N elements
bytes_per_element = a.element_size()
total_bytes = 2 * M * N * bytes_per_element # Read + Write
throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000)
# Print performance metrics
print(f"Matrix size: {M}×{N}")
print(f"Tile shape: {copy_kernel.tile_shape}")
print(f"Average time: {avg_time_ms:.4f} ms")
print(f"Throughput: {throughput_gb_s:.2f} GB/s")
# Verify
if torch.allclose(a, b, atol=1e-3):
print("Verification: PASSED ✓")
else:
print("Verification: FAILED ✗")
diff = (a - b).abs()
print(f"Max diff: {diff.max()}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="TMA V0: Understanding tma_partition - The Foundation of TMA Operations"
)
parser.add_argument("--M", type=int, default=512, help="Matrix dimension M")
parser.add_argument("--N", type=int, default=128, help="Matrix dimension N")
parser.add_argument(
"--num_warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--num_iters", type=int, default=20, help="Number of timing iterations"
)
args = parser.parse_args()
run_tma_copy(
M=args.M,
N=args.N,
num_warmup=args.num_warmup,
num_iters=args.num_iters,
)

View File

@@ -0,0 +1,462 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from typing import Tuple, Type
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.nvgpu import cpasync
from cutlass.cute.runtime import from_dlpack
import cutlass.pipeline as pipeline
import torch
"""
TMA Matrix Transpose with Producer-Consumer Pattern: TMA load -> S2R --> R2S -> TMA store
Warp Roles (Producer-Consumer Pattern):
- Producer: TMA Load Warp (Warp 4) - Loads from Global to Shared A
- Consumer: Transpose Warps (Warp 0-3) - Wait for load, then transpose sA -> sB
- Consumer: TMA Store Warp (Warp 5) - Wait for transpose, then store sB to Global
Synchronization:
1. load_mbar_ptr: TMA Load (producer) -> Transpose Warps (consumer)
2. store_mbar_ptr: Transpose Warps (producer) -> TMA Store (consumer)
This demonstrates how different warps can use shared memory barriers to coordinate
producer-consumer relationships.
"""
class Sm100MatrixTransposeKernelV1:
def __init__(self):
self.tile_shape = (128, 128)
self.tile_m, self.tile_n = self.tile_shape
self.cluster_shape_mn = (1, 1)
self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
# Set specialized warp ids based on tile_shape
self.num_trans_warps = 4 # Maximum number of transpose warps
self.trans_warp_id = tuple(range(self.num_trans_warps))
self.tma_load_warp_id = self.num_trans_warps
self.tma_store_warp_id = self.num_trans_warps + 1
self.threads_per_cta = 32 * len(
(self.tma_store_warp_id, self.tma_load_warp_id, *self.trans_warp_id)
)
self.num_trans_threads = 32 * len(self.trans_warp_id)
self.trans_tile = (self.tile_shape[0] // self.num_trans_warps, 8)
# Set barriers for producer-consumer sync
# Barrier 1: Trans warps sync (for internal coordination)
self.trans_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=32 * len(self.trans_warp_id),
)
# Barrier 2: TMA Store warp waits after Trans warps finish
self.store_barrier = pipeline.NamedBarrier(
barrier_id=2,
num_threads=32, # Only TMA store warp
)
self.buffer_align_bytes = 1024
@cute.jit
def __call__(self, src: cute.Tensor, dst: cute.Tensor):
if cutlass.const_expr(src.element_type != dst.element_type):
raise TypeError("Source and destination element types must match")
self.dtype: Type[cutlass.Numeric] = src.element_type
# Create transposed view of dst for TMA descriptor
# dst is (N, M), we want to view it as (M, N) transposed
transed_dst = cute.make_tensor(
dst.iterator,
cute.make_layout(
(dst.shape[1], dst.shape[0]), stride=(dst.stride[1], dst.stride[0])
),
)
# row-major smem layout for sA (tile_m, tile_n)
smem_layout_sA = sm100_utils.make_smem_layout(
utils.LayoutEnum.from_tensor(src).mma_major_mode(),
(self.tile_m, self.tile_n),
self.dtype,
1,
)
# col-major smem layout for sB (tile_n, tile_m)
# sB should match the transposed destination layout
smem_layout_sB = sm100_utils.make_smem_layout(
utils.LayoutEnum.from_tensor(transed_dst).mma_major_mode(),
(self.tile_m, self.tile_n),
self.dtype,
1,
)
@cute.struct
class SharedStorage:
# Barrier for TMA Load: producer (TMA) -> consumer (Trans warps)
load_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1]
# Barrier for TMA Store: producer (Trans warps) -> consumer (TMA Store)
store_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1]
# Single shared memory buffer (sA and sB are different views of this)
sA: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sA)], 128
]
sB: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sB)], 128
]
self.shared_storage = SharedStorage
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout_sA)
# TMA Atoms
# Use swizzled layout for TMA atom to handle swizzling during load
tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileG2SOp(),
src,
smem_layout_sA,
(self.tile_m, self.tile_n),
)
tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
transed_dst,
smem_layout_sB,
(self.tile_m, self.tile_n),
)
grid_shape = cute.ceil_div((*src.layout.shape, 1), self.tile_shape)
self.kernel(
tma_atom_src,
tma_tensor_src,
tma_atom_dst,
tma_tensor_dst,
smem_layout_sA,
smem_layout_sB,
).launch(
grid=grid_shape,
block=(self.threads_per_cta, 1, 1),
cluster=self.cluster_shape_mnk,
)
@cute.kernel
def kernel(
self,
tma_atom_load: cute.CopyAtom,
tma_tensor_src: cute.Tensor,
tma_atom_store: cute.CopyAtom,
tma_tensor_dst: cute.Tensor,
smem_layout_sA: cute.ComposedLayout,
smem_layout_sB: cute.ComposedLayout,
):
bidx, bidy, _ = cute.arch.block_idx()
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
# Allocate Shared Memory
# We need two buffers for transpose:
# sA: Source Tile (swizzled)
# sB: Destination Tile (swizzled) for TMA Store
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
sA = storage.sA.get_tensor(smem_layout_sA.outer, swizzle=smem_layout_sA.inner)
# sA = cute.make_tensor(storage.sA.iterator, smem_layout_sA)
sB = storage.sB.get_tensor(smem_layout_sB.outer, swizzle=smem_layout_sB.inner)
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout_sA)
load_mbar_ptr = storage.load_mbar_ptr.data_ptr()
store_mbar_ptr = storage.store_mbar_ptr.data_ptr()
# ------------------------------------------------------------------
# Initialize Barriers (all warps participate in initialization)
# ------------------------------------------------------------------
if tidx == 0:
# Barrier for TMA Load: expect 1 arrive (from TMA warp after TMA completes)
cute.arch.mbarrier_init(load_mbar_ptr, 1)
cute.arch.mbarrier_expect_tx(load_mbar_ptr, self.num_tma_load_bytes)
# Barrier for TMA Store: expect arrival from Trans warps
cute.arch.mbarrier_init(store_mbar_ptr, len(self.trans_warp_id))
cute.arch.mbarrier_init_fence()
# Sync all warps after barrier initialization
cute.arch.barrier()
# ------------------------------------------------------------------
# PRODUCER: TMA Load Warp (G -> sA)
# ------------------------------------------------------------------
if warp_idx == self.tma_load_warp_id:
# Issue TMA Load
# ((TileM, TileK), loopM, LoopK)
gA = cute.local_tile(tma_tensor_src, self.tile_shape, (None, None))
# ((TileM, TileK), loopM, LoopK)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_load,
0,
cute.make_layout(1),
cute.group_modes(sA, 0, 2),
cute.group_modes(gA, 0, 2),
)
cute.copy(
tma_atom_load,
tAgA[(None, bidx, bidy)],
tAsA[(None, 0)],
tma_bar_ptr=load_mbar_ptr,
)
# Arrive on mbarrier to satisfy the init count of 1
with cute.arch.elect_one():
cute.arch.mbarrier_arrive(load_mbar_ptr)
# ------------------------------------------------------------------
# CONSUMER: Transpose Warps (sA -> Reg -> sB)
# ------------------------------------------------------------------
if warp_idx < self.tma_load_warp_id:
trans_tid = tidx % self.num_trans_threads
# Wait for TMA Load to complete (consumer wait on load_mbar)
cute.arch.mbarrier_wait(load_mbar_ptr, 0)
atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.dtype,
num_bits_per_copy=self.dtype.width, # Copy one element at a time
)
copy_elems = 1
# Use SAME thread layout for both read and write
# Transpose happens through sB_transposed layout view
# TV layout notation: T = thread-id, V = value-lane within that thread.
# In this example `copy_elems = 1` and `thread_layout` has shape (T, V) = (num_trans_threads, 1),
# so V is always 0 (only one value-lane per thread).
#
# Linearization rule (row-major by strides):
# idx(Ti, Vj) = i + j * num_trans_threads
# Therefore here:
# idx(Ti, V0) = i
#
# Diagram (V is the column, T is the row; showing the first two threads):
#
# V0
# ┌──────┐
# T0 │ T0V0 │ -> idx 0
# T1 │ T1V0 │ -> idx 1
# ... │ ... │
# └──────┘
#
thread_layout = cute.make_layout(
(self.num_trans_threads, 1),
stride=(1, self.num_trans_threads),
)
value_layout = cute.make_layout((1, copy_elems))
# Build a "tiled copy" operator that defines the per-thread copy mapping (T,V) for this warp-group:
# - It is used twice below via `thr_copy.partition_S(...)` and `thr_copy.partition_D(...)` to
# create matching per-thread views of the source and destination tensors.
# - With the SAME (T,V) mapping, the actual transpose is achieved by changing the tensor view
# (`sA` vs `sB` / `sB_transposed`), not by changing which threads perform the copies.
tiled_copy = cute.make_tiled_copy_tv(atom, thread_layout, value_layout)
thr_copy = tiled_copy.get_slice(trans_tid)
# Partition sA (source) for reading
tCsA = thr_copy.partition_S(sA)
# When to use `tiled_copy.retile(...)`:
# - Use it when you allocate/build a register tensor yourself (or slice/reshape it) and its
# internal layout doesn't match the TV layout expected by `tiled_copy` for copy-in/out.
# Why we don't use it here:
# - `cute.make_fragment_like(tCsA)` creates an rmem fragment with the same per-thread shape/layout
# as `tCsA`, so it already matches `tiled_copy`'s view and can be copied into directly.
tCrA = cute.make_fragment_like(tCsA)
cute.copy(tiled_copy, tCsA, tCrA)
# Partition sB for writing
tCsB = thr_copy.partition_D(sB)
# Write from register to sB
cute.copy(tiled_copy, tCrA, tCsB)
# Fence and barrier to make sure shared memory store is visible to TMA store
cute.arch.fence_proxy(
"async.shared",
space="cta",
)
self.trans_sync_barrier.arrive_and_wait()
# Trans warps signal TMA Store warp: "sB is ready!"
with cute.arch.elect_one():
cute.arch.mbarrier_arrive(store_mbar_ptr)
# ------------------------------------------------------------------
# CONSUMER: TMA Store Warp (sB -> G)
# ------------------------------------------------------------------
if warp_idx == self.tma_store_warp_id:
# Wait for Trans warp to complete (consumer wait on store_mbar)
cute.arch.mbarrier_wait(store_mbar_ptr, 0)
gDst_cta = cute.local_tile(
tma_tensor_dst, (self.tile_m, self.tile_n), (None, None)
)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_store,
0,
cute.make_layout(1),
cute.group_modes(sB, 0, 2),
cute.group_modes(gDst_cta, 0, 2),
)
cute.copy(tma_atom_store, tBsB[(None, 0)], tBgB[(None, bidx, bidy)])
def run_transpose(M, N, num_warmup=5, num_iters=20):
"""
Run TMA transpose kernel with performance measurement.
Args:
M: Matrix dimension M
N: Matrix dimension N
num_warmup: Number of warmup iterations
num_iters: Number of timing iterations
Performance Metrics:
- Throughput: Actual achieved bandwidth in GB/s
- Theoretical BW: Peak memory bandwidth (2048 B/clk × 4000 MHz = 8.192 TB/s)
- Bandwidth Efficiency: Percentage of theoretical peak achieved
"""
torch.manual_seed(1111)
# Input (M, N)
input_data = torch.randn((M, N), device="cuda", dtype=torch.float16)
# Output (N, M)
output_data = torch.zeros((N, M), device="cuda", dtype=torch.float16)
# CuTe Wrappers
tensor_src = (
from_dlpack(input_data, assumed_align=16)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=16)
)
tensor_dst = (
from_dlpack(output_data, assumed_align=16)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=16)
)
transpose_kernel = Sm100MatrixTransposeKernelV1()
print("Start kernel compilation...")
# Compile and Run
compiled_kernel = cute.compile(
transpose_kernel, tensor_src, tensor_dst, options="--generate-line-info"
)
print("Start kernel warmup...")
# Warmup runs
for _ in range(num_warmup):
compiled_kernel(tensor_src, tensor_dst)
torch.cuda.synchronize()
print("Kernel warmup completed.")
# Timed runs
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_iters):
compiled_kernel(tensor_src, tensor_dst)
end_event.record()
torch.cuda.synchronize()
# Calculate performance metrics
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_time_ms / num_iters
# Calculate throughput
# For transpose: read M*N elements + write M*N elements
bytes_per_element = input_data.element_size()
total_bytes = 2 * M * N * bytes_per_element # Read + Write
throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000)
# Theoretical bandwidth limit
# Blackwell: 2048 B/clk at 4000 MHz
bytes_per_clk = 2048
freq_mhz = 4000
theoretical_bw_gb_s = bytes_per_clk * freq_mhz * 1e6 / 1e9 # Convert to GB/s
theoretical_bw_tb_s = theoretical_bw_gb_s / 1000 # Convert to TB/s
bandwidth_efficiency = (throughput_gb_s / theoretical_bw_gb_s) * 100 # Percentage
# Print performance metrics
print(f"Matrix size: {M}×{N}")
print(f"Tile shape: {transpose_kernel.tile_shape}")
print(f"Average time: {avg_time_ms:.4f} ms")
print(f"Throughput: {throughput_gb_s:.2f} GB/s")
print(
f"Theoretical BW: {theoretical_bw_tb_s:.2f} TB/s ({theoretical_bw_gb_s:.2f} GB/s)"
)
print(f"Bandwidth Efficiency: {bandwidth_efficiency:.2f}%")
# Verification
expected = input_data.t()
if torch.allclose(output_data, expected, atol=1e-2):
print("Verification: PASSED ✓")
else:
print("Verification: FAILED ✗")
print(f"Max diff: {(output_data - expected).abs().max()}")
if __name__ == "__main__":
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
parser = argparse.ArgumentParser(
description="TMA Matrix Transpose with Producer-Consumer Pattern"
)
parser.add_argument("--M", type=int, default=128, help="Matrix dimension M")
parser.add_argument("--N", type=int, default=128, help="Matrix dimension N")
parser.add_argument(
"--num_warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--num_iters", type=int, default=20, help="Number of timing iterations"
)
args = parser.parse_args()
run_transpose(
args.M,
args.N,
num_warmup=args.num_warmup,
num_iters=args.num_iters,
)

View File

@@ -0,0 +1,648 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
from typing import Tuple, Type, Union
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.nvgpu import cpasync
from cutlass.cute.runtime import from_dlpack
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
import torch
"""
TMA Matrix Transpose with Multi-Stage Pipeline(v2)
This version extends tma_v1.py with:
1. Multi-stage pipeline: Multiple buffers for pipelining TMA loads and stores
2. Pipeline abstraction: Using PipelineTmaAsync for proper producer-consumer coordination
3. Persistent tile scheduler: Efficient work distribution across CTAs
Key Improvements over v1:
- Multi-stage buffers enable overlapping TMA loads, computation, and TMA stores
- Pipeline objects provide cleaner synchronization semantics
Note: TMA multicast is NOT used because each CTA must process different input tiles.
Warp Roles:
- Producer: TMA Load Warp - Loads from Global to Shared memory (multi-stage, multi-tile)
- Consumer: Transpose Warps - Wait for load, transpose sA -> sB (multi-tile)
- Consumer: TMA Store Warp - Wait for transpose, store sB to Global (multi-stage, multi-tile)
Pipeline Stages:
1. Load Pipeline: TMA Load (producer) -> Transpose Warps (consumer)
2. Store Pipeline: Transpose Warps (producer) -> TMA Store (consumer)
"""
class Sm100MatrixTransposeKernelV2:
def __init__(
self,
):
"""
Initialize the TMA transpose kernel with multi-stage pipeline support.
Args:
tile_shape: Tile dimensions (M, N)
cluster_shape_mn: Cluster shape for parallel CTA execution (M, N)
Note:
- Each CTA processes different tiles independently
- Stage counts are automatically computed based on available shared memory
- Persistent scheduler distributes work across CTAs in the cluster
"""
self.tile_shape = (128, 128)
self.tile_m, self.tile_n = self.tile_shape
self.cluster_shape_mn = (1, 1)
self.cluster_shape_mnl = (*self.cluster_shape_mn, 1)
# Set specialized warp ids based on tile_shape
# For 128x128 tile, use 4 transpose warps (same as v1)
self.max_trans_warps = 4 # Maximum number of transpose warps
self.num_trans_warps = self.max_trans_warps # Use all transpose warps
self.trans_warp_id = tuple(range(self.num_trans_warps))
self.tma_load_warp_id = self.num_trans_warps
self.threads_per_cta = 32 * len((self.tma_load_warp_id, *self.trans_warp_id))
self.num_trans_threads = 32 * len(self.trans_warp_id)
# Set barriers for producer-consumer sync
# Barrier 1: Trans warps sync (for internal coordination)
self.trans_sync_barrier = pipeline.NamedBarrier(
barrier_id=1,
num_threads=32 * len(self.trans_warp_id),
)
self.buffer_align_bytes = 128
# Get shared memory capacity for stage computation
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
@staticmethod
def _compute_stages(
tile_m: int,
tile_n: int,
dtype: Type[cutlass.Numeric],
smem_capacity: int,
) -> Tuple[int, int]:
"""
Compute the number of load and store stages based on shared memory capacity.
Strategy:
1. Calculate bytes per stage for load (sA) and store (sB) buffers
2. Reserve space for barriers and alignment
3. Divide remaining smem by bytes per stage to get max stages
4. Clamp to reasonable min/max values
Args:
tile_m: Tile dimension M
tile_n: Tile dimension N
dtype: Data type of the tensors
smem_capacity: Total shared memory capacity in bytes
Returns:
Tuple of (num_load_stages, num_store_stages)
"""
# Calculate bytes per tile (assuming row-major and col-major layouts)
bytes_per_element = dtype.width // 8
# For sA (load buffer): tile_m x tile_n elements
sA_bytes_per_stage = tile_m * tile_n * bytes_per_element
# For sB (store buffer): tile_m x tile_n elements (transposed)
sB_bytes_per_stage = tile_m * tile_n * bytes_per_element
# Reserve space for barriers and other metadata
# Each barrier: 8 bytes (Int64)
# Estimate: max 16 barriers (load + store stages * 2) + alignment
reserved_bytes = 1024 # Conservative estimate
# Available space for staging buffers
available_smem = smem_capacity - reserved_bytes
# Calculate max stages we can fit
# We need space for both load and store stages
total_bytes_per_stage_pair = sA_bytes_per_stage + sB_bytes_per_stage
# Max stages (same for load and store for simplicity)
max_stages = available_smem // total_bytes_per_stage_pair
# Clamp to reasonable values
# Min: 2 stages for basic double buffering
# Max: 8 stages (diminishing returns beyond this)
num_stages = max(2, min(max_stages, 8))
return num_stages, num_stages
@staticmethod
def _compute_grid(
c: cute.Tensor,
cta_tile_shape_mn: Tuple[int, int],
cluster_shape_mn: Tuple[int, int],
max_active_clusters: cutlass.Constexpr,
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
:param c: The output tensor C
:type c: cute.Tensor
:param cta_tile_shape_mn: The shape (M, N) of the CTA tile.
:type cta_tile_shape_mn: tuple[int, int]
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
:type cluster_shape_mn: tuple[int, int]
:param max_active_clusters: Maximum number of active clusters.
:type max_active_clusters: cutlass.Constexpr
:return: A tuple containing:
- tile_sched_params: Parameters for the persistent tile scheduler.
- grid: Grid shape for kernel launch.
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
"""
c_shape = cute.slice_(cta_tile_shape_mn, (None, None))
gc = cute.zipped_divide(c, tiler=c_shape)
num_ctas_mn = gc[(0, (None, None))].shape
cluster_shape_mnl = (*cluster_shape_mn, 1)
num_ctas_mnl = (*num_ctas_mn, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
num_ctas_mnl, cluster_shape_mnl
)
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
tile_sched_params, max_active_clusters
)
return tile_sched_params, grid
@cute.jit
def __call__(
self, src: cute.Tensor, dst: cute.Tensor, max_active_clusters: cutlass.Constexpr
):
if cutlass.const_expr(src.element_type != dst.element_type):
raise TypeError("Source and destination element types must match")
self.dtype: Type[cutlass.Numeric] = src.element_type
# Compute optimal stage counts based on tile size and dtype
self.num_load_stages, self.num_store_stages = self._compute_stages(
self.tile_m,
self.tile_n,
self.dtype,
self.smem_capacity,
)
# Create transposed view of dst for TMA descriptor
# dst is (N, M), we want to view it as (M, N) transposed
transed_dst = cute.make_tensor(
dst.iterator,
cute.make_layout(
(dst.shape[1], dst.shape[0]), stride=(dst.stride[1], dst.stride[0])
),
)
# Create multi-stage layouts for load and store buffers
# row-major smem layout for sA (tile_m, tile_n)
smem_layout_sA_staged = sm100_utils.make_smem_layout(
utils.LayoutEnum.from_tensor(src).mma_major_mode(),
(self.tile_m, self.tile_n),
self.dtype,
self.num_load_stages,
)
# col-major smem layout for sB (tile_n, tile_m)
# sB should match the transposed destination layout
smem_layout_sB_staged = sm100_utils.make_smem_layout(
utils.LayoutEnum.from_tensor(transed_dst).mma_major_mode(),
(self.tile_m, self.tile_n),
self.dtype,
self.num_store_stages,
)
@cute.struct
class SharedStorage:
# Pipeline barriers for multi-stage load
load_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load_stages
]
load_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_load_stages
]
# Pipeline barriers for multi-stage store
store_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_store_stages
]
store_empty_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_store_stages
]
# Multi-stage shared memory buffers
sA: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sA_staged)],
self.buffer_align_bytes,
]
sB: cute.struct.Align[
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sB_staged)],
self.buffer_align_bytes,
]
self.shared_storage = SharedStorage
a_smem_layout = cute.slice_(smem_layout_sA_staged, (None, None, 0))
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, a_smem_layout)
# TMA Atoms
# Each CTA loads its own tile independently
tma_load_op = cpasync.CopyBulkTensorTileG2SOp()
cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)), (1,)
)
tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom(
tma_load_op,
src,
smem_layout_sA_staged,
(self.tile_m, self.tile_n),
)
tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(),
transed_dst,
smem_layout_sB_staged,
(self.tile_m, self.tile_n),
)
tile_sched_params, grid_shape = self._compute_grid(
transed_dst,
(self.tile_m, self.tile_n),
self.cluster_shape_mn,
max_active_clusters,
)
self.kernel(
tma_atom_src,
tma_tensor_src,
tma_atom_dst,
tma_tensor_dst,
smem_layout_sA_staged,
smem_layout_sB_staged,
cluster_layout_vmnk,
tile_sched_params,
).launch(
grid=grid_shape,
block=(self.threads_per_cta, 1, 1),
cluster=self.cluster_shape_mnl,
)
@cute.kernel
def kernel(
self,
tma_atom_load: cute.CopyAtom,
tma_tensor_src: cute.Tensor,
tma_atom_store: cute.CopyAtom,
tma_tensor_dst: cute.Tensor,
smem_layout_sA_staged: Union[cute.Layout, cute.ComposedLayout],
smem_layout_sB_staged: Union[cute.Layout, cute.ComposedLayout],
cluster_layout_vmnk: cute.Layout,
tile_sched_params: utils.PersistentTileSchedulerParams,
):
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
# ---------------- Shared mem & staged buffers ----------------
smem = utils.SmemAllocator()
storage = smem.allocate(self.shared_storage)
sA_staged = storage.sA.get_tensor(
smem_layout_sA_staged.outer, swizzle=smem_layout_sA_staged.inner
)
sB_staged = storage.sB.get_tensor(
smem_layout_sB_staged.outer, swizzle=smem_layout_sB_staged.inner
)
load_mbar_ptr = storage.load_full_mbar_ptr.data_ptr()
_store_mbar_ptr = storage.store_full_mbar_ptr.data_ptr()
load_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1)
load_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.num_trans_warps
)
load_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=load_mbar_ptr,
num_stages=self.num_load_stages,
producer_group=load_producer_group,
consumer_group=load_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
store_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, self.num_trans_threads
)
store_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_store_stages,
producer_group=store_producer_group,
)
# Critical: Initialize pipeline barriers across cluster
# This must happen after pipeline creation and before any producer/consumer work
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
gA = cute.local_tile(tma_tensor_src, self.tile_shape, (None, None))
_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, None, 0)).shape
)
# ((TileM, TileK), loopM, LoopK)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_load,
0,
cute.make_layout(1),
cute.group_modes(sA_staged, 0, 2),
cute.group_modes(gA, 0, 2),
)
gDst_cta = cute.local_tile(tma_tensor_dst, self.tile_shape, (None, None))
tBsB, tBgB = cpasync.tma_partition(
tma_atom_store,
0,
cute.make_layout(1),
cute.group_modes(sB_staged, 0, 2),
cute.group_modes(gDst_cta, 0, 2),
)
# ------------------------------------------------------------------
# PRODUCER: TMA Load Warp (G -> sA)
# ------------------------------------------------------------------
if warp_idx == self.tma_load_warp_id:
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
load_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_load_stages
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
tAgA_slice = tAgA[(None, cur_tile_coord[0], cur_tile_coord[1])]
load_pipeline.producer_acquire(load_producer_state)
cute.copy(
tma_atom_load,
tAgA_slice,
tAsA[(None, load_producer_state.index)],
tma_bar_ptr=load_pipeline.producer_get_barrier(load_producer_state),
)
load_producer_state.advance()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
load_pipeline.producer_tail(load_producer_state)
# ------------------------------------------------------------------
# CONSUMER: Transpose Warps (sA -> Reg -> sB)
# ------------------------------------------------------------------
if warp_idx < self.tma_load_warp_id:
tile_sched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
)
work_tile = tile_sched.initial_work_tile_info()
trans_tid = tidx % self.num_trans_threads
load_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_load_stages
)
while work_tile.is_valid_tile:
# Get tile coord from tile scheduler
cur_tile_coord = work_tile.tile_idx
# Wait for load pipeline to have data
load_pipeline.consumer_wait(load_consumer_state)
atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.dtype,
num_bits_per_copy=self.dtype.width,
)
copy_elems = 1
thread_layout = cute.make_layout(
(self.num_trans_threads, 1),
stride=(1, self.num_trans_threads),
)
value_layout = cute.make_layout((1, copy_elems))
tiled_copy = cute.make_tiled_copy_tv(atom, thread_layout, value_layout)
thr_copy = tiled_copy.get_slice(trans_tid)
# sA -> Reg
tCsA = thr_copy.partition_S(sA_staged)
tCrA = cute.make_rmem_tensor(
tCsA[(None, None, None, 0)].shape, self.dtype
)
tCrA = tiled_copy.retile(tCrA)
cute.copy(
tiled_copy,
tCsA[(None, None, None, load_consumer_state.index)],
tCrA,
)
# release load pipeline
load_pipeline.consumer_release(load_consumer_state)
load_consumer_state.advance()
index = tile_sched.num_tiles_executed % self.num_store_stages
# Reg -> sB
tCsB = thr_copy.partition_D(sB_staged)
cute.copy(tiled_copy, tCrA, tCsB[(None, None, None, index)])
# Fence to ensure smem writes are visible
cute.arch.fence_proxy(
"async.shared",
space="cta",
)
self.trans_sync_barrier.arrive_and_wait()
if warp_idx == self.trans_warp_id[0]:
cute.copy(
tma_atom_store,
tBsB[(None, index)],
tBgB[(None, cur_tile_coord[0], cur_tile_coord[1])],
)
store_pipeline.producer_commit()
store_pipeline.producer_acquire()
self.trans_sync_barrier.arrive_and_wait()
tile_sched.advance_to_next_work()
work_tile = tile_sched.get_current_work()
self.trans_sync_barrier.arrive_and_wait()
store_pipeline.producer_tail()
def run_transpose(M, N, max_active_clusters=0, num_warmup=5, num_iters=20):
"""
Run TMA transpose kernel with automatic stage calculation and performance measurement.
Args:
M: Matrix dimension M
N: Matrix dimension N
max_active_clusters: Maximum number of active clusters (0 for auto)
num_warmup: Number of warmup iterations
num_iters: Number of timing iterations
Performance Metrics:
- Throughput: Actual achieved bandwidth in GB/s
- Theoretical BW: Peak memory bandwidth (2048 B/clk × 4000 MHz = 8.192 TB/s)
- Bandwidth Efficiency: Percentage of theoretical peak achieved
"""
torch.manual_seed(1111)
# Input (M, N)
input_data = torch.randn((M, N), device="cuda", dtype=torch.float16)
# Output (N, M)
output_data = torch.zeros((N, M), device="cuda", dtype=torch.float16)
# CuTe Wrappers
tensor_src = (
from_dlpack(input_data, assumed_align=16)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=16)
)
tensor_dst = (
from_dlpack(output_data, assumed_align=16)
.mark_layout_dynamic(leading_dim=1)
.mark_compact_shape_dynamic(mode=1, divisibility=16)
)
transpose_kernel = Sm100MatrixTransposeKernelV2()
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(1)
# Compile and Run
compiled_kernel = cute.compile(
transpose_kernel,
tensor_src,
tensor_dst,
max_active_clusters,
options="--generate-line-info",
)
# Warmup runs
for _ in range(num_warmup):
compiled_kernel(tensor_src, tensor_dst)
torch.cuda.synchronize()
# Timed runs
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_iters):
compiled_kernel(tensor_src, tensor_dst)
end_event.record()
torch.cuda.synchronize()
# Calculate performance metrics
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_ms = elapsed_time_ms / num_iters
# Calculate throughput
# For transpose: read M*N elements + write M*N elements
bytes_per_element = input_data.element_size()
total_bytes = 2 * M * N * bytes_per_element # Read + Write
throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000)
# Theoretical bandwidth limit
# Blackwell: 2048 B/clk at 4000 MHz
bytes_per_clk = 2048
freq_mhz = 4000
theoretical_bw_gb_s = bytes_per_clk * freq_mhz * 1e6 / 1e9 # Convert to GB/s
theoretical_bw_tb_s = theoretical_bw_gb_s / 1000 # Convert to TB/s
bandwidth_efficiency = (throughput_gb_s / theoretical_bw_gb_s) * 100 # Percentage
# Print computed stage counts after compilation
print(f"Matrix size: {M}×{N}")
print(f"Tile shape: {transpose_kernel.tile_shape}")
print(
f"Computed stages: Load={transpose_kernel.num_load_stages}, Store={transpose_kernel.num_store_stages}"
)
print(f"Average time: {avg_time_ms:.4f} ms")
print(f"Throughput: {throughput_gb_s:.2f} GB/s")
print(
f"Theoretical BW: {theoretical_bw_tb_s:.2f} TB/s ({theoretical_bw_gb_s:.2f} GB/s)"
)
print(f"Bandwidth Efficiency: {bandwidth_efficiency:.2f}%")
# Verification
expected = input_data.t()
if torch.allclose(output_data, expected, atol=1e-2):
print("Verification: PASSED ✓")
else:
print("Verification: FAILED ✗")
print(f"Max diff: {(output_data - expected).abs().max()}")
if __name__ == "__main__":
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
try:
return tuple(int(x.strip()) for x in s.split(","))
except ValueError:
raise argparse.ArgumentTypeError(
"Invalid format. Expected comma-separated integers."
)
parser = argparse.ArgumentParser(
description="TMA Matrix Transpose with Multi-Stage Pipeline and Cluster Support (v2)"
)
parser.add_argument("--M", type=int, default=128, help="Matrix dimension M")
parser.add_argument("--N", type=int, default=128, help="Matrix dimension N")
parser.add_argument(
"--num_warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--num_iters", type=int, default=20, help="Number of timing iterations"
)
args = parser.parse_args()
run_transpose(
args.M,
args.N,
num_warmup=args.num_warmup,
num_iters=args.num_iters,
)

View File

@@ -100,7 +100,7 @@ from cutlass.cute.runtime import from_dlpack
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, ".."))
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
from helpers import fmha_helpers as fmha_utils

Some files were not shown because too many files have changed in this diff Show More