mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 00:40:03 +00:00
v4.5 tag update (#3202)
* Python DSL examples reorganization. * v4.5 tag update.
This commit is contained in:
43
CHANGELOG.md
43
CHANGELOG.md
@@ -2,18 +2,59 @@
|
||||
|
||||
# CUTLASS 4.x
|
||||
|
||||
## [4.5.0](https://github.com/NVIDIA/cutlass/tree/main) (2026-03-27)
|
||||
## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01)
|
||||
|
||||
### CuTe DSL
|
||||
* New features
|
||||
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
|
||||
- MXF8F6F4 mixed precision supoort
|
||||
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
|
||||
- Block Scaled MMA for SM120 now works on Spark
|
||||
- EFC broadcast semantics support
|
||||
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
|
||||
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
|
||||
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
|
||||
- cute.copy now supports user specified loop unrolling
|
||||
|
||||
* Bug fixing and improvements
|
||||
- Improved source code correlation for profiling/debugging
|
||||
- Fixed an aarch64 segfault issue with tvm-ffi
|
||||
- Re-organization for CuTe DSL examples/tutorials for better discoverability
|
||||
|
||||
* More examples of authorizing peak-performance kernels
|
||||
- MOE examles
|
||||
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
|
||||
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
|
||||
- Compared to torch_210_cu13, very few problem has worse perf in B200.
|
||||
- mxfp8_2dx3d: avg 1.29 speedup;
|
||||
- mxfp8_2dx2d: avg 1.41 speedup;
|
||||
- nvfp4_2dx3d: avg 1.11 speedup;
|
||||
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
|
||||
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
|
||||
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
|
||||
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
|
||||
|
||||
* API changes
|
||||
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
|
||||
|
||||
### CUTLASS C++
|
||||
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
|
||||
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
|
||||
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
|
||||
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
|
||||
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
|
||||
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
|
||||
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
|
||||
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
|
||||
- Enables launching GEMM on stream with partial SM allocation.
|
||||
* Fix some kernel issues:
|
||||
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
|
||||
- Fix CUTLASS clang build issues
|
||||
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
|
||||
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
|
||||
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
|
||||
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
|
||||
* Fix some profiler issues:
|
||||
- Add missing reference kernels for blockwise GEMM profiler
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
|
||||
@@ -30,5 +30,5 @@ Certain files within this repository are subject to separate licensing terms:
|
||||
|
||||
- The files located in the `python/CuTeDSL` directory are licensed under the
|
||||
NVIDIA End User License Agreement (EULA). Please refer to
|
||||
https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html
|
||||
https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
|
||||
for the full terms.
|
||||
|
||||
47
README.md
47
README.md
@@ -3,7 +3,7 @@
|
||||
|
||||
# CUTLASS 4.5.0
|
||||
|
||||
_CUTLASS 4.5.0 - March 2026_
|
||||
_CUTLASS 4.5.0 - May 2026_
|
||||
|
||||
CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM)
|
||||
and related computations at all levels and scales within CUDA. It incorporates strategies for
|
||||
@@ -45,16 +45,57 @@ To get started quickly - please refer :
|
||||
|
||||
# What's New in CUTLASS 4.5
|
||||
|
||||
### CuTe DSL
|
||||
## CuTe DSL
|
||||
* New features
|
||||
- New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy.
|
||||
- MXF8F6F4 mixed precision supoort
|
||||
- BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6
|
||||
- Block Scaled MMA for SM120 now works on Spark
|
||||
- EFC broadcast semantics support
|
||||
- EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations.
|
||||
- Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy
|
||||
- dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path
|
||||
- cute.copy now supports user specified loop unrolling
|
||||
|
||||
* Bug fixing and improvements
|
||||
- Improved source code correlation for profiling/debugging
|
||||
- Fixed an aarch64 segfault issue with tvm-ffi
|
||||
- Re-organization for CuTe DSL examples/tutorials for better discoverability
|
||||
|
||||
### CUTLASS C++
|
||||
* More examples of authorizing peak-performance kernels
|
||||
- MOE examles
|
||||
- A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface.
|
||||
- Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM.
|
||||
- Compared to torch_210_cu13, very few problem has worse perf in B200.
|
||||
- mxfp8_2dx3d: avg 1.29 speedup;
|
||||
- mxfp8_2dx2d: avg 1.41 speedup;
|
||||
- nvfp4_2dx3d: avg 1.11 speedup;
|
||||
- nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98)
|
||||
- bf16_2dx3d: avg 1.15 speedup (worst case 0.98)
|
||||
- bf16_2dx2d: avg 1.17 speedup (worst case 0.96)
|
||||
- Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel.
|
||||
|
||||
* API changes
|
||||
- ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead.
|
||||
|
||||
## CUTLASS C++
|
||||
* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels.
|
||||
- Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes.
|
||||
- Uses TMA multicast for A tile when using non-trivial cluster size along N mode.
|
||||
- Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs.
|
||||
- Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm).
|
||||
* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels.
|
||||
* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/).
|
||||
* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation.
|
||||
* Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition
|
||||
- Enables launching GEMM on stream with partial SM allocation.
|
||||
* Fix some kernel issues:
|
||||
- Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates
|
||||
- Fix CUTLASS clang build issues
|
||||
- Fix atomicCAS read-modify-write loop in `ConstSubbyteReference`
|
||||
- Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference
|
||||
- Remove `PipelineStorage` shadowing in SM100 complex epilogue
|
||||
- Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized
|
||||
* Fix some profiler issues:
|
||||
- Add missing reference kernels for blockwise GEMM profiler
|
||||
* Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs!
|
||||
|
||||
@@ -48,7 +48,7 @@ foreach(FUSION_CONV_EXAMPLE
|
||||
fused_two_convs_f16_sm80_shmem
|
||||
fused_two_convs_s8_sm75_rf
|
||||
fused_two_convs_s8_sm75_shmem
|
||||
fused_two_convs_s8_sm80_rf
|
||||
# fused_two_convs_s8_sm80_rf # disabled: fails to build
|
||||
fused_two_convs_s8_sm80_shmem
|
||||
)
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@ public:
|
||||
}
|
||||
implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop);
|
||||
implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler);
|
||||
implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info);
|
||||
|
||||
return implementable;
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ struct Options {
|
||||
float alpha, beta;
|
||||
int iterations;
|
||||
int m, n, k;
|
||||
int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n;
|
||||
int cluster_m, cluster_n;
|
||||
using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
DecompositionMode decomposition_mode;
|
||||
@@ -240,10 +240,8 @@ struct Options {
|
||||
m(256), n(256), k(16384),
|
||||
alpha(1.f), beta(0.f),
|
||||
iterations(10),
|
||||
preferred_cluster_m(4),
|
||||
preferred_cluster_n(4),
|
||||
fallback_cluster_m(2),
|
||||
fallback_cluster_n(1),
|
||||
cluster_m(2),
|
||||
cluster_n(1),
|
||||
decomposition_mode(DecompositionMode::Heuristic),
|
||||
reduction_mode(ReductionMode::Deterministic),
|
||||
splits(1)
|
||||
@@ -265,10 +263,8 @@ struct Options {
|
||||
cmd.get_cmd_line_argument("beta", beta, 0.f);
|
||||
cmd.get_cmd_line_argument("iterations", iterations);
|
||||
cmd.get_cmd_line_argument("splits", splits, 1);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4);
|
||||
cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1);
|
||||
cmd.get_cmd_line_argument("cluster_m", cluster_m, 2);
|
||||
cmd.get_cmd_line_argument("cluster_n", cluster_n, 1);
|
||||
|
||||
// Parse decompsition mode
|
||||
std::string decomp_mode;
|
||||
@@ -303,10 +299,8 @@ struct Options {
|
||||
<< " --k=<int> Sets the K extent of the GEMM\n"
|
||||
<< " --alpha=<f32> Epilogue scalar alpha\n"
|
||||
<< " --beta=<f32> Epilogue scalar beta\n"
|
||||
<< " --preferred_cluster_m=<str> Sets the M extent of preferred cluster shape\n"
|
||||
<< " --preferred_cluster_n=<str> Sets the N extent of preferred cluster shape\n"
|
||||
<< " --fallback_cluster_m=<str> Sets the M extent of fallback cluster shape\n"
|
||||
<< " --fallback_cluster_n=<str> Sets the N extent of fallback cluster shape\n"
|
||||
<< " --cluster_m=<str> Sets the M extent of the cluster shape\n"
|
||||
<< " --cluster_n=<str> Sets the N extent of the cluster shape\n"
|
||||
<< " --decomposition=<str> Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n"
|
||||
<< " --reduction=<str> Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n"
|
||||
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
|
||||
@@ -424,8 +418,8 @@ typename Gemm::Arguments args_from_options(const Options &options) {
|
||||
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
|
||||
};
|
||||
|
||||
arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1);
|
||||
arguments.hw_info.cluster_shape = dim3(options.cluster_m, options.cluster_n, 1);
|
||||
arguments.hw_info.cluster_shape_fallback = dim3(options.cluster_m, options.cluster_n, 1);
|
||||
|
||||
arguments.scheduler.splits = options.splits;
|
||||
arguments.scheduler.decomposition_mode = options.decomposition_mode;
|
||||
@@ -498,8 +492,7 @@ int run(Options &options) {
|
||||
|
||||
std::cout << "Stream-K GEMM with"
|
||||
<< " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k
|
||||
<< " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)"
|
||||
<< " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n"
|
||||
<< " Cluster = (" << options.cluster_m << ", " << options.cluster_n << ", 1)\n"
|
||||
<< " Decomposition_mode=" << options.decomposition_mode_str()
|
||||
<< " Split_count=" << options.splits
|
||||
<< " Reduction_mode=" << options.reduction_mode_str()
|
||||
|
||||
@@ -536,7 +536,11 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
// Each thread owns a single row
|
||||
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x;
|
||||
#else
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
|
||||
#endif
|
||||
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||
|
||||
@@ -573,6 +577,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
|
||||
}
|
||||
|
||||
ElementQK old_row_max = row_max;
|
||||
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
|
||||
auto pos = tTMEM_LOADcS(0);
|
||||
if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) {
|
||||
float curr_max = tiled_tmem_load.get_max();
|
||||
row_max = ::fmax(row_max, curr_max);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// compute rowmax
|
||||
float row_max_0 = row_max;
|
||||
|
||||
@@ -540,7 +540,11 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
|
||||
|
||||
// Each thread owns a single row
|
||||
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
|
||||
using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x;
|
||||
#else
|
||||
using TMEM_LOAD = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
|
||||
#endif
|
||||
using TMEM_STORE = conditional_t<size<1>(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem
|
||||
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
|
||||
|
||||
@@ -577,6 +581,14 @@ struct Sm100FmhaGenMainloopWarpspecialized {
|
||||
}
|
||||
|
||||
ElementQK old_row_max = row_max;
|
||||
#if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED
|
||||
auto pos = tTMEM_LOADcS(0);
|
||||
if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) {
|
||||
float curr_max = tiled_tmem_load.get_max();
|
||||
row_max = ::fmax(row_max, curr_max);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
// compute rowmax
|
||||
float row_max_0 = row_max;
|
||||
|
||||
@@ -213,12 +213,15 @@ auto make_iterator(T* ptr) {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleRunner {
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100;
|
||||
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
// Type of epilogue schedule to generate
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
static constexpr bool FuseQuantization = false;
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
class ClusterShapeMNK = Shape<_1, _1, _1>,
|
||||
bool FuseQuantization = false
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor;
|
||||
@@ -238,10 +241,8 @@ struct ExampleRunner {
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
|
||||
static constexpr int TileM = cute::is_base_of_v<cutlass::gemm::KernelSchedule2Sm, MainloopScheduleType> ? 256 : 128;
|
||||
using MmaTileMNK = Shape<Int<TileM>,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
@@ -712,10 +713,34 @@ int main(int argc, char const **args) {
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl;
|
||||
ExampleRunner runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load, 1SM, 2x2 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_2, _2, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_2x2;
|
||||
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x1 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_2, _1, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_2x1;
|
||||
runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_2, _4, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_2x4;
|
||||
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -220,6 +220,7 @@ using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig<OutputSFVecto
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
class ClusterShapeMNK = Shape<_1, _1, _1>,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
bool FuseQuantization = false
|
||||
@@ -246,8 +247,8 @@ struct ExampleRunner {
|
||||
|
||||
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
|
||||
static constexpr int TileM = cute::is_base_of_v<cutlass::gemm::KernelSchedule2Sm, MainloopScheduleType> ? 256 : 128;
|
||||
using MmaTileMNK = Shape<Int<TileM>,_64,_64>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
static constexpr int AlignmentA = 32;
|
||||
static constexpr int AlignmentB = 32;
|
||||
@@ -485,6 +486,24 @@ struct ExampleRunner {
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else if constexpr (std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100>){
|
||||
return typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{ // Mainloop arguments
|
||||
block_A.device_data(),
|
||||
block_B.device_data(),
|
||||
block_SFA.device_data(),
|
||||
block_SFB.device_data()
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{},
|
||||
block_C.device_data(), stride_C,
|
||||
block_D.device_data(), stride_D
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
return typename Gemm::Arguments {
|
||||
@@ -654,10 +673,80 @@ int main(int argc, char const **args) {
|
||||
ExampleRunner<cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100> runner_tma;
|
||||
runner_tma.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 1x1 cluster:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100> runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x1 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
Shape<_4, _1, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_4x1;
|
||||
runner_mixed_tma_cpasync_1sm_4x1.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 1x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
Shape<_1, _4, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_1x4;
|
||||
runner_mixed_tma_cpasync_1sm_1x4.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x2 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
Shape<_2, _2, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_2x2;
|
||||
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x2 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
Shape<_4, _2, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_4x2;
|
||||
runner_mixed_tma_cpasync_1sm_4x2.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
Shape<_2, _4, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_2x4;
|
||||
runner_mixed_tma_cpasync_1sm_2x4.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100,
|
||||
Shape<_4, _4, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_4x4;
|
||||
runner_mixed_tma_cpasync_1sm_4x4.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x1 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
|
||||
Shape<_4, _1, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_4x1;
|
||||
runner_mixed_tma_cpasync_2sm_4x1.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 2x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
|
||||
Shape<_2, _4, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_2x4;
|
||||
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x2 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
|
||||
Shape<_4, _2, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_4x2;
|
||||
runner_mixed_tma_cpasync_2sm_4x2.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100,
|
||||
Shape<_4, _4, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_4x4;
|
||||
runner_mixed_tma_cpasync_2sm_4x4.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -70,7 +70,6 @@
|
||||
|
||||
#include "helper.h"
|
||||
|
||||
|
||||
using namespace cute;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -169,12 +168,14 @@ bool initialize_block(
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ExampleRunner {
|
||||
|
||||
template<
|
||||
// Type of kernel schedule to generate
|
||||
using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100;
|
||||
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
// Type of epilogue schedule to generate
|
||||
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
class ClusterShapeMNK = Shape<_1, _1, _1>
|
||||
>
|
||||
struct ExampleRunner {
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
@@ -189,7 +190,6 @@ struct ExampleRunner {
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
@@ -219,7 +219,7 @@ struct ExampleRunner {
|
||||
MainloopScheduleType
|
||||
>::CollectiveOp;
|
||||
|
||||
using ProblemShape = cutlass::gemm::MoEProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
using ProblemShape = typename cutlass::gemm::MoEProblemShape<Shape<int,int,int>>; // <M,N,K> per group
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape,
|
||||
@@ -272,10 +272,10 @@ struct ExampleRunner {
|
||||
auto [M, N, K] = problem;
|
||||
printf("group [%d] : M = %d, N = %d, K = %d\n", i, M, N, K);
|
||||
|
||||
cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, Gemm::LayoutA(maxK));
|
||||
cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, Gemm::LayoutB(maxK));
|
||||
cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutC(maxM));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutD(maxM));
|
||||
cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, typename Gemm::LayoutA(maxK));
|
||||
cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, typename Gemm::LayoutB(maxK));
|
||||
cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, typename Gemm::LayoutC(maxM));
|
||||
cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, typename Gemm::LayoutD(maxM));
|
||||
|
||||
using DeviceGemmReference = cutlass::reference::device::Gemm<
|
||||
ElementA,
|
||||
@@ -561,10 +561,62 @@ int main(int argc, char const **args) {
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl;
|
||||
ExampleRunner runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load and 1x1 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_1, _1, _1>
|
||||
> runner_mixed_tma_cpasync_1sm;
|
||||
runner_mixed_tma_cpasync_1sm.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x2 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_2, _2, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_2x2;
|
||||
runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load and 4x4 cluster:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_4, _4, _1>
|
||||
> runner_mixed_tma_cpasync_1sm_4x4;
|
||||
runner_mixed_tma_cpasync_1sm_4x4.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 2x1:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_2, _1, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_2x1;
|
||||
runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 4x1:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_4, _1, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_4x1;
|
||||
runner_mixed_tma_cpasync_2sm_4x1.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 2x4:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_2, _4, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_2x4;
|
||||
runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info);
|
||||
|
||||
std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 4x4:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
Shape<_4, _4, _1>
|
||||
> runner_mixed_tma_cpasync_2sm_4x4;
|
||||
runner_mixed_tma_cpasync_2sm_4x4.run(options, hw_info);
|
||||
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -171,6 +171,8 @@ bool initialize_block(
|
||||
template <
|
||||
// Type of kernel schedule to generate
|
||||
class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
// Static cluster shape
|
||||
class ClusterShapeMNK = Shape<_1, _1, _1>,
|
||||
// Type of epilogue schedule to generate
|
||||
class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto
|
||||
>
|
||||
@@ -189,7 +191,6 @@ struct ExampleRunner {
|
||||
using ElementCompute = float;
|
||||
using ElementScalar = float;
|
||||
|
||||
using ClusterShapeMNK = Shape<_1,_1,_1>;
|
||||
using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage)
|
||||
|
||||
// 16B alignment lets us use TMA
|
||||
@@ -334,6 +335,16 @@ struct ExampleRunner {
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else if constexpr (std::is_same_v<MainloopScheduleType, cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100>){
|
||||
return typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{block_A.get(), block_B.get()},
|
||||
{{}, // epilogue.thread
|
||||
block_C.get(), stride_C, block_D.get(), stride_D},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
else {
|
||||
return typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
@@ -490,13 +501,33 @@ int main(int argc, char const **args) {
|
||||
runner_tma.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with CPASYNC load:" << std::endl;
|
||||
ExampleRunner runner_cpasync;
|
||||
ExampleRunner<cutlass::gemm::KernelWarpSpecialized1SmSm100> runner_cpasync;
|
||||
runner_cpasync.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl;
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load, cluster shape 1x1:" << std::endl;
|
||||
ExampleRunner<cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100> runner_mixed_tma_cpasync;
|
||||
runner_mixed_tma_cpasync.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 2x2:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
Shape<_2, _2, _1>
|
||||
> runner_mixed_tma_cpasync_2x2;
|
||||
runner_mixed_tma_cpasync_2x2.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 4x4, 1SM instructions:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100,
|
||||
Shape<_4, _4, _1>
|
||||
> runner_mixed_tma_cpasync_4x4;
|
||||
runner_mixed_tma_cpasync_4x4.run(options, hw_info);
|
||||
|
||||
std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 4x4, 2SM instructions:" << std::endl;
|
||||
ExampleRunner<
|
||||
cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100,
|
||||
Shape<_4, _4, _1>
|
||||
> runner_mixed_tma_cpasync_4x4_2sm;
|
||||
runner_mixed_tma_cpasync_4x4_2sm.run(options, hw_info);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -48,7 +48,7 @@ from cutlass.cute.typing import Int32, Int64, Float32
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, ".."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from helpers import fmha_helpers as fmha_utils
|
||||
|
||||
@@ -51,7 +51,7 @@ from cutlass.cute.typing import Int32, Float32, Float8E4M3FN, Float16, BFloat16,
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, ".."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from helpers import fmha_helpers as fmha_utils
|
||||
|
||||
@@ -45,14 +45,14 @@ from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from blackwell.mamba2_ssd.mamba2_ssd_reference import (
|
||||
from cute.blackwell.kernel.attention.mamba2_ssd.mamba2_ssd_reference import (
|
||||
ssd_reference_fp32_all,
|
||||
ssd_reference_lowprecision_intermediates,
|
||||
analyze_relative_diffs,
|
||||
)
|
||||
from blackwell.mamba2_ssd.mamba2_ssd_tile_scheduler import (
|
||||
from cute.blackwell.kernel.attention.mamba2_ssd.mamba2_ssd_tile_scheduler import (
|
||||
Mamba2SSDTileSchedulerParams,
|
||||
Mamba2SSDTileScheduler,
|
||||
)
|
||||
@@ -27,14 +27,11 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import enum
|
||||
import math
|
||||
import time
|
||||
from typing import Type, Tuple
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
@@ -42,7 +39,7 @@ import cuda.bindings.driver as cuda
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.cute.nvgpu.tcgen05 as tcgen05
|
||||
from cutlass.cute.nvgpu import tcgen05, OperandMajorMode
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
import cutlass.torch as cutlass_torch
|
||||
@@ -51,9 +48,6 @@ import cutlass.cute.testing as testing
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import *
|
||||
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass.cute.arch.nvvm_wrappers import mapa
|
||||
|
||||
# Kernel invariants
|
||||
mma_modes = (0, 1, 2)
|
||||
mma_dice = (None, None, None) # (MMA, #MMA_M, #MMA_K)
|
||||
@@ -65,19 +59,22 @@ warpgroup_threads = 128
|
||||
# Math helpers
|
||||
log2_e = math.log2(math.e) # change exponential base
|
||||
use_tensor_ssa_math = False # experimental
|
||||
fadd2 = partial(cute.arch.add_packed_f32x2, ftz=False, rnd="rn")
|
||||
fmul2 = partial(cute.arch.mul_packed_f32x2, ftz=False, rnd="rn")
|
||||
ffma2 = partial(cute.arch.fma_packed_f32x2, ftz=False, rnd="rn")
|
||||
fadd2 = cute.arch.add_packed_f32x2
|
||||
fmul2 = cute.arch.mul_packed_f32x2
|
||||
ffma2 = cute.arch.fma_packed_f32x2
|
||||
exp2 = partial(cute.math.exp2, fastmath=True)
|
||||
warp_fmax = partial(cute.arch.warp_redux_sync, kind="fmax", nan=True)
|
||||
smem_fmax = partial(cute.arch.atomic_fmax, sem="relaxed", scope="cta")
|
||||
gmem_fmax = partial(cute.arch.atomic_fmax, sem="relaxed", scope="gpu")
|
||||
|
||||
|
||||
class MixedInputFusedMultiHeadAttentionDecode:
|
||||
def __init__(
|
||||
self,
|
||||
headdim,
|
||||
block_scaledim, # headdim per scale factor; scale factor shape is (batches, heads_k, seqlen, headdim / block_scaledim)
|
||||
grouped_head_tile, # GQA packing tile size, can be less than group size
|
||||
convert_warpgroups = 1, # Multiple warpgroups striding on convert stages
|
||||
block_scaledim, # headdim per scale factor; scale factor shape is (batches, heads_k, seqlen, headdim / block_scaledim)
|
||||
grouped_head_tile, # GQA packing tile size, can be less than group size
|
||||
convert_warpgroups=1, # Multiple warpgroups striding on convert stages
|
||||
):
|
||||
self.headdim = headdim
|
||||
self.grouped_head_tile = grouped_head_tile
|
||||
@@ -93,7 +90,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
self.softmax_warpgroup_id = warpgroup_id
|
||||
warpgroup_id += 1
|
||||
|
||||
self.cvt_warpgroup_ids = tuple(range(warpgroup_id, warpgroup_id+convert_warpgroups))
|
||||
self.cvt_warpgroup_ids = tuple(
|
||||
range(warpgroup_id, warpgroup_id + convert_warpgroups)
|
||||
)
|
||||
warpgroup_id += convert_warpgroups
|
||||
|
||||
# Why 2 MMA+TMA warps when not MMA bound?
|
||||
@@ -114,12 +113,15 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
max_regs_per_wg_thread = 64 * 1024 // warpgroup_threads # 64K regs per SM
|
||||
self.mma_tma_regs = 72
|
||||
self.cvt_regs = 112
|
||||
self.softmax_regs = (max_regs_per_wg_thread
|
||||
- self.mma_tma_regs
|
||||
- self.cvt_regs * convert_warpgroups)
|
||||
self.softmax_regs = (
|
||||
max_regs_per_wg_thread
|
||||
- self.mma_tma_regs
|
||||
- self.cvt_regs * convert_warpgroups
|
||||
)
|
||||
self.softmax_regs = max(128, min(256, self.softmax_regs))
|
||||
assert (self.mma_tma_regs + self.softmax_regs +
|
||||
self.cvt_regs * convert_warpgroups) <= max_regs_per_wg_thread or not self.use_reg_reconfig
|
||||
assert (
|
||||
self.mma_tma_regs + self.softmax_regs + self.cvt_regs * convert_warpgroups
|
||||
) <= max_regs_per_wg_thread or not self.use_reg_reconfig
|
||||
|
||||
self.bs_stages = 2
|
||||
self.sp_stages = 2
|
||||
@@ -140,25 +142,33 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
raise ValueError("use Float8E4M3FN instead of Float8E4M3")
|
||||
|
||||
if d % 64 != 0:
|
||||
raise ValueError(f"headdim({d}) must be multiple of 64")
|
||||
raise testing.CantImplementError(f"headdim({d}) must be multiple of 64")
|
||||
|
||||
if h_q % h_k != 0:
|
||||
raise ValueError(f"heads_q({h_q}) must be a multiple of heads_k({h_k})")
|
||||
raise testing.CantImplementError(
|
||||
f"heads_q({h_q}) must be a multiple of heads_k({h_k})"
|
||||
)
|
||||
|
||||
align_scale_bits = 128 # TMA requirement
|
||||
if self.scaledim * q_dtype.width < align_scale_bits:
|
||||
align_seq = align_scale_bits // (self.scaledim * q_dtype.width)
|
||||
if s_k % align_seq != 0:
|
||||
raise ValueError(f"seqlen({s_k}) must be a multiple of {align_seq}")
|
||||
raise testing.CantImplementError(
|
||||
f"seqlen({s_k}) must be a multiple of {align_seq}"
|
||||
)
|
||||
|
||||
if kv_dtype.width < 8 and d % 128 != 0: # TMA requirement
|
||||
raise ValueError(f"headdim({d}) must be multiple of 128 for {kv_dtype} KV")
|
||||
raise testing.CantImplementError(
|
||||
f"headdim({d}) must be multiple of 128 for {kv_dtype} KV"
|
||||
)
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self,
|
||||
problem_shape: Tuple[Int32, Int32, Int32, Int32, Int32], # b, h_q, h_k, s_k, d
|
||||
kv_splits: Int32, # threadblocks per sequence
|
||||
problem_shape: Tuple[
|
||||
cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32
|
||||
], # b, h_q, h_k, s_k, d
|
||||
kv_splits: cutlass.Int32, # threadblocks per sequence
|
||||
q_iter: cute.Pointer,
|
||||
k_iter: cute.Pointer,
|
||||
v_iter: cute.Pointer,
|
||||
@@ -170,8 +180,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
o_partial_iter: cute.Pointer, # partial O per kv split
|
||||
m_partial_iter: cute.Pointer, # partial colmax_s per kv split
|
||||
l_partial_iter: cute.Pointer, # partial colsum_p per kv split
|
||||
scale_qs: Float32,
|
||||
scale_o: Float32,
|
||||
scale_qs: cutlass.Float32,
|
||||
scale_o: cutlass.Float32,
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
##############################
|
||||
@@ -179,7 +189,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
##############################
|
||||
mma_dtype = q_iter.dtype
|
||||
acc_dtype = o_partial_iter.dtype
|
||||
assert acc_dtype is Float32 # don't support other acc types for now
|
||||
assert acc_dtype is cutlass.Float32 # don't support other acc types for now
|
||||
|
||||
# Block tile sets the granularity at which threadblocks consume work
|
||||
blk_tile_s = 128
|
||||
@@ -197,8 +207,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# GEMM1: (S_K, H_R, D, (H_K, B))
|
||||
tiled_mma_kq = sm100_utils.make_trivial_tiled_mma(
|
||||
mma_dtype,
|
||||
tcgen05.OperandMajorMode.K, # K
|
||||
tcgen05.OperandMajorMode.K, # Q
|
||||
mma_dtype,
|
||||
OperandMajorMode.K, # K
|
||||
OperandMajorMode.K, # Q
|
||||
acc_dtype,
|
||||
tcgen05.CtaGroup.ONE,
|
||||
mma_tile_mnk[:2],
|
||||
@@ -208,8 +219,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# GEMM2: (D, H_R, S_K, (H_K, B))
|
||||
tiled_mma_vp = sm100_utils.make_trivial_tiled_mma( #
|
||||
mma_dtype,
|
||||
tcgen05.OperandMajorMode.K, # V
|
||||
tcgen05.OperandMajorMode.MN, # P
|
||||
mma_dtype,
|
||||
OperandMajorMode.K, # V
|
||||
OperandMajorMode.MN, # P
|
||||
acc_dtype,
|
||||
tcgen05.CtaGroup.ONE,
|
||||
mma_tile_mnk[:2],
|
||||
@@ -523,8 +535,8 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
mM: cute.Tensor,
|
||||
mM_partial: cute.Tensor,
|
||||
mL_partial: cute.Tensor,
|
||||
scale_qs: Float32,
|
||||
scale_qs_log2_e: Float32,
|
||||
scale_qs: cutlass.Float32,
|
||||
scale_qs_log2_e: cutlass.Float32,
|
||||
):
|
||||
# Read special registers
|
||||
kv_splits, tiles_hr, tiles_hb = cute.arch.grid_dim()
|
||||
@@ -586,7 +598,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
##############################
|
||||
# Tmem Allocation
|
||||
##############################
|
||||
tmem_ptr_smem_ptr = smem.allocate_array(Int32)
|
||||
tmem_ptr_smem_ptr = smem.allocate_array(cutlass.Int32)
|
||||
if warp_idx == init_warp and not exit_early:
|
||||
cute.arch.alloc_tmem(self.tmem_alloc_cols, tmem_ptr_smem_ptr)
|
||||
init_warp += 1
|
||||
@@ -595,26 +607,32 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Pipeline Allocation + Init
|
||||
##############################
|
||||
# Allocate Mbarriers
|
||||
q_pipeline_ptr = smem.allocate_array(Int64, self.q_stages * 2)
|
||||
kv_pipeline_ptr = smem.allocate_array(Int64, self.kv_stages * 2)
|
||||
bs_pipeline_ptr = smem.allocate_array(Int64, self.bs_stages * 2)
|
||||
cvt_pipeline_ptr = smem.allocate_array(Int64, self.cvt_stages * 2)
|
||||
s_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2)
|
||||
p_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2)
|
||||
o_pipeline_ptr = smem.allocate_array(Int64, self.o_stages * 2)
|
||||
q_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.q_stages * 2)
|
||||
kv_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.kv_stages * 2)
|
||||
bs_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.bs_stages * 2)
|
||||
cvt_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.cvt_stages * 2)
|
||||
s_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.sp_stages * 2)
|
||||
p_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.sp_stages * 2)
|
||||
o_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.o_stages * 2)
|
||||
|
||||
# Declare named barriers
|
||||
softmax_nbar_id = 1
|
||||
mma_kq_nbar_id = 2
|
||||
mma_vp_nbar_id = 3
|
||||
softmax_nbar = pipeline.NamedBarrier(
|
||||
barrier_id=1, num_threads=warpgroup_threads
|
||||
)
|
||||
mma_kq_nbar = pipeline.NamedBarrier(barrier_id=2, num_threads=64)
|
||||
mma_vp_nbar = pipeline.NamedBarrier(barrier_id=3, num_threads=64)
|
||||
|
||||
# Alias thread cooperatives
|
||||
elect_one_cooperative = pipeline.CooperativeGroup(pipeline.Agent.Thread)
|
||||
warpgroup_cooperative = pipeline.CooperativeGroup(pipeline.Agent.Thread, warpgroup_threads)
|
||||
warpgroup_cooperative = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, warpgroup_threads
|
||||
)
|
||||
mma_group = elect_one_cooperative
|
||||
tma_group = elect_one_cooperative
|
||||
cvt_group = warpgroup_cooperative
|
||||
cvt_groups = pipeline.CooperativeGroup(pipeline.Agent.Thread, warpgroup_threads * self.convert_warpgroups)
|
||||
cvt_groups = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, warpgroup_threads * self.convert_warpgroups
|
||||
)
|
||||
softmax_group = warpgroup_cooperative
|
||||
|
||||
# Initialize pipelines
|
||||
@@ -642,7 +660,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
num_stages=self.bs_stages,
|
||||
producer_group=tma_group,
|
||||
consumer_group=cvt_groups,
|
||||
tx_count=cute.size_in_bytes(mma_dtype, cute.select(smem_layout_bs, mode=[0,1])),
|
||||
tx_count=cute.size_in_bytes(
|
||||
mma_dtype, cute.select(smem_layout_bs, mode=[0, 1])
|
||||
),
|
||||
barrier_storage=bs_pipeline_ptr,
|
||||
tidx=mcast_coord,
|
||||
cta_layout_vmnk=mcast_layout,
|
||||
@@ -772,7 +792,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
tCtO = thrblk_mma_vp.make_fragment_C(tCsO.shape)
|
||||
|
||||
# Tmem tensor allocation
|
||||
tmem_ptr = cute.arch.retrieve_tmem_ptr(Int32, 16, tmem_ptr_smem_ptr)
|
||||
tmem_ptr = cute.arch.retrieve_tmem_ptr(cutlass.Int32, 16, tmem_ptr_smem_ptr)
|
||||
tmem_offset = 0
|
||||
|
||||
tAtK_cvt = cute.make_tensor(
|
||||
@@ -802,7 +822,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Exit early
|
||||
##############################
|
||||
if exit_early:
|
||||
noop = None # early return not supported
|
||||
noop = None # early return not supported # noqa: F841
|
||||
|
||||
##############################
|
||||
# TMA KV Dispatch
|
||||
@@ -991,9 +1011,11 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
cute.arch.setmaxregister_decrease(self.cvt_regs)
|
||||
|
||||
# Intermediate convert type
|
||||
cvt_type = Float32
|
||||
if cutlass.const_expr(mma_dtype is cutlass.BFloat16 and
|
||||
k_dtype in (cutlass.Int4, cutlass.Int8)):
|
||||
cvt_type = cutlass.Float32
|
||||
if cutlass.const_expr(
|
||||
mma_dtype is cutlass.BFloat16
|
||||
and k_dtype in (cutlass.Int4, cutlass.Int8)
|
||||
):
|
||||
cvt_type = mma_dtype
|
||||
|
||||
# Initialize for multiple warpgroups if necessary
|
||||
@@ -1015,12 +1037,13 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
smem_load_atom_k = cute.make_copy_atom(
|
||||
cute.nvgpu.warp.LdMatrix8x16x8bOp(
|
||||
num_matrices=4,
|
||||
unpack_bits=(k_dtype.width if k_dtype.width < 8 else None)),
|
||||
unpack_bits=(k_dtype.width if k_dtype.width < 8 else None),
|
||||
),
|
||||
kv_smem_dtype,
|
||||
)
|
||||
|
||||
tmem_store_k = tcgen05.make_tmem_copy(
|
||||
tmem_store_atom_k, tAtK_cvt[mma_dice+(0,)]
|
||||
tmem_store_atom_k, tAtK_cvt[mma_dice + (0,)]
|
||||
)
|
||||
thr_store_k = tmem_store_k.get_slice(warpgroup_tidx)
|
||||
tKrK_cvt_shape = thr_store_k.partition_S(tAtK_cvt).shape[:-1]
|
||||
@@ -1043,7 +1066,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
smem_load_atom_v = cute.make_copy_atom(smem_load_op_v, kv_smem_dtype)
|
||||
|
||||
tmem_store_v = tcgen05.make_tmem_copy(
|
||||
tmem_store_atom_v, tAtV_cvt[mma_dice+(0,)]
|
||||
tmem_store_atom_v, tAtV_cvt[mma_dice + (0,)]
|
||||
)
|
||||
thr_store_v = tmem_store_v.get_slice(warpgroup_tidx)
|
||||
tVrV_cvt_shape = thr_store_v.partition_S(tAtV_cvt).shape[:-1]
|
||||
@@ -1113,7 +1136,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
bs_handle.release()
|
||||
|
||||
# Convert and scale K
|
||||
for dk in cutlass.range(tiles_dk // self.convert_warpgroups, unroll=2):
|
||||
for dk in cutlass.range(
|
||||
tiles_dk // self.convert_warpgroups, unroll=2
|
||||
):
|
||||
tKrK = cute.make_rmem_tensor(tKrK_shape, kv_smem_dtype)
|
||||
tKrK_cvt = cute.make_rmem_tensor(tKrK_cvt_shape, mma_dtype)
|
||||
|
||||
@@ -1139,7 +1164,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
cvt_handle = cvt_producer.acquire_and_advance()
|
||||
cute.copy(
|
||||
thr_store_k, tKrK_cvt, tKtK_cvt[cpy_dice + (cvt_handle.index,)]
|
||||
thr_store_k,
|
||||
tKrK_cvt,
|
||||
tKtK_cvt[cpy_dice + (cvt_handle.index,)],
|
||||
)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
cvt_handle.commit()
|
||||
@@ -1160,7 +1187,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
bs_handle.release()
|
||||
|
||||
# Convert and scale V
|
||||
for dmsk in cutlass.range(tiles_dm * tiles_sk // self.convert_warpgroups, unroll=2):
|
||||
for dmsk in cutlass.range(
|
||||
tiles_dm * tiles_sk // self.convert_warpgroups, unroll=2
|
||||
):
|
||||
tVrV = cute.make_rmem_tensor(tVrV_shape, kv_smem_dtype)
|
||||
tVrV_cvt = cute.make_rmem_tensor(tVrV_cvt_shape, mma_dtype)
|
||||
|
||||
@@ -1186,7 +1215,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
cvt_handle = cvt_producer.acquire_and_advance()
|
||||
cute.copy(
|
||||
thr_store_v, tVrV_cvt, tVtV_cvt[cpy_dice + (cvt_handle.index,)]
|
||||
thr_store_v,
|
||||
tVrV_cvt,
|
||||
tVtV_cvt[cpy_dice + (cvt_handle.index,)],
|
||||
)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
cvt_handle.commit()
|
||||
@@ -1223,9 +1254,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
k_handle = cvt_consumer.wait_and_advance(k_token)
|
||||
# Signal BMM2 to start
|
||||
if is_last_iter:
|
||||
cute.arch.barrier_arrive(
|
||||
barrier_id=mma_kq_nbar_id, number_of_threads=64
|
||||
)
|
||||
mma_kq_nbar.arrive()
|
||||
for mma_k in cutlass.range_constexpr(tAtK_cvt.shape[2]):
|
||||
cute.gemm(
|
||||
tiled_mma_kq,
|
||||
@@ -1245,7 +1274,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
if s > 0:
|
||||
for _ in cutlass.range_constexpr(tiles_dm * tiles_sk):
|
||||
cvt_consumer.advance()
|
||||
cute.arch.barrier(barrier_id=mma_vp_nbar_id, number_of_threads=64)
|
||||
mma_vp_nbar.arrive_and_wait()
|
||||
s_token = s_producer.try_acquire()
|
||||
|
||||
##############################
|
||||
@@ -1263,7 +1292,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Advance and wait for BMM1
|
||||
for _ in cutlass.range_constexpr(tiles_dk):
|
||||
cvt_consumer.advance()
|
||||
cute.arch.barrier(barrier_id=mma_kq_nbar_id, number_of_threads=64)
|
||||
mma_kq_nbar.arrive_and_wait()
|
||||
|
||||
# Sequence loop
|
||||
p_token = False
|
||||
@@ -1273,7 +1302,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
if s < iters_s - 1:
|
||||
for _ in cutlass.range_constexpr(tiles_dk):
|
||||
cvt_consumer.advance()
|
||||
cute.arch.barrier(barrier_id=mma_kq_nbar_id, number_of_threads=64)
|
||||
mma_kq_nbar.arrive_and_wait()
|
||||
p_token = p_consumer.try_wait()
|
||||
|
||||
# BMM2
|
||||
@@ -1286,9 +1315,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
v_handle = cvt_consumer.wait_and_advance(v_token)
|
||||
# Signal BMM1 to start
|
||||
if is_last_iter:
|
||||
cute.arch.barrier_arrive(
|
||||
barrier_id=mma_vp_nbar_id, number_of_threads=64
|
||||
)
|
||||
mma_vp_nbar.arrive()
|
||||
for mma_k in cutlass.range_constexpr(tAtV_cvt.shape[2]):
|
||||
cute.gemm(
|
||||
tiled_mma_vp,
|
||||
@@ -1325,7 +1352,9 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
tmem_load_atom_s = cute.make_copy_atom(
|
||||
tcgen05.Ld32x32bOp(tmem_op_repeat), acc_dtype
|
||||
)
|
||||
tmem_load_s = tcgen05.make_tmem_copy(tmem_load_atom_s, tCtS[mma_dice + (0,)])
|
||||
tmem_load_s = tcgen05.make_tmem_copy(
|
||||
tmem_load_atom_s, tCtS[mma_dice + (0,)]
|
||||
)
|
||||
thr_load_s = tmem_load_s.get_slice(warpgroup_tidx)
|
||||
|
||||
tmem_store_atom_o = cute.make_copy_atom(
|
||||
@@ -1356,7 +1385,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Partition colmax and initialize in RF
|
||||
tSsM = thr_load_s.partition_D(tCsM) # (CPY, #CPY_MMA, #CPY_M, #CPY_N)
|
||||
tSrM_prev = cute.make_rmem_tensor_like(tSsM)
|
||||
tSrM_prev.fill(-Float32.inf)
|
||||
tSrM_prev.fill(-cutlass.Float32.inf)
|
||||
|
||||
# Partition colsum and initialize in RF
|
||||
# Each thread maintains a local colsum in RF, smem reduction happens after loop
|
||||
@@ -1364,7 +1393,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
tCsL
|
||||
) # (CPY, #CPY_MMA, #CPY_M, #CPY_N, WARPS)
|
||||
tSrL = cute.make_rmem_tensor_like(tSsL[cpy_dice + (0,)])
|
||||
tSrL.fill(Float32(0))
|
||||
tSrL.fill(cutlass.Float32(0))
|
||||
|
||||
assert warp_threads >= cute.size(tSsM)
|
||||
|
||||
@@ -1385,17 +1414,15 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
)
|
||||
|
||||
# Initialize O
|
||||
tSrO.fill(Float32(0))
|
||||
tSrO.fill(cutlass.Float32(0))
|
||||
cute.copy(thr_store_o, tSrO, tStO)
|
||||
|
||||
# Initialize colsum and colmax in smem and wait
|
||||
if warpgroup_widx == 0 and lane_store_max:
|
||||
tSsM[lane_idx] = -Float32.inf
|
||||
tSsM[lane_idx] = -cutlass.Float32.inf
|
||||
if warpgroup_widx == 1 and lane_store_max:
|
||||
tSsL[lane_idx] = Float32(0)
|
||||
cute.arch.barrier(
|
||||
barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads
|
||||
)
|
||||
tSsL[lane_idx] = cutlass.Float32(0)
|
||||
softmax_nbar.arrive_and_wait()
|
||||
|
||||
#
|
||||
# Sequence loop
|
||||
@@ -1410,20 +1437,18 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
# Reduce colmax in warp RF
|
||||
tSrM = cute.make_rmem_tensor_like(tSsM)
|
||||
tSrM_lane = Float32(0) # Avoid dynamic register indexing
|
||||
tSrM_lane = cutlass.Float32(0) # Avoid dynamic register indexing
|
||||
for i in cutlass.range_constexpr(cute.size(tSrS)):
|
||||
tSrM[i] = cute.arch.warp_redux_sync(tSrS[i], kind="fmax", nan=True)
|
||||
tSrM[i] = warp_fmax(tSrS[i])
|
||||
if i == lane_idx:
|
||||
tSrM_lane = tSrM[i]
|
||||
|
||||
# Reduce colmax in smem
|
||||
if lane_store_max:
|
||||
self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
|
||||
smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane)
|
||||
|
||||
# Wait for colmax then load
|
||||
cute.arch.barrier(
|
||||
barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads
|
||||
)
|
||||
softmax_nbar.arrive_and_wait()
|
||||
cute.autovec_copy(tSsM, tSrM)
|
||||
|
||||
# Compute online softmax
|
||||
@@ -1504,7 +1529,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
#
|
||||
|
||||
# Reduce colsum in warp RF
|
||||
tSrL_lane = Float32(0.0)
|
||||
tSrL_lane = cutlass.Float32(0.0)
|
||||
for i in cutlass.range_constexpr(cute.size(tSrL)):
|
||||
tSrL[i] = cute.arch.warp_reduction_sum(tSrL[i])
|
||||
if i == lane_idx:
|
||||
@@ -1515,16 +1540,12 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
tSsL[cpy_dice + (warpgroup_widx,)][lane_idx] = tSrL_lane
|
||||
|
||||
# Wait for colsum
|
||||
cute.arch.barrier(
|
||||
barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads
|
||||
)
|
||||
softmax_nbar.arrive_and_wait()
|
||||
|
||||
if warpgroup_widx == 0 and lane_store_max and inbound_hr:
|
||||
# Load colsum and colmax
|
||||
sL_lane_wg = sL[0, lane_idx, None]
|
||||
sL_lane = (
|
||||
sL_lane_wg[0] + sL_lane_wg[1] + sL_lane_wg[2] + sL_lane_wg[3]
|
||||
)
|
||||
sL_lane = sL_lane_wg[0] + sL_lane_wg[1] + sL_lane_wg[2] + sL_lane_wg[3]
|
||||
sM_lane = sM[0, lane_idx]
|
||||
|
||||
# Scale colmax
|
||||
@@ -1533,7 +1554,7 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
# Store colsum and colmax
|
||||
gL_partial[lane_idx] = sL_lane
|
||||
gM_partial[lane_idx] = sM_lane
|
||||
self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane)
|
||||
|
||||
o_handle = o_consumer.wait_and_advance()
|
||||
cute.copy(thr_load_s, tStO, tSrO)
|
||||
@@ -1563,16 +1584,16 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
o_partial: cute.Tensor,
|
||||
m_partial: cute.Tensor,
|
||||
l_partial: cute.Tensor,
|
||||
scale_o: Float32,
|
||||
scale_o: cutlass.Float32,
|
||||
):
|
||||
d_blk_idx, coord_h, coord_b = cute.arch.block_idx()
|
||||
d_per_blk, _, _ = cute.arch.block_dim()
|
||||
d_idx, _, _ = cute.arch.thread_idx()
|
||||
coord_d = d_blk_idx * d_per_blk + d_idx
|
||||
|
||||
o_dhb = Float32(0)
|
||||
o_dhb = cutlass.Float32(0)
|
||||
m_hb = m[coord_h, coord_b]
|
||||
l_hb = Float32(0)
|
||||
l_hb = cutlass.Float32(0)
|
||||
|
||||
o_partial_dhb = o_partial[coord_d, coord_h, coord_b, None]
|
||||
m_partial_hb = m_partial[coord_h, coord_b, None]
|
||||
@@ -1591,44 +1612,6 @@ class MixedInputFusedMultiHeadAttentionDecode:
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def smem_fmax(ptr: Pointer, val: Float32):
|
||||
# https://stackoverflow.com/a/72461459
|
||||
# Works with canonical NaN which warp_redux_sync(kind="fmax") should return
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[ptr.llvm_ptr, val.ir_value()],
|
||||
"""{\n\t
|
||||
.reg .pred p;\n\t
|
||||
setp.lt.s32 p, $1, 0x0;
|
||||
@p red.relaxed.shared::cta.min.u32 [$0], $1;\n\t
|
||||
@!p red.relaxed.shared::cta.max.s32 [$0], $1;\n\t
|
||||
}\n\t""",
|
||||
"r,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@cute.jit
|
||||
def gmem_fmax(ptr: Pointer, val: Float32):
|
||||
llvm.inline_asm(
|
||||
None,
|
||||
[ptr.llvm_ptr, val.ir_value()],
|
||||
"""{\n\t
|
||||
.reg .pred p;\n\t
|
||||
setp.lt.s32 p, $1, 0x0;
|
||||
@p red.relaxed.gpu.global.min.u32 [$0], $1;\n\t
|
||||
@!p red.relaxed.gpu.global.max.s32 [$0], $1;\n\t
|
||||
}\n\t""",
|
||||
"l,r",
|
||||
has_side_effects=True,
|
||||
is_align_stack=False,
|
||||
asm_dialect=llvm.AsmDialect.AD_ATT,
|
||||
)
|
||||
|
||||
|
||||
def run(
|
||||
batches: int = 1,
|
||||
@@ -1638,10 +1621,10 @@ def run(
|
||||
headdim: int = 512,
|
||||
block_scaledim: int = 512,
|
||||
kv_splits: int = 0,
|
||||
q_dtype: Type[cutlass.Numeric] = BFloat16,
|
||||
kv_dtype: Type[cutlass.Numeric] = Int8,
|
||||
o_dtype: Type[cutlass.Numeric] = BFloat16,
|
||||
acc_dtype: Type[cutlass.Numeric] = Float32,
|
||||
q_dtype: Type[cutlass.Numeric] = cutlass.BFloat16,
|
||||
kv_dtype: Type[cutlass.Numeric] = cutlass.Int8,
|
||||
o_dtype: Type[cutlass.Numeric] = cutlass.BFloat16,
|
||||
acc_dtype: Type[cutlass.Numeric] = cutlass.Float32,
|
||||
tolerance: float = 0.1,
|
||||
scale_q: float = 1.0,
|
||||
scale_o: float = 1.0,
|
||||
@@ -1652,7 +1635,7 @@ def run(
|
||||
use_cold_l2: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
print(f"Running Blackwell SM100 Mixed Input FMHA Decode test with:")
|
||||
print("Running Blackwell SM100 Mixed Input FMHA Decode test with:")
|
||||
print(f"\tbatches: {batches}, seqlen: {seqlen}")
|
||||
print(f"\theads_q: {heads_q}, heads_k: {heads_k}")
|
||||
print(f"\theaddim: {headdim}, block_scaledim: {block_scaledim}")
|
||||
@@ -1709,9 +1692,7 @@ def run(
|
||||
|
||||
problem_shape = (batches, heads_q, heads_k, seqlen_k, headdim)
|
||||
|
||||
fmha.can_implement(
|
||||
problem_shape, kv_splits, q_dtype, kv_dtype, o_dtype, acc_dtype
|
||||
)
|
||||
fmha.can_implement(problem_shape, kv_splits, q_dtype, kv_dtype, o_dtype, acc_dtype)
|
||||
|
||||
#
|
||||
# Allocate Tensors
|
||||
@@ -1719,24 +1700,24 @@ def run(
|
||||
torch.manual_seed(1111)
|
||||
|
||||
def create_tensor(shape, dtype, init=True):
|
||||
init_type = cutlass.torch.TensorInitType.RANDOM
|
||||
init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2)
|
||||
init_type = cutlass_torch.TensorInitType.RANDOM
|
||||
init_config = cutlass_torch.RandomInitConfig(min_val=-2, max_val=2)
|
||||
|
||||
if init is False or init is None:
|
||||
init_type = cutlass.torch.TensorInitType.SKIP
|
||||
init_type = cutlass_torch.TensorInitType.SKIP
|
||||
init_config = None
|
||||
elif isinstance(init, int) or isinstance(init, float):
|
||||
init_type = cutlass.torch.TensorInitType.SCALAR
|
||||
init_config = cutlass.torch.ScalarInitConfig(value=init)
|
||||
init_type = cutlass_torch.TensorInitType.SCALAR
|
||||
init_config = cutlass_torch.ScalarInitConfig(value=init)
|
||||
elif isinstance(init, tuple) or isinstance(init, list):
|
||||
if len(init) == 2:
|
||||
init_type = cutlass.torch.TensorInitType.RANDOM
|
||||
init_config = cutlass.torch.RandomInitConfig(
|
||||
init_type = cutlass_torch.TensorInitType.RANDOM
|
||||
init_config = cutlass_torch.RandomInitConfig(
|
||||
min_val=init[0], max_val=init[1]
|
||||
)
|
||||
if len(init) == 3:
|
||||
init_type = cutlass.torch.TensorInitType.GAUSSIAN
|
||||
init_config = cutlass.torch.RandomInitConfig(
|
||||
init_type = cutlass_torch.TensorInitType.GAUSSIAN
|
||||
init_config = cutlass_torch.RandomInitConfig(
|
||||
mean=init[0], std=init[1], scale=init[2]
|
||||
)
|
||||
|
||||
@@ -1809,7 +1790,7 @@ def run(
|
||||
scale_qs,
|
||||
scale_o,
|
||||
current_stream,
|
||||
options=f"--opt-level 2",
|
||||
options="--opt-level 2",
|
||||
)
|
||||
print("Finished Compiling")
|
||||
|
||||
@@ -1980,7 +1961,7 @@ if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str):
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
return tuple(cutlass.int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
@@ -2052,7 +2033,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--acc_dtype",
|
||||
type=cutlass.dtype,
|
||||
default=Float32,
|
||||
default=cutlass.Float32,
|
||||
help="accumulator/reduction data type",
|
||||
)
|
||||
|
||||
@@ -47,10 +47,10 @@ from cutlass.cute.typing import Int32, Int64, Float32
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from helpers import fmha_helpers as fmha_utils
|
||||
from blackwell.mixed_input_fmha import prefill_helpers as prefill_utils
|
||||
from cute.blackwell.kernel.attention.mixed_input_fmha import prefill_helpers as prefill_utils
|
||||
|
||||
|
||||
class MixedInputFusedMultiHeadAttentionPrefillD256:
|
||||
@@ -49,10 +49,10 @@ from cutlass.cute.typing import Int32, Int64, Float32
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from helpers import fmha_helpers as fmha_utils
|
||||
from blackwell.mixed_input_fmha import prefill_helpers as prefill_utils
|
||||
from cute.blackwell.kernel.attention.mixed_input_fmha import prefill_helpers as prefill_utils
|
||||
|
||||
|
||||
class MixedInputFusedMultiHeadAttentionPrefillD512:
|
||||
@@ -51,9 +51,9 @@ from cutlass.cutlass_dsl import BaseDSL
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from blackwell.mla.mla_helpers import (
|
||||
from cute.blackwell.kernel.attention.mla.mla_helpers import (
|
||||
ceil_div,
|
||||
MAX_SPLITS,
|
||||
LOG2_E,
|
||||
@@ -51,9 +51,9 @@ from cutlass.cutlass_dsl import BaseDSL
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../../.."))
|
||||
|
||||
from blackwell.mla.mla_helpers import (
|
||||
from cute.blackwell.kernel.attention.mla.mla_helpers import (
|
||||
ceil_div,
|
||||
MAX_SPLITS,
|
||||
LOG2_E,
|
||||
@@ -47,9 +47,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
|
||||
|
||||
from blackwell.mixed_input_gemm.mixed_input_host_utils import (
|
||||
from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import (
|
||||
create_tensors_for_contiguous_grouped_mixed_input_gemm as create_tensors,
|
||||
run_contiguous_grouped_ref_and_compare as run_ref_and_compare,
|
||||
)
|
||||
@@ -48,9 +48,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
|
||||
|
||||
from blackwell.mixed_input_gemm.mixed_input_host_utils import (
|
||||
from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import (
|
||||
create_tensors_for_contiguous_grouped_mixed_input_gemm as create_tensors,
|
||||
run_contiguous_grouped_ref_and_compare as run_ref_and_compare,
|
||||
)
|
||||
@@ -47,9 +47,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../.."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
|
||||
|
||||
from blackwell.mixed_input_gemm.mixed_input_host_utils import (
|
||||
from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import (
|
||||
create_tensors_for_batched_mixed_input_gemm as create_tensors,
|
||||
run_batched_mixed_input_ref_and_compare as run_ref_and_compare,
|
||||
)
|
||||
@@ -0,0 +1,695 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
MoE Persistent Tile Scheduler
|
||||
|
||||
A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations.
|
||||
This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo
|
||||
(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile.
|
||||
|
||||
Scenarios:
|
||||
- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate)
|
||||
- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden)
|
||||
|
||||
Key design principle:
|
||||
- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic)
|
||||
- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension
|
||||
- TMA descriptor management is handled by OnlineTensormapDescCreator
|
||||
- The kernel orchestrates all three components
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Literal
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean,
|
||||
Int32,
|
||||
Integer,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
const_expr,
|
||||
dsl_user_op,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
|
||||
# =============================================================================
|
||||
# Work Tile Info
|
||||
# =============================================================================
|
||||
|
||||
class MoEWorkTileInfo:
|
||||
"""
|
||||
Work tile information for MoE scheduler.
|
||||
|
||||
Contains CTA-level tile information for executor warps:
|
||||
- expert_idx: Which expert (-1 means invalid/done)
|
||||
- tile_m_idx: CTA tile index along GEMM M dimension
|
||||
- tile_n_idx: CTA tile index along GEMM N dimension
|
||||
- k_tile_cnt: Number of CTA tiles along K dimension
|
||||
|
||||
Note: These are CTA-level indices, not cluster-level.
|
||||
tile_l_idx is always 0 for MoE, executor can hardcode it.
|
||||
|
||||
For 2Dx3D (Forward):
|
||||
M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed)
|
||||
|
||||
For 2Dx2D (Backward):
|
||||
M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expert_idx: Int32, # -1 means invalid tile
|
||||
tile_m_idx: Int32,
|
||||
tile_n_idx: Int32,
|
||||
k_tile_cnt: Int32,
|
||||
):
|
||||
self.expert_idx = expert_idx
|
||||
self.tile_m_idx = tile_m_idx
|
||||
self.tile_n_idx = tile_n_idx
|
||||
self.k_tile_cnt = k_tile_cnt
|
||||
|
||||
@property
|
||||
def is_valid_tile(self) -> Boolean:
|
||||
"""Check if this is a valid work tile (expert_idx >= 0)."""
|
||||
return self.expert_idx >= Int32(0)
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = extract_mlir_values(self.expert_idx)
|
||||
values.extend(extract_mlir_values(self.tile_m_idx))
|
||||
values.extend(extract_mlir_values(self.tile_n_idx))
|
||||
values.extend(extract_mlir_values(self.k_tile_cnt))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo":
|
||||
assert len(values) == 4
|
||||
return MoEWorkTileInfo(
|
||||
expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]),
|
||||
tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]),
|
||||
tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]),
|
||||
k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]),
|
||||
)
|
||||
|
||||
def to_rmem_tensor(self):
|
||||
"""Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy."""
|
||||
rmem = cute.make_rmem_tensor((4,), Int32)
|
||||
rmem[0] = self.expert_idx
|
||||
rmem[1] = self.tile_m_idx
|
||||
rmem[2] = self.tile_n_idx
|
||||
rmem[3] = self.k_tile_cnt
|
||||
return rmem
|
||||
|
||||
@staticmethod
|
||||
def from_rmem_tensor(rmem) -> "MoEWorkTileInfo":
|
||||
"""Unpack work tile info from an rmem tensor of shape (4,)."""
|
||||
return MoEWorkTileInfo(
|
||||
expert_idx=rmem[0], # type: ignore[arg-type]
|
||||
tile_m_idx=rmem[1], # type: ignore[arg-type]
|
||||
tile_n_idx=rmem[2], # type: ignore[arg-type]
|
||||
k_tile_cnt=rmem[3], # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scheduler Parameters
|
||||
# =============================================================================
|
||||
|
||||
class MoEStaticSchedulerParams:
|
||||
"""
|
||||
Parameters for MoE tile scheduler.
|
||||
|
||||
Uses unified semantics for both scenarios:
|
||||
- expert_shape: (expert_cnt, intermediate, hidden)
|
||||
|
||||
For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert
|
||||
For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert
|
||||
|
||||
Tile hierarchy:
|
||||
- cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k)
|
||||
- cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n)
|
||||
- cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape
|
||||
|
||||
This class is used both on host (for grid shape calculation) and on device
|
||||
(stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk,
|
||||
cluster_shape_mn) are NOT serialized to MLIR values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden)
|
||||
cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k)
|
||||
cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n)
|
||||
):
|
||||
self.scenario = scenario
|
||||
e, i, h = expert_shape
|
||||
self.expert_cnt = e if isinstance(e, Int32) else Int32(e)
|
||||
self.intermediate = i if isinstance(i, Int32) else Int32(i)
|
||||
self.hidden = h if isinstance(h, Int32) else Int32(h)
|
||||
self.cta_tile_shape_mnk = cta_tile_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mn
|
||||
|
||||
@property
|
||||
def cluster_tile_m(self) -> int:
|
||||
"""Cluster tile size along M = cta_tile_m * cluster_m."""
|
||||
return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0]
|
||||
|
||||
@property
|
||||
def cluster_tile_n(self) -> int:
|
||||
"""Cluster tile size along N = cta_tile_n * cluster_n."""
|
||||
return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1]
|
||||
|
||||
@property
|
||||
def cta_tile_k(self) -> int:
|
||||
"""CTA tile size along K (same as cluster since cluster_k = 1)."""
|
||||
return self.cta_tile_shape_mnk[2]
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
"""Only serialize runtime values, not codegen-time constants."""
|
||||
values = []
|
||||
values.extend(extract_mlir_values(self.expert_cnt))
|
||||
values.extend(extract_mlir_values(self.intermediate))
|
||||
values.extend(extract_mlir_values(self.hidden))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams":
|
||||
assert len(values) == 3
|
||||
return MoEStaticSchedulerParams(
|
||||
scenario=self.scenario,
|
||||
expert_shape=(
|
||||
new_from_mlir_values(self.expert_cnt, [values[0]]),
|
||||
new_from_mlir_values(self.intermediate, [values[1]]),
|
||||
new_from_mlir_values(self.hidden, [values[2]]),
|
||||
),
|
||||
cta_tile_shape_mnk=self.cta_tile_shape_mnk,
|
||||
cluster_shape_mn=self.cluster_shape_mn,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: "MoEStaticSchedulerParams",
|
||||
max_active_clusters: int,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Compute grid shape for kernel launch.
|
||||
|
||||
Since host doesn't know token distribution across experts,
|
||||
we launch max_active_clusters and let device-side scheduler
|
||||
determine which tiles are valid.
|
||||
"""
|
||||
return (
|
||||
params.cluster_shape_mn[0],
|
||||
params.cluster_shape_mn[1],
|
||||
max_active_clusters,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scheduler (Device-side)
|
||||
# =============================================================================
|
||||
|
||||
class MoEStaticPersistentTileScheduler:
|
||||
"""
|
||||
Persistent tile scheduler specialized for MoE grouped GEMM.
|
||||
|
||||
This scheduler is ONLY responsible for tile iteration. It does NOT know
|
||||
about tensor types, TMA descriptors, or domain conversion. Those concerns
|
||||
are handled by MoESchedExtension and OnlineTensormapDescCreator respectively.
|
||||
|
||||
Architecture:
|
||||
- Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info
|
||||
- Executor warps: Read work_tile_info from smem, use MoESchedExtension for
|
||||
domain conversion and TMA desc selection
|
||||
|
||||
The scheduler handles:
|
||||
- 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden)
|
||||
- 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis)
|
||||
|
||||
Usage (Scheduler warp):
|
||||
scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim)
|
||||
work_tile_info = scheduler.initial_work_tile_info()
|
||||
# Broadcast work_tile_info to smem...
|
||||
|
||||
while work_tile_info.is_valid_tile:
|
||||
# ... do work ...
|
||||
work_tile_info = scheduler.advance_to_next_work()
|
||||
# Broadcast work_tile_info to smem...
|
||||
|
||||
Usage (Executor warps - via MoESchedExtension):
|
||||
# Read work_tile_info from smem...
|
||||
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
|
||||
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
|
||||
real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes)
|
||||
params: MoEStaticSchedulerParams,
|
||||
# Runtime tensor for scheduling
|
||||
offs: cute.Tensor, # (experts,) cumsum of token counts
|
||||
# Scheduling state
|
||||
num_persistent_clusters: Int32,
|
||||
current_work_linear_idx: Int32,
|
||||
cta_id_in_cluster: cute.Coord,
|
||||
# Expert tracking state (for O(1) advance within same expert)
|
||||
current_expert_idx: Int32,
|
||||
expert_tile_start: Int32, # cumsum of tiles before current expert
|
||||
expert_tile_end: Int32, # cumsum of tiles including current expert
|
||||
):
|
||||
self.params = params
|
||||
self.offs = offs
|
||||
self.num_persistent_clusters = num_persistent_clusters
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self.cta_id_in_cluster = cta_id_in_cluster
|
||||
# Expert tracking
|
||||
self.current_expert_idx = current_expert_idx
|
||||
self.expert_tile_start = expert_tile_start
|
||||
self.expert_tile_end = expert_tile_end
|
||||
|
||||
# =========================================================================
|
||||
# Convenience accessors for params
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def scenario(self) -> Literal["2Dx3D", "2Dx2D"]:
|
||||
return self.params.scenario
|
||||
|
||||
@property
|
||||
def expert_cnt(self) -> Int32:
|
||||
return self.params.expert_cnt
|
||||
|
||||
@property
|
||||
def intermediate(self) -> Int32:
|
||||
return self.params.intermediate
|
||||
|
||||
@property
|
||||
def hidden(self) -> Int32:
|
||||
return self.params.hidden
|
||||
|
||||
@property
|
||||
def cta_tile_shape_mnk(self) -> Tuple[int, int, int]:
|
||||
return self.params.cta_tile_shape_mnk
|
||||
|
||||
@property
|
||||
def cluster_shape_mn(self) -> Tuple[int, int]:
|
||||
return self.params.cluster_shape_mn
|
||||
|
||||
@property
|
||||
def cluster_tile_m(self) -> int:
|
||||
return self.params.cluster_tile_m
|
||||
|
||||
@property
|
||||
def cluster_tile_n(self) -> int:
|
||||
return self.params.cluster_tile_n
|
||||
|
||||
@property
|
||||
def cta_tile_k(self) -> int:
|
||||
return self.params.cta_tile_k
|
||||
|
||||
# =========================================================================
|
||||
# MLIR value serialization (for SSA value passing in device code)
|
||||
# =========================================================================
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = []
|
||||
# Params (only runtime values are extracted)
|
||||
values.extend(extract_mlir_values(self.params))
|
||||
# Runtime tensor for scheduling
|
||||
values.extend(extract_mlir_values(self.offs))
|
||||
# Scheduling state
|
||||
values.extend(extract_mlir_values(self.num_persistent_clusters))
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self.cta_id_in_cluster))
|
||||
# Expert tracking state
|
||||
values.extend(extract_mlir_values(self.current_expert_idx))
|
||||
values.extend(extract_mlir_values(self.expert_tile_start))
|
||||
values.extend(extract_mlir_values(self.expert_tile_end))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: List[ir.Value]
|
||||
) -> "MoEStaticPersistentTileScheduler":
|
||||
idx = 0
|
||||
|
||||
# Params (3 values: expert_cnt, intermediate, hidden)
|
||||
new_params = new_from_mlir_values(self.params, values[idx:idx + 3])
|
||||
idx += 3
|
||||
|
||||
# Runtime tensor for scheduling (variable size)
|
||||
offs_len = len(extract_mlir_values(self.offs))
|
||||
new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len])
|
||||
idx += offs_len
|
||||
|
||||
# Scheduling state
|
||||
new_num_persistent_clusters = new_from_mlir_values(
|
||||
self.num_persistent_clusters, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
# cta_id_in_cluster (3 values for Coord)
|
||||
new_cta_id_in_cluster = new_from_mlir_values(
|
||||
self.cta_id_in_cluster, values[idx:idx + 3]
|
||||
)
|
||||
idx += 3
|
||||
|
||||
# Expert tracking state
|
||||
new_current_expert_idx = new_from_mlir_values(
|
||||
self.current_expert_idx, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
new_expert_tile_start = new_from_mlir_values(
|
||||
self.expert_tile_start, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
new_expert_tile_end = new_from_mlir_values(
|
||||
self.expert_tile_end, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
return MoEStaticPersistentTileScheduler(
|
||||
params=new_params,
|
||||
offs=new_offs,
|
||||
num_persistent_clusters=new_num_persistent_clusters,
|
||||
current_work_linear_idx=new_current_work_linear_idx,
|
||||
cta_id_in_cluster=new_cta_id_in_cluster,
|
||||
current_expert_idx=new_current_expert_idx,
|
||||
expert_tile_start=new_expert_tile_start,
|
||||
expert_tile_end=new_expert_tile_end,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Factory method
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
@dsl_user_op
|
||||
def create(
|
||||
params: MoEStaticSchedulerParams,
|
||||
offs: cute.Tensor,
|
||||
block_idx: Tuple[Integer, Integer, Integer],
|
||||
grid_dim: Tuple[Integer, Integer, Integer],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> "MoEStaticPersistentTileScheduler":
|
||||
"""
|
||||
Create a MoE persistent tile scheduler.
|
||||
|
||||
:param params: Scheduler parameters (from host)
|
||||
:param offs: Cumsum tensor of token counts per expert, shape (experts,)
|
||||
:param block_idx: CUDA block index
|
||||
:param grid_dim: CUDA grid dimensions
|
||||
"""
|
||||
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
|
||||
params.cluster_shape_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
bidx, bidy, bidz = block_idx
|
||||
current_work_linear_idx = Int32(bidz)
|
||||
|
||||
cta_id_in_cluster = (
|
||||
Int32(bidx % params.cluster_shape_mn[0]),
|
||||
Int32(bidy % params.cluster_shape_mn[1]),
|
||||
Int32(0),
|
||||
)
|
||||
|
||||
# Initialize expert tracking to "before expert 0"
|
||||
# The first call to _get_work_tile_for_linear_idx will advance to the correct expert
|
||||
current_expert_idx = Int32(0)
|
||||
expert_tile_start = Int32(0)
|
||||
expert_tile_end = Int32(0) # Will be computed on first access
|
||||
|
||||
return MoEStaticPersistentTileScheduler(
|
||||
params=params,
|
||||
offs=offs,
|
||||
num_persistent_clusters=num_persistent_clusters,
|
||||
current_work_linear_idx=current_work_linear_idx,
|
||||
cta_id_in_cluster=cta_id_in_cluster,
|
||||
current_expert_idx=current_expert_idx,
|
||||
expert_tile_start=expert_tile_start,
|
||||
expert_tile_end=expert_tile_end,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Tile iteration methods
|
||||
# =========================================================================
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
|
||||
"""Get the initial work tile info."""
|
||||
return self._get_work_tile_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
|
||||
"""Advance to the next work tile and return its info."""
|
||||
self._current_work_linear_idx += self.num_persistent_clusters
|
||||
return self._get_work_tile_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _get_work_tile_for_linear_idx(
|
||||
self,
|
||||
cluster_linear_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
) -> MoEWorkTileInfo:
|
||||
"""
|
||||
Convert a linear cluster index to MoEWorkTileInfo.
|
||||
|
||||
Uses cached expert tracking state for O(1) fast path when staying
|
||||
within the same expert. Advances expert state when needed.
|
||||
|
||||
Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range.
|
||||
"""
|
||||
# Ensure expert tracking is initialized and up-to-date
|
||||
self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip)
|
||||
|
||||
# Check if valid (still within expert range after advancing)
|
||||
is_valid = self.current_expert_idx < self.expert_cnt
|
||||
|
||||
work_tile_info = MoEWorkTileInfo(
|
||||
expert_idx=Int32(-1),
|
||||
tile_m_idx=Int32(0),
|
||||
tile_n_idx=Int32(0),
|
||||
k_tile_cnt=Int32(0),
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
# Compute local cluster tile indices within current expert
|
||||
local_idx = cluster_linear_idx - self.expert_tile_start
|
||||
cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx(
|
||||
local_idx, self.current_expert_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# Convert cluster tile indices to CTA tile indices
|
||||
# cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster
|
||||
cta_tile_m_idx = (
|
||||
cluster_tile_m_idx * self.cluster_shape_mn[0]
|
||||
+ self.cta_id_in_cluster[0] # type: ignore[index]
|
||||
)
|
||||
cta_tile_n_idx = (
|
||||
cluster_tile_n_idx * self.cluster_shape_mn[1]
|
||||
+ self.cta_id_in_cluster[1] # type: ignore[index]
|
||||
)
|
||||
# Compute k_tile_cnt
|
||||
k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip)
|
||||
|
||||
work_tile_info = MoEWorkTileInfo(
|
||||
expert_idx=self.current_expert_idx,
|
||||
tile_m_idx=cta_tile_m_idx,
|
||||
tile_n_idx=cta_tile_n_idx,
|
||||
k_tile_cnt=k_tile_cnt,
|
||||
)
|
||||
return work_tile_info
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _advance_expert_to_contain(
|
||||
self,
|
||||
cluster_linear_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Advance expert tracking state until current expert contains cluster_linear_idx,
|
||||
or we run out of experts.
|
||||
|
||||
Fast path: If already in correct expert, no work needed.
|
||||
"""
|
||||
# Initialize expert_tile_end if this is the first call (expert_tile_end == 0)
|
||||
if self.expert_tile_end == Int32(0):
|
||||
tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip)
|
||||
self.expert_tile_end = tiles_for_expert_0
|
||||
|
||||
# Advance until cluster_linear_idx < expert_tile_end or no more experts
|
||||
while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt:
|
||||
self.current_expert_idx = self.current_expert_idx + 1
|
||||
self.expert_tile_start = self.expert_tile_end
|
||||
|
||||
if self.current_expert_idx < self.expert_cnt:
|
||||
tiles_for_expert = self._compute_tiles_for_expert(
|
||||
self.current_expert_idx, loc=loc, ip=ip
|
||||
)
|
||||
self.expert_tile_end = self.expert_tile_end + tiles_for_expert
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _compute_tiles_for_expert(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Int32:
|
||||
"""Compute total cluster tiles for a given expert."""
|
||||
if const_expr(self.scenario == "2Dx2D"):
|
||||
# Fixed M=hidden, N=intermediate
|
||||
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
|
||||
return cluster_tile_m_cnt * cluster_tile_n_cnt
|
||||
else: # 2Dx3D
|
||||
# Variable M (tokens), fixed N
|
||||
tokens_i = self.offs[expert_idx]
|
||||
if expert_idx > 0:
|
||||
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
||||
cluster_tile_m_cnt = (
|
||||
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
|
||||
) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (
|
||||
self.intermediate + self.cluster_tile_n - 1
|
||||
) // self.cluster_tile_n
|
||||
return cluster_tile_m_cnt * cluster_tile_n_cnt
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _decompose_local_idx(
|
||||
self,
|
||||
local_idx: Int32,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""
|
||||
Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx).
|
||||
|
||||
Uses "short side first" strategy: the shorter dimension changes faster.
|
||||
This maximizes overlap between adjacent clusters for better L2 cache utilization.
|
||||
|
||||
For example, if m_cnt=2, n_cnt=8:
|
||||
- N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx
|
||||
- Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ...
|
||||
"""
|
||||
# Get tile counts for M and N
|
||||
cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts(
|
||||
expert_idx, loc=loc, ip=ip
|
||||
)
|
||||
cluster_tile_m_idx = -1
|
||||
cluster_tile_n_idx = -1
|
||||
|
||||
# Short side first: shorter dimension changes faster
|
||||
# If m_cnt <= n_cnt: m is shorter, m changes faster
|
||||
# local_idx = n_idx * m_cnt + m_idx
|
||||
# If n_cnt < m_cnt: n is shorter, n changes faster
|
||||
# local_idx = m_idx * n_cnt + n_idx
|
||||
if cluster_tile_m_cnt <= cluster_tile_n_cnt:
|
||||
# M is shorter or equal, M changes faster
|
||||
cluster_tile_m_idx = local_idx % cluster_tile_m_cnt
|
||||
cluster_tile_n_idx = local_idx // cluster_tile_m_cnt
|
||||
else:
|
||||
# N is shorter, N changes faster
|
||||
cluster_tile_n_idx = local_idx % cluster_tile_n_cnt
|
||||
cluster_tile_m_idx = local_idx // cluster_tile_n_cnt
|
||||
|
||||
return (cluster_tile_m_idx, cluster_tile_n_idx)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _get_cluster_tile_counts(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert."""
|
||||
if const_expr(self.scenario == "2Dx2D"):
|
||||
# Fixed M=hidden, N=intermediate
|
||||
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
|
||||
else: # 2Dx3D
|
||||
# Variable M (tokens), fixed N
|
||||
tokens_i = self.offs[expert_idx]
|
||||
if expert_idx > 0:
|
||||
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
||||
cluster_tile_m_cnt = (
|
||||
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
|
||||
) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (
|
||||
self.intermediate + self.cluster_tile_n - 1
|
||||
) // self.cluster_tile_n
|
||||
return (cluster_tile_m_cnt, cluster_tile_n_cnt)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _compute_k_tile_cnt(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Int32:
|
||||
"""
|
||||
Compute the number of K tiles for this expert.
|
||||
|
||||
2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k)
|
||||
2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k)
|
||||
"""
|
||||
if const_expr(self.scenario == "2Dx3D"):
|
||||
# K is hidden (fixed)
|
||||
return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k
|
||||
else: # 2Dx2D
|
||||
# K is tokens_i (variable per expert)
|
||||
tokens_i = self.offs[expert_idx]
|
||||
if expert_idx > cutlass.Int32(0):
|
||||
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
||||
return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator]
|
||||
@@ -0,0 +1,443 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
MoE Scheduler Extension.
|
||||
|
||||
Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level
|
||||
domain conversion and TMA descriptor selection. This is the "glue" layer between:
|
||||
|
||||
- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt)
|
||||
- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace
|
||||
- Kernel: orchestrates everything
|
||||
|
||||
Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own
|
||||
MoESchedExtension subclass with kernel-specific domain conversion logic.
|
||||
|
||||
Key design principles:
|
||||
- Unified interface: get_gmem_tensor() for all tensor types
|
||||
- Free implementation: no role-based templates, each subclass writes its own logic
|
||||
- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc.
|
||||
are available as tools but not mandatory
|
||||
|
||||
Architecture:
|
||||
|
||||
Scheduler ──(produces)──> MoEWorkTileInfo
|
||||
│
|
||||
expert_idx, tile_m, tile_n, k_cnt
|
||||
│
|
||||
v
|
||||
Extension ──(uses)──> OnlineTensormapDescCreator
|
||||
│ │
|
||||
│ get_gmem_tensor() │ get_desc_ptr()
|
||||
│ prefetch_for_expert() │ construct_and_write()
|
||||
│ │
|
||||
└── internal calls ───────┘
|
||||
|
||||
Kernel (caller): the only place that knows all three exist
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.typing import Pointer
|
||||
from cutlass.cutlass_dsl import Int32
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
|
||||
from blackwell.kernel.moe.moe_utils import (
|
||||
OnlineTensormapDescCreator,
|
||||
tensormap_ptr_for_copy,
|
||||
compute_expert_token_range,
|
||||
rewrite_tensor_shape,
|
||||
prefetch_tma_descriptor,
|
||||
)
|
||||
from blackwell.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MoESchedExtension(ABC):
|
||||
"""
|
||||
Abstract base class for MoE scheduler extensions.
|
||||
|
||||
Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA
|
||||
descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.)
|
||||
provides its own subclass with kernel-specific logic.
|
||||
|
||||
The extension:
|
||||
- Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval
|
||||
- Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors
|
||||
- Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors
|
||||
|
||||
Subclasses are free to add any additional attributes in __init__ (scenario,
|
||||
codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic
|
||||
per tensor_name. No role-based templates or rigid patterns are imposed.
|
||||
|
||||
Usage in kernel (caller):
|
||||
ext = ConcreteSchedExtension(tensormap_ctor, scenario=...)
|
||||
|
||||
while work_tile_info.is_valid_tile:
|
||||
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
|
||||
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
|
||||
# Use real_a, desc_a in cute.copy ...
|
||||
"""
|
||||
|
||||
def __init__(self, tensormap_ctor: OnlineTensormapDescCreator):
|
||||
super().__init__()
|
||||
self.tensormap_ctor = tensormap_ctor
|
||||
|
||||
@abstractmethod
|
||||
def get_gmem_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
gmem_tensor_in_moe_view: cute.Tensor,
|
||||
offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]],
|
||||
work_tile_info: MoEWorkTileInfo,
|
||||
) -> Tuple[cute.Tensor, "Pointer | None"]:
|
||||
"""
|
||||
Convert an MoE-view tensor to the real per-expert tensor for the
|
||||
current work tile, and return the appropriate TMA descriptor pointer.
|
||||
|
||||
The MoE-view tensor uses "fake" GEMM domain dimensions that span all
|
||||
experts (e.g., fake_m = tokens_sum). This method slices/offsets it
|
||||
to the current expert's actual region.
|
||||
|
||||
:param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa")
|
||||
:param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain
|
||||
:param offs: Either a single cumsum tensor (experts,), or a tuple of
|
||||
(offs_token, offs_padded) where offs_padded provides
|
||||
padded offsets for scale-factor domain conversion.
|
||||
:param work_tile_info: Current work tile from the scheduler
|
||||
:return: (real_tensor, tma_desc_ptr_or_none)
|
||||
- real_tensor: domain-offset and shape-rewritten tensor for this expert
|
||||
- tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy),
|
||||
or None if the caller should use the global TMA descriptor
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prefetch_for_expert(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
Prefetch expert-wise TMA descriptors for the given expert.
|
||||
|
||||
Called when the scheduler advances to a new expert, allowing the TMA
|
||||
descriptor cache to be warmed up before the descriptors are needed.
|
||||
|
||||
:param expert_idx: Index of the expert whose descriptors to prefetch
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Grouped MM Extension
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GroupedMmSchedExtension(MoESchedExtension):
|
||||
"""
|
||||
MoE scheduler extension for grouped_mm: handles tensors a, b, c.
|
||||
|
||||
Domain conversion logic per scenario:
|
||||
|
||||
2Dx3D:
|
||||
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
|
||||
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc
|
||||
|
||||
2Dx2D:
|
||||
A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc
|
||||
B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc
|
||||
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
tensormap_ctor: OnlineTensormapDescCreator,
|
||||
):
|
||||
super().__init__(tensormap_ctor)
|
||||
self.scenario = scenario
|
||||
|
||||
@cute.jit
|
||||
def get_gmem_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
gmem_tensor_in_moe_view: cute.Tensor,
|
||||
offs: cute.Tensor,
|
||||
work_tile_info: MoEWorkTileInfo,
|
||||
):
|
||||
expert_idx = work_tile_info.expert_idx
|
||||
token_offset, tokens_i = compute_expert_token_range(offs, expert_idx)
|
||||
|
||||
shape = gmem_tensor_in_moe_view.shape
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (fake_m, k, 1) -> offset fake_m, global desc
|
||||
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, k, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (fake_m, n, 1) -> expert-wise desc, no offset
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(tokens_i, shape[1], c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (m, fake_k, 1) -> expert-wise desc, no offset
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, fake_k, 1) -> expert-wise desc, no offset
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (m, n, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
raise ValueError("Invalid scenario or GEMM tensor name.")
|
||||
|
||||
@cute.jit
|
||||
def prefetch_for_expert(self, expert_idx: Int32) -> None:
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
|
||||
else:
|
||||
raise ValueError("Invalid scenario.")
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scaled Grouped MM Extension (block-scaled MoE)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ScaledGroupedMmSchedExtension(MoESchedExtension):
|
||||
"""
|
||||
MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb.
|
||||
|
||||
Extends GroupedMmSchedExtension with scale-factor tensor support.
|
||||
SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert
|
||||
via tile_atom_to_shape_SF.
|
||||
|
||||
The offs parameter is always a tuple (offs_token, offs_padded):
|
||||
- offs_token: cumsum offsets in data (activation) domain
|
||||
- offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity)
|
||||
|
||||
sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size.
|
||||
|
||||
Domain conversion logic per scenario:
|
||||
|
||||
2Dx3D:
|
||||
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
|
||||
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
C: (fake_m, n, 1) -> rewrite shape, expert-wise desc
|
||||
SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset,
|
||||
atom-tile, global desc
|
||||
SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx,
|
||||
atom-tile, global desc
|
||||
|
||||
2Dx2D:
|
||||
A: (m, fake_k, 1) -> rewrite shape, expert-wise desc
|
||||
B: (n, fake_k, 1) -> rewrite shape, expert-wise desc
|
||||
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
|
||||
atom-tile, expert-wise desc
|
||||
SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
|
||||
atom-tile, expert-wise desc
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
tensormap_ctor: OnlineTensormapDescCreator,
|
||||
):
|
||||
super().__init__(tensormap_ctor)
|
||||
self.scenario = scenario
|
||||
|
||||
@cute.jit
|
||||
def get_gmem_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
gmem_tensor_in_moe_view: cute.Tensor,
|
||||
offs: Tuple[cute.Tensor, cute.Tensor],
|
||||
work_tile_info: MoEWorkTileInfo,
|
||||
):
|
||||
# Unpack the offs tuple
|
||||
offs_token, offs_padded = offs
|
||||
|
||||
expert_idx = work_tile_info.expert_idx
|
||||
token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx)
|
||||
padded_offset, padded_size_i = compute_expert_token_range(
|
||||
offs_padded, expert_idx
|
||||
)
|
||||
|
||||
shape = gmem_tensor_in_moe_view.shape
|
||||
stride = gmem_tensor_in_moe_view.stride
|
||||
c1 = cutlass.Int32(1)
|
||||
sf_vec_size = self.tensormap_ctor.sf_vec_size
|
||||
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (fake_m, k, 1) -> offset fake_m, global desc
|
||||
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, k, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (fake_m, n, 1) -> expert-wise desc
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(tokens_i, shape[1], c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfa"):
|
||||
# SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc
|
||||
real = cute.domain_offset(
|
||||
(padded_offset, 0, 0), gmem_tensor_in_moe_view
|
||||
)
|
||||
per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = cute.make_tensor(
|
||||
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
|
||||
)
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfb"):
|
||||
# SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = cute.make_tensor(
|
||||
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
|
||||
)
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (m, fake_k, 1) -> expert-wise desc
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, fake_k, 1) -> expert-wise desc
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (m, n, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfa"):
|
||||
# SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
|
||||
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfb"):
|
||||
# SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
|
||||
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
raise ValueError("Invalid scenario or tensor name.")
|
||||
|
||||
@cute.jit
|
||||
def prefetch_for_expert(self, expert_idx: Int32) -> None:
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx))
|
||||
else:
|
||||
raise ValueError("Invalid scenario.")
|
||||
910
examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_utils.py
Normal file
910
examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_utils.py
Normal file
@@ -0,0 +1,910 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Online TMA Descriptor Construction Utilities.
|
||||
|
||||
Provides utilities for dynamically creating TMA descriptors at kernel runtime
|
||||
based on runtime-provided information (problem sizes, pointers, etc.).
|
||||
|
||||
Key components:
|
||||
- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods)
|
||||
- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors
|
||||
- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM
|
||||
- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM
|
||||
- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.)
|
||||
- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type
|
||||
- compute_expert_token_range: Compute per-expert token offset and count from offs
|
||||
- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Literal, Tuple, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.typing import AddressSpace, Pointer
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
from cutlass.cutlass_dsl import dsl_user_op, Int32
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass._mlir.dialects import cute as _cute_ir
|
||||
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
|
||||
|
||||
TensormapDescBytes = 128
|
||||
|
||||
# =============================================================================
|
||||
# Pointer Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def spin_wait(
|
||||
ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Generic spin-wait.
|
||||
Example usage:
|
||||
# Wait until counter >= total_blocks
|
||||
spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100)
|
||||
|
||||
# Wait until flag == 1
|
||||
spin_wait(flag_ptr, lambda x: x == 1)
|
||||
"""
|
||||
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
|
||||
while not condition(current):
|
||||
# Load with L1 cache bypass (ld.global.cg)
|
||||
if cutlass.const_expr(fail_sleep_cycles > 0):
|
||||
cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip)
|
||||
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def gmem_ptr_to_generic(
|
||||
gmem_ptr: Pointer,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> Pointer:
|
||||
if gmem_ptr.memspace != AddressSpace.gmem:
|
||||
raise ValueError(
|
||||
f"gmem_ptr_to_generic requires pointer in gmem address space, "
|
||||
f"got {gmem_ptr.memspace}"
|
||||
)
|
||||
# Get LLVM pointer and cast to generic address space
|
||||
llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||||
generic_llvm_ptr = llvm.addrspacecast(
|
||||
llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip
|
||||
)
|
||||
# Create a new cute.Pointer with generic address space, preserving alignment
|
||||
return cute.make_ptr(
|
||||
gmem_ptr.dtype,
|
||||
generic_llvm_ptr,
|
||||
AddressSpace.generic,
|
||||
assumed_align=gmem_ptr.alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def generic_ptr_to_gmem(
|
||||
generic_ptr: Pointer,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> Pointer:
|
||||
if generic_ptr.memspace != AddressSpace.generic:
|
||||
raise ValueError(
|
||||
f"generic_ptr_to_gmem requires pointer in generic address space, "
|
||||
f"got {generic_ptr.memspace}"
|
||||
)
|
||||
# Get LLVM pointer and cast to gmem address space
|
||||
llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||||
gmem_llvm_ptr = llvm.addrspacecast(
|
||||
llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip
|
||||
)
|
||||
# Create a new cute.Pointer with gmem address space, preserving alignment
|
||||
return cute.make_ptr(
|
||||
generic_ptr.dtype,
|
||||
gmem_llvm_ptr,
|
||||
AddressSpace.gmem,
|
||||
assumed_align=generic_ptr.alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Prefetch a TMA descriptor from global memory.
|
||||
|
||||
This function prefetches the TMA descriptor pointed to by tma_desc_ptr
|
||||
into the TMA descriptor cache. The pointer must be in generic or global
|
||||
address space. If a gmem pointer is passed, it will be automatically
|
||||
converted to generic address space.
|
||||
|
||||
:param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory
|
||||
:type tma_desc_ptr: Pointer
|
||||
:raises ValueError: If pointer is not in generic or global address space
|
||||
"""
|
||||
if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic):
|
||||
raise ValueError(
|
||||
f"prefetch_tma_descriptor requires pointer in gmem or generic address space, "
|
||||
f"got {tma_desc_ptr.memspace}"
|
||||
)
|
||||
# Convert gmem pointer to generic if needed
|
||||
if tma_desc_ptr.memspace == AddressSpace.gmem:
|
||||
tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip)
|
||||
# Convert cute.Pointer to LLVM pointer for prefetch
|
||||
llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||||
from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch
|
||||
|
||||
nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer:
|
||||
"""Offset a pointer by a given number of bytes."""
|
||||
element_offset = byte_offset * 8 // ptr.dtype.width
|
||||
return ptr + element_offset
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer:
|
||||
"""
|
||||
Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy.
|
||||
|
||||
cute.copy requires the tma_desc_ptr to be in generic address space and
|
||||
recast to TmaDescriptorTiledType. This utility performs both conversions.
|
||||
|
||||
:param raw_ptr: Raw pointer to TMA descriptor in gmem
|
||||
:type raw_ptr: Pointer
|
||||
:return: Pointer compatible with cute.copy's tma_desc_ptr parameter
|
||||
:rtype: Pointer
|
||||
"""
|
||||
generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip)
|
||||
tma_desc_ptr_ty = _cute_ir.PtrType.get(
|
||||
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
|
||||
generic_ptr.memspace,
|
||||
generic_ptr.alignment,
|
||||
)
|
||||
return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def compute_expert_token_range(
|
||||
offs: cute.Tensor,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""
|
||||
Compute token offset and count for a given expert from the cumsum offs tensor.
|
||||
|
||||
:param offs: Cumulative sum tensor of token counts per expert, shape (experts,)
|
||||
:param expert_idx: Index of the expert
|
||||
:return: (token_offset, tokens_i) where token_offset is the start position
|
||||
and tokens_i is the number of tokens for this expert
|
||||
"""
|
||||
token_offset = Int32(0)
|
||||
if expert_idx > Int32(0):
|
||||
token_offset = offs[expert_idx - 1] # type: ignore[assignment]
|
||||
tokens_i = offs[expert_idx] - token_offset
|
||||
return token_offset, tokens_i
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def rewrite_tensor_shape(
|
||||
tensor: cute.Tensor,
|
||||
new_shape: Tuple,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Tensor:
|
||||
"""
|
||||
Rewrite tensor shape while keeping the same stride and iterator.
|
||||
|
||||
This is primarily for debug friendliness - shows the actual expert's shape
|
||||
instead of the fake global shape. No runtime overhead as it becomes
|
||||
dead code in non-debug builds.
|
||||
|
||||
:param tensor: Source tensor whose stride and iterator to preserve
|
||||
:param new_shape: New shape to apply
|
||||
:return: New tensor with the given shape but original stride and iterator
|
||||
"""
|
||||
new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip)
|
||||
return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TMA Descriptor Workspace Helper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TensormapWorkspace:
|
||||
"""
|
||||
Helper for linear workspace layout of TMA descriptors.
|
||||
|
||||
Manages address calculation for a workspace buffer containing TMA descriptors
|
||||
organized as: for each executor (e.g., expert or group), a fixed set of
|
||||
named descriptor slots.
|
||||
|
||||
Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...]
|
||||
|
||||
Example:
|
||||
# 2Dx3D MoE: only C is expert-wise
|
||||
workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||||
|
||||
# 2Dx2D MoE: A and B are expert-wise
|
||||
workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
|
||||
|
||||
# General grouped GEMM: all three tensors
|
||||
workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"])
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_ptr: Pointer, slot_names: list):
|
||||
"""
|
||||
:param workspace_ptr: Pointer to the beginning of the workspace buffer
|
||||
:param slot_names: Ordered list of tensor names, defining the slot layout
|
||||
per executor. e.g., ["a", "b", "c"]
|
||||
"""
|
||||
self.workspace_ptr = workspace_ptr
|
||||
self._name_to_slot = {name: i for i, name in enumerate(slot_names)}
|
||||
self._slots_per_executor = len(slot_names)
|
||||
|
||||
@cute.jit
|
||||
def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
"""
|
||||
Get the workspace pointer for a specific TMA descriptor.
|
||||
|
||||
:param tensor_name: Name of the tensor (must be one of the slot_names)
|
||||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
|
||||
:return: Aligned pointer to the TMA descriptor in workspace
|
||||
"""
|
||||
if cutlass.const_expr(tensor_name not in self._name_to_slot):
|
||||
raise ValueError(
|
||||
f"Invalid tensor_name '{tensor_name}', "
|
||||
f"expected one of {list(self._name_to_slot.keys())}"
|
||||
)
|
||||
slot = self._name_to_slot[tensor_name]
|
||||
byte_offset = (
|
||||
executor_idx * self._slots_per_executor + slot
|
||||
) * TensormapDescBytes
|
||||
return ptr_offset_bytes(self.workspace_ptr, byte_offset).align(
|
||||
TensormapDescBytes
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def size_bytes(num_slots: int, num_executors: int) -> int:
|
||||
"""
|
||||
Calculate workspace size in bytes.
|
||||
|
||||
:param num_slots: Number of descriptor slots per executor
|
||||
:param num_executors: Total number of executors (e.g., expert_cnt or group_cnt)
|
||||
:return: Total workspace size in bytes
|
||||
"""
|
||||
return num_slots * num_executors * TensormapDescBytes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Online TMA Descriptor Creator (Abstract Base Class)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OnlineTensormapDescCreator(ABC):
|
||||
"""
|
||||
Abstract base class for building TMA descriptors online (at kernel runtime).
|
||||
|
||||
Subclasses store all needed parameters (both codegen-time configs and runtime
|
||||
values) as explicit instance attributes in __init__. No dict-based APIs.
|
||||
|
||||
Subclasses must implement exactly 2 abstract methods:
|
||||
- construct_and_write: Build TMA descriptor(s) for one executor and write to workspace
|
||||
- get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace
|
||||
|
||||
To convert the raw pointer for use with cute.copy, callers should use the
|
||||
standalone tensormap_ptr_for_copy() utility.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
|
||||
"""
|
||||
Build TMA descriptor(s) for one executor and write to workspace.
|
||||
|
||||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx).
|
||||
Semantics may vary by subclass when ``dependency`` is provided.
|
||||
:param dependency: Optional pipeline consumer for inter-warp-group
|
||||
synchronization. When provided, the subclass decides when to wait
|
||||
(via ``dependency.wait_and_advance()``) and release. The subclass
|
||||
also decides how to interpret ``executor_idx`` in this mode.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
"""
|
||||
Get the raw gmem pointer to a specific TMA descriptor in workspace.
|
||||
|
||||
:param tensor_name: Name identifying which tensor's descriptor
|
||||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
|
||||
:return: Raw pointer (gmem) to the TMA descriptor
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Grouped GEMM Tensormap Constructor
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
|
||||
"""
|
||||
Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only).
|
||||
|
||||
Non-expert-wise descriptors are passed directly at kernel launch.
|
||||
This class only handles:
|
||||
- 2Dx3D: C descriptors (expert-wise, to avoid write conflicts)
|
||||
- 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis)
|
||||
|
||||
All parameters are stored as explicit instance attributes (no dicts).
|
||||
|
||||
Workspace layout:
|
||||
- 2Dx3D: [C_0, C_1, ..., C_{n-1}]
|
||||
- 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
# Codegen-time configs
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
c_dtype,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
epi_smem_layout,
|
||||
a_tma_op,
|
||||
b_tma_op,
|
||||
c_tma_op,
|
||||
tiled_mma,
|
||||
mma_tiler,
|
||||
cluster_layout_vmnk_shape,
|
||||
epi_tile,
|
||||
# Runtime params
|
||||
a_tensor: cute.Tensor, # fake GEMM domain A
|
||||
b_tensor: cute.Tensor, # fake GEMM domain B
|
||||
c_tensor: cute.Tensor, # fake GEMM domain C
|
||||
offs: cute.Tensor, # (experts,) cumsum
|
||||
workspace_ptr: Pointer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scenario = scenario
|
||||
# Codegen-time configs
|
||||
self.a_dtype = a_dtype
|
||||
self.b_dtype = b_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.a_smem_layout = a_smem_layout
|
||||
self.b_smem_layout = b_smem_layout
|
||||
self.epi_smem_layout = epi_smem_layout
|
||||
self.a_tma_op = a_tma_op
|
||||
self.b_tma_op = b_tma_op
|
||||
self.c_tma_op = c_tma_op
|
||||
self.tiled_mma = tiled_mma
|
||||
self.mma_tiler = mma_tiler
|
||||
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
|
||||
self.epi_tile = epi_tile
|
||||
# Runtime params
|
||||
self.a_tensor = a_tensor
|
||||
self.b_tensor = b_tensor
|
||||
self.c_tensor = c_tensor
|
||||
self.offs = offs
|
||||
# Workspace with scenario-specific slot layout
|
||||
if scenario == "2Dx3D":
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||||
else:
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
|
||||
|
||||
@staticmethod
|
||||
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
|
||||
"""Calculate workspace size in bytes for tensormap descriptors."""
|
||||
if scenario == "2Dx3D":
|
||||
return TensormapWorkspace.size_bytes(1, expert_cnt) # only C
|
||||
else:
|
||||
return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B
|
||||
|
||||
@cute.jit
|
||||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
return self.workspace.get_ptr(tensor_name, executor_idx)
|
||||
|
||||
@cute.jit
|
||||
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
|
||||
"""
|
||||
Create expert-wise tensormap descriptors for the given expert.
|
||||
|
||||
- 2Dx3D: Creates C descriptor for this expert
|
||||
- 2Dx2D: Creates A and B descriptors for this expert
|
||||
"""
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
self._construct_c_desc_2dx3d(executor_idx)
|
||||
else: # 2Dx2D
|
||||
self._construct_ab_descs_2dx2d(executor_idx)
|
||||
|
||||
@cute.jit
|
||||
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx3D: Create expert-wise C descriptor.
|
||||
C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1)
|
||||
Slice fake_m -> (tokens_i, intermediate, 1) per expert.
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
|
||||
c_ptr = self.c_tensor.iterator
|
||||
c_stride = self.c_tensor.stride
|
||||
intermediate = self.c_tensor.shape[1] # type: ignore[index]
|
||||
|
||||
c1 = cutlass.Int32(1)
|
||||
c0 = cutlass.Int32(0)
|
||||
|
||||
c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index]
|
||||
c_layout_i = cute.make_layout(
|
||||
(tokens_i, intermediate, c1),
|
||||
stride=(c_stride[0], c_stride[1], c0), # type: ignore[index]
|
||||
)
|
||||
c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i)
|
||||
|
||||
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
|
||||
self.c_tma_op,
|
||||
c_tensor_i,
|
||||
self.epi_smem_layout,
|
||||
self.epi_tile,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
|
||||
|
||||
@cute.jit
|
||||
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx2D: Create expert-wise A and B descriptors.
|
||||
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
|
||||
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
|
||||
c1 = cutlass.Int32(1)
|
||||
c0 = cutlass.Int32(0)
|
||||
|
||||
# A tensor: (m, fake_k, 1) -> (m, tokens_i, 1)
|
||||
a_ptr = self.a_tensor.iterator
|
||||
a_stride = self.a_tensor.stride
|
||||
a_m = self.a_tensor.shape[0] # type: ignore[index]
|
||||
|
||||
a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index]
|
||||
a_layout_i = cute.make_layout(
|
||||
(a_m, tokens_i, c1),
|
||||
stride=(a_stride[0], a_stride[1], c0), # type: ignore[index]
|
||||
)
|
||||
a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i)
|
||||
|
||||
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
self.a_tma_op,
|
||||
a_tensor_i,
|
||||
self.a_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
|
||||
|
||||
# B tensor: (n, fake_k, 1) -> (n, tokens_i, 1)
|
||||
b_ptr = self.b_tensor.iterator
|
||||
b_stride = self.b_tensor.stride
|
||||
b_n = self.b_tensor.shape[0] # type: ignore[index]
|
||||
|
||||
b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index]
|
||||
b_layout_i = cute.make_layout(
|
||||
(b_n, tokens_i, c1),
|
||||
stride=(b_stride[0], b_stride[1], c0), # type: ignore[index]
|
||||
)
|
||||
b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i)
|
||||
|
||||
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
self.b_tma_op,
|
||||
b_tensor_i,
|
||||
self.b_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Scaled Grouped GEMM Tensormap Constructor
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
|
||||
"""
|
||||
Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled).
|
||||
|
||||
.. py:attribute:: ChunkSize
|
||||
:value: 128
|
||||
|
||||
Number of experts processed per chunk in the desc_init_kernel.
|
||||
Must match the warp-group width (4 warps × 32 threads).
|
||||
|
||||
Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support.
|
||||
|
||||
Expert-wise descriptors only — non-expert-wise descriptors are passed
|
||||
directly at kernel launch.
|
||||
|
||||
Workspace layout:
|
||||
- 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert)
|
||||
- 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert)
|
||||
|
||||
:param scenario: "2Dx3D" or "2Dx2D"
|
||||
:param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4)
|
||||
:param a_dtype: Data type for tensor A
|
||||
:param b_dtype: Data type for tensor B
|
||||
:param c_dtype: Data type for tensor C
|
||||
:param sf_dtype: Data type for scale factors (SFA/SFB)
|
||||
:param a_smem_layout: SMEM layout for A TMA
|
||||
:param b_smem_layout: SMEM layout for B TMA
|
||||
:param epi_smem_layout: SMEM layout for epilogue (C) TMA
|
||||
:param sfa_smem_layout: SMEM layout for SFA TMA
|
||||
:param sfb_smem_layout: SMEM layout for SFB TMA
|
||||
:param a_tma_op: TMA operation for A
|
||||
:param b_tma_op: TMA operation for B
|
||||
:param c_tma_op: TMA operation for C (S2G store or reduce)
|
||||
:param sfa_tma_op: TMA operation for SFA
|
||||
:param sfb_tma_op: TMA operation for SFB
|
||||
:param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction
|
||||
:param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication)
|
||||
:param mma_tiler: MMA tiler shape (M, N, K)
|
||||
:param mma_tiler_sfb: MMA tiler shape for SFB
|
||||
:param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast
|
||||
:param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast
|
||||
:param epi_tile: Epilogue tile shape
|
||||
:param a_tensor: Fake GEMM domain A tensor
|
||||
:param b_tensor: Fake GEMM domain B tensor
|
||||
:param c_tensor: Fake GEMM domain C tensor
|
||||
:param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout)
|
||||
:param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout)
|
||||
:param offs: (experts,) cumsum offsets in data domain
|
||||
:param offs_padded: (experts,) cumsum offsets in padded scale domain
|
||||
:param workspace_ptr: Pointer to workspace for TMA descriptors
|
||||
:param expert_cnt: Total number of experts
|
||||
"""
|
||||
|
||||
ChunkSize = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
sf_vec_size: int,
|
||||
# Codegen-time configs: dtypes
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
c_dtype,
|
||||
sf_dtype,
|
||||
# Codegen-time configs: SMEM layouts
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
epi_smem_layout,
|
||||
sfa_smem_layout,
|
||||
sfb_smem_layout,
|
||||
# Codegen-time configs: TMA ops
|
||||
a_tma_op,
|
||||
b_tma_op,
|
||||
c_tma_op,
|
||||
sfa_tma_op,
|
||||
sfb_tma_op,
|
||||
# Codegen-time configs: MMA / cluster / tile
|
||||
tiled_mma,
|
||||
tiled_mma_sfb,
|
||||
mma_tiler,
|
||||
mma_tiler_sfb,
|
||||
cluster_layout_vmnk_shape,
|
||||
cluster_layout_sfb_vmnk_shape,
|
||||
epi_tile,
|
||||
# Runtime params
|
||||
a_tensor: cute.Tensor,
|
||||
b_tensor: cute.Tensor,
|
||||
c_tensor: cute.Tensor,
|
||||
sfa_tensor: cute.Tensor,
|
||||
sfb_tensor: cute.Tensor,
|
||||
offs: cute.Tensor,
|
||||
offs_padded: cute.Tensor,
|
||||
workspace_ptr: Pointer,
|
||||
expert_cnt: Optional[Union[Int32, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scenario = scenario
|
||||
self.sf_vec_size = sf_vec_size
|
||||
# Dtypes
|
||||
self.a_dtype = a_dtype
|
||||
self.b_dtype = b_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.sf_dtype = sf_dtype
|
||||
# SMEM layouts
|
||||
self.a_smem_layout = a_smem_layout
|
||||
self.b_smem_layout = b_smem_layout
|
||||
self.epi_smem_layout = epi_smem_layout
|
||||
self.sfa_smem_layout = sfa_smem_layout
|
||||
self.sfb_smem_layout = sfb_smem_layout
|
||||
# TMA ops
|
||||
self.a_tma_op = a_tma_op
|
||||
self.b_tma_op = b_tma_op
|
||||
self.c_tma_op = c_tma_op
|
||||
self.sfa_tma_op = sfa_tma_op
|
||||
self.sfb_tma_op = sfb_tma_op
|
||||
# MMA / cluster / tile
|
||||
self.tiled_mma = tiled_mma
|
||||
self.tiled_mma_sfb = tiled_mma_sfb
|
||||
self.mma_tiler = mma_tiler
|
||||
self.mma_tiler_sfb = mma_tiler_sfb
|
||||
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
|
||||
self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape
|
||||
self.epi_tile = epi_tile
|
||||
# Runtime params
|
||||
self.a_tensor = a_tensor
|
||||
self.b_tensor = b_tensor
|
||||
self.c_tensor = c_tensor
|
||||
self.sfa_tensor = sfa_tensor
|
||||
self.sfb_tensor = sfb_tensor
|
||||
self.offs = offs
|
||||
self.offs_padded = offs_padded
|
||||
self.expert_cnt = expert_cnt
|
||||
# Workspace with scenario-specific slot layout
|
||||
if scenario == "2Dx3D":
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||||
else:
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"])
|
||||
|
||||
@staticmethod
|
||||
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
|
||||
"""Calculate workspace size in bytes for tensormap descriptors."""
|
||||
if scenario == "2Dx3D":
|
||||
return TensormapWorkspace.size_bytes(1, expert_cnt) # C only
|
||||
else:
|
||||
return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB
|
||||
|
||||
@cute.jit
|
||||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
return self.workspace.get_ptr(tensor_name, executor_idx)
|
||||
|
||||
@cute.jit
|
||||
def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None:
|
||||
"""
|
||||
Create expert-wise tensormap descriptors for all experts.
|
||||
|
||||
``lane_in_group`` is the thread's position within its warp group
|
||||
(0..ChunkSize-1). The method loops internally over all experts in
|
||||
chunks of ``ChunkSize``, with two-phase pipeline synchronization
|
||||
per chunk.
|
||||
|
||||
Per-chunk execution:
|
||||
|
||||
1. Phase 1: Build descriptors that do NOT depend on ``offs_padded``
|
||||
(A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum.
|
||||
2. Barrier: ``consumer.wait_and_advance()`` — all threads participate.
|
||||
3. Phase 2: Build descriptors that depend on ``offs_padded``
|
||||
(SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer.
|
||||
4. Release: ``handle.release()`` — all threads participate.
|
||||
|
||||
:param lane_in_group: Thread's position within the warp group (0..127).
|
||||
:param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the
|
||||
consumer for mbarrier sync, and the SMEM tensor of shape
|
||||
``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``.
|
||||
"""
|
||||
consumer, smem_offs_padded = dependency
|
||||
assert self.expert_cnt is not None
|
||||
num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize
|
||||
|
||||
chunk_idx = cutlass.Int32(0)
|
||||
while chunk_idx < num_chunks:
|
||||
expert_idx = chunk_idx * self.ChunkSize + lane_in_group
|
||||
in_bounds = expert_idx < self.expert_cnt
|
||||
|
||||
# Phase 1: non-dependent descriptors
|
||||
if in_bounds:
|
||||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
self._construct_ab_descs_2dx2d(expert_idx)
|
||||
else:
|
||||
self._construct_c_desc_2dx3d(expert_idx)
|
||||
|
||||
# All threads participate in barrier (fixed arrive count)
|
||||
handle = consumer.wait_and_advance()
|
||||
|
||||
# Phase 2: dependent descriptors (read padded offsets from SMEM)
|
||||
if in_bounds:
|
||||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
# smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]]
|
||||
# padded_offset = smem[lane] (prev expert's cumulative)
|
||||
# padded_end = smem[lane + 1] (this expert's cumulative)
|
||||
padded_offset = smem_offs_padded[lane_in_group]
|
||||
padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset
|
||||
self._construct_sf_descs_2dx2d_direct(
|
||||
expert_idx, padded_offset, padded_size_i
|
||||
)
|
||||
|
||||
# All threads release (fixed arrive count)
|
||||
handle.release()
|
||||
|
||||
chunk_idx += 1
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.jit
|
||||
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx3D: Create expert-wise C descriptor.
|
||||
C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert.
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor)
|
||||
c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index]
|
||||
|
||||
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
|
||||
self.c_tma_op,
|
||||
c_i,
|
||||
self.epi_smem_layout,
|
||||
self.epi_tile,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.jit
|
||||
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx2D: Create expert-wise A and B descriptors.
|
||||
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
|
||||
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
# A: (m, fake_k, 1) -> domain_offset + rewrite shape
|
||||
a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor)
|
||||
a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
|
||||
|
||||
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
self.a_tma_op,
|
||||
a_i,
|
||||
self.a_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
|
||||
|
||||
# B: (n, fake_k, 1) -> domain_offset + rewrite shape
|
||||
b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor)
|
||||
b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
|
||||
|
||||
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
self.b_tma_op,
|
||||
b_i,
|
||||
self.b_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2Dx2D: SFA, SFB descriptors (new for block-scaled)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.jit
|
||||
def _construct_sf_descs_2dx2d_direct(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
padded_offset: Int32,
|
||||
padded_size_i: Int32,
|
||||
) -> None:
|
||||
"""
|
||||
2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed
|
||||
padded offset and size.
|
||||
|
||||
This variant allows the caller to supply padded offsets from SMEM
|
||||
(in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM.
|
||||
"""
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
a_chunks_to_move = (
|
||||
padded_offset
|
||||
// self.sf_vec_size
|
||||
* cute.size(self.sfa_tensor, mode=[0])
|
||||
// 128
|
||||
)
|
||||
a_elems_to_move = (
|
||||
cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size
|
||||
)
|
||||
b_chunks_to_move = (
|
||||
padded_offset
|
||||
// self.sf_vec_size
|
||||
* cute.size(self.sfb_tensor, mode=[0])
|
||||
// 128
|
||||
)
|
||||
b_elems_to_move = (
|
||||
cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size
|
||||
)
|
||||
|
||||
per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size)
|
||||
sfa_i = cute.make_tensor(
|
||||
self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i
|
||||
)
|
||||
|
||||
tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
self.sfa_tma_op,
|
||||
sfa_i,
|
||||
self.sfa_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
internal_type=cutlass.Uint64,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx))
|
||||
|
||||
per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size)
|
||||
sfb_i = cute.make_tensor(
|
||||
self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i
|
||||
)
|
||||
|
||||
tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
self.sfb_tma_op,
|
||||
sfb_i,
|
||||
self.sfb_smem_layout,
|
||||
self.mma_tiler_sfb,
|
||||
self.tiled_mma_sfb,
|
||||
self.cluster_layout_sfb_vmnk_shape,
|
||||
internal_type=cutlass.Uint64,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -47,11 +47,11 @@ from cutlass.cute.runtime import make_ptr
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
examples_dir = os.path.join(current_dir, "..", "..")
|
||||
examples_dir = os.path.join(current_dir, "..", "..", "..", "..")
|
||||
if examples_dir not in sys.path:
|
||||
sys.path.insert(0, examples_dir)
|
||||
|
||||
from blackwell.tutorial_gemm.utils import create_parser, run
|
||||
from cute.blackwell.tutorial.tutorial_gemm.utils import create_parser, run
|
||||
|
||||
mma_tiler_mn = (128, 256)
|
||||
mma_inst_shape_k = 64
|
||||
@@ -49,11 +49,11 @@ from cutlass.cute.runtime import from_dlpack, make_ptr
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
examples_dir = os.path.join(current_dir, "..", "..")
|
||||
examples_dir = os.path.join(current_dir, "..", "..", "..", "..")
|
||||
if examples_dir not in sys.path:
|
||||
sys.path.insert(0, examples_dir)
|
||||
|
||||
from blackwell.tutorial_gemm.utils import create_parser, run
|
||||
from cute.blackwell.tutorial.tutorial_gemm.utils import create_parser, run
|
||||
|
||||
mma_tiler_mn = (256, 256)
|
||||
mma_inst_shape_k = 64
|
||||
@@ -0,0 +1,385 @@
|
||||
# CUTLASS Tutorial Examples for Blackwell TMA
|
||||
|
||||
## TMA V0: Understanding tma_partition - The Foundation of TMA Operations
|
||||
|
||||
This example demonstrates the fundamental building blocks of Tensor Memory Accelerator (TMA) operations on NVIDIA Blackwell (SM100) architecture using CuTe DSL. It focuses on understanding the `tma_partition` interface, which is essential for all TMA operations.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
* **TMA Load (Global → Shared)**: Asynchronous bulk data transfer from Global Memory to Shared Memory using TMA hardware
|
||||
* **TMA Store (Shared → Global)**: Asynchronous bulk data transfer from Shared Memory back to Global Memory
|
||||
* **tma_partition**: Core interface that prepares tensors for TMA operations by partitioning them according to TMA atom layout requirements
|
||||
* **group_modes**: Tensor mode grouping to define the TMA atom shape - crucial for proper tma_partition usage
|
||||
* **mbarrier Synchronization**: Hardware barriers (`mbarrier_init`, `mbarrier_arrive`, `mbarrier_wait`) for synchronizing asynchronous TMA operations
|
||||
|
||||
### Kernel Architecture
|
||||
|
||||
The `Sm100SimpleCopyKernel` performs a simple tile-based copy operation to illustrate TMA fundamentals:
|
||||
|
||||
#### Configuration
|
||||
|
||||
* **tile_shape**: Fixed at (128, 128) - Tile dimensions (M, N)
|
||||
* **cluster_shape_mn**: Fixed at (1, 1) - Single CTA execution (no cluster parallelism)
|
||||
* **Shared Memory**: Single buffer sized to hold one tile (tile_m × tile_n elements)
|
||||
* **Synchronization**: Single mbarrier for TMA load completion
|
||||
|
||||
#### Key Components
|
||||
|
||||
1. **TMA Descriptor Creation**: Creates TMA atoms (`tma_atom_src`, `tma_atom_dst`) that encapsulate TMA hardware instructions
|
||||
2. **Shared Memory Layout**: Row-major layout `(tile_m, tile_n):(tile_n, 1)` for simplicity
|
||||
3. **Barrier Management**: Single barrier coordinates TMA load completion before processing
|
||||
|
||||
### Execution Flow
|
||||
|
||||
1. **Initialization**:
|
||||
* Allocate shared memory buffer for one tile
|
||||
* Initialize mbarrier with `elect_one()` to ensure proper synchronization semantics
|
||||
* Set barrier to expect TMA transaction bytes (`tile_m × tile_n × element_size`)
|
||||
* Synchronize all threads after barrier initialization
|
||||
|
||||
2. **Tensor Preparation**:
|
||||
```
|
||||
Tile global tensors into (tile_m, tile_n) blocks
|
||||
Apply group_modes to combine tile dimensions into Mode 0 (TMA atom)
|
||||
|
||||
Example: gSrc_tiled with shape (128, 128, 4, 2)
|
||||
After group_modes(_, 0, 2): ((128, 128), 4, 2)
|
||||
└───┬────┘ └─┬─┘
|
||||
Mode 0 Rest modes
|
||||
```
|
||||
|
||||
3. **TMA Partition**:
|
||||
```python
|
||||
tAsA, tAgA = tma_partition(
|
||||
tma_atom_src, # TMA operation atom
|
||||
cta_id=0, # CTA ID within cluster
|
||||
cta_layout, # Cluster layout
|
||||
group_modes(smem_tensor, 0, 2), # SMEM view (Mode 0 = atom)
|
||||
group_modes(gSrc_tiled, 0, 2) # Global view (Mode 0 = atom)
|
||||
)
|
||||
# Returns:
|
||||
# tAsA: SMEM view with TMA internal layout
|
||||
# tAgA: Global view with shape ((TMA_Layout), grid_m, grid_n)
|
||||
```
|
||||
|
||||
4. **Tile Selection**:
|
||||
```python
|
||||
# Select specific tile for this CTA
|
||||
tAgA_cta = tAgA[(None, bidx, bidy)]
|
||||
# None: keep entire TMA atom (Mode 0)
|
||||
# bidx, bidy: index into grid dimensions
|
||||
```
|
||||
|
||||
5. **TMA Load** (Global → Shared):
|
||||
```python
|
||||
cute.copy(tma_atom_src, tAgA_cta, tAsA, tma_bar_ptr=barrier_ptr)
|
||||
# Arrive on barrier (producer signals completion)
|
||||
# Wait on barrier (all threads wait for TMA completion)
|
||||
```
|
||||
|
||||
6. **TMA Store** (Shared → Global):
|
||||
```python
|
||||
cute.copy(tma_atom_dst, tAsA, tBgB_cta)
|
||||
# Synchronous store completes before kernel exit
|
||||
```
|
||||
|
||||
### Understanding tma_partition
|
||||
|
||||
The `tma_partition` function is the key to TMA operations. The tutorial includes comprehensive inline diagrams showing:
|
||||
|
||||
* **Input Tensor Shapes**: How tensors are organized before partitioning
|
||||
* **group_modes Effect**: How mode grouping creates the required atom structure
|
||||
* **Partition Output**: The structure of partitioned tensors for TMA operations
|
||||
* **Indexing Pattern**: How to select CTA-specific tiles from partitioned views
|
||||
|
||||
See lines ~155-255 in `tma_v0.py` for detailed visual explanations.
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N)
|
||||
* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution
|
||||
* `threads_per_cta`: 32 - Single warp (all threads participate in barriers)
|
||||
* `buffer_align_bytes`: 1024 - Shared memory alignment
|
||||
|
||||
### Usage
|
||||
|
||||
Run the copy kernel with custom matrix dimensions:
|
||||
|
||||
```bash
|
||||
# Basic usage (default: 512×128 matrix)
|
||||
python tma_v0.py
|
||||
|
||||
# Custom dimensions
|
||||
python tma_v0.py --M 1024 --N 2048
|
||||
|
||||
# With custom benchmark iterations
|
||||
python tma_v0.py --M 4096 --N 4096 --num_warmup 10 --num_iters 50
|
||||
```
|
||||
|
||||
#### Example Code
|
||||
|
||||
```python
|
||||
from tma_v0 import run_tma_copy
|
||||
|
||||
# Run copy on 1024×2048 matrix
|
||||
run_tma_copy(M=1024, N=2048)
|
||||
# Output: Performance metrics and verification result
|
||||
```
|
||||
|
||||
### Performance Considerations
|
||||
|
||||
* **Single-stage operation**: No pipelining - simple sequential load → store pattern
|
||||
* **Educational focus**: Designed for understanding TMA fundamentals, not peak performance
|
||||
* **Barrier synchronization**: Demonstrates proper mbarrier usage patterns
|
||||
* **Foundation for V1/V2**: Concepts learned here are essential for understanding multi-stage pipelines and warp specialization in V1 and V2
|
||||
|
||||
## TMA V1: Matrix Transpose with Producer-Consumer Pattern
|
||||
|
||||
This example demonstrates a TMA-based matrix transpose using producer-consumer synchronization with mbarriers on NVIDIA Blackwell (SM100) architecture.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
* **Producer-Consumer Pattern**: Different warps coordinate through mbarrier synchronization
|
||||
* **TMA Operations**: Asynchronous bulk data transfer between Global and Shared Memory
|
||||
* **Shared Memory Swizzle**: Optimized layouts to avoid bank conflicts during transpose
|
||||
* **Warp Specialization**: Dedicated warps for loading, transposing, and storing
|
||||
* **mbarrier Synchronization**: Hardware barriers coordinate asynchronous operations
|
||||
|
||||
### Kernel Architecture
|
||||
|
||||
The `Sm100MatrixTransposeKernel` performs a tiled matrix transpose (M×N → N×M) with the following design:
|
||||
|
||||
#### Warp Roles
|
||||
|
||||
1. **TMA Load Warp** (Warp 4, Producer): Issues TMA load operations from Global Memory to Shared Memory buffer `sA`
|
||||
2. **Transpose Warps** (Warps 0-3, 4 warps, Consumer/Producer):
|
||||
* Wait for `sA` to be filled by TMA load (consumer of load_mbar)
|
||||
* Transpose data from `sA` → Registers → `sB`
|
||||
* Signal completion to TMA Store warp (producer for store_mbar)
|
||||
3. **TMA Store Warp** (Warp 5, Consumer):
|
||||
* Wait for `sB` to be ready (consumer of store_mbar)
|
||||
* Issue TMA store operations from `sB` to Global Memory
|
||||
|
||||
#### Synchronization Barriers
|
||||
|
||||
The kernel uses two mbarrier instances for producer-consumer coordination:
|
||||
|
||||
1. **load_mbar_ptr**: Synchronizes TMA Load → Transpose Warps
|
||||
* Producer: TMA Load Warp (arrives after TMA completes)
|
||||
* Consumer: Transpose Warps (wait before reading `sA`)
|
||||
* Expected arrivals: 1 (from TMA load warp)
|
||||
* Expected transactions: `tile_m × tile_n × element_size` bytes
|
||||
|
||||
2. **store_mbar_ptr**: Synchronizes Transpose Warps → TMA Store
|
||||
* Producer: Transpose Warps (each warp arrives after writing to `sB`)
|
||||
* Consumer: TMA Store Warp (waits before issuing TMA store)
|
||||
* Expected arrivals: 4 (one from each transpose warp)
|
||||
|
||||
#### Execution Flow
|
||||
|
||||
1. **Initialization**:
|
||||
* All warps participate in barrier initialization (thread 0 initializes)
|
||||
* Allocate two shared memory buffers: `sA` (row-major) and `sB` (column-major)
|
||||
* Create TMA descriptors for source and transposed destination
|
||||
* Initialize `load_mbar` with expected count of 1 and transaction bytes
|
||||
* Initialize `store_mbar` with expected count of 4 (number of transpose warps)
|
||||
|
||||
2. **TMA Load Warp** (Producer for Load Pipeline):
|
||||
```
|
||||
partition source tensor by tile shape
|
||||
issue TMA load: Global[block_tile] → sA
|
||||
arrive on load_mbar to signal completion
|
||||
```
|
||||
|
||||
3. **Transpose Warps** (Consumer for Load, Producer for Store):
|
||||
```
|
||||
wait on load_mbar for TMA load to complete
|
||||
|
||||
partition sA for reading (each thread handles subset)
|
||||
copy data: sA → Registers
|
||||
partition sB for writing
|
||||
copy data: Registers → sB (transpose happens via layout)
|
||||
|
||||
fence to ensure smem writes are visible
|
||||
synchronize with trans_sync_barrier
|
||||
[elect one thread] arrive on store_mbar to signal completion
|
||||
```
|
||||
|
||||
4. **TMA Store Warp** (Consumer for Store Pipeline):
|
||||
```
|
||||
wait on store_mbar for transpose to complete
|
||||
partition destination tensor by tile shape
|
||||
issue TMA store: sB → Global[block_tile]
|
||||
```
|
||||
|
||||
### Key Features
|
||||
|
||||
* **Simple Producer-Consumer Model**: Clear separation of concerns with dedicated warps
|
||||
* **Efficient Synchronization**: Hardware mbarriers minimize synchronization overhead
|
||||
* **Memory Layout Optimization**: Swizzled layouts prevent bank conflicts
|
||||
* **Transposition via Layouts**: Transpose is achieved through different memory layouts for `sA` and `sB`
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N)
|
||||
* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution
|
||||
* Warp count: 6 warps (1 TMA Load + 4 Transpose + 1 TMA Store)
|
||||
|
||||
### Usage
|
||||
|
||||
Run the transpose kernel with custom matrix dimensions:
|
||||
|
||||
```bash
|
||||
# Basic usage (128×128 matrix)
|
||||
python tma_v1.py
|
||||
|
||||
# Custom dimensions
|
||||
python tma_v1.py --M 1024 --N 2048
|
||||
|
||||
# With custom benchmark iterations
|
||||
python tma_v1.py --M 4096 --N 4096 --num_warmup 10 --num_iters 50
|
||||
```
|
||||
|
||||
#### Example Code
|
||||
|
||||
```python
|
||||
from tma_v1 import run_transpose
|
||||
|
||||
# Run transpose on 1024×2048 matrix
|
||||
run_transpose(M=1024, N=2048)
|
||||
# Output: Performance metrics and verification result
|
||||
```
|
||||
|
||||
### Performance Considerations
|
||||
|
||||
* **Single-stage pipeline**: Simpler than multi-stage but with potential for idle warps
|
||||
* **Warp specialization**: Clear roles minimize synchronization complexity
|
||||
* **Good for learning**: Demonstrates fundamental TMA and mbarrier concepts
|
||||
|
||||
## TMA V2: Transpose with Multi-Stage Pipeline
|
||||
|
||||
This example demonstrates a TMA implementation with multi-stage pipelining for efficient matrix transpose operations on NVIDIA Blackwell (SM100) architecture.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
* **Multi-Stage Pipeline**: Multiple buffers enable overlapping TMA loads, computation (transpose), and TMA stores to hide memory latency
|
||||
* **Pipeline Abstractions**: `PipelineTmaAsync` and `PipelineTmaStore` provide producer-consumer synchronization
|
||||
* **Persistent Tile Scheduler**: Efficient work distribution across CTAs for dynamic load balancing
|
||||
* **Shared Memory Swizzle**: Optimized shared memory layouts to avoid bank conflicts
|
||||
* **Warp Specialization**: Different warps handle loading and transposing; first transpose warp also handles storing
|
||||
|
||||
### Kernel Architecture
|
||||
|
||||
The `Sm100MatrixTransposeKernelV2` performs a tiled matrix transpose (M×N → N×M) with the following design:
|
||||
|
||||
#### Warp Roles
|
||||
|
||||
1. **TMA Load Warp** (Producer): Loads tiles from Global Memory to Shared Memory buffer `sA` using TMA load operations
|
||||
2. **Transpose Warps** (4 warps, Consumer/Producer):
|
||||
* Wait for data in `sA` from load pipeline
|
||||
* Transpose data from `sA` → Registers → `sB`
|
||||
* Synchronize via named barrier to ensure all transpose warps complete
|
||||
* First transpose warp (`trans_warp_id[0]`) issues TMA store operations from `sB` to Global Memory
|
||||
|
||||
#### Pipeline Stages
|
||||
|
||||
The kernel uses two separate pipelines for maximum parallelism:
|
||||
|
||||
1. **Load Pipeline**: `TMA Load Warp` (producer) → `Transpose Warps` (consumer)
|
||||
* Multi-stage buffer `sA` (automatically computed based on available SMEM)
|
||||
* Enables prefetching multiple tiles while processing current tile
|
||||
2. **Store Pipeline**: `Transpose Warps` (producer) → `First Transpose Warp` (consumer, issues TMA store)
|
||||
* Multi-stage buffer `sB` (automatically computed based on available SMEM)
|
||||
* First transpose warp handles TMA store after all transpose warps finish writing to `sB`
|
||||
|
||||
#### Stage Calculation
|
||||
|
||||
The kernel automatically computes the optimal number of pipeline stages using `_compute_stages()`:
|
||||
|
||||
* Calculates bytes needed per stage for `sA` (tile_m × tile_n) and `sB` (tile_m × tile_n)
|
||||
* Reserves space for pipeline barriers and metadata (~1KB)
|
||||
* Divides remaining shared memory by bytes per stage
|
||||
* Clamps to 2-8 stages (2 for double buffering minimum, 8 for diminishing returns)
|
||||
|
||||
#### Execution Flow
|
||||
|
||||
1. **Initialization**:
|
||||
* Allocate multi-stage shared memory buffers (`sA_staged`, `sB_staged`)
|
||||
* Create TMA descriptors for source and transposed destination
|
||||
* Initialize load and store pipelines with barrier synchronization
|
||||
* Use persistent tile scheduler to distribute work tiles across CTAs
|
||||
|
||||
2. **TMA Load Warp** (Producer for Load Pipeline):
|
||||
```
|
||||
for each tile assigned by scheduler:
|
||||
acquire next available stage in load pipeline
|
||||
issue TMA load: Global[tile] → sA[stage]
|
||||
advance to next stage
|
||||
```
|
||||
|
||||
3. **Transpose Warps** (Consumer for Load, Producer for Store):
|
||||
```
|
||||
for each tile assigned by scheduler:
|
||||
wait for load pipeline to fill current stage
|
||||
copy data: sA[load_stage] → Registers
|
||||
release load pipeline stage
|
||||
|
||||
transpose and write: Registers → sB[store_stage]
|
||||
fence to ensure smem writes are visible
|
||||
synchronize all transpose warps with barrier
|
||||
|
||||
[first transpose warp only] issue TMA store: sB[stage] → Global[tile]
|
||||
[first transpose warp only] commit to store pipeline
|
||||
[first transpose warp only] acquire next store pipeline stage
|
||||
synchronize all transpose warps with barrier
|
||||
```
|
||||
|
||||
4. **Pipeline Teardown**:
|
||||
* Producer/consumer tail operations ensure all in-flight operations complete
|
||||
|
||||
### Key Features
|
||||
|
||||
* **Automatic Stage Optimization**: Kernel calculates optimal number of stages based on tile size, data type, and available shared memory
|
||||
* **Persistent Tile Scheduler**: Efficient work distribution across CTAs for dynamic load balancing
|
||||
* **Memory Layout Optimization**: Uses appropriate swizzling for row-major input and column-major transposed output
|
||||
* **Efficient Synchronization**: Named barriers for intra-CTA coordination, mbarriers for pipeline stages
|
||||
|
||||
### Configuration Parameters
|
||||
|
||||
* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N)
|
||||
* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution
|
||||
* Number of pipeline stages: Automatically computed based on available SMEM
|
||||
|
||||
### Usage
|
||||
|
||||
Run the transpose kernel with custom matrix dimensions:
|
||||
|
||||
```bash
|
||||
# Basic usage (128×128 matrix)
|
||||
python tma_v2.py
|
||||
|
||||
# Custom dimensions
|
||||
python tma_v2.py --M 1024 --N 2048
|
||||
|
||||
python tma_v2.py --M 1024 --N 2048 --num_warmup 10 --num_iters 50
|
||||
```
|
||||
|
||||
#### Example Code
|
||||
|
||||
```python
|
||||
from tma_v2 import run_transpose
|
||||
|
||||
# Run transpose on 1024×2048 matrix
|
||||
run_transpose(M=1024, N=2048)
|
||||
# Output: "TransposeSuccess!" if verification passes
|
||||
```
|
||||
|
||||
### Performance Considerations
|
||||
|
||||
* **Multi-stage pipelining** hides memory latency by overlapping loads, computation, and stores
|
||||
* **Persistent scheduling** provides better load balancing for irregular matrix sizes
|
||||
* **Warp specialization** maximizes throughput: TMA load warp handles all loads, transpose warps handle computation, and first transpose warp handles stores
|
||||
|
||||
## TMA V3: TMA With MMA (Tensor Cores)
|
||||
|
||||
TBD
|
||||
@@ -0,0 +1,409 @@
|
||||
# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
from typing import Type, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import torch
|
||||
|
||||
"""
|
||||
TMA V0: Understanding tma_partition - The Foundation of TMA Operations
|
||||
|
||||
This tutorial demonstrates TMA (Tensor Memory Accelerator) operations through
|
||||
a simple copy kernel. It focuses on understanding the fundamental tma_partition
|
||||
interface, which is the key to all TMA operations.
|
||||
|
||||
What This Tutorial Covers:
|
||||
1. TMA Load (Global Memory -> Shared Memory)
|
||||
2. TMA Store (Shared Memory -> Global Memory)
|
||||
3. Barrier synchronization for TMA (elect_one, mbarrier_init, mbarrier_arrive, mbarrier_wait)
|
||||
4. Detailed explanation of tma_partition with visual diagrams
|
||||
|
||||
Key Learning Points:
|
||||
- tma_partition: How it transforms tensors for TMA operations
|
||||
- group_modes: Why and how to group tensor modes to define TMA atom shape
|
||||
- Indexing: How to select specific tiles from partitioned tensors
|
||||
- Data flow: Complete visualization from input tensors to TMA copy
|
||||
|
||||
Visual Diagrams:
|
||||
See line ~155 for comprehensive diagrams showing:
|
||||
- Input tensor shapes and transformations
|
||||
- group_modes effect on tensor layouts
|
||||
- tma_partition output structure
|
||||
- Complete data flow from global/shared memory to TMA copy
|
||||
- Indexing pattern for CTA-specific tiles
|
||||
|
||||
Example Usage:
|
||||
```bash
|
||||
python cutlass_ir/compiler/python/examples/cute/blackwell/tutorial/tutorial_tma/tma_v0.py
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class Sm100SimpleCopyKernel:
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the configuration for a Blackwell TMA copy kernel.
|
||||
"""
|
||||
self.tile_shape = (128, 128)
|
||||
self.tile_m, self.tile_n = self.tile_shape
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.threads_per_cta = 32
|
||||
self.buffer_align_bytes = 1024
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, src: cute.Tensor, dst: cute.Tensor):
|
||||
if cutlass.const_expr(src.element_type != dst.element_type):
|
||||
raise TypeError("Source and destination element types must match")
|
||||
|
||||
self.dtype: Type[cutlass.Numeric] = src.element_type
|
||||
|
||||
# layout for each cta: (tile_m, tile_n):(tile_n, 1)
|
||||
smem_layout = cute.make_layout(
|
||||
(self.tile_m, self.tile_n), stride=(self.tile_n, 1)
|
||||
)
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
barrier_storage: cute.struct.MemRange[cutlass.Int64, 1]
|
||||
smem_data: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
|
||||
self.shared_storage = SharedStorage
|
||||
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout)
|
||||
|
||||
# cta_tiler: the per-CTA tile extents (M, N) used by TMA.
|
||||
# Note: smem_layout may include swizzle or composed layout,
|
||||
# so we use product_each(...) to take the product along each logical dimension and get
|
||||
# the final (tile_m, tile_n) extents expected by TMA.
|
||||
# In this simple example, smem_layout.shape == (tile_m, tile_n), so product_each(...) is
|
||||
# just (tile_m, tile_n).
|
||||
cta_tiler = cute.product_each(smem_layout.shape)
|
||||
|
||||
tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileG2SOp(), src, smem_layout, cta_tiler
|
||||
)
|
||||
|
||||
tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(), dst, smem_layout, cta_tiler
|
||||
)
|
||||
|
||||
# Grid shape is now (M/TileM, N/TileN)
|
||||
grid_shape = cute.ceil_div((*src.layout.shape, 1), self.tile_shape)
|
||||
|
||||
self.kernel(
|
||||
tma_atom_src, tma_tensor_src, tma_atom_dst, tma_tensor_dst, smem_layout
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=(self.threads_per_cta, 1, 1),
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
tma_atom_src: cute.CopyAtom,
|
||||
tma_tensor_src: cute.Tensor,
|
||||
tma_atom_dst: cute.CopyAtom,
|
||||
tma_tensor_dst: cute.Tensor,
|
||||
smem_layout: Union[cute.Layout, cute.ComposedLayout],
|
||||
):
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
|
||||
# Allocate Shared Memory
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(self.shared_storage)
|
||||
|
||||
# Initialize barrier for TMA synchronization
|
||||
barrier_ptr = storage.barrier_storage.data_ptr()
|
||||
|
||||
# Initialize the barrier: elect_one ensures only one thread executes this
|
||||
# Note: We must use elect_one() instead of "if tid == 0" because:
|
||||
# - elect_one() provides proper synchronization semantics
|
||||
# - It ensures all threads are aware that exactly one thread is executing
|
||||
# - It prevents race conditions and provides memory ordering guarantees
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_init(barrier_ptr, 1)
|
||||
cute.arch.mbarrier_expect_tx(barrier_ptr, self.num_tma_load_bytes)
|
||||
|
||||
# Fence ensures init/expect_tx are visible before proceeding
|
||||
cute.arch.mbarrier_init_fence()
|
||||
cute.arch.barrier()
|
||||
|
||||
# Tile the (M, N) tensor: ((TileM, TileN), M/TileM, N/TileN)
|
||||
gSrc_tiled = cute.local_tile(
|
||||
tma_tensor_src, (self.tile_m, self.tile_n), (None, None)
|
||||
)
|
||||
gDst_tiled = cute.local_tile(
|
||||
tma_tensor_dst, (self.tile_m, self.tile_n), (None, None)
|
||||
)
|
||||
|
||||
smem_tensor = storage.smem_data.get_tensor(smem_layout)
|
||||
|
||||
# ======================================================================
|
||||
# TMA Partition: Tensor Preparation for TMA Operations
|
||||
# ======================================================================
|
||||
#
|
||||
# tma_partition prepares tensors for TMA copy by partitioning them
|
||||
# according to the TMA atom's internal layout requirements.
|
||||
#
|
||||
# Signature:
|
||||
# tma_partition(atom, cta_id, cta_layout, smem_tensor, gmem_tensor)
|
||||
# -> (smem_view, gmem_view)
|
||||
#
|
||||
# Key Requirement: Mode 0 of both tensors must represent the TMA atom
|
||||
#
|
||||
# Example: M=512, N=128, TileM=128, TileN=64
|
||||
#
|
||||
# Input Tensors:
|
||||
# gSrc_tiled: (128, 64, 4, 2) # 4 separate modes
|
||||
# └──┬──┘ └──┬──┘
|
||||
# Tile Grid
|
||||
#
|
||||
# smem_tensor: (128, 64) # 2 separate modes
|
||||
# └──┬──┘
|
||||
# Tile
|
||||
#
|
||||
# Apply group_modes(tensor, 0, 2) to group first 2 modes:
|
||||
#
|
||||
# group_modes(gSrc_tiled, 0, 2) => ((128, 64), 4, 2)
|
||||
# └───┬───┘
|
||||
# Mode 0 = Atom
|
||||
#
|
||||
# group_modes(smem_tensor, 0, 2) => ((128, 64),)
|
||||
# └───┬───┘
|
||||
# Mode 0 = Atom
|
||||
#
|
||||
# After tma_partition:
|
||||
#
|
||||
# tAsA: SMEM view with TMA internal layout
|
||||
# Shape: ((TMA_Layout),)
|
||||
# - TMA_Layout: Swizzled/banked layout for efficient SMEM access
|
||||
#
|
||||
# tAgA: Global view preserving rest modes
|
||||
# Shape: ((TMA_Layout), 4, 2)
|
||||
# └─────┬─────┘ └──┬──┘
|
||||
# TMA atom Rest modes (grid)
|
||||
#
|
||||
# Usage Pattern:
|
||||
# 1. Group modes to define atom: group_modes(tensor, 0, 2)
|
||||
# 2. Call tma_partition: tAsA, tAgA = tma_partition(...)
|
||||
# 3. Select tile for CTA: tAgA_cta = tAgA[(None, bidx, bidy)]
|
||||
# - None: keep entire atom
|
||||
# - bidx, bidy: index into rest modes
|
||||
# 4. Issue TMA copy: cute.copy(atom, tAgA_cta, tAsA)
|
||||
#
|
||||
# Visual Data Flow:
|
||||
#
|
||||
# Global Memory (512x128) Shared Memory (128x64)
|
||||
# ┌────────────────────┐ ┌──────────────┐
|
||||
# │ ┌───┬───┐ │ │ │
|
||||
# │ │0,0│0,1│ │ │ smem_tensor │
|
||||
# │ ├───┼───┤ │ │ (128, 64) │
|
||||
# │ │1,0│1,1│ 4x2 │ │ │
|
||||
# │ ├───┼───┤ tiles │ └──────────────┘
|
||||
# │ │2,0│2,1│ │ │
|
||||
# │ ├───┼───┤ │ │ group_modes
|
||||
# │ │3,0│3,1│ │ ↓
|
||||
# │ └───┴───┘ │ ((128, 64),)
|
||||
# └────────────────────┘ │
|
||||
# │ │
|
||||
# │ gSrc_tiled │
|
||||
# │ (128, 64, 4, 2) │
|
||||
# ↓ │
|
||||
# group_modes(_, 0, 2) │
|
||||
# ↓ │
|
||||
# ((128, 64), 4, 2) │
|
||||
# │ │
|
||||
# └────────┬────────────────────────┘
|
||||
# ↓
|
||||
# tma_partition
|
||||
# ↓
|
||||
# ┌──────────────┴──────────────┐
|
||||
# │ │
|
||||
# tAgA tAsA
|
||||
# ((TMA_Layout), 4, 2) ((TMA_Layout),)
|
||||
# │
|
||||
# ↓ tAgA[(None, bidx, bidy)]
|
||||
# tAgA_cta
|
||||
# ((TMA_Layout),)
|
||||
#
|
||||
# ======================================================================
|
||||
|
||||
# TMA Load partition
|
||||
# Here we only use 1x1 cluster, so cta_id is 0 and cta_layout is (1).
|
||||
# More details about how to set cta_coord and cta_layout can be found in the tma_v4.py
|
||||
# Note: Smem and gemm should have the same size (atom element size) in the first rank
|
||||
tAsA, tAgA = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_src,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(smem_tensor, 0, 2),
|
||||
cute.group_modes(gSrc_tiled, 0, 2),
|
||||
)
|
||||
|
||||
# TMA Store partition
|
||||
# Same process as TMA Load, but for destination tensor
|
||||
# Partitions gDst_tiled and smem_tensor according to TMA Store atom
|
||||
_, tBgB = cute.nvgpu.cpasync.tma_partition(
|
||||
tma_atom_dst,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(smem_tensor, 0, 2),
|
||||
cute.group_modes(gDst_tiled, 0, 2),
|
||||
)
|
||||
|
||||
# Select specific tile for this CTA from partitioned global views
|
||||
# Input: tAgA with shape ((TMA_Layout), 4, 2)
|
||||
# Output: tAgA_cta with shape ((TMA_Layout),)
|
||||
# The (None, bidx, bidy) indexing:
|
||||
# - None: keeps the entire TMA atom layout (mode 0)
|
||||
# - bidx: selects from rest mode 1 (M dimension grid)
|
||||
# - bidy: selects from rest mode 2 (N dimension grid)
|
||||
tAgA_cta = tAgA[(None, bidx, bidy)]
|
||||
tBgB_cta = tBgB[(None, bidx, bidy)]
|
||||
|
||||
# ---------- TMA Load: Global -> Shared ----------
|
||||
cute.copy(
|
||||
tma_atom_src,
|
||||
tAgA_cta, # Source (TMA Tensor View)
|
||||
tAsA, # Dest (SMEM Tensor View)
|
||||
tma_bar_ptr=barrier_ptr,
|
||||
)
|
||||
|
||||
# Signal arrival on the barrier after TMA is issued
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_arrive(barrier_ptr)
|
||||
|
||||
# Wait for TMA to complete
|
||||
cute.arch.mbarrier_wait(barrier_ptr, 0)
|
||||
|
||||
# ---------- TMA Store: Shared -> Global ----------
|
||||
cute.copy(
|
||||
tma_atom_dst,
|
||||
tAsA, # Source (SMEM Tensor View)
|
||||
tBgB_cta, # Dest (Global Tensor View)
|
||||
)
|
||||
|
||||
|
||||
def run_tma_copy(M, N, num_warmup=5, num_iters=20):
|
||||
"""
|
||||
Run TMA copy kernel with performance measurement.
|
||||
|
||||
Args:
|
||||
M: Matrix dimension M
|
||||
N: Matrix dimension N
|
||||
num_warmup: Number of warmup iterations
|
||||
num_iters: Number of timing iterations
|
||||
"""
|
||||
# Create tensors with shape (M, N)
|
||||
a = torch.randn((M, N), dtype=torch.float16, device="cuda")
|
||||
b = torch.zeros((M, N), dtype=torch.float16, device="cuda")
|
||||
|
||||
# Notice: We declare N-dimension as the leading dimension should be divisible by 16
|
||||
a_cute = (
|
||||
from_dlpack(a, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=16)
|
||||
)
|
||||
b_cute = (
|
||||
from_dlpack(b, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=16)
|
||||
)
|
||||
|
||||
copy_kernel = Sm100SimpleCopyKernel()
|
||||
compiled_kernel = cute.compile(copy_kernel, a_cute, b_cute)
|
||||
|
||||
# Warmup runs
|
||||
for _ in range(num_warmup):
|
||||
compiled_kernel(a_cute, b_cute)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timed runs
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(num_iters):
|
||||
compiled_kernel(a_cute, b_cute)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Calculate performance metrics
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
avg_time_ms = elapsed_time_ms / num_iters
|
||||
|
||||
# Calculate throughput
|
||||
# For copy: read M*N elements + write M*N elements
|
||||
bytes_per_element = a.element_size()
|
||||
total_bytes = 2 * M * N * bytes_per_element # Read + Write
|
||||
throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000)
|
||||
|
||||
# Print performance metrics
|
||||
print(f"Matrix size: {M}×{N}")
|
||||
print(f"Tile shape: {copy_kernel.tile_shape}")
|
||||
print(f"Average time: {avg_time_ms:.4f} ms")
|
||||
print(f"Throughput: {throughput_gb_s:.2f} GB/s")
|
||||
|
||||
# Verify
|
||||
if torch.allclose(a, b, atol=1e-3):
|
||||
print("Verification: PASSED ✓")
|
||||
else:
|
||||
print("Verification: FAILED ✗")
|
||||
diff = (a - b).abs()
|
||||
print(f"Max diff: {diff.max()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TMA V0: Understanding tma_partition - The Foundation of TMA Operations"
|
||||
)
|
||||
parser.add_argument("--M", type=int, default=512, help="Matrix dimension M")
|
||||
parser.add_argument("--N", type=int, default=128, help="Matrix dimension N")
|
||||
parser.add_argument(
|
||||
"--num_warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=20, help="Number of timing iterations"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_tma_copy(
|
||||
M=args.M,
|
||||
N=args.N,
|
||||
num_warmup=args.num_warmup,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
@@ -0,0 +1,462 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
from typing import Tuple, Type
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.pipeline as pipeline
|
||||
import torch
|
||||
|
||||
"""
|
||||
TMA Matrix Transpose with Producer-Consumer Pattern: TMA load -> S2R --> R2S -> TMA store
|
||||
|
||||
Warp Roles (Producer-Consumer Pattern):
|
||||
- Producer: TMA Load Warp (Warp 4) - Loads from Global to Shared A
|
||||
- Consumer: Transpose Warps (Warp 0-3) - Wait for load, then transpose sA -> sB
|
||||
- Consumer: TMA Store Warp (Warp 5) - Wait for transpose, then store sB to Global
|
||||
|
||||
Synchronization:
|
||||
1. load_mbar_ptr: TMA Load (producer) -> Transpose Warps (consumer)
|
||||
2. store_mbar_ptr: Transpose Warps (producer) -> TMA Store (consumer)
|
||||
|
||||
This demonstrates how different warps can use shared memory barriers to coordinate
|
||||
producer-consumer relationships.
|
||||
"""
|
||||
|
||||
|
||||
class Sm100MatrixTransposeKernelV1:
|
||||
def __init__(self):
|
||||
self.tile_shape = (128, 128)
|
||||
self.tile_m, self.tile_n = self.tile_shape
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cluster_shape_mnk = (*self.cluster_shape_mn, 1)
|
||||
|
||||
# Set specialized warp ids based on tile_shape
|
||||
self.num_trans_warps = 4 # Maximum number of transpose warps
|
||||
self.trans_warp_id = tuple(range(self.num_trans_warps))
|
||||
self.tma_load_warp_id = self.num_trans_warps
|
||||
self.tma_store_warp_id = self.num_trans_warps + 1
|
||||
self.threads_per_cta = 32 * len(
|
||||
(self.tma_store_warp_id, self.tma_load_warp_id, *self.trans_warp_id)
|
||||
)
|
||||
self.num_trans_threads = 32 * len(self.trans_warp_id)
|
||||
|
||||
self.trans_tile = (self.tile_shape[0] // self.num_trans_warps, 8)
|
||||
|
||||
# Set barriers for producer-consumer sync
|
||||
# Barrier 1: Trans warps sync (for internal coordination)
|
||||
self.trans_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=1,
|
||||
num_threads=32 * len(self.trans_warp_id),
|
||||
)
|
||||
# Barrier 2: TMA Store warp waits after Trans warps finish
|
||||
self.store_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=2,
|
||||
num_threads=32, # Only TMA store warp
|
||||
)
|
||||
self.buffer_align_bytes = 1024
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, src: cute.Tensor, dst: cute.Tensor):
|
||||
if cutlass.const_expr(src.element_type != dst.element_type):
|
||||
raise TypeError("Source and destination element types must match")
|
||||
|
||||
self.dtype: Type[cutlass.Numeric] = src.element_type
|
||||
|
||||
# Create transposed view of dst for TMA descriptor
|
||||
# dst is (N, M), we want to view it as (M, N) transposed
|
||||
transed_dst = cute.make_tensor(
|
||||
dst.iterator,
|
||||
cute.make_layout(
|
||||
(dst.shape[1], dst.shape[0]), stride=(dst.stride[1], dst.stride[0])
|
||||
),
|
||||
)
|
||||
|
||||
# row-major smem layout for sA (tile_m, tile_n)
|
||||
smem_layout_sA = sm100_utils.make_smem_layout(
|
||||
utils.LayoutEnum.from_tensor(src).mma_major_mode(),
|
||||
(self.tile_m, self.tile_n),
|
||||
self.dtype,
|
||||
1,
|
||||
)
|
||||
|
||||
# col-major smem layout for sB (tile_n, tile_m)
|
||||
# sB should match the transposed destination layout
|
||||
smem_layout_sB = sm100_utils.make_smem_layout(
|
||||
utils.LayoutEnum.from_tensor(transed_dst).mma_major_mode(),
|
||||
(self.tile_m, self.tile_n),
|
||||
self.dtype,
|
||||
1,
|
||||
)
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
# Barrier for TMA Load: producer (TMA) -> consumer (Trans warps)
|
||||
load_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1]
|
||||
# Barrier for TMA Store: producer (Trans warps) -> consumer (TMA Store)
|
||||
store_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1]
|
||||
# Single shared memory buffer (sA and sB are different views of this)
|
||||
sA: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sA)], 128
|
||||
]
|
||||
sB: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sB)], 128
|
||||
]
|
||||
|
||||
self.shared_storage = SharedStorage
|
||||
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout_sA)
|
||||
|
||||
# TMA Atoms
|
||||
# Use swizzled layout for TMA atom to handle swizzling during load
|
||||
tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileG2SOp(),
|
||||
src,
|
||||
smem_layout_sA,
|
||||
(self.tile_m, self.tile_n),
|
||||
)
|
||||
|
||||
tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(),
|
||||
transed_dst,
|
||||
smem_layout_sB,
|
||||
(self.tile_m, self.tile_n),
|
||||
)
|
||||
|
||||
grid_shape = cute.ceil_div((*src.layout.shape, 1), self.tile_shape)
|
||||
self.kernel(
|
||||
tma_atom_src,
|
||||
tma_tensor_src,
|
||||
tma_atom_dst,
|
||||
tma_tensor_dst,
|
||||
smem_layout_sA,
|
||||
smem_layout_sB,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=(self.threads_per_cta, 1, 1),
|
||||
cluster=self.cluster_shape_mnk,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
tma_atom_load: cute.CopyAtom,
|
||||
tma_tensor_src: cute.Tensor,
|
||||
tma_atom_store: cute.CopyAtom,
|
||||
tma_tensor_dst: cute.Tensor,
|
||||
smem_layout_sA: cute.ComposedLayout,
|
||||
smem_layout_sB: cute.ComposedLayout,
|
||||
):
|
||||
bidx, bidy, _ = cute.arch.block_idx()
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
# Allocate Shared Memory
|
||||
# We need two buffers for transpose:
|
||||
# sA: Source Tile (swizzled)
|
||||
# sB: Destination Tile (swizzled) for TMA Store
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(self.shared_storage)
|
||||
sA = storage.sA.get_tensor(smem_layout_sA.outer, swizzle=smem_layout_sA.inner)
|
||||
# sA = cute.make_tensor(storage.sA.iterator, smem_layout_sA)
|
||||
sB = storage.sB.get_tensor(smem_layout_sB.outer, swizzle=smem_layout_sB.inner)
|
||||
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout_sA)
|
||||
load_mbar_ptr = storage.load_mbar_ptr.data_ptr()
|
||||
store_mbar_ptr = storage.store_mbar_ptr.data_ptr()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Initialize Barriers (all warps participate in initialization)
|
||||
# ------------------------------------------------------------------
|
||||
if tidx == 0:
|
||||
# Barrier for TMA Load: expect 1 arrive (from TMA warp after TMA completes)
|
||||
cute.arch.mbarrier_init(load_mbar_ptr, 1)
|
||||
cute.arch.mbarrier_expect_tx(load_mbar_ptr, self.num_tma_load_bytes)
|
||||
|
||||
# Barrier for TMA Store: expect arrival from Trans warps
|
||||
cute.arch.mbarrier_init(store_mbar_ptr, len(self.trans_warp_id))
|
||||
cute.arch.mbarrier_init_fence()
|
||||
|
||||
# Sync all warps after barrier initialization
|
||||
cute.arch.barrier()
|
||||
# ------------------------------------------------------------------
|
||||
# PRODUCER: TMA Load Warp (G -> sA)
|
||||
# ------------------------------------------------------------------
|
||||
if warp_idx == self.tma_load_warp_id:
|
||||
# Issue TMA Load
|
||||
# ((TileM, TileK), loopM, LoopK)
|
||||
gA = cute.local_tile(tma_tensor_src, self.tile_shape, (None, None))
|
||||
# ((TileM, TileK), loopM, LoopK)
|
||||
tAsA, tAgA = cpasync.tma_partition(
|
||||
tma_atom_load,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sA, 0, 2),
|
||||
cute.group_modes(gA, 0, 2),
|
||||
)
|
||||
|
||||
cute.copy(
|
||||
tma_atom_load,
|
||||
tAgA[(None, bidx, bidy)],
|
||||
tAsA[(None, 0)],
|
||||
tma_bar_ptr=load_mbar_ptr,
|
||||
)
|
||||
|
||||
# Arrive on mbarrier to satisfy the init count of 1
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_arrive(load_mbar_ptr)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CONSUMER: Transpose Warps (sA -> Reg -> sB)
|
||||
# ------------------------------------------------------------------
|
||||
if warp_idx < self.tma_load_warp_id:
|
||||
trans_tid = tidx % self.num_trans_threads
|
||||
|
||||
# Wait for TMA Load to complete (consumer wait on load_mbar)
|
||||
cute.arch.mbarrier_wait(load_mbar_ptr, 0)
|
||||
|
||||
atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.dtype,
|
||||
num_bits_per_copy=self.dtype.width, # Copy one element at a time
|
||||
)
|
||||
|
||||
copy_elems = 1
|
||||
|
||||
# Use SAME thread layout for both read and write
|
||||
# Transpose happens through sB_transposed layout view
|
||||
|
||||
# TV layout notation: T = thread-id, V = value-lane within that thread.
|
||||
# In this example `copy_elems = 1` and `thread_layout` has shape (T, V) = (num_trans_threads, 1),
|
||||
# so V is always 0 (only one value-lane per thread).
|
||||
#
|
||||
# Linearization rule (row-major by strides):
|
||||
# idx(Ti, Vj) = i + j * num_trans_threads
|
||||
# Therefore here:
|
||||
# idx(Ti, V0) = i
|
||||
#
|
||||
# Diagram (V is the column, T is the row; showing the first two threads):
|
||||
#
|
||||
# V0
|
||||
# ┌──────┐
|
||||
# T0 │ T0V0 │ -> idx 0
|
||||
# T1 │ T1V0 │ -> idx 1
|
||||
# ... │ ... │
|
||||
# └──────┘
|
||||
#
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_trans_threads, 1),
|
||||
stride=(1, self.num_trans_threads),
|
||||
)
|
||||
value_layout = cute.make_layout((1, copy_elems))
|
||||
# Build a "tiled copy" operator that defines the per-thread copy mapping (T,V) for this warp-group:
|
||||
# - It is used twice below via `thr_copy.partition_S(...)` and `thr_copy.partition_D(...)` to
|
||||
# create matching per-thread views of the source and destination tensors.
|
||||
# - With the SAME (T,V) mapping, the actual transpose is achieved by changing the tensor view
|
||||
# (`sA` vs `sB` / `sB_transposed`), not by changing which threads perform the copies.
|
||||
tiled_copy = cute.make_tiled_copy_tv(atom, thread_layout, value_layout)
|
||||
thr_copy = tiled_copy.get_slice(trans_tid)
|
||||
|
||||
# Partition sA (source) for reading
|
||||
tCsA = thr_copy.partition_S(sA)
|
||||
# When to use `tiled_copy.retile(...)`:
|
||||
# - Use it when you allocate/build a register tensor yourself (or slice/reshape it) and its
|
||||
# internal layout doesn't match the TV layout expected by `tiled_copy` for copy-in/out.
|
||||
# Why we don't use it here:
|
||||
# - `cute.make_fragment_like(tCsA)` creates an rmem fragment with the same per-thread shape/layout
|
||||
# as `tCsA`, so it already matches `tiled_copy`'s view and can be copied into directly.
|
||||
tCrA = cute.make_fragment_like(tCsA)
|
||||
cute.copy(tiled_copy, tCsA, tCrA)
|
||||
|
||||
# Partition sB for writing
|
||||
tCsB = thr_copy.partition_D(sB)
|
||||
|
||||
# Write from register to sB
|
||||
cute.copy(tiled_copy, tCrA, tCsB)
|
||||
|
||||
# Fence and barrier to make sure shared memory store is visible to TMA store
|
||||
cute.arch.fence_proxy(
|
||||
"async.shared",
|
||||
space="cta",
|
||||
)
|
||||
self.trans_sync_barrier.arrive_and_wait()
|
||||
|
||||
# Trans warps signal TMA Store warp: "sB is ready!"
|
||||
with cute.arch.elect_one():
|
||||
cute.arch.mbarrier_arrive(store_mbar_ptr)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CONSUMER: TMA Store Warp (sB -> G)
|
||||
# ------------------------------------------------------------------
|
||||
if warp_idx == self.tma_store_warp_id:
|
||||
# Wait for Trans warp to complete (consumer wait on store_mbar)
|
||||
cute.arch.mbarrier_wait(store_mbar_ptr, 0)
|
||||
|
||||
gDst_cta = cute.local_tile(
|
||||
tma_tensor_dst, (self.tile_m, self.tile_n), (None, None)
|
||||
)
|
||||
tBsB, tBgB = cpasync.tma_partition(
|
||||
tma_atom_store,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sB, 0, 2),
|
||||
cute.group_modes(gDst_cta, 0, 2),
|
||||
)
|
||||
cute.copy(tma_atom_store, tBsB[(None, 0)], tBgB[(None, bidx, bidy)])
|
||||
|
||||
|
||||
def run_transpose(M, N, num_warmup=5, num_iters=20):
|
||||
"""
|
||||
Run TMA transpose kernel with performance measurement.
|
||||
|
||||
Args:
|
||||
M: Matrix dimension M
|
||||
N: Matrix dimension N
|
||||
num_warmup: Number of warmup iterations
|
||||
num_iters: Number of timing iterations
|
||||
|
||||
Performance Metrics:
|
||||
- Throughput: Actual achieved bandwidth in GB/s
|
||||
- Theoretical BW: Peak memory bandwidth (2048 B/clk × 4000 MHz = 8.192 TB/s)
|
||||
- Bandwidth Efficiency: Percentage of theoretical peak achieved
|
||||
"""
|
||||
torch.manual_seed(1111)
|
||||
# Input (M, N)
|
||||
input_data = torch.randn((M, N), device="cuda", dtype=torch.float16)
|
||||
# Output (N, M)
|
||||
output_data = torch.zeros((N, M), device="cuda", dtype=torch.float16)
|
||||
|
||||
# CuTe Wrappers
|
||||
tensor_src = (
|
||||
from_dlpack(input_data, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=16)
|
||||
)
|
||||
tensor_dst = (
|
||||
from_dlpack(output_data, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=16)
|
||||
)
|
||||
|
||||
transpose_kernel = Sm100MatrixTransposeKernelV1()
|
||||
|
||||
print("Start kernel compilation...")
|
||||
|
||||
# Compile and Run
|
||||
compiled_kernel = cute.compile(
|
||||
transpose_kernel, tensor_src, tensor_dst, options="--generate-line-info"
|
||||
)
|
||||
|
||||
print("Start kernel warmup...")
|
||||
# Warmup runs
|
||||
for _ in range(num_warmup):
|
||||
compiled_kernel(tensor_src, tensor_dst)
|
||||
torch.cuda.synchronize()
|
||||
print("Kernel warmup completed.")
|
||||
|
||||
# Timed runs
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(num_iters):
|
||||
compiled_kernel(tensor_src, tensor_dst)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Calculate performance metrics
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
avg_time_ms = elapsed_time_ms / num_iters
|
||||
|
||||
# Calculate throughput
|
||||
# For transpose: read M*N elements + write M*N elements
|
||||
bytes_per_element = input_data.element_size()
|
||||
total_bytes = 2 * M * N * bytes_per_element # Read + Write
|
||||
throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000)
|
||||
|
||||
# Theoretical bandwidth limit
|
||||
# Blackwell: 2048 B/clk at 4000 MHz
|
||||
bytes_per_clk = 2048
|
||||
freq_mhz = 4000
|
||||
theoretical_bw_gb_s = bytes_per_clk * freq_mhz * 1e6 / 1e9 # Convert to GB/s
|
||||
theoretical_bw_tb_s = theoretical_bw_gb_s / 1000 # Convert to TB/s
|
||||
bandwidth_efficiency = (throughput_gb_s / theoretical_bw_gb_s) * 100 # Percentage
|
||||
|
||||
# Print performance metrics
|
||||
print(f"Matrix size: {M}×{N}")
|
||||
print(f"Tile shape: {transpose_kernel.tile_shape}")
|
||||
print(f"Average time: {avg_time_ms:.4f} ms")
|
||||
print(f"Throughput: {throughput_gb_s:.2f} GB/s")
|
||||
print(
|
||||
f"Theoretical BW: {theoretical_bw_tb_s:.2f} TB/s ({theoretical_bw_gb_s:.2f} GB/s)"
|
||||
)
|
||||
print(f"Bandwidth Efficiency: {bandwidth_efficiency:.2f}%")
|
||||
|
||||
# Verification
|
||||
expected = input_data.t()
|
||||
if torch.allclose(output_data, expected, atol=1e-2):
|
||||
print("Verification: PASSED ✓")
|
||||
else:
|
||||
print("Verification: FAILED ✗")
|
||||
print(f"Max diff: {(output_data - expected).abs().max()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TMA Matrix Transpose with Producer-Consumer Pattern"
|
||||
)
|
||||
parser.add_argument("--M", type=int, default=128, help="Matrix dimension M")
|
||||
parser.add_argument("--N", type=int, default=128, help="Matrix dimension N")
|
||||
parser.add_argument(
|
||||
"--num_warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=20, help="Number of timing iterations"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_transpose(
|
||||
args.M,
|
||||
args.N,
|
||||
num_warmup=args.num_warmup,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
@@ -0,0 +1,648 @@
|
||||
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
from typing import Tuple, Type, Union
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
|
||||
import torch
|
||||
|
||||
"""
|
||||
TMA Matrix Transpose with Multi-Stage Pipeline(v2)
|
||||
|
||||
This version extends tma_v1.py with:
|
||||
1. Multi-stage pipeline: Multiple buffers for pipelining TMA loads and stores
|
||||
2. Pipeline abstraction: Using PipelineTmaAsync for proper producer-consumer coordination
|
||||
3. Persistent tile scheduler: Efficient work distribution across CTAs
|
||||
|
||||
Key Improvements over v1:
|
||||
- Multi-stage buffers enable overlapping TMA loads, computation, and TMA stores
|
||||
- Pipeline objects provide cleaner synchronization semantics
|
||||
|
||||
Note: TMA multicast is NOT used because each CTA must process different input tiles.
|
||||
|
||||
Warp Roles:
|
||||
- Producer: TMA Load Warp - Loads from Global to Shared memory (multi-stage, multi-tile)
|
||||
- Consumer: Transpose Warps - Wait for load, transpose sA -> sB (multi-tile)
|
||||
- Consumer: TMA Store Warp - Wait for transpose, store sB to Global (multi-stage, multi-tile)
|
||||
|
||||
Pipeline Stages:
|
||||
1. Load Pipeline: TMA Load (producer) -> Transpose Warps (consumer)
|
||||
2. Store Pipeline: Transpose Warps (producer) -> TMA Store (consumer)
|
||||
"""
|
||||
|
||||
|
||||
class Sm100MatrixTransposeKernelV2:
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Initialize the TMA transpose kernel with multi-stage pipeline support.
|
||||
|
||||
Args:
|
||||
tile_shape: Tile dimensions (M, N)
|
||||
cluster_shape_mn: Cluster shape for parallel CTA execution (M, N)
|
||||
|
||||
Note:
|
||||
- Each CTA processes different tiles independently
|
||||
- Stage counts are automatically computed based on available shared memory
|
||||
- Persistent scheduler distributes work across CTAs in the cluster
|
||||
"""
|
||||
self.tile_shape = (128, 128)
|
||||
self.tile_m, self.tile_n = self.tile_shape
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cluster_shape_mnl = (*self.cluster_shape_mn, 1)
|
||||
|
||||
# Set specialized warp ids based on tile_shape
|
||||
# For 128x128 tile, use 4 transpose warps (same as v1)
|
||||
self.max_trans_warps = 4 # Maximum number of transpose warps
|
||||
self.num_trans_warps = self.max_trans_warps # Use all transpose warps
|
||||
self.trans_warp_id = tuple(range(self.num_trans_warps))
|
||||
self.tma_load_warp_id = self.num_trans_warps
|
||||
self.threads_per_cta = 32 * len((self.tma_load_warp_id, *self.trans_warp_id))
|
||||
self.num_trans_threads = 32 * len(self.trans_warp_id)
|
||||
|
||||
# Set barriers for producer-consumer sync
|
||||
# Barrier 1: Trans warps sync (for internal coordination)
|
||||
self.trans_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=1,
|
||||
num_threads=32 * len(self.trans_warp_id),
|
||||
)
|
||||
self.buffer_align_bytes = 128
|
||||
|
||||
# Get shared memory capacity for stage computation
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
@staticmethod
|
||||
def _compute_stages(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
dtype: Type[cutlass.Numeric],
|
||||
smem_capacity: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Compute the number of load and store stages based on shared memory capacity.
|
||||
|
||||
Strategy:
|
||||
1. Calculate bytes per stage for load (sA) and store (sB) buffers
|
||||
2. Reserve space for barriers and alignment
|
||||
3. Divide remaining smem by bytes per stage to get max stages
|
||||
4. Clamp to reasonable min/max values
|
||||
|
||||
Args:
|
||||
tile_m: Tile dimension M
|
||||
tile_n: Tile dimension N
|
||||
dtype: Data type of the tensors
|
||||
smem_capacity: Total shared memory capacity in bytes
|
||||
|
||||
Returns:
|
||||
Tuple of (num_load_stages, num_store_stages)
|
||||
"""
|
||||
# Calculate bytes per tile (assuming row-major and col-major layouts)
|
||||
bytes_per_element = dtype.width // 8
|
||||
|
||||
# For sA (load buffer): tile_m x tile_n elements
|
||||
sA_bytes_per_stage = tile_m * tile_n * bytes_per_element
|
||||
|
||||
# For sB (store buffer): tile_m x tile_n elements (transposed)
|
||||
sB_bytes_per_stage = tile_m * tile_n * bytes_per_element
|
||||
|
||||
# Reserve space for barriers and other metadata
|
||||
# Each barrier: 8 bytes (Int64)
|
||||
# Estimate: max 16 barriers (load + store stages * 2) + alignment
|
||||
reserved_bytes = 1024 # Conservative estimate
|
||||
|
||||
# Available space for staging buffers
|
||||
available_smem = smem_capacity - reserved_bytes
|
||||
|
||||
# Calculate max stages we can fit
|
||||
# We need space for both load and store stages
|
||||
total_bytes_per_stage_pair = sA_bytes_per_stage + sB_bytes_per_stage
|
||||
|
||||
# Max stages (same for load and store for simplicity)
|
||||
max_stages = available_smem // total_bytes_per_stage_pair
|
||||
|
||||
# Clamp to reasonable values
|
||||
# Min: 2 stages for basic double buffering
|
||||
# Max: 8 stages (diminishing returns beyond this)
|
||||
num_stages = max(2, min(max_stages, 8))
|
||||
|
||||
return num_stages, num_stages
|
||||
|
||||
@staticmethod
|
||||
def _compute_grid(
|
||||
c: cute.Tensor,
|
||||
cta_tile_shape_mn: Tuple[int, int],
|
||||
cluster_shape_mn: Tuple[int, int],
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
|
||||
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
|
||||
|
||||
:param c: The output tensor C
|
||||
:type c: cute.Tensor
|
||||
:param cta_tile_shape_mn: The shape (M, N) of the CTA tile.
|
||||
:type cta_tile_shape_mn: tuple[int, int]
|
||||
:param cluster_shape_mn: Shape of each cluster in M, N dimensions.
|
||||
:type cluster_shape_mn: tuple[int, int]
|
||||
:param max_active_clusters: Maximum number of active clusters.
|
||||
:type max_active_clusters: cutlass.Constexpr
|
||||
|
||||
:return: A tuple containing:
|
||||
- tile_sched_params: Parameters for the persistent tile scheduler.
|
||||
- grid: Grid shape for kernel launch.
|
||||
:rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]
|
||||
"""
|
||||
c_shape = cute.slice_(cta_tile_shape_mn, (None, None))
|
||||
gc = cute.zipped_divide(c, tiler=c_shape)
|
||||
num_ctas_mn = gc[(0, (None, None))].shape
|
||||
cluster_shape_mnl = (*cluster_shape_mn, 1)
|
||||
num_ctas_mnl = (*num_ctas_mn, 1)
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
num_ctas_mnl, cluster_shape_mnl
|
||||
)
|
||||
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
|
||||
tile_sched_params, max_active_clusters
|
||||
)
|
||||
|
||||
return tile_sched_params, grid
|
||||
|
||||
@cute.jit
|
||||
def __call__(
|
||||
self, src: cute.Tensor, dst: cute.Tensor, max_active_clusters: cutlass.Constexpr
|
||||
):
|
||||
if cutlass.const_expr(src.element_type != dst.element_type):
|
||||
raise TypeError("Source and destination element types must match")
|
||||
|
||||
self.dtype: Type[cutlass.Numeric] = src.element_type
|
||||
|
||||
# Compute optimal stage counts based on tile size and dtype
|
||||
self.num_load_stages, self.num_store_stages = self._compute_stages(
|
||||
self.tile_m,
|
||||
self.tile_n,
|
||||
self.dtype,
|
||||
self.smem_capacity,
|
||||
)
|
||||
|
||||
# Create transposed view of dst for TMA descriptor
|
||||
# dst is (N, M), we want to view it as (M, N) transposed
|
||||
transed_dst = cute.make_tensor(
|
||||
dst.iterator,
|
||||
cute.make_layout(
|
||||
(dst.shape[1], dst.shape[0]), stride=(dst.stride[1], dst.stride[0])
|
||||
),
|
||||
)
|
||||
|
||||
# Create multi-stage layouts for load and store buffers
|
||||
# row-major smem layout for sA (tile_m, tile_n)
|
||||
smem_layout_sA_staged = sm100_utils.make_smem_layout(
|
||||
utils.LayoutEnum.from_tensor(src).mma_major_mode(),
|
||||
(self.tile_m, self.tile_n),
|
||||
self.dtype,
|
||||
self.num_load_stages,
|
||||
)
|
||||
|
||||
# col-major smem layout for sB (tile_n, tile_m)
|
||||
# sB should match the transposed destination layout
|
||||
smem_layout_sB_staged = sm100_utils.make_smem_layout(
|
||||
utils.LayoutEnum.from_tensor(transed_dst).mma_major_mode(),
|
||||
(self.tile_m, self.tile_n),
|
||||
self.dtype,
|
||||
self.num_store_stages,
|
||||
)
|
||||
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
# Pipeline barriers for multi-stage load
|
||||
load_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_load_stages
|
||||
]
|
||||
load_empty_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_load_stages
|
||||
]
|
||||
|
||||
# Pipeline barriers for multi-stage store
|
||||
store_full_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_store_stages
|
||||
]
|
||||
store_empty_mbar_ptr: cute.struct.MemRange[
|
||||
cutlass.Int64, self.num_store_stages
|
||||
]
|
||||
|
||||
# Multi-stage shared memory buffers
|
||||
sA: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sA_staged)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
sB: cute.struct.Align[
|
||||
cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sB_staged)],
|
||||
self.buffer_align_bytes,
|
||||
]
|
||||
|
||||
self.shared_storage = SharedStorage
|
||||
a_smem_layout = cute.slice_(smem_layout_sA_staged, (None, None, 0))
|
||||
self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, a_smem_layout)
|
||||
|
||||
# TMA Atoms
|
||||
# Each CTA loads its own tile independently
|
||||
tma_load_op = cpasync.CopyBulkTensorTileG2SOp()
|
||||
|
||||
cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*self.cluster_shape_mn, 1)), (1,)
|
||||
)
|
||||
|
||||
tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom(
|
||||
tma_load_op,
|
||||
src,
|
||||
smem_layout_sA_staged,
|
||||
(self.tile_m, self.tile_n),
|
||||
)
|
||||
|
||||
tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(),
|
||||
transed_dst,
|
||||
smem_layout_sB_staged,
|
||||
(self.tile_m, self.tile_n),
|
||||
)
|
||||
tile_sched_params, grid_shape = self._compute_grid(
|
||||
transed_dst,
|
||||
(self.tile_m, self.tile_n),
|
||||
self.cluster_shape_mn,
|
||||
max_active_clusters,
|
||||
)
|
||||
|
||||
self.kernel(
|
||||
tma_atom_src,
|
||||
tma_tensor_src,
|
||||
tma_atom_dst,
|
||||
tma_tensor_dst,
|
||||
smem_layout_sA_staged,
|
||||
smem_layout_sB_staged,
|
||||
cluster_layout_vmnk,
|
||||
tile_sched_params,
|
||||
).launch(
|
||||
grid=grid_shape,
|
||||
block=(self.threads_per_cta, 1, 1),
|
||||
cluster=self.cluster_shape_mnl,
|
||||
)
|
||||
|
||||
@cute.kernel
|
||||
def kernel(
|
||||
self,
|
||||
tma_atom_load: cute.CopyAtom,
|
||||
tma_tensor_src: cute.Tensor,
|
||||
tma_atom_store: cute.CopyAtom,
|
||||
tma_tensor_dst: cute.Tensor,
|
||||
smem_layout_sA_staged: Union[cute.Layout, cute.ComposedLayout],
|
||||
smem_layout_sB_staged: Union[cute.Layout, cute.ComposedLayout],
|
||||
cluster_layout_vmnk: cute.Layout,
|
||||
tile_sched_params: utils.PersistentTileSchedulerParams,
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
|
||||
# ---------------- Shared mem & staged buffers ----------------
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(self.shared_storage)
|
||||
|
||||
sA_staged = storage.sA.get_tensor(
|
||||
smem_layout_sA_staged.outer, swizzle=smem_layout_sA_staged.inner
|
||||
)
|
||||
sB_staged = storage.sB.get_tensor(
|
||||
smem_layout_sB_staged.outer, swizzle=smem_layout_sB_staged.inner
|
||||
)
|
||||
|
||||
load_mbar_ptr = storage.load_full_mbar_ptr.data_ptr()
|
||||
_store_mbar_ptr = storage.store_full_mbar_ptr.data_ptr()
|
||||
|
||||
load_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1)
|
||||
load_consumer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self.num_trans_warps
|
||||
)
|
||||
|
||||
load_pipeline = pipeline.PipelineTmaAsync.create(
|
||||
barrier_storage=load_mbar_ptr,
|
||||
num_stages=self.num_load_stages,
|
||||
producer_group=load_producer_group,
|
||||
consumer_group=load_consumer_group,
|
||||
tx_count=self.num_tma_load_bytes,
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
defer_sync=True,
|
||||
)
|
||||
|
||||
store_producer_group = pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, self.num_trans_threads
|
||||
)
|
||||
store_pipeline = pipeline.PipelineTmaStore.create(
|
||||
num_stages=self.num_store_stages,
|
||||
producer_group=store_producer_group,
|
||||
)
|
||||
|
||||
# Critical: Initialize pipeline barriers across cluster
|
||||
# This must happen after pipeline creation and before any producer/consumer work
|
||||
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
|
||||
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
|
||||
|
||||
gA = cute.local_tile(tma_tensor_src, self.tile_shape, (None, None))
|
||||
|
||||
_cta_layout = cute.make_layout(
|
||||
cute.slice_(cluster_layout_vmnk, (0, None, None, 0)).shape
|
||||
)
|
||||
# ((TileM, TileK), loopM, LoopK)
|
||||
tAsA, tAgA = cpasync.tma_partition(
|
||||
tma_atom_load,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sA_staged, 0, 2),
|
||||
cute.group_modes(gA, 0, 2),
|
||||
)
|
||||
|
||||
gDst_cta = cute.local_tile(tma_tensor_dst, self.tile_shape, (None, None))
|
||||
tBsB, tBgB = cpasync.tma_partition(
|
||||
tma_atom_store,
|
||||
0,
|
||||
cute.make_layout(1),
|
||||
cute.group_modes(sB_staged, 0, 2),
|
||||
cute.group_modes(gDst_cta, 0, 2),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# PRODUCER: TMA Load Warp (G -> sA)
|
||||
# ------------------------------------------------------------------
|
||||
if warp_idx == self.tma_load_warp_id:
|
||||
tile_sched = utils.StaticPersistentTileScheduler.create(
|
||||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
load_producer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_load_stages
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
tAgA_slice = tAgA[(None, cur_tile_coord[0], cur_tile_coord[1])]
|
||||
|
||||
load_pipeline.producer_acquire(load_producer_state)
|
||||
|
||||
cute.copy(
|
||||
tma_atom_load,
|
||||
tAgA_slice,
|
||||
tAsA[(None, load_producer_state.index)],
|
||||
tma_bar_ptr=load_pipeline.producer_get_barrier(load_producer_state),
|
||||
)
|
||||
|
||||
load_producer_state.advance()
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
load_pipeline.producer_tail(load_producer_state)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CONSUMER: Transpose Warps (sA -> Reg -> sB)
|
||||
# ------------------------------------------------------------------
|
||||
if warp_idx < self.tma_load_warp_id:
|
||||
tile_sched = utils.StaticPersistentTileScheduler.create(
|
||||
tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim()
|
||||
)
|
||||
work_tile = tile_sched.initial_work_tile_info()
|
||||
|
||||
trans_tid = tidx % self.num_trans_threads
|
||||
|
||||
load_consumer_state = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_load_stages
|
||||
)
|
||||
|
||||
while work_tile.is_valid_tile:
|
||||
# Get tile coord from tile scheduler
|
||||
cur_tile_coord = work_tile.tile_idx
|
||||
|
||||
# Wait for load pipeline to have data
|
||||
load_pipeline.consumer_wait(load_consumer_state)
|
||||
|
||||
atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.dtype,
|
||||
num_bits_per_copy=self.dtype.width,
|
||||
)
|
||||
copy_elems = 1
|
||||
|
||||
thread_layout = cute.make_layout(
|
||||
(self.num_trans_threads, 1),
|
||||
stride=(1, self.num_trans_threads),
|
||||
)
|
||||
value_layout = cute.make_layout((1, copy_elems))
|
||||
tiled_copy = cute.make_tiled_copy_tv(atom, thread_layout, value_layout)
|
||||
thr_copy = tiled_copy.get_slice(trans_tid)
|
||||
|
||||
# sA -> Reg
|
||||
tCsA = thr_copy.partition_S(sA_staged)
|
||||
|
||||
tCrA = cute.make_rmem_tensor(
|
||||
tCsA[(None, None, None, 0)].shape, self.dtype
|
||||
)
|
||||
tCrA = tiled_copy.retile(tCrA)
|
||||
|
||||
cute.copy(
|
||||
tiled_copy,
|
||||
tCsA[(None, None, None, load_consumer_state.index)],
|
||||
tCrA,
|
||||
)
|
||||
|
||||
# release load pipeline
|
||||
load_pipeline.consumer_release(load_consumer_state)
|
||||
load_consumer_state.advance()
|
||||
|
||||
index = tile_sched.num_tiles_executed % self.num_store_stages
|
||||
# Reg -> sB
|
||||
tCsB = thr_copy.partition_D(sB_staged)
|
||||
cute.copy(tiled_copy, tCrA, tCsB[(None, None, None, index)])
|
||||
|
||||
# Fence to ensure smem writes are visible
|
||||
cute.arch.fence_proxy(
|
||||
"async.shared",
|
||||
space="cta",
|
||||
)
|
||||
self.trans_sync_barrier.arrive_and_wait()
|
||||
|
||||
if warp_idx == self.trans_warp_id[0]:
|
||||
cute.copy(
|
||||
tma_atom_store,
|
||||
tBsB[(None, index)],
|
||||
tBgB[(None, cur_tile_coord[0], cur_tile_coord[1])],
|
||||
)
|
||||
store_pipeline.producer_commit()
|
||||
store_pipeline.producer_acquire()
|
||||
self.trans_sync_barrier.arrive_and_wait()
|
||||
|
||||
tile_sched.advance_to_next_work()
|
||||
work_tile = tile_sched.get_current_work()
|
||||
|
||||
self.trans_sync_barrier.arrive_and_wait()
|
||||
|
||||
store_pipeline.producer_tail()
|
||||
|
||||
|
||||
def run_transpose(M, N, max_active_clusters=0, num_warmup=5, num_iters=20):
|
||||
"""
|
||||
Run TMA transpose kernel with automatic stage calculation and performance measurement.
|
||||
|
||||
Args:
|
||||
M: Matrix dimension M
|
||||
N: Matrix dimension N
|
||||
max_active_clusters: Maximum number of active clusters (0 for auto)
|
||||
num_warmup: Number of warmup iterations
|
||||
num_iters: Number of timing iterations
|
||||
|
||||
Performance Metrics:
|
||||
- Throughput: Actual achieved bandwidth in GB/s
|
||||
- Theoretical BW: Peak memory bandwidth (2048 B/clk × 4000 MHz = 8.192 TB/s)
|
||||
- Bandwidth Efficiency: Percentage of theoretical peak achieved
|
||||
"""
|
||||
torch.manual_seed(1111)
|
||||
# Input (M, N)
|
||||
input_data = torch.randn((M, N), device="cuda", dtype=torch.float16)
|
||||
# Output (N, M)
|
||||
output_data = torch.zeros((N, M), device="cuda", dtype=torch.float16)
|
||||
|
||||
# CuTe Wrappers
|
||||
tensor_src = (
|
||||
from_dlpack(input_data, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=16)
|
||||
)
|
||||
tensor_dst = (
|
||||
from_dlpack(output_data, assumed_align=16)
|
||||
.mark_layout_dynamic(leading_dim=1)
|
||||
.mark_compact_shape_dynamic(mode=1, divisibility=16)
|
||||
)
|
||||
|
||||
transpose_kernel = Sm100MatrixTransposeKernelV2()
|
||||
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(1)
|
||||
# Compile and Run
|
||||
compiled_kernel = cute.compile(
|
||||
transpose_kernel,
|
||||
tensor_src,
|
||||
tensor_dst,
|
||||
max_active_clusters,
|
||||
options="--generate-line-info",
|
||||
)
|
||||
|
||||
# Warmup runs
|
||||
for _ in range(num_warmup):
|
||||
compiled_kernel(tensor_src, tensor_dst)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timed runs
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for _ in range(num_iters):
|
||||
compiled_kernel(tensor_src, tensor_dst)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Calculate performance metrics
|
||||
elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
avg_time_ms = elapsed_time_ms / num_iters
|
||||
|
||||
# Calculate throughput
|
||||
# For transpose: read M*N elements + write M*N elements
|
||||
bytes_per_element = input_data.element_size()
|
||||
total_bytes = 2 * M * N * bytes_per_element # Read + Write
|
||||
throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000)
|
||||
|
||||
# Theoretical bandwidth limit
|
||||
# Blackwell: 2048 B/clk at 4000 MHz
|
||||
bytes_per_clk = 2048
|
||||
freq_mhz = 4000
|
||||
theoretical_bw_gb_s = bytes_per_clk * freq_mhz * 1e6 / 1e9 # Convert to GB/s
|
||||
theoretical_bw_tb_s = theoretical_bw_gb_s / 1000 # Convert to TB/s
|
||||
bandwidth_efficiency = (throughput_gb_s / theoretical_bw_gb_s) * 100 # Percentage
|
||||
|
||||
# Print computed stage counts after compilation
|
||||
print(f"Matrix size: {M}×{N}")
|
||||
print(f"Tile shape: {transpose_kernel.tile_shape}")
|
||||
print(
|
||||
f"Computed stages: Load={transpose_kernel.num_load_stages}, Store={transpose_kernel.num_store_stages}"
|
||||
)
|
||||
print(f"Average time: {avg_time_ms:.4f} ms")
|
||||
print(f"Throughput: {throughput_gb_s:.2f} GB/s")
|
||||
print(
|
||||
f"Theoretical BW: {theoretical_bw_tb_s:.2f} TB/s ({theoretical_bw_gb_s:.2f} GB/s)"
|
||||
)
|
||||
print(f"Bandwidth Efficiency: {bandwidth_efficiency:.2f}%")
|
||||
|
||||
# Verification
|
||||
expected = input_data.t()
|
||||
if torch.allclose(output_data, expected, atol=1e-2):
|
||||
print("Verification: PASSED ✓")
|
||||
else:
|
||||
print("Verification: FAILED ✗")
|
||||
print(f"Max diff: {(output_data - expected).abs().max()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_comma_separated_ints(s: str) -> Tuple[int, ...]:
|
||||
try:
|
||||
return tuple(int(x.strip()) for x in s.split(","))
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format. Expected comma-separated integers."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TMA Matrix Transpose with Multi-Stage Pipeline and Cluster Support (v2)"
|
||||
)
|
||||
parser.add_argument("--M", type=int, default=128, help="Matrix dimension M")
|
||||
parser.add_argument("--N", type=int, default=128, help="Matrix dimension N")
|
||||
parser.add_argument(
|
||||
"--num_warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_iters", type=int, default=20, help="Number of timing iterations"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
run_transpose(
|
||||
args.M,
|
||||
args.N,
|
||||
num_warmup=args.num_warmup,
|
||||
num_iters=args.num_iters,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ from cutlass.cute.runtime import from_dlpack
|
||||
|
||||
if __name__ == "__main__":
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.join(current_dir, ".."))
|
||||
sys.path.insert(0, os.path.join(current_dir, "../../../.."))
|
||||
|
||||
from helpers import fmha_helpers as fmha_utils
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user