diff --git a/CHANGELOG.md b/CHANGELOG.md index b6af0b9d0..568379566 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,18 +2,59 @@ # CUTLASS 4.x -## [4.5.0](https://github.com/NVIDIA/cutlass/tree/main) (2026-03-27) +## [4.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v4.5.0) (2026-05-01) ### CuTe DSL +* New features + - New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy. + - MXF8F6F4 mixed precision supoort + - BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6 + - Block Scaled MMA for SM120 now works on Spark + - EFC broadcast semantics support + - EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations. + - Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy + - dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path + - cute.copy now supports user specified loop unrolling + * Bug fixing and improvements - Improved source code correlation for profiling/debugging + - Fixed an aarch64 segfault issue with tvm-ffi + - Re-organization for CuTe DSL examples/tutorials for better discoverability + +* More examples of authorizing peak-performance kernels + - MOE examles + - A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface. + - Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM. + - Compared to torch_210_cu13, very few problem has worse perf in B200. + - mxfp8_2dx3d: avg 1.29 speedup; + - mxfp8_2dx2d: avg 1.41 speedup; + - nvfp4_2dx3d: avg 1.11 speedup; + - nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98) + - bf16_2dx3d: avg 1.15 speedup (worst case 0.98) + - bf16_2dx2d: avg 1.17 speedup (worst case 0.96) + - Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel. + +* API changes + - ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead. ### CUTLASS C++ +* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels. + - Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes. + - Uses TMA multicast for A tile when using non-trivial cluster size along N mode. + - Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs. + - Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm). +* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels. +* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). +* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation. * Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition - Enables launching GEMM on stream with partial SM allocation. * Fix some kernel issues: - Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates - Fix CUTLASS clang build issues + - Fix atomicCAS read-modify-write loop in `ConstSubbyteReference` + - Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference + - Remove `PipelineStorage` shadowing in SM100 complex epilogue + - Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized * Fix some profiler issues: - Add missing reference kernels for blockwise GEMM profiler * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! diff --git a/LICENSE.txt b/LICENSE.txt index a7389952d..e971c5d22 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -30,5 +30,5 @@ Certain files within this repository are subject to separate licensing terms: - The files located in the `python/CuTeDSL` directory are licensed under the NVIDIA End User License Agreement (EULA). Please refer to - https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html + https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html for the full terms. diff --git a/README.md b/README.md index c410237e8..43925736e 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ # CUTLASS 4.5.0 -_CUTLASS 4.5.0 - March 2026_ +_CUTLASS 4.5.0 - May 2026_ CUTLASS is a collection of abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for @@ -45,16 +45,57 @@ To get started quickly - please refer : # What's New in CUTLASS 4.5 -### CuTe DSL +## CuTe DSL +* New features + - New Block API `block_copy()` to simplify TMA and S2T copy. Users can ignore detail about multicast and 2CTA partition for TMA by `block_copy()` and need not to invoke `tma_partition()`. And users can remove bulk of S2T initialization to simplify S2T copy. + - MXF8F6F4 mixed precision supoort + - BlockScaled MMA now supports MXF8*MXF4 or MXF8*MXF6 + - Block Scaled MMA for SM120 now works on Spark + - EFC broadcast semantics support + - EFC epilogue functions can now broadcast and remap tensor modes via `C.remap_modes[:, 0, 1]` subscript syntax (where `:` marks a broadcast dimension and integers select source mode indices). Covers scalar broadcast, row/column broadcast, and arbitrary mode permutations (e.g. transpose). The PyTorch reference evaluator mirrors the same transformations. + - Initial linter support: Improved type hints on CuTe DSL APIs to support static type checkers like MyPy + - dataclasses.dataclass is now supported for JIT compilaton and cute.compile for both plain and tvm-ffi path + - cute.copy now supports user specified loop unrolling + * Bug fixing and improvements - Improved source code correlation for profiling/debugging + - Fixed an aarch64 segfault issue with tvm-ffi + - Re-organization for CuTe DSL examples/tutorials for better discoverability -### CUTLASS C++ +* More examples of authorizing peak-performance kernels + - MOE examles + - A new style of grouped-gemm that aligns to torch's grouped_mm and scaled_groued_mm interface. + - Expert-wise tensormap descriptor setup by a cheap helper kernel (~2us) to avoid long latency in tile switching, kernel structure is much more closer to a normal GEMM. + - Compared to torch_210_cu13, very few problem has worse perf in B200. + - mxfp8_2dx3d: avg 1.29 speedup; + - mxfp8_2dx2d: avg 1.41 speedup; + - nvfp4_2dx3d: avg 1.11 speedup; + - nvfp4_2dx2d: avg 1.12 speedup (worst case 0.98) + - bf16_2dx3d: avg 1.15 speedup (worst case 0.98) + - bf16_2dx2d: avg 1.17 speedup (worst case 0.96) + - Note: The perf is measured from torch profiler, this impl includes the helper kernel + main kernel, while torch's includes its setup kernel and cutlass_cpp main kernel. + +* API changes + - ab_dtype is deprecated in make_trivial_tiled_mma and make_blockscaled_trivial_tiled_mma from blackwell_helpers.py. Please specify a_dtype and b_dtype separately instead. + +## CUTLASS C++ +* Add 2SM MMA instruction support to mixed TMA+CpAsync SM100 vanilla GEMM kernels. + - Mixed TMA+CpAsync can now accept static, but non trivial cluster shapes. + - Uses TMA multicast for A tile when using non-trivial cluster size along N mode. + - Uses an additional barrier (mma_trampoline_barrier) to track cp.async arrivals in both CTAs. + - Changes included in [example 92](https://github.com/NVIDIA/cutlass/tree/main/examples/92_blackwell_moe_gemm). +* Add support for 128x32xK and 128x64xK tile sizes for SM120 blockscaled MMA collective builders, yielding up to 30% performance improvement on Blackwell SM121 related kernels. +* Add static load to tensor memory support, included in [example 77](https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha/). +* Use 64-bit adds for SM100 MMA descriptor offsets and reduce move instructions for improved code generation. * Add [example 95](https://github.com/NVIDIA/cutlass/tree/main/examples/95_blackwell_gemm_green_context) to support green context SM partition - Enables launching GEMM on stream with partial SM allocation. * Fix some kernel issues: - Fix l2_capacity=0 handling in Blackwell SM100/SM120 kernel templates - Fix CUTLASS clang build issues + - Fix atomicCAS read-modify-write loop in `ConstSubbyteReference` + - Replace `__nv_atomic_load_n` with `volatile` for CUDA 11.4 compatibility in subbyte reference + - Remove `PipelineStorage` shadowing in SM100 complex epilogue + - Fix build issue in SM90 epilogue fusion visitor TMA warpspecialized * Fix some profiler issues: - Add missing reference kernels for blockwise GEMM profiler * Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! diff --git a/examples/13_two_tensor_op_fusion/CMakeLists.txt b/examples/13_two_tensor_op_fusion/CMakeLists.txt index 2cd849c4d..f44c44120 100644 --- a/examples/13_two_tensor_op_fusion/CMakeLists.txt +++ b/examples/13_two_tensor_op_fusion/CMakeLists.txt @@ -48,7 +48,7 @@ foreach(FUSION_CONV_EXAMPLE fused_two_convs_f16_sm80_shmem fused_two_convs_s8_sm75_rf fused_two_convs_s8_sm75_shmem - fused_two_convs_s8_sm80_rf + # fused_two_convs_s8_sm80_rf # disabled: fails to build fused_two_convs_s8_sm80_shmem ) diff --git a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp index 381c39cf1..2ee53b7e4 100644 --- a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp +++ b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -195,7 +195,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu index f702693ce..60e30219c 100644 --- a/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu +++ b/examples/74_blackwell_gemm_streamk/blackwell_gemm_streamk.cu @@ -216,7 +216,7 @@ struct Options { float alpha, beta; int iterations; int m, n, k; - int preferred_cluster_m, preferred_cluster_n, fallback_cluster_m, fallback_cluster_n; + int cluster_m, cluster_n; using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; DecompositionMode decomposition_mode; @@ -240,10 +240,8 @@ struct Options { m(256), n(256), k(16384), alpha(1.f), beta(0.f), iterations(10), - preferred_cluster_m(4), - preferred_cluster_n(4), - fallback_cluster_m(2), - fallback_cluster_n(1), + cluster_m(2), + cluster_n(1), decomposition_mode(DecompositionMode::Heuristic), reduction_mode(ReductionMode::Deterministic), splits(1) @@ -265,10 +263,8 @@ struct Options { cmd.get_cmd_line_argument("beta", beta, 0.f); cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("splits", splits, 1); - cmd.get_cmd_line_argument("preferred_cluster_m", preferred_cluster_m, 4); - cmd.get_cmd_line_argument("preferred_cluster_n", preferred_cluster_n, 4); - cmd.get_cmd_line_argument("fallback_cluster_m", fallback_cluster_m, 2); - cmd.get_cmd_line_argument("fallback_cluster_n", fallback_cluster_n, 1); + cmd.get_cmd_line_argument("cluster_m", cluster_m, 2); + cmd.get_cmd_line_argument("cluster_n", cluster_n, 1); // Parse decompsition mode std::string decomp_mode; @@ -303,10 +299,8 @@ struct Options { << " --k= Sets the K extent of the GEMM\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n" - << " --preferred_cluster_m= Sets the M extent of preferred cluster shape\n" - << " --preferred_cluster_n= Sets the N extent of preferred cluster shape\n" - << " --fallback_cluster_m= Sets the M extent of fallback cluster shape\n" - << " --fallback_cluster_n= Sets the N extent of fallback cluster shape\n" + << " --cluster_m= Sets the M extent of the cluster shape\n" + << " --cluster_n= Sets the N extent of the cluster shape\n" << " --decomposition= Mode in which the stream-K kernel should decompose the problem. Options: Heuristic (default), SplitK, StreamK, DataParallel\n" << " --reduction= Mode in which the stream-K kernel's reduction should be performed. Options: Deterministic (default), Nondeterministic\n" << " --iterations= Number of profiling iterations to perform.\n\n"; @@ -424,8 +418,8 @@ typename Gemm::Arguments args_from_options(const Options &options) { {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} }; - arguments.hw_info.cluster_shape = dim3(options.preferred_cluster_m, options.preferred_cluster_n,1); - arguments.hw_info.cluster_shape_fallback = dim3(options.fallback_cluster_m, options.fallback_cluster_n,1); + arguments.hw_info.cluster_shape = dim3(options.cluster_m, options.cluster_n, 1); + arguments.hw_info.cluster_shape_fallback = dim3(options.cluster_m, options.cluster_n, 1); arguments.scheduler.splits = options.splits; arguments.scheduler.decomposition_mode = options.decomposition_mode; @@ -498,8 +492,7 @@ int run(Options &options) { std::cout << "Stream-K GEMM with" << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k - << " Preferred Cluster = (" << options.preferred_cluster_m << ", " << options.preferred_cluster_n << ", 1)" - << " Fallback Cluster = (" << options.fallback_cluster_m << ", " << options.fallback_cluster_n << ", 1)\n" + << " Cluster = (" << options.cluster_m << ", " << options.cluster_n << ", 1)\n" << " Decomposition_mode=" << options.decomposition_mode_str() << " Split_count=" << options.splits << " Reduction_mode=" << options.reduction_mode_str() diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 01ca1a0e4..0c2007b67 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -536,7 +536,11 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row + #if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED + using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x; + #else using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + #endif using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem @@ -573,6 +577,14 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } ElementQK old_row_max = row_max; + #if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED + auto pos = tTMEM_LOADcS(0); + if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) { + float curr_max = tiled_tmem_load.get_max(); + row_max = ::fmax(row_max, curr_max); + } + else + #endif { // compute rowmax float row_max_0 = row_max; diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp index dc0f5fc9c..2f9070176 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -540,7 +540,11 @@ struct Sm100FmhaGenMainloopWarpspecialized { Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); // Each thread owns a single row + #if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED + using TMEM_LOAD = SM100_TMEM_LOAD_STAT_32dp32b32x; + #else using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + #endif using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem @@ -577,6 +581,14 @@ struct Sm100FmhaGenMainloopWarpspecialized { } ElementQK old_row_max = row_max; + #if defined CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED + auto pos = tTMEM_LOADcS(0); + if (!need_apply_mask || (need_apply_mask && (get<0>(pos) >= get<1>(pos) + 12) && (get<1>(pos) < get<1>(problem_shape)))) { + float curr_max = tiled_tmem_load.get_max(); + row_max = ::fmax(row_max, curr_max); + } + else + #endif { // compute rowmax float row_max_0 = row_max; diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu index 83cf624a9..118b08977 100644 --- a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_grouped.cu @@ -213,12 +213,15 @@ auto make_iterator(T* ptr) { /////////////////////////////////////////////////////////////////////////////////////////////////// -struct ExampleRunner { +template < // Type of kernel schedule to generate - using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100; + class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, // Type of epilogue schedule to generate - using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; - static constexpr bool FuseQuantization = false; + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, + class ClusterShapeMNK = Shape<_1, _1, _1>, + bool FuseQuantization = false +> +struct ExampleRunner { using LayoutATag = cutlass::layout::RowMajor; using LayoutBTag = cutlass::layout::ColumnMajor; @@ -238,10 +241,8 @@ struct ExampleRunner { using ElementCompute = float; using ElementScalar = float; - - - using ClusterShapeMNK = Shape<_1,_1,_1>; - using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage) + static constexpr int TileM = cute::is_base_of_v ? 256 : 128; + using MmaTileMNK = Shape,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage) static constexpr int AlignmentA = 32; static constexpr int AlignmentB = 32; @@ -712,10 +713,34 @@ int main(int argc, char const **args) { hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl; ExampleRunner runner_mixed_tma_cpasync; runner_mixed_tma_cpasync.run(options, hw_info); + std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load, 1SM, 2x2 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + cutlass::epilogue::collective::EpilogueScheduleAuto, + Shape<_2, _2, _1> + > runner_mixed_tma_cpasync_1sm_2x2; + runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info); + + std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x1 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100, + cutlass::epilogue::collective::EpilogueScheduleAuto, + Shape<_2, _1, _1> + > runner_mixed_tma_cpasync_2sm_2x1; + runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info); + + std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load, 2x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100, + cutlass::epilogue::collective::EpilogueScheduleAuto, + Shape<_2, _4, _1> + > runner_mixed_tma_cpasync_2sm_2x4; + runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info); + #endif return 0; diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu index 3360cb456..011ad0c9e 100644 --- a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_fp4_regular.cu @@ -220,6 +220,7 @@ using SfdOutputCfg = cutlass::detail::Sm1xxBlockScaledOutputConfig, // Type of epilogue schedule to generate class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, bool FuseQuantization = false @@ -246,8 +247,8 @@ struct ExampleRunner { - using ClusterShapeMNK = Shape<_1,_1,_1>; - using MmaTileMNK = Shape<_128,_64,_256>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage) + static constexpr int TileM = cute::is_base_of_v ? 256 : 128; + using MmaTileMNK = Shape,_64,_64>; // use tile size of N=64 to match real use cases (N is typically very small in decoding stage) static constexpr int AlignmentA = 32; static constexpr int AlignmentB = 32; @@ -485,6 +486,24 @@ struct ExampleRunner { }, hw_info }; + } + else if constexpr (std::is_same_v){ + return typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { // Mainloop arguments + block_A.device_data(), + block_B.device_data(), + block_SFA.device_data(), + block_SFB.device_data() + }, + { // Epilogue arguments + {}, + block_C.device_data(), stride_C, + block_D.device_data(), stride_D + }, + hw_info + }; } else { return typename Gemm::Arguments { @@ -654,10 +673,80 @@ int main(int argc, char const **args) { ExampleRunner runner_tma; runner_tma.run(options, hw_info); - std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 1x1 cluster:" << std::endl; ExampleRunner runner_mixed_tma_cpasync; runner_mixed_tma_cpasync.run(options, hw_info); + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x1 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + Shape<_4, _1, _1> + > runner_mixed_tma_cpasync_1sm_4x1; + runner_mixed_tma_cpasync_1sm_4x1.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 1x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + Shape<_1, _4, _1> + > runner_mixed_tma_cpasync_1sm_1x4; + runner_mixed_tma_cpasync_1sm_1x4.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x2 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + Shape<_2, _2, _1> + > runner_mixed_tma_cpasync_1sm_2x2; + runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x2 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + Shape<_4, _2, _1> + > runner_mixed_tma_cpasync_1sm_4x2; + runner_mixed_tma_cpasync_1sm_4x2.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + Shape<_2, _4, _1> + > runner_mixed_tma_cpasync_1sm_2x4; + runner_mixed_tma_cpasync_1sm_2x4.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 4x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100, + Shape<_4, _4, _1> + > runner_mixed_tma_cpasync_1sm_4x4; + runner_mixed_tma_cpasync_1sm_4x4.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x1 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100, + Shape<_4, _1, _1> + > runner_mixed_tma_cpasync_2sm_4x1; + runner_mixed_tma_cpasync_2sm_4x1.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 2x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100, + Shape<_2, _4, _1> + > runner_mixed_tma_cpasync_2sm_2x4; + runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x2 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100, + Shape<_4, _2, _1> + > runner_mixed_tma_cpasync_2sm_4x2; + runner_mixed_tma_cpasync_2sm_4x2.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 2SM instr, 4x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100, + Shape<_4, _4, _1> + > runner_mixed_tma_cpasync_2sm_4x4; + runner_mixed_tma_cpasync_2sm_4x4.run(options, hw_info); + #endif return 0; diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu index a286acb8a..13673fda8 100644 --- a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_grouped.cu @@ -70,7 +70,6 @@ #include "helper.h" - using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -169,12 +168,14 @@ bool initialize_block( /////////////////////////////////////////////////////////////////////////////////////////////////// -struct ExampleRunner { - +template< // Type of kernel schedule to generate - using MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100; + class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, // Type of epilogue schedule to generate - using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, + class ClusterShapeMNK = Shape<_1, _1, _1> +> +struct ExampleRunner { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -189,7 +190,6 @@ struct ExampleRunner { using ElementCompute = float; using ElementScalar = float; - using ClusterShapeMNK = Shape<_1,_1,_1>; using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage) // 16B alignment lets us use TMA @@ -219,7 +219,7 @@ struct ExampleRunner { MainloopScheduleType >::CollectiveOp; - using ProblemShape = cutlass::gemm::MoEProblemShape>; // per group + using ProblemShape = typename cutlass::gemm::MoEProblemShape>; // per group using GemmKernel = cutlass::gemm::kernel::GemmUniversal< ProblemShape, @@ -272,10 +272,10 @@ struct ExampleRunner { auto [M, N, K] = problem; printf("group [%d] : M = %d, N = %d, K = %d\n", i, M, N, K); - cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, Gemm::LayoutA(maxK)); - cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, Gemm::LayoutB(maxK)); - cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutC(maxM)); - cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, Gemm::LayoutD(maxM)); + cutlass::TensorRef ref_A(block_A.get() + size_t(1) * i * maxM * maxK, typename Gemm::LayoutA(maxK)); + cutlass::TensorRef ref_B(block_B.get() + size_t(1) * i * maxN * maxK, typename Gemm::LayoutB(maxK)); + cutlass::TensorRef ref_C(block_C.get() + size_t(1) * i * maxN * maxM, typename Gemm::LayoutC(maxM)); + cutlass::TensorRef ref_D(block_ref_D.get() + size_t(1) * i * maxN * maxM, typename Gemm::LayoutD(maxM)); using DeviceGemmReference = cutlass::reference::device::Gemm< ElementA, @@ -561,10 +561,62 @@ int main(int argc, char const **args) { hw_info.device_id = 0; hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); - std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + std::cout << "Running kernel with mixed TMA+CPASYNC load, 1SM:" << std::endl; ExampleRunner runner_mixed_tma_cpasync; runner_mixed_tma_cpasync.run(options, hw_info); + std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load and 1x1 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + cutlass::epilogue::collective::EpilogueScheduleAuto, + Shape<_1, _1, _1> + > runner_mixed_tma_cpasync_1sm; + runner_mixed_tma_cpasync_1sm.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load with 1SM instr, 2x2 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + cutlass::epilogue::collective::EpilogueScheduleAuto, + Shape<_2, _2, _1> + > runner_mixed_tma_cpasync_1sm_2x2; + runner_mixed_tma_cpasync_1sm_2x2.run(options, hw_info); + + std::cout << "\n\n\nRunning kernel with mixed TMA+CPASYNC load and 4x4 cluster:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + cutlass::epilogue::collective::EpilogueScheduleAuto, + Shape<_4, _4, _1> + > runner_mixed_tma_cpasync_1sm_4x4; + runner_mixed_tma_cpasync_1sm_4x4.run(options, hw_info); + + std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 2x1:" << std::endl; + ExampleRunner + > runner_mixed_tma_cpasync_2sm_2x1; + runner_mixed_tma_cpasync_2sm_2x1.run(options, hw_info); + + std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 4x1:" << std::endl; + ExampleRunner + > runner_mixed_tma_cpasync_2sm_4x1; + runner_mixed_tma_cpasync_2sm_4x1.run(options, hw_info); + + std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 2x4:" << std::endl; + ExampleRunner + > runner_mixed_tma_cpasync_2sm_2x4; + runner_mixed_tma_cpasync_2sm_2x4.run(options, hw_info); + + std::cout << "\n\n\nRunning 2SM kernel with mixed TMA+CPASYNC load 4x4:" << std::endl; + ExampleRunner + > runner_mixed_tma_cpasync_2sm_4x4; + runner_mixed_tma_cpasync_2sm_4x4.run(options, hw_info); + #endif return 0; diff --git a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu index 4e2335f04..1703edf15 100644 --- a/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu +++ b/examples/92_blackwell_moe_gemm/92_blackwell_moe_gemm_regular.cu @@ -171,6 +171,8 @@ bool initialize_block( template < // Type of kernel schedule to generate class MainloopScheduleType = cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + // Static cluster shape + class ClusterShapeMNK = Shape<_1, _1, _1>, // Type of epilogue schedule to generate class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto > @@ -189,7 +191,6 @@ struct ExampleRunner { using ElementCompute = float; using ElementScalar = float; - using ClusterShapeMNK = Shape<_1,_1,_1>; using MmaTileMNK = Shape<_128,_16,Int<128 / sizeof(ElementA)>>; // use tile size of N=16 to match real use cases (N is typically very small in decoding stage) // 16B alignment lets us use TMA @@ -334,6 +335,16 @@ struct ExampleRunner { hw_info }; } + else if constexpr (std::is_same_v){ + return typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), block_B.get()}, + {{}, // epilogue.thread + block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + } else { return typename Gemm::Arguments { cutlass::gemm::GemmUniversalMode::kGemm, @@ -490,13 +501,33 @@ int main(int argc, char const **args) { runner_tma.run(options, hw_info); std::cout << "Running kernel with CPASYNC load:" << std::endl; - ExampleRunner runner_cpasync; + ExampleRunner runner_cpasync; runner_cpasync.run(options, hw_info); - std::cout << "Running kernel with mixed TMA+CPASYNC load:" << std::endl; + std::cout << "Running kernel with mixed TMA+CPASYNC load, cluster shape 1x1:" << std::endl; ExampleRunner runner_mixed_tma_cpasync; runner_mixed_tma_cpasync.run(options, hw_info); + std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 2x2:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + Shape<_2, _2, _1> + > runner_mixed_tma_cpasync_2x2; + runner_mixed_tma_cpasync_2x2.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 4x4, 1SM instructions:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized1SmSm100, + Shape<_4, _4, _1> + > runner_mixed_tma_cpasync_4x4; + runner_mixed_tma_cpasync_4x4.run(options, hw_info); + + std::cout << "Running kernel with mixed TMA+CPASYNC load w/ multicast, cluster shape 4x4, 2SM instructions:" << std::endl; + ExampleRunner< + cutlass::gemm::KernelMixedTmaCpAsyncWarpSpecialized2SmSm100, + Shape<_4, _4, _1> + > runner_mixed_tma_cpasync_4x4_2sm; + runner_mixed_tma_cpasync_4x4_2sm.run(options, hw_info); #endif return 0; diff --git a/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/examples/python/CuTeDSL/cute/ampere/kernel/attention/flash_attention_v2.py similarity index 100% rename from examples/python/CuTeDSL/ampere/flash_attention_v2.py rename to examples/python/CuTeDSL/cute/ampere/kernel/attention/flash_attention_v2.py diff --git a/examples/python/CuTeDSL/ampere/hstu_attention.py b/examples/python/CuTeDSL/cute/ampere/kernel/attention/hstu_attention.py similarity index 100% rename from examples/python/CuTeDSL/ampere/hstu_attention.py rename to examples/python/CuTeDSL/cute/ampere/kernel/attention/hstu_attention.py diff --git a/examples/python/CuTeDSL/ampere/sgemm.py b/examples/python/CuTeDSL/cute/ampere/kernel/dense_gemm/sgemm.py similarity index 100% rename from examples/python/CuTeDSL/ampere/sgemm.py rename to examples/python/CuTeDSL/cute/ampere/kernel/dense_gemm/sgemm.py diff --git a/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/examples/python/CuTeDSL/cute/ampere/kernel/dense_gemm/tensorop_gemm.py similarity index 100% rename from examples/python/CuTeDSL/ampere/tensorop_gemm.py rename to examples/python/CuTeDSL/cute/ampere/kernel/dense_gemm/tensorop_gemm.py diff --git a/examples/python/CuTeDSL/ampere/elementwise_add.py b/examples/python/CuTeDSL/cute/ampere/kernel/elementwise/elementwise_add.py similarity index 100% rename from examples/python/CuTeDSL/ampere/elementwise_add.py rename to examples/python/CuTeDSL/cute/ampere/kernel/elementwise/elementwise_add.py diff --git a/examples/python/CuTeDSL/ampere/elementwise_apply.py b/examples/python/CuTeDSL/cute/ampere/kernel/elementwise/elementwise_apply.py similarity index 100% rename from examples/python/CuTeDSL/ampere/elementwise_apply.py rename to examples/python/CuTeDSL/cute/ampere/kernel/elementwise/elementwise_apply.py diff --git a/examples/python/CuTeDSL/ampere/elementwise_add_autotune.py b/examples/python/CuTeDSL/cute/ampere/tutorial/elementwise_add_autotune.py similarity index 100% rename from examples/python/CuTeDSL/ampere/elementwise_add_autotune.py rename to examples/python/CuTeDSL/cute/ampere/tutorial/elementwise_add_autotune.py diff --git a/examples/python/CuTeDSL/blackwell/epilogue/activation_custom_epilogue_dense_gemm.py b/examples/python/CuTeDSL/cute/blackwell/efc/activation_custom_epilogue_dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/epilogue/activation_custom_epilogue_dense_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/efc/activation_custom_epilogue_dense_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/epilogue/common_dense_gemm_efc.py b/examples/python/CuTeDSL/cute/blackwell/efc/common_dense_gemm_efc.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/epilogue/common_dense_gemm_efc.py rename to examples/python/CuTeDSL/cute/blackwell/efc/common_dense_gemm_efc.py diff --git a/examples/python/CuTeDSL/blackwell/epilogue/common_efc.py b/examples/python/CuTeDSL/cute/blackwell/efc/common_efc.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/epilogue/common_efc.py rename to examples/python/CuTeDSL/cute/blackwell/efc/common_efc.py diff --git a/examples/python/CuTeDSL/blackwell/epilogue/custom_epilogue_dense_gemm.py b/examples/python/CuTeDSL/cute/blackwell/efc/custom_epilogue_dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/epilogue/custom_epilogue_dense_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/efc/custom_epilogue_dense_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/epilogue/synthetic_custom_epilogue_dense_gemm.py b/examples/python/CuTeDSL/cute/blackwell/efc/synthetic_custom_epilogue_dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/epilogue/synthetic_custom_epilogue_dense_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/efc/synthetic_custom_epilogue_dense_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/fmha.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/fmha.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py index 5d4ddd7b4..539585e28 100644 --- a/examples/python/CuTeDSL/blackwell/fmha.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py @@ -48,7 +48,7 @@ from cutlass.cute.typing import Int32, Int64, Float32 if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) from helpers import fmha_helpers as fmha_utils diff --git a/examples/python/CuTeDSL/blackwell/fmha_bwd.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/fmha_bwd.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py index 9d415423a..cb94029af 100644 --- a/examples/python/CuTeDSL/blackwell/fmha_bwd.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py @@ -51,7 +51,7 @@ from cutlass.cute.typing import Int32, Float32, Float8E4M3FN, Float16, BFloat16, if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) from helpers import fmha_helpers as fmha_utils diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py index bd709f837..a9e2a5d56 100644 --- a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd.py @@ -45,14 +45,14 @@ from cutlass.cute.runtime import from_dlpack if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) -from blackwell.mamba2_ssd.mamba2_ssd_reference import ( +from cute.blackwell.kernel.attention.mamba2_ssd.mamba2_ssd_reference import ( ssd_reference_fp32_all, ssd_reference_lowprecision_intermediates, analyze_relative_diffs, ) -from blackwell.mamba2_ssd.mamba2_ssd_tile_scheduler import ( +from cute.blackwell.kernel.attention.mamba2_ssd.mamba2_ssd_tile_scheduler import ( Mamba2SSDTileSchedulerParams, Mamba2SSDTileScheduler, ) diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd_reference.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd_reference.py diff --git a/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd_tile_scheduler.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mamba2_ssd/mamba2_ssd_tile_scheduler.py diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py similarity index 90% rename from examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py index 579b3cd0e..8c4f7157a 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_decode.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_decode.py @@ -27,14 +27,11 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse -import enum import math -import time from typing import Type, Tuple from functools import partial import torch -import torch.nn.functional as F from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel @@ -42,7 +39,7 @@ import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute -import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.cute.nvgpu import tcgen05, OperandMajorMode import cutlass.utils as utils import cutlass.pipeline as pipeline import cutlass.torch as cutlass_torch @@ -51,9 +48,6 @@ import cutlass.cute.testing as testing from cutlass.cute.runtime import from_dlpack from cutlass.cute.typing import * -from cutlass._mlir.dialects import llvm -from cutlass.cute.arch.nvvm_wrappers import mapa - # Kernel invariants mma_modes = (0, 1, 2) mma_dice = (None, None, None) # (MMA, #MMA_M, #MMA_K) @@ -65,19 +59,22 @@ warpgroup_threads = 128 # Math helpers log2_e = math.log2(math.e) # change exponential base use_tensor_ssa_math = False # experimental -fadd2 = partial(cute.arch.add_packed_f32x2, ftz=False, rnd="rn") -fmul2 = partial(cute.arch.mul_packed_f32x2, ftz=False, rnd="rn") -ffma2 = partial(cute.arch.fma_packed_f32x2, ftz=False, rnd="rn") +fadd2 = cute.arch.add_packed_f32x2 +fmul2 = cute.arch.mul_packed_f32x2 +ffma2 = cute.arch.fma_packed_f32x2 exp2 = partial(cute.math.exp2, fastmath=True) +warp_fmax = partial(cute.arch.warp_redux_sync, kind="fmax", nan=True) +smem_fmax = partial(cute.arch.atomic_fmax, sem="relaxed", scope="cta") +gmem_fmax = partial(cute.arch.atomic_fmax, sem="relaxed", scope="gpu") class MixedInputFusedMultiHeadAttentionDecode: def __init__( self, headdim, - block_scaledim, # headdim per scale factor; scale factor shape is (batches, heads_k, seqlen, headdim / block_scaledim) - grouped_head_tile, # GQA packing tile size, can be less than group size - convert_warpgroups = 1, # Multiple warpgroups striding on convert stages + block_scaledim, # headdim per scale factor; scale factor shape is (batches, heads_k, seqlen, headdim / block_scaledim) + grouped_head_tile, # GQA packing tile size, can be less than group size + convert_warpgroups=1, # Multiple warpgroups striding on convert stages ): self.headdim = headdim self.grouped_head_tile = grouped_head_tile @@ -93,7 +90,9 @@ class MixedInputFusedMultiHeadAttentionDecode: self.softmax_warpgroup_id = warpgroup_id warpgroup_id += 1 - self.cvt_warpgroup_ids = tuple(range(warpgroup_id, warpgroup_id+convert_warpgroups)) + self.cvt_warpgroup_ids = tuple( + range(warpgroup_id, warpgroup_id + convert_warpgroups) + ) warpgroup_id += convert_warpgroups # Why 2 MMA+TMA warps when not MMA bound? @@ -114,12 +113,15 @@ class MixedInputFusedMultiHeadAttentionDecode: max_regs_per_wg_thread = 64 * 1024 // warpgroup_threads # 64K regs per SM self.mma_tma_regs = 72 self.cvt_regs = 112 - self.softmax_regs = (max_regs_per_wg_thread - - self.mma_tma_regs - - self.cvt_regs * convert_warpgroups) + self.softmax_regs = ( + max_regs_per_wg_thread + - self.mma_tma_regs + - self.cvt_regs * convert_warpgroups + ) self.softmax_regs = max(128, min(256, self.softmax_regs)) - assert (self.mma_tma_regs + self.softmax_regs + - self.cvt_regs * convert_warpgroups) <= max_regs_per_wg_thread or not self.use_reg_reconfig + assert ( + self.mma_tma_regs + self.softmax_regs + self.cvt_regs * convert_warpgroups + ) <= max_regs_per_wg_thread or not self.use_reg_reconfig self.bs_stages = 2 self.sp_stages = 2 @@ -140,25 +142,33 @@ class MixedInputFusedMultiHeadAttentionDecode: raise ValueError("use Float8E4M3FN instead of Float8E4M3") if d % 64 != 0: - raise ValueError(f"headdim({d}) must be multiple of 64") + raise testing.CantImplementError(f"headdim({d}) must be multiple of 64") if h_q % h_k != 0: - raise ValueError(f"heads_q({h_q}) must be a multiple of heads_k({h_k})") + raise testing.CantImplementError( + f"heads_q({h_q}) must be a multiple of heads_k({h_k})" + ) align_scale_bits = 128 # TMA requirement if self.scaledim * q_dtype.width < align_scale_bits: align_seq = align_scale_bits // (self.scaledim * q_dtype.width) if s_k % align_seq != 0: - raise ValueError(f"seqlen({s_k}) must be a multiple of {align_seq}") + raise testing.CantImplementError( + f"seqlen({s_k}) must be a multiple of {align_seq}" + ) if kv_dtype.width < 8 and d % 128 != 0: # TMA requirement - raise ValueError(f"headdim({d}) must be multiple of 128 for {kv_dtype} KV") + raise testing.CantImplementError( + f"headdim({d}) must be multiple of 128 for {kv_dtype} KV" + ) @cute.jit def __call__( self, - problem_shape: Tuple[Int32, Int32, Int32, Int32, Int32], # b, h_q, h_k, s_k, d - kv_splits: Int32, # threadblocks per sequence + problem_shape: Tuple[ + cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32, cutlass.Int32 + ], # b, h_q, h_k, s_k, d + kv_splits: cutlass.Int32, # threadblocks per sequence q_iter: cute.Pointer, k_iter: cute.Pointer, v_iter: cute.Pointer, @@ -170,8 +180,8 @@ class MixedInputFusedMultiHeadAttentionDecode: o_partial_iter: cute.Pointer, # partial O per kv split m_partial_iter: cute.Pointer, # partial colmax_s per kv split l_partial_iter: cute.Pointer, # partial colsum_p per kv split - scale_qs: Float32, - scale_o: Float32, + scale_qs: cutlass.Float32, + scale_o: cutlass.Float32, stream: cuda.CUstream, ): ############################## @@ -179,7 +189,7 @@ class MixedInputFusedMultiHeadAttentionDecode: ############################## mma_dtype = q_iter.dtype acc_dtype = o_partial_iter.dtype - assert acc_dtype is Float32 # don't support other acc types for now + assert acc_dtype is cutlass.Float32 # don't support other acc types for now # Block tile sets the granularity at which threadblocks consume work blk_tile_s = 128 @@ -197,8 +207,9 @@ class MixedInputFusedMultiHeadAttentionDecode: # GEMM1: (S_K, H_R, D, (H_K, B)) tiled_mma_kq = sm100_utils.make_trivial_tiled_mma( mma_dtype, - tcgen05.OperandMajorMode.K, # K - tcgen05.OperandMajorMode.K, # Q + mma_dtype, + OperandMajorMode.K, # K + OperandMajorMode.K, # Q acc_dtype, tcgen05.CtaGroup.ONE, mma_tile_mnk[:2], @@ -208,8 +219,9 @@ class MixedInputFusedMultiHeadAttentionDecode: # GEMM2: (D, H_R, S_K, (H_K, B)) tiled_mma_vp = sm100_utils.make_trivial_tiled_mma( # mma_dtype, - tcgen05.OperandMajorMode.K, # V - tcgen05.OperandMajorMode.MN, # P + mma_dtype, + OperandMajorMode.K, # V + OperandMajorMode.MN, # P acc_dtype, tcgen05.CtaGroup.ONE, mma_tile_mnk[:2], @@ -523,8 +535,8 @@ class MixedInputFusedMultiHeadAttentionDecode: mM: cute.Tensor, mM_partial: cute.Tensor, mL_partial: cute.Tensor, - scale_qs: Float32, - scale_qs_log2_e: Float32, + scale_qs: cutlass.Float32, + scale_qs_log2_e: cutlass.Float32, ): # Read special registers kv_splits, tiles_hr, tiles_hb = cute.arch.grid_dim() @@ -586,7 +598,7 @@ class MixedInputFusedMultiHeadAttentionDecode: ############################## # Tmem Allocation ############################## - tmem_ptr_smem_ptr = smem.allocate_array(Int32) + tmem_ptr_smem_ptr = smem.allocate_array(cutlass.Int32) if warp_idx == init_warp and not exit_early: cute.arch.alloc_tmem(self.tmem_alloc_cols, tmem_ptr_smem_ptr) init_warp += 1 @@ -595,26 +607,32 @@ class MixedInputFusedMultiHeadAttentionDecode: # Pipeline Allocation + Init ############################## # Allocate Mbarriers - q_pipeline_ptr = smem.allocate_array(Int64, self.q_stages * 2) - kv_pipeline_ptr = smem.allocate_array(Int64, self.kv_stages * 2) - bs_pipeline_ptr = smem.allocate_array(Int64, self.bs_stages * 2) - cvt_pipeline_ptr = smem.allocate_array(Int64, self.cvt_stages * 2) - s_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2) - p_pipeline_ptr = smem.allocate_array(Int64, self.sp_stages * 2) - o_pipeline_ptr = smem.allocate_array(Int64, self.o_stages * 2) + q_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.q_stages * 2) + kv_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.kv_stages * 2) + bs_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.bs_stages * 2) + cvt_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.cvt_stages * 2) + s_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.sp_stages * 2) + p_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.sp_stages * 2) + o_pipeline_ptr = smem.allocate_array(cutlass.Int64, self.o_stages * 2) # Declare named barriers - softmax_nbar_id = 1 - mma_kq_nbar_id = 2 - mma_vp_nbar_id = 3 + softmax_nbar = pipeline.NamedBarrier( + barrier_id=1, num_threads=warpgroup_threads + ) + mma_kq_nbar = pipeline.NamedBarrier(barrier_id=2, num_threads=64) + mma_vp_nbar = pipeline.NamedBarrier(barrier_id=3, num_threads=64) # Alias thread cooperatives elect_one_cooperative = pipeline.CooperativeGroup(pipeline.Agent.Thread) - warpgroup_cooperative = pipeline.CooperativeGroup(pipeline.Agent.Thread, warpgroup_threads) + warpgroup_cooperative = pipeline.CooperativeGroup( + pipeline.Agent.Thread, warpgroup_threads + ) mma_group = elect_one_cooperative tma_group = elect_one_cooperative cvt_group = warpgroup_cooperative - cvt_groups = pipeline.CooperativeGroup(pipeline.Agent.Thread, warpgroup_threads * self.convert_warpgroups) + cvt_groups = pipeline.CooperativeGroup( + pipeline.Agent.Thread, warpgroup_threads * self.convert_warpgroups + ) softmax_group = warpgroup_cooperative # Initialize pipelines @@ -642,7 +660,9 @@ class MixedInputFusedMultiHeadAttentionDecode: num_stages=self.bs_stages, producer_group=tma_group, consumer_group=cvt_groups, - tx_count=cute.size_in_bytes(mma_dtype, cute.select(smem_layout_bs, mode=[0,1])), + tx_count=cute.size_in_bytes( + mma_dtype, cute.select(smem_layout_bs, mode=[0, 1]) + ), barrier_storage=bs_pipeline_ptr, tidx=mcast_coord, cta_layout_vmnk=mcast_layout, @@ -772,7 +792,7 @@ class MixedInputFusedMultiHeadAttentionDecode: tCtO = thrblk_mma_vp.make_fragment_C(tCsO.shape) # Tmem tensor allocation - tmem_ptr = cute.arch.retrieve_tmem_ptr(Int32, 16, tmem_ptr_smem_ptr) + tmem_ptr = cute.arch.retrieve_tmem_ptr(cutlass.Int32, 16, tmem_ptr_smem_ptr) tmem_offset = 0 tAtK_cvt = cute.make_tensor( @@ -802,7 +822,7 @@ class MixedInputFusedMultiHeadAttentionDecode: # Exit early ############################## if exit_early: - noop = None # early return not supported + noop = None # early return not supported # noqa: F841 ############################## # TMA KV Dispatch @@ -991,9 +1011,11 @@ class MixedInputFusedMultiHeadAttentionDecode: cute.arch.setmaxregister_decrease(self.cvt_regs) # Intermediate convert type - cvt_type = Float32 - if cutlass.const_expr(mma_dtype is cutlass.BFloat16 and - k_dtype in (cutlass.Int4, cutlass.Int8)): + cvt_type = cutlass.Float32 + if cutlass.const_expr( + mma_dtype is cutlass.BFloat16 + and k_dtype in (cutlass.Int4, cutlass.Int8) + ): cvt_type = mma_dtype # Initialize for multiple warpgroups if necessary @@ -1015,12 +1037,13 @@ class MixedInputFusedMultiHeadAttentionDecode: smem_load_atom_k = cute.make_copy_atom( cute.nvgpu.warp.LdMatrix8x16x8bOp( num_matrices=4, - unpack_bits=(k_dtype.width if k_dtype.width < 8 else None)), + unpack_bits=(k_dtype.width if k_dtype.width < 8 else None), + ), kv_smem_dtype, ) tmem_store_k = tcgen05.make_tmem_copy( - tmem_store_atom_k, tAtK_cvt[mma_dice+(0,)] + tmem_store_atom_k, tAtK_cvt[mma_dice + (0,)] ) thr_store_k = tmem_store_k.get_slice(warpgroup_tidx) tKrK_cvt_shape = thr_store_k.partition_S(tAtK_cvt).shape[:-1] @@ -1043,7 +1066,7 @@ class MixedInputFusedMultiHeadAttentionDecode: smem_load_atom_v = cute.make_copy_atom(smem_load_op_v, kv_smem_dtype) tmem_store_v = tcgen05.make_tmem_copy( - tmem_store_atom_v, tAtV_cvt[mma_dice+(0,)] + tmem_store_atom_v, tAtV_cvt[mma_dice + (0,)] ) thr_store_v = tmem_store_v.get_slice(warpgroup_tidx) tVrV_cvt_shape = thr_store_v.partition_S(tAtV_cvt).shape[:-1] @@ -1113,7 +1136,9 @@ class MixedInputFusedMultiHeadAttentionDecode: bs_handle.release() # Convert and scale K - for dk in cutlass.range(tiles_dk // self.convert_warpgroups, unroll=2): + for dk in cutlass.range( + tiles_dk // self.convert_warpgroups, unroll=2 + ): tKrK = cute.make_rmem_tensor(tKrK_shape, kv_smem_dtype) tKrK_cvt = cute.make_rmem_tensor(tKrK_cvt_shape, mma_dtype) @@ -1139,7 +1164,9 @@ class MixedInputFusedMultiHeadAttentionDecode: cvt_handle = cvt_producer.acquire_and_advance() cute.copy( - thr_store_k, tKrK_cvt, tKtK_cvt[cpy_dice + (cvt_handle.index,)] + thr_store_k, + tKrK_cvt, + tKtK_cvt[cpy_dice + (cvt_handle.index,)], ) cute.arch.fence_view_async_tmem_store() cvt_handle.commit() @@ -1160,7 +1187,9 @@ class MixedInputFusedMultiHeadAttentionDecode: bs_handle.release() # Convert and scale V - for dmsk in cutlass.range(tiles_dm * tiles_sk // self.convert_warpgroups, unroll=2): + for dmsk in cutlass.range( + tiles_dm * tiles_sk // self.convert_warpgroups, unroll=2 + ): tVrV = cute.make_rmem_tensor(tVrV_shape, kv_smem_dtype) tVrV_cvt = cute.make_rmem_tensor(tVrV_cvt_shape, mma_dtype) @@ -1186,7 +1215,9 @@ class MixedInputFusedMultiHeadAttentionDecode: cvt_handle = cvt_producer.acquire_and_advance() cute.copy( - thr_store_v, tVrV_cvt, tVtV_cvt[cpy_dice + (cvt_handle.index,)] + thr_store_v, + tVrV_cvt, + tVtV_cvt[cpy_dice + (cvt_handle.index,)], ) cute.arch.fence_view_async_tmem_store() cvt_handle.commit() @@ -1223,9 +1254,7 @@ class MixedInputFusedMultiHeadAttentionDecode: k_handle = cvt_consumer.wait_and_advance(k_token) # Signal BMM2 to start if is_last_iter: - cute.arch.barrier_arrive( - barrier_id=mma_kq_nbar_id, number_of_threads=64 - ) + mma_kq_nbar.arrive() for mma_k in cutlass.range_constexpr(tAtK_cvt.shape[2]): cute.gemm( tiled_mma_kq, @@ -1245,7 +1274,7 @@ class MixedInputFusedMultiHeadAttentionDecode: if s > 0: for _ in cutlass.range_constexpr(tiles_dm * tiles_sk): cvt_consumer.advance() - cute.arch.barrier(barrier_id=mma_vp_nbar_id, number_of_threads=64) + mma_vp_nbar.arrive_and_wait() s_token = s_producer.try_acquire() ############################## @@ -1263,7 +1292,7 @@ class MixedInputFusedMultiHeadAttentionDecode: # Advance and wait for BMM1 for _ in cutlass.range_constexpr(tiles_dk): cvt_consumer.advance() - cute.arch.barrier(barrier_id=mma_kq_nbar_id, number_of_threads=64) + mma_kq_nbar.arrive_and_wait() # Sequence loop p_token = False @@ -1273,7 +1302,7 @@ class MixedInputFusedMultiHeadAttentionDecode: if s < iters_s - 1: for _ in cutlass.range_constexpr(tiles_dk): cvt_consumer.advance() - cute.arch.barrier(barrier_id=mma_kq_nbar_id, number_of_threads=64) + mma_kq_nbar.arrive_and_wait() p_token = p_consumer.try_wait() # BMM2 @@ -1286,9 +1315,7 @@ class MixedInputFusedMultiHeadAttentionDecode: v_handle = cvt_consumer.wait_and_advance(v_token) # Signal BMM1 to start if is_last_iter: - cute.arch.barrier_arrive( - barrier_id=mma_vp_nbar_id, number_of_threads=64 - ) + mma_vp_nbar.arrive() for mma_k in cutlass.range_constexpr(tAtV_cvt.shape[2]): cute.gemm( tiled_mma_vp, @@ -1325,7 +1352,9 @@ class MixedInputFusedMultiHeadAttentionDecode: tmem_load_atom_s = cute.make_copy_atom( tcgen05.Ld32x32bOp(tmem_op_repeat), acc_dtype ) - tmem_load_s = tcgen05.make_tmem_copy(tmem_load_atom_s, tCtS[mma_dice + (0,)]) + tmem_load_s = tcgen05.make_tmem_copy( + tmem_load_atom_s, tCtS[mma_dice + (0,)] + ) thr_load_s = tmem_load_s.get_slice(warpgroup_tidx) tmem_store_atom_o = cute.make_copy_atom( @@ -1356,7 +1385,7 @@ class MixedInputFusedMultiHeadAttentionDecode: # Partition colmax and initialize in RF tSsM = thr_load_s.partition_D(tCsM) # (CPY, #CPY_MMA, #CPY_M, #CPY_N) tSrM_prev = cute.make_rmem_tensor_like(tSsM) - tSrM_prev.fill(-Float32.inf) + tSrM_prev.fill(-cutlass.Float32.inf) # Partition colsum and initialize in RF # Each thread maintains a local colsum in RF, smem reduction happens after loop @@ -1364,7 +1393,7 @@ class MixedInputFusedMultiHeadAttentionDecode: tCsL ) # (CPY, #CPY_MMA, #CPY_M, #CPY_N, WARPS) tSrL = cute.make_rmem_tensor_like(tSsL[cpy_dice + (0,)]) - tSrL.fill(Float32(0)) + tSrL.fill(cutlass.Float32(0)) assert warp_threads >= cute.size(tSsM) @@ -1385,17 +1414,15 @@ class MixedInputFusedMultiHeadAttentionDecode: ) # Initialize O - tSrO.fill(Float32(0)) + tSrO.fill(cutlass.Float32(0)) cute.copy(thr_store_o, tSrO, tStO) # Initialize colsum and colmax in smem and wait if warpgroup_widx == 0 and lane_store_max: - tSsM[lane_idx] = -Float32.inf + tSsM[lane_idx] = -cutlass.Float32.inf if warpgroup_widx == 1 and lane_store_max: - tSsL[lane_idx] = Float32(0) - cute.arch.barrier( - barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads - ) + tSsL[lane_idx] = cutlass.Float32(0) + softmax_nbar.arrive_and_wait() # # Sequence loop @@ -1410,20 +1437,18 @@ class MixedInputFusedMultiHeadAttentionDecode: # Reduce colmax in warp RF tSrM = cute.make_rmem_tensor_like(tSsM) - tSrM_lane = Float32(0) # Avoid dynamic register indexing + tSrM_lane = cutlass.Float32(0) # Avoid dynamic register indexing for i in cutlass.range_constexpr(cute.size(tSrS)): - tSrM[i] = cute.arch.warp_redux_sync(tSrS[i], kind="fmax", nan=True) + tSrM[i] = warp_fmax(tSrS[i]) if i == lane_idx: tSrM_lane = tSrM[i] # Reduce colmax in smem if lane_store_max: - self.smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane) + smem_fmax(tSsM.iterator + tSsM.layout(lane_idx), tSrM_lane) # Wait for colmax then load - cute.arch.barrier( - barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads - ) + softmax_nbar.arrive_and_wait() cute.autovec_copy(tSsM, tSrM) # Compute online softmax @@ -1504,7 +1529,7 @@ class MixedInputFusedMultiHeadAttentionDecode: # # Reduce colsum in warp RF - tSrL_lane = Float32(0.0) + tSrL_lane = cutlass.Float32(0.0) for i in cutlass.range_constexpr(cute.size(tSrL)): tSrL[i] = cute.arch.warp_reduction_sum(tSrL[i]) if i == lane_idx: @@ -1515,16 +1540,12 @@ class MixedInputFusedMultiHeadAttentionDecode: tSsL[cpy_dice + (warpgroup_widx,)][lane_idx] = tSrL_lane # Wait for colsum - cute.arch.barrier( - barrier_id=softmax_nbar_id, number_of_threads=warpgroup_threads - ) + softmax_nbar.arrive_and_wait() if warpgroup_widx == 0 and lane_store_max and inbound_hr: # Load colsum and colmax sL_lane_wg = sL[0, lane_idx, None] - sL_lane = ( - sL_lane_wg[0] + sL_lane_wg[1] + sL_lane_wg[2] + sL_lane_wg[3] - ) + sL_lane = sL_lane_wg[0] + sL_lane_wg[1] + sL_lane_wg[2] + sL_lane_wg[3] sM_lane = sM[0, lane_idx] # Scale colmax @@ -1533,7 +1554,7 @@ class MixedInputFusedMultiHeadAttentionDecode: # Store colsum and colmax gL_partial[lane_idx] = sL_lane gM_partial[lane_idx] = sM_lane - self.gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane) + gmem_fmax(gM.iterator + gM.layout(lane_idx), sM_lane) o_handle = o_consumer.wait_and_advance() cute.copy(thr_load_s, tStO, tSrO) @@ -1563,16 +1584,16 @@ class MixedInputFusedMultiHeadAttentionDecode: o_partial: cute.Tensor, m_partial: cute.Tensor, l_partial: cute.Tensor, - scale_o: Float32, + scale_o: cutlass.Float32, ): d_blk_idx, coord_h, coord_b = cute.arch.block_idx() d_per_blk, _, _ = cute.arch.block_dim() d_idx, _, _ = cute.arch.thread_idx() coord_d = d_blk_idx * d_per_blk + d_idx - o_dhb = Float32(0) + o_dhb = cutlass.Float32(0) m_hb = m[coord_h, coord_b] - l_hb = Float32(0) + l_hb = cutlass.Float32(0) o_partial_dhb = o_partial[coord_d, coord_h, coord_b, None] m_partial_hb = m_partial[coord_h, coord_b, None] @@ -1591,44 +1612,6 @@ class MixedInputFusedMultiHeadAttentionDecode: return - @staticmethod - @cute.jit - def smem_fmax(ptr: Pointer, val: Float32): - # https://stackoverflow.com/a/72461459 - # Works with canonical NaN which warp_redux_sync(kind="fmax") should return - llvm.inline_asm( - None, - [ptr.llvm_ptr, val.ir_value()], - """{\n\t - .reg .pred p;\n\t - setp.lt.s32 p, $1, 0x0; - @p red.relaxed.shared::cta.min.u32 [$0], $1;\n\t - @!p red.relaxed.shared::cta.max.s32 [$0], $1;\n\t - }\n\t""", - "r,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - @staticmethod - @cute.jit - def gmem_fmax(ptr: Pointer, val: Float32): - llvm.inline_asm( - None, - [ptr.llvm_ptr, val.ir_value()], - """{\n\t - .reg .pred p;\n\t - setp.lt.s32 p, $1, 0x0; - @p red.relaxed.gpu.global.min.u32 [$0], $1;\n\t - @!p red.relaxed.gpu.global.max.s32 [$0], $1;\n\t - }\n\t""", - "l,r", - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - def run( batches: int = 1, @@ -1638,10 +1621,10 @@ def run( headdim: int = 512, block_scaledim: int = 512, kv_splits: int = 0, - q_dtype: Type[cutlass.Numeric] = BFloat16, - kv_dtype: Type[cutlass.Numeric] = Int8, - o_dtype: Type[cutlass.Numeric] = BFloat16, - acc_dtype: Type[cutlass.Numeric] = Float32, + q_dtype: Type[cutlass.Numeric] = cutlass.BFloat16, + kv_dtype: Type[cutlass.Numeric] = cutlass.Int8, + o_dtype: Type[cutlass.Numeric] = cutlass.BFloat16, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, tolerance: float = 0.1, scale_q: float = 1.0, scale_o: float = 1.0, @@ -1652,7 +1635,7 @@ def run( use_cold_l2: bool = False, **kwargs, ): - print(f"Running Blackwell SM100 Mixed Input FMHA Decode test with:") + print("Running Blackwell SM100 Mixed Input FMHA Decode test with:") print(f"\tbatches: {batches}, seqlen: {seqlen}") print(f"\theads_q: {heads_q}, heads_k: {heads_k}") print(f"\theaddim: {headdim}, block_scaledim: {block_scaledim}") @@ -1709,9 +1692,7 @@ def run( problem_shape = (batches, heads_q, heads_k, seqlen_k, headdim) - fmha.can_implement( - problem_shape, kv_splits, q_dtype, kv_dtype, o_dtype, acc_dtype - ) + fmha.can_implement(problem_shape, kv_splits, q_dtype, kv_dtype, o_dtype, acc_dtype) # # Allocate Tensors @@ -1719,24 +1700,24 @@ def run( torch.manual_seed(1111) def create_tensor(shape, dtype, init=True): - init_type = cutlass.torch.TensorInitType.RANDOM - init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) + init_type = cutlass_torch.TensorInitType.RANDOM + init_config = cutlass_torch.RandomInitConfig(min_val=-2, max_val=2) if init is False or init is None: - init_type = cutlass.torch.TensorInitType.SKIP + init_type = cutlass_torch.TensorInitType.SKIP init_config = None elif isinstance(init, int) or isinstance(init, float): - init_type = cutlass.torch.TensorInitType.SCALAR - init_config = cutlass.torch.ScalarInitConfig(value=init) + init_type = cutlass_torch.TensorInitType.SCALAR + init_config = cutlass_torch.ScalarInitConfig(value=init) elif isinstance(init, tuple) or isinstance(init, list): if len(init) == 2: - init_type = cutlass.torch.TensorInitType.RANDOM - init_config = cutlass.torch.RandomInitConfig( + init_type = cutlass_torch.TensorInitType.RANDOM + init_config = cutlass_torch.RandomInitConfig( min_val=init[0], max_val=init[1] ) if len(init) == 3: - init_type = cutlass.torch.TensorInitType.GAUSSIAN - init_config = cutlass.torch.RandomInitConfig( + init_type = cutlass_torch.TensorInitType.GAUSSIAN + init_config = cutlass_torch.RandomInitConfig( mean=init[0], std=init[1], scale=init[2] ) @@ -1809,7 +1790,7 @@ def run( scale_qs, scale_o, current_stream, - options=f"--opt-level 2", + options="--opt-level 2", ) print("Finished Compiling") @@ -1980,7 +1961,7 @@ if __name__ == "__main__": def parse_comma_separated_ints(s: str): try: - return tuple(int(x.strip()) for x in s.split(",")) + return tuple(cutlass.int(x.strip()) for x in s.split(",")) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." @@ -2052,7 +2033,7 @@ if __name__ == "__main__": parser.add_argument( "--acc_dtype", type=cutlass.dtype, - default=Float32, + default=cutlass.Float32, help="accumulator/reduction data type", ) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_prefill_d256.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_prefill_d256.py index 07a2a73da..f92964837 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d256.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_prefill_d256.py @@ -47,10 +47,10 @@ from cutlass.cute.typing import Int32, Int64, Float32 if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) from helpers import fmha_helpers as fmha_utils -from blackwell.mixed_input_fmha import prefill_helpers as prefill_utils +from cute.blackwell.kernel.attention.mixed_input_fmha import prefill_helpers as prefill_utils class MixedInputFusedMultiHeadAttentionPrefillD256: diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_prefill_d512.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_prefill_d512.py index 89713fef5..a11032091 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/mixed_input_fmha_prefill_d512.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/mixed_input_fmha_prefill_d512.py @@ -49,10 +49,10 @@ from cutlass.cute.typing import Int32, Int64, Float32 if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) from helpers import fmha_helpers as fmha_utils -from blackwell.mixed_input_fmha import prefill_helpers as prefill_utils +from cute.blackwell.kernel.attention.mixed_input_fmha import prefill_helpers as prefill_utils class MixedInputFusedMultiHeadAttentionPrefillD512: diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_fmha/prefill_helpers.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/prefill_helpers.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/mixed_input_fmha/prefill_helpers.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mixed_input_fmha/prefill_helpers.py diff --git a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py index 58e4a7f85..8d75ffa67 100644 --- a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp16.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp16.py @@ -51,9 +51,9 @@ from cutlass.cutlass_dsl import BaseDSL if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) -from blackwell.mla.mla_helpers import ( +from cute.blackwell.kernel.attention.mla.mla_helpers import ( ceil_div, MAX_SPLITS, LOG2_E, diff --git a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py index 2ae63af0f..6693521cf 100644 --- a/examples/python/CuTeDSL/blackwell/mla/mla_decode_fp8.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_decode_fp8.py @@ -51,9 +51,9 @@ from cutlass.cutlass_dsl import BaseDSL if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../../..")) -from blackwell.mla.mla_helpers import ( +from cute.blackwell.kernel.attention.mla.mla_helpers import ( ceil_div, MAX_SPLITS, LOG2_E, diff --git a/examples/python/CuTeDSL/blackwell/mla/mla_helpers.py b/examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_helpers.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/mla/mla_helpers.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/attention/mla/mla_helpers.py diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent.py diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_amax.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent_amax.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_amax.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent_amax.py diff --git a/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_prefetch.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent_prefetch.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent_prefetch.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/dense_blockscaled_gemm_persistent_prefetch.py diff --git a/examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/sm103_dense_blockscaled_gemm_persistent.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/sm103_dense_blockscaled_gemm_persistent.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/sm103_dense_blockscaled_gemm_persistent.py diff --git a/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_grouped_gemm/grouped_blockscaled_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_grouped_gemm/grouped_blockscaled_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockwise_gemm/blockwise_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockwise_gemm/blockwise_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockwise_gemm/contiguous_grouped_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/blockwise_gemm/contiguous_grouped_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockwise_gemm/contiguous_grouped_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockwise_gemm/masked_grouped_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/blockwise_gemm/masked_grouped_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/blockwise_gemm/masked_grouped_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py b/examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_alpha_beta_persistent.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_gemm_alpha_beta_persistent.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_alpha_beta_persistent.py diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent.py diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py b/examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent_dynamic.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_gemm_persistent_dynamic.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent_dynamic.py diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_persistent_prefetch.py b/examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent_prefetch.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_gemm_persistent_prefetch.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_persistent_prefetch.py diff --git a/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_software_pipeline.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/dense_gemm/dense_gemm_software_pipeline.py diff --git a/examples/python/CuTeDSL/distributed/README.md b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/README.md similarity index 100% rename from examples/python/CuTeDSL/distributed/README.md rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/README.md diff --git a/examples/python/CuTeDSL/distributed/all_reduce_one_shot_lamport.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_one_shot_lamport.py similarity index 100% rename from examples/python/CuTeDSL/distributed/all_reduce_one_shot_lamport.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_one_shot_lamport.py diff --git a/examples/python/CuTeDSL/distributed/all_reduce_simple.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_simple.py similarity index 100% rename from examples/python/CuTeDSL/distributed/all_reduce_simple.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_simple.py diff --git a/examples/python/CuTeDSL/distributed/all_reduce_tma.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_tma.py similarity index 100% rename from examples/python/CuTeDSL/distributed/all_reduce_tma.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_tma.py diff --git a/examples/python/CuTeDSL/distributed/all_reduce_two_shot_multimem.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_two_shot_multimem.py similarity index 100% rename from examples/python/CuTeDSL/distributed/all_reduce_two_shot_multimem.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/all_reduce_two_shot_multimem.py diff --git a/examples/python/CuTeDSL/distributed/distributed_all_gather_gemm_blackwell.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/distributed_all_gather_gemm_blackwell.py similarity index 100% rename from examples/python/CuTeDSL/distributed/distributed_all_gather_gemm_blackwell.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/distributed_all_gather_gemm_blackwell.py diff --git a/examples/python/CuTeDSL/distributed/distributed_gemm_all_reduce_blackwell.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/distributed_gemm_all_reduce_blackwell.py similarity index 100% rename from examples/python/CuTeDSL/distributed/distributed_gemm_all_reduce_blackwell.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/distributed_gemm_all_reduce_blackwell.py diff --git a/examples/python/CuTeDSL/distributed/distributed_gemm_reduce_scatter_blackwell.py b/examples/python/CuTeDSL/cute/blackwell/kernel/distributed/distributed_gemm_reduce_scatter_blackwell.py similarity index 100% rename from examples/python/CuTeDSL/distributed/distributed_gemm_reduce_scatter_blackwell.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/distributed/distributed_gemm_reduce_scatter_blackwell.py diff --git a/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/grouped_gemm/grouped_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/grouped_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/grouped_gemm/grouped_gemm.py diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/grouped_mixed_input_gemm.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/grouped_mixed_input_gemm.py index acab8a761..c806d547b 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/grouped_mixed_input_gemm.py @@ -47,9 +47,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../..")) -from blackwell.mixed_input_gemm.mixed_input_host_utils import ( +from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import ( create_tensors_for_contiguous_grouped_mixed_input_gemm as create_tensors, run_contiguous_grouped_ref_and_compare as run_ref_and_compare, ) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py index 43322e283..22e06df11 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/grouped_mixed_input_gemm_acc_scale.py @@ -48,9 +48,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../..")) -from blackwell.mixed_input_gemm.mixed_input_host_utils import ( +from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import ( create_tensors_for_contiguous_grouped_mixed_input_gemm as create_tensors, run_contiguous_grouped_ref_and_compare as run_ref_and_compare, ) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py index ac0dd51f4..d4bd60ab0 100644 --- a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_gemm.py @@ -47,9 +47,9 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "../..")) + sys.path.insert(0, os.path.join(current_dir, "../../../..")) -from blackwell.mixed_input_gemm.mixed_input_host_utils import ( +from cute.blackwell.kernel.mixed_input_gemm.mixed_input_host_utils import ( create_tensors_for_batched_mixed_input_gemm as create_tensors, run_batched_mixed_input_ref_and_compare as run_ref_and_compare, ) diff --git a/examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_host_utils.py b/examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_host_utils.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_host_utils.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/mixed_input_gemm/mixed_input_host_utils.py diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_persistent_scheduler.py b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_persistent_scheduler.py new file mode 100644 index 000000000..6cf6c0c9a --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_persistent_scheduler.py @@ -0,0 +1,695 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +MoE Persistent Tile Scheduler + +A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations. +This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo +(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile. + +Scenarios: +- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate) +- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden) + +Key design principle: +- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic) +- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension +- TMA descriptor management is handled by OnlineTensormapDescCreator +- The kernel orchestrates all three components +""" + +from typing import List, Tuple, Literal + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import ( + Boolean, + Int32, + Integer, + extract_mlir_values, + new_from_mlir_values, + const_expr, + dsl_user_op, +) +from cutlass._mlir import ir + +# ============================================================================= +# Work Tile Info +# ============================================================================= + +class MoEWorkTileInfo: + """ + Work tile information for MoE scheduler. + + Contains CTA-level tile information for executor warps: + - expert_idx: Which expert (-1 means invalid/done) + - tile_m_idx: CTA tile index along GEMM M dimension + - tile_n_idx: CTA tile index along GEMM N dimension + - k_tile_cnt: Number of CTA tiles along K dimension + + Note: These are CTA-level indices, not cluster-level. + tile_l_idx is always 0 for MoE, executor can hardcode it. + + For 2Dx3D (Forward): + M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed) + + For 2Dx2D (Backward): + M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic) + """ + + def __init__( + self, + expert_idx: Int32, # -1 means invalid tile + tile_m_idx: Int32, + tile_n_idx: Int32, + k_tile_cnt: Int32, + ): + self.expert_idx = expert_idx + self.tile_m_idx = tile_m_idx + self.tile_n_idx = tile_n_idx + self.k_tile_cnt = k_tile_cnt + + @property + def is_valid_tile(self) -> Boolean: + """Check if this is a valid work tile (expert_idx >= 0).""" + return self.expert_idx >= Int32(0) + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.expert_idx) + values.extend(extract_mlir_values(self.tile_m_idx)) + values.extend(extract_mlir_values(self.tile_n_idx)) + values.extend(extract_mlir_values(self.k_tile_cnt)) + return values + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo": + assert len(values) == 4 + return MoEWorkTileInfo( + expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]), + tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]), + tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]), + k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]), + ) + + def to_rmem_tensor(self): + """Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy.""" + rmem = cute.make_rmem_tensor((4,), Int32) + rmem[0] = self.expert_idx + rmem[1] = self.tile_m_idx + rmem[2] = self.tile_n_idx + rmem[3] = self.k_tile_cnt + return rmem + + @staticmethod + def from_rmem_tensor(rmem) -> "MoEWorkTileInfo": + """Unpack work tile info from an rmem tensor of shape (4,).""" + return MoEWorkTileInfo( + expert_idx=rmem[0], # type: ignore[arg-type] + tile_m_idx=rmem[1], # type: ignore[arg-type] + tile_n_idx=rmem[2], # type: ignore[arg-type] + k_tile_cnt=rmem[3], # type: ignore[arg-type] + ) + + +# ============================================================================= +# Scheduler Parameters +# ============================================================================= + +class MoEStaticSchedulerParams: + """ + Parameters for MoE tile scheduler. + + Uses unified semantics for both scenarios: + - expert_shape: (expert_cnt, intermediate, hidden) + + For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert + For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert + + Tile hierarchy: + - cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k) + - cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n) + - cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape + + This class is used both on host (for grid shape calculation) and on device + (stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk, + cluster_shape_mn) are NOT serialized to MLIR values. + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden) + cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k) + cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n) + ): + self.scenario = scenario + e, i, h = expert_shape + self.expert_cnt = e if isinstance(e, Int32) else Int32(e) + self.intermediate = i if isinstance(i, Int32) else Int32(i) + self.hidden = h if isinstance(h, Int32) else Int32(h) + self.cta_tile_shape_mnk = cta_tile_shape_mnk + self.cluster_shape_mn = cluster_shape_mn + + @property + def cluster_tile_m(self) -> int: + """Cluster tile size along M = cta_tile_m * cluster_m.""" + return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0] + + @property + def cluster_tile_n(self) -> int: + """Cluster tile size along N = cta_tile_n * cluster_n.""" + return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1] + + @property + def cta_tile_k(self) -> int: + """CTA tile size along K (same as cluster since cluster_k = 1).""" + return self.cta_tile_shape_mnk[2] + + def __extract_mlir_values__(self) -> List[ir.Value]: + """Only serialize runtime values, not codegen-time constants.""" + values = [] + values.extend(extract_mlir_values(self.expert_cnt)) + values.extend(extract_mlir_values(self.intermediate)) + values.extend(extract_mlir_values(self.hidden)) + return values + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams": + assert len(values) == 3 + return MoEStaticSchedulerParams( + scenario=self.scenario, + expert_shape=( + new_from_mlir_values(self.expert_cnt, [values[0]]), + new_from_mlir_values(self.intermediate, [values[1]]), + new_from_mlir_values(self.hidden, [values[2]]), + ), + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + ) + + @staticmethod + def get_grid_shape( + params: "MoEStaticSchedulerParams", + max_active_clusters: int, + ) -> Tuple[int, int, int]: + """ + Compute grid shape for kernel launch. + + Since host doesn't know token distribution across experts, + we launch max_active_clusters and let device-side scheduler + determine which tiles are valid. + """ + return ( + params.cluster_shape_mn[0], + params.cluster_shape_mn[1], + max_active_clusters, + ) + + +# ============================================================================= +# Scheduler (Device-side) +# ============================================================================= + +class MoEStaticPersistentTileScheduler: + """ + Persistent tile scheduler specialized for MoE grouped GEMM. + + This scheduler is ONLY responsible for tile iteration. It does NOT know + about tensor types, TMA descriptors, or domain conversion. Those concerns + are handled by MoESchedExtension and OnlineTensormapDescCreator respectively. + + Architecture: + - Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info + - Executor warps: Read work_tile_info from smem, use MoESchedExtension for + domain conversion and TMA desc selection + + The scheduler handles: + - 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden) + - 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis) + + Usage (Scheduler warp): + scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim) + work_tile_info = scheduler.initial_work_tile_info() + # Broadcast work_tile_info to smem... + + while work_tile_info.is_valid_tile: + # ... do work ... + work_tile_info = scheduler.advance_to_next_work() + # Broadcast work_tile_info to smem... + + Usage (Executor warps - via MoESchedExtension): + # Read work_tile_info from smem... + real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info) + real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info) + real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info) + """ + + def __init__( + self, + # Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes) + params: MoEStaticSchedulerParams, + # Runtime tensor for scheduling + offs: cute.Tensor, # (experts,) cumsum of token counts + # Scheduling state + num_persistent_clusters: Int32, + current_work_linear_idx: Int32, + cta_id_in_cluster: cute.Coord, + # Expert tracking state (for O(1) advance within same expert) + current_expert_idx: Int32, + expert_tile_start: Int32, # cumsum of tiles before current expert + expert_tile_end: Int32, # cumsum of tiles including current expert + ): + self.params = params + self.offs = offs + self.num_persistent_clusters = num_persistent_clusters + self._current_work_linear_idx = current_work_linear_idx + self.cta_id_in_cluster = cta_id_in_cluster + # Expert tracking + self.current_expert_idx = current_expert_idx + self.expert_tile_start = expert_tile_start + self.expert_tile_end = expert_tile_end + + # ========================================================================= + # Convenience accessors for params + # ========================================================================= + + @property + def scenario(self) -> Literal["2Dx3D", "2Dx2D"]: + return self.params.scenario + + @property + def expert_cnt(self) -> Int32: + return self.params.expert_cnt + + @property + def intermediate(self) -> Int32: + return self.params.intermediate + + @property + def hidden(self) -> Int32: + return self.params.hidden + + @property + def cta_tile_shape_mnk(self) -> Tuple[int, int, int]: + return self.params.cta_tile_shape_mnk + + @property + def cluster_shape_mn(self) -> Tuple[int, int]: + return self.params.cluster_shape_mn + + @property + def cluster_tile_m(self) -> int: + return self.params.cluster_tile_m + + @property + def cluster_tile_n(self) -> int: + return self.params.cluster_tile_n + + @property + def cta_tile_k(self) -> int: + return self.params.cta_tile_k + + # ========================================================================= + # MLIR value serialization (for SSA value passing in device code) + # ========================================================================= + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = [] + # Params (only runtime values are extracted) + values.extend(extract_mlir_values(self.params)) + # Runtime tensor for scheduling + values.extend(extract_mlir_values(self.offs)) + # Scheduling state + values.extend(extract_mlir_values(self.num_persistent_clusters)) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self.cta_id_in_cluster)) + # Expert tracking state + values.extend(extract_mlir_values(self.current_expert_idx)) + values.extend(extract_mlir_values(self.expert_tile_start)) + values.extend(extract_mlir_values(self.expert_tile_end)) + return values + + def __new_from_mlir_values__( + self, values: List[ir.Value] + ) -> "MoEStaticPersistentTileScheduler": + idx = 0 + + # Params (3 values: expert_cnt, intermediate, hidden) + new_params = new_from_mlir_values(self.params, values[idx:idx + 3]) + idx += 3 + + # Runtime tensor for scheduling (variable size) + offs_len = len(extract_mlir_values(self.offs)) + new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len]) + idx += offs_len + + # Scheduling state + new_num_persistent_clusters = new_from_mlir_values( + self.num_persistent_clusters, [values[idx]] + ) + idx += 1 + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[idx]] + ) + idx += 1 + + # cta_id_in_cluster (3 values for Coord) + new_cta_id_in_cluster = new_from_mlir_values( + self.cta_id_in_cluster, values[idx:idx + 3] + ) + idx += 3 + + # Expert tracking state + new_current_expert_idx = new_from_mlir_values( + self.current_expert_idx, [values[idx]] + ) + idx += 1 + new_expert_tile_start = new_from_mlir_values( + self.expert_tile_start, [values[idx]] + ) + idx += 1 + new_expert_tile_end = new_from_mlir_values( + self.expert_tile_end, [values[idx]] + ) + idx += 1 + + return MoEStaticPersistentTileScheduler( + params=new_params, + offs=new_offs, + num_persistent_clusters=new_num_persistent_clusters, + current_work_linear_idx=new_current_work_linear_idx, + cta_id_in_cluster=new_cta_id_in_cluster, + current_expert_idx=new_current_expert_idx, + expert_tile_start=new_expert_tile_start, + expert_tile_end=new_expert_tile_end, + ) + + # ========================================================================= + # Factory method + # ========================================================================= + + @staticmethod + @dsl_user_op + def create( + params: MoEStaticSchedulerParams, + offs: cute.Tensor, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ) -> "MoEStaticPersistentTileScheduler": + """ + Create a MoE persistent tile scheduler. + + :param params: Scheduler parameters (from host) + :param offs: Cumsum tensor of token counts per expert, shape (experts,) + :param block_idx: CUDA block index + :param grid_dim: CUDA grid dimensions + """ + num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size( + params.cluster_shape_mn, loc=loc, ip=ip + ) + + bidx, bidy, bidz = block_idx + current_work_linear_idx = Int32(bidz) + + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + + # Initialize expert tracking to "before expert 0" + # The first call to _get_work_tile_for_linear_idx will advance to the correct expert + current_expert_idx = Int32(0) + expert_tile_start = Int32(0) + expert_tile_end = Int32(0) # Will be computed on first access + + return MoEStaticPersistentTileScheduler( + params=params, + offs=offs, + num_persistent_clusters=num_persistent_clusters, + current_work_linear_idx=current_work_linear_idx, + cta_id_in_cluster=cta_id_in_cluster, + current_expert_idx=current_expert_idx, + expert_tile_start=expert_tile_start, + expert_tile_end=expert_tile_end, + ) + + # ========================================================================= + # Tile iteration methods + # ========================================================================= + + @dsl_user_op + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo: + """Get the initial work tile info.""" + return self._get_work_tile_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + @cute.jit + def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo: + """Advance to the next work tile and return its info.""" + self._current_work_linear_idx += self.num_persistent_clusters + return self._get_work_tile_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + @cute.jit + def _get_work_tile_for_linear_idx( + self, + cluster_linear_idx: Int32, + *, + loc=None, + ip=None + ) -> MoEWorkTileInfo: + """ + Convert a linear cluster index to MoEWorkTileInfo. + + Uses cached expert tracking state for O(1) fast path when staying + within the same expert. Advances expert state when needed. + + Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range. + """ + # Ensure expert tracking is initialized and up-to-date + self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip) + + # Check if valid (still within expert range after advancing) + is_valid = self.current_expert_idx < self.expert_cnt + + work_tile_info = MoEWorkTileInfo( + expert_idx=Int32(-1), + tile_m_idx=Int32(0), + tile_n_idx=Int32(0), + k_tile_cnt=Int32(0), + ) + + if is_valid: + # Compute local cluster tile indices within current expert + local_idx = cluster_linear_idx - self.expert_tile_start + cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx( + local_idx, self.current_expert_idx, loc=loc, ip=ip + ) + + # Convert cluster tile indices to CTA tile indices + # cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster + cta_tile_m_idx = ( + cluster_tile_m_idx * self.cluster_shape_mn[0] + + self.cta_id_in_cluster[0] # type: ignore[index] + ) + cta_tile_n_idx = ( + cluster_tile_n_idx * self.cluster_shape_mn[1] + + self.cta_id_in_cluster[1] # type: ignore[index] + ) + # Compute k_tile_cnt + k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip) + + work_tile_info = MoEWorkTileInfo( + expert_idx=self.current_expert_idx, + tile_m_idx=cta_tile_m_idx, + tile_n_idx=cta_tile_n_idx, + k_tile_cnt=k_tile_cnt, + ) + return work_tile_info + + @dsl_user_op + @cute.jit + def _advance_expert_to_contain( + self, + cluster_linear_idx: Int32, + *, + loc=None, + ip=None, + ) -> None: + """ + Advance expert tracking state until current expert contains cluster_linear_idx, + or we run out of experts. + + Fast path: If already in correct expert, no work needed. + """ + # Initialize expert_tile_end if this is the first call (expert_tile_end == 0) + if self.expert_tile_end == Int32(0): + tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip) + self.expert_tile_end = tiles_for_expert_0 + + # Advance until cluster_linear_idx < expert_tile_end or no more experts + while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt: + self.current_expert_idx = self.current_expert_idx + 1 + self.expert_tile_start = self.expert_tile_end + + if self.current_expert_idx < self.expert_cnt: + tiles_for_expert = self._compute_tiles_for_expert( + self.current_expert_idx, loc=loc, ip=ip + ) + self.expert_tile_end = self.expert_tile_end + tiles_for_expert + + @dsl_user_op + @cute.jit + def _compute_tiles_for_expert( + self, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Int32: + """Compute total cluster tiles for a given expert.""" + if const_expr(self.scenario == "2Dx2D"): + # Fixed M=hidden, N=intermediate + cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m + cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n + return cluster_tile_m_cnt * cluster_tile_n_cnt + else: # 2Dx3D + # Variable M (tokens), fixed N + tokens_i = self.offs[expert_idx] + if expert_idx > 0: + tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator] + cluster_tile_m_cnt = ( + tokens_i + self.cluster_tile_m - 1 # type: ignore[operator] + ) // self.cluster_tile_m + cluster_tile_n_cnt = ( + self.intermediate + self.cluster_tile_n - 1 + ) // self.cluster_tile_n + return cluster_tile_m_cnt * cluster_tile_n_cnt + + @dsl_user_op + @cute.jit + def _decompose_local_idx( + self, + local_idx: Int32, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32]: + """ + Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx). + + Uses "short side first" strategy: the shorter dimension changes faster. + This maximizes overlap between adjacent clusters for better L2 cache utilization. + + For example, if m_cnt=2, n_cnt=8: + - N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx + - Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ... + """ + # Get tile counts for M and N + cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts( + expert_idx, loc=loc, ip=ip + ) + cluster_tile_m_idx = -1 + cluster_tile_n_idx = -1 + + # Short side first: shorter dimension changes faster + # If m_cnt <= n_cnt: m is shorter, m changes faster + # local_idx = n_idx * m_cnt + m_idx + # If n_cnt < m_cnt: n is shorter, n changes faster + # local_idx = m_idx * n_cnt + n_idx + if cluster_tile_m_cnt <= cluster_tile_n_cnt: + # M is shorter or equal, M changes faster + cluster_tile_m_idx = local_idx % cluster_tile_m_cnt + cluster_tile_n_idx = local_idx // cluster_tile_m_cnt + else: + # N is shorter, N changes faster + cluster_tile_n_idx = local_idx % cluster_tile_n_cnt + cluster_tile_m_idx = local_idx // cluster_tile_n_cnt + + return (cluster_tile_m_idx, cluster_tile_n_idx) + + @dsl_user_op + @cute.jit + def _get_cluster_tile_counts( + self, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32]: + """Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert.""" + if const_expr(self.scenario == "2Dx2D"): + # Fixed M=hidden, N=intermediate + cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m + cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n + else: # 2Dx3D + # Variable M (tokens), fixed N + tokens_i = self.offs[expert_idx] + if expert_idx > 0: + tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator] + cluster_tile_m_cnt = ( + tokens_i + self.cluster_tile_m - 1 # type: ignore[operator] + ) // self.cluster_tile_m + cluster_tile_n_cnt = ( + self.intermediate + self.cluster_tile_n - 1 + ) // self.cluster_tile_n + return (cluster_tile_m_cnt, cluster_tile_n_cnt) + + @dsl_user_op + @cute.jit + def _compute_k_tile_cnt( + self, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Int32: + """ + Compute the number of K tiles for this expert. + + 2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k) + 2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k) + """ + if const_expr(self.scenario == "2Dx3D"): + # K is hidden (fixed) + return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k + else: # 2Dx2D + # K is tokens_i (variable per expert) + tokens_i = self.offs[expert_idx] + if expert_idx > cutlass.Int32(0): + tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator] + return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator] diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_sched_extension.py b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_sched_extension.py new file mode 100644 index 000000000..e07ce22b2 --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_sched_extension.py @@ -0,0 +1,443 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +MoE Scheduler Extension. + +Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level +domain conversion and TMA descriptor selection. This is the "glue" layer between: + +- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt) +- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace +- Kernel: orchestrates everything + +Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own +MoESchedExtension subclass with kernel-specific domain conversion logic. + +Key design principles: +- Unified interface: get_gmem_tensor() for all tensor types +- Free implementation: no role-based templates, each subclass writes its own logic +- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc. + are available as tools but not mandatory + +Architecture: + + Scheduler ──(produces)──> MoEWorkTileInfo + │ + expert_idx, tile_m, tile_n, k_cnt + │ + v + Extension ──(uses)──> OnlineTensormapDescCreator + │ │ + │ get_gmem_tensor() │ get_desc_ptr() + │ prefetch_for_expert() │ construct_and_write() + │ │ + └── internal calls ───────┘ + + Kernel (caller): the only place that knows all three exist +""" + +from abc import ABC, abstractmethod +from typing import Literal, Tuple, Union + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Pointer +from cutlass.cutlass_dsl import Int32 + +from dataclasses import dataclass + +from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF +from blackwell.kernel.moe.moe_utils import ( + OnlineTensormapDescCreator, + tensormap_ptr_for_copy, + compute_expert_token_range, + rewrite_tensor_shape, + prefetch_tma_descriptor, +) +from blackwell.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo + + +@dataclass(frozen=True) +class MoESchedExtension(ABC): + """ + Abstract base class for MoE scheduler extensions. + + Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA + descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.) + provides its own subclass with kernel-specific logic. + + The extension: + - Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval + - Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors + - Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors + + Subclasses are free to add any additional attributes in __init__ (scenario, + codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic + per tensor_name. No role-based templates or rigid patterns are imposed. + + Usage in kernel (caller): + ext = ConcreteSchedExtension(tensormap_ctor, scenario=...) + + while work_tile_info.is_valid_tile: + real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info) + real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info) + # Use real_a, desc_a in cute.copy ... + """ + + def __init__(self, tensormap_ctor: OnlineTensormapDescCreator): + super().__init__() + self.tensormap_ctor = tensormap_ctor + + @abstractmethod + def get_gmem_tensor( + self, + tensor_name: str, + gmem_tensor_in_moe_view: cute.Tensor, + offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]], + work_tile_info: MoEWorkTileInfo, + ) -> Tuple[cute.Tensor, "Pointer | None"]: + """ + Convert an MoE-view tensor to the real per-expert tensor for the + current work tile, and return the appropriate TMA descriptor pointer. + + The MoE-view tensor uses "fake" GEMM domain dimensions that span all + experts (e.g., fake_m = tokens_sum). This method slices/offsets it + to the current expert's actual region. + + :param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa") + :param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain + :param offs: Either a single cumsum tensor (experts,), or a tuple of + (offs_token, offs_padded) where offs_padded provides + padded offsets for scale-factor domain conversion. + :param work_tile_info: Current work tile from the scheduler + :return: (real_tensor, tma_desc_ptr_or_none) + - real_tensor: domain-offset and shape-rewritten tensor for this expert + - tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy), + or None if the caller should use the global TMA descriptor + """ + ... + + @abstractmethod + def prefetch_for_expert(self, expert_idx: Int32) -> None: + """ + Prefetch expert-wise TMA descriptors for the given expert. + + Called when the scheduler advances to a new expert, allowing the TMA + descriptor cache to be warmed up before the descriptors are needed. + + :param expert_idx: Index of the expert whose descriptors to prefetch + """ + ... + + +# ============================================================================= +# Grouped MM Extension +# ============================================================================= + + +class GroupedMmSchedExtension(MoESchedExtension): + """ + MoE scheduler extension for grouped_mm: handles tensors a, b, c. + + Domain conversion logic per scenario: + + 2Dx3D: + A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc + B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc + C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc + + 2Dx2D: + A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc + B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc + C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + tensormap_ctor: OnlineTensormapDescCreator, + ): + super().__init__(tensormap_ctor) + self.scenario = scenario + + @cute.jit + def get_gmem_tensor( + self, + tensor_name: str, + gmem_tensor_in_moe_view: cute.Tensor, + offs: cute.Tensor, + work_tile_info: MoEWorkTileInfo, + ): + expert_idx = work_tile_info.expert_idx + token_offset, tokens_i = compute_expert_token_range(offs, expert_idx) + + shape = gmem_tensor_in_moe_view.shape + c1 = cutlass.Int32(1) + + if cutlass.const_expr(self.scenario == "2Dx3D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (fake_m, k, 1) -> offset fake_m, global desc + real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index] + return (real, None) + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, k, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + elif cutlass.const_expr(tensor_name == "c"): + # C: (fake_m, n, 1) -> expert-wise desc, no offset + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (tokens_i, shape[1], c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("c", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(self.scenario == "2Dx2D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (m, fake_k, 1) -> expert-wise desc, no offset + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("a", expert_idx) + ) + return (real, desc) + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, fake_k, 1) -> expert-wise desc, no offset + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("b", expert_idx) + ) + return (real, desc) + elif cutlass.const_expr(tensor_name == "c"): + # C: (m, n, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + + raise ValueError("Invalid scenario or GEMM tensor name.") + + @cute.jit + def prefetch_for_expert(self, expert_idx: Int32) -> None: + if cutlass.const_expr(self.scenario == "2Dx3D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx)) + elif cutlass.const_expr(self.scenario == "2Dx2D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx)) + else: + raise ValueError("Invalid scenario.") + + + +# ============================================================================= +# Scaled Grouped MM Extension (block-scaled MoE) +# ============================================================================= + + +class ScaledGroupedMmSchedExtension(MoESchedExtension): + """ + MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb. + + Extends GroupedMmSchedExtension with scale-factor tensor support. + SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert + via tile_atom_to_shape_SF. + + The offs parameter is always a tuple (offs_token, offs_padded): + - offs_token: cumsum offsets in data (activation) domain + - offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity) + + sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size. + + Domain conversion logic per scenario: + + 2Dx3D: + A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc + B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc + C: (fake_m, n, 1) -> rewrite shape, expert-wise desc + SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset, + atom-tile, global desc + SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx, + atom-tile, global desc + + 2Dx2D: + A: (m, fake_k, 1) -> rewrite shape, expert-wise desc + B: (n, fake_k, 1) -> rewrite shape, expert-wise desc + C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc + SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset, + atom-tile, expert-wise desc + SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset, + atom-tile, expert-wise desc + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + tensormap_ctor: OnlineTensormapDescCreator, + ): + super().__init__(tensormap_ctor) + self.scenario = scenario + + @cute.jit + def get_gmem_tensor( + self, + tensor_name: str, + gmem_tensor_in_moe_view: cute.Tensor, + offs: Tuple[cute.Tensor, cute.Tensor], + work_tile_info: MoEWorkTileInfo, + ): + # Unpack the offs tuple + offs_token, offs_padded = offs + + expert_idx = work_tile_info.expert_idx + token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx) + padded_offset, padded_size_i = compute_expert_token_range( + offs_padded, expert_idx + ) + + shape = gmem_tensor_in_moe_view.shape + stride = gmem_tensor_in_moe_view.stride + c1 = cutlass.Int32(1) + sf_vec_size = self.tensormap_ctor.sf_vec_size + + if cutlass.const_expr(self.scenario == "2Dx3D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (fake_m, k, 1) -> offset fake_m, global desc + real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index] + return (real, None) + + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, k, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + + elif cutlass.const_expr(tensor_name == "c"): + # C: (fake_m, n, 1) -> expert-wise desc + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (tokens_i, shape[1], c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("c", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "sfa"): + # SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc + real = cute.domain_offset( + (padded_offset, 0, 0), gmem_tensor_in_moe_view + ) + per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = cute.make_tensor( + real.iterator, cute.make_layout(sf_layout.shape, stride=stride) + ) + return (real, None) + + elif cutlass.const_expr(tensor_name == "sfb"): + # SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = cute.make_tensor( + real.iterator, cute.make_layout(sf_layout.shape, stride=stride) + ) + return (real, None) + + elif cutlass.const_expr(self.scenario == "2Dx2D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (m, fake_k, 1) -> expert-wise desc + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("a", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, fake_k, 1) -> expert-wise desc + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("b", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "c"): + # C: (m, n, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + + elif cutlass.const_expr(tensor_name == "sfa"): + # SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc + per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("sfa", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "sfb"): + # SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc + per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("sfb", expert_idx) + ) + return (real, desc) + + raise ValueError("Invalid scenario or tensor name.") + + @cute.jit + def prefetch_for_expert(self, expert_idx: Int32) -> None: + if cutlass.const_expr(self.scenario == "2Dx3D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx)) + elif cutlass.const_expr(self.scenario == "2Dx2D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)) + else: + raise ValueError("Invalid scenario.") diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_utils.py b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_utils.py new file mode 100644 index 000000000..e21d0389f --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/moe_utils.py @@ -0,0 +1,910 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Online TMA Descriptor Construction Utilities. + +Provides utilities for dynamically creating TMA descriptors at kernel runtime +based on runtime-provided information (problem sizes, pointers, etc.). + +Key components: +- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods) +- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors +- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM +- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM +- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.) +- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type +- compute_expert_token_range: Compute per-expert token offset and count from offs +- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility +""" + +from abc import ABC, abstractmethod +from typing import Optional, Literal, Tuple, Union + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import AddressSpace, Pointer +from cutlass.cute.nvgpu import cpasync +from cutlass.cutlass_dsl import dsl_user_op, Int32 +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm +from cutlass._mlir.dialects import cute as _cute_ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir +from dataclasses import dataclass + +from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF + +TensormapDescBytes = 128 + +# ============================================================================= +# Pointer Utilities +# ============================================================================= + + +@dsl_user_op +@cute.jit +def spin_wait( + ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None +) -> None: + """ + Generic spin-wait. + Example usage: + # Wait until counter >= total_blocks + spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100) + + # Wait until flag == 1 + spin_wait(flag_ptr, lambda x: x == 1) + """ + current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip) + while not condition(current): + # Load with L1 cache bypass (ld.global.cg) + if cutlass.const_expr(fail_sleep_cycles > 0): + cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip) + current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip) + + +@dsl_user_op +def gmem_ptr_to_generic( + gmem_ptr: Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Pointer: + if gmem_ptr.memspace != AddressSpace.gmem: + raise ValueError( + f"gmem_ptr_to_generic requires pointer in gmem address space, " + f"got {gmem_ptr.memspace}" + ) + # Get LLVM pointer and cast to generic address space + llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip) + generic_llvm_ptr = llvm.addrspacecast( + llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip + ) + # Create a new cute.Pointer with generic address space, preserving alignment + return cute.make_ptr( + gmem_ptr.dtype, + generic_llvm_ptr, + AddressSpace.generic, + assumed_align=gmem_ptr.alignment, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def generic_ptr_to_gmem( + generic_ptr: Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Pointer: + if generic_ptr.memspace != AddressSpace.generic: + raise ValueError( + f"generic_ptr_to_gmem requires pointer in generic address space, " + f"got {generic_ptr.memspace}" + ) + # Get LLVM pointer and cast to gmem address space + llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip) + gmem_llvm_ptr = llvm.addrspacecast( + llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip + ) + # Create a new cute.Pointer with gmem address space, preserving alignment + return cute.make_ptr( + generic_ptr.dtype, + gmem_llvm_ptr, + AddressSpace.gmem, + assumed_align=generic_ptr.alignment, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None: + """ + Prefetch a TMA descriptor from global memory. + + This function prefetches the TMA descriptor pointed to by tma_desc_ptr + into the TMA descriptor cache. The pointer must be in generic or global + address space. If a gmem pointer is passed, it will be automatically + converted to generic address space. + + :param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory + :type tma_desc_ptr: Pointer + :raises ValueError: If pointer is not in generic or global address space + """ + if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic): + raise ValueError( + f"prefetch_tma_descriptor requires pointer in gmem or generic address space, " + f"got {tma_desc_ptr.memspace}" + ) + # Convert gmem pointer to generic if needed + if tma_desc_ptr.memspace == AddressSpace.gmem: + tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip) + # Convert cute.Pointer to LLVM pointer for prefetch + llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip) + from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch + + nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip) + + +def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer: + """Offset a pointer by a given number of bytes.""" + element_offset = byte_offset * 8 // ptr.dtype.width + return ptr + element_offset + + +@dsl_user_op +def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer: + """ + Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy. + + cute.copy requires the tma_desc_ptr to be in generic address space and + recast to TmaDescriptorTiledType. This utility performs both conversions. + + :param raw_ptr: Raw pointer to TMA descriptor in gmem + :type raw_ptr: Pointer + :return: Pointer compatible with cute.copy's tma_desc_ptr parameter + :rtype: Pointer + """ + generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip) + tma_desc_ptr_ty = _cute_ir.PtrType.get( + _cute_nvgpu_ir.TmaDescriptorTiledType.get(), + generic_ptr.memspace, + generic_ptr.alignment, + ) + return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value) + + +# ============================================================================= +# MoE Utilities +# ============================================================================= + + +@dsl_user_op +@cute.jit +def compute_expert_token_range( + offs: cute.Tensor, + expert_idx: Int32, + *, + loc=None, + ip=None, +) -> Tuple[Int32, Int32]: + """ + Compute token offset and count for a given expert from the cumsum offs tensor. + + :param offs: Cumulative sum tensor of token counts per expert, shape (experts,) + :param expert_idx: Index of the expert + :return: (token_offset, tokens_i) where token_offset is the start position + and tokens_i is the number of tokens for this expert + """ + token_offset = Int32(0) + if expert_idx > Int32(0): + token_offset = offs[expert_idx - 1] # type: ignore[assignment] + tokens_i = offs[expert_idx] - token_offset + return token_offset, tokens_i + + +@dsl_user_op +def rewrite_tensor_shape( + tensor: cute.Tensor, + new_shape: Tuple, + *, + loc=None, + ip=None, +) -> cute.Tensor: + """ + Rewrite tensor shape while keeping the same stride and iterator. + + This is primarily for debug friendliness - shows the actual expert's shape + instead of the fake global shape. No runtime overhead as it becomes + dead code in non-debug builds. + + :param tensor: Source tensor whose stride and iterator to preserve + :param new_shape: New shape to apply + :return: New tensor with the given shape but original stride and iterator + """ + new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip) + return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip) + + +# ============================================================================= +# TMA Descriptor Workspace Helper +# ============================================================================= + + +class TensormapWorkspace: + """ + Helper for linear workspace layout of TMA descriptors. + + Manages address calculation for a workspace buffer containing TMA descriptors + organized as: for each executor (e.g., expert or group), a fixed set of + named descriptor slots. + + Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...] + + Example: + # 2Dx3D MoE: only C is expert-wise + workspace = TensormapWorkspace(workspace_ptr, ["c"]) + + # 2Dx2D MoE: A and B are expert-wise + workspace = TensormapWorkspace(workspace_ptr, ["a", "b"]) + + # General grouped GEMM: all three tensors + workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"]) + """ + + def __init__(self, workspace_ptr: Pointer, slot_names: list): + """ + :param workspace_ptr: Pointer to the beginning of the workspace buffer + :param slot_names: Ordered list of tensor names, defining the slot layout + per executor. e.g., ["a", "b", "c"] + """ + self.workspace_ptr = workspace_ptr + self._name_to_slot = {name: i for i, name in enumerate(slot_names)} + self._slots_per_executor = len(slot_names) + + @cute.jit + def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + """ + Get the workspace pointer for a specific TMA descriptor. + + :param tensor_name: Name of the tensor (must be one of the slot_names) + :param executor_idx: Index of the executor (e.g., group_idx or expert_idx) + :return: Aligned pointer to the TMA descriptor in workspace + """ + if cutlass.const_expr(tensor_name not in self._name_to_slot): + raise ValueError( + f"Invalid tensor_name '{tensor_name}', " + f"expected one of {list(self._name_to_slot.keys())}" + ) + slot = self._name_to_slot[tensor_name] + byte_offset = ( + executor_idx * self._slots_per_executor + slot + ) * TensormapDescBytes + return ptr_offset_bytes(self.workspace_ptr, byte_offset).align( + TensormapDescBytes + ) + + @staticmethod + def size_bytes(num_slots: int, num_executors: int) -> int: + """ + Calculate workspace size in bytes. + + :param num_slots: Number of descriptor slots per executor + :param num_executors: Total number of executors (e.g., expert_cnt or group_cnt) + :return: Total workspace size in bytes + """ + return num_slots * num_executors * TensormapDescBytes + + +# ============================================================================= +# Online TMA Descriptor Creator (Abstract Base Class) +# ============================================================================= + + +@dataclass(frozen=True) +class OnlineTensormapDescCreator(ABC): + """ + Abstract base class for building TMA descriptors online (at kernel runtime). + + Subclasses store all needed parameters (both codegen-time configs and runtime + values) as explicit instance attributes in __init__. No dict-based APIs. + + Subclasses must implement exactly 2 abstract methods: + - construct_and_write: Build TMA descriptor(s) for one executor and write to workspace + - get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace + + To convert the raw pointer for use with cute.copy, callers should use the + standalone tensormap_ptr_for_copy() utility. + """ + + @abstractmethod + def construct_and_write(self, executor_idx: Int32, dependency=None) -> None: + """ + Build TMA descriptor(s) for one executor and write to workspace. + + :param executor_idx: Index of the executor (e.g., group_idx or expert_idx). + Semantics may vary by subclass when ``dependency`` is provided. + :param dependency: Optional pipeline consumer for inter-warp-group + synchronization. When provided, the subclass decides when to wait + (via ``dependency.wait_and_advance()``) and release. The subclass + also decides how to interpret ``executor_idx`` in this mode. + """ + ... + + @abstractmethod + def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + """ + Get the raw gmem pointer to a specific TMA descriptor in workspace. + + :param tensor_name: Name identifying which tensor's descriptor + :param executor_idx: Index of the executor (e.g., group_idx or expert_idx) + :return: Raw pointer (gmem) to the TMA descriptor + """ + ... + + +# ============================================================================= +# MoE Grouped GEMM Tensormap Constructor +# ============================================================================= + + +class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator): + """ + Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only). + + Non-expert-wise descriptors are passed directly at kernel launch. + This class only handles: + - 2Dx3D: C descriptors (expert-wise, to avoid write conflicts) + - 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis) + + All parameters are stored as explicit instance attributes (no dicts). + + Workspace layout: + - 2Dx3D: [C_0, C_1, ..., C_{n-1}] + - 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}] + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + # Codegen-time configs + a_dtype, + b_dtype, + c_dtype, + a_smem_layout, + b_smem_layout, + epi_smem_layout, + a_tma_op, + b_tma_op, + c_tma_op, + tiled_mma, + mma_tiler, + cluster_layout_vmnk_shape, + epi_tile, + # Runtime params + a_tensor: cute.Tensor, # fake GEMM domain A + b_tensor: cute.Tensor, # fake GEMM domain B + c_tensor: cute.Tensor, # fake GEMM domain C + offs: cute.Tensor, # (experts,) cumsum + workspace_ptr: Pointer, + ) -> None: + super().__init__() + self.scenario = scenario + # Codegen-time configs + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.c_dtype = c_dtype + self.a_smem_layout = a_smem_layout + self.b_smem_layout = b_smem_layout + self.epi_smem_layout = epi_smem_layout + self.a_tma_op = a_tma_op + self.b_tma_op = b_tma_op + self.c_tma_op = c_tma_op + self.tiled_mma = tiled_mma + self.mma_tiler = mma_tiler + self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape + self.epi_tile = epi_tile + # Runtime params + self.a_tensor = a_tensor + self.b_tensor = b_tensor + self.c_tensor = c_tensor + self.offs = offs + # Workspace with scenario-specific slot layout + if scenario == "2Dx3D": + self.workspace = TensormapWorkspace(workspace_ptr, ["c"]) + else: + self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"]) + + @staticmethod + def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int: + """Calculate workspace size in bytes for tensormap descriptors.""" + if scenario == "2Dx3D": + return TensormapWorkspace.size_bytes(1, expert_cnt) # only C + else: + return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B + + @cute.jit + def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + return self.workspace.get_ptr(tensor_name, executor_idx) + + @cute.jit + def construct_and_write(self, executor_idx: Int32, dependency=None) -> None: + """ + Create expert-wise tensormap descriptors for the given expert. + + - 2Dx3D: Creates C descriptor for this expert + - 2Dx2D: Creates A and B descriptors for this expert + """ + if cutlass.const_expr(self.scenario == "2Dx3D"): + self._construct_c_desc_2dx3d(executor_idx) + else: # 2Dx2D + self._construct_ab_descs_2dx2d(executor_idx) + + @cute.jit + def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None: + """ + 2Dx3D: Create expert-wise C descriptor. + C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1) + Slice fake_m -> (tokens_i, intermediate, 1) per expert. + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + + c_ptr = self.c_tensor.iterator + c_stride = self.c_tensor.stride + intermediate = self.c_tensor.shape[1] # type: ignore[index] + + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index] + c_layout_i = cute.make_layout( + (tokens_i, intermediate, c1), + stride=(c_stride[0], c_stride[1], c0), # type: ignore[index] + ) + c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i) + + tma_atom_c, _ = cpasync.make_tiled_tma_atom( + self.c_tma_op, + c_tensor_i, + self.epi_smem_layout, + self.epi_tile, + ) + cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx)) + + @cute.jit + def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None: + """ + 2Dx2D: Create expert-wise A and B descriptors. + A: (m, fake_k, 1) -> slice to (m, tokens_i, 1) + B: (n, fake_k, 1) -> slice to (n, tokens_i, 1) + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + # A tensor: (m, fake_k, 1) -> (m, tokens_i, 1) + a_ptr = self.a_tensor.iterator + a_stride = self.a_tensor.stride + a_m = self.a_tensor.shape[0] # type: ignore[index] + + a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index] + a_layout_i = cute.make_layout( + (a_m, tokens_i, c1), + stride=(a_stride[0], a_stride[1], c0), # type: ignore[index] + ) + a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i) + + tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A( + self.a_tma_op, + a_tensor_i, + self.a_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx)) + + # B tensor: (n, fake_k, 1) -> (n, tokens_i, 1) + b_ptr = self.b_tensor.iterator + b_stride = self.b_tensor.stride + b_n = self.b_tensor.shape[0] # type: ignore[index] + + b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index] + b_layout_i = cute.make_layout( + (b_n, tokens_i, c1), + stride=(b_stride[0], b_stride[1], c0), # type: ignore[index] + ) + b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i) + + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + self.b_tma_op, + b_tensor_i, + self.b_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx)) + + +# ============================================================================= +# MoE Scaled Grouped GEMM Tensormap Constructor +# ============================================================================= + + +class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator): + """ + Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled). + + .. py:attribute:: ChunkSize + :value: 128 + + Number of experts processed per chunk in the desc_init_kernel. + Must match the warp-group width (4 warps × 32 threads). + + Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support. + + Expert-wise descriptors only — non-expert-wise descriptors are passed + directly at kernel launch. + + Workspace layout: + - 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert) + - 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert) + + :param scenario: "2Dx3D" or "2Dx2D" + :param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4) + :param a_dtype: Data type for tensor A + :param b_dtype: Data type for tensor B + :param c_dtype: Data type for tensor C + :param sf_dtype: Data type for scale factors (SFA/SFB) + :param a_smem_layout: SMEM layout for A TMA + :param b_smem_layout: SMEM layout for B TMA + :param epi_smem_layout: SMEM layout for epilogue (C) TMA + :param sfa_smem_layout: SMEM layout for SFA TMA + :param sfb_smem_layout: SMEM layout for SFB TMA + :param a_tma_op: TMA operation for A + :param b_tma_op: TMA operation for B + :param c_tma_op: TMA operation for C (S2G store or reduce) + :param sfa_tma_op: TMA operation for SFA + :param sfb_tma_op: TMA operation for SFB + :param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction + :param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication) + :param mma_tiler: MMA tiler shape (M, N, K) + :param mma_tiler_sfb: MMA tiler shape for SFB + :param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast + :param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast + :param epi_tile: Epilogue tile shape + :param a_tensor: Fake GEMM domain A tensor + :param b_tensor: Fake GEMM domain B tensor + :param c_tensor: Fake GEMM domain C tensor + :param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout) + :param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout) + :param offs: (experts,) cumsum offsets in data domain + :param offs_padded: (experts,) cumsum offsets in padded scale domain + :param workspace_ptr: Pointer to workspace for TMA descriptors + :param expert_cnt: Total number of experts + """ + + ChunkSize = 128 + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + sf_vec_size: int, + # Codegen-time configs: dtypes + a_dtype, + b_dtype, + c_dtype, + sf_dtype, + # Codegen-time configs: SMEM layouts + a_smem_layout, + b_smem_layout, + epi_smem_layout, + sfa_smem_layout, + sfb_smem_layout, + # Codegen-time configs: TMA ops + a_tma_op, + b_tma_op, + c_tma_op, + sfa_tma_op, + sfb_tma_op, + # Codegen-time configs: MMA / cluster / tile + tiled_mma, + tiled_mma_sfb, + mma_tiler, + mma_tiler_sfb, + cluster_layout_vmnk_shape, + cluster_layout_sfb_vmnk_shape, + epi_tile, + # Runtime params + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + c_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + offs: cute.Tensor, + offs_padded: cute.Tensor, + workspace_ptr: Pointer, + expert_cnt: Optional[Union[Int32, int]] = None, + ) -> None: + super().__init__() + self.scenario = scenario + self.sf_vec_size = sf_vec_size + # Dtypes + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.c_dtype = c_dtype + self.sf_dtype = sf_dtype + # SMEM layouts + self.a_smem_layout = a_smem_layout + self.b_smem_layout = b_smem_layout + self.epi_smem_layout = epi_smem_layout + self.sfa_smem_layout = sfa_smem_layout + self.sfb_smem_layout = sfb_smem_layout + # TMA ops + self.a_tma_op = a_tma_op + self.b_tma_op = b_tma_op + self.c_tma_op = c_tma_op + self.sfa_tma_op = sfa_tma_op + self.sfb_tma_op = sfb_tma_op + # MMA / cluster / tile + self.tiled_mma = tiled_mma + self.tiled_mma_sfb = tiled_mma_sfb + self.mma_tiler = mma_tiler + self.mma_tiler_sfb = mma_tiler_sfb + self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape + self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape + self.epi_tile = epi_tile + # Runtime params + self.a_tensor = a_tensor + self.b_tensor = b_tensor + self.c_tensor = c_tensor + self.sfa_tensor = sfa_tensor + self.sfb_tensor = sfb_tensor + self.offs = offs + self.offs_padded = offs_padded + self.expert_cnt = expert_cnt + # Workspace with scenario-specific slot layout + if scenario == "2Dx3D": + self.workspace = TensormapWorkspace(workspace_ptr, ["c"]) + else: + self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"]) + + @staticmethod + def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int: + """Calculate workspace size in bytes for tensormap descriptors.""" + if scenario == "2Dx3D": + return TensormapWorkspace.size_bytes(1, expert_cnt) # C only + else: + return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB + + @cute.jit + def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + return self.workspace.get_ptr(tensor_name, executor_idx) + + @cute.jit + def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None: + """ + Create expert-wise tensormap descriptors for all experts. + + ``lane_in_group`` is the thread's position within its warp group + (0..ChunkSize-1). The method loops internally over all experts in + chunks of ``ChunkSize``, with two-phase pipeline synchronization + per chunk. + + Per-chunk execution: + + 1. Phase 1: Build descriptors that do NOT depend on ``offs_padded`` + (A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum. + 2. Barrier: ``consumer.wait_and_advance()`` — all threads participate. + 3. Phase 2: Build descriptors that depend on ``offs_padded`` + (SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer. + 4. Release: ``handle.release()`` — all threads participate. + + :param lane_in_group: Thread's position within the warp group (0..127). + :param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the + consumer for mbarrier sync, and the SMEM tensor of shape + ``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``. + """ + consumer, smem_offs_padded = dependency + assert self.expert_cnt is not None + num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize + + chunk_idx = cutlass.Int32(0) + while chunk_idx < num_chunks: + expert_idx = chunk_idx * self.ChunkSize + lane_in_group + in_bounds = expert_idx < self.expert_cnt + + # Phase 1: non-dependent descriptors + if in_bounds: + if cutlass.const_expr(self.scenario == "2Dx2D"): + self._construct_ab_descs_2dx2d(expert_idx) + else: + self._construct_c_desc_2dx3d(expert_idx) + + # All threads participate in barrier (fixed arrive count) + handle = consumer.wait_and_advance() + + # Phase 2: dependent descriptors (read padded offsets from SMEM) + if in_bounds: + if cutlass.const_expr(self.scenario == "2Dx2D"): + # smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]] + # padded_offset = smem[lane] (prev expert's cumulative) + # padded_end = smem[lane + 1] (this expert's cumulative) + padded_offset = smem_offs_padded[lane_in_group] + padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset + self._construct_sf_descs_2dx2d_direct( + expert_idx, padded_offset, padded_size_i + ) + + # All threads release (fixed arrive count) + handle.release() + + chunk_idx += 1 + + # ----------------------------------------------------------------- + # 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor) + # ----------------------------------------------------------------- + + @cute.jit + def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None: + """ + 2Dx3D: Create expert-wise C descriptor. + C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert. + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + c1 = cutlass.Int32(1) + + c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor) + c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index] + + tma_atom_c, _ = cpasync.make_tiled_tma_atom( + self.c_tma_op, + c_i, + self.epi_smem_layout, + self.epi_tile, + ) + cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx)) + + # ----------------------------------------------------------------- + # 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor) + # ----------------------------------------------------------------- + + @cute.jit + def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None: + """ + 2Dx2D: Create expert-wise A and B descriptors. + A: (m, fake_k, 1) -> slice to (m, tokens_i, 1) + B: (n, fake_k, 1) -> slice to (n, tokens_i, 1) + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + c1 = cutlass.Int32(1) + + # A: (m, fake_k, 1) -> domain_offset + rewrite shape + a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor) + a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index] + + tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A( + self.a_tma_op, + a_i, + self.a_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx)) + + # B: (n, fake_k, 1) -> domain_offset + rewrite shape + b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor) + b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index] + + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + self.b_tma_op, + b_i, + self.b_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx)) + + # ----------------------------------------------------------------- + # 2Dx2D: SFA, SFB descriptors (new for block-scaled) + # ----------------------------------------------------------------- + + @cute.jit + def _construct_sf_descs_2dx2d_direct( + self, + expert_idx: Int32, + padded_offset: Int32, + padded_size_i: Int32, + ) -> None: + """ + 2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed + padded offset and size. + + This variant allows the caller to supply padded offsets from SMEM + (in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM. + """ + c1 = cutlass.Int32(1) + + a_chunks_to_move = ( + padded_offset + // self.sf_vec_size + * cute.size(self.sfa_tensor, mode=[0]) + // 128 + ) + a_elems_to_move = ( + cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size + ) + b_chunks_to_move = ( + padded_offset + // self.sf_vec_size + * cute.size(self.sfb_tensor, mode=[0]) + // 128 + ) + b_elems_to_move = ( + cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size + ) + + per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index] + sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size) + sfa_i = cute.make_tensor( + self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i + ) + + tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A( + self.sfa_tma_op, + sfa_i, + self.sfa_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + internal_type=cutlass.Uint64, + ) + cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx)) + + per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index] + sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size) + sfb_i = cute.make_tensor( + self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i + ) + + tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B( + self.sfb_tma_op, + sfb_i, + self.sfb_smem_layout, + self.mma_tiler_sfb, + self.tiled_mma_sfb, + self.cluster_layout_sfb_vmnk_shape, + internal_type=cutlass.Uint64, + ) + cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx)) diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_grouped_mm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_grouped_mm.py new file mode 100644 index 000000000..13c6352e2 --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_grouped_mm.py @@ -0,0 +1,2019 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from typing import Optional, Tuple, Literal, Type, Union + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Pointer +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(current_dir, "../../..")) + +from blackwell.kernel.moe.moe_utils import ( + MoEGroupedGemmTensormapConstructor, +) +from blackwell.kernel.moe.moe_persistent_scheduler import ( + MoEStaticSchedulerParams, + MoEStaticPersistentTileScheduler, + MoEWorkTileInfo, +) +from blackwell.kernel.moe.moe_sched_extension import GroupedMmSchedExtension +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) + + +class GroupedGemmKernel: + """ + Grouped GEMM kernel for MoE operations. + + PyTorch interface (from torch.nn.functional.grouped_mm): + - 2Dx3D (Forward): mat_a(tokens_sum, K) x mat_b(experts, K, N) -> out(tokens_sum, N) + - 2Dx2D (Weight grad): mat_a(hidden, tokens_sum) x mat_b(tokens_sum, intermediate) -> out(experts, hidden, intermediate) + + Kernel interface uses "fake" GEMM MNKL domain: + + 2Dx3D: + A_cute: (gemm_fake_m, gemm_k, 1) # fake_m = tokens_sum, scheduler will offset + B_cute: (gemm_n, gemm_k, gemm_fake_l) # fake_l = expert_idx, scheduler will select + C_cute: (gemm_fake_m, gemm_n, 1) # fake_m = tokens_sum, scheduler will offset + + 2Dx2D: + A_cute: (gemm_m, gemm_fake_k, 1) # fake_k = tokens_sum, scheduler will offset + B_cute: (gemm_n, gemm_fake_k, 1) # fake_k = tokens_sum, scheduler will offset + C_cute: (gemm_m, gemm_n, gemm_fake_l) # fake_l = expert_idx, scheduler will select + + The scheduler handles the fake dimensions by: + - For fake_m/fake_k: Computing token_offset from offs and adjusting tensor coord + - For fake_l: Selecting expert slice via L coordinate + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + out_dtype: Type[cutlass.Numeric], + accumulate_on_output: bool, + separate_tensormap_init: bool = True, + fixed_expert_cnt: Optional[int] = None, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64), + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1), + use_2cta_instrs: bool = False, + ): + # User-provided configs + self.scenario = scenario + self.out_dtype = out_dtype + self.accumulate_on_output = accumulate_on_output + self.separate_tensormap_init = separate_tensormap_init + self.fixed_expert_cnt = fixed_expert_cnt # Not used yet... + self.acc_dtype = acc_dtype + self.mma_tiler_mnk = mma_tiler_mnk + self.cluster_shape_mn = (cluster_shape_mnk[0], cluster_shape_mnk[1]) + self.use_2cta_instrs = use_2cta_instrs + self.arch = "sm_100" + + if accumulate_on_output and scenario == "2Dx3D": + raise ValueError( + "Non-sense config: grad accumulate should only happens in 2Dx2D." + ) + + self._validate_mma_tiler_and_cluster_shape() + + # K dimension is deferred in _setup_attributes + self.mma_tiler = (mma_tiler_mnk[0], mma_tiler_mnk[1], 1) + + # CTA group for tcgen05 MMA + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # Occupancy and warp specialization + self.occupancy = 1 + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.sched_warp_id = 6 + self.threads_per_cta = 32 * len( + ( + self.mma_warp_id, + self.tma_warp_id, + self.sched_warp_id, + *self.epilogue_warp_id, + ) + ) + + # Barrier IDs for synchronization + self.epilog_sync_bar_id = 1 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 + + def _validate_mma_tiler_and_cluster_shape(self): + """Validate codegen-time MMA tiler and cluster shape constraints.""" + m, n = self.mma_tiler_mnk[0], self.mma_tiler_mnk[1] + cm, cn = self.cluster_shape_mn + + if self.use_2cta_instrs: + valid_m = [128, 256] + else: + valid_m = [64, 128] + if m not in valid_m: + raise ValueError( + f"mma_tiler M ({m}) must be one of {valid_m} " + f"(use_2cta_instrs={self.use_2cta_instrs})" + ) + + if n not in range(32, 257, 32): + raise ValueError(f"mma_tiler N ({n}) must be a multiple of 32 in [32, 256]") + + if cm % (2 if self.use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape M ({cm}) must be even when use_2cta_instrs=True" + ) + + is_pow2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if cm * cn > 16 or not is_pow2(cm) or not is_pow2(cn): + raise ValueError( + f"Invalid cluster_shape ({cm}, {cn}): each dim must be " + f"a power of 2, and product must be <= 16" + ) + + def _create_tiled_mma(self) -> cute.TiledMma: + """Create tiled MMA atom based on input dtypes and major modes.""" + return utils.sm100.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + def _setup_attributes(self) -> None: + """ + Set up configurations that depend on GEMM inputs. + + This method configures: + - tiled_mma with correct dtypes and major modes + - MMA/cluster/tile shapes + - Cluster layout + - Multicast CTA counts + - Epilogue tile shape + - Stage counts (ACC, A/B, C) + - SMEM layouts for A/B/C + - Tensor memory allocation columns + - TMA load bytes + """ + tiled_mma = self._create_tiled_mma() + + # Use user-specified K dimension directly from mma_tiler_mnk + # Verify K is a multiple of the MMA instruction's native K size + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + assert self.mma_tiler_mnk[2] % mma_inst_shape_k == 0, ( + f"mma_tiler K ({self.mma_tiler_mnk[2]}) must be a multiple of " + f"MMA instruction K ({mma_inst_shape_k})" + ) + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + self.mma_tiler_mnk[2], + ) + + # CTA tile shape + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Multicast CTA counts + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Epilogue tile shape (always use TMA store for MoE) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # C SMEM layout for epilogue + c_smem_layout = utils.sm100.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, 1 + ) + + self.smem_capacity = utils.get_smem_capacity_in_bytes() + + # Compute stage counts + self.num_acc_stage = 2 + self.num_c_stage = 2 # Always use TMA store for MoE + + a_smem_layout_stage_one = utils.sm100.make_smem_layout_a( + tiled_mma, self.mma_tiler, self.a_dtype, 1 + ) + b_smem_layout_stage_one = utils.sm100.make_smem_layout_b( + tiled_mma, self.mma_tiler, self.b_dtype, 1 + ) + + ab_bytes_per_stage = cute.size_in_bytes( + self.a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout_stage_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(self.c_dtype, c_smem_layout) + c_bytes = c_bytes_per_stage * self.num_c_stage + + self.num_sched_stages = 2 + sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32) + sched_bytes = sched_work_tile_bytes_per_stage * self.num_sched_stages + + fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes + + self.num_ab_stage = ( + self.smem_capacity // self.occupancy - fixed_overhead + ) // ab_bytes_per_stage + + # Refine epilogue stages with remaining SMEM + self.num_c_stage += ( + self.smem_capacity + - self.occupancy * ab_bytes_per_stage * self.num_ab_stage + - self.occupancy * fixed_overhead + ) // (self.occupancy * c_bytes_per_stage) + + # SMEM layouts + self.a_smem_layout_staged = utils.sm100.make_smem_layout_a( + tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage + ) + self.b_smem_layout_staged = utils.sm100.make_smem_layout_b( + tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage + ) + self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage + ) + + # Tensor memory allocation columns + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols( + tCtAcc_fake, arch=self.arch + ) + + # TMA load bytes + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + def get_workspace_size(self, expert_cnt: int) -> int: + """ + Workspace size for expert-wise TMA descriptors. + + 2Dx3D: Need C desc per expert -> expert_cnt * TensormapDescBytes + 2Dx2D: Need A and B desc per expert -> 2 * expert_cnt * TensormapDescBytes + """ + return MoEGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + + @cute.jit + def __call__( + self, + mat_a: cute.Tensor, # PyTorch mat_a + mat_b: cute.Tensor, # PyTorch mat_b + out: cute.Tensor, # PyTorch output + offs: cute.Tensor, # (experts,) cumsum + bias: Optional[cute.Tensor], + workspace: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ) -> None: + """ + Launch the grouped GEMM kernel. + + This method: + 1. Transforms PyTorch tensors to GEMM domain tensors + 2. Infers dtypes and major modes from GEMM domain tensors + 3. Sets up kernel attributes + 4. Creates TMA atoms for A, B, C + 5. Creates scheduler parameters + 6. Launches the kernel + """ + if cutlass.const_expr(bias is not None): + raise NotImplementedError("bias is not supported yet (align with torch).") + + # ===================================================================== + # Step 1: Transform PyTorch tensors to GEMM domain (fake MNKL) + # ===================================================================== + + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(self.scenario == "2Dx3D"): + # mat_a: (tokens_sum, hidden) -> A_cute: (fake_m, k, 1) + tokens_sum, hidden = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (tokens_sum, hidden, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + + # mat_b: (experts, hidden, intermediate) -> B_cute: (n, k, fake_l) + experts, hidden_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, hidden_b, experts), + stride=(mat_b.stride[2], mat_b.stride[1], mat_b.stride[0]), + ), + ) + + # out: (tokens_sum, intermediate) -> C_cute: (fake_m, n, 1) + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (tokens_sum, intermediate, c1), + stride=(out.stride[0], out.stride[1], c0), + ), + ) + + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + else: # 2Dx2D + # mat_a: (hidden, tokens_sum) -> A_cute: (m, fake_k, 1) + hidden, tokens_sum = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (hidden, tokens_sum, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + + # mat_b: (tokens_sum, intermediate) -> B_cute: (n, fake_k, 1) + tokens_sum_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, tokens_sum_b, c1), + stride=(mat_b.stride[1], mat_b.stride[0], c0), + ), + ) + + # out: (experts, hidden, intermediate) -> C_cute: (m, n, fake_l) + experts, hidden_c, intermediate_c = out.shape + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (hidden_c, intermediate_c, experts), + stride=(out.stride[1], out.stride[2], out.stride[0]), + ), + ) + + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + # ===================================================================== + # Step 2: Infer dtypes and major modes from GEMM domain tensors + # ===================================================================== + + self.a_dtype: Type[cutlass.Numeric] = a_gemm.element_type + self.b_dtype: Type[cutlass.Numeric] = b_gemm.element_type + self.c_dtype: Type[cutlass.Numeric] = c_gemm.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_gemm).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_gemm).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_gemm) + + # ===================================================================== + # Step 3: Setup kernel attributes + # ===================================================================== + + k = self.mma_tiler_mnk[2] + a_tile_bits = self.a_dtype.width * k + b_tile_bits = self.b_dtype.width * k + if cutlass.const_expr(a_tile_bits % 256 != 0): + raise ValueError( + f"a_dtype ({self.a_dtype.width}b) * mma_tiler K ({k}) = " + f"{a_tile_bits}b, must be a multiple of 256b (MMA instruction K width)" + ) + if cutlass.const_expr(b_tile_bits % 256 != 0): + raise ValueError( + f"b_dtype ({self.b_dtype.width}b) * mma_tiler K ({k}) = " + f"{b_tile_bits}b, must be a multiple of 256b (MMA instruction K width)" + ) + + self._setup_attributes() + tiled_mma = self._create_tiled_mma() + + # ===================================================================== + # Step 4: Create TMA atoms for A, B, C + # ===================================================================== + + # TMA load for A + a_op = utils.sm100.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_gemm, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for B + b_op = utils.sm100.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_gemm, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA store for C (or TMA reduce for accumulate_on_output) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + c_tma_op, c_gemm, epi_smem_layout, self.epi_tile + ) + + # ===================================================================== + # Step 5: Create MoEStaticSchedulerParams and compute grid + # ===================================================================== + + sched_params = MoEStaticSchedulerParams( + scenario=self.scenario, + expert_shape=(expert_cnt, intermediate_dim, hidden_dim), + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + ) + + grid = MoEStaticSchedulerParams.get_grid_shape( + sched_params, max_active_clusters + ) + + # ===================================================================== + # Step 5.5: Launch desc init kernel (if separate_tensormap_init) + # ===================================================================== + # + # Pre-initialize expert-wise TMA descriptors in workspace before + # the main kernel. Stream ordering guarantees completion before + # the main kernel starts. + # + # 2Dx3D: C desc per expert (C has dynamic fake_m per expert) + # 2Dx2D: A,B desc per expert (A,B have dynamic fake_k per expert) + # + + if cutlass.const_expr(self.separate_tensormap_init): + self.desc_init_kernel( + tiled_mma, + a_gemm, + b_gemm, + c_gemm, + offs, + expert_cnt, + workspace.iterator, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + ).launch( + grid=(expert_cnt, 1, 1), + block=[32, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + # ===================================================================== + # Step 6: Launch kernel + # ===================================================================== + + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + a_gemm, + b_gemm, + c_gemm, + offs, + sched_params, + workspace.iterator, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=self.occupancy, + ) + + # GPU device kernel: TMA descriptor initialization + @cute.kernel + def desc_init_kernel( + self, + tiled_mma: cute.TiledMma, + a_gemm: cute.Tensor, # GEMM domain A (fake MNKL) + b_gemm: cute.Tensor, # GEMM domain B (fake MNKL) + c_gemm: cute.Tensor, # GEMM domain C (fake MNKL) + offs: cute.Tensor, # (experts,) cumsum + expert_cnt: Union[cutlass.Int32, int], + workspace_ptr: Pointer, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + ): + """ + Separate kernel to pre-initialize expert-wise TMA descriptors. + + Grid: (expert_cnt, 1, 1) - one block per expert + Block: (32, 1, 1) - one warp per block + + Each block constructs and writes TMA descriptors for one expert + to the pre-allocated workspace buffer. + + 2Dx3D: Creates C descriptor per expert (C has dynamic fake_m per expert) + 2Dx2D: Creates A and B descriptors per expert (A/B have dynamic fake_k per expert) + """ + # ================================================================= + # Reconstruct TMA constructor with explicit attributes + # ================================================================= + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = utils.sm100.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = utils.sm100.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + tensormap_ctor = MoEGroupedGemmTensormapConstructor( + scenario=self.scenario, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + tiled_mma=tiled_mma, + mma_tiler=self.mma_tiler, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + offs=offs, + workspace_ptr=workspace_ptr, + ) + + # ================================================================= + # Each block constructs descriptors for one expert + # ================================================================= + + expert_idx, _, _ = cute.arch.block_idx() + tensormap_ctor.construct_and_write(expert_idx) + + # GPU device kernel: main GEMM kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + a_gemm: cute.Tensor, # GEMM domain A (fake MNKL) + b_gemm: cute.Tensor, # GEMM domain B (fake MNKL) + c_gemm: cute.Tensor, # GEMM domain C (fake MNKL) + offs: cute.Tensor, # (experts,) cumsum + sched_params: MoEStaticSchedulerParams, + workspace_ptr: Pointer, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + ): + """ + GPU device kernel for MoE Grouped GEMM. + + Warp specialization: + - Warps 0-3: Epilogue warps (TMEM -> RMEM -> SMEM -> GMEM) + - Warp 4: MMA warp (tcgen05.mma) + - Warp 5: TMA load warp (also prefetches expert-wise TMA descriptors) + + The kernel uses MoEStaticPersistentTileScheduler to iterate over tiles + across all experts. For each tile: + 1. TMA load warp fetches A/B tiles using get_gmem_tensor + 2. MMA warp performs matrix multiply-accumulate + 3. Epilogue warps store results using TMA store/reduce + + Note: Python objects holding MLIR values cannot be kernel params. + The following are constructed inside the kernel from individually-passed params: + - tensormap_ctor: MoEGroupedGemmTensormapConstructor (online tensormap builder) + - ext: GroupedMmSchedExtension (domain conversion + TMA desc selection) + """ + # ================================================================= + # Reconstruct dicts that can't be passed as kernel params + # ================================================================= + + # Construct TMA descriptor creator and scheduler extension + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = utils.sm100.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = utils.sm100.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + tensormap_ctor = MoEGroupedGemmTensormapConstructor( + scenario=self.scenario, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + tiled_mma=tiled_mma, + mma_tiler=self.mma_tiler, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + offs=offs, + workspace_ptr=workspace_ptr, + ) + ext = GroupedMmSchedExtension( + scenario=self.scenario, tensormap_ctor=tensormap_ctor + ) + + # ================================================================= + # Kernel setup + # ================================================================= + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # CTA/thread coordinates + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + tidx, _, _ = cute.arch.thread_idx() + + # ================================================================= + # SharedStorage + # ================================================================= + + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_stage * 2 + ] + sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4] + sched_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_sched_stages * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ================================================================= + # Pipelines + # ================================================================= + + # AB pipeline (TMA load → MMA) + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + # ACC pipeline (MMA → epilogue) + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = ( + len(self.epilogue_warp_id) * 32 * (2 if use_2cta_instrs else 1) + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Scheduler pipeline (sched warp → tma/mma/epi warps) + sched_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32) + num_sched_consumer_threads = 32 * len( + (self.tma_warp_id, self.mma_warp_id, *self.epilogue_warp_id) + ) + sched_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_sched_consumer_threads + ) + sched_pipeline = pipeline.PipelineAsync.create( + num_stages=self.num_sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + barrier_storage=storage.sched_mbar_ptr.data_ptr(), + defer_sync=True, + ) + + # TMEM allocator + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr.ptr, + ) + + # Cluster barrier sync after init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # ================================================================= + # SMEM tensors A/B + # ================================================================= + + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + + # Multicast masks + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # MMA fragments (SMEM → TMEM partitions) + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # Cluster wait before TMEM alloc + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # ================================================================= + # Scheduler warp (warp 6) + # ================================================================= + + sched_buf_ptr = storage.sched_buf.data_ptr() + sched_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=128 + ) + sched_buf_tensor = cute.make_tensor( + sched_buf_ptr, cute.make_layout((4, self.num_sched_stages), stride=(1, 4)) + ) + + if warp_idx == self.sched_warp_id: + scheduler = MoEStaticPersistentTileScheduler.create( + sched_params, offs, cute.arch.block_idx(), cute.arch.grid_dim() + ) + + sched_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_sched_stages + ) + + # Always produce the initial work_tile_info first + work_tile_info = scheduler.initial_work_tile_info() + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + # Iterate remaining tiles starting from the first advance + work_tile_info = scheduler.advance_to_next_work() + while work_tile_info.is_valid_tile: + ext.prefetch_for_expert(work_tile_info.expert_idx) + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + work_tile_info = scheduler.advance_to_next_work() + + # Write invalid sentinel (expert_idx = -1) so consumers exit + sched_pipeline.producer_acquire(sched_producer_state) + sentinel = MoEWorkTileInfo( + cutlass.Int32(-1), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0) + ) + rmem = sentinel.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + + sched_pipeline.producer_tail(sched_producer_state) + + # ================================================================= + # TMA load warp (warp 5) + # ================================================================= + + if warp_idx == self.tma_warp_id: + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get real GEMM domain tensors + TMA desc ptrs via extension + real_a, desc_ptr_a = ext.get_gmem_tensor( + "a", + tma_tensor_a, + offs, + work_tile_info, + ) + real_b, desc_ptr_b = ext.get_gmem_tensor( + "b", + tma_tensor_b, + offs, + work_tile_info, + ) + + # local_tile for this tile's A and B + gA_mkl = cute.local_tile( + real_a, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + gB_nkl = cute.local_tile( + real_b, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + + # MMA partition for TMA + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA_mkl) + tCgB = thr_mma.partition_B(gB_nkl) + + # TMA partition + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # Slice to current tile coords (L=0 for MoE, expert already selected) + mma_tile_m = work_tile_info.tile_m_idx // cute.size( + tiled_mma.thr_id.shape + ) + tAgA_slice = tAgA[(None, mma_tile_m, None, 0)] + tBgB_slice = tBgB[(None, work_tile_info.tile_n_idx, None, 0)] + + # TMA load loop + ab_producer.reset() + peek_ab_empty_status = ab_producer.try_acquire() + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_producer.acquire_and_advance(peek_ab_empty_status) + peek_ab_empty_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_empty_status = ab_producer.try_acquire() + cute.copy( + tma_atom_a, + tAgA_slice[(None, handle.count)], + tAsA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_a, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, handle.count)], + tBsB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_b, + mcast_mask=b_full_mcast_mask, + ) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + ab_producer.tail() + + # ================================================================= + # MMA warp (warp 4) + # ================================================================= + + if warp_idx == self.mma_warp_id: + # Retrieve TMEM + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + if is_leader_cta: + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # AB consumer mainloop + ab_consumer.reset() + peek_ab_full_status = cutlass.Boolean(1) + if k_tile_cnt > 0: + peek_ab_full_status = ab_consumer.try_wait() + acc_pipeline.producer_acquire(acc_producer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_consumer.wait_and_advance(peek_ab_full_status) + peek_ab_full_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0) + tile_crd = (None, None, None, handle.index) + cute.gemm( + tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc + ) + handle.release() + + if k_tile_cnt > 0: + acc_pipeline.producer_commit(acc_producer_state) + if k_tile_cnt > 0: + acc_producer_state.advance() + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + acc_pipeline.producer_tail(acc_producer_state) + + # ================================================================= + # SMEM tensor C (allocated after MMA section, same as dense) + # ================================================================= + + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + + # ================================================================= + # Epilogue warps (warps 0-3) + # ================================================================= + + if warp_idx < self.mma_warp_id: + # Allocate TMEM + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilogue_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + + # Epilogue copy setup (same for all tiles, depends only on shapes) + # Transform ACC layout: ((ATOM_M, ATOM_N), MMA_M, MMA_N, STAGE) + # -> ((ATOM_M, MMA_M), (ATOM_N, MMA_N), STAGE) + tCtAcc_transformed = transform_partitioned_tensor_layout(tCtAcc_base) + + num_tiles_executed = cutlass.Int32(0) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + # Get real C tensor + TMA desc ptr via extension + real_c, desc_ptr_c = ext.get_gmem_tensor( + "c", + tma_tensor_c, + offs, + work_tile_info, + ) + + # local_tile + partition for C + gC_mnl = cute.local_tile( + real_c, + cute.slice_(self.mma_tiler, (None, None, 0)), + (None, None, None), + ) + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgC = thr_mma.partition_C(gC_mnl) + tCgC_transformed = transform_partitioned_tensor_layout(tCgC) + + mma_tile_coord_mnl = ( + work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + work_tile_info.tile_n_idx, + cutlass.Int32(0), + ) + + # Partition for TMEM → RMEM copy + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + epilogue_tmem_copy_and_partition( + self, + tidx, + tCtAcc_transformed, + tCgC_transformed, + epi_tile, + use_2cta_instrs, + ) + ) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( + self, tiled_copy_t2r, tTR_rC, tidx, sC + ) + + # TMA partition for C store (with expert-wise desc_ptr) + tCgC_epi = cute.flat_divide(tCgC_transformed, epi_tile) + bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + + # Set TMEM buffer for current tile + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # Wait for accumulator buffer full + if k_tile_cnt > 0: + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Store accumulator to global memory in subtiles + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = num_tiles_executed * subtile_cnt + + for subtile_idx in range(subtile_cnt): + # TMEM → RMEM + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Convert to output dtype + acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc)) + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + else: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = acc_vec.to(self.c_dtype) + tRS_rC.store(acc_vec) + + # RMEM → SMEM + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)] + ) + cute.arch.fence_proxy("async.shared", space="cta") + epilog_sync_barrier.arrive_and_wait() + + # SMEM → GMEM (TMA store or TMA reduce) + if warp_idx == self.epilogue_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=desc_ptr_c, + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + # Release accumulator buffer + if k_tile_cnt > 0: + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + num_tiles_executed += cutlass.Int32(1) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + # Wait for C store complete + c_pipeline.producer_tail() + + # Free TMEM + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +# ============================================================================= +# Host Validation +# ============================================================================= + +from dataclasses import dataclass, field +import re + +import numpy as np +import torch +import cutlass.torch as cutlass_torch + + +def torch_version_lt(major: int, minor: int) -> bool: + """Best-effort torch version check that tolerates local build suffixes.""" + match = re.match(r"^\s*(\d+)\.(\d+)", torch.__version__) + if match is None: + print( + "WARNING: failed to parse torch.__version__, " + "falling back to torch._grouped_mm host reference." + ) + return True + version = (int(match.group(1)), int(match.group(2))) + return version < (major, minor) + + +@dataclass +class ProblemDesc: + tokens: int + experts: int + top_k_select: int + balance_route: bool + hidden: int + intermediate: int + scenario: Literal["2Dx3D", "2Dx2D"] + ab_dtype: torch.dtype + out_dtype: torch.dtype + acc_dtype: torch.dtype + grad_accumulate: bool = False + # GEMM-domain layout control (which axis is stride-1) + # A (M, K): "k_major" (default) or "m_major" + # B (N, K): "n_major" (default) or "k_major" + # C (M, N): "n_major" (default) or "m_major" + a_layout: Literal["k_major", "m_major"] = "k_major" + b_layout: Literal["k_major", "n_major"] = "n_major" + c_layout: Literal["m_major", "n_major"] = "n_major" + + def __str__(self) -> str: + d = lambda t: str(t).split(".")[-1] + route = "balanced" if self.balance_route else "random" + return ( + f"ProblemDesc: {self.scenario} | tokens={self.tokens} experts={self.experts} " + f"top_k={self.top_k_select} route={route} | hidden={self.hidden} intermediate={self.intermediate} | " + f"{d(self.ab_dtype)}->{d(self.out_dtype)}(acc={d(self.acc_dtype)}) grad_acc={self.grad_accumulate} | " + f"layout: A={self.a_layout} B={self.b_layout} C={self.c_layout}" + ) + + +@dataclass +class ImplDesc: + mma_tiler_mnk: Tuple[int, int, int] + cluster_shape_mnk: Tuple[int, int, int] + use_2cta_instrs: bool + static_expert_cnt: Optional[int] = None + separate_tensormap_init: bool = True + + def __str__(self) -> str: + tile = ",".join(map(str, self.mma_tiler_mnk)) + cluster = ",".join(map(str, self.cluster_shape_mnk)) + static_e = ( + self.static_expert_cnt if self.static_expert_cnt is not None else "dynamic" + ) + return ( + f"ImplDesc: tile={tile} cluster={cluster} 2cta={self.use_2cta_instrs} | " + f"static_E={static_e} sep_tmap={self.separate_tensormap_init}" + ) + + +@dataclass +class MiscDesc: + perf_run: bool = False + perf_e2e: bool = False + compare_with_bmm: bool = False + compare_with_sol: bool = False + no_torch_210: bool = field(init=False) + + def __post_init__(self): + self.no_torch_210 = torch_version_lt(2, 10) + if self.perf_e2e and not self.perf_run: + raise ValueError("--perf_e2e requires --perf_run to be enabled.") + if self.perf_e2e and self.compare_with_sol: + raise ValueError( + "--perf_e2e and --compare_with_sol are mutually exclusive." + ) + + def __str__(self) -> str: + ref = "bmm" if self.compare_with_bmm else "grouped_mm" + return ( + f"MiscDesc: perf={self.perf_run} perf_e2e={self.perf_e2e} " + f"ref={ref} sol={self.compare_with_sol} no_torch_210={self.no_torch_210}" + ) + + +def l2_flush(size_mb: int = 400) -> None: + """Best-effort L2 flush by touching a large temporary tensor.""" + num_bytes = size_mb * 1024 * 1024 + flush_buf = torch.randint(0, 256, (num_bytes,), dtype=torch.uint8, device="cuda") + del flush_buf + + +class GroupedGemmTester: + def __init__(self, problem: ProblemDesc, impl: ImplDesc, misc: MiscDesc): + self.problem = problem + self.impl = impl + self.misc = misc + + self.tokens_after_repeat = problem.tokens * problem.top_k_select + self.expert_cnt = problem.experts + self.hidden = problem.hidden + self.intermediate = problem.intermediate + + self.A_tensor: torch.Tensor = None + self.B_tensor: torch.Tensor = None + self.C_tensor: torch.Tensor = None + self.C_ref_tensor: torch.Tensor = None + self.offs_tensor: torch.Tensor = None + self.workspace_tensor: torch.Tensor = None + + # This should be a common func + self.temp_type_mapping = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + } + + def _generate_offs(self) -> torch.Tensor: + """Generate group-end offsets. + + Some experts may receive 0 tokens (valid in real MoE routing). + """ + total = self.tokens_after_repeat + expert_cnt = self.expert_cnt + + if self.problem.balance_route: + base = total // expert_cnt + remainder = total % expert_cnt + sizes = [base + (1 if i < remainder else 0) for i in range(expert_cnt)] + else: + proportions = np.random.dirichlet([0.5] * expert_cnt) + raw = np.floor(proportions * total).astype(int) + deficit = total - raw.sum() + while deficit > 0: + idx = int(np.argmin(raw / (proportions * total + 1e-12))) + raw[idx] += 1 + deficit -= 1 + while deficit < 0: + ratios = np.where( + raw > 0, + raw / (proportions * total + 1e-12), + -np.inf, + ) + idx = int(np.argmax(ratios)) + raw[idx] -= 1 + deficit += 1 + sizes = raw.tolist() + + assert sum(sizes) == total + + cum = 0 + offsets = [] + for s in sizes: + cum += s + offsets.append(cum) + return torch.tensor(offsets, dtype=torch.int32, device="cuda") + + def _generate_tensor(self, shape: Tuple) -> torch.Tensor: + if self.misc.perf_run: + return torch.randn(shape, dtype=self.problem.ab_dtype, device="cuda") + else: + return torch.randint(-1, 2, shape, device="cuda", dtype=torch.int8).to( + self.problem.ab_dtype + ) + + def _get_stream(self) -> cuda.CUstream: + return cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + def generate_inputs(self) -> None: + self.offs_tensor = self._generate_offs() + + tokens = self.tokens_after_repeat + hidden = self.hidden + intermediate = self.intermediate + expert_cnt = self.expert_cnt + + if self.problem.scenario == "2Dx3D": + # PyTorch shape: A (tokens, hidden), B (expert_cnt, hidden, intermediate), C (tokens, intermediate) + # GEMM domain: A (M=tokens, K=hidden), B (N=intermediate, K=hidden), C (M=tokens, N=intermediate) + + # GEMM A: k_major → K(hidden) stride-1; m_major → M(tokens) stride-1 + if self.problem.a_layout == "k_major": + self.A_tensor = self._generate_tensor((tokens, hidden)) + else: + self.A_tensor = self._generate_tensor((hidden, tokens)).T + + # GEMM B: n_major → N(intermediate) stride-1; k_major → K(hidden) stride-1 + if self.problem.b_layout == "n_major": + self.B_tensor = self._generate_tensor( + (expert_cnt, hidden, intermediate) + ) + else: + self.B_tensor = self._generate_tensor( + (expert_cnt, intermediate, hidden) + ).transpose(1, 2) + + # GEMM C: n_major → N(intermediate) stride-1; m_major → M(tokens) stride-1 + if self.problem.c_layout == "n_major": + self.C_tensor = torch.full( + (tokens, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: + self.C_tensor = torch.full( + (intermediate, tokens), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).T + + elif self.problem.scenario == "2Dx2D": + # PyTorch shape: mat_a (hidden, tokens), mat_b (tokens, intermediate), out (expert_cnt, hidden, intermediate) + # out matches weight shape (expert_cnt, hidden, intermediate) for weight gradient + # GEMM domain: A (M=hidden, K=tokens), B (N=intermediate, K=tokens), C (M=hidden, N=intermediate) + + # GEMM A: k_major → K(tokens) stride-1; m_major → M(hidden) stride-1 + if self.problem.a_layout == "k_major": + self.A_tensor = self._generate_tensor((hidden, tokens)) + else: + self.A_tensor = self._generate_tensor((tokens, hidden)).T + + # GEMM B: n_major → N(intermediate) stride-1; k_major → K(tokens) stride-1 + if self.problem.b_layout == "n_major": + self.B_tensor = self._generate_tensor((tokens, intermediate)) + else: + self.B_tensor = self._generate_tensor((intermediate, tokens)).T + + # GEMM C: n_major → N(intermediate) stride-1; m_major → M(hidden) stride-1 + if self.problem.c_layout == "n_major": + self.C_tensor = torch.full( + (expert_cnt, hidden, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: + self.C_tensor = torch.full( + (expert_cnt, intermediate, hidden), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).transpose(1, 2) + if self.problem.grad_accumulate: + self.C_tensor *= 0 + else: + raise ValueError(f"Unknown scenario: {self.problem.scenario}") + + def compute_reference(self) -> None: + if self.misc.perf_run: + return + if self.misc.compare_with_bmm: + self._compute_reference_bmm() + else: + self._compute_reference_grouped_mm() + + def _compute_reference_grouped_mm(self) -> None: + grouped_mm_op = ( + torch._grouped_mm + if self.misc.no_torch_210 + else torch.nn.functional.grouped_mm + ) + self.C_ref_tensor = grouped_mm_op( + self.A_tensor, + self.B_tensor, + offs=self.offs_tensor, + out_dtype=self.problem.out_dtype, + ) + + def _compute_reference_bmm(self) -> None: + """Manual per-expert torch.mm loop as reference (avoids grouped_mm bugs on small cases).""" + # Preallocate the full reference output to avoid keeping both the per-expert + # results list and the final cat/stack result alive at the same time. + self.C_ref_tensor = torch.empty_like(self.C_tensor) + + prev = 0 + for i in range(self.expert_cnt): + cur = self.offs_tensor[i].item() + if self.problem.scenario == "2Dx3D": + # A (tokens, hidden), B (E, hidden, intermediate) → C_i (tokens_i, intermediate) + a_slice = self.A_tensor[prev:cur, :] + b_slice = self.B_tensor[i] + self.C_ref_tensor[prev:cur, :].copy_(torch.mm(a_slice, b_slice)) + else: # 2Dx2D + # A (hidden, tokens), B (tokens, intermediate) → C_i (hidden, intermediate) + a_slice = self.A_tensor[:, prev:cur] + b_slice = self.B_tensor[prev:cur, :] + self.C_ref_tensor[i, :, :].copy_(torch.mm(a_slice, b_slice)) + prev = cur + + def create_kernel(self) -> GroupedGemmKernel: + return GroupedGemmKernel( + scenario=self.problem.scenario, + out_dtype=self.temp_type_mapping[self.problem.out_dtype], + accumulate_on_output=self.problem.grad_accumulate + and self.problem.scenario == "2Dx2D", + separate_tensormap_init=self.impl.separate_tensormap_init, + fixed_expert_cnt=self.impl.static_expert_cnt, + acc_dtype=self.temp_type_mapping[self.problem.acc_dtype], + mma_tiler_mnk=self.impl.mma_tiler_mnk, + cluster_shape_mnk=self.impl.cluster_shape_mnk, + use_2cta_instrs=self.impl.use_2cta_instrs, + ) + + def run_kernel(self, kernel: GroupedGemmKernel) -> Optional[float]: + """Run our CuTe kernel. + + Returns: + Average kernel time in ms when perf_e2e is enabled, None otherwise. + """ + workspace_size = kernel.get_workspace_size(self.expert_cnt) + self.workspace_tensor = torch.full( + (workspace_size,), 255, dtype=torch.uint8, device="cuda" + ) + + torch.cuda.synchronize() + + ab_cutlass_dtype = self.temp_type_mapping[self.problem.ab_dtype] + out_cutlass_dtype = self.temp_type_mapping[self.problem.out_dtype] + + a_cute, self.A_tensor = cutlass_torch.cute_tensor_like( + self.A_tensor, ab_cutlass_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_cute, self.B_tensor = cutlass_torch.cute_tensor_like( + self.B_tensor, ab_cutlass_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_cute, self.C_tensor = cutlass_torch.cute_tensor_like( + self.C_tensor, out_cutlass_dtype, is_dynamic_layout=True, assumed_align=16 + ) + is_dynamic_expert_cnt = self.impl.static_expert_cnt is None + offs_cute, self.offs_tensor = cutlass_torch.cute_tensor_like( + self.offs_tensor, + cutlass.Int32, + is_dynamic_layout=is_dynamic_expert_cnt, + assumed_align=16, + ) + workspace_cute, self.workspace_tensor = cutlass_torch.cute_tensor_like( + self.workspace_tensor, + cutlass.Uint8, + is_dynamic_layout=is_dynamic_expert_cnt, + assumed_align=128, + ) + + # Query max active clusters from hardware + cluster_size = self.impl.cluster_shape_mnk[0] * self.impl.cluster_shape_mnk[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + print(f"A_tensor: {tuple(self.A_tensor.shape)}:{self.A_tensor.stride()}") + print(f"B_tensor: {tuple(self.B_tensor.shape)}:{self.B_tensor.stride()}") + print( + f"offset_tensor: {tuple(self.offs_tensor.shape)}:{self.offs_tensor.stride()}" + ) + print(f"C_tensor: {tuple(self.C_tensor.shape)}:{self.C_tensor.stride()}") + + stream = self._get_stream() + + if self.misc.perf_e2e: + compiled = cute.compile( + kernel, + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + max_active_clusters, + stream, + ) + + warmup_iters = 4 + timed_iters = 4 + + for _ in range(warmup_iters): + l2_flush() + compiled( + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + stream, + ) + torch.cuda.synchronize() + + times = [] + for _ in range(timed_iters): + l2_flush() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + compiled( + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + stream, + ) + end_evt.record() + torch.cuda.synchronize() + times.append(start_evt.elapsed_time(end_evt)) + + avg_ms = sum(times) / len(times) + print(f"[perf_e2e] Individual times (ms): {[f'{t:.4f}' for t in times]}") + print(f"[perf_e2e] Average kernel time: {avg_ms:.4f} ms") + return avg_ms + else: + l2_flush() + kernel( + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + max_active_clusters, + stream, + ) + torch.cuda.synchronize() + return None + + def validate(self) -> None: + if not self.misc.perf_run: + assert torch.equal(self.C_tensor, self.C_ref_tensor), ( + "Validation failed: C_tensor != C_ref_tensor" + ) + + def run_sol_comparison(self) -> None: + """Run a dense batched GEMM as Speed-of-Light reference. + + Reuses the same tensor memory from the grouped run by + view/reshape/permute -- zero GPU allocation. + """ + import sys, os + + _examples_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") + ) + if _examples_root not in sys.path: + sys.path.insert(0, _examples_root) + + from blackwell.kernel.dense_gemm.dense_gemm_persistent import ( + PersistentDenseGemmKernel, + ) + + tokens = self.tokens_after_repeat + experts = self.expert_cnt + assert tokens % experts == 0, ( + f"compare_with_sol requires tokens*top_k ({tokens}) " + f"evenly divisible by experts ({experts}) so every group " + f"has exactly the same size" + ) + tpe = tokens // experts + + if self.problem.scenario == "2Dx3D": + M, N, K, L = tpe, self.intermediate, self.hidden, experts + else: # 2Dx2D + M, N, K, L = self.hidden, self.intermediate, tpe, experts + + # Reshape into GEMM-domain batch-last: A(M,K,L), B(N,K,L), C(M,N,L). + # Data values are irrelevant (perf only) — just need correct shape + # and stride pattern so the dense kernel sees the right major mode. + if self.problem.a_layout == "k_major": + a_sol = self.A_tensor.contiguous().view(L, M, K).permute(1, 2, 0) + leading_dim_a = 1 + else: + a_sol = self.A_tensor.contiguous().view(L, K, M).permute(2, 1, 0) + leading_dim_a = 0 + + if self.problem.b_layout == "n_major": + b_sol = self.B_tensor.contiguous().view(L, K, N).permute(2, 1, 0) + leading_dim_b = 0 + else: + b_sol = self.B_tensor.contiguous().view(L, N, K).permute(1, 2, 0) + leading_dim_b = 1 + + if self.problem.c_layout == "n_major": + c_sol = self.C_tensor.contiguous().view(L, M, N).permute(1, 2, 0) + leading_dim_c = 1 + else: + c_sol = self.C_tensor.contiguous().view(L, N, M).permute(2, 1, 0) + leading_dim_c = 0 + + from cutlass.cute.runtime import from_dlpack + + a_cute_sol = from_dlpack(a_sol, assumed_align=16).mark_layout_dynamic( + leading_dim=leading_dim_a + ) + b_cute_sol = from_dlpack(b_sol, assumed_align=16).mark_layout_dynamic( + leading_dim=leading_dim_b + ) + c_cute_sol = from_dlpack(c_sol, assumed_align=16).mark_layout_dynamic( + leading_dim=leading_dim_c + ) + + mma_tiler_mn = self.impl.mma_tiler_mnk[:2] + cluster_shape_mn = self.impl.cluster_shape_mnk[:2] + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + sol_kernel = PersistentDenseGemmKernel( + acc_dtype=self.temp_type_mapping[self.problem.acc_dtype], + use_2cta_instrs=self.impl.use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + use_tma_store=True, + ) + + print(f"\n[SOL] Dense BMM: M={M} N={N} K={K} L={L}") + print(f"[SOL] a_sol: {tuple(a_sol.shape)}:{a_sol.stride()}") + print(f"[SOL] b_sol: {tuple(b_sol.shape)}:{b_sol.stride()}") + print(f"[SOL] c_sol: {tuple(c_sol.shape)}:{c_sol.stride()}") + + l2_flush() + sol_kernel( + a_cute_sol, + b_cute_sol, + c_cute_sol, + max_active_clusters, + self._get_stream(), + ) + torch.cuda.synchronize() + + def run(self) -> None: + from torch.profiler import profile, ProfilerActivity + + print(self.problem) + print(self.impl) + print(self.misc) + self.generate_inputs() + kernel = self.create_kernel() + + if self.misc.perf_e2e: + self.run_kernel(kernel) + else: + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: + self.compute_reference() + self.run_kernel(kernel) + if ( + self.misc.compare_with_sol + and self.misc.perf_run + and self.problem.balance_route + ): + self.run_sol_comparison() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + + self.validate() + + +if __name__ == "__main__": + import argparse + + def parse_dtype(s: str) -> torch.dtype: + return getattr(torch, s) + + def parse_tuple(s: str) -> Tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + parser = argparse.ArgumentParser() + parser.add_argument("--tokens", type=int, default=128) + parser.add_argument("--experts", type=int, default=128) + parser.add_argument("--top_k_select", type=int, default=8) + parser.add_argument("--balance_route", action="store_true", default=False) + parser.add_argument("--hidden", type=int, default=2048) + parser.add_argument("--intermediate", type=int, default=7168) + parser.add_argument( + "--scenario", type=str, default="2Dx3D", choices=["2Dx3D", "2Dx2D"] + ) + parser.add_argument("--ab_dtype", type=str, default="bfloat16") + parser.add_argument("--out_dtype", type=str, default="bfloat16") + parser.add_argument("--acc_dtype", type=str, default="float32") + parser.add_argument("--grad_accumulate", action="store_true", default=False) + parser.add_argument( + "--a_layout", type=str, default="k_major", choices=["k_major", "m_major"] + ) + parser.add_argument( + "--b_layout", type=str, default="n_major", choices=["k_major", "n_major"] + ) + parser.add_argument( + "--c_layout", type=str, default="n_major", choices=["m_major", "n_major"] + ) + parser.add_argument("--mma_tiler_mnk", type=str, default="128,128,64") + parser.add_argument("--cluster_shape_mnk", type=str, default="1,1,1") + parser.add_argument("--use_2cta_instrs", action="store_true", default=False) + parser.add_argument("--static_expert_cnt", type=int, default=None) + parser.add_argument("--separate_tensormap_init", action="store_true", default=False) + parser.add_argument("--perf_run", action="store_true", default=False) + parser.add_argument("--perf_e2e", action="store_true", default=False) + parser.add_argument("--compare_with_bmm", action="store_true", default=False) + parser.add_argument("--compare_with_sol", action="store_true", default=False) + args = parser.parse_args() + + problem = ProblemDesc( + tokens=args.tokens, + experts=args.experts, + top_k_select=args.top_k_select, + balance_route=args.balance_route, + hidden=args.hidden, + intermediate=args.intermediate, + scenario=args.scenario, + ab_dtype=parse_dtype(args.ab_dtype), + out_dtype=parse_dtype(args.out_dtype), + acc_dtype=parse_dtype(args.acc_dtype), + grad_accumulate=args.grad_accumulate, + a_layout=args.a_layout, + b_layout=args.b_layout, + c_layout=args.c_layout, + ) + if not args.separate_tensormap_init: + print( + "Change separate_tensormap_init to True as current the fused version not implmented yet." + ) + args.separate_tensormap_init = True + impl = ImplDesc( + mma_tiler_mnk=parse_tuple(args.mma_tiler_mnk), + cluster_shape_mnk=parse_tuple(args.cluster_shape_mnk), + use_2cta_instrs=args.use_2cta_instrs, + static_expert_cnt=args.static_expert_cnt, + separate_tensormap_init=args.separate_tensormap_init, + ) + + misc = MiscDesc( + perf_run=args.perf_run, + perf_e2e=args.perf_e2e, + compare_with_bmm=args.compare_with_bmm, + compare_with_sol=args.compare_with_sol, + ) + if misc.no_torch_210: + misc.compare_with_bmm = True + print("Override to set --compare_with_bmm to avoid possible torch crash.") + + tester = GroupedGemmTester(problem, impl, misc) + tester.run() + print("PASS") diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_scaled_grouped_mm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_scaled_grouped_mm.py new file mode 100644 index 000000000..5a2d98fe3 --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_scaled_grouped_mm.py @@ -0,0 +1,3901 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Scaled Grouped GEMM for MoE operations with block scaling (MXFP8, MXFP4, NVFP4). + +PyTorch interface (from torch.nn.functional.scaled_grouped_mm): +- 2Dx3D (Forward): mat_a(tokens_sum, K) x mat_b(experts, K, N) -> out(tokens_sum, N) +- 2Dx2D (Weight grad): mat_a(M, tokens_sum) x mat_b(tokens_sum, N) -> out(experts, M, N) + +Kernel interface uses GEMM MNKL domain (same as torch_grouped_mm.py): + A_cute: (M, K, L) + B_cute: (N, K, L) + C_cute: (M, N, L) + SFA_cute, SFB_cute: scale factors with block-scaled atom layout + +The scheduler handles fake dimensions by computing token_offset from offs. +""" + +import os +import sys +from typing import Optional, Tuple, Literal, Type, Union + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Pointer +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(current_dir, "../../..")) + +from blackwell.kernel.moe.moe_utils import ( + MoEScaledGroupedGemmTensormapConstructor, +) +from blackwell.kernel.moe.moe_persistent_scheduler import ( + MoEStaticSchedulerParams, + MoEStaticPersistentTileScheduler, + MoEWorkTileInfo, +) +from blackwell.kernel.moe.moe_sched_extension import ScaledGroupedMmSchedExtension +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) + +# ============================================================================= +# ScaledGroupedGemmKernel +# ============================================================================= + + +class ScaledGroupedGemmKernel: + """ + Scaled Grouped GEMM kernel for MoE operations with block scaling. + + Combines: + - MoE grouped structure from GroupedGemmKernel (scheduler warp, expert-wise + TMA descriptors, MoEStaticPersistentTileScheduler) + - Block-scaled MMA from Sm100BlockScaledPersistentDenseGemmKernel (SFA/SFB + tensors, blockscaled tiled_mma, SMEM→TMEM SF copy) + + Warp specialization (7 warps): + - Warps 0-3: Epilogue (TMEM → RMEM → SMEM → GMEM, global_scale multiply) + - Warp 4: MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM) + - Warp 5: TMA load (A, B, SFA, SFB from GMEM → SMEM) + - Warp 6: Scheduler (MoEStaticPersistentTileScheduler, produces work tiles) + + __init__ parameters are codegen-time configuration only. + Runtime dtypes (a_dtype, b_dtype, sf_dtype, c_dtype) and layout modes + (a_major_mode, b_major_mode, c_layout) are inferred from input tensors + in __call__. + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + sf_vec_size: int, + accumulate_on_output: bool, + separate_tensormap_init: bool, + consistent_token_padding: bool, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64), + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1), + use_2cta_instrs: bool = False, + fixed_expert_cnt: Optional[int] = None, + ): + # ── User-provided codegen-time configuration ── + self.scenario = scenario + self.sf_vec_size = sf_vec_size + self.accumulate_on_output = accumulate_on_output + self.separate_tensormap_init = separate_tensormap_init + self.consistent_token_padding = consistent_token_padding + self.acc_dtype = acc_dtype + self.mma_tiler_mnk = mma_tiler_mnk + self.cluster_shape_mn = (cluster_shape_mnk[0], cluster_shape_mnk[1]) + self.use_2cta_instrs = use_2cta_instrs + self.fixed_expert_cnt = fixed_expert_cnt + self.arch = "sm_100" + + if accumulate_on_output and scenario == "2Dx3D": + raise ValueError( + "accumulate_on_output only makes sense for 2Dx2D (weight grad)." + ) + + self._validate_mma_tiler_and_cluster_shape() + + # ── MMA tiler — K is refined in _setup_attributes ── + self.mma_tiler = (mma_tiler_mnk[0], mma_tiler_mnk[1], 1) + + # ── CTA group for tcgen05 MMA ── + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # ── Warp specialization (7 warps) ── + self.occupancy = 1 + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.sched_warp_id = 6 + self.threads_per_cta = 32 * len( + ( + self.mma_warp_id, + self.tma_warp_id, + self.sched_warp_id, + *self.epilogue_warp_id, + ) + ) + + # ── Barrier IDs for synchronization ── + self.epilog_sync_bar_id = 1 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 + + self.smem_capacity = utils.get_smem_capacity_in_bytes(self.arch) + self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols(self.arch) + + # ----------------------------------------------------------------- + # Workspace size + # ----------------------------------------------------------------- + + def get_workspace_size(self, expert_cnt: int) -> int: + """Workspace size for the aux init kernel. + + Layout: [TMA descriptors (managed by tensormap ctor)] [padded scale offsets] + """ + desc_bytes = MoEScaledGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + padded_offs_bytes = expert_cnt * 4 if not self.consistent_token_padding else 0 + return desc_bytes + padded_offs_bytes + + # ----------------------------------------------------------------- + # Static validation + # ----------------------------------------------------------------- + + def _validate_mma_tiler_and_cluster_shape(self): + """Validate codegen-time MMA tiler and cluster shape constraints.""" + m, n, k = self.mma_tiler_mnk + cm, cn = self.cluster_shape_mn + + if m not in [128, 256]: + raise ValueError(f"mma_tiler M ({m}) must be one of [128, 256]") + + per_cta_m = m // (2 if self.use_2cta_instrs else 1) + if per_cta_m != 128: + raise ValueError( + f"per-CTA mma_tiler M must be 128, got {per_cta_m} " + f"(mma_tiler_m={m}, use_2cta_instrs={self.use_2cta_instrs})" + ) + + if n not in [64, 128, 256]: + raise ValueError(f"mma_tiler N ({n}) must be one of [64, 128, 256]") + + sf_k_granularity = self.sf_vec_size * 4 + if k % sf_k_granularity != 0: + raise ValueError( + f"mma_tiler K ({k}) must be a multiple of " + f"sf_vec_size * 4 = {sf_k_granularity}" + ) + + if cm % (2 if self.use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape M ({cm}) must be even when use_2cta_instrs=True" + ) + + is_pow2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if cm * cn > 16 or not is_pow2(cm) or not is_pow2(cn) or cm > 4 or cn > 4: + raise ValueError( + f"Invalid cluster_shape ({cm}, {cn}): each dim must be " + f"a power of 2 and <= 4, product must be <= 16" + ) + + if self.sf_vec_size not in {16, 32}: + raise ValueError(f"sf_vec_size ({self.sf_vec_size}) must be 16 or 32") + + # ----------------------------------------------------------------- + # _create_tiled_mma / _create_tiled_mma_sfb + # ----------------------------------------------------------------- + + def _create_tiled_mma(self) -> cute.TiledMma: + """Create blockscaled tiled MMA atom.""" + return sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + def _create_tiled_mma_sfb(self) -> cute.TiledMma: + """Create blockscaled tiled MMA atom for SFB (always CtaGroup.ONE).""" + return sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # ----------------------------------------------------------------- + # _setup_attributes + # ----------------------------------------------------------------- + + def _setup_attributes(self) -> None: + """ + Set up configurations that depend on GEMM inputs. + + Configures: + - tiled_mma / tiled_mma_sfb with correct dtypes and major modes + - MMA/cluster/tile shapes + - Cluster layouts (main + sfb) + - Multicast CTA counts + - Epilogue tile shape + - Stage counts (ACC, AB+SF, C) + - SMEM layouts for A/B/SFA/SFB/C + - TMEM column counts (accumulator + SFA + SFB) + - TMA load bytes + - Overlapping accumulator support + """ + # ── MMA instruction shapes ── + self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = self._create_tiled_mma() + tiled_mma_sfb = self._create_tiled_mma_sfb() + + # ── MMA / cluster / tile shapes ── + # Use user-specified K dimension from mma_tiler_mnk + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + assert self.mma_tiler_mnk[2] % mma_inst_shape_k == 0, ( + f"mma_tiler K ({self.mma_tiler_mnk[2]}) must be a multiple of " + f"MMA instruction K ({mma_inst_shape_k})" + ) + mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + self.mma_tiler_mnk[2], + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + self.mma_tiler_mnk[2], + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + # ── Cluster layouts ── + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # ── Multicast CTA counts ── + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # ── Epilogue tile shape ── + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + self.epi_tile_n = cute.size(self.epi_tile[1]) + + # ── Stage counts ── + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + self.num_sched_stages = 2 + + # ── SMEM layouts ── + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + # ── Overlapping accumulator ── + # N=256: TMEM can't fit 2 full acc buffers + SF, so acc and SF share columns. + # The acc pipeline uses 1 barrier stage with phase-based toggling. + # N<256: TMEM fits 2 independent acc buffers, normal 2-stage pipeline. + self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256 + self.num_acc_pipeline_stages = ( + 1 if self.overlapping_accum else self.num_acc_stage + ) + + # ── TMEM column counts ── + sf_atom_mn = 32 + self.num_sfa_tmem_cols = ( + self.cta_tile_shape_mnk[0] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sfb_tmem_cols = ( + self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[ + 1 + ] * self.num_acc_stage - ( + self.num_sf_tmem_cols if self.overlapping_accum else 0 + ) + + # Only when overlapping_accum, release accumulator buffer early in epilogue + self.iter_acc_early_release_in_epilogue = ( + self.num_sf_tmem_cols // self.epi_tile_n + ) + + # ── TMA load bytes (A + B + SFA + SFB per stage) ── + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # ----------------------------------------------------------------- + # _compute_stages (static) + # ----------------------------------------------------------------- + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Compute stage counts for ACC, A/B/SFA/SFB, and C.""" + num_acc_stage = 2 + num_c_stage = 2 + + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32) + num_sched_stages = 2 + sched_bytes = sched_work_tile_bytes_per_stage * num_sched_stages + + fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes + + num_ab_stage = ( + smem_capacity // occupancy - fixed_overhead + ) // ab_bytes_per_stage + + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * fixed_overhead + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + # ----------------------------------------------------------------- + # mainloop_s2t_copy_and_partition (from dense_blockscaled) + # ----------------------------------------------------------------- + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem → tmem load of a scale factor tensor, + then partition smem (source) and tmem (destination). + """ + tCsSF_compact = cute.filter_zeros(sSF) + tCtSF_compact = cute.filter_zeros(tSF) + + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + # ----------------------------------------------------------------- + # __call__ (JIT entry point) + # ----------------------------------------------------------------- + + @cute.jit + def __call__( + self, + mat_a: cute.Tensor, # PyTorch mat_a (data) + mat_b: cute.Tensor, # PyTorch mat_b (data) + scale_a: cute.Tensor, # SFA (assembled block-scaled layout) + scale_b: cute.Tensor, # SFB (assembled block-scaled layout) + out: cute.Tensor, # Output C + offs: cute.Tensor, # (experts,) cumsum end offsets, int32 + workspace: cute.Tensor, # Expert-wise TMA desc + padded offs + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + global_scale_a: Optional[cute.Tensor] = None, # NVFP4: per-expert f32 scalar + global_scale_b: Optional[cute.Tensor] = None, # NVFP4: per-expert f32 scalar + bias: Optional[cute.Tensor] = None, + ) -> None: + """Launch the scaled grouped GEMM kernel.""" + if cutlass.const_expr(bias is not None): + raise NotImplementedError("bias is not supported yet (align with torch).") + + # ================================================================= + # Step 1: Transform PyTorch tensors to GEMM domain (fake MNKL) + # ================================================================= + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(self.scenario == "2Dx3D"): + # mat_a: (tokens_sum, hidden) -> A: (fake_m, k, 1) + tokens_sum, hidden = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (tokens_sum, hidden, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + # mat_b: (experts, hidden, intermediate) -> B: (n, k, fake_l) + experts, hidden_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, hidden_b, experts), + stride=(mat_b.stride[2], mat_b.stride[1], mat_b.stride[0]), + ), + ) + # out: (tokens_sum, intermediate) -> C: (fake_m, n, 1) + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (tokens_sum, intermediate, c1), + stride=(out.stride[0], out.stride[1], c0), + ), + ) + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + # SFA/SFB: scale tensors have host-padded dimensions. + # Use their own shape as the "data shape" for atom tiling. + tokens_sum_padded = scale_a.shape[0] + hidden_padded = scale_a.shape[1] * self.sf_vec_size + sfa_gemm = cute.make_tensor( + scale_a.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (tokens_sum_padded, hidden_padded, c1), self.sf_vec_size + ), + ) + intermediate_padded_mul_hidden_padded = scale_b.shape[1] + intermediate_padded = ( + intermediate_padded_mul_hidden_padded * self.sf_vec_size + ) // hidden_padded + sfb_gemm = cute.make_tensor( + scale_b.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (intermediate_padded, hidden_padded, experts), self.sf_vec_size + ), + ) + + else: # 2Dx2D + # mat_a: (hidden, tokens_sum) -> A: (m, fake_k, 1) + hidden, tokens_sum = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (hidden, tokens_sum, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + # mat_b: (tokens_sum, intermediate) -> B: (n, fake_k, 1) + tokens_sum_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, tokens_sum_b, c1), + stride=(mat_b.stride[1], mat_b.stride[0], c0), + ), + ) + # out: (experts, hidden, intermediate) -> C: (m, n, fake_l) + experts, hidden_c, intermediate_c = out.shape + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (hidden_c, intermediate_c, experts), + stride=(out.stride[1], out.stride[2], out.stride[0]), + ), + ) + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + # SFA/SFB: scale tensors have host-padded dimensions. + hidden_padded = scale_a.shape[0] + tokens_sum_padded = scale_a.shape[1] * self.sf_vec_size + sfa_gemm = cute.make_tensor( + scale_a.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (hidden_padded, tokens_sum_padded, c1), self.sf_vec_size + ), + ) + intermediate_padded = scale_b.shape[0] + sfb_gemm = cute.make_tensor( + scale_b.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (intermediate_padded, tokens_sum_padded, c1), self.sf_vec_size + ), + ) + + # ================================================================= + # Step 2: Infer dtypes and major modes + # ================================================================= + + self.a_dtype: Type[cutlass.Numeric] = a_gemm.element_type + self.b_dtype: Type[cutlass.Numeric] = b_gemm.element_type + self.c_dtype: Type[cutlass.Numeric] = c_gemm.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_gemm.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_gemm).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_gemm).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_gemm) + + # ================================================================= + # Step 3: Setup kernel attributes + # ================================================================= + + self._setup_attributes() + tiled_mma = self._create_tiled_mma() + tiled_mma_sfb = self._create_tiled_mma_sfb() + + # ================================================================= + # Step 4: Create TMA atoms for A, B, SFA, SFB, C + # ================================================================= + + # ── TMA load A ── + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_gemm, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # ── TMA load B ── + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_gemm, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # ── TMA load SFA ── + # sfa_gemm is already atom-tiled from tile_atom_to_shape_SF + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_gemm, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + # ── TMA load SFB ── + # sfb_gemm is already atom-tiled from tile_atom_to_shape_SF + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_gemm, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + # ── TMA store/reduce C ── + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + c_tma_op, c_gemm, epi_smem_layout, self.epi_tile + ) + + # ================================================================= + # Step 5: offs_padded tensor (written by desc_init_kernel) + # ================================================================= + + # consistent_token_padding=True → offs_padded=None, main kernel reuses offs + # consistent_token_padding=False → offs_padded in GMEM workspace, written by desc_init + if cutlass.const_expr(self.consistent_token_padding): + offs_padded = None + else: + desc_bytes = MoEScaledGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + offs_padded = cute.make_tensor( + cute.recast_ptr(workspace.iterator + desc_bytes, dtype=offs.dtype), + cute.make_layout((expert_cnt,)), + ) + + # ================================================================= + # Step 6: Create MoEStaticSchedulerParams and compute grid + # ================================================================= + + sched_params = MoEStaticSchedulerParams( + scenario=self.scenario, + expert_shape=(expert_cnt, intermediate_dim, hidden_dim), + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + ) + + grid = MoEStaticSchedulerParams.get_grid_shape( + sched_params, max_active_clusters + ) + + # ================================================================= + # Step 7: Launch desc_init_kernel (if separate_tensormap_init) + # ================================================================= + + if cutlass.const_expr(self.separate_tensormap_init): + self.desc_init_kernel( + tiled_mma, + tiled_mma_sfb, + a_gemm, + b_gemm, + c_gemm, + sfa_gemm, + sfb_gemm, + offs, + expert_cnt, + workspace.iterator, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + ).launch( + grid=(1, 1, 1), + block=[self._desc_init_block_threads, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + # ================================================================= + # Step 8: Launch main kernel + # ================================================================= + + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + a_gemm, + b_gemm, + c_gemm, + sfa_gemm, + sfb_gemm, + offs, + sched_params, + workspace.iterator, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + offs_padded, + global_scale_a, + global_scale_b, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=self.occupancy, + ) + + # ----------------------------------------------------------------- + # desc_init_kernel (GPU device kernel) + # ----------------------------------------------------------------- + + # Number of warps per warp-group in desc_init_kernel. + _desc_init_warps_per_group = 4 + # Threads per warp-group (must equal MoEScaledGroupedGemmTensormapConstructor.ChunkSize). + _desc_init_group_threads = _desc_init_warps_per_group * 32 # 128 + # Total threads in desc_init_kernel (2 warp-groups × 4 warps each). + _desc_init_block_threads = _desc_init_group_threads * 2 # 256 + # Named barrier ID for warp-group-internal sync within Group A. + _desc_init_group_a_bar_id = 1 + + @cute.kernel + def desc_init_kernel( + self, + # ── MMA atoms ── + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + # ── GEMM domain tensors (fake MNKL) ── + a_gemm: cute.Tensor, + b_gemm: cute.Tensor, + c_gemm: cute.Tensor, + sfa_gemm: cute.Tensor, + sfb_gemm: cute.Tensor, + # ── Scheduling / workspace ── + offs: cute.Tensor, + expert_cnt: Union[cutlass.Int32, int], + workspace_ptr: Pointer, + # ── Cluster layouts ── + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + # ── SMEM layouts ── + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + ): + """ + Pre-initialize expert-wise TMA descriptors and compute padded scale + offsets (``offs_padded``). + + Grid: (1, 1, 1) + Block: (256, 1, 1) — 8 warps split into two groups of 4: + + - **Group A** (warps 0-3, threads 0..127): Compute ``offs_padded`` + prefix sum, write to SMEM + GMEM. + - **Group B** (warps 4-7, threads 128..255): Create TMA descriptors + via ``construct_and_write`` (chunked, with pipeline sync). + + Synchronization: + - Group A internal: NamedBarrier (for cross-warp prefix sum) + - Group A → Group B: PipelineAsync (mbarrier producer-consumer) + """ + chunk_size = self._desc_init_group_threads # 128 + full_mask = 0xFFFFFFFF + warp_size = 32 + + # ================================================================= + # Thread identity + # ================================================================= + + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + lane_in_group = tidx % chunk_size # 0..127 within each group + + # ================================================================= + # Reconstruct TMA ops (same as before) + # ================================================================= + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)) + sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_tma_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + # ================================================================= + # GMEM offs_padded tensor (written by Group A, read by main kernel) + # Only allocated when consistent_token_padding=False. + # ================================================================= + + if cutlass.const_expr(not self.consistent_token_padding): + desc_bytes = MoEScaledGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + gmem_offs_padded = cute.make_tensor( + cute.recast_ptr(workspace_ptr + desc_bytes, dtype=offs.dtype), + cute.make_layout((expert_cnt,)), + ) + + # ================================================================= + # SMEM allocation + # ================================================================= + + smem = utils.SmemAllocator() + + @cute.struct + class DescInitStorage: + # offs_padded SMEM buffer: [carry, chunk[0..127]] + offs_padded_buf: cute.struct.MemRange[cutlass.Int32, chunk_size + 1] + # Cross-warp prefix sum scratch (one per warp in Group A) + warp_sums: cute.struct.MemRange[ + cutlass.Int32, self._desc_init_warps_per_group + ] + # Pipeline mbarrier storage (PipelineAsync with 1 stage needs 2 mbarriers) + pipeline_mbar: cute.struct.MemRange[cutlass.Int64, 2] + + storage = smem.allocate(DescInitStorage) + + # Make a tensor view for the SMEM offs_padded buffer + smem_offs_padded = cute.make_tensor( + storage.offs_padded_buf.data_ptr(), + cute.make_layout((chunk_size + 1,)), + ) + smem_warp_sums = cute.make_tensor( + storage.warp_sums.data_ptr(), + cute.make_layout((self._desc_init_warps_per_group,)), + ) + + # ================================================================= + # Pipeline: Group A (producer) → Group B (consumer) + # ================================================================= + + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, chunk_size) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, chunk_size) + pipe = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=producer_group, + consumer_group=consumer_group, + barrier_storage=storage.pipeline_mbar.data_ptr(), + ) + producer, consumer = pipe.make_participants() + + # Named barrier for Group A internal sync (cross-warp prefix sum) + group_a_sync = pipeline.NamedBarrier( + barrier_id=self._desc_init_group_a_bar_id, + num_threads=chunk_size, + ) + + # ================================================================= + # Padding granularity P + # ================================================================= + + if cutlass.const_expr(self.scenario == "2Dx2D"): + # tokens = K (reduce dim): pad scale cols → P = sf_vec_size × 4 + pad_granularity = self.sf_vec_size * 4 + else: + # tokens = M (non-reduce dim): pad scale rows → P = 128 + pad_granularity = 128 + + # ================================================================= + # Tensormap constructor (for Group B) + # ================================================================= + + tensormap_ctor = MoEScaledGroupedGemmTensormapConstructor( + scenario=self.scenario, + sf_vec_size=self.sf_vec_size, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + sf_dtype=self.sf_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + sfa_smem_layout=sfa_smem_layout, + sfb_smem_layout=sfb_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + sfa_tma_op=sfa_tma_op, + sfb_tma_op=sfb_tma_op, + tiled_mma=tiled_mma, + tiled_mma_sfb=tiled_mma_sfb, + mma_tiler=self.mma_tiler, + mma_tiler_sfb=self.mma_tiler_sfb, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + cluster_layout_sfb_vmnk_shape=cluster_layout_sfb_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + sfa_tensor=sfa_gemm, + sfb_tensor=sfb_gemm, + offs=offs, + offs_padded=offs + if cutlass.const_expr(self.consistent_token_padding) + else gmem_offs_padded, + workspace_ptr=workspace_ptr, + expert_cnt=expert_cnt, + ) + + # ================================================================= + # Warp-group split + # ================================================================= + + num_chunks = (expert_cnt + chunk_size - 1) // chunk_size + + if warp_idx < self._desc_init_warps_per_group: + # ============================================================= + # Group A: produce offs_padded into SMEM (+ GMEM if needed) + # ============================================================= + + warp_in_group = warp_idx # 0..3 + lane_in_warp = tidx % warp_size + + carry = cutlass.Int32(0) + chunk_idx = cutlass.Int32(0) + + while chunk_idx < num_chunks: + expert_idx = chunk_idx * chunk_size + lane_in_group + + if cutlass.const_expr(self.consistent_token_padding): + # ── Fast path: offs_padded == offs, just load ── + offs_val = cutlass.Int32(0) + if expert_idx < expert_cnt: + offs_val = offs[expert_idx] + + # Wait for consumer to release SMEM from previous chunk + producer.acquire_and_advance() + + # Write SMEM: [carry, offs[chunk_base..chunk_base+127]] + if lane_in_group == cutlass.Int32(0): + smem_offs_padded[0] = carry + smem_offs_padded[lane_in_group + 1] = offs_val + + # Ensure all SMEM writes visible, then signal consumer + group_a_sync.arrive_and_wait() + producer.commit() + + # Only thread 0 needs carry (to write smem[0] next iteration) + if lane_in_group == cutlass.Int32(0): + carry = smem_offs_padded[chunk_size] + + else: + # ── Full path: compute prefix sum of padded sizes ── + + # Load and compute per-thread padded size + padded_size = cutlass.Int32(0) + if expert_idx < expert_cnt: + prev_off = cutlass.Int32(0) + if expert_idx > cutlass.Int32(0): + prev_off = offs[expert_idx - 1] + size_i = offs[expert_idx] - prev_off + padded_size = ( + (size_i + pad_granularity - 1) // pad_granularity + ) * pad_granularity + + # Stage 1: warp-level inclusive prefix sum (shfl_up) + val = padded_size + for d in [1, 2, 4, 8, 16]: + n = cute.arch.shuffle_sync_up( + val, d, mask=full_mask, mask_and_clamp=0 + ) + if lane_in_warp >= d: + val = val + n + + # Lane 31 of each warp holds the warp total + if lane_in_warp == warp_size - 1: + smem_warp_sums[warp_in_group] = val + + # Group A internal sync (warp_sums visible) + group_a_sync.arrive_and_wait() + + # Stage 2: cross-warp correction + cross_warp_prefix = cutlass.Int32(0) + if warp_in_group >= 1: + cross_warp_prefix = smem_warp_sums[0] + if warp_in_group >= 2: + cross_warp_prefix = cross_warp_prefix + smem_warp_sums[1] + if warp_in_group >= 3: + cross_warp_prefix = cross_warp_prefix + smem_warp_sums[2] + + offs_padded_val = carry + val + cross_warp_prefix + + # Wait for consumer to release SMEM from previous chunk + producer.acquire_and_advance() + + # Write SMEM: [carry, offs_padded[chunk_base..chunk_base+127]] + if lane_in_group == cutlass.Int32(0): + smem_offs_padded[0] = carry + smem_offs_padded[lane_in_group + 1] = offs_padded_val + + # Ensure all SMEM writes visible, then signal consumer + group_a_sync.arrive_and_wait() + producer.commit() + + # Write GMEM (overlaps with Group B's phase 2) + if expert_idx < expert_cnt: + gmem_offs_padded[expert_idx] = offs_padded_val + + # Update carry + carry = smem_offs_padded[chunk_size] + + chunk_idx += 1 + + else: + # ============================================================= + # Group B: create TMA descriptors (chunked, with pipeline sync) + # ============================================================= + + tensormap_ctor.construct_and_write( + lane_in_group, + dependency=(consumer, smem_offs_padded), + ) + + # ----------------------------------------------------------------- + # kernel (GPU device kernel) + # ----------------------------------------------------------------- + + @cute.kernel + def kernel( + self, + # ── MMA atoms ── + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + # ── TMA atoms and tensors: A ── + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + # ── TMA atoms and tensors: B ── + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + # ── TMA atoms and tensors: SFA ── + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + # ── TMA atoms and tensors: SFB ── + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + # ── TMA atoms and tensors: C ── + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + # ── GEMM domain tensors ── + a_gemm: cute.Tensor, + b_gemm: cute.Tensor, + c_gemm: cute.Tensor, + sfa_gemm: cute.Tensor, + sfb_gemm: cute.Tensor, + # ── Scheduling / workspace ── + offs: cute.Tensor, + sched_params: MoEStaticSchedulerParams, + workspace_ptr: Pointer, + # ── Cluster layouts ── + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + # ── SMEM layouts ── + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + # ── Optional: padded offsets ── + offs_padded: Optional[cute.Tensor], + # ── Optional: NVFP4 per-expert global scales ── + global_scale_a: Optional[cute.Tensor], + global_scale_b: Optional[cute.Tensor], + ): + """ + GPU device kernel for MoE Scaled Grouped GEMM with block scaling. + + Backbone: torch_grouped_mm.py (7-warp MoE scheduler structure) + GEMM internals: dense_blockscaled_gemm_persistent.py + """ + # ================================================================= + # Reconstruct objects that can't be passed as kernel params + # ================================================================= + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)) + sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_tma_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + # Build offs tuple for the extension + if cutlass.const_expr(offs_padded is not None): + offs_for_ext = (offs, offs_padded) + else: + offs_for_ext = (offs, offs) + + tensormap_ctor = MoEScaledGroupedGemmTensormapConstructor( + scenario=self.scenario, + sf_vec_size=self.sf_vec_size, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + sf_dtype=self.sf_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + sfa_smem_layout=sfa_smem_layout, + sfb_smem_layout=sfb_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + sfa_tma_op=sfa_tma_op, + sfb_tma_op=sfb_tma_op, + tiled_mma=tiled_mma, + tiled_mma_sfb=tiled_mma_sfb, + mma_tiler=self.mma_tiler, + mma_tiler_sfb=self.mma_tiler_sfb, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + cluster_layout_sfb_vmnk_shape=cluster_layout_sfb_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + sfa_tensor=sfa_gemm, + sfb_tensor=sfb_gemm, + offs=offs, + offs_padded=offs_padded if offs_padded is not None else offs, + workspace_ptr=workspace_ptr, + ) + ext = ScaledGroupedMmSchedExtension( + scenario=self.scenario, tensormap_ctor=tensormap_ctor + ) + + # ================================================================= + # Kernel setup + # ================================================================= + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + tidx, _, _ = cute.arch.thread_idx() + + # ================================================================= + # SharedStorage + # ================================================================= + + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_pipeline_stages * 2 + ] + sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4] + sched_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_sched_stages * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ================================================================= + # Pipelines + # ================================================================= + + # AB pipeline (TMA load → MMA) — same as grouped_mm + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + # ACC pipeline (MMA → epilogue) + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = ( + len(self.epilogue_warp_id) * 32 * (2 if use_2cta_instrs else 1) + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_pipeline_stages, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Scheduler pipeline (sched warp → tma/mma/epi warps) + sched_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32) + num_sched_consumer_threads = 32 * len( + (self.tma_warp_id, self.mma_warp_id, *self.epilogue_warp_id) + ) + sched_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_sched_consumer_threads + ) + sched_pipeline = pipeline.PipelineAsync.create( + num_stages=self.num_sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + barrier_storage=storage.sched_mbar_ptr.data_ptr(), + defer_sync=True, + ) + + # TMEM allocator + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr.ptr, + ) + + # Cluster barrier sync after init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # ================================================================= + # SMEM tensors A/B/SFA/SFB + # ================================================================= + + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + sSFA = smem.allocate_tensor( + element_type=self.sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + sSFB = smem.allocate_tensor( + element_type=self.sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + + # (MMA, MMA_M, MMA_N, STAGE=2) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + if cutlass.const_expr(self.overlapping_accum): + # Overlapping: two acc buffers share TMEM with SF columns, + # so the stage stride is smaller than a full N-width. + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + + # Cluster wait before TMEM alloc + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # ================================================================= + # Scheduler warp (warp 6) — same as grouped_mm + # ================================================================= + + sched_buf_ptr = storage.sched_buf.data_ptr() + sched_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=128 + ) + sched_buf_tensor = cute.make_tensor( + sched_buf_ptr, cute.make_layout((4, self.num_sched_stages), stride=(1, 4)) + ) + + if warp_idx == self.sched_warp_id: + scheduler = MoEStaticPersistentTileScheduler.create( + sched_params, offs, cute.arch.block_idx(), cute.arch.grid_dim() + ) + + sched_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_sched_stages + ) + + work_tile_info = scheduler.initial_work_tile_info() + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + work_tile_info = scheduler.advance_to_next_work() + while work_tile_info.is_valid_tile: + ext.prefetch_for_expert(work_tile_info.expert_idx) + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + work_tile_info = scheduler.advance_to_next_work() + + sched_pipeline.producer_acquire(sched_producer_state) + sentinel = MoEWorkTileInfo( + cutlass.Int32(-1), + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int32(0), + ) + rmem = sentinel.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + + sched_pipeline.producer_tail(sched_producer_state) + + # ================================================================= + # TMA load warp (warp 5) + # ================================================================= + + if warp_idx == self.tma_warp_id: + # Multicast masks, only used in TMA load warp + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr( + self.is_a_mcast or self.is_b_mcast or use_2cta_instrs + ): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, + block_in_cluster_coord_sfb_vmnk, + mcast_mode=1, + ) + + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get real GEMM domain tensors + TMA desc ptrs via extension + real_a, desc_ptr_a = ext.get_gmem_tensor( + "a", + tma_tensor_a, + offs_for_ext, + work_tile_info, + ) + real_b, desc_ptr_b = ext.get_gmem_tensor( + "b", + tma_tensor_b, + offs_for_ext, + work_tile_info, + ) + real_sfa, desc_ptr_sfa = ext.get_gmem_tensor( + "sfa", + tma_tensor_sfa, + offs_for_ext, + work_tile_info, + ) + real_sfb, desc_ptr_sfb = ext.get_gmem_tensor( + "sfb", + tma_tensor_sfb, + offs_for_ext, + work_tile_info, + ) + + # local_tile for A, B + gA_mkl = cute.local_tile( + real_a, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + gB_nkl = cute.local_tile( + real_b, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + + # local_tile for SFA, SFB + gSFA_mkl = cute.local_tile( + real_sfa, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + gSFB_nkl = cute.local_tile( + real_sfb, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + + # MMA partition for TMA + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA_mkl) + tCgB = thr_mma.partition_B(gB_nkl) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + + # TMA partition A + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA partition B + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + # TMA partition SFA + sfa_cta_layout = a_cta_layout + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + # TMA partition SFB + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # Slice to current tile coords (L=0, expert already selected) + mma_tile_m = work_tile_info.tile_m_idx // cute.size( + tiled_mma.thr_id.shape + ) + tAgA_slice = tAgA[(None, mma_tile_m, None, 0)] + tBgB_slice = tBgB[(None, work_tile_info.tile_n_idx, None, 0)] + tAgSFA_slice = tAgSFA[(None, mma_tile_m, None, 0)] + + # SFB slice — N=64 + slice_n = work_tile_info.tile_n_idx + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = work_tile_info.tile_n_idx // 2 + tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)] + + # TMA load loop + ab_producer.reset() + peek_ab_empty_status = ab_producer.try_acquire() + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_producer.acquire_and_advance(peek_ab_empty_status) + peek_ab_empty_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_empty_status = ab_producer.try_acquire() + # TMA load A + cute.copy( + tma_atom_a, + tAgA_slice[(None, handle.count)], + tAsA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_a, + mcast_mask=a_full_mcast_mask, + ) + # TMA load B + cute.copy( + tma_atom_b, + tBgB_slice[(None, handle.count)], + tBsB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_b, + mcast_mask=b_full_mcast_mask, + ) + # TMA load SFA + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, handle.count)], + tAsSFA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_sfa, + mcast_mask=sfa_full_mcast_mask, + ) + # TMA load SFB + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, handle.count)], + tBsSFB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_sfb, + mcast_mask=sfb_full_mcast_mask, + ) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + ab_producer.tail() + + # ================================================================= + # MMA warp (warp 4) + # ================================================================= + + if warp_idx == self.mma_warp_id: + # MMA fragments (SMEM → TMEM partitions), only used in this warp + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # SFA TMEM tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # SFB TMEM tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # S2T copy partitions for SFA/SFB + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_pipeline_stages + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_producer_state.phase ^ 1 + else: + acc_stage_index = acc_producer_state.index + + if is_leader_cta: + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + # SFB TMEM pointer offset for N=64 + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + offset = cutlass.Int32((work_tile_info.tile_n_idx % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + self.num_accumulator_tmem_cols + + self.num_sfa_tmem_cols + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + # AB consumer mainloop + ab_consumer.reset() + peek_ab_full_status = cutlass.Boolean(1) + if k_tile_cnt > 0: + peek_ab_full_status = ab_consumer.try_wait() + acc_pipeline.producer_acquire(acc_producer_state) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_consumer.wait_and_advance(peek_ab_full_status) + peek_ab_full_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() + + # S2T copy SFA/SFB from SMEM to TMEM + s2t_stage_coord = ( + None, + None, + None, + None, + handle.index, + ) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + + # Block-scaled GEMM with paired operands + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0) + tile_crd = (None, None, None, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + [tCrA[tile_crd], tCtSFA], + [tCrB[tile_crd], tCtSFB_mma], + tCtAcc, + ) + handle.release() + + if k_tile_cnt > 0: + acc_pipeline.producer_commit(acc_producer_state) + if k_tile_cnt > 0: + acc_producer_state.advance() + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + acc_pipeline.producer_tail(acc_producer_state) + + # ================================================================= + # SMEM tensor C (allocated after MMA section) + # ================================================================= + + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + + # ================================================================= + # Epilogue warps (warps 0-3) + # ================================================================= + + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_pipeline_stages + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilogue_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + + # Layout transformation for epilogue + tCtAcc_transformed = transform_partitioned_tensor_layout(tCtAcc_base) + + num_tiles_executed = cutlass.Int32(0) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get real C tensor + TMA desc ptr + real_c, desc_ptr_c = ext.get_gmem_tensor( + "c", + tma_tensor_c, + offs_for_ext, + work_tile_info, + ) + # local_tile + partition for C + gC_mnl = cute.local_tile( + real_c, + cute.slice_(self.mma_tiler, (None, None, 0)), + (None, None, None), + ) + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgC = thr_mma.partition_C(gC_mnl) + tCgC_transformed = transform_partitioned_tensor_layout(tCgC) + + mma_tile_coord_mnl = ( + work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + work_tile_info.tile_n_idx, + cutlass.Int32(0), + ) + + # Partition for TMEM → RMEM copy + tiled_copy_t2r, tTR_tAcc_base_epi, tTR_rAcc = ( + epilogue_tmem_copy_and_partition( + self, + tidx, + tCtAcc_transformed, + tCgC_transformed, + epi_tile, + use_2cta_instrs, + ) + ) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( + self, tiled_copy_t2r, tTR_rC, tidx, sC + ) + + # TMA partition for C store + tCgC_epi = cute.flat_divide(tCgC_transformed, epi_tile) + bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = True if acc_stage_index == 0 else False + else: + acc_stage_index = acc_consumer_state.index + + # Set TMEM buffer for current tile + tTR_tAcc = tTR_tAcc_base_epi[ + (None, None, None, None, None, acc_stage_index) + ] + + # Wait for accumulator buffer full + if k_tile_cnt > 0: + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Compute per-expert global_scale alpha for NVFP4 + if cutlass.const_expr(global_scale_a is not None): + expert_idx = work_tile_info.expert_idx + alpha = cute.arch.load( + global_scale_a.iterator + expert_idx, + cutlass.Float32, + ) * cute.arch.load( + global_scale_b.iterator + expert_idx, + cutlass.Float32, + ) + else: + alpha = None + + # Store accumulator to global memory in subtiles + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = num_tiles_executed * subtile_cnt + + for subtile_idx in cutlass.range(subtile_cnt): + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = ( + self.cta_tile_shape_mnk[1] // self.epi_tile_n + - 1 + - subtile_idx + ) + + # TMEM → RMEM + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Early release for overlapping_accum + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + cute.arch.fence_view_async_tmem_load() + if k_tile_cnt > 0: + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # Convert to output dtype, apply global_scale + acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc)) + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + else: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + if cutlass.const_expr(global_scale_a is not None): + acc_vec = acc_vec * alpha + acc_vec = acc_vec.to(self.c_dtype) + tRS_rC.store(acc_vec) + + # RMEM → SMEM + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)] + ) + cute.arch.fence_proxy("async.shared", space="cta") + epilog_sync_barrier.arrive_and_wait() + + # SMEM → GMEM (TMA store or TMA reduce) + if warp_idx == self.epilogue_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, real_subtile_idx)], + tma_desc_ptr=desc_ptr_c, + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + # Release accumulator buffer (non-overlapping path) + if cutlass.const_expr(not self.overlapping_accum): + if k_tile_cnt > 0: + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + num_tiles_executed += cutlass.Int32(1) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + # Wait for C store complete + c_pipeline.producer_tail() + + # Free TMEM + tmem.relinquish_alloc_permit() + epilog_sync_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) + + +# ============================================================================= +# Non-Kernel Part +# ============================================================================= + +from dataclasses import dataclass, field +import re + +import numpy as np +import torch +import cutlass.torch as cutlass_torch + +# ============================================================================= +# Utility functions +# ============================================================================= + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def round_up(a: int, b: int) -> int: + return ceil_div(a, b) * b + + +def torch_version_lt(major: int, minor: int) -> bool: + """Best-effort torch version check that tolerates local build suffixes.""" + match = re.match(r"^\s*(\d+)\.(\d+)", torch.__version__) + if match is None: + print( + "WARNING: failed to parse torch.__version__, " + "falling back to manual host reference." + ) + return True + version = (int(match.group(1)), int(match.group(2))) + return version < (major, minor) + + +def offs_to_group_sizes(offs: torch.Tensor) -> list[int]: + """Convert cumulative end offsets to per-group sizes.""" + offs_cpu = offs.cpu().tolist() + prev = 0 + sizes = [] + for end in offs_cpu: + sizes.append(end - prev) + prev = end + return sizes + + +def l2_flush(size_mb: int = 400) -> None: + """Best-effort L2 flush by touching a large temporary tensor.""" + num_bytes = size_mb * 1024 * 1024 + flush_buf = torch.randint(0, 256, (num_bytes,), dtype=torch.uint8, device="cuda") + del flush_buf + + +# ============================================================================= +# Format configuration +# +# Note: For all current formats, sf_vec_size == blocksize. +# The kernel can derive sf_vec_size from blocksize directly. +# ============================================================================= + +_FORMAT_CONFIG = { + "mxfp8": { + "data_dtype": torch.float8_e4m3fn, + "blocksize": 32, + "scale_dtype": torch.float8_e8m0fnu, + "has_global_scale": False, + }, + "mxfp4": { + "data_dtype": torch.float4_e2m1fn_x2, + "blocksize": 32, + "scale_dtype": torch.float8_e8m0fnu, + "has_global_scale": False, + }, + "nvfp4": { + "data_dtype": torch.float4_e2m1fn_x2, + "blocksize": 16, + "scale_dtype": torch.float8_e4m3fn, + "has_global_scale": True, + }, +} + +# FP4 nibble encoding: value → 4-bit nibble (float4 e2m1 format) +# 0 → 0x0 +# 0.5 → 0x1 1.0 → 0x2 1.5 → 0x3 +# 2.0 → 0x4 3.0 → 0x5 4.0 → 0x6 6.0 → 0x7 +# -0 → 0x8 -0.5 → 0x9 -1.0 → 0xA -1.5 → 0xB +# -2.0 → 0xC -3.0 → 0xD -4.0 → 0xE -6.0 → 0xF + +# Correctness-friendly: only {0, 1, -1} → nibbles {0x0, 0x2, 0xA} +_FP4_CORRECTNESS_NIBBLES = torch.tensor([0x0, 0x2, 0xA], dtype=torch.uint8) +# Perf: all 16 valid nibbles (index == nibble value) +_FP4_PERF_NIBBLES = torch.arange(16, dtype=torch.uint8) +_FP4_DECODE_TABLE = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, +) + + +# ============================================================================= +# Scale shape computation +# ============================================================================= + + +def compute_scale_shape( + scenario: str, + operand: str, + group_sizes: list[int], + hidden: int, + intermediate: int, + K_fixed: int, + blocksize: int, + expert_cnt: int, +) -> tuple[int, ...]: + """ + Compute the assembled (swizzled 32_4_4) scale tensor shape. + + Swizzle 32_4_4 pads each group's scale to rows=round_up(non_K, 128), + cols=round_up(ceil_div(K, blocksize), 4), then flattens per group. + + Scale layout per scenario/operand: + 2Dx3D A: groups along M (variable per expert), K fixed + -> (sum(round_up(M_g, 128)), round_up(ceil_div(K, bs), 4)) + 2Dx3D B: per-expert (K, N same for all) + -> (G, round_up(N, 128) * round_up(ceil_div(K, bs), 4)) + 2Dx2D A: M fixed, groups along K (variable per expert) + -> (round_up(M, 128), sum(round_up(ceil_div(K_g, bs), 4))) + 2Dx2D B: N fixed, groups along K (variable per expert) + -> (round_up(N, 128), sum(round_up(ceil_div(K_g, bs), 4))) + + Args: + scenario: "2Dx3D" or "2Dx2D" + operand: "a" or "b" + group_sizes: per-expert sizes of the grouped dimension + (M sizes for 2Dx3D, K sizes for 2Dx2D) + hidden: M dimension (hidden_size) + intermediate: N dimension (intermediate_size) + K_fixed: K dimension (used where K is fixed across experts) + blocksize: 32 for MXFP8/MXFP4, 16 for NVFP4 + expert_cnt: number of experts (G) + """ + if scenario == "2Dx3D": + # group_sizes = per-expert M sizes; K is fixed for all experts + if operand == "a": + total_rows = sum(round_up(mg, 128) for mg in group_sizes) + total_cols = round_up(ceil_div(K_fixed, blocksize), 4) + return (total_rows, total_cols) + else: + padded_N = round_up(intermediate, 128) + padded_K_scale = round_up(ceil_div(K_fixed, blocksize), 4) + return (expert_cnt, padded_N * padded_K_scale) + else: # 2Dx2D + # group_sizes = per-expert K sizes; M and N are fixed + if operand == "a": + padded_M = round_up(hidden, 128) + total_cols = sum(round_up(ceil_div(kg, blocksize), 4) for kg in group_sizes) + return (padded_M, total_cols) + else: + padded_N = round_up(intermediate, 128) + total_cols = sum(round_up(ceil_div(kg, blocksize), 4) for kg in group_sizes) + return (padded_N, total_cols) + + +def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor: + """Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor.""" + if scale_2d.dim() != 2: + raise ValueError(f"Expected 2D scale tensor, got {scale_2d.dim()}D.") + rows, cols = scale_2d.shape + if rows == 0 or cols == 0: + return scale_2d.new_empty((0,)) + + row_blocks = ceil_div(rows, 128) + col_blocks = ceil_div(cols, 4) + padded_rows = row_blocks * 128 + padded_cols = col_blocks * 4 + + padded = scale_2d + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), dtype=scale_2d.dtype, device=scale_2d.device + ) + padded[:rows, :cols] = scale_2d + + blocks = padded.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + return rearranged.flatten() + + +def pad_and_swizzle_single(raw_scale_2d: torch.Tensor) -> torch.Tensor: + if raw_scale_2d.dim() != 2: + raise ValueError(f"Expected 2D scale tensor, got {raw_scale_2d.dim()}D.") + return to_blocked(raw_scale_2d) + + +def create_raw_scale_tensor( + non_k_size: int, + k_size: int, + blocksize: int, + scale_dtype: torch.dtype, + device: str = "cuda", +) -> torch.Tensor: + """Create one raw, non-swizzled scale tensor with exact values in {1, 2}.""" + scale_cols = ceil_div(k_size, blocksize) + return ( + torch.randint( + 1, + 3, + (non_k_size, scale_cols), + dtype=torch.float32, + device=device, + ) + .to(scale_dtype) + .reshape(non_k_size, scale_cols) + ) + + +def cat_byte_reinterpretable_tensors( + tensors: list[torch.Tensor], dim: int = 0 +) -> torch.Tensor: + """Concatenate byte-backed float tensors via uint8 view when native cat is unsupported.""" + if not tensors: + raise ValueError("Expected at least one tensor to concatenate.") + first = tensors[0] + if first.is_floating_point() and first.element_size() == 1: + concatenated = torch.cat( + [tensor.view(torch.uint8) for tensor in tensors], dim=dim + ) + return concatenated.view(first.dtype) + return torch.cat(tensors, dim=dim) + + +def stack_byte_reinterpretable_tensors( + tensors: list[torch.Tensor], dim: int = 0 +) -> torch.Tensor: + """Stack byte-backed float tensors via uint8 view when native stack is unsupported.""" + if not tensors: + raise ValueError("Expected at least one tensor to stack.") + first = tensors[0] + if first.is_floating_point() and first.element_size() == 1: + stacked = torch.stack([tensor.view(torch.uint8) for tensor in tensors], dim=dim) + return stacked.view(first.dtype) + return torch.stack(tensors, dim=dim) + + +def assemble_raw_scales_2d2d( + raw_scales: list[torch.Tensor], non_k_size: int +) -> torch.Tensor: + flat_parts = [pad_and_swizzle_single(scale) for scale in raw_scales] + all_flat = cat_byte_reinterpretable_tensors(flat_parts, dim=0) + return all_flat.reshape(round_up(non_k_size, 128), -1) + + +def assemble_raw_scales_2d3d_3d_side(raw_scales: list[torch.Tensor]) -> torch.Tensor: + flat_parts = [pad_and_swizzle_single(scale) for scale in raw_scales] + return stack_byte_reinterpretable_tensors(flat_parts, dim=0) + + +def assemble_raw_scales_2d3d_2d_side(raw_scales: list[torch.Tensor]) -> torch.Tensor: + flat_parts = [pad_and_swizzle_single(scale) for scale in raw_scales] + all_flat = cat_byte_reinterpretable_tensors(flat_parts, dim=0) + total_rows = sum(round_up(scale.shape[0], 128) for scale in raw_scales) + return all_flat.reshape(total_rows, -1) + + +def fp4_packed_dim(tensor: torch.Tensor) -> int: + positive_strides = [ + (abs(stride), idx) for idx, stride in enumerate(tensor.stride()) if stride > 0 + ] + if not positive_strides: + return tensor.dim() - 1 + return min(positive_strides)[1] + + +def unpack_fp4_to_f32(packed: torch.Tensor) -> torch.Tensor: + """Unpack a float4_e2m1fn_x2 tensor into float32 along the packed dimension.""" + packed_dim = fp4_packed_dim(packed) + raw = packed.view(torch.uint8) + + if packed_dim != raw.dim() - 1: + perm = list(range(raw.dim())) + perm[packed_dim], perm[-1] = perm[-1], perm[packed_dim] + raw = raw.permute(perm).contiguous() + else: + perm = None + + lo = (raw & 0x0F).to(torch.int64) + hi = (raw >> 4).to(torch.int64) + lut = _FP4_DECODE_TABLE.to(raw.device) + + unpacked_shape = list(raw.shape) + unpacked_shape[-1] *= 2 + unpacked = torch.empty(unpacked_shape, dtype=torch.float32, device=raw.device) + unpacked[..., ::2] = lut[lo] + unpacked[..., 1::2] = lut[hi] + + if perm is not None: + unpacked = unpacked.permute(perm) + return unpacked + + +def slice_tensor_logical_dim( + tensor: torch.Tensor, dim: int, start: int, end: int +) -> torch.Tensor: + """Slice along the logical dimension, compensating for FP4 packing when needed.""" + if tensor.dtype == torch.float4_e2m1fn_x2 and dim == fp4_packed_dim(tensor): + if start % 2 != 0 or end % 2 != 0: + raise ValueError( + f"FP4 packed slicing requires even indices, got start={start}, end={end}." + ) + start = start // 2 + end = end // 2 + return tensor.narrow(dim, start, end - start) + + +def dequant_block_scale_to_fp32( + data: torch.Tensor, + raw_scale: torch.Tensor, + blocksize: int, + global_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Dequantize a single 2D tensor using raw block scales into fp32.""" + if data.dtype == torch.float4_e2m1fn_x2: + data_fp32 = unpack_fp4_to_f32(data) + else: + data_fp32 = data.to(torch.float32) + + if data_fp32.dim() != 2 or raw_scale.dim() != 2: + raise ValueError( + f"Expected 2D tensors, got data={data_fp32.dim()}D raw_scale={raw_scale.dim()}D." + ) + + expected_scale_shape = (data_fp32.shape[0], ceil_div(data_fp32.shape[1], blocksize)) + if tuple(raw_scale.shape) != expected_scale_shape: + raise ValueError( + f"Scale shape mismatch: expected {expected_scale_shape}, got {tuple(raw_scale.shape)}." + ) + + scale_fp32 = raw_scale.to(torch.float32) + expanded_scale = scale_fp32.repeat_interleave(blocksize, dim=-1)[ + :, : data_fp32.shape[1] + ] + result = data_fp32 * expanded_scale + + if global_scale is not None: + result = result * global_scale.to(torch.float32).reshape(1, 1) + return result + + +def transpose_rhs_for_block_dequant(data: torch.Tensor) -> torch.Tensor: + """Convert a (K, N) RHS slice into an (N, K) tensor for block dequant along K.""" + if data.dim() != 2: + raise ValueError(f"Expected 2D RHS tensor, got {data.dim()}D.") + if data.dtype == torch.float4_e2m1fn_x2: + # Avoid contiguous()/copy_ on FP4 tensors; unpack first, then transpose in fp32. + return unpack_fp4_to_f32(data).transpose(0, 1) + return data.transpose(0, 1) + + +# ============================================================================= +# Host Validation +# ============================================================================= + + +@dataclass +class ProblemDesc: + tokens: int + experts: int + top_k_select: int + balance_route: bool + hidden: int + intermediate: int + scenario: Literal["2Dx3D", "2Dx2D"] + kind: Literal["mxfp8", "mxfp4", "nvfp4"] + out_dtype: torch.dtype = torch.bfloat16 + acc_dtype: torch.dtype = torch.float32 + grad_accumulate: bool = False + # If True, the user guarantees activation tensors (with tokens_sum dim) + # are padded per-group to the same granularity as the block-scale layout: + # 2Dx3D (groups along M): each group's M_g padded to 128 + # 2Dx2D (groups along K): each group's K_g padded to sf_vec_size * 4 + # This enables the kernel to skip padded-offset computation. + # Currently NOT implemented — forced to False at CLI level. + consistent_token_padding: bool = False + # GEMM-domain layout control (which axis is stride-1) + # Only effective for FP8. FP4 always uses the torch-expected layout + # (K stride-1 for both A and B). + # A (M, K): "k_major" → K stride-1 (default) | "m_major" → M stride-1 + # B (N, K): "k_major" → K stride-1 (default) | "n_major" → N stride-1 + # C (M, N): "n_major" → N stride-1 (default) | "m_major" → M stride-1 + # Note: default b_layout is "k_major" (unlike torch_grouped_mm.py's "n_major") + # because torch.nn.functional.scaled_grouped_mm expects K stride-1 for B. + a_layout: Literal["k_major", "m_major"] = "k_major" + b_layout: Literal["k_major", "n_major"] = "k_major" + c_layout: Literal["n_major", "m_major"] = "n_major" + + def __str__(self) -> str: + d = lambda t: str(t).split(".")[-1] + route = "balanced" if self.balance_route else "random" + return ( + f"ProblemDesc: {self.scenario} | kind={self.kind} | " + f"tokens={self.tokens} experts={self.experts} " + f"top_k={self.top_k_select} route={route} | " + f"hidden={self.hidden} intermediate={self.intermediate} | " + f"out={d(self.out_dtype)} acc={d(self.acc_dtype)} " + f"grad_acc={self.grad_accumulate} " + f"consistent_pad={self.consistent_token_padding} | " + f"layout: A={self.a_layout} B={self.b_layout} C={self.c_layout}" + ) + + +@dataclass +class ImplDesc: + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64) + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1) + use_2cta_instrs: bool = False + static_expert_cnt: Optional[int] = None + separate_tensormap_init: bool = True + + def __str__(self) -> str: + tile = ",".join(map(str, self.mma_tiler_mnk)) + cluster = ",".join(map(str, self.cluster_shape_mnk)) + static_e = ( + self.static_expert_cnt if self.static_expert_cnt is not None else "dynamic" + ) + return ( + f"ImplDesc: tile={tile} cluster={cluster} " + f"2cta={self.use_2cta_instrs} | " + f"static_E={static_e} sep_tmap={self.separate_tensormap_init}" + ) + + +@dataclass +class MiscDesc: + perf_run: bool = False + perf_e2e: bool = False + compare_with_sol: bool = False + no_torch_210: bool = field(init=False) + + def __post_init__(self): + self.no_torch_210 = torch_version_lt(2, 10) + if self.perf_e2e and not self.perf_run: + raise ValueError("--perf_e2e requires --perf_run to be enabled.") + if self.perf_e2e and self.compare_with_sol: + raise ValueError( + "--perf_e2e and --compare_with_sol are mutually exclusive." + ) + + def __str__(self) -> str: + return ( + f"MiscDesc: perf={self.perf_run} perf_e2e={self.perf_e2e} " + f"sol={self.compare_with_sol} no_torch_210={self.no_torch_210}" + ) + + +class ScaledGroupedGemmTester: + def __init__(self, problem: ProblemDesc, impl: ImplDesc, misc: MiscDesc): + self.problem = problem + self.impl = impl + self.misc = misc + + self.cfg = _FORMAT_CONFIG[problem.kind] + self.tokens_after_repeat = problem.tokens * problem.top_k_select + self.expert_cnt = problem.experts + self.hidden = problem.hidden + self.intermediate = problem.intermediate + + self.A_tensor: Optional[torch.Tensor] = None + self.B_tensor: Optional[torch.Tensor] = None + self.C_tensor: Optional[torch.Tensor] = None + self.C_ref_tensor: Optional[torch.Tensor] = None + self.scale_a_tensor: Optional[torch.Tensor] = None + self.scale_b_tensor: Optional[torch.Tensor] = None + self.raw_scale_a_tensors: Optional[list[torch.Tensor]] = None + self.raw_scale_b_tensors: Optional[list[torch.Tensor]] = None + self.global_scale_a: Optional[torch.Tensor] = None + self.global_scale_b: Optional[torch.Tensor] = None + self.offs_tensor: Optional[torch.Tensor] = None + self.workspace_tensor: Optional[torch.Tensor] = None + + if problem.grad_accumulate and problem.scenario == "2Dx3D": + raise ValueError( + "grad_accumulate only makes sense for 2Dx2D (weight grad) scenario." + ) + + # ----------------------------------------------------------------- + # Offs generation (aligned to blocksize) + # ----------------------------------------------------------------- + + def _generate_offs(self) -> torch.Tensor: + """Generate group-end offsets aligned to blocksize. + + Some experts may receive 0 tokens (valid in real MoE routing). + Each non-empty group's size is a multiple of blocksize. + """ + blocksize = self.cfg["blocksize"] + total = self.tokens_after_repeat + expert_cnt = self.expert_cnt + + assert total % blocksize == 0, ( + f"tokens_after_repeat ({total}) must be divisible by " + f"blocksize ({blocksize})" + ) + n_slots = total // blocksize + + if self.problem.balance_route: + # Distribute as evenly as possible; some experts get 0 if n_slots < expert_cnt + base = n_slots // expert_cnt + remainder = n_slots % expert_cnt + slots = [base + (1 if i < remainder else 0) for i in range(expert_cnt)] + else: + # Dirichlet distribution: naturally allows 0-size groups + # alpha=1.0 → uniform on simplex (moderate variation) + # alpha<1.0 → skewed (few experts get most tokens) + # alpha>1.0 → more uniform + proportions = np.random.dirichlet([0.5] * expert_cnt) + raw = np.floor(proportions * n_slots).astype(int) + deficit = n_slots - raw.sum() + while deficit > 0: + idx = int(np.argmin(raw / (proportions * n_slots + 1e-12))) + raw[idx] += 1 + deficit -= 1 + while deficit < 0: + ratios = np.where( + raw > 0, + raw / (proportions * n_slots + 1e-12), + -np.inf, + ) + idx = int(np.argmax(ratios)) + raw[idx] -= 1 + deficit += 1 + slots = raw.tolist() + + assert sum(slots) == n_slots + + cum = 0 + offsets = [] + for s in slots: + cum += s * blocksize + offsets.append(cum) + return torch.tensor(offsets, dtype=torch.int32, device="cuda") + + # ----------------------------------------------------------------- + # Tensor creation helpers + # ----------------------------------------------------------------- + + def _create_fp8_tensor(self, shape: tuple) -> torch.Tensor: + """Create FP8 tensor. + + - correctness mode: randint {-1, 0, 1} via bf16 cast + - perf mode: random valid fp8 bit patterns via uint8 + (float8_e4m3fn NaN encodings 0x7F/0xFF are replaced) + """ + data_dtype = self.cfg["data_dtype"] + elem_cnt = 1 + for s in shape: + elem_cnt *= s + if self.misc.perf_run: + raw = torch.randint(0, 256, (elem_cnt,), dtype=torch.uint8, device="cuda") + # float8_e4m3fn: 0x7F and 0xFF are NaN → clamp to valid max + if data_dtype == torch.float8_e4m3fn: + raw[raw == 0x7F] = 0x7E + raw[raw == 0xFF] = 0xFE + return raw.view(data_dtype).reshape(shape) + else: + return ( + torch.randint(-1, 2, (elem_cnt,), dtype=torch.bfloat16, device="cuda") + .to(data_dtype) + .reshape(shape) + ) + + def _create_fp4_tensor( + self, logical_shape: tuple, packed_dim: int = -1 + ) -> torch.Tensor: + """Create FP4 tensor. + + Args: + logical_shape: shape in FP4 elements (packed_dim size must be even). + packed_dim: dimension to pack (halve). This dim becomes stride-1. + + - perf mode: random uint8 bytes (all 256 values are valid FP4 pairs, + FP4 e2m1 has no NaN/inf). No nibble mapping needed. + - correctness mode: index→nibble mapping for values {0, 1, -1}, + then explicit nibble packing. + + Returns: + float4_e2m1fn_x2 tensor with packed_dim halved and stride-1. + """ + ndim = len(logical_shape) + packed_dim = packed_dim % ndim + assert logical_shape[packed_dim] % 2 == 0, ( + f"packed_dim {packed_dim} size ({logical_shape[packed_dim]}) must be even" + ) + + if self.misc.perf_run: + # All 256 byte values are valid FP4 pairs — just random bytes + elem_cnt = 1 + for s in logical_shape: + elem_cnt *= s + byte_cnt = elem_cnt // 2 + + flat = torch.randint(0, 256, (byte_cnt,), dtype=torch.uint8, device="cuda") + + # Build shape with packed dim moved to last and halved + shape_reordered = list(logical_shape) + need_perm = packed_dim != ndim - 1 + if need_perm: + shape_reordered[packed_dim], shape_reordered[-1] = ( + shape_reordered[-1], + shape_reordered[packed_dim], + ) + shape_reordered[-1] //= 2 + + tensor = flat.view(torch.float4_e2m1fn_x2).reshape(shape_reordered) + + if need_perm: + perm = list(range(ndim)) + perm[packed_dim], perm[-1] = perm[-1], perm[packed_dim] + tensor = tensor.permute(perm) + return tensor + + # ── Correctness mode: index→nibble mapping + explicit pack ── + # Use uint8 + masked_fill_ instead of int64 fancy indexing to avoid + # 16x memory overhead (int64 = 8 bytes vs FP4 = 0.5 bytes per element). + + nibbles = torch.randint(0, 3, logical_shape, dtype=torch.uint8, device="cuda") + nibbles.masked_fill_(nibbles == 2, 0xA) + nibbles.masked_fill_(nibbles == 1, 0x2) + + # Move packed_dim to last for packing + need_perm = packed_dim != ndim - 1 + if need_perm: + perm_to_last = list(range(ndim)) + perm_to_last[packed_dim], perm_to_last[-1] = ( + perm_to_last[-1], + perm_to_last[packed_dim], + ) + nibbles = nibbles.permute(perm_to_last).contiguous() + + # Pack pairs along last dim: byte = (odd_nibble << 4) | even_nibble + even = nibbles[..., ::2] + odd = nibbles[..., 1::2] + packed_uint8 = (odd << 4) | even + + tensor = packed_uint8.view(torch.float4_e2m1fn_x2) + + if need_perm: + inv_perm = list(range(ndim)) + inv_perm[packed_dim], inv_perm[-1] = inv_perm[-1], inv_perm[packed_dim] + tensor = tensor.permute(inv_perm) + + return tensor + + def _create_scale_tensor(self, shape: tuple) -> torch.Tensor: + """Scale tensor: random values {1, 2} (exact in all scale dtypes).""" + elem_cnt = 1 + for s in shape: + elem_cnt *= s + return ( + torch.randint(1, 3, (elem_cnt,), dtype=torch.float32, device="cuda") + .to(self.cfg["scale_dtype"]) + .reshape(shape) + ) + + def _generate_raw_scales( + self, group_sizes: list[int] + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + blocksize = self.cfg["blocksize"] + scale_dtype = self.cfg["scale_dtype"] + device = self.A_tensor.device.type if self.A_tensor is not None else "cuda" + + if self.problem.scenario == "2Dx3D": + raw_scale_a = [ + create_raw_scale_tensor( + non_k_size=group_size, + k_size=self.hidden, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for group_size in group_sizes + ] + raw_scale_b = [ + create_raw_scale_tensor( + non_k_size=self.intermediate, + k_size=self.hidden, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for _ in range(self.expert_cnt) + ] + else: + raw_scale_a = [ + create_raw_scale_tensor( + non_k_size=self.hidden, + k_size=group_size, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for group_size in group_sizes + ] + raw_scale_b = [ + create_raw_scale_tensor( + non_k_size=self.intermediate, + k_size=group_size, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for group_size in group_sizes + ] + + return raw_scale_a, raw_scale_b + + def _assemble_scales_from_raw( + self, raw_scale_a: list[torch.Tensor], raw_scale_b: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.problem.scenario == "2Dx3D": + scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a) + scale_b = assemble_raw_scales_2d3d_3d_side(raw_scale_b) + else: + scale_a = assemble_raw_scales_2d2d(raw_scale_a, self.hidden) + scale_b = assemble_raw_scales_2d2d(raw_scale_b, self.intermediate) + return scale_a, scale_b + + # ----------------------------------------------------------------- + # generate_inputs + # ----------------------------------------------------------------- + + def generate_inputs(self) -> None: + self.offs_tensor = self._generate_offs() + group_sizes = offs_to_group_sizes(self.offs_tensor) + + tokens = self.tokens_after_repeat + hidden = self.hidden + intermediate = self.intermediate + expert_cnt = self.expert_cnt + blocksize = self.cfg["blocksize"] + is_fp4 = self.cfg["data_dtype"] == torch.float4_e2m1fn_x2 + + if is_fp4: + if self.problem.a_layout != "k_major": + print("WARNING: FP4 ignores a_layout, always uses k_major (K stride-1)") + if self.problem.b_layout != "k_major": + print("WARNING: FP4 ignores b_layout, always uses k_major (K stride-1)") + + if self.problem.scenario == "2Dx3D": + # ── Data tensors ── + # PyTorch: A (tokens, hidden), B (expert_cnt, hidden, intermediate) + # GEMM: A (M=tokens, K=hidden), B (N=intermediate, K=hidden, L=expert_cnt) + + # A: (tokens, hidden) — K=hidden is last dim + if is_fp4: + self.A_tensor = self._create_fp4_tensor((tokens, hidden), packed_dim=-1) + elif self.problem.a_layout == "k_major": + self.A_tensor = self._create_fp8_tensor((tokens, hidden)) + else: # m_major + self.A_tensor = self._create_fp8_tensor((hidden, tokens)).T + + # B: (expert_cnt, hidden, intermediate) — K=hidden is dim 1 + if is_fp4: + self.B_tensor = self._create_fp4_tensor( + (expert_cnt, hidden, intermediate), packed_dim=1 + ) + elif self.problem.b_layout == "k_major": + self.B_tensor = self._create_fp8_tensor( + (expert_cnt, intermediate, hidden) + ).transpose(1, 2) + else: # n_major + self.B_tensor = self._create_fp8_tensor( + (expert_cnt, hidden, intermediate) + ) + + # C: (tokens, intermediate) + # GEMM C (M=tokens, N=intermediate): n_major → N stride-1; m_major → M stride-1 + if self.problem.c_layout == "n_major": + self.C_tensor = torch.full( + (tokens, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: # m_major + self.C_tensor = torch.full( + (intermediate, tokens), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).T + + # ── Scale tensors ── + K_fixed = hidden + sfa_shape = compute_scale_shape( + "2Dx3D", + "a", + group_sizes, + hidden, + intermediate, + K_fixed, + blocksize, + expert_cnt, + ) + sfb_shape = compute_scale_shape( + "2Dx3D", + "b", + group_sizes, + hidden, + intermediate, + K_fixed, + blocksize, + expert_cnt, + ) + + elif self.problem.scenario == "2Dx2D": + # ── Data tensors ── + # PyTorch: A (hidden, tokens), B (tokens, intermediate) + # GEMM: A (M=hidden, K=tokens), B (N=intermediate, K=tokens, L=expert_cnt) + + # A: (hidden, tokens) — K=tokens is last dim + if is_fp4: + self.A_tensor = self._create_fp4_tensor((hidden, tokens), packed_dim=-1) + elif self.problem.a_layout == "k_major": + self.A_tensor = self._create_fp8_tensor((hidden, tokens)) + else: # m_major + self.A_tensor = self._create_fp8_tensor((tokens, hidden)).T + + # B: (tokens, intermediate) — K=tokens is dim 0 + if is_fp4: + self.B_tensor = self._create_fp4_tensor( + (tokens, intermediate), packed_dim=0 + ) + elif self.problem.b_layout == "k_major": + self.B_tensor = self._create_fp8_tensor((intermediate, tokens)).T + else: # n_major + self.B_tensor = self._create_fp8_tensor((tokens, intermediate)) + + # C: (expert_cnt, hidden, intermediate) + # GEMM C (M=hidden, N=intermediate): n_major → N stride-1; m_major → M stride-1 + if self.problem.c_layout == "n_major": + if self.problem.grad_accumulate: + self.C_tensor = torch.zeros( + (expert_cnt, hidden, intermediate), + dtype=self.problem.out_dtype, + device="cuda", + ) + else: + self.C_tensor = torch.full( + (expert_cnt, hidden, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: # m_major + if self.problem.grad_accumulate: + self.C_tensor = torch.zeros( + (expert_cnt, intermediate, hidden), + dtype=self.problem.out_dtype, + device="cuda", + ).transpose(1, 2) + else: + self.C_tensor = torch.full( + (expert_cnt, intermediate, hidden), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).transpose(1, 2) + + # ── Scale tensors ── + K_total = tokens + sfa_shape = compute_scale_shape( + "2Dx2D", + "a", + group_sizes, + hidden, + intermediate, + K_total, + blocksize, + expert_cnt, + ) + sfb_shape = compute_scale_shape( + "2Dx2D", + "b", + group_sizes, + hidden, + intermediate, + K_total, + blocksize, + expert_cnt, + ) + else: + raise ValueError(f"Unknown scenario: {self.problem.scenario}") + + self.raw_scale_a_tensors, self.raw_scale_b_tensors = self._generate_raw_scales( + group_sizes + ) + self.scale_a_tensor, self.scale_b_tensor = self._assemble_scales_from_raw( + self.raw_scale_a_tensors, self.raw_scale_b_tensors + ) + assert tuple(self.scale_a_tensor.shape) == tuple(sfa_shape), ( + f"scale_a shape mismatch: expected {sfa_shape}, " + f"got {tuple(self.scale_a_tensor.shape)}" + ) + assert tuple(self.scale_b_tensor.shape) == tuple(sfb_shape), ( + f"scale_b shape mismatch: expected {sfb_shape}, " + f"got {tuple(self.scale_b_tensor.shape)}" + ) + + # NVFP4: per-expert global scales + if self.cfg["has_global_scale"]: + self.global_scale_a = torch.randint( + 1, 3, (expert_cnt,), dtype=torch.float32, device="cuda" + ) + self.global_scale_b = torch.randint( + 1, 3, (expert_cnt,), dtype=torch.float32, device="cuda" + ) + + # ----------------------------------------------------------------- + # Reference preparation + # ----------------------------------------------------------------- + + @staticmethod + def _prepare_ref_ab( + tensor: torch.Tensor, + k_dim: int, + pad_k_size: Optional[int] = None, + pad_non_k_size: Optional[int] = None, + ) -> torch.Tensor: + """Prepare a ref tensor: make ``k_dim`` stride-1 and optionally pad. + + Args: + tensor: input data tensor (A or B). + k_dim: which dimension is K (must become stride-1). + pad_k_size: zero-pad K to this size (workaround: PyTorch 3D + scale validation uses floor division for K // blocksize). + pad_non_k_size: zero-pad the trailing dim (N) to this size + (workaround: PyTorch requires trailing dim % 16 == 0). + Only effective when ``k_dim`` is not the trailing dim. + + All padding happens in the permuted-contiguous space (standard layout) + so it is safe for packed sub-byte types like float4_e2m1fn_x2. + After permute(k_dim↔last), K is last and N is second-to-last: + F.pad(t, (0, k_pad)) -> pads K (last dim) + F.pad(t, (0, 0, 0, n_pad)) -> pads N (second-to-last dim) + The final permute restores the original dim order with K stride-1. + """ + ndim = tensor.dim() + k_dim = k_dim % ndim + needs_k_pad = pad_k_size is not None and pad_k_size > tensor.shape[k_dim] + needs_n_pad = ( + pad_non_k_size is not None + and k_dim != ndim - 1 + and pad_non_k_size > tensor.shape[-1] + ) + if tensor.stride(k_dim) == 1 and not needs_k_pad and not needs_n_pad: + return tensor + print( + f"WARNING: _prepare_ref_ab is copying/padding k_dim={k_dim} " + f"(stride={tensor.stride(k_dim)}, " + f"pad_k={'yes' if needs_k_pad else 'no'}, " + f"pad_n={'yes' if needs_n_pad else 'no'}); " + f"perf comparison with the kernel is not apples-to-apples." + ) + perm = list(range(ndim)) + perm[k_dim], perm[-1] = perm[-1], perm[k_dim] + orig_dtype = tensor.dtype + t = tensor.permute(perm).contiguous() + if needs_k_pad or needs_n_pad: + t = t.view(torch.uint8) + if needs_k_pad: + t = torch.nn.functional.pad(t, (0, pad_k_size - t.shape[-1])) + if needs_n_pad: + t = torch.nn.functional.pad(t, (0, 0, 0, pad_non_k_size - t.shape[-2])) + t = t.view(orig_dtype) + res = t.permute(perm) + return res + + def _prepare_ref_tensors( + self, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare A and B for torch.nn.functional.scaled_grouped_mm. + + The torch API requires K to be stride-1 for both A and B. + For FP8 with non-standard layout, we permute+contiguous. + For FP4, tensors are already created with K stride-1. + + WORKAROUND (two PyTorch bugs in scaled_grouped_mm): + 1. 3D scale validation uses K // blocksize (floor) instead of ceil_div, + producing zero-sized expectations when K < blocksize. + Fix: zero-pad data along K to the next blocksize multiple. + Safe because K is the reduction dimension (zero * scale = zero). + 2. Requires mat_a.size(-1) % 16 == 0 and mat_b.size(-1) % 16 == 0 + regardless of which dimension is stride-1. + Fix: zero-pad B's trailing dim (N=intermediate) to next 16-multiple. + Safe because padded N columns produce zero output columns; the + reference output is sliced back in compute_reference. + """ + blocksize = self.cfg["blocksize"] + # For the torch's incomplete and unreasonable check. + N_padded = round_up(self.problem.intermediate, 16) + + if self.problem.scenario == "2Dx3D": + K_padded = round_up(self.problem.hidden, blocksize) + if self.problem.kind in ["nvfp4", "mxfp4"]: + K_padded = K_padded // 2 + # A: (tokens, hidden) — K=hidden is dim -1 + ref_a = self._prepare_ref_ab(self.A_tensor, k_dim=-1, pad_k_size=K_padded) + # B: (expert_cnt, hidden, intermediate) — K=hidden dim 1, N=intermediate dim -1 + ref_b = self._prepare_ref_ab( + self.B_tensor, k_dim=1, pad_k_size=K_padded, pad_non_k_size=N_padded + ) + else: + # A: (hidden, tokens) — K=tokens is dim -1 + # 2Dx2D: K=total_tokens, already blocksize-aligned by _generate_offs + ref_a = self._prepare_ref_ab(self.A_tensor, k_dim=-1) + # B: (tokens, intermediate) — K=tokens dim 0, N=intermediate dim -1 + ref_b = self._prepare_ref_ab( + self.B_tensor, k_dim=0, pad_non_k_size=N_padded + ) + return ref_a, ref_b + + def _compute_reference_manual_2d2d(self) -> torch.Tensor: + group_sizes = offs_to_group_sizes(self.offs_tensor) + results = [] + prev = 0 + blocksize = self.cfg["blocksize"] + + for expert_idx, group_size in enumerate(group_sizes): + cur = prev + group_size + a_slice = slice_tensor_logical_dim( + self.A_tensor, dim=1, start=prev, end=cur + ) + b_slice = slice_tensor_logical_dim( + self.B_tensor, dim=0, start=prev, end=cur + ) + + global_scale_a = ( + self.global_scale_a[expert_idx : expert_idx + 1] + if self.global_scale_a is not None + else None + ) + global_scale_b = ( + self.global_scale_b[expert_idx : expert_idx + 1] + if self.global_scale_b is not None + else None + ) + + a_fp32 = dequant_block_scale_to_fp32( + a_slice, + self.raw_scale_a_tensors[expert_idx], + blocksize, + global_scale_a, + ) + b_fp32_t = dequant_block_scale_to_fp32( + transpose_rhs_for_block_dequant(b_slice), + self.raw_scale_b_tensors[expert_idx], + blocksize, + global_scale_b, + ) + b_fp32 = b_fp32_t.transpose(0, 1) + results.append((a_fp32 @ b_fp32).to(self.problem.out_dtype)) + prev = cur + + return torch.stack(results, dim=0) + + def _compute_reference_manual_2d3d(self) -> torch.Tensor: + group_sizes = offs_to_group_sizes(self.offs_tensor) + results = [] + prev = 0 + blocksize = self.cfg["blocksize"] + + for expert_idx, group_size in enumerate(group_sizes): + cur = prev + group_size + a_slice = slice_tensor_logical_dim( + self.A_tensor, dim=0, start=prev, end=cur + ) + b_slice = self.B_tensor[expert_idx] + + global_scale_a = ( + self.global_scale_a[expert_idx : expert_idx + 1] + if self.global_scale_a is not None + else None + ) + global_scale_b = ( + self.global_scale_b[expert_idx : expert_idx + 1] + if self.global_scale_b is not None + else None + ) + + a_fp32 = dequant_block_scale_to_fp32( + a_slice, + self.raw_scale_a_tensors[expert_idx], + blocksize, + global_scale_a, + ) + b_fp32_t = dequant_block_scale_to_fp32( + transpose_rhs_for_block_dequant(b_slice), + self.raw_scale_b_tensors[expert_idx], + blocksize, + global_scale_b, + ) + b_fp32 = b_fp32_t.transpose(0, 1) + results.append((a_fp32 @ b_fp32).to(self.problem.out_dtype)) + prev = cur + + return torch.cat(results, dim=0) + + def _compute_reference_manual(self) -> None: + if self.raw_scale_a_tensors is None or self.raw_scale_b_tensors is None: + raise RuntimeError("Raw scale tensors must be generated before manual ref.") + + if self.problem.scenario == "2Dx2D": + self.C_ref_tensor = self._compute_reference_manual_2d2d() + else: + self.C_ref_tensor = self._compute_reference_manual_2d3d() + + def _compute_reference_torch(self) -> None: + from torch.nn.functional import scaled_grouped_mm, ScalingType, SwizzleType + + ref_a, ref_b = self._prepare_ref_tensors() + + if self.problem.kind in ("mxfp8", "mxfp4"): + scale_a_arg = self.scale_a_tensor + scale_b_arg = self.scale_b_tensor + recipe_a = ScalingType.BlockWise1x32 + recipe_b = ScalingType.BlockWise1x32 + else: # nvfp4 + scale_a_arg = [self.scale_a_tensor, self.global_scale_a] + scale_b_arg = [self.scale_b_tensor, self.global_scale_b] + recipe_a = [ScalingType.BlockWise1x16, ScalingType.TensorWise] + recipe_b = [ScalingType.BlockWise1x16, ScalingType.TensorWise] + + swizzle = SwizzleType.SWIZZLE_32_4_4 + ref_result = scaled_grouped_mm( + ref_a, + ref_b, + scale_a=scale_a_arg, + scale_recipe_a=recipe_a, + scale_b=scale_b_arg, + scale_recipe_b=recipe_b, + swizzle_a=swizzle, + swizzle_b=swizzle, + offs=self.offs_tensor, + output_dtype=self.problem.out_dtype, + ) + + self.C_ref_tensor = ref_result[..., : self.problem.intermediate] + + # ----------------------------------------------------------------- + # compute_reference + # ----------------------------------------------------------------- + + def compute_reference(self) -> None: + if self.misc.perf_run: + return + if self.misc.no_torch_210: + self._compute_reference_manual() + else: + self._compute_reference_torch() + + # ----------------------------------------------------------------- + # Kernel execution (stub — to be filled when kernel is implemented) + # ----------------------------------------------------------------- + + def create_kernel(self) -> ScaledGroupedGemmKernel: + _torch_to_cutlass = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + } + return ScaledGroupedGemmKernel( + scenario=self.problem.scenario, + sf_vec_size=self.cfg["blocksize"], + accumulate_on_output=( + self.problem.grad_accumulate and self.problem.scenario == "2Dx2D" + ), + separate_tensormap_init=self.impl.separate_tensormap_init, + consistent_token_padding=self.problem.consistent_token_padding, + acc_dtype=_torch_to_cutlass[self.problem.acc_dtype], + mma_tiler_mnk=self.impl.mma_tiler_mnk, + cluster_shape_mnk=self.impl.cluster_shape_mnk, + use_2cta_instrs=self.impl.use_2cta_instrs, + fixed_expert_cnt=self.impl.static_expert_cnt, + ) + + def run_kernel(self, kernel: ScaledGroupedGemmKernel) -> Optional[float]: + """Run our CuTe kernel. + + Returns: + Average kernel time in ms when perf_e2e is enabled, None otherwise. + """ + _torch_to_cutlass = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, + torch.float4_e2m1fn_x2: cutlass.Float4E2M1FN, + } + if hasattr(torch, "float8_e8m0fnu"): + _torch_to_cutlass[torch.float8_e8m0fnu] = cutlass.Float8E8M0FNU + + # Allocate workspace + workspace_size = kernel.get_workspace_size(self.expert_cnt) + self.workspace_tensor = torch.full( + (workspace_size,), 255, dtype=torch.uint8, device="cuda" + ) + torch.cuda.synchronize() + + # Convert torch tensors → cute tensors + data_dtype = _torch_to_cutlass[self.cfg["data_dtype"]] + sf_cutlass_dtype = _torch_to_cutlass[self.cfg["scale_dtype"]] + out_cutlass_dtype = _torch_to_cutlass[self.problem.out_dtype] + + is_dynamic_expert_cnt = self.impl.static_expert_cnt is None + + def torch_tensor_to_cute_tensor_with_dyn_layout( + torch_tensor: torch.Tensor, + ) -> cute.Tensor: + cute_tensor = cutlass_torch.from_dlpack(torch_tensor) + leading_dim = cutlass_torch.get_leading_dim(torch_tensor) + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + return cute_tensor + + a_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.A_tensor) + b_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.B_tensor) + scale_a_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.scale_a_tensor) + scale_b_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.scale_b_tensor) + c_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.C_tensor) + offs_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.offs_tensor) + workspace_cute = torch_tensor_to_cute_tensor_with_dyn_layout( + self.workspace_tensor + ) + + # Query max active clusters from hardware + cluster_size = self.impl.cluster_shape_mnk[0] * self.impl.cluster_shape_mnk[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + # Prepare optional NVFP4 global scales + global_scale_a_cute = None + global_scale_b_cute = None + if self.global_scale_a is not None: + global_scale_a_cute = torch_tensor_to_cute_tensor_with_dyn_layout( + self.global_scale_a + ) + global_scale_b_cute = torch_tensor_to_cute_tensor_with_dyn_layout( + self.global_scale_b + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + if self.misc.perf_e2e: + compiled = cute.compile( + kernel, + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + max_active_clusters, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + + warmup_iters = 4 + timed_iters = 4 + + for _ in range(warmup_iters): + compiled( + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + torch.cuda.synchronize() + + times = [] + for _ in range(timed_iters): + l2_flush() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + compiled( + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + end_evt.record() + torch.cuda.synchronize() + times.append(start_evt.elapsed_time(end_evt)) + + avg_ms = sum(times) / len(times) + print(f"[perf_e2e] Individual times (ms): {[f'{t:.4f}' for t in times]}") + print(f"[perf_e2e] Average kernel time: {avg_ms:.4f} ms") + return avg_ms + else: + l2_flush() + kernel( + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + max_active_clusters, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + torch.cuda.synchronize() + return None + + # ----------------------------------------------------------------- + # Validation + # ----------------------------------------------------------------- + + def validate(self) -> None: + if self.misc.perf_run: + return + using_torch_ref = not self.misc.no_torch_210 + if using_torch_ref and self.problem.scenario == "2Dx2D": + # Pytorch bug: zero token does not write out due to the incorrect arg setting. + self.C_ref_tensor = self.C_ref_tensor.contiguous() + group_sizes = offs_to_group_sizes(self.offs_tensor) + for i, g in enumerate(group_sizes): + if g == 0: + self.C_ref_tensor[i].zero_() + if using_torch_ref and ( + self.problem.scenario == "2Dx3D" + and self.tokens_after_repeat // self.expert_cnt == 0 + ): + print( + "Warning: Due to the Pytorch 2.10 FBGEMM bug (incorrect `M/G` early exit), ref tensor will be all 0 in this case, skip ref check." + ) + return + try: + diff = (self.C_tensor - self.C_ref_tensor).float().abs() + max_diff = diff.max().item() + if max_diff == 0.0: + print("Validation PASSED (exact match)") + else: + print( + f"C_tensor: shape={tuple(self.C_tensor.shape)} " + f"stride={self.C_tensor.stride()} dtype={self.C_tensor.dtype}" + ) + print( + f"C_ref_tensor: shape={tuple(self.C_ref_tensor.shape)} " + f"stride={self.C_ref_tensor.stride()} dtype={self.C_ref_tensor.dtype}" + ) + print( + f"Validation FAILED: " + f"max_diff={max_diff} " + f"mean_diff={diff.mean().item()}" + ) + assert False, "C_tensor != C_ref_tensor" + except torch.cuda.OutOfMemoryError: + print("OOM during diff computation, falling back to torch.equal") + assert torch.equal(self.C_tensor, self.C_ref_tensor), ( + "C_tensor != C_ref_tensor" + ) + + # ----------------------------------------------------------------- + # SOL comparison + # ----------------------------------------------------------------- + + def run_sol_comparison(self) -> None: + """Run a dense batched block-scaled GEMM as Speed-of-Light reference. + + Reuses the same tensor memory from the grouped run by passing + raw pointers with a batched problem_mnkl -- zero GPU allocation. + """ + import sys, os + + _examples_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") + ) + if _examples_root not in sys.path: + sys.path.insert(0, _examples_root) + + from blackwell.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent import ( + Sm100BlockScaledPersistentDenseGemmKernel, + ) + from cutlass.cute.nvgpu import OperandMajorMode + from cutlass.cute.runtime import make_ptr + + tokens = self.tokens_after_repeat + experts = self.expert_cnt + blocksize = self.cfg["blocksize"] + n_slots = tokens // blocksize + assert tokens % blocksize == 0 and n_slots % experts == 0, ( + f"compare_with_sol requires tokens*top_k ({tokens}) to be " + f"divisible by blocksize ({blocksize}), and the resulting " + f"n_slots ({n_slots}) evenly divisible by experts ({experts}) " + f"so every group has exactly the same size" + ) + tpe = tokens // experts + + if self.problem.scenario == "2Dx3D": + M, N, K, L = tpe, self.intermediate, self.hidden, experts + else: # 2Dx2D + M, N, K, L = self.hidden, self.intermediate, tpe, experts + + # Dtype mapping + _torch_to_cutlass = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, + torch.float4_e2m1fn_x2: cutlass.Float4E2M1FN, + } + if hasattr(torch, "float8_e8m0fnu"): + _torch_to_cutlass[torch.float8_e8m0fnu] = cutlass.Float8E8M0FNU + + data_dtype = _torch_to_cutlass[self.cfg["data_dtype"]] + sf_dtype = _torch_to_cutlass[self.cfg["scale_dtype"]] + out_dtype = _torch_to_cutlass[self.problem.out_dtype] + + # Layout mapping + a_major = ( + OperandMajorMode.K + if self.problem.a_layout == "k_major" + else OperandMajorMode.MN + ) + b_major = ( + OperandMajorMode.K + if self.problem.b_layout == "k_major" + else OperandMajorMode.MN + ) + c_layout = ( + utils.LayoutEnum.ROW_MAJOR + if self.problem.c_layout == "n_major" + else utils.LayoutEnum.COL_MAJOR + ) + layouts = (a_major, b_major, c_layout) + + # Construct pointers from existing grouped tensors + a_ptr = make_ptr( + data_dtype, + self.A_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + data_dtype, + self.B_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + sfa_ptr = make_ptr( + sf_dtype, + self.scale_a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + sfb_ptr = make_ptr( + sf_dtype, + self.scale_b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + c_ptr = make_ptr( + out_dtype, + self.C_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + + mma_tiler_mn = self.impl.mma_tiler_mnk[:2] + cluster_shape_mn = self.impl.cluster_shape_mnk[:2] + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + sol_kernel = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size=self.cfg["blocksize"], + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + + problem_mnkl = ( + cutlass.Int32(M), + cutlass.Int32(N), + cutlass.Int32(K), + cutlass.Int32(L), + ) + + print(f"\n[SOL] Dense block-scaled BMM: M={M} N={N} K={K} L={L}") + print(f"[SOL] kind={self.problem.kind} sf_vec_size={self.cfg['blocksize']}") + + l2_flush() + sol_kernel( + a_ptr, + b_ptr, + sfa_ptr, + sfb_ptr, + c_ptr, + layouts, + problem_mnkl, + max_active_clusters, + cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + torch.cuda.synchronize() + + # ----------------------------------------------------------------- + # Run + # ----------------------------------------------------------------- + + def run(self) -> None: + print(self.problem) + print(self.impl) + print(self.misc) + + self.generate_inputs() + + group_sizes = offs_to_group_sizes(self.offs_tensor) + print( + f"A: shape={tuple(self.A_tensor.shape)} " + f"stride={self.A_tensor.stride()} dtype={self.A_tensor.dtype}" + ) + print( + f"B: shape={tuple(self.B_tensor.shape)} " + f"stride={self.B_tensor.stride()} dtype={self.B_tensor.dtype}" + ) + print( + f"C: shape={tuple(self.C_tensor.shape)} " + f"stride={self.C_tensor.stride()} dtype={self.C_tensor.dtype}" + ) + print( + f"scale_a: shape={tuple(self.scale_a_tensor.shape)} " + f"stride={self.scale_a_tensor.stride()} dtype={self.scale_a_tensor.dtype}" + ) + print( + f"scale_b: shape={tuple(self.scale_b_tensor.shape)} " + f"stride={self.scale_b_tensor.stride()} dtype={self.scale_a_tensor.dtype}" + ) + if self.cfg["has_global_scale"]: + print(f"global_scale_a: {self.global_scale_a.cpu().tolist()}") + print(f"global_scale_b: {self.global_scale_b.cpu().tolist()}") + print(f"offs: {self.offs_tensor.cpu().tolist()} group_sizes={group_sizes}") + + kernel = self.create_kernel() + + if self.misc.perf_e2e: + self.run_kernel(kernel) + else: + from torch.profiler import profile, ProfilerActivity + + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: + self.compute_reference() + self.run_kernel(kernel) + if ( + self.misc.compare_with_sol + and self.misc.perf_run + and self.problem.balance_route + ): + self.run_sol_comparison() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + + self.validate() + print("PASS") + + +# ============================================================================= +# CLI entry point +# ============================================================================= + +if __name__ == "__main__": + import argparse + + def parse_tuple(s: str) -> Tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + parser = argparse.ArgumentParser( + description="Scaled Grouped GEMM for MoE (MXFP8 / MXFP4 / NVFP4)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # ── Problem ── + parser.add_argument("--tokens", type=int, default=128) + parser.add_argument("--experts", type=int, default=4) + parser.add_argument("--top_k_select", type=int, default=2) + parser.add_argument("--balance_route", action="store_true", default=False) + parser.add_argument("--hidden", type=int, default=512) + parser.add_argument("--intermediate", type=int, default=384) + parser.add_argument( + "--scenario", type=str, default="2Dx3D", choices=["2Dx3D", "2Dx2D"] + ) + parser.add_argument( + "--kind", type=str, default="mxfp8", choices=["mxfp8", "mxfp4", "nvfp4"] + ) + parser.add_argument("--out_dtype", type=str, default="bfloat16") + parser.add_argument("--acc_dtype", type=str, default="float32") + parser.add_argument("--grad_accumulate", action="store_true", default=False) + parser.add_argument( + "--consistent_token_padding", action="store_true", default=False + ) + parser.add_argument( + "--a_layout", type=str, default="k_major", choices=["k_major", "m_major"] + ) + parser.add_argument( + "--b_layout", type=str, default="k_major", choices=["k_major", "n_major"] + ) + parser.add_argument( + "--c_layout", type=str, default="n_major", choices=["n_major", "m_major"] + ) + + # ── Impl ── + parser.add_argument("--mma_tiler_mnk", type=str, default="128,128,128") + parser.add_argument("--cluster_shape_mnk", type=str, default="1,1,1") + parser.add_argument("--use_2cta_instrs", action="store_true", default=False) + parser.add_argument("--static_expert_cnt", type=int, default=None) + parser.add_argument("--separate_tensormap_init", action="store_true", default=True) + + # ── Misc ── + parser.add_argument("--perf_run", action="store_true", default=False) + parser.add_argument("--perf_e2e", action="store_true", default=False) + parser.add_argument("--compare_with_sol", action="store_true", default=False) + + args = parser.parse_args() + + if args.consistent_token_padding: + print( + "WARNING: Overriding consistent_token_padding to False " + "(not implemented yet)." + ) + args.consistent_token_padding = False + + problem = ProblemDesc( + tokens=args.tokens, + experts=args.experts, + top_k_select=args.top_k_select, + balance_route=args.balance_route, + hidden=args.hidden, + intermediate=args.intermediate, + scenario=args.scenario, + kind=args.kind, + out_dtype=getattr(torch, args.out_dtype), + acc_dtype=getattr(torch, args.acc_dtype), + grad_accumulate=args.grad_accumulate, + consistent_token_padding=args.consistent_token_padding, + a_layout=args.a_layout, + b_layout=args.b_layout, + c_layout=args.c_layout, + ) + + if not args.separate_tensormap_init: + print( + "Overriding separate_tensormap_init to True " + "(fused version not implemented yet)." + ) + args.separate_tensormap_init = True + + impl = ImplDesc( + mma_tiler_mnk=parse_tuple(args.mma_tiler_mnk), + cluster_shape_mnk=parse_tuple(args.cluster_shape_mnk), + use_2cta_instrs=args.use_2cta_instrs, + static_expert_cnt=args.static_expert_cnt, + separate_tensormap_init=args.separate_tensormap_init, + ) + misc = MiscDesc( + perf_run=args.perf_run, + perf_e2e=args.perf_e2e, + compare_with_sol=args.compare_with_sol, + ) + + tester = ScaledGroupedGemmTester(problem, impl, misc) + tester.run() + print("DONE") diff --git a/examples/python/CuTeDSL/blackwell/reduce.py b/examples/python/CuTeDSL/cute/blackwell/kernel/reduce/reduce.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/reduce.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/reduce/reduce.py diff --git a/examples/python/CuTeDSL/blackwell/rmsnorm.py b/examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/rmsnorm.py rename to examples/python/CuTeDSL/cute/blackwell/kernel/rmsnorm/rmsnorm.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/README.md b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/README.md similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/README.md rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/README.md diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_0.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_0.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_0.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_1.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_1.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_1.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_2.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_2.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_2.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_3.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_3.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_3_1.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_3_1.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_3_1.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_4.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_4.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_4.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_5.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_5.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_5.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_6.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/fp16_gemm_6.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/fp16_gemm_6.py diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py index d2631f1dd..07eaa4763 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_0.py +++ b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_0.py @@ -47,11 +47,11 @@ from cutlass.cute.runtime import make_ptr if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - examples_dir = os.path.join(current_dir, "..", "..") + examples_dir = os.path.join(current_dir, "..", "..", "..", "..") if examples_dir not in sys.path: sys.path.insert(0, examples_dir) -from blackwell.tutorial_gemm.utils import create_parser, run +from cute.blackwell.tutorial.tutorial_gemm.utils import create_parser, run mma_tiler_mn = (128, 256) mma_inst_shape_k = 64 diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_1.py similarity index 99% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_1.py index bebe522c4..8b163bc2f 100644 --- a/examples/python/CuTeDSL/blackwell/tutorial_gemm/nvfp4_gemm_1.py +++ b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/nvfp4_gemm_1.py @@ -49,11 +49,11 @@ from cutlass.cute.runtime import from_dlpack, make_ptr if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - examples_dir = os.path.join(current_dir, "..", "..") + examples_dir = os.path.join(current_dir, "..", "..", "..", "..") if examples_dir not in sys.path: sys.path.insert(0, examples_dir) -from blackwell.tutorial_gemm.utils import create_parser, run +from cute.blackwell.tutorial.tutorial_gemm.utils import create_parser, run mma_tiler_mn = (256, 256) mma_inst_shape_k = 64 diff --git a/examples/python/CuTeDSL/blackwell/tutorial_gemm/utils.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/utils.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/tutorial_gemm/utils.py rename to examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_gemm/utils.py diff --git a/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/README.md b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/README.md new file mode 100644 index 000000000..d71898a8f --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/README.md @@ -0,0 +1,385 @@ +# CUTLASS Tutorial Examples for Blackwell TMA + +## TMA V0: Understanding tma_partition - The Foundation of TMA Operations + +This example demonstrates the fundamental building blocks of Tensor Memory Accelerator (TMA) operations on NVIDIA Blackwell (SM100) architecture using CuTe DSL. It focuses on understanding the `tma_partition` interface, which is essential for all TMA operations. + +### Key Concepts + +* **TMA Load (Global → Shared)**: Asynchronous bulk data transfer from Global Memory to Shared Memory using TMA hardware +* **TMA Store (Shared → Global)**: Asynchronous bulk data transfer from Shared Memory back to Global Memory +* **tma_partition**: Core interface that prepares tensors for TMA operations by partitioning them according to TMA atom layout requirements +* **group_modes**: Tensor mode grouping to define the TMA atom shape - crucial for proper tma_partition usage +* **mbarrier Synchronization**: Hardware barriers (`mbarrier_init`, `mbarrier_arrive`, `mbarrier_wait`) for synchronizing asynchronous TMA operations + +### Kernel Architecture + +The `Sm100SimpleCopyKernel` performs a simple tile-based copy operation to illustrate TMA fundamentals: + +#### Configuration + +* **tile_shape**: Fixed at (128, 128) - Tile dimensions (M, N) +* **cluster_shape_mn**: Fixed at (1, 1) - Single CTA execution (no cluster parallelism) +* **Shared Memory**: Single buffer sized to hold one tile (tile_m × tile_n elements) +* **Synchronization**: Single mbarrier for TMA load completion + +#### Key Components + +1. **TMA Descriptor Creation**: Creates TMA atoms (`tma_atom_src`, `tma_atom_dst`) that encapsulate TMA hardware instructions +2. **Shared Memory Layout**: Row-major layout `(tile_m, tile_n):(tile_n, 1)` for simplicity +3. **Barrier Management**: Single barrier coordinates TMA load completion before processing + +### Execution Flow + +1. **Initialization**: + * Allocate shared memory buffer for one tile + * Initialize mbarrier with `elect_one()` to ensure proper synchronization semantics + * Set barrier to expect TMA transaction bytes (`tile_m × tile_n × element_size`) + * Synchronize all threads after barrier initialization + +2. **Tensor Preparation**: + ``` + Tile global tensors into (tile_m, tile_n) blocks + Apply group_modes to combine tile dimensions into Mode 0 (TMA atom) + + Example: gSrc_tiled with shape (128, 128, 4, 2) + After group_modes(_, 0, 2): ((128, 128), 4, 2) + └───┬────┘ └─┬─┘ + Mode 0 Rest modes + ``` + +3. **TMA Partition**: + ```python + tAsA, tAgA = tma_partition( + tma_atom_src, # TMA operation atom + cta_id=0, # CTA ID within cluster + cta_layout, # Cluster layout + group_modes(smem_tensor, 0, 2), # SMEM view (Mode 0 = atom) + group_modes(gSrc_tiled, 0, 2) # Global view (Mode 0 = atom) + ) + # Returns: + # tAsA: SMEM view with TMA internal layout + # tAgA: Global view with shape ((TMA_Layout), grid_m, grid_n) + ``` + +4. **Tile Selection**: + ```python + # Select specific tile for this CTA + tAgA_cta = tAgA[(None, bidx, bidy)] + # None: keep entire TMA atom (Mode 0) + # bidx, bidy: index into grid dimensions + ``` + +5. **TMA Load** (Global → Shared): + ```python + cute.copy(tma_atom_src, tAgA_cta, tAsA, tma_bar_ptr=barrier_ptr) + # Arrive on barrier (producer signals completion) + # Wait on barrier (all threads wait for TMA completion) + ``` + +6. **TMA Store** (Shared → Global): + ```python + cute.copy(tma_atom_dst, tAsA, tBgB_cta) + # Synchronous store completes before kernel exit + ``` + +### Understanding tma_partition + +The `tma_partition` function is the key to TMA operations. The tutorial includes comprehensive inline diagrams showing: + +* **Input Tensor Shapes**: How tensors are organized before partitioning +* **group_modes Effect**: How mode grouping creates the required atom structure +* **Partition Output**: The structure of partitioned tensors for TMA operations +* **Indexing Pattern**: How to select CTA-specific tiles from partitioned views + +See lines ~155-255 in `tma_v0.py` for detailed visual explanations. + +### Configuration Parameters + +* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N) +* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution +* `threads_per_cta`: 32 - Single warp (all threads participate in barriers) +* `buffer_align_bytes`: 1024 - Shared memory alignment + +### Usage + +Run the copy kernel with custom matrix dimensions: + +```bash +# Basic usage (default: 512×128 matrix) +python tma_v0.py + +# Custom dimensions +python tma_v0.py --M 1024 --N 2048 + +# With custom benchmark iterations +python tma_v0.py --M 4096 --N 4096 --num_warmup 10 --num_iters 50 +``` + +#### Example Code + +```python +from tma_v0 import run_tma_copy + +# Run copy on 1024×2048 matrix +run_tma_copy(M=1024, N=2048) +# Output: Performance metrics and verification result +``` + +### Performance Considerations + +* **Single-stage operation**: No pipelining - simple sequential load → store pattern +* **Educational focus**: Designed for understanding TMA fundamentals, not peak performance +* **Barrier synchronization**: Demonstrates proper mbarrier usage patterns +* **Foundation for V1/V2**: Concepts learned here are essential for understanding multi-stage pipelines and warp specialization in V1 and V2 + +## TMA V1: Matrix Transpose with Producer-Consumer Pattern + +This example demonstrates a TMA-based matrix transpose using producer-consumer synchronization with mbarriers on NVIDIA Blackwell (SM100) architecture. + +### Key Concepts + +* **Producer-Consumer Pattern**: Different warps coordinate through mbarrier synchronization +* **TMA Operations**: Asynchronous bulk data transfer between Global and Shared Memory +* **Shared Memory Swizzle**: Optimized layouts to avoid bank conflicts during transpose +* **Warp Specialization**: Dedicated warps for loading, transposing, and storing +* **mbarrier Synchronization**: Hardware barriers coordinate asynchronous operations + +### Kernel Architecture + +The `Sm100MatrixTransposeKernel` performs a tiled matrix transpose (M×N → N×M) with the following design: + +#### Warp Roles + +1. **TMA Load Warp** (Warp 4, Producer): Issues TMA load operations from Global Memory to Shared Memory buffer `sA` +2. **Transpose Warps** (Warps 0-3, 4 warps, Consumer/Producer): + * Wait for `sA` to be filled by TMA load (consumer of load_mbar) + * Transpose data from `sA` → Registers → `sB` + * Signal completion to TMA Store warp (producer for store_mbar) +3. **TMA Store Warp** (Warp 5, Consumer): + * Wait for `sB` to be ready (consumer of store_mbar) + * Issue TMA store operations from `sB` to Global Memory + +#### Synchronization Barriers + +The kernel uses two mbarrier instances for producer-consumer coordination: + +1. **load_mbar_ptr**: Synchronizes TMA Load → Transpose Warps + * Producer: TMA Load Warp (arrives after TMA completes) + * Consumer: Transpose Warps (wait before reading `sA`) + * Expected arrivals: 1 (from TMA load warp) + * Expected transactions: `tile_m × tile_n × element_size` bytes + +2. **store_mbar_ptr**: Synchronizes Transpose Warps → TMA Store + * Producer: Transpose Warps (each warp arrives after writing to `sB`) + * Consumer: TMA Store Warp (waits before issuing TMA store) + * Expected arrivals: 4 (one from each transpose warp) + +#### Execution Flow + +1. **Initialization**: + * All warps participate in barrier initialization (thread 0 initializes) + * Allocate two shared memory buffers: `sA` (row-major) and `sB` (column-major) + * Create TMA descriptors for source and transposed destination + * Initialize `load_mbar` with expected count of 1 and transaction bytes + * Initialize `store_mbar` with expected count of 4 (number of transpose warps) + +2. **TMA Load Warp** (Producer for Load Pipeline): + ``` + partition source tensor by tile shape + issue TMA load: Global[block_tile] → sA + arrive on load_mbar to signal completion + ``` + +3. **Transpose Warps** (Consumer for Load, Producer for Store): + ``` + wait on load_mbar for TMA load to complete + + partition sA for reading (each thread handles subset) + copy data: sA → Registers + partition sB for writing + copy data: Registers → sB (transpose happens via layout) + + fence to ensure smem writes are visible + synchronize with trans_sync_barrier + [elect one thread] arrive on store_mbar to signal completion + ``` + +4. **TMA Store Warp** (Consumer for Store Pipeline): + ``` + wait on store_mbar for transpose to complete + partition destination tensor by tile shape + issue TMA store: sB → Global[block_tile] + ``` + +### Key Features + +* **Simple Producer-Consumer Model**: Clear separation of concerns with dedicated warps +* **Efficient Synchronization**: Hardware mbarriers minimize synchronization overhead +* **Memory Layout Optimization**: Swizzled layouts prevent bank conflicts +* **Transposition via Layouts**: Transpose is achieved through different memory layouts for `sA` and `sB` + +### Configuration Parameters + +* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N) +* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution +* Warp count: 6 warps (1 TMA Load + 4 Transpose + 1 TMA Store) + +### Usage + +Run the transpose kernel with custom matrix dimensions: + +```bash +# Basic usage (128×128 matrix) +python tma_v1.py + +# Custom dimensions +python tma_v1.py --M 1024 --N 2048 + +# With custom benchmark iterations +python tma_v1.py --M 4096 --N 4096 --num_warmup 10 --num_iters 50 +``` + +#### Example Code + +```python +from tma_v1 import run_transpose + +# Run transpose on 1024×2048 matrix +run_transpose(M=1024, N=2048) +# Output: Performance metrics and verification result +``` + +### Performance Considerations + +* **Single-stage pipeline**: Simpler than multi-stage but with potential for idle warps +* **Warp specialization**: Clear roles minimize synchronization complexity +* **Good for learning**: Demonstrates fundamental TMA and mbarrier concepts + +## TMA V2: Transpose with Multi-Stage Pipeline + +This example demonstrates a TMA implementation with multi-stage pipelining for efficient matrix transpose operations on NVIDIA Blackwell (SM100) architecture. + +### Key Concepts + +* **Multi-Stage Pipeline**: Multiple buffers enable overlapping TMA loads, computation (transpose), and TMA stores to hide memory latency +* **Pipeline Abstractions**: `PipelineTmaAsync` and `PipelineTmaStore` provide producer-consumer synchronization +* **Persistent Tile Scheduler**: Efficient work distribution across CTAs for dynamic load balancing +* **Shared Memory Swizzle**: Optimized shared memory layouts to avoid bank conflicts +* **Warp Specialization**: Different warps handle loading and transposing; first transpose warp also handles storing + +### Kernel Architecture + +The `Sm100MatrixTransposeKernelV2` performs a tiled matrix transpose (M×N → N×M) with the following design: + +#### Warp Roles + +1. **TMA Load Warp** (Producer): Loads tiles from Global Memory to Shared Memory buffer `sA` using TMA load operations +2. **Transpose Warps** (4 warps, Consumer/Producer): + * Wait for data in `sA` from load pipeline + * Transpose data from `sA` → Registers → `sB` + * Synchronize via named barrier to ensure all transpose warps complete + * First transpose warp (`trans_warp_id[0]`) issues TMA store operations from `sB` to Global Memory + +#### Pipeline Stages + +The kernel uses two separate pipelines for maximum parallelism: + +1. **Load Pipeline**: `TMA Load Warp` (producer) → `Transpose Warps` (consumer) + * Multi-stage buffer `sA` (automatically computed based on available SMEM) + * Enables prefetching multiple tiles while processing current tile +2. **Store Pipeline**: `Transpose Warps` (producer) → `First Transpose Warp` (consumer, issues TMA store) + * Multi-stage buffer `sB` (automatically computed based on available SMEM) + * First transpose warp handles TMA store after all transpose warps finish writing to `sB` + +#### Stage Calculation + +The kernel automatically computes the optimal number of pipeline stages using `_compute_stages()`: + +* Calculates bytes needed per stage for `sA` (tile_m × tile_n) and `sB` (tile_m × tile_n) +* Reserves space for pipeline barriers and metadata (~1KB) +* Divides remaining shared memory by bytes per stage +* Clamps to 2-8 stages (2 for double buffering minimum, 8 for diminishing returns) + +#### Execution Flow + +1. **Initialization**: + * Allocate multi-stage shared memory buffers (`sA_staged`, `sB_staged`) + * Create TMA descriptors for source and transposed destination + * Initialize load and store pipelines with barrier synchronization + * Use persistent tile scheduler to distribute work tiles across CTAs + +2. **TMA Load Warp** (Producer for Load Pipeline): + ``` + for each tile assigned by scheduler: + acquire next available stage in load pipeline + issue TMA load: Global[tile] → sA[stage] + advance to next stage + ``` + +3. **Transpose Warps** (Consumer for Load, Producer for Store): + ``` + for each tile assigned by scheduler: + wait for load pipeline to fill current stage + copy data: sA[load_stage] → Registers + release load pipeline stage + + transpose and write: Registers → sB[store_stage] + fence to ensure smem writes are visible + synchronize all transpose warps with barrier + + [first transpose warp only] issue TMA store: sB[stage] → Global[tile] + [first transpose warp only] commit to store pipeline + [first transpose warp only] acquire next store pipeline stage + synchronize all transpose warps with barrier + ``` + +4. **Pipeline Teardown**: + * Producer/consumer tail operations ensure all in-flight operations complete + +### Key Features + +* **Automatic Stage Optimization**: Kernel calculates optimal number of stages based on tile size, data type, and available shared memory +* **Persistent Tile Scheduler**: Efficient work distribution across CTAs for dynamic load balancing +* **Memory Layout Optimization**: Uses appropriate swizzling for row-major input and column-major transposed output +* **Efficient Synchronization**: Named barriers for intra-CTA coordination, mbarriers for pipeline stages + +### Configuration Parameters + +* `tile_shape`: Fixed at (128, 128) - Tile dimensions (M, N) +* `cluster_shape_mn`: Fixed at (1, 1) - Single CTA execution +* Number of pipeline stages: Automatically computed based on available SMEM + +### Usage + +Run the transpose kernel with custom matrix dimensions: + +```bash +# Basic usage (128×128 matrix) +python tma_v2.py + +# Custom dimensions +python tma_v2.py --M 1024 --N 2048 + +python tma_v2.py --M 1024 --N 2048 --num_warmup 10 --num_iters 50 +``` + +#### Example Code + +```python +from tma_v2 import run_transpose + +# Run transpose on 1024×2048 matrix +run_transpose(M=1024, N=2048) +# Output: "TransposeSuccess!" if verification passes +``` + +### Performance Considerations + +* **Multi-stage pipelining** hides memory latency by overlapping loads, computation, and stores +* **Persistent scheduling** provides better load balancing for irregular matrix sizes +* **Warp specialization** maximizes throughput: TMA load warp handles all loads, transpose warps handle computation, and first transpose warp handles stores + +## TMA V3: TMA With MMA (Tensor Cores) + +TBD diff --git a/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v0.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v0.py new file mode 100644 index 000000000..3c5198c6f --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v0.py @@ -0,0 +1,409 @@ +# Copyright (c) 2024 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Type, Union + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass.cute.nvgpu import cpasync +from cutlass.cute.runtime import from_dlpack +import torch + +""" +TMA V0: Understanding tma_partition - The Foundation of TMA Operations + +This tutorial demonstrates TMA (Tensor Memory Accelerator) operations through +a simple copy kernel. It focuses on understanding the fundamental tma_partition +interface, which is the key to all TMA operations. + +What This Tutorial Covers: +1. TMA Load (Global Memory -> Shared Memory) +2. TMA Store (Shared Memory -> Global Memory) +3. Barrier synchronization for TMA (elect_one, mbarrier_init, mbarrier_arrive, mbarrier_wait) +4. Detailed explanation of tma_partition with visual diagrams + +Key Learning Points: +- tma_partition: How it transforms tensors for TMA operations +- group_modes: Why and how to group tensor modes to define TMA atom shape +- Indexing: How to select specific tiles from partitioned tensors +- Data flow: Complete visualization from input tensors to TMA copy + +Visual Diagrams: +See line ~155 for comprehensive diagrams showing: +- Input tensor shapes and transformations +- group_modes effect on tensor layouts +- tma_partition output structure +- Complete data flow from global/shared memory to TMA copy +- Indexing pattern for CTA-specific tiles + +Example Usage: +```bash +python cutlass_ir/compiler/python/examples/cute/blackwell/tutorial/tutorial_tma/tma_v0.py +``` +""" + + +class Sm100SimpleCopyKernel: + def __init__(self): + """ + Initializes the configuration for a Blackwell TMA copy kernel. + """ + self.tile_shape = (128, 128) + self.tile_m, self.tile_n = self.tile_shape + self.cluster_shape_mn = (1, 1) + self.threads_per_cta = 32 + self.buffer_align_bytes = 1024 + + @cute.jit + def __call__(self, src: cute.Tensor, dst: cute.Tensor): + if cutlass.const_expr(src.element_type != dst.element_type): + raise TypeError("Source and destination element types must match") + + self.dtype: Type[cutlass.Numeric] = src.element_type + + # layout for each cta: (tile_m, tile_n):(tile_n, 1) + smem_layout = cute.make_layout( + (self.tile_m, self.tile_n), stride=(self.tile_n, 1) + ) + + @cute.struct + class SharedStorage: + barrier_storage: cute.struct.MemRange[cutlass.Int64, 1] + smem_data: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout) + + # cta_tiler: the per-CTA tile extents (M, N) used by TMA. + # Note: smem_layout may include swizzle or composed layout, + # so we use product_each(...) to take the product along each logical dimension and get + # the final (tile_m, tile_n) extents expected by TMA. + # In this simple example, smem_layout.shape == (tile_m, tile_n), so product_each(...) is + # just (tile_m, tile_n). + cta_tiler = cute.product_each(smem_layout.shape) + + tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), src, smem_layout, cta_tiler + ) + + tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), dst, smem_layout, cta_tiler + ) + + # Grid shape is now (M/TileM, N/TileN) + grid_shape = cute.ceil_div((*src.layout.shape, 1), self.tile_shape) + + self.kernel( + tma_atom_src, tma_tensor_src, tma_atom_dst, tma_tensor_dst, smem_layout + ).launch( + grid=grid_shape, + block=(self.threads_per_cta, 1, 1), + cluster=(*self.cluster_shape_mn, 1), + ) + + @cute.kernel + def kernel( + self, + tma_atom_src: cute.CopyAtom, + tma_tensor_src: cute.Tensor, + tma_atom_dst: cute.CopyAtom, + tma_tensor_dst: cute.Tensor, + smem_layout: Union[cute.Layout, cute.ComposedLayout], + ): + bidx, bidy, _ = cute.arch.block_idx() + + # Allocate Shared Memory + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize barrier for TMA synchronization + barrier_ptr = storage.barrier_storage.data_ptr() + + # Initialize the barrier: elect_one ensures only one thread executes this + # Note: We must use elect_one() instead of "if tid == 0" because: + # - elect_one() provides proper synchronization semantics + # - It ensures all threads are aware that exactly one thread is executing + # - It prevents race conditions and provides memory ordering guarantees + with cute.arch.elect_one(): + cute.arch.mbarrier_init(barrier_ptr, 1) + cute.arch.mbarrier_expect_tx(barrier_ptr, self.num_tma_load_bytes) + + # Fence ensures init/expect_tx are visible before proceeding + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + + # Tile the (M, N) tensor: ((TileM, TileN), M/TileM, N/TileN) + gSrc_tiled = cute.local_tile( + tma_tensor_src, (self.tile_m, self.tile_n), (None, None) + ) + gDst_tiled = cute.local_tile( + tma_tensor_dst, (self.tile_m, self.tile_n), (None, None) + ) + + smem_tensor = storage.smem_data.get_tensor(smem_layout) + + # ====================================================================== + # TMA Partition: Tensor Preparation for TMA Operations + # ====================================================================== + # + # tma_partition prepares tensors for TMA copy by partitioning them + # according to the TMA atom's internal layout requirements. + # + # Signature: + # tma_partition(atom, cta_id, cta_layout, smem_tensor, gmem_tensor) + # -> (smem_view, gmem_view) + # + # Key Requirement: Mode 0 of both tensors must represent the TMA atom + # + # Example: M=512, N=128, TileM=128, TileN=64 + # + # Input Tensors: + # gSrc_tiled: (128, 64, 4, 2) # 4 separate modes + # └──┬──┘ └──┬──┘ + # Tile Grid + # + # smem_tensor: (128, 64) # 2 separate modes + # └──┬──┘ + # Tile + # + # Apply group_modes(tensor, 0, 2) to group first 2 modes: + # + # group_modes(gSrc_tiled, 0, 2) => ((128, 64), 4, 2) + # └───┬───┘ + # Mode 0 = Atom + # + # group_modes(smem_tensor, 0, 2) => ((128, 64),) + # └───┬───┘ + # Mode 0 = Atom + # + # After tma_partition: + # + # tAsA: SMEM view with TMA internal layout + # Shape: ((TMA_Layout),) + # - TMA_Layout: Swizzled/banked layout for efficient SMEM access + # + # tAgA: Global view preserving rest modes + # Shape: ((TMA_Layout), 4, 2) + # └─────┬─────┘ └──┬──┘ + # TMA atom Rest modes (grid) + # + # Usage Pattern: + # 1. Group modes to define atom: group_modes(tensor, 0, 2) + # 2. Call tma_partition: tAsA, tAgA = tma_partition(...) + # 3. Select tile for CTA: tAgA_cta = tAgA[(None, bidx, bidy)] + # - None: keep entire atom + # - bidx, bidy: index into rest modes + # 4. Issue TMA copy: cute.copy(atom, tAgA_cta, tAsA) + # + # Visual Data Flow: + # + # Global Memory (512x128) Shared Memory (128x64) + # ┌────────────────────┐ ┌──────────────┐ + # │ ┌───┬───┐ │ │ │ + # │ │0,0│0,1│ │ │ smem_tensor │ + # │ ├───┼───┤ │ │ (128, 64) │ + # │ │1,0│1,1│ 4x2 │ │ │ + # │ ├───┼───┤ tiles │ └──────────────┘ + # │ │2,0│2,1│ │ │ + # │ ├───┼───┤ │ │ group_modes + # │ │3,0│3,1│ │ ↓ + # │ └───┴───┘ │ ((128, 64),) + # └────────────────────┘ │ + # │ │ + # │ gSrc_tiled │ + # │ (128, 64, 4, 2) │ + # ↓ │ + # group_modes(_, 0, 2) │ + # ↓ │ + # ((128, 64), 4, 2) │ + # │ │ + # └────────┬────────────────────────┘ + # ↓ + # tma_partition + # ↓ + # ┌──────────────┴──────────────┐ + # │ │ + # tAgA tAsA + # ((TMA_Layout), 4, 2) ((TMA_Layout),) + # │ + # ↓ tAgA[(None, bidx, bidy)] + # tAgA_cta + # ((TMA_Layout),) + # + # ====================================================================== + + # TMA Load partition + # Here we only use 1x1 cluster, so cta_id is 0 and cta_layout is (1). + # More details about how to set cta_coord and cta_layout can be found in the tma_v4.py + # Note: Smem and gemm should have the same size (atom element size) in the first rank + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_src, + 0, + cute.make_layout(1), + cute.group_modes(smem_tensor, 0, 2), + cute.group_modes(gSrc_tiled, 0, 2), + ) + + # TMA Store partition + # Same process as TMA Load, but for destination tensor + # Partitions gDst_tiled and smem_tensor according to TMA Store atom + _, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_atom_dst, + 0, + cute.make_layout(1), + cute.group_modes(smem_tensor, 0, 2), + cute.group_modes(gDst_tiled, 0, 2), + ) + + # Select specific tile for this CTA from partitioned global views + # Input: tAgA with shape ((TMA_Layout), 4, 2) + # Output: tAgA_cta with shape ((TMA_Layout),) + # The (None, bidx, bidy) indexing: + # - None: keeps the entire TMA atom layout (mode 0) + # - bidx: selects from rest mode 1 (M dimension grid) + # - bidy: selects from rest mode 2 (N dimension grid) + tAgA_cta = tAgA[(None, bidx, bidy)] + tBgB_cta = tBgB[(None, bidx, bidy)] + + # ---------- TMA Load: Global -> Shared ---------- + cute.copy( + tma_atom_src, + tAgA_cta, # Source (TMA Tensor View) + tAsA, # Dest (SMEM Tensor View) + tma_bar_ptr=barrier_ptr, + ) + + # Signal arrival on the barrier after TMA is issued + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(barrier_ptr) + + # Wait for TMA to complete + cute.arch.mbarrier_wait(barrier_ptr, 0) + + # ---------- TMA Store: Shared -> Global ---------- + cute.copy( + tma_atom_dst, + tAsA, # Source (SMEM Tensor View) + tBgB_cta, # Dest (Global Tensor View) + ) + + +def run_tma_copy(M, N, num_warmup=5, num_iters=20): + """ + Run TMA copy kernel with performance measurement. + + Args: + M: Matrix dimension M + N: Matrix dimension N + num_warmup: Number of warmup iterations + num_iters: Number of timing iterations + """ + # Create tensors with shape (M, N) + a = torch.randn((M, N), dtype=torch.float16, device="cuda") + b = torch.zeros((M, N), dtype=torch.float16, device="cuda") + + # Notice: We declare N-dimension as the leading dimension should be divisible by 16 + a_cute = ( + from_dlpack(a, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=16) + ) + b_cute = ( + from_dlpack(b, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=16) + ) + + copy_kernel = Sm100SimpleCopyKernel() + compiled_kernel = cute.compile(copy_kernel, a_cute, b_cute) + + # Warmup runs + for _ in range(num_warmup): + compiled_kernel(a_cute, b_cute) + torch.cuda.synchronize() + + # Timed runs + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + compiled_kernel(a_cute, b_cute) + end_event.record() + torch.cuda.synchronize() + + # Calculate performance metrics + elapsed_time_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_time_ms / num_iters + + # Calculate throughput + # For copy: read M*N elements + write M*N elements + bytes_per_element = a.element_size() + total_bytes = 2 * M * N * bytes_per_element # Read + Write + throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000) + + # Print performance metrics + print(f"Matrix size: {M}×{N}") + print(f"Tile shape: {copy_kernel.tile_shape}") + print(f"Average time: {avg_time_ms:.4f} ms") + print(f"Throughput: {throughput_gb_s:.2f} GB/s") + + # Verify + if torch.allclose(a, b, atol=1e-3): + print("Verification: PASSED ✓") + else: + print("Verification: FAILED ✗") + diff = (a - b).abs() + print(f"Max diff: {diff.max()}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="TMA V0: Understanding tma_partition - The Foundation of TMA Operations" + ) + parser.add_argument("--M", type=int, default=512, help="Matrix dimension M") + parser.add_argument("--N", type=int, default=128, help="Matrix dimension N") + parser.add_argument( + "--num_warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--num_iters", type=int, default=20, help="Number of timing iterations" + ) + args = parser.parse_args() + + run_tma_copy( + M=args.M, + N=args.N, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + ) diff --git a/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v1.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v1.py new file mode 100644 index 000000000..e0fd741d6 --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v1.py @@ -0,0 +1,462 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Tuple, Type +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync +from cutlass.cute.runtime import from_dlpack +import cutlass.pipeline as pipeline +import torch + +""" +TMA Matrix Transpose with Producer-Consumer Pattern: TMA load -> S2R --> R2S -> TMA store + +Warp Roles (Producer-Consumer Pattern): +- Producer: TMA Load Warp (Warp 4) - Loads from Global to Shared A +- Consumer: Transpose Warps (Warp 0-3) - Wait for load, then transpose sA -> sB +- Consumer: TMA Store Warp (Warp 5) - Wait for transpose, then store sB to Global + +Synchronization: +1. load_mbar_ptr: TMA Load (producer) -> Transpose Warps (consumer) +2. store_mbar_ptr: Transpose Warps (producer) -> TMA Store (consumer) + +This demonstrates how different warps can use shared memory barriers to coordinate +producer-consumer relationships. +""" + + +class Sm100MatrixTransposeKernelV1: + def __init__(self): + self.tile_shape = (128, 128) + self.tile_m, self.tile_n = self.tile_shape + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + + # Set specialized warp ids based on tile_shape + self.num_trans_warps = 4 # Maximum number of transpose warps + self.trans_warp_id = tuple(range(self.num_trans_warps)) + self.tma_load_warp_id = self.num_trans_warps + self.tma_store_warp_id = self.num_trans_warps + 1 + self.threads_per_cta = 32 * len( + (self.tma_store_warp_id, self.tma_load_warp_id, *self.trans_warp_id) + ) + self.num_trans_threads = 32 * len(self.trans_warp_id) + + self.trans_tile = (self.tile_shape[0] // self.num_trans_warps, 8) + + # Set barriers for producer-consumer sync + # Barrier 1: Trans warps sync (for internal coordination) + self.trans_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.trans_warp_id), + ) + # Barrier 2: TMA Store warp waits after Trans warps finish + self.store_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32, # Only TMA store warp + ) + self.buffer_align_bytes = 1024 + + @cute.jit + def __call__(self, src: cute.Tensor, dst: cute.Tensor): + if cutlass.const_expr(src.element_type != dst.element_type): + raise TypeError("Source and destination element types must match") + + self.dtype: Type[cutlass.Numeric] = src.element_type + + # Create transposed view of dst for TMA descriptor + # dst is (N, M), we want to view it as (M, N) transposed + transed_dst = cute.make_tensor( + dst.iterator, + cute.make_layout( + (dst.shape[1], dst.shape[0]), stride=(dst.stride[1], dst.stride[0]) + ), + ) + + # row-major smem layout for sA (tile_m, tile_n) + smem_layout_sA = sm100_utils.make_smem_layout( + utils.LayoutEnum.from_tensor(src).mma_major_mode(), + (self.tile_m, self.tile_n), + self.dtype, + 1, + ) + + # col-major smem layout for sB (tile_n, tile_m) + # sB should match the transposed destination layout + smem_layout_sB = sm100_utils.make_smem_layout( + utils.LayoutEnum.from_tensor(transed_dst).mma_major_mode(), + (self.tile_m, self.tile_n), + self.dtype, + 1, + ) + + @cute.struct + class SharedStorage: + # Barrier for TMA Load: producer (TMA) -> consumer (Trans warps) + load_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + # Barrier for TMA Store: producer (Trans warps) -> consumer (TMA Store) + store_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + # Single shared memory buffer (sA and sB are different views of this) + sA: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sA)], 128 + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sB)], 128 + ] + + self.shared_storage = SharedStorage + self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout_sA) + + # TMA Atoms + # Use swizzled layout for TMA atom to handle swizzling during load + tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + src, + smem_layout_sA, + (self.tile_m, self.tile_n), + ) + + tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + transed_dst, + smem_layout_sB, + (self.tile_m, self.tile_n), + ) + + grid_shape = cute.ceil_div((*src.layout.shape, 1), self.tile_shape) + self.kernel( + tma_atom_src, + tma_tensor_src, + tma_atom_dst, + tma_tensor_dst, + smem_layout_sA, + smem_layout_sB, + ).launch( + grid=grid_shape, + block=(self.threads_per_cta, 1, 1), + cluster=self.cluster_shape_mnk, + ) + + @cute.kernel + def kernel( + self, + tma_atom_load: cute.CopyAtom, + tma_tensor_src: cute.Tensor, + tma_atom_store: cute.CopyAtom, + tma_tensor_dst: cute.Tensor, + smem_layout_sA: cute.ComposedLayout, + smem_layout_sB: cute.ComposedLayout, + ): + bidx, bidy, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Allocate Shared Memory + # We need two buffers for transpose: + # sA: Source Tile (swizzled) + # sB: Destination Tile (swizzled) for TMA Store + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sA = storage.sA.get_tensor(smem_layout_sA.outer, swizzle=smem_layout_sA.inner) + # sA = cute.make_tensor(storage.sA.iterator, smem_layout_sA) + sB = storage.sB.get_tensor(smem_layout_sB.outer, swizzle=smem_layout_sB.inner) + self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_layout_sA) + load_mbar_ptr = storage.load_mbar_ptr.data_ptr() + store_mbar_ptr = storage.store_mbar_ptr.data_ptr() + + # ------------------------------------------------------------------ + # Initialize Barriers (all warps participate in initialization) + # ------------------------------------------------------------------ + if tidx == 0: + # Barrier for TMA Load: expect 1 arrive (from TMA warp after TMA completes) + cute.arch.mbarrier_init(load_mbar_ptr, 1) + cute.arch.mbarrier_expect_tx(load_mbar_ptr, self.num_tma_load_bytes) + + # Barrier for TMA Store: expect arrival from Trans warps + cute.arch.mbarrier_init(store_mbar_ptr, len(self.trans_warp_id)) + cute.arch.mbarrier_init_fence() + + # Sync all warps after barrier initialization + cute.arch.barrier() + # ------------------------------------------------------------------ + # PRODUCER: TMA Load Warp (G -> sA) + # ------------------------------------------------------------------ + if warp_idx == self.tma_load_warp_id: + # Issue TMA Load + # ((TileM, TileK), loopM, LoopK) + gA = cute.local_tile(tma_tensor_src, self.tile_shape, (None, None)) + # ((TileM, TileK), loopM, LoopK) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_load, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 2), + cute.group_modes(gA, 0, 2), + ) + + cute.copy( + tma_atom_load, + tAgA[(None, bidx, bidy)], + tAsA[(None, 0)], + tma_bar_ptr=load_mbar_ptr, + ) + + # Arrive on mbarrier to satisfy the init count of 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(load_mbar_ptr) + + # ------------------------------------------------------------------ + # CONSUMER: Transpose Warps (sA -> Reg -> sB) + # ------------------------------------------------------------------ + if warp_idx < self.tma_load_warp_id: + trans_tid = tidx % self.num_trans_threads + + # Wait for TMA Load to complete (consumer wait on load_mbar) + cute.arch.mbarrier_wait(load_mbar_ptr, 0) + + atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=self.dtype.width, # Copy one element at a time + ) + + copy_elems = 1 + + # Use SAME thread layout for both read and write + # Transpose happens through sB_transposed layout view + + # TV layout notation: T = thread-id, V = value-lane within that thread. + # In this example `copy_elems = 1` and `thread_layout` has shape (T, V) = (num_trans_threads, 1), + # so V is always 0 (only one value-lane per thread). + # + # Linearization rule (row-major by strides): + # idx(Ti, Vj) = i + j * num_trans_threads + # Therefore here: + # idx(Ti, V0) = i + # + # Diagram (V is the column, T is the row; showing the first two threads): + # + # V0 + # ┌──────┐ + # T0 │ T0V0 │ -> idx 0 + # T1 │ T1V0 │ -> idx 1 + # ... │ ... │ + # └──────┘ + # + thread_layout = cute.make_layout( + (self.num_trans_threads, 1), + stride=(1, self.num_trans_threads), + ) + value_layout = cute.make_layout((1, copy_elems)) + # Build a "tiled copy" operator that defines the per-thread copy mapping (T,V) for this warp-group: + # - It is used twice below via `thr_copy.partition_S(...)` and `thr_copy.partition_D(...)` to + # create matching per-thread views of the source and destination tensors. + # - With the SAME (T,V) mapping, the actual transpose is achieved by changing the tensor view + # (`sA` vs `sB` / `sB_transposed`), not by changing which threads perform the copies. + tiled_copy = cute.make_tiled_copy_tv(atom, thread_layout, value_layout) + thr_copy = tiled_copy.get_slice(trans_tid) + + # Partition sA (source) for reading + tCsA = thr_copy.partition_S(sA) + # When to use `tiled_copy.retile(...)`: + # - Use it when you allocate/build a register tensor yourself (or slice/reshape it) and its + # internal layout doesn't match the TV layout expected by `tiled_copy` for copy-in/out. + # Why we don't use it here: + # - `cute.make_fragment_like(tCsA)` creates an rmem fragment with the same per-thread shape/layout + # as `tCsA`, so it already matches `tiled_copy`'s view and can be copied into directly. + tCrA = cute.make_fragment_like(tCsA) + cute.copy(tiled_copy, tCsA, tCrA) + + # Partition sB for writing + tCsB = thr_copy.partition_D(sB) + + # Write from register to sB + cute.copy(tiled_copy, tCrA, tCsB) + + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + self.trans_sync_barrier.arrive_and_wait() + + # Trans warps signal TMA Store warp: "sB is ready!" + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(store_mbar_ptr) + + # ------------------------------------------------------------------ + # CONSUMER: TMA Store Warp (sB -> G) + # ------------------------------------------------------------------ + if warp_idx == self.tma_store_warp_id: + # Wait for Trans warp to complete (consumer wait on store_mbar) + cute.arch.mbarrier_wait(store_mbar_ptr, 0) + + gDst_cta = cute.local_tile( + tma_tensor_dst, (self.tile_m, self.tile_n), (None, None) + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_store, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 2), + cute.group_modes(gDst_cta, 0, 2), + ) + cute.copy(tma_atom_store, tBsB[(None, 0)], tBgB[(None, bidx, bidy)]) + + +def run_transpose(M, N, num_warmup=5, num_iters=20): + """ + Run TMA transpose kernel with performance measurement. + + Args: + M: Matrix dimension M + N: Matrix dimension N + num_warmup: Number of warmup iterations + num_iters: Number of timing iterations + + Performance Metrics: + - Throughput: Actual achieved bandwidth in GB/s + - Theoretical BW: Peak memory bandwidth (2048 B/clk × 4000 MHz = 8.192 TB/s) + - Bandwidth Efficiency: Percentage of theoretical peak achieved + """ + torch.manual_seed(1111) + # Input (M, N) + input_data = torch.randn((M, N), device="cuda", dtype=torch.float16) + # Output (N, M) + output_data = torch.zeros((N, M), device="cuda", dtype=torch.float16) + + # CuTe Wrappers + tensor_src = ( + from_dlpack(input_data, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=16) + ) + tensor_dst = ( + from_dlpack(output_data, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=16) + ) + + transpose_kernel = Sm100MatrixTransposeKernelV1() + + print("Start kernel compilation...") + + # Compile and Run + compiled_kernel = cute.compile( + transpose_kernel, tensor_src, tensor_dst, options="--generate-line-info" + ) + + print("Start kernel warmup...") + # Warmup runs + for _ in range(num_warmup): + compiled_kernel(tensor_src, tensor_dst) + torch.cuda.synchronize() + print("Kernel warmup completed.") + + # Timed runs + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + compiled_kernel(tensor_src, tensor_dst) + end_event.record() + torch.cuda.synchronize() + + # Calculate performance metrics + elapsed_time_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_time_ms / num_iters + + # Calculate throughput + # For transpose: read M*N elements + write M*N elements + bytes_per_element = input_data.element_size() + total_bytes = 2 * M * N * bytes_per_element # Read + Write + throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000) + + # Theoretical bandwidth limit + # Blackwell: 2048 B/clk at 4000 MHz + bytes_per_clk = 2048 + freq_mhz = 4000 + theoretical_bw_gb_s = bytes_per_clk * freq_mhz * 1e6 / 1e9 # Convert to GB/s + theoretical_bw_tb_s = theoretical_bw_gb_s / 1000 # Convert to TB/s + bandwidth_efficiency = (throughput_gb_s / theoretical_bw_gb_s) * 100 # Percentage + + # Print performance metrics + print(f"Matrix size: {M}×{N}") + print(f"Tile shape: {transpose_kernel.tile_shape}") + print(f"Average time: {avg_time_ms:.4f} ms") + print(f"Throughput: {throughput_gb_s:.2f} GB/s") + print( + f"Theoretical BW: {theoretical_bw_tb_s:.2f} TB/s ({theoretical_bw_gb_s:.2f} GB/s)" + ) + print(f"Bandwidth Efficiency: {bandwidth_efficiency:.2f}%") + + # Verification + expected = input_data.t() + if torch.allclose(output_data, expected, atol=1e-2): + print("Verification: PASSED ✓") + else: + print("Verification: FAILED ✗") + print(f"Max diff: {(output_data - expected).abs().max()}") + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="TMA Matrix Transpose with Producer-Consumer Pattern" + ) + parser.add_argument("--M", type=int, default=128, help="Matrix dimension M") + parser.add_argument("--N", type=int, default=128, help="Matrix dimension N") + parser.add_argument( + "--num_warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--num_iters", type=int, default=20, help="Number of timing iterations" + ) + args = parser.parse_args() + + run_transpose( + args.M, + args.N, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + ) diff --git a/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v2.py b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v2.py new file mode 100644 index 000000000..9918e752c --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/tutorial/tutorial_tma/tma_v2.py @@ -0,0 +1,648 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Tuple, Type, Union +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync +from cutlass.cute.runtime import from_dlpack +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import torch + +""" +TMA Matrix Transpose with Multi-Stage Pipeline(v2) + +This version extends tma_v1.py with: +1. Multi-stage pipeline: Multiple buffers for pipelining TMA loads and stores +2. Pipeline abstraction: Using PipelineTmaAsync for proper producer-consumer coordination +3. Persistent tile scheduler: Efficient work distribution across CTAs + +Key Improvements over v1: +- Multi-stage buffers enable overlapping TMA loads, computation, and TMA stores +- Pipeline objects provide cleaner synchronization semantics + +Note: TMA multicast is NOT used because each CTA must process different input tiles. + +Warp Roles: +- Producer: TMA Load Warp - Loads from Global to Shared memory (multi-stage, multi-tile) +- Consumer: Transpose Warps - Wait for load, transpose sA -> sB (multi-tile) +- Consumer: TMA Store Warp - Wait for transpose, store sB to Global (multi-stage, multi-tile) + +Pipeline Stages: +1. Load Pipeline: TMA Load (producer) -> Transpose Warps (consumer) +2. Store Pipeline: Transpose Warps (producer) -> TMA Store (consumer) +""" + + +class Sm100MatrixTransposeKernelV2: + def __init__( + self, + ): + """ + Initialize the TMA transpose kernel with multi-stage pipeline support. + + Args: + tile_shape: Tile dimensions (M, N) + cluster_shape_mn: Cluster shape for parallel CTA execution (M, N) + + Note: + - Each CTA processes different tiles independently + - Stage counts are automatically computed based on available shared memory + - Persistent scheduler distributes work across CTAs in the cluster + """ + self.tile_shape = (128, 128) + self.tile_m, self.tile_n = self.tile_shape + self.cluster_shape_mn = (1, 1) + self.cluster_shape_mnl = (*self.cluster_shape_mn, 1) + + # Set specialized warp ids based on tile_shape + # For 128x128 tile, use 4 transpose warps (same as v1) + self.max_trans_warps = 4 # Maximum number of transpose warps + self.num_trans_warps = self.max_trans_warps # Use all transpose warps + self.trans_warp_id = tuple(range(self.num_trans_warps)) + self.tma_load_warp_id = self.num_trans_warps + self.threads_per_cta = 32 * len((self.tma_load_warp_id, *self.trans_warp_id)) + self.num_trans_threads = 32 * len(self.trans_warp_id) + + # Set barriers for producer-consumer sync + # Barrier 1: Trans warps sync (for internal coordination) + self.trans_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.trans_warp_id), + ) + self.buffer_align_bytes = 128 + + # Get shared memory capacity for stage computation + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + @staticmethod + def _compute_stages( + tile_m: int, + tile_n: int, + dtype: Type[cutlass.Numeric], + smem_capacity: int, + ) -> Tuple[int, int]: + """ + Compute the number of load and store stages based on shared memory capacity. + + Strategy: + 1. Calculate bytes per stage for load (sA) and store (sB) buffers + 2. Reserve space for barriers and alignment + 3. Divide remaining smem by bytes per stage to get max stages + 4. Clamp to reasonable min/max values + + Args: + tile_m: Tile dimension M + tile_n: Tile dimension N + dtype: Data type of the tensors + smem_capacity: Total shared memory capacity in bytes + + Returns: + Tuple of (num_load_stages, num_store_stages) + """ + # Calculate bytes per tile (assuming row-major and col-major layouts) + bytes_per_element = dtype.width // 8 + + # For sA (load buffer): tile_m x tile_n elements + sA_bytes_per_stage = tile_m * tile_n * bytes_per_element + + # For sB (store buffer): tile_m x tile_n elements (transposed) + sB_bytes_per_stage = tile_m * tile_n * bytes_per_element + + # Reserve space for barriers and other metadata + # Each barrier: 8 bytes (Int64) + # Estimate: max 16 barriers (load + store stages * 2) + alignment + reserved_bytes = 1024 # Conservative estimate + + # Available space for staging buffers + available_smem = smem_capacity - reserved_bytes + + # Calculate max stages we can fit + # We need space for both load and store stages + total_bytes_per_stage_pair = sA_bytes_per_stage + sB_bytes_per_stage + + # Max stages (same for load and store for simplicity) + max_stages = available_smem // total_bytes_per_stage_pair + + # Clamp to reasonable values + # Min: 2 stages for basic double buffering + # Max: 8 stages (diminishing returns beyond this) + num_stages = max(2, min(max_stages, 8)) + + return num_stages, num_stages + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mn: The shape (M, N) of the CTA tile. + :type cta_tile_shape_mn: tuple[int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mn, (None, None)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mn = gc[(0, (None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + num_ctas_mnl = (*num_ctas_mn, 1) + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @cute.jit + def __call__( + self, src: cute.Tensor, dst: cute.Tensor, max_active_clusters: cutlass.Constexpr + ): + if cutlass.const_expr(src.element_type != dst.element_type): + raise TypeError("Source and destination element types must match") + + self.dtype: Type[cutlass.Numeric] = src.element_type + + # Compute optimal stage counts based on tile size and dtype + self.num_load_stages, self.num_store_stages = self._compute_stages( + self.tile_m, + self.tile_n, + self.dtype, + self.smem_capacity, + ) + + # Create transposed view of dst for TMA descriptor + # dst is (N, M), we want to view it as (M, N) transposed + transed_dst = cute.make_tensor( + dst.iterator, + cute.make_layout( + (dst.shape[1], dst.shape[0]), stride=(dst.stride[1], dst.stride[0]) + ), + ) + + # Create multi-stage layouts for load and store buffers + # row-major smem layout for sA (tile_m, tile_n) + smem_layout_sA_staged = sm100_utils.make_smem_layout( + utils.LayoutEnum.from_tensor(src).mma_major_mode(), + (self.tile_m, self.tile_n), + self.dtype, + self.num_load_stages, + ) + + # col-major smem layout for sB (tile_n, tile_m) + # sB should match the transposed destination layout + smem_layout_sB_staged = sm100_utils.make_smem_layout( + utils.LayoutEnum.from_tensor(transed_dst).mma_major_mode(), + (self.tile_m, self.tile_n), + self.dtype, + self.num_store_stages, + ) + + @cute.struct + class SharedStorage: + # Pipeline barriers for multi-stage load + load_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_load_stages + ] + load_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_load_stages + ] + + # Pipeline barriers for multi-stage store + store_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_store_stages + ] + store_empty_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_store_stages + ] + + # Multi-stage shared memory buffers + sA: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sA_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.dtype, cute.cosize(smem_layout_sB_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + a_smem_layout = cute.slice_(smem_layout_sA_staged, (None, None, 0)) + self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, a_smem_layout) + + # TMA Atoms + # Each CTA loads its own tile independently + tma_load_op = cpasync.CopyBulkTensorTileG2SOp() + + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), (1,) + ) + + tma_atom_src, tma_tensor_src = cpasync.make_tiled_tma_atom( + tma_load_op, + src, + smem_layout_sA_staged, + (self.tile_m, self.tile_n), + ) + + tma_atom_dst, tma_tensor_dst = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + transed_dst, + smem_layout_sB_staged, + (self.tile_m, self.tile_n), + ) + tile_sched_params, grid_shape = self._compute_grid( + transed_dst, + (self.tile_m, self.tile_n), + self.cluster_shape_mn, + max_active_clusters, + ) + + self.kernel( + tma_atom_src, + tma_tensor_src, + tma_atom_dst, + tma_tensor_dst, + smem_layout_sA_staged, + smem_layout_sB_staged, + cluster_layout_vmnk, + tile_sched_params, + ).launch( + grid=grid_shape, + block=(self.threads_per_cta, 1, 1), + cluster=self.cluster_shape_mnl, + ) + + @cute.kernel + def kernel( + self, + tma_atom_load: cute.CopyAtom, + tma_tensor_src: cute.Tensor, + tma_atom_store: cute.CopyAtom, + tma_tensor_dst: cute.Tensor, + smem_layout_sA_staged: Union[cute.Layout, cute.ComposedLayout], + smem_layout_sB_staged: Union[cute.Layout, cute.ComposedLayout], + cluster_layout_vmnk: cute.Layout, + tile_sched_params: utils.PersistentTileSchedulerParams, + ): + tidx, _, _ = cute.arch.thread_idx() + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # ---------------- Shared mem & staged buffers ---------------- + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + sA_staged = storage.sA.get_tensor( + smem_layout_sA_staged.outer, swizzle=smem_layout_sA_staged.inner + ) + sB_staged = storage.sB.get_tensor( + smem_layout_sB_staged.outer, swizzle=smem_layout_sB_staged.inner + ) + + load_mbar_ptr = storage.load_full_mbar_ptr.data_ptr() + _store_mbar_ptr = storage.store_full_mbar_ptr.data_ptr() + + load_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + load_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_trans_warps + ) + + load_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=load_mbar_ptr, + num_stages=self.num_load_stages, + producer_group=load_producer_group, + consumer_group=load_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_trans_threads + ) + store_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_store_stages, + producer_group=store_producer_group, + ) + + # Critical: Initialize pipeline barriers across cluster + # This must happen after pipeline creation and before any producer/consumer work + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + gA = cute.local_tile(tma_tensor_src, self.tile_shape, (None, None)) + + _cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, None, 0)).shape + ) + # ((TileM, TileK), loopM, LoopK) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_load, + 0, + cute.make_layout(1), + cute.group_modes(sA_staged, 0, 2), + cute.group_modes(gA, 0, 2), + ) + + gDst_cta = cute.local_tile(tma_tensor_dst, self.tile_shape, (None, None)) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_store, + 0, + cute.make_layout(1), + cute.group_modes(sB_staged, 0, 2), + cute.group_modes(gDst_cta, 0, 2), + ) + + # ------------------------------------------------------------------ + # PRODUCER: TMA Load Warp (G -> sA) + # ------------------------------------------------------------------ + if warp_idx == self.tma_load_warp_id: + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + load_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_load_stages + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + tAgA_slice = tAgA[(None, cur_tile_coord[0], cur_tile_coord[1])] + + load_pipeline.producer_acquire(load_producer_state) + + cute.copy( + tma_atom_load, + tAgA_slice, + tAsA[(None, load_producer_state.index)], + tma_bar_ptr=load_pipeline.producer_get_barrier(load_producer_state), + ) + + load_producer_state.advance() + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_pipeline.producer_tail(load_producer_state) + + # ------------------------------------------------------------------ + # CONSUMER: Transpose Warps (sA -> Reg -> sB) + # ------------------------------------------------------------------ + if warp_idx < self.tma_load_warp_id: + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + trans_tid = tidx % self.num_trans_threads + + load_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_load_stages + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + + # Wait for load pipeline to have data + load_pipeline.consumer_wait(load_consumer_state) + + atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=self.dtype.width, + ) + copy_elems = 1 + + thread_layout = cute.make_layout( + (self.num_trans_threads, 1), + stride=(1, self.num_trans_threads), + ) + value_layout = cute.make_layout((1, copy_elems)) + tiled_copy = cute.make_tiled_copy_tv(atom, thread_layout, value_layout) + thr_copy = tiled_copy.get_slice(trans_tid) + + # sA -> Reg + tCsA = thr_copy.partition_S(sA_staged) + + tCrA = cute.make_rmem_tensor( + tCsA[(None, None, None, 0)].shape, self.dtype + ) + tCrA = tiled_copy.retile(tCrA) + + cute.copy( + tiled_copy, + tCsA[(None, None, None, load_consumer_state.index)], + tCrA, + ) + + # release load pipeline + load_pipeline.consumer_release(load_consumer_state) + load_consumer_state.advance() + + index = tile_sched.num_tiles_executed % self.num_store_stages + # Reg -> sB + tCsB = thr_copy.partition_D(sB_staged) + cute.copy(tiled_copy, tCrA, tCsB[(None, None, None, index)]) + + # Fence to ensure smem writes are visible + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + self.trans_sync_barrier.arrive_and_wait() + + if warp_idx == self.trans_warp_id[0]: + cute.copy( + tma_atom_store, + tBsB[(None, index)], + tBgB[(None, cur_tile_coord[0], cur_tile_coord[1])], + ) + store_pipeline.producer_commit() + store_pipeline.producer_acquire() + self.trans_sync_barrier.arrive_and_wait() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + self.trans_sync_barrier.arrive_and_wait() + + store_pipeline.producer_tail() + + +def run_transpose(M, N, max_active_clusters=0, num_warmup=5, num_iters=20): + """ + Run TMA transpose kernel with automatic stage calculation and performance measurement. + + Args: + M: Matrix dimension M + N: Matrix dimension N + max_active_clusters: Maximum number of active clusters (0 for auto) + num_warmup: Number of warmup iterations + num_iters: Number of timing iterations + + Performance Metrics: + - Throughput: Actual achieved bandwidth in GB/s + - Theoretical BW: Peak memory bandwidth (2048 B/clk × 4000 MHz = 8.192 TB/s) + - Bandwidth Efficiency: Percentage of theoretical peak achieved + """ + torch.manual_seed(1111) + # Input (M, N) + input_data = torch.randn((M, N), device="cuda", dtype=torch.float16) + # Output (N, M) + output_data = torch.zeros((N, M), device="cuda", dtype=torch.float16) + + # CuTe Wrappers + tensor_src = ( + from_dlpack(input_data, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=16) + ) + tensor_dst = ( + from_dlpack(output_data, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=16) + ) + + transpose_kernel = Sm100MatrixTransposeKernelV2() + + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(1) + # Compile and Run + compiled_kernel = cute.compile( + transpose_kernel, + tensor_src, + tensor_dst, + max_active_clusters, + options="--generate-line-info", + ) + + # Warmup runs + for _ in range(num_warmup): + compiled_kernel(tensor_src, tensor_dst) + torch.cuda.synchronize() + + # Timed runs + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(num_iters): + compiled_kernel(tensor_src, tensor_dst) + end_event.record() + torch.cuda.synchronize() + + # Calculate performance metrics + elapsed_time_ms = start_event.elapsed_time(end_event) + avg_time_ms = elapsed_time_ms / num_iters + + # Calculate throughput + # For transpose: read M*N elements + write M*N elements + bytes_per_element = input_data.element_size() + total_bytes = 2 * M * N * bytes_per_element # Read + Write + throughput_gb_s = (total_bytes / 1e9) / (avg_time_ms / 1000) + + # Theoretical bandwidth limit + # Blackwell: 2048 B/clk at 4000 MHz + bytes_per_clk = 2048 + freq_mhz = 4000 + theoretical_bw_gb_s = bytes_per_clk * freq_mhz * 1e6 / 1e9 # Convert to GB/s + theoretical_bw_tb_s = theoretical_bw_gb_s / 1000 # Convert to TB/s + bandwidth_efficiency = (throughput_gb_s / theoretical_bw_gb_s) * 100 # Percentage + + # Print computed stage counts after compilation + print(f"Matrix size: {M}×{N}") + print(f"Tile shape: {transpose_kernel.tile_shape}") + print( + f"Computed stages: Load={transpose_kernel.num_load_stages}, Store={transpose_kernel.num_store_stages}" + ) + print(f"Average time: {avg_time_ms:.4f} ms") + print(f"Throughput: {throughput_gb_s:.2f} GB/s") + print( + f"Theoretical BW: {theoretical_bw_tb_s:.2f} TB/s ({theoretical_bw_gb_s:.2f} GB/s)" + ) + print(f"Bandwidth Efficiency: {bandwidth_efficiency:.2f}%") + + # Verification + expected = input_data.t() + if torch.allclose(output_data, expected, atol=1e-2): + print("Verification: PASSED ✓") + else: + print("Verification: FAILED ✗") + print(f"Max diff: {(output_data - expected).abs().max()}") + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="TMA Matrix Transpose with Multi-Stage Pipeline and Cluster Support (v2)" + ) + parser.add_argument("--M", type=int, default=128, help="Matrix dimension M") + parser.add_argument("--N", type=int, default=128, help="Matrix dimension N") + parser.add_argument( + "--num_warmup", type=int, default=5, help="Number of warmup iterations" + ) + parser.add_argument( + "--num_iters", type=int, default=20, help="Number of timing iterations" + ) + args = parser.parse_args() + + run_transpose( + args.M, + args.N, + num_warmup=args.num_warmup, + num_iters=args.num_iters, + ) diff --git a/examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py b/examples/python/CuTeDSL/cute/blackwell_geforce/kernel/dense_gemm/dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/blackwell_geforce/dense_gemm.py rename to examples/python/CuTeDSL/cute/blackwell_geforce/kernel/dense_gemm/dense_gemm.py diff --git a/examples/python/CuTeDSL/hopper/fmha.py b/examples/python/CuTeDSL/cute/hopper/kernel/attention/fmha.py similarity index 99% rename from examples/python/CuTeDSL/hopper/fmha.py rename to examples/python/CuTeDSL/cute/hopper/kernel/attention/fmha.py index 1892e848c..6a4c35e59 100644 --- a/examples/python/CuTeDSL/hopper/fmha.py +++ b/examples/python/CuTeDSL/cute/hopper/kernel/attention/fmha.py @@ -100,7 +100,7 @@ from cutlass.cute.runtime import from_dlpack if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) - sys.path.insert(0, os.path.join(current_dir, "..")) + sys.path.insert(0, os.path.join(current_dir, "../../../..")) from helpers import fmha_helpers as fmha_utils diff --git a/examples/python/CuTeDSL/hopper/cta_norm.py b/examples/python/CuTeDSL/cute/hopper/kernel/cta_norm/cta_norm.py similarity index 100% rename from examples/python/CuTeDSL/hopper/cta_norm.py rename to examples/python/CuTeDSL/cute/hopper/kernel/cta_norm/cta_norm.py diff --git a/examples/python/CuTeDSL/hopper/dense_gemm.py b/examples/python/CuTeDSL/cute/hopper/kernel/dense_gemm/dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/hopper/dense_gemm.py rename to examples/python/CuTeDSL/cute/hopper/kernel/dense_gemm/dense_gemm.py diff --git a/examples/python/CuTeDSL/hopper/dense_gemm_fp8_2xacc.py b/examples/python/CuTeDSL/cute/hopper/kernel/dense_gemm/dense_gemm_fp8_2xacc.py similarity index 100% rename from examples/python/CuTeDSL/hopper/dense_gemm_fp8_2xacc.py rename to examples/python/CuTeDSL/cute/hopper/kernel/dense_gemm/dense_gemm_fp8_2xacc.py diff --git a/examples/python/CuTeDSL/hopper/dense_gemm_persistent.py b/examples/python/CuTeDSL/cute/hopper/kernel/dense_gemm/dense_gemm_persistent.py similarity index 100% rename from examples/python/CuTeDSL/hopper/dense_gemm_persistent.py rename to examples/python/CuTeDSL/cute/hopper/kernel/dense_gemm/dense_gemm_persistent.py diff --git a/examples/python/CuTeDSL/hopper/grouped_gemm.py b/examples/python/CuTeDSL/cute/hopper/kernel/grouped_gemm/grouped_gemm.py similarity index 100% rename from examples/python/CuTeDSL/hopper/grouped_gemm.py rename to examples/python/CuTeDSL/cute/hopper/kernel/grouped_gemm/grouped_gemm.py diff --git a/examples/python/CuTeDSL/notebooks/README.md b/examples/python/CuTeDSL/cute/notebooks/README.md similarity index 100% rename from examples/python/CuTeDSL/notebooks/README.md rename to examples/python/CuTeDSL/cute/notebooks/README.md diff --git a/examples/python/CuTeDSL/notebooks/async_pipeline.ipynb b/examples/python/CuTeDSL/cute/notebooks/async_pipeline.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/async_pipeline.ipynb rename to examples/python/CuTeDSL/cute/notebooks/async_pipeline.ipynb diff --git a/examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb b/examples/python/CuTeDSL/cute/notebooks/benchmark_autotune.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/benchmark_autotune.ipynb rename to examples/python/CuTeDSL/cute/notebooks/benchmark_autotune.ipynb diff --git a/examples/python/CuTeDSL/notebooks/composed_layout.ipynb b/examples/python/CuTeDSL/cute/notebooks/composed_layout.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/composed_layout.ipynb rename to examples/python/CuTeDSL/cute/notebooks/composed_layout.ipynb diff --git a/examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb b/examples/python/CuTeDSL/cute/notebooks/cuda_graphs.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/cuda_graphs.ipynb rename to examples/python/CuTeDSL/cute/notebooks/cuda_graphs.ipynb diff --git a/examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb b/examples/python/CuTeDSL/cute/notebooks/cute_layout_algebra.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/cute_layout_algebra.ipynb rename to examples/python/CuTeDSL/cute/notebooks/cute_layout_algebra.ipynb diff --git a/examples/python/CuTeDSL/notebooks/data_types.ipynb b/examples/python/CuTeDSL/cute/notebooks/data_types.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/data_types.ipynb rename to examples/python/CuTeDSL/cute/notebooks/data_types.ipynb diff --git a/examples/python/CuTeDSL/notebooks/elementwise_add.ipynb b/examples/python/CuTeDSL/cute/notebooks/elementwise_add.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/elementwise_add.ipynb rename to examples/python/CuTeDSL/cute/notebooks/elementwise_add.ipynb diff --git a/examples/python/CuTeDSL/notebooks/hello_world.ipynb b/examples/python/CuTeDSL/cute/notebooks/hello_world.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/hello_world.ipynb rename to examples/python/CuTeDSL/cute/notebooks/hello_world.ipynb diff --git a/examples/python/CuTeDSL/notebooks/images/blocked_gemm.svg b/examples/python/CuTeDSL/cute/notebooks/images/blocked_gemm.svg similarity index 100% rename from examples/python/CuTeDSL/notebooks/images/blocked_gemm.svg rename to examples/python/CuTeDSL/cute/notebooks/images/blocked_gemm.svg diff --git a/examples/python/CuTeDSL/notebooks/images/cuda_graphs_image.png b/examples/python/CuTeDSL/cute/notebooks/images/cuda_graphs_image.png similarity index 100% rename from examples/python/CuTeDSL/notebooks/images/cuda_graphs_image.png rename to examples/python/CuTeDSL/cute/notebooks/images/cuda_graphs_image.png diff --git a/examples/python/CuTeDSL/notebooks/images/software_pipelining_ab_stages_minus_1.svg b/examples/python/CuTeDSL/cute/notebooks/images/software_pipelining_ab_stages_minus_1.svg similarity index 100% rename from examples/python/CuTeDSL/notebooks/images/software_pipelining_ab_stages_minus_1.svg rename to examples/python/CuTeDSL/cute/notebooks/images/software_pipelining_ab_stages_minus_1.svg diff --git a/examples/python/CuTeDSL/notebooks/images/software_pipelining_ab_stages_minus_2.svg b/examples/python/CuTeDSL/cute/notebooks/images/software_pipelining_ab_stages_minus_2.svg similarity index 100% rename from examples/python/CuTeDSL/notebooks/images/software_pipelining_ab_stages_minus_2.svg rename to examples/python/CuTeDSL/cute/notebooks/images/software_pipelining_ab_stages_minus_2.svg diff --git a/examples/python/CuTeDSL/notebooks/print.ipynb b/examples/python/CuTeDSL/cute/notebooks/print.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/print.ipynb rename to examples/python/CuTeDSL/cute/notebooks/print.ipynb diff --git a/examples/python/CuTeDSL/notebooks/tensor.ipynb b/examples/python/CuTeDSL/cute/notebooks/tensor.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/tensor.ipynb rename to examples/python/CuTeDSL/cute/notebooks/tensor.ipynb diff --git a/examples/python/CuTeDSL/notebooks/tensorssa.ipynb b/examples/python/CuTeDSL/cute/notebooks/tensorssa.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/tensorssa.ipynb rename to examples/python/CuTeDSL/cute/notebooks/tensorssa.ipynb diff --git a/examples/python/CuTeDSL/notebooks/tour_to_sol_gemm.ipynb b/examples/python/CuTeDSL/cute/notebooks/tour_to_sol_gemm.ipynb similarity index 100% rename from examples/python/CuTeDSL/notebooks/tour_to_sol_gemm.ipynb rename to examples/python/CuTeDSL/cute/notebooks/tour_to_sol_gemm.ipynb diff --git a/examples/python/CuTeDSL/experimental/ampere/memcpy_simt_universal_copy.py b/examples/python/CuTeDSL/cute_ext/ampere/memcpy_simt_universal_copy.py similarity index 100% rename from examples/python/CuTeDSL/experimental/ampere/memcpy_simt_universal_copy.py rename to examples/python/CuTeDSL/cute_ext/ampere/memcpy_simt_universal_copy.py diff --git a/examples/python/CuTeDSL/experimental/blackwell/dense_block_scaled_gemm.py b/examples/python/CuTeDSL/cute_ext/blackwell/dense_block_scaled_gemm.py similarity index 100% rename from examples/python/CuTeDSL/experimental/blackwell/dense_block_scaled_gemm.py rename to examples/python/CuTeDSL/cute_ext/blackwell/dense_block_scaled_gemm.py diff --git a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm.py b/examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm.py similarity index 100% rename from examples/python/CuTeDSL/experimental/blackwell/dense_gemm.py rename to examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm.py diff --git a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_2sm.py b/examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_2sm.py similarity index 100% rename from examples/python/CuTeDSL/experimental/blackwell/dense_gemm_2sm.py rename to examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_2sm.py diff --git a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_cute_pipeline.py b/examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_cute_pipeline.py similarity index 100% rename from examples/python/CuTeDSL/experimental/blackwell/dense_gemm_cute_pipeline.py rename to examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_cute_pipeline.py diff --git a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_ptr_array.py b/examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_ptr_array.py similarity index 78% rename from examples/python/CuTeDSL/experimental/blackwell/dense_gemm_ptr_array.py rename to examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_ptr_array.py index bd4e09c19..066bd9d6f 100755 --- a/examples/python/CuTeDSL/experimental/blackwell/dense_gemm_ptr_array.py +++ b/examples/python/CuTeDSL/cute_ext/blackwell/dense_gemm_ptr_array.py @@ -1,4 +1,4 @@ -# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without @@ -37,10 +37,13 @@ from cutlass.base_dsl.typing import Numeric from cutlass import cute as cute from cutlass import utils from cutlass import torch as cutlass_torch +from cutlass.cute.experimental.host_runtime import QueryDeviceWorkspaceFunc +from cutlass.cute.runtime import from_dlpack import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.cute.testing as testing + class DenseGemmPtrArrayKernel: def __init__( self, @@ -115,7 +118,6 @@ class DenseGemmPtrArrayKernel: # Get pointers for the first batch to perform shape and stage calculations A_0_ptr = self._get_pointer(mA_tensor[0], self.ab_dtype) B_0_ptr = self._get_pointer(mB_tensor[0], self.ab_dtype) - D_0_ptr = self._get_pointer(mD_tensor[0], self.d_dtype) mA = cute.make_tensor( A_0_ptr, layout=cute.make_layout(self.A_shape, stride=self.A_stride) @@ -125,11 +127,8 @@ class DenseGemmPtrArrayKernel: B_0_ptr, layout=cute.make_layout(self.B_shape, stride=self.B_stride) ) - mD = cute.make_tensor( - D_0_ptr, layout=cute.make_layout(self.D_shape, stride=self.D_stride) - ) - tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.ab_dtype, self.ab_dtype, utils.LayoutEnum.from_tensor(mA).mma_major_mode(), utils.LayoutEnum.from_tensor(mB).mma_major_mode(), @@ -146,26 +145,19 @@ class DenseGemmPtrArrayKernel: mma_inst_shape_k * mma_inst_tile_k, ) - d_layout = utils.LayoutEnum.from_tensor(mD) - d_dtype = mD.element_type - tiler_mk = (mnk_tiler[0], mnk_tiler[2]) tiler_nk = (mnk_tiler[1], mnk_tiler[2]) - tiler_mn = (mnk_tiler[0], mnk_tiler[1]) gA = cute.zipped_divide(mA, tiler_mk) gB = cute.zipped_divide(mB, tiler_nk) - gD = cute.zipped_divide(mD, tiler_mn) mainloop_stage = 2 acc_stage = 2 cta_m, cta_n, cta_l = cute.arch.block_idx() - tid_x, _, _ = cute.arch.thread_idx() gA_tile = gA[(None, None), (cta_m, None, cta_l)] gB_tile = gB[(None, None), (cta_n, None, cta_l)] - gD_tile = gD[(None, None), (cta_m, cta_n, cta_l)] # Compute A/B/C shared memory layout a_smem_layout_staged = sm100_utils.make_smem_layout_a( @@ -184,18 +176,6 @@ class DenseGemmPtrArrayKernel: cta_tile_shape_mnk = cute.shape_div( mnk_tiler, (cute.size(tiled_mma.thr_id.shape), 1, 1) ) - epi_tile = sm100_utils.compute_epilogue_tile_shape( - cta_tile_shape_mnk, - self.use_2cta_instrs, - d_layout, - d_dtype, - ) - sc_smem_layout_staged = sm100_utils.make_smem_layout_epi( - d_dtype, - d_layout, - epi_tile, - self.TMA_STORE_STAGE, - ) # UMMA ACC TMEM Layout tmem_layout = cute_ext.make_tmem_layout_acc(tiled_mma, mnk_tiler, acc_stage) @@ -222,51 +202,6 @@ class DenseGemmPtrArrayKernel: alignment=16, ) - # Allocate SMEM buffer for C - bufferC = cute_ext.allocate( - d_dtype, - cute.AddressSpace.smem, - sc_smem_layout_staged, - alignment=1024, - ) - - # Create the TMEM load atom - copy_atom_t2r = sm100_utils.get_tmem_load_op( - cta_tile_shape_mnk, - d_layout, - self.tmem_output_dtype, - self.acc_dtype, - epi_tile, - self.use_2cta_instrs, - ) - - # Take only one stage of the TMEM buffer - accumulators = cute.zipped_divide(bufferAcc, ((epi_tile), 1)) - acc_epi_div = accumulators[((None, None), 0), 0] - - # Create the TMEM copy atom based on the size of transfer within one iteration of epilogue - tiled_copy_t2r = cute.nvgpu.tcgen05.make_tmem_copy(copy_atom_t2r, acc_epi_div) - - # Calculate the per thread destination size per iteration for output of TMEM and input of SMEM - gC_mnl_epi = cute.flat_divide(gD_tile, epi_tile) - acc_d_rmem_layout = cute_ext.make_t2r_rmem_layout( - tiled_copy_t2r, gC_mnl_epi, tid_x - ) - - # Allocate RMEM buffers - bufferRAcc = cute_ext.allocate( - self.acc_dtype, - cute.AddressSpace.rmem, - acc_d_rmem_layout, - alignment=32, - ) - bufferRD = cute_ext.allocate( - d_dtype, - cute.AddressSpace.rmem, - acc_d_rmem_layout, - alignment=32, - ) - # TMA -> UMMA mainloop_pipe = cute_ext.TMAToUMMAPipeline.create( num_stages=mainloop_stage, @@ -363,42 +298,21 @@ class DenseGemmPtrArrayKernel: # MMA section remains same as a regular GEMM if is_mma_thr: producer_stage_token, idx = acc_pipe.producer_acquire_and_get_stage() - ## acc_producer_body begin ## accumulators_sliced = bufferAcc[None, None, None, idx] - mma_atom = cute.make_mma_atom(tiled_mma.op) - mma_atom.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, False) - for k_tile in cutlass.range(0, k_tile_size, 1, unroll=1): - # Scoped state management - pipeline object manages consumer state internally - ( - _, - mainloop_idx, - ) = mainloop_pipe.consumer_wait_and_get_stage() - ## tma_consumer_body begin ## + (updated_a_pipe, _updated_b_pipe) = cute_ext.mainloop_mma( + tiled_mma, + bufferA, + bufferB, + accumulators_sliced, + 0, + k_tile_size, + mma_inst_tile_k, + mainloop_pipe, + mainloop_pipe, + ) + mainloop_pipe = updated_a_pipe - bufferA_sliced_stage = cute.core.slice_( - bufferA, (None, None, None, mainloop_idx) - ) - bufferB_sliced_stage = cute.core.slice_( - bufferB, (None, None, None, mainloop_idx) - ) - - for k_block in cutlass.range(mma_inst_tile_k, unroll_full=True): - bufferA_sliced = bufferA_sliced_stage[None, None, k_block] - bufferB_sliced = bufferB_sliced_stage[None, None, k_block] - - cute_ext.dot( - mma_atom, - cute.append_ones(bufferA_sliced, up_to_rank=3), - cute.append_ones(bufferB_sliced, up_to_rank=3), - accumulators_sliced, - ) - mma_atom.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) - - ## tma_consumer_body end ## - mainloop_pipe.consumer_release_and_advance() - - ## acc_producer_body end ## acc_pipe.producer_commit_and_advance() if is_epi_thr: @@ -406,57 +320,22 @@ class DenseGemmPtrArrayKernel: mD = cute.make_tensor( ptr_D, layout=cute.make_layout(self.D_shape, stride=self.D_stride) ) - gD = cute.zipped_divide(mD, tiler_mn) - gD_tile = gD[(None, None), (cta_m, cta_n, cta_l)] - gC_mnl_epi = cute.flat_divide(gD_tile, epi_tile) + _, idx = acc_pipe.consumer_wait_and_get_stage() - ## acc_consume_body begin ## accumulators_sliced = bufferAcc[(None, None), 0, 0, idx] - acc_epi_div_tiled = cute.flat_divide(accumulators_sliced, epi_tile) + cta_d_tile_coord = (cta_m, cta_n, cta_l) - subtile_cnt = cute.size(acc_epi_div_tiled.shape, mode=[3]) - for mn in range(subtile_cnt): - # TMEM -> RMEM - cute_ext.partition_and_copy( - tiled_copy_t2r.get_slice(tid_x), - acc_epi_div_tiled[None, None, 0, mn], - bufferRAcc, - ) + tma_store_pipe = cute_ext.epilogue_tma_store( + cta_tile_shape_mnk, + self.use_2cta_instrs, + accumulators_sliced, + mD, + cta_d_tile_coord, + tma_store_pipe, + tma_store_warp_id, + self.epilogue_op, + ) - # RMEM -> RMEM - bufferRD.store(self.epilogue_op(bufferRAcc.load().to(self.d_dtype))) - - # Acquire pipeline stage and synchronize before RMEM->SMEM copy - tma_store_pipe.acquire_sync() - idx = tma_store_pipe.get_index() - - # RMEM -> SMEM - tiled_copy_r2s = cute.make_tiled_copy_D( - cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.d_dtype), - tiled_copy_t2r, - ) - cute_ext.partition_and_copy( - tiled_copy_r2s.get_slice(tid_x), - bufferRD, - bufferC[None, None, idx], - ) - - # Fence SMEM writes and synchronize before TMA store - tma_store_pipe.commit_sync() - - # SMEM -> GMEM (only designated TMA store warp performs TMA store) - if warp_idx == tma_store_warp_id: - c_cta_v_map = cute_ext.get_cta_v_map_c(mD, epi_tile) - cute_ext.tma_store( - bufferC[None, None, idx], - gC_mnl_epi[None, None, 0, mn], - cta_v_map=c_cta_v_map, - ) - - # Release pipeline stage and advance - tma_store_pipe.release_advance() - - tma_store_pipe.tail() acc_pipe.consumer_release_and_advance() @@ -639,6 +518,19 @@ def run( d_major = c_major d_dtype = c_dtype + sm100_utils.check_gemm_tma_alignment( + m, + n, + k, + ab_dtype, + ab_dtype, + d_dtype, + a_major, + b_major, + d_major, + output_tensor_name="D", + ) + # a_tensor, b_tensor, d_tensor are cute Tensors where each element is an Int64 pointer to global memory # A_cutes, B_cutes, D_cutes are lists of cute Tensors for each batch of A/B/D ( @@ -674,7 +566,15 @@ def run( compiled_dense_gemm = cute_ext.compile( ptr_array_dense_gemm, a_tensor, b_tensor, d_tensor ) - compiled_dense_gemm(a_tensor, b_tensor, d_tensor) + + query = compiled_dense_gemm.get_aux_func( + QueryDeviceWorkspaceFunc, kernel=ptr_array_dense_gemm.kernel + ) + req = query(a_tensor, b_tensor, d_tensor) + workspace = torch.empty(req.size_in_bytes, dtype=torch.uint8, device="cuda") + workspace_cute = from_dlpack(workspace) + + compiled_dense_gemm(a_tensor, b_tensor, d_tensor, workspace_cute) if not skip_ref_check: for batch_idx in range(l): @@ -704,7 +604,11 @@ def run( ) = create_tensors_for_ptr_array( l, m, n, k, a_major, b_major, d_major, ab_dtype, d_dtype ) - args = testing.JitArguments(a_tensor, b_tensor, d_tensor) + + ws = torch.empty(req.size_in_bytes, dtype=torch.uint8, device="cuda") + ws_cute = from_dlpack(ws) + + args = testing.JitArguments(a_tensor, b_tensor, d_tensor, ws_cute) args.add_to_scope([A_cutes, B_cutes, D_cutes]) return args diff --git a/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py b/examples/python/CuTeDSL/dsl_tutorials/call_bypass_dlpack.py similarity index 98% rename from examples/python/CuTeDSL/ampere/call_bypass_dlpack.py rename to examples/python/CuTeDSL/dsl_tutorials/call_bypass_dlpack.py index 1a21d37a6..291ed5ef9 100644 --- a/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py +++ b/examples/python/CuTeDSL/dsl_tutorials/call_bypass_dlpack.py @@ -77,7 +77,7 @@ if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "..")) -from ampere.tensorop_gemm import TensorOpGemm +from cute.ampere.kernel.dense_gemm.tensorop_gemm import TensorOpGemm @cute.jit diff --git a/examples/python/CuTeDSL/ampere/call_from_jit.py b/examples/python/CuTeDSL/dsl_tutorials/call_from_jit.py similarity index 99% rename from examples/python/CuTeDSL/ampere/call_from_jit.py rename to examples/python/CuTeDSL/dsl_tutorials/call_from_jit.py index f4fa4339d..8aa147a42 100644 --- a/examples/python/CuTeDSL/ampere/call_from_jit.py +++ b/examples/python/CuTeDSL/dsl_tutorials/call_from_jit.py @@ -68,7 +68,7 @@ if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "..")) -from ampere.tensorop_gemm import TensorOpGemm +from cute.ampere.kernel.dense_gemm.tensorop_gemm import TensorOpGemm class BufferWithLayout: diff --git a/examples/python/CuTeDSL/ampere/cooperative_launch.py b/examples/python/CuTeDSL/dsl_tutorials/cooperative_launch.py similarity index 100% rename from examples/python/CuTeDSL/ampere/cooperative_launch.py rename to examples/python/CuTeDSL/dsl_tutorials/cooperative_launch.py diff --git a/examples/python/CuTeDSL/cute/dataclass_immutable.py b/examples/python/CuTeDSL/dsl_tutorials/dataclass_immutable.py similarity index 100% rename from examples/python/CuTeDSL/cute/dataclass_immutable.py rename to examples/python/CuTeDSL/dsl_tutorials/dataclass_immutable.py diff --git a/examples/python/CuTeDSL/ampere/dynamic_smem_size.py b/examples/python/CuTeDSL/dsl_tutorials/dynamic_smem_size.py similarity index 100% rename from examples/python/CuTeDSL/ampere/dynamic_smem_size.py rename to examples/python/CuTeDSL/dsl_tutorials/dynamic_smem_size.py diff --git a/examples/python/CuTeDSL/cute/export/export_to_c.py b/examples/python/CuTeDSL/dsl_tutorials/export/export_to_c.py similarity index 100% rename from examples/python/CuTeDSL/cute/export/export_to_c.py rename to examples/python/CuTeDSL/dsl_tutorials/export/export_to_c.py diff --git a/examples/python/CuTeDSL/cute/export/load_in_python.py b/examples/python/CuTeDSL/dsl_tutorials/export/load_in_python.py similarity index 100% rename from examples/python/CuTeDSL/cute/export/load_in_python.py rename to examples/python/CuTeDSL/dsl_tutorials/export/load_in_python.py diff --git a/examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.cpp b/examples/python/CuTeDSL/dsl_tutorials/export/run_with_dynamic_loading.cpp similarity index 100% rename from examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.cpp rename to examples/python/CuTeDSL/dsl_tutorials/export/run_with_dynamic_loading.cpp diff --git a/examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.sh b/examples/python/CuTeDSL/dsl_tutorials/export/run_with_dynamic_loading.sh similarity index 100% rename from examples/python/CuTeDSL/cute/export/run_with_dynamic_loading.sh rename to examples/python/CuTeDSL/dsl_tutorials/export/run_with_dynamic_loading.sh diff --git a/examples/python/CuTeDSL/cute/export/run_with_static_linking.cpp b/examples/python/CuTeDSL/dsl_tutorials/export/run_with_static_linking.cpp similarity index 100% rename from examples/python/CuTeDSL/cute/export/run_with_static_linking.cpp rename to examples/python/CuTeDSL/dsl_tutorials/export/run_with_static_linking.cpp diff --git a/examples/python/CuTeDSL/cute/export/run_with_static_linking.sh b/examples/python/CuTeDSL/dsl_tutorials/export/run_with_static_linking.sh similarity index 100% rename from examples/python/CuTeDSL/cute/export/run_with_static_linking.sh rename to examples/python/CuTeDSL/dsl_tutorials/export/run_with_static_linking.sh diff --git a/examples/python/CuTeDSL/cute/ffi/CMakeLists.txt b/examples/python/CuTeDSL/dsl_tutorials/ffi/CMakeLists.txt similarity index 100% rename from examples/python/CuTeDSL/cute/ffi/CMakeLists.txt rename to examples/python/CuTeDSL/dsl_tutorials/ffi/CMakeLists.txt diff --git a/examples/python/CuTeDSL/cute/ffi/jit_argument.py b/examples/python/CuTeDSL/dsl_tutorials/ffi/jit_argument.py similarity index 100% rename from examples/python/CuTeDSL/cute/ffi/jit_argument.py rename to examples/python/CuTeDSL/dsl_tutorials/ffi/jit_argument.py diff --git a/examples/python/CuTeDSL/cute/ffi/tensor.cpp b/examples/python/CuTeDSL/dsl_tutorials/ffi/tensor.cpp similarity index 100% rename from examples/python/CuTeDSL/cute/ffi/tensor.cpp rename to examples/python/CuTeDSL/dsl_tutorials/ffi/tensor.cpp diff --git a/examples/python/CuTeDSL/ampere/inline_ptx.py b/examples/python/CuTeDSL/dsl_tutorials/inline_ptx.py similarity index 98% rename from examples/python/CuTeDSL/ampere/inline_ptx.py rename to examples/python/CuTeDSL/dsl_tutorials/inline_ptx.py index 0eeedeb36..9f66b067c 100644 --- a/examples/python/CuTeDSL/ampere/inline_ptx.py +++ b/examples/python/CuTeDSL/dsl_tutorials/inline_ptx.py @@ -30,9 +30,10 @@ from functools import partial from typing import Union import cutlass.cute as cute +from cutlass import Constexpr from cutlass.cute.runtime import from_dlpack from cutlass._mlir.dialects import llvm -from cutlass.cute.typing import Boolean, Int32, Int, Constexpr +from cutlass.cute.typing import Boolean, Int32, Int from cutlass.cutlass_dsl import T, dsl_user_op from cutlass.cute.arch.nvvm_wrappers import FULL_MASK, WARP_SIZE @@ -42,7 +43,7 @@ A simple example to show how to wrap PTX instructions by using inline_asm op in Situations like: 1. Instructions that are not already exposed by CuTe DSL via `nvvm` module -2. Sequences of instructions that the compiler otherwise does not generate optimally +2. Sequences of instructions that the compiler otherwise does not generate optimally motivate developers to inline PTX themselves. @@ -57,7 +58,7 @@ To run this example: .. code-block:: bash - python examples/ampere/inline_ptx.py + python examples/dsl/inline_ptx.py The example will run the vote kernel with inline ptx and nvvm dialect separately. The results from inline ptx and nvvm dialect will be verified correspondingly. diff --git a/examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb b/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax.ipynb similarity index 100% rename from examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb rename to examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax.ipynb diff --git a/examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py b/examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax_kernels.py similarity index 100% rename from examples/python/CuTeDSL/jax/cute_dsl_jax_kernels.py rename to examples/python/CuTeDSL/dsl_tutorials/jax/cute_dsl_jax_kernels.py diff --git a/examples/python/CuTeDSL/jax/cutlass_call_basic.py b/examples/python/CuTeDSL/dsl_tutorials/jax/cutlass_call_basic.py similarity index 100% rename from examples/python/CuTeDSL/jax/cutlass_call_basic.py rename to examples/python/CuTeDSL/dsl_tutorials/jax/cutlass_call_basic.py diff --git a/examples/python/CuTeDSL/jax/cutlass_call_export.py b/examples/python/CuTeDSL/dsl_tutorials/jax/cutlass_call_export.py similarity index 100% rename from examples/python/CuTeDSL/jax/cutlass_call_export.py rename to examples/python/CuTeDSL/dsl_tutorials/jax/cutlass_call_export.py diff --git a/examples/python/CuTeDSL/jax/cutlass_call_sharding.py b/examples/python/CuTeDSL/dsl_tutorials/jax/cutlass_call_sharding.py similarity index 100% rename from examples/python/CuTeDSL/jax/cutlass_call_sharding.py rename to examples/python/CuTeDSL/dsl_tutorials/jax/cutlass_call_sharding.py diff --git a/examples/python/CuTeDSL/jax/elementwise_apply_example.py b/examples/python/CuTeDSL/dsl_tutorials/jax/elementwise_apply_example.py similarity index 100% rename from examples/python/CuTeDSL/jax/elementwise_apply_example.py rename to examples/python/CuTeDSL/dsl_tutorials/jax/elementwise_apply_example.py diff --git a/examples/python/CuTeDSL/cute/print_latex.py b/examples/python/CuTeDSL/dsl_tutorials/print_latex.py similarity index 100% rename from examples/python/CuTeDSL/cute/print_latex.py rename to examples/python/CuTeDSL/dsl_tutorials/print_latex.py diff --git a/examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py b/examples/python/CuTeDSL/dsl_tutorials/programmatic_dependent_launch.py similarity index 100% rename from examples/python/CuTeDSL/blackwell/programmatic_dependent_launch.py rename to examples/python/CuTeDSL/dsl_tutorials/programmatic_dependent_launch.py diff --git a/examples/python/CuTeDSL/ampere/smem_allocator.py b/examples/python/CuTeDSL/dsl_tutorials/smem_allocator.py similarity index 100% rename from examples/python/CuTeDSL/ampere/smem_allocator.py rename to examples/python/CuTeDSL/dsl_tutorials/smem_allocator.py diff --git a/examples/python/CuTeDSL/cute/torch_fake_tensor.py b/examples/python/CuTeDSL/dsl_tutorials/torch_fake_tensor.py similarity index 100% rename from examples/python/CuTeDSL/cute/torch_fake_tensor.py rename to examples/python/CuTeDSL/dsl_tutorials/torch_fake_tensor.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/ampere_gemm_with_fake_tensor.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/ampere_gemm_with_fake_tensor.py similarity index 98% rename from examples/python/CuTeDSL/cute/tvm_ffi/ampere_gemm_with_fake_tensor.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/ampere_gemm_with_fake_tensor.py index 2453e000b..5dd987e01 100644 --- a/examples/python/CuTeDSL/cute/tvm_ffi/ampere_gemm_with_fake_tensor.py +++ b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/ampere_gemm_with_fake_tensor.py @@ -43,7 +43,7 @@ if __name__ == "__main__": # Add the current directory to sys.path current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.join(current_dir, "..", "..")) -from ampere.tensorop_gemm import TensorOpGemm +from cute.ampere.kernel.dense_gemm.tensorop_gemm import TensorOpGemm @cute.jit diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/aot_export.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_export.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/aot_export.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_export.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_cpp_bundle.cpp b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_cpp_bundle.cpp rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.cpp diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_cpp_bundle.sh b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.sh similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_cpp_bundle.sh rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_cpp_bundle.sh diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_jax.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_jax.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_jax.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_jax.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_torch.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_torch.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/aot_use_in_torch.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/aot_use_in_torch.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/compile_with_fake_tensor.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/compile_with_fake_tensor.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/compile_with_fake_tensor.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/compile_with_fake_tensor.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/error_reporting.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/error_reporting.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/error_reporting.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/error_reporting.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/jit_and_use_in_jax.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/jit_and_use_in_jax.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/jit_and_use_in_jax.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/jit_and_use_in_jax.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/jit_and_use_in_torch.py b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/jit_and_use_in_torch.py similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/jit_and_use_in_torch.py rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/jit_and_use_in_torch.py diff --git a/examples/python/CuTeDSL/cute/tvm_ffi/requirements.txt b/examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/requirements.txt similarity index 100% rename from examples/python/CuTeDSL/cute/tvm_ffi/requirements.txt rename to examples/python/CuTeDSL/dsl_tutorials/tvm_ffi/requirements.txt diff --git a/examples/python/CuTeDSL/advanced_compiler_control/gemm0.bin b/examples/python/advanced_compiler_control/gemm0.bin similarity index 100% rename from examples/python/CuTeDSL/advanced_compiler_control/gemm0.bin rename to examples/python/advanced_compiler_control/gemm0.bin diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp index d9cecf9ff..0559917da 100644 --- a/include/cute/arch/config.hpp +++ b/include/cute/arch/config.hpp @@ -209,3 +209,8 @@ # define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED #endif +#if (defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED)) + #define CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED +#endif + + diff --git a/include/cute/arch/copy_sm100.hpp b/include/cute/arch/copy_sm100.hpp index c8109b372..fcbe88b48 100644 --- a/include/cute/arch/copy_sm100.hpp +++ b/include/cute/arch/copy_sm100.hpp @@ -7612,4 +7612,173 @@ struct SM100_TMEM_STORE_32dp32b128x_16b //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::LOAD_STAT { + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_STAT_32dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + float& row_max) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED) + asm volatile ("tcgen05.ld.red.sync.aligned.32x32b.x32.max.f32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}, %32," + "[%33];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=f"(row_max) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD_STAT without CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_STAT_32dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127, float& row_max) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED) + asm volatile ("tcgen05.ld.red.sync.aligned.32x32b.x128.max.f32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}, %128," + "[%129];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127), "=f"(row_max) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD_STAT without CUTE_ARCH_TCGEN05_TMEM_STAT_ENABLED."); +#endif + } +//#endif +}; + +} // end namespace SM100::TMEM::LOAD_STAT + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm100.hpp b/include/cute/atom/copy_traits_sm100.hpp index 7996fa3c0..be67a795c 100644 --- a/include/cute/atom/copy_traits_sm100.hpp +++ b/include/cute/atom/copy_traits_sm100.hpp @@ -3506,6 +3506,105 @@ tmem_load_to_store(CopyOp) { //////////////////////////////////////////////////////////////////////////////////////////////////// +namespace SM100::TMEM::LOAD_STAT { + +// +// Specialized copy_unpack implementation for SM100::TMEM::LOAD_STAT instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_tmem::value, "Expected TMEM src."); + static_assert(is_rmem::value, "Expected RMEM dst."); + + using SrcType = typename TS::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected src to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(src.data()); + const float& row_max = traits.get_max(); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); + + // thread idx <=> DP lane assert. + // ASSERT thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + float* row_max_ = const_cast(&row_max); + float tmp_row_max = row_max_[0]; + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}, + &tmp_row_max, seq<0>{}); + + row_max_[0] = fmax(row_max_[0], tmp_row_max); +} + +} // end namespace SM100::TMEM::LOAD_STAT + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD_STAT::SM100_TMEM_LOAD_STAT_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout>, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout>, + Stride, _1>>; + using RefLayout = SrcLayout; + + float row_max = -cutlass::platform::numeric_limits::infinity(); + + CUTE_HOST_DEVICE constexpr + float const& get_max() const { + return row_max; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD_STAT::SM100_TMEM_LOAD_STAT_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; + + float row_max = -cutlass::platform::numeric_limits::infinity(); + + CUTE_HOST_DEVICE constexpr + float const& get_max() const { + return row_max; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// // // UTCCP Copy Traits diff --git a/include/cute/atom/mma_traits_sm100.hpp b/include/cute/atom/mma_traits_sm100.hpp index c9b42a8f2..5b6af4218 100644 --- a/include/cute/atom/mma_traits_sm100.hpp +++ b/include/cute/atom/mma_traits_sm100.hpp @@ -333,11 +333,7 @@ struct DescriptorIterator CUTE_HOST_DEVICE constexpr DescriptorIterator operator+(Index const& offset) const { - // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc - SmemDescriptor ret; - ret.lo = desc_.lo + uint32_t(offset); - ret.hi = desc_.hi; - return { ret }; + return { desc_ + uint64_t(offset)}; } }; diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 5f26e1689..e288cb91a 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -49,8 +49,10 @@ #endif #if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +#if !(defined(__QNX__) && __QNX__ >= 800 && defined(NV_IS_SAFETY)) # define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED #endif +#endif #if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) # define CUDA_ENABLE_PREFERRED_CLUSTER diff --git a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp index aae691260..5a82b7554 100644 --- a/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm100_implicit_gemm_tma_warpspecialized.hpp @@ -269,7 +269,7 @@ public: bool implementable = true; implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp index 8cf41c1b8..ed0c0df6f 100644 --- a/include/cutlass/cuda_host_adapter.hpp +++ b/include/cutlass/cuda_host_adapter.hpp @@ -87,7 +87,9 @@ namespace cutlass { #if ((__CUDACC_VER_MAJOR__ >= 12) || \ ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +#if !(defined(__QNX__) && __QNX__ >= 800 && defined(NV_IS_SAFETY)) #include +#endif #endif // (__CUDACC_VERSION__ >= 11.8) #include diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl index 72ef10e34..3dd4241e3 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_mixed_tma_cpasync_umma_builder.inl @@ -121,7 +121,8 @@ struct CollectiveBuilder< StageCountType, BuilderScheduleTag, cute::enable_if_t< - cute::is_same_v && + (cute::is_same_v || + cute::is_same_v) && (cute::is_same_v ) > @@ -139,7 +140,7 @@ struct CollectiveBuilder< static_assert(cute::is_static_v, "TileShape has to be static"); static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); - static constexpr bool is_2sm = false; // detail::blockscaled::is_2sm(); + static constexpr bool is_2sm = detail::blockscaled::is_2sm(); static constexpr auto Instr = detail::blockscaled::select_instr(); using TiledMma = typename cutlass::gemm::collective::detail::TrivialBlockscaledMma(TileShape_MNK{})))); // Assigning 4 warps for mainloop load of B - static constexpr int NumLoadThreadsCpAsync = 128; + static constexpr int NumLoadThreadsCpAsync = 128 / size(AtomThrID{}); using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); @@ -196,7 +197,7 @@ struct CollectiveBuilder< using GmemCopyAtomB = cute::Copy_Atom, ElementB>; using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, - decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK{}) / size(AtomThrID{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); @@ -233,15 +234,24 @@ struct CollectiveBuilder< static constexpr uint32_t SchedulerPipelineStageCount = AccumulatorPipelineStageCount + 1; // AccumulatorPipeline = PipelineUmmaAsync - static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); + static constexpr auto AccumulatorPipelineStorage = sizeof(typename cutlass::PipelineUmmaAsync::SharedStorage); // CLCPipeline = PipelineCLCFetchAsync static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); // CLC (scheduler) response static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Tmem dealloc barrier + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // MMA trampoline barrier (for 2SM synchronization) + static constexpr auto MmaTrampolineBarrierStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem base pointer storage + static constexpr auto TmemBasePtrStorage = sizeof(uint32_t); // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + CLCPipelineStorage + - CLCResponseStorage); + CLCResponseStorage + + TmemDeallocStorage + + MmaTrampolineBarrierStorage + + TmemBasePtrStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); diff --git a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl index 3a45b1cd9..506fd092d 100644 --- a/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_blockscaled_umma_builder.inl @@ -124,6 +124,7 @@ struct CollectiveBuilder< ) && // Blockscaled Gemm (not cute::is_same_v) && + (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v) && diff --git a/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl index b2648e098..8eb347f71 100644 --- a/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_mixed_tma_cpasync_umma_builder.inl @@ -65,10 +65,11 @@ struct CollectiveBuilder< AlignmentB, ElementAccumulator, TileShape_MNK, // (MmaAtomShapeM, MmaAtomShapeN, TileK) - ClusterShape_MNK, // Static cluster shape (_1, _1, _1) + ClusterShape_MNK, // Static cluster shape, but can be non-trivial StageCountType, BuilderScheduleTag, - cute::enable_if_t && + cute::enable_if_t<(cute::is_same_v || + cute::is_same_v) && (cute::is_same_v )> > @@ -100,7 +101,7 @@ struct CollectiveBuilder< cute::size<2>(TileShape_MNK{})))); // Assigning 4 warps for mainloop load of B - static constexpr int NumLoadThreadsCpAsync = 128; + static constexpr int NumLoadThreadsCpAsync = 128 / size(AtomThrID{}); using SmemShapeA_M = decltype(shape_div(shape<0>(TileShape_MNK{}), shape_div(shape<0>(TileShape_MNK{}), size<0>(TileShape_MNK{}) / size(AtomThrID{})))); @@ -116,7 +117,7 @@ struct CollectiveBuilder< using GmemCopyAtomB = cute::Copy_Atom, ElementB>; using GmemTiledCopyB = decltype(detail::make_simt_gmem_tiled_copy< GmemCopyAtomB, NumLoadThreadsCpAsync, AlignmentB, TagToStrideB_t, - decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK{}) / size(AtomThrID{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using BlockTileB_N = decltype(cute::size<0,0>(MmaShapeB_NK{}) * cute::size<1>(MmaShapeB_NK{})); using BlockTileB_K = decltype(cute::size<0,1>(MmaShapeB_NK{}) * cute::size<2>(MmaShapeB_NK{})); @@ -133,10 +134,19 @@ struct CollectiveBuilder< static constexpr auto CLCPipelineStorage = sizeof(typename cutlass::PipelineCLCFetchAsync::SharedStorage); // CLC (scheduler) response static constexpr auto CLCResponseStorage = SchedulerPipelineStageCount * detail::CLCResponseSize; + // Tmem dealloc barrier + static constexpr auto TmemDeallocStorage = sizeof(cutlass::arch::ClusterBarrier); + // MMA trampoline barrier (for 2SM synchronization) + static constexpr auto MmaTrampolineBarrierStorage = sizeof(cutlass::arch::ClusterBarrier); + // Tmem base pointer storage + static constexpr auto TmemBasePtrStorage = sizeof(uint32_t); // Smem usage that's not part of CollectiveEpilogue::SharedStorage & CollectiveMainloop::SharedStorage static constexpr auto KernelSmemCarveout = static_cast( AccumulatorPipelineStorage + CLCPipelineStorage + - CLCResponseStorage); + CLCResponseStorage + + TmemDeallocStorage + + MmaTrampolineBarrierStorage + + TmemBasePtrStorage); // Reduce SMEM capacity available for buffers considering barrier allocations. static constexpr int ReducedSmemCapacityBytes = detail::sm100_reduced_smem_capacity_bytes(); using SmemTileShape = cute::Shape; diff --git a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl index a8d27d527..0aba04d99 100644 --- a/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm100_umma_builder.inl @@ -188,6 +188,7 @@ struct CollectiveBuilder< // Dense Gemm / PtrArrayDenseGemm ( (not cute::is_same_v) && + (not cute::is_same_v) && (not cute::is_same_v) && (cute::is_base_of_v || cute::is_same_v)) && diff --git a/include/cutlass/gemm/collective/builders/sm1xx_common.inl b/include/cutlass/gemm/collective/builders/sm1xx_common.inl index 2dd6fdafc..73bd6a34d 100644 --- a/include/cutlass/gemm/collective/builders/sm1xx_common.inl +++ b/include/cutlass/gemm/collective/builders/sm1xx_common.inl @@ -503,6 +503,7 @@ check_input_datatypes() { || (cute::is_same_v) || (cute::is_same_v) || (cute::is_same_v) + || (cute::is_same_v) // SM100 BS ptr_array || (cute::is_same_v) || (cute::is_same_v) diff --git a/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp index bbd4a4920..f67704925 100644 --- a/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_blockscaled_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -103,11 +103,6 @@ struct CollectiveMma< using TiledMma = TiledMma_; using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; - // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received - static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); - static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); - - static_assert(size(typename TiledMma::AtomThrID{}) == 1); using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled< Stages, @@ -132,10 +127,13 @@ struct CollectiveMma< using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); - using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + using LoadShapeB_NK = decltype(make_shape( + get<1>(TileShape{}) / size(AtomThrShapeMNK{}), get<2>(TileShape{}) + )); // CtaShape_MNK is queried from collective in all kernel layers - using CtaShape_MNK = TileShape; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); + static_assert(shape<1>(CtaShape_MNK{}) == 192 or shape<1>(CtaShape_MNK{}) == 64 or shape<1>(CtaShape_MNK{}) == 128 or shape<1>(CtaShape_MNK{}) == 256, "Cta N should be one of 64/128/192/256"); @@ -320,6 +318,7 @@ struct CollectiveMma< // Device side kernel params struct Params { + static_assert(cute::is_static_v, "`ClusterShape` must be static in mixed TMA cpasync kernel."); using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{}))); using ClusterLayoutSfb_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), @@ -451,6 +450,25 @@ struct CollectiveMma< bool implementable = true; + static constexpr bool IsDynamicCluster = !cute::is_static_v; + + constexpr bool IsBlockscaled = !cute::is_void_v; + if constexpr (IsBlockscaled) { + if constexpr (IsDynamicCluster) { + implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); + // Special cluster shape check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= (args.hw_info.cluster_shape.x <= 4 && args.hw_info.cluster_shape.y <= 4 && + args.hw_info.cluster_shape_fallback.x <= 4 && args.hw_info.cluster_shape_fallback.y <= 4); + } + else { + // Special cluster shape check for scale factor multicasts. Due to limited size of scale factors, we can't multicast among + // more than 4 CTAs + implementable &= ((size<0>(ClusterShape{}) <= 4) && (size<1>(ClusterShape{}) <= 4)); + } + } + + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); @@ -567,7 +585,7 @@ struct CollectiveMma< Tensor gSFB_nkl = local_tile(mSFB_nkl, TileShape_SF{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (TILE_N,TILE_K,n,k,l) - ThrMMA cta_mma = TiledMma{}.get_slice(0); + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) @@ -582,9 +600,10 @@ struct CollectiveMma< // Define the CTA-in-cluster Layout and Coord Layout cta_layout_mnk = make_layout(ClusterShape{}); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + uint32_t cta_rank_in_cluster = static_cast(cute::block_rank_in_cluster()); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster); Layout cta_layout_sfb_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma_SF::AtomThrID{})); - auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(0); + auto cta_coord_sfb_vmnk = cta_layout_sfb_vmnk.get_flat_coord(cta_rank_in_cluster); // Project the cta_layout for tma_a along the n-modes auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, @@ -599,10 +618,15 @@ struct CollectiveMma< get<1>(cta_coord_sfb_vmnk), make_layout(size<1>(cta_layout_sfb_vmnk)), group_modes<0,3>(sSFB), group_modes<0,3>(tCgSFB_nkl)); + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfa = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t mcast_mask_sfb = create_tma_multicast_mask<1>(cta_layout_sfb_vmnk, cta_coord_sfb_vmnk); + return cute::make_tuple( shape<3>(gA_mkl), // for scheduler tAgA_mkl, tAsA, // for input tensor values - tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB // for input scale factor tensor values + tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB, // for input scale factor tensor values + mcast_mask_a, mcast_mask_sfa, mcast_mask_sfb ); } @@ -628,11 +652,13 @@ struct CollectiveMma< Tensor mB_nkl = make_tensor(make_gmem_ptr(ptr_B), shape_b, stride_b); //(n,k,l) // Partition for cpasync Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor tBgB_nkl = flatten(flat_divide(gB_nkl, make_shape(safe_div(size(get<1>(TileShape{})), size(AtomThrShapeMNK{}))))); // Build the coordinate tensors with the same shape as input matrices Tensor cB_nk = make_identity_tensor(make_shape(N,K)); // Slice the coordinate tensors in the same way as A/B tensor partitioning Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + Tensor ctBgB_nk = flatten(flat_divide(cgB_nk, make_shape(safe_div(size(get<1>(TileShape{})), size(AtomThrShapeMNK{}))))); Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), LoadSmemLayoutB{}); @@ -642,7 +668,7 @@ struct CollectiveMma< auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); return cute::make_tuple( - gB_nkl, cgB_nk, sB, + tBgB_nkl, ctBgB_nk, sB, // problem_shape_MNKL, gmem_to_smem_b_tiled_copy, thr_copy_b); } @@ -687,13 +713,13 @@ struct CollectiveMma< auto tiled_copy_s2t_SFA = make_utccp_copy(UtccpOp{}, tCtSFA_compact); auto tiled_copy_s2t_SFB = make_utccp_copy(UtccpOp{}, tCtSFB_compact); - auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(0); + auto thr_copy_s2t_SFA = tiled_copy_s2t_SFA.get_slice(blockIdx.x % size(AtomThrID{})); auto thr_tCsSFA_compact_s2t_ = thr_copy_s2t_SFA.partition_S(tCsSFA_compact); // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor auto thr_tCsSFA_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFA_compact_s2t_); auto thr_tCtSFA_compact_s2t = thr_copy_s2t_SFA.partition_D(tCtSFA_compact); - auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(0); + auto thr_copy_s2t_SFB = tiled_copy_s2t_SFB.get_slice(blockIdx.x % size(AtomThrID{})); auto thr_tCsSFB_compact_s2t_ = thr_copy_s2t_SFB.partition_S(tCsSFB_compact); // SMEM to TMEM copy operation requires source SMEM operand to be an SMEM descriptor auto thr_tCsSFB_compact_s2t = get_utccp_smem_desc_tensor(thr_tCsSFB_compact_s2t_); @@ -745,7 +771,9 @@ struct CollectiveMma< auto [k_tiles, tAgA_mkl, tAsA, - tAgSFA_mkl, tBgSFB_nkl, tAsSFA, tBsSFB] = load_inputs; + tAgSFA_mkl, tBgSFB_nkl, + tAsSFA, tBsSFB, + mcast_mask_a, mcast_mask_sfa, mcast_mask_sfb] = load_inputs; // slice out the work coord from partitioned tensors Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); @@ -768,9 +796,9 @@ struct CollectiveMma< barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); if (cute::elect_one_sync()) { - copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); - copy(observed_tma_load_sfa_->with(*tma_barrier), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); - copy(observed_tma_load_sfb_->with(*tma_barrier), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_sfa_->with(*tma_barrier, mcast_mask_sfa), tAgSFA(_,*k_tile_iter), tAsSFA(_,write_stage)); + copy(observed_tma_load_sfb_->with(*tma_barrier, mcast_mask_sfb), tBgSFB(_,*k_tile_iter), tBsSFB(_,write_stage)); } --k_tile_count; @@ -821,10 +849,12 @@ struct CollectiveMma< auto [M,N,K,L] = effective_shape; + auto peer_cta_idx = get<0>(cta_coord_mnkl) % size(AtomThrShapeMNK{}); + // Slice out the work coord from partitioned tensors - Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor gB_in = tBgB_nkl(_, peer_cta_idx, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); // Repeat slicing out coordinate tensor exactly the same as input tensor does - Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + Tensor cgB_nk_in = cgB_nk(_, peer_cta_idx, _, get<1>(cta_coord_mnkl), _); auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative @@ -865,7 +895,7 @@ struct CollectiveMma< copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + mainloop_pipeline.producer_commit_local(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); --k_tile_count; ++k_tile_iter; ++mainloop_pipe_producer_state; @@ -892,7 +922,7 @@ struct CollectiveMma< --k_tile_count; // UNLOCK mainloop_pipe_producer_state - mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + mainloop_pipeline.producer_commit_local(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); // Advance mainloop_pipe_producer_state ++mainloop_pipe_producer_state; @@ -935,7 +965,11 @@ struct CollectiveMma< cute::tuple> const& accumulators_pair, cute::tuple const& mma_inputs, CtaTileCoord cta_tile_coord, - int k_tile_count + int k_tile_count, + bool is_mma_leader_cta, + uint32_t mma_peer_cta_rank, + arch::ClusterBarrier& mma_trampoline_barrier, + uint32_t mma_trampoline_barrier_phase ) { static_assert(is_tmem::value, "Accumulator must be tmem resident."); static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); @@ -952,6 +986,8 @@ struct CollectiveMma< auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + constexpr bool is_2sm = size(AtomThrShapeMNK{}) > 1; + auto tCtSFB_mma = [tCtSFB = tCtSFB, cta_tile_coord]() { if constexpr (IsCtaN192) { // If this is an ODD tile, shift the TMEM start address for N=192 case by two words (ignores first 64 columns of SFB) @@ -973,15 +1009,18 @@ struct CollectiveMma< }(); // Wait for tmem accumulator buffer to become empty with a flipped phase - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); - + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } // // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + if (is_mma_leader_cta) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + } mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); @@ -992,26 +1031,47 @@ struct CollectiveMma< copy(tiled_copy_s2t_SFB, thr_tCsSFB_s2t(_,_,_,_,read_stage_tma), thr_tCtSFB_s2t); } - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma.with(tiled_mma.accumulate_, - tCtSFA(_,_,k_block), - tCtSFB_mma(_,_,k_block)), - tCrA(_,_,k_block,read_stage_tma), - tCrB(_,_,k_block,read_stage_cpasync), - accumulators); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; + if (is_mma_leader_cta) { + if constexpr (is_2sm) { + mma_trampoline_barrier.wait(mma_trampoline_barrier_phase); + } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma.with(tiled_mma.accumulate_, + tCtSFA(_,_,k_block), + tCtSFB_mma(_,_,k_block)), + tCrA(_,_,k_block,read_stage_tma), + tCrB(_,_,k_block,read_stage_cpasync), + accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + } else { + if constexpr (is_2sm) { + mma_trampoline_barrier.arrive(mma_peer_cta_rank); + } } - mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); - mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + if constexpr (is_2sm) { + if (is_mma_leader_cta) { + mma_trampoline_barrier.arrive(mma_peer_cta_rank); + } else { + mma_trampoline_barrier.wait(mma_trampoline_barrier_phase); + } + } + + if (is_mma_leader_cta) { + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + } --k_tile_count; ++mainloop_pipe_tma_consumer_state; ++mainloop_pipe_cpasync_consumer_state; + + mma_trampoline_barrier_phase ^= 1; } - return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, mma_trampoline_barrier_phase); } protected: diff --git a/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp index 945e5feee..a8eac4aa8 100644 --- a/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm100_mma_mixed_tma_cpasync_warpspecialized.hpp @@ -56,6 +56,8 @@ using namespace cute; ///////////////////////////////////////////////////////////////////////////////////////////////// + + // WarpSpecialized Mainloop // Both DMA Load and MMA methods of this class must be run by a single thread that's picked by elect_one template < @@ -63,7 +65,7 @@ template < int SchedulerPipelineStageCount, int AccumulatorPipelineStageCount, class ArchTag_, - class ClusterShape, // Static cluster shape or dynamic (int, int, _1) + class ClusterShape, // Static cluster shape class TileShape_, // (MmaAtomShapeM, MmaAtomShapeN, TileK) class ElementA_, class StrideA_, @@ -103,12 +105,6 @@ struct CollectiveMma< using TiledMma = TiledMma_; using AtomThrShapeMNK = Shape(typename TiledMma::ThrLayoutVMNK{})), _1, _1>; - // Statically asserting to ensure only 1x1x1 cluster shape & 1sm setup is received - static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment SM100 GEMM only supports 1SM MMA"); - static_assert(size(ClusterShape{}) == 1, "CPASYNC does not support multicast so the cluster shape is restricted to 1, 1, 1"); - - static_assert(size(typename TiledMma::AtomThrID{}) == 1); - using DispatchPolicy = MainloopSm100UmmaMixedTmaCpAsyncWarpSpecialized< Stages, SchedulerPipelineStageCount, @@ -124,11 +120,12 @@ struct CollectiveMma< // Define A and B block shapes using MmaShapeA_MK = decltype(partition_shape_A(TiledMma{}, make_shape(size<0>(TileShape{}), size<2>(TileShape{})))); using MmaShapeB_NK = decltype(partition_shape_B(TiledMma{}, make_shape(size<1>(TileShape{}), size<2>(TileShape{})))); - // using LoadShapeA_MK = decltype(select<0,2>(TileShape{})); - using LoadShapeB_NK = decltype(select<1,2>(TileShape{})); + using LoadShapeB_NK = decltype(make_shape( + get<1>(TileShape{}) / size(AtomThrShapeMNK{}), get<2>(TileShape{}) + )); // CtaShape_MNK is queried from collective in all kernel layers - using CtaShape_MNK = TileShape; + using CtaShape_MNK = decltype(shape_div(TileShape{}, AtomThrShapeMNK{})); using ElementA = ElementA_; using ElementAMma = typename TiledMma::ValTypeA; @@ -257,6 +254,7 @@ struct CollectiveMma< // Device side kernel params struct Params { + static_assert(cute::is_static_v, "`ClusterShape` must be static in mixed TMA cpasync kernel."); using ClusterLayout_VMNK = decltype(tiled_divide(make_layout(ClusterShape{}), make_tile(typename TiledMma::AtomThrID{}))); @@ -281,8 +279,10 @@ struct CollectiveMma< CollectiveMma(Params const& params) : runtime_data_type_a_(params.runtime_data_type_a) , runtime_data_type_b_(params.runtime_data_type_b) { - + + observed_tma_load_a_ = ¶ms.tma_load_a; + } template @@ -394,9 +394,8 @@ struct CollectiveMma< /// Set up the data needed by this collective for load. /// Return tuple element contain /// gA_mkl - The tiled tensor for input A - /// gB_nkl - The tiled tensor for input B /// tAsA - partitioned smem tensor for A - /// tBsB - partitioned smem tensor for B + /// mcast_mask_a - tma multicast mask for A template CUTLASS_DEVICE auto load_init_tma( @@ -410,7 +409,7 @@ struct CollectiveMma< Tensor mA_mkl = observed_tma_load_a_->get_tma_tensor(make_shape(M,K,L)); Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M, BLK_K, m, k, l) - ThrMMA cta_mma = TiledMma{}.get_slice(0); + ThrMMA cta_mma = TiledMma{}.get_slice(blockIdx.x % size(typename TiledMma::AtomThrID{})); Tensor tCgA_mkl = cta_mma.partition_A(gA_mkl); // (MMA, MMA_M, MMA_K, m, k, l) Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (MMA,MMA_M,MMA_K,PIPE) @@ -418,16 +417,20 @@ struct CollectiveMma< // Define the CTA-in-cluster Layout and Coord Layout cta_layout_mnk = make_layout(ClusterShape{}); Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMma::AtomThrID{})); - auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(0); + uint32_t cta_rank_in_cluster = static_cast(cute::block_rank_in_cluster()); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster); // Project the cta_layout for tma_a along the n-modes auto [tAgA_mkl, tAsA] = tma_partition(*observed_tma_load_a_, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), group_modes<0,3>(sA), group_modes<0,3>(tCgA_mkl)); - + + uint16_t mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + return cute::make_tuple( shape<3>(gA_mkl), // for scheduler - tAgA_mkl, tAsA // for input tensor values + tAgA_mkl, tAsA, // for input tensor values + mcast_mask_a // for TMA multicast ); } @@ -451,11 +454,13 @@ struct CollectiveMma< Tensor mB_nkl = make_tensor(make_gmem_ptr(params.ptr_B), shape_b, stride_b); //(n,k,l) // Partition for cpasync Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + Tensor tBgB_nkl = flatten(flat_divide(gB_nkl, make_shape(safe_div(size(get<1>(TileShape{})), size(AtomThrShapeMNK{}))))); // Build the coordinate tensors with the same shape as input matrices Tensor cB_nk = make_identity_tensor(make_shape(N,K)); // Slice the coordinate tensors in the same way as A/B tensor partitioning Tensor cgB_nk = local_tile(cB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + Tensor ctBgB_nk = flatten(flat_divide(cgB_nk, make_shape(safe_div(size(get<1>(TileShape{})), size(AtomThrShapeMNK{}))))); Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), LoadSmemLayoutB{}); @@ -465,7 +470,7 @@ struct CollectiveMma< auto thr_copy_b = gmem_to_smem_b_tiled_copy.get_slice(thread_idx); return cute::make_tuple( - gB_nkl, cgB_nk, sB, + tBgB_nkl, ctBgB_nk, sB, gmem_to_smem_b_tiled_copy, thr_copy_b); } @@ -514,7 +519,8 @@ struct CollectiveMma< MainloopPipelineTMAState mainloop_pipe_producer_state, cute::tuple const& load_inputs, + STensorA, + uint16_t> const& load_inputs, TileCoordMNKL const& cta_coord_mnkl, KTileIterator k_tile_iter, int k_tile_count) { @@ -522,10 +528,11 @@ struct CollectiveMma< KTileCount k_tiles = get<0>(load_inputs); GTensorPartitionedA tAgA_mkl = get<1>(load_inputs); STensorA tAsA = get<2>(load_inputs); + uint16_t mcast_mask_a = get<3>(load_inputs); // slice out the work coord from partitioned tensors Tensor tAgA = tAgA_mkl(_, get<0>(cta_coord_mnkl) / size(typename TiledMma::AtomThrID{}), _, get<3>(cta_coord_mnkl)); - + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); // Issue the Mainloop loads @@ -542,7 +549,7 @@ struct CollectiveMma< barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state); if (cute::elect_one_sync()) { - copy(observed_tma_load_a_->with(*tma_barrier), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); + copy(observed_tma_load_a_->with(*tma_barrier, mcast_mask_a), tAgA(_,*k_tile_iter), tAsA(_,write_stage)); } --k_tile_count; @@ -583,13 +590,15 @@ struct CollectiveMma< auto [M,N,K,L] = effective_shape; + auto peer_cta_idx = get<0>(cta_coord_mnkl) % size(AtomThrShapeMNK{}); + // Slice out the work coord from partitioned tensors - Tensor gB_in = tBgB_nkl(_, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); - // Repeat slicing out coordinate tensor exactly the same as input tensor does - Tensor cgB_nk_in = cgB_nk(_, _, get<1>(cta_coord_mnkl), _); + Tensor gB_in = tBgB_nkl(_, peer_cta_idx, _, get<1>(cta_coord_mnkl), _, get<3>(cta_coord_mnkl)); + Tensor cgB_nk_in = cgB_nk(_, peer_cta_idx, _, get<1>(cta_coord_mnkl), _); auto k_residue = K - size<1>(gB_in) * size<2>(gB_in); // K - BLK_K * k is negative + // Repeat slicing out coordinate tensor exactly the same as input tensor does Tensor gB = gB_in; Tensor cB = cgB_nk_in; @@ -627,12 +636,11 @@ struct CollectiveMma< copy_if(gmem_to_smem_b_tiled_copy, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); - mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + mainloop_pipeline.producer_commit_local(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); --k_tile_count; ++k_tile_iter; ++mainloop_pipe_producer_state; } - // last tile with predication on k to account for residue // For performance consideration, // this predicated block for K-tail is only activated when there is k-residue @@ -654,7 +662,7 @@ struct CollectiveMma< --k_tile_count; // UNLOCK mainloop_pipe_producer_state - mainloop_pipeline.producer_commit(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); + mainloop_pipeline.producer_commit_local(mainloop_pipe_producer_state, cutlass::arch::cpasync_barrier_arrive); // Advance mainloop_pipe_producer_state ++mainloop_pipe_producer_state; @@ -666,12 +674,6 @@ struct CollectiveMma< /// Perform a Producer Epilogue to prevent early exit of ctas in a Cluster CUTLASS_DEVICE void load_tail_tma(MainloopPipelineTMA mainloop_pipeline, MainloopPipelineTMAState mainloop_pipe_producer_state) { - // Issue the epilogue waits - // This helps avoid early exit of ctas in Cluster - // Waits for all stages to either be released (all - // Consumer UNLOCKs), or if the stage was never used - // then would just be acquired since the phase was - // still inverted from make_producer_start_state mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); } CUTLASS_DEVICE void @@ -697,7 +699,11 @@ struct CollectiveMma< cute::tuple> const& accumulators_pair, cute::tuple const& mma_inputs, CtaTileCoord cta_tile_coord, - int k_tile_count + int k_tile_count, + bool is_mma_leader_cta, + uint32_t mma_peer_cta_rank, + arch::ClusterBarrier& mma_trampoline_barrier, + uint32_t mma_trampoline_barrier_phase ) { static_assert(is_tmem::value, "Accumulator must be tmem resident."); static_assert(rank(FrgLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA, MMA_M, MMA_N)"); @@ -707,37 +713,63 @@ struct CollectiveMma< auto [mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline] = pipelines; auto [mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state] = pipeline_states; + constexpr bool is_2sm = size(AtomThrShapeMNK{}) > 1; + // // PIPELINED MAIN LOOP // tiled_mma.accumulate_ = UMMA::ScaleOut::Zero; // Wait for tmem accumulator buffer to become empty with a flipped phase - accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + if (is_mma_leader_cta) { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + } CUTLASS_PRAGMA_NO_UNROLL while (k_tile_count > 0) { - mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + if (is_mma_leader_cta) { + mainloop_pipeline_tma.consumer_wait(mainloop_pipe_tma_consumer_state); + } mainloop_pipeline_cpasync.consumer_wait(mainloop_pipe_cpasync_consumer_state); int read_stage_tma = mainloop_pipe_tma_consumer_state.index(); int read_stage_cpasync = mainloop_pipe_cpasync_consumer_state.index(); - - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M) x (V,N) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators); - tiled_mma.accumulate_ = UMMA::ScaleOut::One; + if (is_mma_leader_cta) { + if constexpr (is_2sm) { + mma_trampoline_barrier.wait(mma_trampoline_barrier_phase); + } + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage_tma), tCrB(_,_,k_block,read_stage_cpasync), accumulators); + tiled_mma.accumulate_ = UMMA::ScaleOut::One; + } + } else { + if constexpr (is_2sm) { + mma_trampoline_barrier.arrive(mma_peer_cta_rank); + } } - mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); - mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + if constexpr (is_2sm) { + if (is_mma_leader_cta) { + mma_trampoline_barrier.arrive(mma_peer_cta_rank); + } else { + mma_trampoline_barrier.wait(mma_trampoline_barrier_phase); + } + } + + if (is_mma_leader_cta) { + mainloop_pipeline_tma.consumer_release(mainloop_pipe_tma_consumer_state); + mainloop_pipeline_cpasync.consumer_release(mainloop_pipe_cpasync_consumer_state); + } --k_tile_count; ++mainloop_pipe_tma_consumer_state; ++mainloop_pipe_cpasync_consumer_state; + + mma_trampoline_barrier_phase ^= 1; } - return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state); + return cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, mma_trampoline_barrier_phase); } protected: @@ -745,7 +777,6 @@ protected: typename Params::TMA_A const* observed_tma_load_a_{nullptr}; RuntimeDataTypeA runtime_data_type_a_{}; RuntimeDataTypeB runtime_data_type_b_{}; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 7d3043e09..9cd9d2569 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -715,6 +715,7 @@ struct KernelTmaWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelSchedul struct KernelTmaWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; // Use for 2SM Dense GEMM Kernels for Collective Mainloop Builder struct KernelWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; // Use for 1SM Dense GEMM Kernels for Collective Mainloop Builder Without TMA struct KernelMixedTmaCpAsyncWarpSpecialized1SmSm100 final : KernelSchedule1Sm, KernelScheduleSm100DenseGemm {}; +struct KernelMixedTmaCpAsyncWarpSpecialized2SmSm100 final : KernelSchedule2Sm, KernelScheduleSm100DenseGemm {}; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 Ptr-Array Dense GEMM Dispatch Policies @@ -829,6 +830,7 @@ struct KernelTmaWarpSpecialized2SmMxf4Sm100 final : KernelSchedule2 struct KernelTmaWarpSpecialized1SmMxf8f6f4Sm100 final : KernelSchedule1Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelTmaWarpSpecialized2SmMxf8f6f4Sm100 final : KernelSchedule2Sm, KernelScheduleMxf8f6f4Sm100 { }; struct KernelMixedTmaCpAsyncWarpSpecialized1SmBlockScaledSm100 final : KernelSchedule1Sm, KernelScheduleBlockScaledGemmSm100 {}; +struct KernelMixedTmaCpAsyncWarpSpecialized2SmBlockScaledSm100 final : KernelSchedule2Sm, KernelScheduleBlockScaledGemmSm100 {}; /////////////////////////////////////////////////////////////////////////////////////////////////////// // SM100 BlockScaled Ptr Array Dense GEMM Dispatch Policies diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp index ca5597946..4b0a07381 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized.hpp @@ -416,7 +416,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); return implementable; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp index 25d5c6e32..0d2bd831e 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_input_transform.hpp @@ -324,7 +324,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp index 83ae76ac5..acbdc8cf1 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_array_tma_warpspecialized_mma_transform.hpp @@ -357,7 +357,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); return implementable; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp index d0fd31a5b..ea9369385 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_cpasync_warpspecialized.hpp @@ -244,7 +244,8 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; // Calculate workspace pointers uint8_t* workspace_ptr = reinterpret_cast(workspace); @@ -286,7 +287,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); static constexpr int MaxClusterSize = 16; implementable &= size(ClusterShape{}) <= MaxClusterSize; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp index 8b0fc43ad..a0ea7c870 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_mixed_tma_cpasync_warpspecialized.hpp @@ -59,6 +59,16 @@ namespace cutlass::gemm::kernel { /////////////////////////////////////////////////////////////////////////////// +namespace detail { +template +struct is_blockscaled_mixed_tma_cpasync : cute::false_type {}; + +template +struct is_blockscaled_mixed_tma_cpasync< + MainloopSm100UmmaMixedTmaCpAsyncWarpSpecializedBlockScaled +> : cute::true_type {}; +} // namespace detail + template < class ProblemShape_, class CollectiveMainloop_, @@ -72,14 +82,18 @@ class GemmUniversal< TileSchedulerTag_, cute::enable_if_t< cutlass::detail::is_kernel_tag_of_v>> + KernelMixedTmaCpAsyncWarpSpecializedSm100> + >> { public: using ProblemShape = ProblemShape_; static constexpr bool IsGroupedGemmKernel = cutlass::gemm::detail::is_moe_problem_shape::value; static constexpr bool IsMoEScheduler = false; // stub for MoE scheduler, which accepts a MoEProblemShape instead of GroupProblemShape - + static constexpr bool IsBlockscaled = detail::is_blockscaled_mixed_tma_cpasync< + typename CollectiveMainloop_::DispatchPolicy + >::value; + CUTLASS_HOST_DEVICE static auto get_problem_shape_gemm(ProblemShape const& shape) { if constexpr (IsGroupedGemmKernel) { @@ -156,7 +170,6 @@ public: using CtaShape_MNK = typename CollectiveMainloop::CtaShape_MNK; using AtomThrShapeMNK = typename CollectiveMainloop::AtomThrShapeMNK; - static_assert(size(AtomThrShapeMNK{}) == 1, "Lower alignment kernel only supports 1x1x1 cluster shape."); using TileSchedulerTag = cute::conditional_t; using TileScheduler = typename detail::TileSchedulerSelector< TileSchedulerTag, ArchTag, CtaShape_MNK, ClusterShape, SchedulerPipelineStageCount, ProblemShape>::Scheduler; @@ -210,7 +223,8 @@ public: cutlass::PipelineAsync>; using CLCPipelineState = typename CLCPipeline::PipelineState; - using TmemAllocator = cute::TMEM::Allocator1Sm; + using TmemAllocator = cute::conditional_t(typename TiledMma::ThrLayoutVMNK{})) == 1, + cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>; // Kernel level shared memory storage struct SharedStorage { @@ -225,6 +239,7 @@ public: alignas(16) CLCPipelineStorage clc; alignas(16) AccumulatorPipelineStorage accumulator; alignas(16) arch::ClusterBarrier tmem_dealloc; + alignas(16) arch::ClusterBarrier mma_trampoline_barrier; } pipelines; alignas(16) typename TileScheduler::CLCResponse clc_response[SchedulerPipelineStageCount]; @@ -305,7 +320,8 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; // Calculate workspace pointers uint8_t* workspace_ptr = reinterpret_cast(workspace); @@ -371,7 +387,7 @@ public: auto problem_shape_gemm = get_problem_shape_gemm(args.problem_shape); implementable &= CollectiveMainloop::can_implement(problem_shape_gemm, args.mainloop); implementable &= CollectiveEpilogue::can_implement(problem_shape_gemm, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); static constexpr int MaxClusterSize = 16; implementable &= size(ClusterShape{}) <= MaxClusterSize; @@ -484,11 +500,13 @@ public: auto cluster_shape = ClusterShape{}; constexpr int cluster_size = size(ClusterShape{}); int cta_rank_in_cluster = cute::block_rank_in_cluster(); - bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; int cta_coord_v = cta_rank_in_cluster % size<0>(typename TiledMma::AtomThrID{}); - bool is_mma_leader_cta = cta_coord_v == 0; int mma_leader_ctas = size(shape_div(cluster_shape, AtomThrShapeMNK{})); - [[maybe_unused]] uint32_t mma_peer_cta_rank = cta_rank_in_cluster; + constexpr bool has_mma_peer_cta = size(AtomThrShapeMNK{}) == 2; + uint32_t mma_peer_cta_rank = has_mma_peer_cta ? cta_rank_in_cluster ^ 1 : cta_rank_in_cluster; + bool is_mma_leader_cta = cta_coord_v == 0; + [[maybe_unused]] bool is_first_cta_in_cluster = cta_rank_in_cluster == 0; + [[maybe_unused]] uint32_t mma_leader_cta_rank = is_mma_leader_cta? cta_rank_in_cluster : mma_peer_cta_rank; // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); @@ -497,12 +515,19 @@ public: CollectiveMainloop collective_mainloop(params.mainloop); CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + arch::ClusterBarrier& mma_trampoline_barrier = shared_storage.pipelines.mma_trampoline_barrier; + if (WarpCategory::MMA == warp_category && lane_predicate) { + mma_trampoline_barrier.init(NumMMAThreads); + } + + // Do we load source tensor C or other aux inputs bool is_epi_load_needed = collective_epilogue.is_producer_load_needed(); IsParticipant is_participant = { - (warp_category == WarpCategory::MMA) && is_mma_leader_cta, // mma - (warp_category == WarpCategory::Sched) && is_first_cta_in_cluster, // sched + (warp_category == WarpCategory::MMA), // mma + (warp_category == WarpCategory::Sched) + && (!IsSchedDynamicPersistent || is_first_cta_in_cluster), // sched (warp_category == WarpCategory::MainloopLoadTMA), // main_load_tma (warp_category == WarpCategory::EpilogueLoad) && is_epi_load_needed, // epi_load (warp_category == WarpCategory::Epilogue), // epilogue @@ -521,11 +546,27 @@ public: mainloop_pipeline_tma_params.is_leader = lane_predicate && is_mma_leader_cta && is_participant.main_load_tma; mainloop_pipeline_tma_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; mainloop_pipeline_tma_params.initializing_warp = 0; - MainloopPipelineTMA mainloop_pipeline_tma(shared_storage.pipelines.mainloop.tma, - mainloop_pipeline_tma_params, - cluster_shape, - cute::true_type{}, // Perform barrier init - cute::false_type{}); // Delay mask calculation + MainloopPipelineTMA mainloop_pipeline_tma = [&] () { + if constexpr (IsBlockscaled) { + // If blockscaled, SFB is also multicasted, so we need to wait on the row and column CTAs. + return MainloopPipelineTMA(shared_storage.pipelines.mainloop.tma, + mainloop_pipeline_tma_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + } + else { + // If not blockscaled, there is no multicast across M mode (i.e. across columsn), so we + // don't need to wait on anything except the row CTAs. + return MainloopPipelineTMA(shared_storage.pipelines.mainloop.tma, + mainloop_pipeline_tma_params, + cluster_shape, + McastDirection::kRow, + cute::true_type{}, // Perform barrier init + cute::false_type{}); // Delay mask calculation + } + }(); + // Mainloop Load pipeline (CpAsync) typename MainloopPipelineCpAsync::Params mainloop_pipeline_cpasync_params; @@ -606,7 +647,13 @@ public: accumulator_pipeline_params.producer_arv_count = 1; accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * NumEpilogueThreads; accumulator_pipeline_params.initializing_warp = 2; - AccumulatorPipeline accumulator_pipeline(shared_storage.pipelines.accumulator, accumulator_pipeline_params, cluster_shape); + AccumulatorPipeline accumulator_pipeline( + shared_storage.pipelines.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::false_type{} // Delay mask init + ); // Tmem allocator TmemAllocator tmem_allocator{}; @@ -617,6 +664,11 @@ public: arch::ClusterBarrier& tmem_deallocation_result_barrier = shared_storage.pipelines.tmem_dealloc; [[maybe_unused]] uint32_t dealloc_barrier_phase = 0; + if (WarpCategory::MMA == warp_category) { + if (has_mma_peer_cta && lane_predicate) { + tmem_deallocation_result_barrier.init(NumMMAThreads); + } + } MainloopPipelineTMAState mainloop_pipe_tma_consumer_state; MainloopPipelineTMAState mainloop_pipe_tma_producer_state = cutlass::make_producer_start_state(); MainloopPipelineCpAsyncState mainloop_pipe_cpasync_consumer_state; @@ -639,6 +691,13 @@ public: pipeline_init_arrive_relaxed(cluster_size); dim3 block_id_in_cluster = cute::block_id_in_cluster(); + if constexpr (IsBlockscaled) { + mainloop_pipeline_tma.init_masks(cluster_shape); + } else { + mainloop_pipeline_tma.init_masks(cluster_shape, McastDirection::kRow); + } + accumulator_pipeline.init_masks(cluster_shape, block_id_in_cluster); + // TileID scheduler TileScheduler scheduler(&shared_storage.clc_response[0], params.scheduler, block_id_in_cluster); typename TileScheduler::WorkTileInfo work_tile_info = scheduler.initial_work_tile_info(cluster_shape); @@ -713,7 +772,7 @@ public: auto load_inputs = collective_mainloop.load_init_cpasync( problem_shape_MNKL, params.mainloop, shared_storage.tensors.mainloop, scheduler, work_tile_info); - Tensor gA_mkl = get<0>(load_inputs); + Tensor tBgB_nkl = get<0>(load_inputs); do { // Get current work tile and fetch next work tile @@ -722,7 +781,7 @@ public: auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. - auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, shape<3>(gA_mkl)); + auto k_tile_iter = scheduler.get_k_tile_iterator(work_tile_info, effective_shape, CtaShape_MNK{}, shape<4>(tBgB_nkl)); auto k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, effective_shape, CtaShape_MNK{}); auto [mainloop_producer_state_next, unused_] = collective_mainloop.load_cpasync( @@ -756,7 +815,7 @@ public: } else if (is_participant.sched) { - + if constexpr (IsSchedDynamicPersistent) { // Whether a new CLC query must be performed. // See comment below where this variable is updated for a description of @@ -815,13 +874,13 @@ public: __syncwarp(); tmem_allocation_result_barrier.arrive(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; - // bulk_tmem.data() = tmem_base_ptr; collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); + uint32_t mma_trampoline_barrier_phase = 0; // Pass the acc with tuple type since the bgrad kernel change the mma_init API - auto mma_inputs = collective_mainloop.mma_init(params.mainloop, - tmem_storage, + auto mma_inputs = collective_mainloop.mma_init(params.mainloop, + tmem_storage, shared_storage.tensors.mainloop); do { auto effective_shape = get_effective_shape(params.problem_shape, work_tile_info); @@ -842,8 +901,8 @@ public: // accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); int acc_stage = accumulator_pipe_producer_state.index(); - // Tensor accumulators = bulk_tmem(_,_,_,acc_stage); - auto [mainloop_pipe_tma_consumer_state_next_, mainloop_pipe_cpasync_consumer_state_next_] = collective_mainloop.mma( + + auto [mainloop_pipe_tma_consumer_state_next_, mainloop_pipe_cpasync_consumer_state_next_, mma_trampoline_barrier_phase_next_] = collective_mainloop.mma( cute::make_tuple(mainloop_pipeline_tma, mainloop_pipeline_cpasync, accumulator_pipeline), cute::make_tuple(mainloop_pipe_tma_consumer_state, mainloop_pipe_cpasync_consumer_state, accumulator_pipe_producer_state), // Pass the acc with tuple type since the bgrad kernel change the mma API @@ -851,12 +910,20 @@ public: collective_mainloop.slice_accumulator(tmem_storage, acc_stage), mma_inputs, cta_coord_mnkl, - k_tile_count + k_tile_count, + is_mma_leader_cta, + mma_peer_cta_rank, + mma_trampoline_barrier, + mma_trampoline_barrier_phase ); + mainloop_pipe_tma_consumer_state = mainloop_pipe_tma_consumer_state_next_; mainloop_pipe_cpasync_consumer_state = mainloop_pipe_cpasync_consumer_state_next_; + mma_trampoline_barrier_phase = mma_trampoline_barrier_phase_next_; - accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + if (is_mma_leader_cta) { + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + } ++accumulator_pipe_producer_state; work_tile_info = next_work_tile_info; @@ -865,7 +932,15 @@ public: // Release the right to allocate before deallocations so that the next CTA can rasterize tmem_allocator.release_allocation_lock(); - accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + if (is_mma_leader_cta) { + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + } + if constexpr (has_mma_peer_cta) { + // Leader does wait + arrive, follower does arrive + wait + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, not is_mma_leader_cta); + tmem_deallocation_result_barrier.wait(dealloc_barrier_phase); + tmem_deallocation_result_barrier.arrive(mma_peer_cta_rank, is_mma_leader_cta); + } // Free entire tmem allocation tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); @@ -924,7 +999,6 @@ public: tmem_allocation_result_barrier.arrive_and_wait(); uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; collective_mainloop.set_tmem_offsets(tmem_storage, tmem_base_ptr); - // bulk_tmem.data() = tmem_base_ptr; bool do_tail_store = false; do { diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp index 4ea3f0232..b8695e73c 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized.hpp @@ -307,7 +307,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp index c82e084f0..806cb3aac 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_input_transform.hpp @@ -311,7 +311,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp index 5533f109f..18846eb24 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mixed_input_transform.hpp @@ -310,7 +310,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp index 9a6b10ac8..69bb68ecc 100644 --- a/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp +++ b/include/cutlass/gemm/kernel/sm100_gemm_tma_warpspecialized_mma_transform.hpp @@ -317,7 +317,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp index d87ac8f7e..5d491a13d 100644 --- a/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm100_sparse_gemm_tma_warpspecialized.hpp @@ -335,7 +335,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { static constexpr int MaxClusterSize = 16; diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp index 739010c35..637fd4b9b 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler.hpp @@ -35,6 +35,7 @@ #include "cute/int_tuple.hpp" +#include "cutlass/kernel_hardware_info.hpp" #include "cutlass/arch/config.h" #include "cutlass/arch/barrier.h" #include "cutlass/detail/cluster.hpp" @@ -337,7 +338,7 @@ public: } static bool - can_implement(Arguments const& args) { + can_implement(Arguments const& args, KernelHardwareInfo const&) { return true; } diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp index 2d8728a99..d4e049269 100755 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_group.hpp @@ -32,6 +32,7 @@ #pragma once +#include "cutlass/kernel_hardware_info.hpp" #include "cutlass/arch/barrier.h" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp" @@ -110,7 +111,7 @@ public: } static bool - can_implement(Arguments const& args) { + can_implement(Arguments const& args, KernelHardwareInfo const&) { return true; } diff --git a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp index 54d81cabb..e1ba75810 100644 --- a/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm100_tile_scheduler_stream_k.hpp @@ -218,8 +218,14 @@ public: } static bool - can_implement(Arguments const& args) { - return UnderlyingStreamKScheduler::can_implement(args); + can_implement(Arguments const& args, KernelHardwareInfo const& hw_info) { + if (hw_info.cluster_shape.x != hw_info.cluster_shape_fallback.x || + hw_info.cluster_shape.y != hw_info.cluster_shape_fallback.y || + hw_info.cluster_shape.z != hw_info.cluster_shape_fallback.z) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Stream-K scheduler requires cluster shape and fallback cluster shape to be the same.\n"); + return false; + } + return UnderlyingStreamKScheduler::can_implement(args, hw_info); } CUTLASS_DEVICE diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp index 7416f417a..cec1bd101 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_array_tma_warpspecialized.hpp @@ -355,7 +355,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if (!implementable) { CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Mainloop, Epilogue or Scheduler don't meet the requirements for Ptr Array Gemm or Grouped Gemm.\n"); return implementable; diff --git a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp index 455464e66..14cb858b9 100644 --- a/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm103_blockscaled_gemm_tma_warpspecialized.hpp @@ -314,7 +314,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); if constexpr (IsDynamicCluster) { implementable &= cutlass::detail::preferred_cluster_can_implement(args.hw_info.cluster_shape, args.hw_info.cluster_shape_fallback); diff --git a/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp b/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp index 5f4e19364..761bbc4f7 100644 --- a/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp +++ b/include/cutlass/gemm/kernel/sm120_gemm_tma_warpspecialized_cooperative_asymmetric_dma.hpp @@ -229,7 +229,8 @@ public: CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + KernelHardwareInfo hw_info = args.hw_info; + hw_info.sm_count = sm_count; // Calculate workspace pointers uint8_t* workspace_ptr = reinterpret_cast(workspace); @@ -275,7 +276,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index faa9b1cd7..84d21f1e0 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -149,7 +149,7 @@ static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializin can_implement(Arguments const& args) { bool mode_implementable = args.mode == GemmUniversalMode::kGemm or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); - return mode_implementable && TileScheduler::can_implement(args.scheduler); + return mode_implementable && TileScheduler::can_implement(args.scheduler, args.hw_info); } static size_t diff --git a/include/cutlass/gemm/kernel/sm70_gemm_array.hpp b/include/cutlass/gemm/kernel/sm70_gemm_array.hpp index 409ecda15..3b385fe2d 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm_array.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm_array.hpp @@ -159,7 +159,7 @@ static_assert(is_valid_tile_scheduler, "SM70 kernel does not support specializin return implementable; } typename ProblemShape::UnderlyingProblemShape problem_shape = args.problem_shape.get_host_problem_shape(); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp index 3fec578ba..1b9fcbf72 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp @@ -336,7 +336,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp index 451e92511..c828e8295 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp @@ -348,7 +348,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 899ad0198..d39d430e1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -167,7 +167,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index 52904eddf..bb2d88f19 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -233,7 +233,7 @@ public: implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(transformed_problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 3f669e1b2..b38e3750a 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -289,7 +289,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index d86235aad..818b8b1a1 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -287,7 +287,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp index 584178d54..ec4faba85 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -188,7 +188,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 2eaab2458..61490a734 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -220,7 +220,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index d44839884..a2cbc5ec8 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -233,7 +233,7 @@ public: } implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); + implementable &= TileScheduler::can_implement(args.scheduler, args.hw_info); return implementable; } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp index d746ca735..0e879f1cf 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_group.hpp @@ -241,7 +241,7 @@ public: } static bool - can_implement(Arguments const& args) { + can_implement(Arguments const& args, KernelHardwareInfo const&) { return true; } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index c874d638a..f2a640327 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -245,7 +245,7 @@ public: } static bool - can_implement(Arguments const& args) { + can_implement(Arguments const& args, KernelHardwareInfo const&) { // Split count must be positive, and > 1 is only valid for heuristic and split-K decomposition modes return args.splits >= 1 && (args.splits == 1 || diff --git a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp index 0fb6b5298..f384dfa90 100644 --- a/include/cutlass/gemm/kernel/static_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/static_tile_scheduler.hpp @@ -126,7 +126,7 @@ public: CUTLASS_HOST_DEVICE static bool - can_implement(Arguments const& args) { + can_implement(Arguments const& args, KernelHardwareInfo const&) { return args.max_swizzle_size >= 0; } diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index 6b8c2427b..f99f9336b 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -91,6 +91,7 @@ struct KernelHardwareInfo { void const* kernel_ptr, cudaStream_t stream = nullptr) { int max_active_clusters = 0; +#if !(defined(__QNX__) && __QNX__ >= 800 && defined(NV_IS_SAFETY)) #if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) ClusterLauncher::LaunchConfig cluster_launch_config = ClusterLauncher::make_cluster_launch_config( cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1}, @@ -110,6 +111,10 @@ struct KernelHardwareInfo { #else CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster occupancy query."); return max_active_clusters; +#endif +#else + CUTLASS_TRACE_HOST("ClusterLauncher: cluster launch disabled for QNX 8+ safety builds"); + return max_active_clusters; #endif } diff --git a/include/cutlass/pipeline/sm100_pipeline.hpp b/include/cutlass/pipeline/sm100_pipeline.hpp index f02ae23a2..3892e4294 100644 --- a/include/cutlass/pipeline/sm100_pipeline.hpp +++ b/include/cutlass/pipeline/sm100_pipeline.hpp @@ -550,7 +550,7 @@ public: using ThreadCategory = typename Impl::ThreadCategory; using Params = typename Impl::Params; - using McastDirection = McastDirection; + using McastDirection = cutlass::McastDirection; // Helper function to initialize barriers static @@ -820,6 +820,18 @@ public: impl_.producer_acquire(state, barrier_token); } + template + CUTLASS_DEVICE + void producer_commit_local(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { + cute::forward(user_defined_arrive_op)(producer_get_barrier(state)); + producer_commit_local(state); + } + + CUTLASS_DEVICE + void producer_commit_local(PipelineState state) { + impl_.producer_commit(state); + } + template CUTLASS_DEVICE void producer_commit(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { diff --git a/include/cutlass/subbyte_reference.h b/include/cutlass/subbyte_reference.h index 543089d5d..83bcbd786 100644 --- a/include/cutlass/subbyte_reference.h +++ b/include/cutlass/subbyte_reference.h @@ -454,18 +454,21 @@ public: // // Homebrew read-modify-write // - Storage original; - Storage updated; + Storage assumed; +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) + Storage original = __nv_atomic_load_n(ptr_, __NV_ATOMIC_RELAXED); +#else + Storage original = *const_cast(ptr_); +#endif do { - original = (*ptr_); + assumed = original; + Storage updated = Storage((assumed & kUpdateMask) | new_bits); - updated = Storage((original & kUpdateMask) | new_bits); + original = atomicCAS(ptr_, assumed, updated); - original = atomicCAS(ptr_, original, updated); - - } while (updated != original); + } while (original != assumed); #else diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h index 4a08ba9e3..324093473 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -204,8 +204,7 @@ public: void store(Fragment const &frag, TensorCoord const & tile_offset) { store_with_pointer_offset( frag, - tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + - tile_offset.strided() * Shape::kStrided * stride_ + tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + tile_offset.strided() * Shape::kStrided * stride_ ); } diff --git a/include/cutlass/version.h b/include/cutlass/version.h index 5c30d8c6a..f388aa75e 100644 --- a/include/cutlass/version.h +++ b/include/cutlass/version.h @@ -35,8 +35,8 @@ #include #define CUTLASS_MAJOR 4 -#define CUTLASS_MINOR 5 -#define CUTLASS_PATCH 0 +#define CUTLASS_MINOR 4 +#define CUTLASS_PATCH 2 #ifdef CUTLASS_VERSIONS_GENERATED #include "cutlass/version_extended.h" diff --git a/media/docs/pythonDSL/cute_dsl.rst b/media/docs/pythonDSL/cute_dsl.rst index 0dfbf0656..50a7341d2 100644 --- a/media/docs/pythonDSL/cute_dsl.rst +++ b/media/docs/pythonDSL/cute_dsl.rst @@ -11,6 +11,7 @@ CuTe DSL Control Flow JIT Argument Generation JIT Argument: Layouts + Struct-like JIT Arguments JIT Caching JIT Compilation Options JIT Types diff --git a/media/docs/pythonDSL/cute_dsl_api/changelog.rst b/media/docs/pythonDSL/cute_dsl_api/changelog.rst index e4200e074..3c5d4b761 100644 --- a/media/docs/pythonDSL/cute_dsl_api/changelog.rst +++ b/media/docs/pythonDSL/cute_dsl_api/changelog.rst @@ -2,6 +2,24 @@ Changelog for CuTe DSL API changes ====================================== +`4.4.0 `_ (2026-03-24) +============================================================================== + +* Added native support for ``typing.NamedTuple`` as a JIT function argument. + + - A NamedTuple whose fields are DSL scalar types (``Int32``, ``Float32``, …) + can be passed directly to ``@cute.jit`` / ``cute.compile`` without any + protocol implementation. + - Fields are flattened field-by-field through the existing pytree system and + reconstructed via the NamedTuple constructor on entry to the kernel body. + Field attribute access (``tup.a``, ``tup.b``, …) works as in native Python. + - NamedTuple fields are **immutable** (tuple subclass). To replace a field, + construct a new NamedTuple inside the kernel. Use ``@native_struct`` when + mutable fields are required. + - See :doc:`../cute_dsl_general/dsl_struct_types` for a guide to NamedTuple, + ``@native_struct``, and other struct-like JIT argument types. + + `4.3.0 `_ (2025-10-20) ============================================================================== @@ -73,7 +91,7 @@ Changelog for CuTe DSL API changes - Introduce S2T CopyOps in `tcgen05/copy.py `_. - Introduce BlockScaled layout utilities in `blockscaled_layout.py `_ for creating the required scale factor layouts in global memory, shared memory and tensor memory. -* ``cutlass.cute.compile`` now supports compilation options. Refer to `JIT compilation options `_ for more details. +* ``cutlass.cute.compile`` now supports compilation options. Refer to `JIT compilation options `_ for more details. * ``cutlass.cute.testing.assert_`` now works for device JIT function. Specify ``--enable-assertions`` as compilation option to enable. * ``cutlass.cute.make_tiled_copy`` is now deprecated. Please use ``cutlass.cute.make_tiled_copy_tv`` instead. * Shared memory capacity query diff --git a/media/docs/pythonDSL/cute_dsl_general/debugging.rst b/media/docs/pythonDSL/cute_dsl_general/debugging.rst index 759086e3a..8b46dfd50 100644 --- a/media/docs/pythonDSL/cute_dsl_general/debugging.rst +++ b/media/docs/pythonDSL/cute_dsl_general/debugging.rst @@ -39,10 +39,11 @@ CuTe DSL provides environment variables to control logging level: # Enable console logging (default: False) export CUTE_DSL_LOG_TO_CONSOLE=1 - # Log to file instead of console (default: False) - export CUTE_DSL_LOG_TO_FILE=my_log.txt + # Log to file instead of console (default: False). + # Set to 1/True to enable; the log file path is chosen automatically by the DSL. + export CUTE_DSL_LOG_TO_FILE=1 - # Control log verbosity (0, 10, 20, 30, 40, 50, default: 10) + # Control log verbosity (0=disabled, 1=all messages (debug and above), 10=debug, 20=info, 30=warning, 40=error, 50=critical; default: 1) export CUTE_DSL_LOG_LEVEL=20 @@ -68,40 +69,53 @@ Similar to standard Python logging, different log levels provide varying degrees +--------+-------------+ -Dump the generated IR -~~~~~~~~~~~~~~~~~~~~~ +Save generated artifacts to files +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For users familiar with MLIR and compilers, CuTe DSL supports dumping the Intermediate Representation (IR). -This helps you verify whether the IR is generated as expected. +CuTe DSL can save generated artifacts (IR, PTX, CUBIN, …) to files for offline inspection. +Use ``CUTE_DSL_KEEP`` with a comma-separated list of artifact tokens: .. code:: bash - # Dump Generated CuTe IR (default: False) + # Save clean IR (after canonicalize+cse, human-readable) to a .mlir file + export CUTE_DSL_KEEP=ir + + # Save raw IR (before any passes) to a .mlir file + export CUTE_DSL_KEEP=ir-debug + + # Save PTX assembly to a .ptx file + export CUTE_DSL_KEEP=ptx + + # Save CUBIN binary to a .cubin file + export CUTE_DSL_KEEP=cubin + + # Save LLVM IR to a file + export CUTE_DSL_KEEP=llvm + + # Save multiple artifacts at once + export CUTE_DSL_KEEP=ir,ptx,cubin + + # Save all supported artifacts + export CUTE_DSL_KEEP=all + +Files are written to the current working directory by default. Use ``CUTE_DSL_DUMP_DIR`` +to redirect them (see `Change the dump directory`_ below). + +.. note:: + + The ``sass`` token requires ``nvdisasm`` (or ``nvdisasm_internal``) to be available + in your ``PATH``. It is usually installed with the CUDA toolkit. + +Print the generated IR to the console +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To print the IR directly to the console (without writing a file): + +.. code:: bash + + # Print generated IR to stdout (default: False) export CUTE_DSL_PRINT_IR=1 - # Keep Generated CuTe IR in a file (default: False) - export CUTE_DSL_KEEP_IR=1 - - -Dump the generated PTX & CUBIN -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For users familiar with PTX and SASS, CuTe DSL supports dumping the generated PTX and CUBIN. - -.. code:: bash - - # Dump generated PTX in a .ptx file (default: False) - export CUTE_DSL_KEEP_PTX=1 - - # Dump generated cubin in a .cubin file (default: False) - export CUTE_DSL_KEEP_CUBIN=1 - -To further get SASS from cubin, users can use ``nvdisasm`` (usually installed with CUDA toolkit) to disassemble the cubin. - -.. code:: bash - - nvdisasm your_dsl_code.cubin > your_dsl_code.sass - Access the dumped contents programmatically ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst index baeb6a7ce..284962772 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_code_generation.rst @@ -101,18 +101,11 @@ The result: |DSL| bridges Python and GPU hardware through a three-stage pipeline. -.. _fig-dsl-modes: - -.. figure:: dsl_modes.png - :width: 400 +.. figure:: dsl_compilation.png + :width: 600 :align: center - *Left*: tracing mode records only the path that executed. - *Right*: preprocessor mode emits structured |IR| for every branch and loop - before tracing the arithmetic. - - - The default |DSL| compilation pipeline (mode 2): Python source flows through AST preprocessing + The |DSL| compilation pipeline: Python source flows through AST preprocessing and interpreter-driven tracing to produce |IR|, which is then lowered and compiled to device code. @@ -258,8 +251,8 @@ Practical Implications 4. |DSL| Code-Generation Modes ------------------------------ -CuTe's Python front-end combines the techniques above into **two mutually -exclusive modes** (see :ref:`fig-dsl-modes`), selectable with the ``preprocessor`` flag of the +CuTe’s Python front-end combines the techniques above into **two mutually +exclusive modes**, selectable with the ``preprocessor`` flag of the ``@jit`` decorator: 1. Tracing mode ``@jit(preprocess=False)`` – tracing only. @@ -272,3 +265,10 @@ optimisation problems of pure tracing; tracing then fills in the arithmetic. This hybrid “preprocessor” pipeline is unique to |DSL| and was designed specifically to overcome the disadvantages identified above. +.. figure:: dsl_modes.png + :width: 400 + :align: center + + *Left*: tracing mode records only the path that executed. + *Right*: preprocessor mode emits structured |IR| for every branch and loop + before tracing the arithmetic. diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst index de1f088ab..116cb86cd 100644 --- a/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_introduction.rst @@ -117,6 +117,12 @@ Defines GPU kernel functions, compiled as specialized GPU symbols through |DC|. - ``False`` (default) — Standard kernel launch. - ``True`` — Cooperative kernel launch. +- ``smem_merge_branch_allocs`` + Enables mutually exclusive control flow branches (sequentially executed if-else) to reuse the same shared memory. + + - ``False`` (default) — Shared memory is allocated additively across all branches (default CUDA C++ behavior). + - ``True`` — Merge shared-memory allocations across branches (experimental feature, recommended for mega-kernels). + Calling Conventions ------------------- diff --git a/media/docs/pythonDSL/cute_dsl_general/dsl_struct_types.rst b/media/docs/pythonDSL/cute_dsl_general/dsl_struct_types.rst new file mode 100644 index 000000000..6f8d2d17c --- /dev/null +++ b/media/docs/pythonDSL/cute_dsl_general/dsl_struct_types.rst @@ -0,0 +1,134 @@ +.. _dsl_struct_types: + +Struct-like JIT Arguments +========================= + +|DSL| supports several struct-like Python types as JIT function arguments. +Each provides a different trade-off between mutability, syntax convenience, +and low-level control. + +.. |DSL| replace:: CuTe DSL + +.. contents:: On this page + :local: + :depth: 2 + + +Overview +-------- + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Type + - Mutable fields? + - Notes + * - ``typing.NamedTuple`` + - **No** + - Tuple subclass — fields fixed at construction. + Flattened field-by-field through the pytree system. + * - ``@dataclass(frozen=True)`` + - **No** + - Frozen dataclass — treated as a read-only pytree container, + similar to ``NamedTuple``. + + +NamedTuple +---------- + +A ``typing.NamedTuple`` whose fields are DSL scalar types (``Int32``, +``Float32``, etc.) can be passed directly to ``@cute.jit`` / +``cute.compile`` without any boilerplate or protocol implementation. + +**How it works.** NamedTuples are registered as pytree containers in the DSL +tree system. Each field is flattened individually through the existing DSL +type paths and reconstructed by calling the NamedTuple constructor on the way +into the kernel body. Field attribute access (``tup.a``, ``tup.b``, …) +works exactly as in native Python. + +Basic usage +^^^^^^^^^^^ + +.. code-block:: python + + from typing import NamedTuple + import cutlass + import cutlass.cute as cute + + class Vec3(NamedTuple): + x: cutlass.Int32 + y: cutlass.Int32 + z: cutlass.Int32 + + @cute.jit + def print_vec(v: Vec3): + cute.printf("x=%d y=%d z=%d\n", v.x, v.y, v.z) + + v = Vec3(x=cutlass.Int32(1), y=cutlass.Int32(2), z=cutlass.Int32(3)) + cute.compile(print_vec, v)(v) + +Control flow on fields +^^^^^^^^^^^^^^^^^^^^^^ + +Fields are DSL values inside the kernel, so they work in ``if``/``else`` +branches and ``for`` loops: + +.. code-block:: python + + @cute.jit + def clamp_positive(v: Vec3, out: cute.Tensor): + """Write max(field, 0) for each component.""" + out[0] = cutlass.Int32(0) if v.x < cutlass.Int32(0) else v.x + out[1] = cutlass.Int32(0) if v.y < cutlass.Int32(0) else v.y + out[2] = cutlass.Int32(0) if v.z < cutlass.Int32(0) else v.z + + @cute.jit + def triangular_sum(v: Vec3, out: cute.Tensor): + """Sum 0..v.x-1 into out[0], and so on.""" + s = cutlass.Int32(0) + for i in range(v.x): + s = s + i + out[0] = s + +Creating a new NamedTuple value inside the kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +NamedTuple fields are **immutable** — the same constraint as native Python +tuples. Assigning ``tup.x = ...`` inside a kernel raises ``AttributeError``. +To "update" a field, construct a replacement NamedTuple: + +.. code-block:: python + + @cute.jit + def scale(v: Vec3, factor: cutlass.Int32, out: cute.Tensor): + # Construct a new Vec3 with all fields scaled + scaled = Vec3(x=v.x * factor, y=v.y * factor, z=v.z * factor) + out[0] = scaled.x + out[1] = scaled.y + out[2] = scaled.z + +Choosing the right type +----------------------- + +.. list-table:: + :header-rows: 1 + :widths: 35 65 + + * - Use case + - Recommended type + * - Read-only config / parameters passed into a kernel + - ``NamedTuple`` or ``@dataclass(frozen=True)`` + * - Accumulator or running state updated inside a kernel + - ``@native_struct`` + * - Want Python-native immutable semantics (hashable, unpackable) + - ``NamedTuple`` + * - Need fine-grained LLVM struct control (packing, zero-init) + - ``@native_struct`` + + +See also +-------- + +* :doc:`dsl_jit_arg_generation` — overview of JIT function argument protocols +* :doc:`dsl_dynamic_layout` — passing ``Layout`` objects as JIT arguments diff --git a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst index 4253bdd30..974777b48 100644 --- a/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst +++ b/media/docs/pythonDSL/cute_dsl_general/framework_integration.rst @@ -79,6 +79,12 @@ by reducing register usage and the number of address calculation instructions. W to True, a runtime check is performed to ensure that the layout does not overflow. Please note that this parameter only has an effect when the tensor's layout is marked as dynamic. +For packed subbyte torch dtypes such as ``torch.float4_e2m1fn_x2``, ``from_dlpack`` exposes the +logical element layout expected by CuTe instead of the packed storage layout. For example, a torch +tensor with shape ``(128, 128)`` and dtype ``torch.float4_e2m1fn_x2`` is exposed as a logical FP4 +tensor with shape ``(128, 256)``. The same logical reinterpretation also applies when the leading +dimension is not the last mode. + Code Example ~~~~~~~~~~~~ diff --git a/media/docs/pythonDSL/cute_dsl_general/resources.rst b/media/docs/pythonDSL/cute_dsl_general/resources.rst index 0a63a87ae..7b4c608f8 100644 --- a/media/docs/pythonDSL/cute_dsl_general/resources.rst +++ b/media/docs/pythonDSL/cute_dsl_general/resources.rst @@ -15,7 +15,7 @@ Conference Talks An introduction to the |DSL| architecture, covering the hybrid AST-rewrite and tracing approach, MLIR code generation, and integration with CUTLASS. -* `LLVM Video `_ +* `Video `__ * `Slides (PDF) `_ ---- @@ -25,4 +25,4 @@ tracing approach, MLIR code generation, and integration with CUTLASS. Learn how to leverage Tensor Cores directly from Python using CUTLASS 4.0's new DSL front-end, enabling rapid kernel development without writing CUDA C++. -* `GTC Video `_ +* `Video `__ diff --git a/media/docs/pythonDSL/functionality.rst b/media/docs/pythonDSL/functionality.rst index b3575dd7a..72b4e3ca9 100644 --- a/media/docs/pythonDSL/functionality.rst +++ b/media/docs/pythonDSL/functionality.rst @@ -3,11 +3,7 @@ Functionality ==================== -The CUTLASS DSL 4.0 release supports **Python 3.12** only. It shares the same driver requirements -as the `CUDA Toolkit 12.9 `__. -Specifically, the driver version must be 575.51.03 or later. - -Currently, only Linux x86_64 is supported. Additional platform support will be added in future releases. +For dependency version requirements, refer to the :doc:`quick_start` section. Supported MMA Operations --------------------------------- diff --git a/media/docs/pythonDSL/limitations.rst b/media/docs/pythonDSL/limitations.rst index f396e4f5e..34ee08339 100644 --- a/media/docs/pythonDSL/limitations.rst +++ b/media/docs/pythonDSL/limitations.rst @@ -217,18 +217,68 @@ Programming Model **CuTe Layout algebra in native Python** - Entirety of CuTe Layout algebra operations and APIs require JIT compilation. These - functionalities are exclusively available within JIT-compiled functions and cannot be + Entirety of CuTe Layout algebra operations and APIs require JIT compilation. These + functionalities are exclusively available within JIT-compiled functions and cannot be accessed in standard Python execution environments. - - Additionally, there exists a restricted set of data types that can be passed as arguments - to JIT-compiled functions, which further constrains their usage in native Python contexts. - Only following CuTe algebra types are supported as JIT function arguments: ``Tensor``, ``Pointer``, + + Additionally, there exists a restricted set of data types that can be passed as arguments + to JIT-compiled functions, which further constrains their usage in native Python contexts. + Only following CuTe algebra types are supported as JIT function arguments: ``Tensor``, ``Pointer``, ``Shape``, ``Stride``, ``Coord`` and ``IntTuple``. For ``Stride``, we don't support ``ScacledBasis`` - from native Python Context. Unfortunately, in the first release, we don't support + from native Python Context. Unfortunately, in the first release, we don't support passing ``Layout`` under native Python Context. +**Block-level Utilities (block_copy)** + The block-level utility ``block_copy`` provides a high-level abstraction + for common copy patterns, but has the following limitations: + + **block_copy limitations:** + + - **Limited copy op support**: Currently only ``TmaCopyOp``-based tiled copies + (TMA loads/stores) and S2T copies (SMEM to TMEM, e.g., ``tcgen05.Cp*Op``) are + supported. Other ``TiledCopy`` ops will raise ``NotImplementedError``. Support + for additional copy ops may be added in future releases. + + +**Global variables** + CuTe DSL does not support global variables. + It is not allowed to use ``global`` in the DSL. + The following example illustrates functionality in Python that is not supported in the DSL: + + .. code:: python + + @cute.jit + def foo(): + global x + x = 1 + + foo() + + The example above fails to compile because ``global x`` is not supported in the DSL. + + +**Nonlocal variables** + The use of the ``nonlocal`` keyword is restricted in CuTe DSL. CuTe DSL does not support capturing variables + from an outer (enclosing) scope that is outside of the JIT-compiled function. If you try to use ``nonlocal`` + to refer to a variable defined in Python code that is not tracked by current JIT context, a runtime error will be raised. + + .. code:: python + + def outer(): + x = 1 + + @cute.jit + def inner(): + nonlocal x # Not supported + x = 2 + + inner() + + The above code will fail with a runtime error because ``x`` is defined in a scope not managed + by the CuTe DSL's JIT compilation. Nonlocal variables must be managed within the same JIT context; + otherwise, a runtime error will be raised. + Suggestions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/media/docs/pythonDSL/quick_start.rst b/media/docs/pythonDSL/quick_start.rst index e6392c78b..1a71826de 100644 --- a/media/docs/pythonDSL/quick_start.rst +++ b/media/docs/pythonDSL/quick_start.rst @@ -3,19 +3,19 @@ Quick Start Guide ======================= -The CUTLASS DSL 4.4 release currently supports **Linux** and **Python 3.10 - 3.14** only. To install CUTLASS DSLs (limited to CuTe DSL for now), use the following command +Compatibility Requirements +--------------------------------- + +The CUTLASS DSL 4.4 release currently supports **Linux** and **Python 3.10 - 3.14** only. + +Only Linux x86_64 and aarch64 are supported. Additional platform support will be added in future releases. + +CUTLASS DSL supports the same NVIDIA driver version as the corresponding `CUDA Toolkit `_ +(CUDA Toolkit 12.9 or CUDA Toolkit 13.1). Specifically, for 12.9, the driver version must be 575.51.03 or later. Installation ----------------------- -Before installing the latest version, you need to uninstall any previous CUTLASS DSL Installation. - -.. code-block:: bash - - pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu13 -y - - - To ensure compatibility with the examples and code on `GitHub `_, use the `setup.sh `_ file from the corresponding commit in the repository. @@ -38,12 +38,10 @@ If you just want to try out the last known stable release of the CUTLASS DSL (ma pip install nvidia-cutlass-dsl # For CUDA Toolkit 13.1: - pip install nvidia-cutlass-dsl[cu13] + pip install "nvidia-cutlass-dsl[cu13]" -The ``nvidia-cutlass-dsl`` wheel includes everything needed to generate GPU kernels. It requires -the same NVIDIA driver version as the corresponding `CUDA Toolkit `_ -(CUDA Toolkit 12.9 or CUDA Toolkit 13.1). +The ``nvidia-cutlass-dsl`` wheel includes everything needed to generate GPU kernels. Recommended Dependencies --------------------------------- @@ -52,9 +50,7 @@ To run examples and begin development, we recommend installing: .. code-block:: bash - pip install torch jupyter - -We recommend installing JAX with CUDA support at version 0.8.1 to run JAX examples. + pip install torch jupyter mypy==1.19.1 Recommended Python environment variables for jupyter notebooks -------------------------------------------------------------- diff --git a/python/CuTeDSL/cutlass/__init__.py b/python/CuTeDSL/cutlass/__init__.py index 3d00944bf..1d8c1ea37 100644 --- a/python/CuTeDSL/cutlass/__init__.py +++ b/python/CuTeDSL/cutlass/__init__.py @@ -13,6 +13,27 @@ from ._mlir._mlir_libs import _cutlass_ir _cutlass_ir.populate(_cutlass_ir) + +def _ensure_mlir_type_compat() -> None: + """Patch `.isinstance()` onto MLIR type classes that no longer expose it.""" + try: + from ._mlir import ir as _mlir_ir + except Exception: + return + for name in dir(_mlir_ir): + if not name.endswith("Type"): + continue + cls = getattr(_mlir_ir, name) + if not isinstance(cls, type) or hasattr(cls, "isinstance"): + continue + try: + cls.isinstance = staticmethod(lambda ty, _cls=cls: isinstance(ty, _cls)) # type: ignore[attr-defined] + except Exception: + continue + + +_ensure_mlir_type_compat() +del _ensure_mlir_type_compat __version__ = "@CUTLASS_IR_WHEEL_RELEASE_VERSION@" # Monkey patch CUDA version query function from ._mlir._mlir_libs._cutlass_ir._base_dsl import ( @@ -52,6 +73,7 @@ from .cutlass_dsl import ( # Data types dtype, # Provides conversions to types inheriting from NumericType DSLRuntimeError, + DSLAstPreprocessorError, JitArgAdapterRegistry, # Construction utilities for user-defined classes extract_mlir_values, @@ -77,3 +99,5 @@ cuda = _dsl.cuda_helpers # Jax Framework support from . import jax as jax + +CACHE_FILE = "compiled_cache.db" diff --git a/python/CuTeDSL/cutlass/_pth_hook.py b/python/CuTeDSL/cutlass/_pth_hook.py new file mode 100644 index 000000000..d2b5e96b0 --- /dev/null +++ b/python/CuTeDSL/cutlass/_pth_hook.py @@ -0,0 +1,69 @@ +"""Hook script loaded by cutlass-dsl-dev.pth at Python startup. + +This script sets up the editable install environment: +1. Sets CUTE_DSL_LIBS environment variable +2. Installs the custom editable finder for cutlass._mlir and DSL modules + +The .pth file calls setup() with paths configured during installation. +""" + +from __future__ import annotations + +import os +from importlib.util import module_from_spec, spec_from_file_location +from pathlib import Path + + +def setup( + cutlass_source_dir: str | Path, + vendored_mlir_dir: str | Path, + lib_so_path: str | Path, + finder_module_path: str | Path, + root_dir: str | Path | None = None, +) -> None: + """Set up the editable install environment. + + This function is called by the .pth file at Python startup with paths + configured during pip installation. + + :param cutlass_source_dir: Path to cutlass source package directory + :param vendored_mlir_dir: Path to vendored _mlir directory + :param lib_so_path: Path to libcute_dsl_runtime.so + :param finder_module_path: Path to _editable_finder.py module + :param root_dir: Path to DSL root directory (optional) + """ + # Convert to Path objects + cutlass_source_dir = Path(cutlass_source_dir) + vendored_mlir_dir = Path(vendored_mlir_dir) + lib_so_path = Path(lib_so_path) + finder_module_path = Path(finder_module_path) + if root_dir is not None: + root_dir = Path(root_dir) + + # Set CUTE_DSL_LIBS environment variable + os.environ.setdefault("CUTE_DSL_LIBS", str(lib_so_path)) + + # Load and configure the custom editable finder module + spec = spec_from_file_location("_editable_finder", finder_module_path) + if spec is None: + raise ImportError( + f"Failed to obtain module spec for '_editable_finder' at {finder_module_path}. " + f"Ensure the file exists and is a valid Python module." + ) + if spec.loader is None: + raise ImportError( + f"Failed to obtain loader for '_editable_finder' from spec at {finder_module_path}. " + f"The module spec was created but has no loader." + ) + + finder_mod = module_from_spec(spec) + spec.loader.exec_module(finder_mod) + + # Configure the finder's path variables + finder_mod.CUTLASS_SOURCE_DIR = cutlass_source_dir + finder_mod.VENDORED_MLIR_DIR = vendored_mlir_dir + if root_dir is not None: + finder_mod.ROOT_DIR = root_dir + + # Install the finder into sys.meta_path + finder_mod.install() diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py index db064f198..b963a0ad3 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/__init__.py @@ -14,10 +14,11 @@ This module provides MLIR Dialect helper functions """ from . import arith +from .dialect_proxy import DialectAutoConvertProxy from .lru_cache_ir import lru_cache_ir from .op import dsl_user_op -__all__ = ["arith", "lru_cache_ir", "dsl_user_op"] +__all__ = ["arith", "DialectAutoConvertProxy", "lru_cache_ir", "dsl_user_op"] try: from . import gpu diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py index b02c79d72..0732e32a5 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/arith.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # Use of this software is governed by the terms and conditions of the @@ -14,12 +14,17 @@ This module provides MLIR Arith Dialect helper functions """ import array +import builtins +from typing import Any, Callable, Optional, Union + import numpy as np from ..common import * -from ..._mlir import ir # type: ignore -from ..._mlir.extras import types as T # type: ignore -from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore +from ..common import DSLRuntimeError, DSLNotImplemented +from ..._mlir import ir +from ..._mlir.extras import types as T +from ..._mlir.dialects import arith, math, builtin +from ..._mlir.dialects import nvgpu, vector, llvm from .op import dsl_user_op from .lru_cache_ir import lru_cache_ir @@ -29,7 +34,7 @@ from .lru_cache_ir import lru_cache_ir # ============================================================================= -def recast_type(src_type, res_elem_type) -> ir.Type: +def recast_type(src_type: ir.Type, res_elem_type: ir.Type) -> ir.Type: if isinstance(src_type, T.VectorType): if src_type.scalable: res_type = T.vector( @@ -55,20 +60,20 @@ def recast_type(src_type, res_elem_type) -> ir.Type: return res_type -def is_scalar(ty) -> bool: +def is_scalar(ty: ir.Type) -> bool: return not isinstance( ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType) ) -def element_type(ty) -> ir.Type: +def element_type(ty: ir.Type) -> ir.Type: if not is_scalar(ty): return ty.element_type else: return ty -def is_narrow_precision(ty) -> bool: +def is_narrow_precision(ty: ir.Type) -> bool: narrow_types = { T.f8E8M0FNU(), T.f8E4M3FN(), @@ -82,7 +87,7 @@ def is_narrow_precision(ty) -> bool: return ty in narrow_types -def is_float_type(ty) -> bool: +def is_float_type(ty: ir.Type) -> bool: return ( arith._is_float_type(ty) # TODO-upstream: prediction is not correct. Patch here and fix in upstream later @@ -91,33 +96,28 @@ def is_float_type(ty) -> bool: ) -def truncf_to_narrow(res_ty, src, loc, ip): - res_elem_ty = element_type(res_ty) - if res_elem_ty == T.f8E8M0FNU(): - rnd = nvgpu.RoundingMode.RP - else: - rnd = nvgpu.RoundingMode.RN - return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip) +def is_integer_like_type(ty: ir.Type) -> bool: + return isinstance(ty, (ir.IntegerType, ir.IndexType)) -def extf_from_narrow(res_ty, src, loc, ip): - src_elem_ty = element_type(src.type) - - # When source type is E8M0, temporary element type has to be bf16 - tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16() - tmp_ty = recast_type(src.type, tmp_elem_ty) - - # narrow -> bf16/f16 -> target type - tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip) - return arith.extf(res_ty, tmp, loc=loc, ip=ip) - - -def bitcast(src, res_elem_type, *, loc=None, ip=None): +def bitcast( + src: ir.Value, + res_elem_type: ir.Type, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: res_type = recast_type(src.type, res_elem_type) return arith.bitcast(res_type, src, loc=loc, ip=ip) -def cvtf(src, res_elem_type, *, loc=None, ip=None): +def cvtf( + src: ir.Value, + res_elem_type: ir.Type, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: src_elem_type = element_type(src.type) if res_elem_type == src_elem_type: @@ -145,41 +145,51 @@ def cvtf(src, res_elem_type, *, loc=None, ip=None): return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip) if res_elem_type.width > src_elem_type.width: - if is_narrow_precision(src_elem_type): - return extf_from_narrow(res_type, src, loc, ip) - else: - return arith.extf(res_type, src, loc=loc, ip=ip) + return arith.extf(res_type, src, loc=loc, ip=ip) else: - tmp_mlir_type = recast_type(src.type, T.f32()) - - # f16 -- extf -> f32 -- truncf -> bf16 - # TODO-upstream: update arith to support bf16 <-> f16 conversion? + # bf16 <-> f16: both are 16-bit, arith.truncf requires strict narrowing. + # Route through f32 intermediate. if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or ( src_elem_type == T.bf16() and res_elem_type == T.f16() ): - tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip) + tmp_type = recast_type(src.type, T.f32()) + tmp = arith.extf(tmp_type, src, loc=loc, ip=ip) return arith.truncf(res_type, tmp, loc=loc, ip=ip) - # {f8, f6, f4} -> f16, f32, ... - elif is_narrow_precision(res_elem_type): - return truncf_to_narrow(res_type, src, loc, ip) - else: - return arith.truncf(res_type, src, loc=loc, ip=ip) + # E8M0 requires upward rounding; all others default to to_nearest_even + roundingmode = ( + arith.RoundingMode.upward if res_elem_type == T.f8E8M0FNU() else None + ) + return arith.truncf(res_type, src, roundingmode=roundingmode, loc=loc, ip=ip) -def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): +def fptoi( + src: ir.Value, + signed: Union[bool, None], + res_elem_type: ir.Type, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: res_type = recast_type(src.type, res_elem_type) # TODO-upstream: update arith to support this kind of conversion if element_type(src.type) in (T.tf32(), T.bf16()): src = cvtf(src, T.f32(), loc=loc, ip=ip) - if signed: + if signed != False: # noqa: E712 return arith.fptosi(res_type, src, loc=loc, ip=ip) else: return arith.fptoui(res_type, src, loc=loc, ip=ip) -def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): +def itofp( + src: ir.Value, + signed: Union[bool, None], + res_elem_type: ir.Type, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: res_type = recast_type(src.type, res_elem_type) orig_res_type = res_type @@ -187,7 +197,7 @@ def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): if res_elem_type in (T.tf32(), T.bf16()): res_type = recast_type(src.type, T.f32()) - if signed and element_type(src.type).width > 1: + if signed != False and element_type(src.type).width > 1: # noqa: E712 res = arith.sitofp(res_type, src, loc=loc, ip=ip) else: res = arith.uitofp(res_type, src, loc=loc, ip=ip) @@ -198,7 +208,13 @@ def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip) -def int_to_int(a, dst_elem_type, *, loc=None, ip=None): +def int_to_int( + a: ir.Value, + dst_elem_type: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: src_signed = a.signed dst_signed = dst_elem_type.signed src_width = element_type(a.type).width @@ -208,7 +224,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None): if dst_width == src_width: return a - elif src_signed != False and not dst_signed: + elif src_signed != False and not dst_signed: # noqa: E712 # Signed -> Unsigned if dst_width > src_width: return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) @@ -217,7 +233,7 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None): elif src_signed == dst_signed: # Same signedness if dst_width > src_width: - if src_signed != False and src_width > 1: + if src_signed != False and src_width > 1: # noqa: E712 return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip) else: return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) @@ -244,7 +260,14 @@ def int_to_int(a, dst_elem_type, *, loc=None, ip=None): # ============================================================================= -def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None): +def _cast( + res_elem_ty: ir.Type, + src: ir.Value, + is_signed: Optional[bool] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ This function provides simplified interface to upstream op builder arith.truncf(T.vector(shape, new_type), src) @@ -271,7 +294,7 @@ def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None): if src_elem_ty.width >= res_elem_ty.width: cast_op = arith.trunci else: - if is_signed: + if is_signed != False: # noqa: E712 cast_op = arith.extsi else: cast_op = arith.extui @@ -289,7 +312,13 @@ def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None): @lru_cache_ir() -def const(value, ty=None, *, loc=None, ip=None): +def const( + value: Union[int, float, bool, np.ndarray], + ty: Optional[Union[ir.Type, "NumericMeta"]] = None, # type: ignore[name-defined] + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Generates dynamic expression for constant values. """ @@ -302,7 +331,8 @@ def const(value, ty=None, *, loc=None, ip=None): # Early return if is_dynamic_expression(value) and ( - value.type.isinstance(value.type) or T.bool().isinstance(value.type) + isinstance(value.type, type(value.type)) # type: ignore[union-attr] + or isinstance(value.type, type(T.bool())) # type: ignore[union-attr] ): return value @@ -316,7 +346,7 @@ def const(value, ty=None, *, loc=None, ip=None): ty = T.i32() elif isinstance(value, np.ndarray): ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype)) - value = array.array(value.dtype.kind, value.flatten().tolist()) + value = array.array(value.dtype.kind, value.flatten().tolist()) # type: ignore[assignment] else: raise DSLNotImplemented(f"{type(value)} is not supported") elif isinstance(ty, NumericMeta): @@ -339,17 +369,18 @@ def const(value, ty=None, *, loc=None, ip=None): return arith.constant(ty, value, loc=loc, ip=ip) -def _dispatch_to_rhs_r_op(op): +def _dispatch_to_rhs_r_op(op: Callable[..., "ArithValue"]) -> Callable[..., Any]: """Decorator that dispatches to the right-hand-side's reverse operation. If the other operand is not an ArithValue or is a subclass (more specific) of ArithValue, this allows proper method resolution for binary operations. """ - def wrapper(self, other, **kwargs): + def wrapper( + self: "ArithValue", other: Union[int, float, bool, "ArithValue"], **kwargs: Any + ) -> Any: if not isinstance(other, ArithValue): if not isinstance(other, (int, float, bool)): - # allows to call other.__rmul__ return NotImplemented return op(self, other, **kwargs) @@ -357,19 +388,18 @@ def _dispatch_to_rhs_r_op(op): return wrapper -def _binary_op(op): +def _binary_op(op: Callable[..., "ArithValue"]) -> Callable[..., "ArithValue"]: """ Decorator to check if the 'other' argument is an ArithValue. If not, returns NotImplemented. """ - def wrapper(self, other, **kwargs): - # When reach this point, `self` must be cast to base `ArithValue` type + def wrapper( + self: "ArithValue", other: Union[int, float, bool, "ArithValue"], **kwargs: Any + ) -> "ArithValue": if isinstance(other, (int, float, bool)): other = const(other, self.type).with_signedness(self.signed) - # Call the original function - # If sub-class doesn't implement overloaded arithmetic, cast to base class return op(self, other, **kwargs) return wrapper @@ -390,13 +420,19 @@ def _binary_op(op): @ir.register_value_caster(ir.F32Type.static_typeid) @ir.register_value_caster(ir.F64Type.static_typeid) @ir.register_value_caster(ir.IntegerType.static_typeid) -@ir.register_value_caster(ir.VectorType.static_typeid) @ir.register_value_caster(ir.RankedTensorType.static_typeid) class ArithValue(ir.Value): """Overloads operators for MLIR's Arith dialects binary operations.""" @dsl_user_op - def __init__(self, v, signed: Union[bool, None] = None, *, loc=None, ip=None): + def __init__( + self, + v: Union[int, ir.Value], + signed: Union[bool, None] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: if isinstance(v, int): v = arith.constant(self.type, v, loc=loc, ip=ip) super().__init__(v) @@ -406,11 +442,25 @@ class ArithValue(ir.Value): # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL self.signed = signed and elem_ty.width > 1 - def with_signedness(self, signed: Union[bool, None]): + @dsl_user_op + def ir_value( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: + return self + + def with_signedness(self, signed: Union[bool, None]) -> "ArithValue": return type(self)(self, signed) @dsl_user_op - def __neg__(self, *, loc=None, ip=None): + def __neg__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.type == T.bool(): raise TypeError( "Negation, the operator `-` is not supported for boolean type" @@ -424,7 +474,13 @@ class ArithValue(ir.Value): @dsl_user_op @_binary_op - def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __pow__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float and other.is_float: return math.powf(self, other, loc=loc, ip=ip) elif self.is_float and not other.is_float: @@ -440,14 +496,26 @@ class ArithValue(ir.Value): @dsl_user_op @_binary_op - def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rpow__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__pow__(self, loc=loc, ip=ip) # arith operators @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __add__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __add__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.addf(self, other, loc=loc, ip=ip) else: @@ -456,7 +524,13 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __sub__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.subf(self, other, loc=loc, ip=ip) else: @@ -465,7 +539,13 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __mul__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.mulf(self, other, loc=loc, ip=ip) else: @@ -474,7 +554,13 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __truediv__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.divf(self, other, loc=loc, ip=ip) else: @@ -485,11 +571,17 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __floordiv__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: q = arith.divf(self, other, loc=loc, ip=ip) return math.floor(q, loc=loc, ip=ip) - elif self.signed != False: + elif self.signed != False: # noqa: E712 return arith.floordivsi(self, other, loc=loc, ip=ip) else: return arith.divui(self, other, loc=loc, ip=ip) @@ -497,52 +589,100 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __mod__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.remf(self, other, loc=loc, ip=ip) - elif self.signed != False: + elif self.signed != False: # noqa: E712 return arith.remsi(self, other, loc=loc, ip=ip) else: return arith.remui(self, other, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __radd__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__add__(self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rsub__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__sub__(self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rmul__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__mul__(self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rtruediv__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__truediv__(self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rfloordiv__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__floordiv__(self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rmod__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__mod__(self, loc=loc, ip=ip) # Comparison operators (comparison doesn't have right-hand-side variants) @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __lt__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip) - elif self.signed != False: + elif self.signed != False: # noqa: E712 return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip) @@ -550,10 +690,16 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __le__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __le__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip) - elif self.signed != False: + elif self.signed != False: # noqa: E712 return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip) @@ -561,7 +707,13 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __eq__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip) else: @@ -570,7 +722,13 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __ne__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: # In Python, bool(float("nan")) is True, so use unordered comparison here return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip) @@ -580,10 +738,16 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __gt__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip) - elif self.signed != False: + elif self.signed != False: # noqa: E712 return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip) @@ -591,43 +755,94 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __ge__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": if self.is_float: return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip) - elif self.signed != False: + elif self.signed != False: # noqa: E712 return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip) else: return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip) # Unary operators @dsl_user_op - def __invert__(self, *, loc=None, ip=None) -> "ArithValue": - return arith.xori(self, arith.constant(self.type, -1)) + def __abs__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": + if self.is_float: + return math.absf(self, loc=loc, ip=ip) + else: + return math.absi(self, loc=loc, ip=ip) + + @dsl_user_op + def __invert__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": + # For i1 (boolean) types, the all-ones value is 1, not -1. + # Using -1 with i1 vectors causes arith.constant to produce a + # type mismatch (e.g. vector<32xi0> instead of vector<32xi1>). + all_ones = 1 if element_type(self.type).width == 1 else -1 + return arith.xori(self, const(all_ones, self.type)) # Bitwise operations @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __and__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __and__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.andi(self, other, loc=loc, ip=ip) @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __or__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __or__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.ori(self, other, loc=loc, ip=ip) @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __xor__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.xori(self, other, loc=loc, ip=ip) @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue": - if self.signed != False: + def __rshift__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": + if self.signed != False: # noqa: E712 return arith.shrsi(self, other, loc=loc, ip=ip) else: return arith.shrui(self, other, loc=loc, ip=ip) @@ -635,45 +850,87 @@ class ArithValue(ir.Value): @dsl_user_op @_dispatch_to_rhs_r_op @_binary_op - def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __lshift__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.shli(self, other, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rand__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.andi(other, self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __ror__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.ori(other, self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rxor__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return arith.xori(other, self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rrshift__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__rshift__(self, loc=loc, ip=ip) @dsl_user_op @_binary_op - def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + def __rlshift__( + self, + other: "ArithValue", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ArithValue": return other.__lshift__(self, loc=loc, ip=ip) - def __hash__(self): + def __hash__(self) -> int: return super().__hash__() - def __str__(self): + def __str__(self) -> str: return "?" - def __repr__(self): + def __repr__(self) -> str: return self.__str__() -def _min(lhs, rhs, *, loc=None, ip=None): +def _min( + lhs: ir.Value, + rhs: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ This function provides a unified interface for building arith min @@ -697,7 +954,7 @@ def _min(lhs, rhs, *, loc=None, ip=None): assert hasattr(lhs, "signed"), ( "Should have attribute `signed`, must be a bug" ) - if lhs.signed != False: + if lhs.signed != False: # noqa: E712 return arith.minsi(lhs, rhs, loc=loc, ip=ip) else: return arith.minui(lhs, rhs, loc=loc, ip=ip) @@ -705,7 +962,7 @@ def _min(lhs, rhs, *, loc=None, ip=None): return arith.minimumf(lhs, rhs, loc=loc, ip=ip) elif arith._is_integer_like_type(lhs.type): assert hasattr(lhs, "signed"), "Should have attribute `signed`, must be a bug" - if lhs.signed != False: + if lhs.signed != False: # noqa: E712 return arith.minsi(lhs, rhs, loc=loc, ip=ip) else: return arith.minui(lhs, rhs, loc=loc, ip=ip) @@ -713,7 +970,13 @@ def _min(lhs, rhs, *, loc=None, ip=None): return arith.minimumf(lhs, rhs, loc=loc, ip=ip) -def _max(lhs, rhs, *, loc=None, ip=None): +def _max( + lhs: ir.Value, + rhs: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ This function provides a unified interface for building arith max @@ -736,7 +999,7 @@ def _max(lhs, rhs, *, loc=None, ip=None): assert hasattr(lhs, "signed"), ( "Should have attribute `signed`, must be a bug" ) - if lhs.signed != False: + if lhs.signed != False: # noqa: E712 return arith.maxsi(lhs, rhs, loc=loc, ip=ip) else: return arith.maxui(lhs, rhs, loc=loc, ip=ip) @@ -744,9 +1007,548 @@ def _max(lhs, rhs, *, loc=None, ip=None): return arith.maximumf(lhs, rhs, loc=loc, ip=ip) elif arith._is_integer_like_type(lhs.type): assert hasattr(lhs, "signed"), "Should have attribute `signed`, must be a bug" - if lhs.signed != False: + if lhs.signed != False: # noqa: E712 return arith.maxsi(lhs, rhs, loc=loc, ip=ip) else: return arith.maxui(lhs, rhs, loc=loc, ip=ip) else: return arith.maximumf(lhs, rhs, loc=loc, ip=ip) + + +# ============================================================================= +# Vector Type - Immutable on registers +# ============================================================================= + + +@ir.register_value_caster(ir.VectorType.static_typeid) +class Vector(ArithValue): + """Wrap an MLIR ``vector`` register value with DSL type information. + + Provides element extraction (``vec[i]`` / ``vec[a:b]``), element-wise + arithmetic (``+``, ``-``, ``*``, ``/``), type conversion (:meth:`to`), + and bit-reinterpretation (:meth:`bitcast`) on top of a raw MLIR vector. + + Vectors live entirely in registers — they carry no memory address and do + not support in-place element assignment. + + Registered as the MLIR value caster for :class:`ir.VectorType`, so any + op that returns a vector automatically produces a ``Vector`` instance. + + :param v: Underlying MLIR vector value. + :type v: ir.Value + :param dtype: DSL element type (e.g. ``Float32``, ``Int32``). + Inferred from the MLIR element type when omitted. + :type dtype: type, optional + """ + + _dtype: "type" + _mlir_type: ir.Type + _shape: "tuple[int, ...]" + + # Cache parameterized subclasses so ``Vector[T, N] is Vector[T, N]``. + _parameterized_cache: "dict[tuple, type]" = {} + + @classmethod + def __class_getitem__(cls, params: "tuple[type, int | tuple[int, ...]]") -> type: + """Return a parameterized Vector subclass with a ``mlir_type`` descriptor. + + ``Vector[Int32, 4].mlir_type`` returns ``vector<4xi32>`` and + ``Vector[Float32, (4, 8)].mlir_type`` returns ``vector<4x8xf32>``, + matching the dual type-descriptor / value-constructor pattern of + scalar types like ``Int32``. Follows the same approach as + ``Pointer.__class_getitem__``. + """ + element_type, shape = params + + from ..typing import Numeric + + if not (isinstance(element_type, type) and issubclass(element_type, Numeric)): + raise TypeError( + f"Vector element type must be a Numeric type (Integer or Float), " + f"got {element_type!r}" + ) + if isinstance(shape, int): + shape = (shape,) + shape = tuple(shape) + if not shape or any(d <= 0 for d in shape): + raise ValueError( + f"Vector shape dimensions must be positive integers, got {shape}" + ) + key = (element_type, shape) + if key not in cls._parameterized_cache: + shape_str = "x".join(str(d) for d in shape) + + class _Parameterized(cls): # type: ignore[valid-type, misc] + """Vector subclass with an ``mlir_type`` type descriptor.""" + + class mlir_type: # noqa: N801 + def __get__( + self, obj: object, objtype: Optional[type] = None + ) -> ir.Type: + return ir.VectorType.get(list(shape), element_type.mlir_type) # type: ignore[arg-type, attr-defined] + + mlir_type = mlir_type() # type: ignore[misc, assignment] + + @staticmethod + def __get_mlir_types__() -> list: + """Return MLIR types list — compatible with FFI ``_to_mlir_types``.""" + return [ir.VectorType.get(list(shape), element_type.mlir_type)] # type: ignore[arg-type, attr-defined] + + @staticmethod + def isinstance(value: object) -> bool: + """Check if an ``ir.Value`` matches this parameterized vector type.""" + if not builtins.isinstance(value, ir.Value): + return False + ty = value.type + if not builtins.isinstance(ty, ir.VectorType): + return False + return ( + list(ty.shape) == list(shape) # type: ignore[arg-type] + and ty.element_type == element_type.mlir_type # type: ignore[attr-defined] + ) + + _Parameterized.__name__ = f"Vector[{element_type.__name__}, {shape_str}]" + _Parameterized.__qualname__ = _Parameterized.__name__ + cls._parameterized_cache[key] = _Parameterized + return cls._parameterized_cache[key] + + def __init__( + self, + v: "ir.Value", + *, + dtype: "type | None" = None, + loc: "ir.Location | None" = None, + ip: "ir.InsertionPoint | None" = None, + ) -> None: + # Derive signedness from dtype for ArithValue base + signed = getattr(dtype, "signed", None) + super().__init__(v, signed, loc=loc, ip=ip) + + # Infer dtype from MLIR element type if not provided + if dtype is None: + from ..typing import Numeric + + dtype = Numeric.from_mlir_type(self.type.element_type) + self._dtype = dtype + self._mlir_type = dtype.mlir_type # type: ignore[attr-defined] + + # Shape is always derived from MLIR vector type + self._shape = tuple(self.type.shape) + + # ========================================================================= + # DSL Infrastructure + # ========================================================================= + + def __extract_mlir_values__(self) -> list: + return [self] + + def __extract_mlir_attributes__(self) -> list: + return [ir.DictAttr.get({})] + + def __new_from_mlir_values__(self, values: list) -> "Vector": + return Vector(values[0], dtype=self._dtype) + + def with_signedness(self, signed: Union[bool, None]) -> "Vector": + """Override ArithValue.with_signedness for keyword-only __init__.""" + new_vec = Vector(self, dtype=self._dtype) + elem_ty = self.type.element_type + new_vec.signed = ( + signed + and ir.IntegerType.isinstance(elem_ty) + and ir.IntegerType(elem_ty).width > 1 + ) + return new_vec + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def dtype(self) -> "type": + """The DSL element type (e.g., Float32, Int32).""" + return self._dtype + + @property + def shape(self) -> "tuple[int, ...]": + """The logical shape of the vector array (1D, 2D, or 3D).""" + return self._shape + + @property + def _count(self) -> int: + """Total number of elements (product of shape dimensions).""" + result = 1 + for dim in self._shape: + result *= dim + return result + + def numel(self) -> int: + """Total number of elements (product of all shape dimensions).""" + return self._count + + # Vector has no memory space - it's always in registers + # The space property is intentionally not exposed on Vector + + def ir_value( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: + """Return the underlying MLIR vector value.""" + return self + + # ========================================================================= + # Indexing Operations + # ========================================================================= + + def _compute_linear_index( + self, + indices: "tuple[Union[int, Int32], ...]", # type: ignore[name-defined] + ) -> "Union[int, Int32]": # type: ignore[name-defined] + """Compute linear index from multi-dimensional indices (row-major order).""" + if len(indices) != len(self._shape): + raise IndexError( + f"Expected {len(self._shape)} indices for shape {self._shape}, " + f"got {len(indices)}" + ) + + # Check if all indices are static (Python ints) + all_static = all(isinstance(i, int) for i in indices) + + if all_static: + # Static computation + linear = 0 + stride = 1 + for i in range(len(self._shape) - 1, -1, -1): + linear += indices[i] * stride + stride *= self._shape[i] + return linear + else: + from ..typing import Int32 + + # Dynamic computation using Int32 arithmetic + linear = Int32(0) # type: ignore[assignment] + stride = 1 + for i in range(len(self._shape) - 1, -1, -1): + idx = indices[i] if isinstance(indices[i], Int32) else Int32(indices[i]) + linear = linear + idx * Int32(stride) + stride *= self._shape[i] + return linear + + def __getitem__( + self, + idx: "Union[int, Int32, tuple, slice]", # type: ignore[name-defined] + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> object: + """Extract an element or a contiguous sub-vector. + + Supports three indexing modes: + + * **Scalar index** — returns a single DSL scalar value:: + + elem = vec[i] # static int or Int32 + + * **1-D slice** — all bounds must be static Python ``int``s:: + + sub = vec[start:stop] # stride defaults to 1 + sub = vec[start:stop:stride] # explicit stride + + * **Multi-dimensional slice** — one entry per dimension, all bounds + must be static ``int``s. An integer in a multi-dim slice is treated + as a size-1 slice (the dimension is kept):: + + sub = mat[r0:r1, c0:c1] # 2-D: rows r0:r1, cols c0:c1 + sub = mat[:, c0:c1] # 2-D: all rows, cols c0:c1 + sub = mat[0, c0:c1] # 2-D: row 0 (size 1), cols c0:c1 + + Slices use ``vector.extract_strided_slice`` internally; dynamic + (MLIR-value) slice bounds are **not** supported. + + :param idx: Element index (int or Int32), a slice, or a tuple of + ints/slices for multi-dimensional access. + :type idx: int or Int32 or tuple or slice + :return: A scalar DSL value (for element indexing) or a new + :class:`Vector` (for slice indexing). + :rtype: Numeric or Vector + :raises TypeError: If slice bounds are not static Python ints. + :raises IndexError: If the number of dimensions in a multi-dim index + does not match :attr:`shape`. + """ + from ..utils.logger import log + + # Slice → extract_strided_slice (step==1) or vector.shuffle (step>1) + if isinstance(idx, slice): + start = idx.start if idx.start is not None else 0 + step = idx.step if idx.step is not None else 1 + stop = idx.stop if idx.stop is not None else self._count + if not all(isinstance(v, int) for v in (start, stop, step)): + raise TypeError( + "Vector slice indices must be static ints; " + f"got start={start}, stop={stop}, step={step}" + ) + size = (stop - start + step - 1) // step + result_ty = ir.VectorType.get([size], self._mlir_type) + if step == 1: + result = vector.extract_strided_slice( + result_ty, self, [start], [size], [step], loc=loc, ip=ip + ) + else: + # vector.extract_strided_slice requires stride==1; use shuffle instead + mask = list(range(start, stop, step)) + result = vector.shuffle(self, self, mask, loc=loc, ip=ip) + return Vector(result, dtype=self._dtype) + + # Multi-dimensional slice: tuple containing at least one slice object + if isinstance(idx, tuple) and any(isinstance(i, slice) for i in idx): + if len(idx) != len(self._shape): + raise IndexError( + f"Expected {len(self._shape)} indices for shape {self._shape}, " + f"got {len(idx)}" + ) + offsets: "list[int]" = [] + sizes: "list[int]" = [] + strides: "list[int]" = [] + for dim, (i, dim_size) in enumerate(zip(idx, self._shape)): + if isinstance(i, slice): + start = i.start if i.start is not None else 0 + stop = i.stop if i.stop is not None else dim_size + step = i.step if i.step is not None else 1 + if not all(isinstance(v, int) for v in (start, stop, step)): + raise TypeError( + f"Vector slice indices must be static ints in dimension {dim}; " + f"got start={start}, stop={stop}, step={step}" + ) + if step != 1: + raise NotImplementedError( + f"Multi-dimensional strided slice (step={step}) is not supported; " + "use step=1 for multi-dimensional slices" + ) + offsets.append(start) + sizes.append(stop - start) + strides.append(1) + elif isinstance(i, int): + # Integer index: treated as a size-1 slice (rank is preserved) + if i < 0: + i += dim_size + offsets.append(i) + sizes.append(1) + strides.append(1) + else: + raise TypeError( + f"Vector multi-dimensional slice: dimension {dim} index must be " + f"a static int or slice, got {type(i).__name__}" + ) + result_ty = ir.VectorType.get(sizes, self._mlir_type) + result = vector.extract_strided_slice( + result_ty, self, offsets, sizes, strides, loc=loc, ip=ip + ) + return Vector(result, dtype=self._dtype) + + # Normalize to tuple + if not isinstance(idx, tuple): + indices = (idx,) + else: + indices = idx + + # Compute linear index + linear_idx = self._compute_linear_index(indices) + + log().info( + f"Vector.__getitem__: idx={idx}, linear={linear_idx}, " + f"dtype={self._dtype}, shape={self._shape}" + ) + + # For dynamic indices, we use llvm.extractelement instead of vector.extract + # because vector.extract has issues with dynamic positions + if isinstance(linear_idx, int): + # Static index - use vector.extract with static position + elem = vector.extract(self, [], [linear_idx]) + else: + # Dynamic index - use llvm.extractelement + elem = llvm.extractelement(self, linear_idx.ir_value()) + + return self._dtype(elem) + + def __setitem__( + self, + idx: "Union[int, Int32, tuple]", # type: ignore[name-defined] + value: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + """ + Vector element assignment is not supported. + + Vectors are immutable register values. Use one of these alternatives: + + 1. Use Array for mutable memory-backed storage: + arr = ctm.allocate_memory_local(ctm.Float32, 4) + arr[0] = value # This works + + 2. Use full() to create vectors with initial values: + vec = ctm.full((4,), 1.0, ctm.Float32) + """ + raise TypeError( + "Vector is immutable. Element assignment (vec[i] = value) is not supported. " + ) + + # ========================================================================= + # Arithmetic Operations + # ========================================================================= + + def _is_float_type(self) -> bool: + """Check if this vector contains floating-point elements.""" + return arith._is_float_type(self._mlir_type) + + # Arithmetic operators (+, -, *, /, -x) are inherited from ArithValue. + # Results are automatically wrapped as Vector via the value caster. + + def to( + self, + dtype: "type", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + """Convert the vector elements to a different numeric type. + + :param dtype: Target DSL element type (e.g. ``Float16``, ``Int32``). + :type dtype: Type[Numeric] + :return: A new :class:`Vector` with the same shape and elements cast + to ``dtype``. + :rtype: Vector + :raises TypeError: If ``dtype`` is not a subclass of ``Numeric``. + + Example:: + + vec_f32 = ctm.full([4], 1.5, dtype=ctm.Float32) + vec_i32 = vec_f32.to(ctm.Int32) # fp → int truncation + vec_f16 = vec_f32.to(ctm.Float16) # fp32 → fp16 narrowing + """ + from inspect import isclass + from ..typing import Numeric, Integer + + if dtype is ir.Value: + return self + + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {type(dtype)}") + + src_dtype = self._dtype + if src_dtype == dtype: + return self + + # maybe_downcast handles narrow precision types, with_signedness sets signedness + src = self.maybe_downcast().with_signedness(self.signed) + + if src_dtype.is_float and dtype.is_float: # type: ignore[attr-defined] + res_vect = cvtf(src, dtype.mlir_type, loc=loc, ip=ip) + elif src_dtype.is_float and issubclass(dtype, Integer): # type: ignore[attr-defined] + res_vect = fptoi(src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip) + elif issubclass(src_dtype, Integer) and dtype.is_float: + res_vect = itofp(src, src_dtype.signed, dtype.mlir_type, loc=loc, ip=ip) + else: + res_vect = int_to_int(src, dtype, loc=loc, ip=ip) + + return Vector(res_vect, dtype=dtype) + + @dsl_user_op + def bitcast( + self, + dtype: "type", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + """Reinterpret the vector bits as a different element type. + + The total bit width is preserved; the element count adjusts + proportionally. For example, ``vector<4xi32>`` bitcast to + ``Float16`` yields ``vector<8xf16>`` (4 × 32 = 8 × 16 bits). + + :param dtype: Target DSL element type (e.g. ``Float32``, ``Float16``). + :type dtype: Type[Numeric] + :return: A new :class:`Vector` with bits reinterpreted as ``dtype``. + :rtype: Vector + :raises TypeError: If ``dtype`` is not a subclass of ``Numeric``. + """ + from inspect import isclass + from ..typing import Numeric + + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a Numeric type, but got {dtype}") + if dtype is self._dtype: + return self + new_count = self._count * self._dtype.width // dtype.width # type: ignore[attr-defined] + target_vec_ty = T.vector(new_count, dtype.mlir_type) + res_vec = vector.bitcast(target_vec_ty, self, loc=loc, ip=ip) + return Vector(res_vec, dtype=dtype, loc=loc, ip=ip) + + @dsl_user_op + def __add__( + self, + other: "Vector", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + result = super().__add__(other, loc=loc, ip=ip) + return Vector(result, dtype=self.dtype, loc=loc, ip=ip) + + @dsl_user_op + def __radd__( + self, + other: "Vector", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + result = super().__radd__(other, loc=loc, ip=ip) + return Vector(result, dtype=self.dtype, loc=loc, ip=ip) + + @dsl_user_op + def __sub__( + self, + other: "Vector", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + result = super().__sub__(other, loc=loc, ip=ip) + return Vector(result, dtype=self.dtype, loc=loc, ip=ip) + + @dsl_user_op + def __rsub__( + self, + other: "Vector", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + result = super().__rsub__(other, loc=loc, ip=ip) + return Vector(result, dtype=self.dtype, loc=loc, ip=ip) + + @dsl_user_op + def __mul__( + self, + other: "Vector", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + result = super().__mul__(other, loc=loc, ip=ip) + return Vector(result, dtype=self.dtype, loc=loc, ip=ip) + + @dsl_user_op + def __rmul__( + self, + other: "Vector", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Vector": + result = super().__rmul__(other, loc=loc, ip=ip) + return Vector(result, dtype=self.dtype, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/dialect_proxy.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/dialect_proxy.py new file mode 100644 index 000000000..c0087e00f --- /dev/null +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/dialect_proxy.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +import types +from collections.abc import Callable +from typing import Any + + +class DialectAutoConvertProxy: + """ + Proxy that wraps a raw MLIR dialect module, auto-converting DSL types + (anything with an ``.ir_value()`` method) to ``ir.Value`` when calling + dialect operations. + + This enables users to write cleaner code without explicit + ``.ir_value()`` calls:: + + # Before (raw dialect module): + nvvm.shfl_sync(T.i32(), Int32(mask).ir_value(), ...) + + # After (proxied): + nvvm.shfl_sync(T.i32(), Int32(mask), ...) + + Non-callable attributes and enum classes are passed through unchanged + so that attribute access like ``nvvm.ShflKind.idx`` still works. + + Parameters + ---------- + dialect_module + The raw MLIR dialect module to wrap + (e.g. ``cutlass._mlir.dialects.nvvm``). + """ + + def __init__(self, dialect_module: types.ModuleType) -> None: + self._module = dialect_module + self._wrapped_cache: dict[str, Callable[..., object]] = {} + + @staticmethod + def _convert_arg( + arg: object, + loc: object | None, + ip: object | None, + ) -> object: + """Recursively convert DSL objects to ir.Value.""" + if hasattr(arg, "ir_value") and callable(arg.ir_value): + try: + return arg.ir_value(loc=loc, ip=ip) + except TypeError: + # Some ir_value() methods (e.g. Array) don't accept loc/ip. + return arg.ir_value() + if isinstance(arg, (list, tuple)): + converted = [ + DialectAutoConvertProxy._convert_arg(item, loc, ip) for item in arg + ] + return type(arg)(converted) + return arg + + def __getattr__(self, name: str) -> Any: + attr = getattr(self._module, name) + + # Non-callable attributes and enum classes pass through + # unchanged. Enum classes need attribute access (e.g. + # ShflKind.idx), but MLIR operation classes should be + # wrapped for argument conversion. + if not callable(attr) or isinstance(attr, enum.EnumMeta): + return attr + + # Use cache for wrapped callables + if name not in self._wrapped_cache: + + def _make_wrapper( + func: Callable[..., object], + ) -> Callable[..., object]: + def wrapped( + *args: object, + loc: object | None = None, + ip: object | None = None, + **kwargs: object, + ) -> object: + converted_args = tuple( + DialectAutoConvertProxy._convert_arg(arg, loc, ip) + for arg in args + ) + converted_kwargs = { + k: DialectAutoConvertProxy._convert_arg(v, loc, ip) + for k, v in kwargs.items() + } + return func( + *converted_args, + loc=loc, + ip=ip, + **converted_kwargs, + ) + + return wrapped + + self._wrapped_cache[name] = _make_wrapper(attr) + + return self._wrapped_cache[name] + + def __dir__(self) -> list[str]: + return dir(self._module) diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py index b21e44e9d..b9d05e10b 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/gpu.py @@ -15,7 +15,7 @@ This module provides MLIR GPU Dialect helper functions from ..._mlir import ir from ..._mlir.dialects import gpu, arith, scf -from ..._mlir.extras import types as T +from ..._mlir.extras import types as _T from ..common import * @@ -24,13 +24,13 @@ from ..common import * # ============================================================================= -def create_async_token(): +def create_async_token() -> ir.Value: token_ty = gpu.AsyncTokenType.get() token = gpu.wait(token_ty, []) return token -def printf(fmt, *args, threadNumber=-1): +def printf(fmt: str, *args: ir.Value, threadNumber: int = -1) -> None: """Generate gpu.printf OP predicated on threadNumber""" type_formats = [] for arg in args: diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py index eed81e192..a7353b781 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/lru_cache_ir.py @@ -24,11 +24,12 @@ def make_layout(...): """ from functools import lru_cache, wraps +from typing import Any, Callable -from ..._mlir import ir # type: ignore +from ..._mlir import ir -def get_ir_context(func): +def get_ir_context(func: Any) -> Any: """ Return the context for given func called under ir. Currently the context includes MLIRContext and InsertionPoint. @@ -42,7 +43,7 @@ def get_ir_context(func): return None -def lru_cache_ir(maxsize=128, typed=True): +def lru_cache_ir(maxsize: int = 128, typed: bool = True) -> Callable[..., Any]: """ Applies an LRU cache to a given function, with awareness of IR context. @@ -53,14 +54,14 @@ def lru_cache_ir(maxsize=128, typed=True): :param typed: Whether params are type-sensitive, default to True as IR is type-sensitive """ - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: # Use functools.lru_cache with a custom wrapper to control the key generation @lru_cache(maxsize=maxsize, typed=typed) - def cached_func(context, *args, **kwargs): + def cached_func(context: Any, *args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: try: # Call the cached function with the context return cached_func(get_ir_context(func), *args, **kwargs) @@ -68,8 +69,8 @@ def lru_cache_ir(maxsize=128, typed=True): return func(*args, **kwargs) # Expose cache-related methods for introspection - wrapper.cache_clear = cached_func.cache_clear - wrapper.cache_info = cached_func.cache_info + wrapper.cache_clear = cached_func.cache_clear # type: ignore[attr-defined] + wrapper.cache_info = cached_func.cache_info # type: ignore[attr-defined] return wrapper return decorator diff --git a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py index 4ef4c3c75..4e2af99c6 100644 --- a/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py +++ b/python/CuTeDSL/cutlass/base_dsl/_mlir_helpers/op.py @@ -17,14 +17,23 @@ import inspect import os import types from functools import wraps +from typing import Any, Callable from ..._mlir import ir -from ..common import DSLRuntimeError +from ..common import DSLRuntimeError, DSLOperationBuildError from ..utils.stacktrace import walk_to_top_module - # The DSL package root is empty by default. -_DSL_PACKAGE_ROOT = "" +_DSL_PACKAGE_ROOT: str | None = "" + +# Whether location tracking is enabled. +_ENABLE_FRAME_FILTERING: bool = False + + +def _set_enable_frame_filtering(enable: bool) -> None: + """Set whether location tracking is enabled.""" + global _ENABLE_FRAME_FILTERING + _ENABLE_FRAME_FILTERING = enable def _is_framework_frame(filename: str) -> bool: @@ -58,7 +67,7 @@ def _find_user_frame(start_frame: types.FrameType | None) -> types.FrameType | N return start_frame -def dsl_user_op(opFunc): +def dsl_user_op(opFunc: Callable[..., Any]) -> Callable[..., Any]: """ This is a decorator that needs to be used in each user-facing API to manage source location for toolchain. @@ -70,14 +79,16 @@ def dsl_user_op(opFunc): """ @wraps(opFunc) - def wrapper(*args, **kwargs): - loc = kwargs.pop("loc", None) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Pop loc= from kwargs so callers that still pass it don't break. + # We no longer forward it — LOC_TRACEBACKS captures full stacks automatically. + loc: Any = kwargs.pop("loc", None) frameInfo = None verifier_error = False if loc is None and ir.Context.current is not None: - frame = _find_user_frame(inspect.currentframe().f_back) - frameInfo = inspect.getframeinfo(frame) + frame = _find_user_frame(inspect.currentframe().f_back) # type: ignore[union-attr] + frameInfo = inspect.getframeinfo(frame) # type: ignore[arg-type] try: # In Python < 3.11, getframeinfo returns a NamedTuple without positions if not hasattr(frameInfo, "positions"): @@ -89,8 +100,8 @@ def dsl_user_op(opFunc): else: file_loc = ir.Location.file( frameInfo.filename, - frameInfo.positions.lineno, - frameInfo.positions.col_offset or 0, + frameInfo.positions.lineno, # type: ignore[attr-defined] + frameInfo.positions.col_offset or 0, # type: ignore[attr-defined] ) loc = ir.Location.name( ( @@ -108,8 +119,18 @@ def dsl_user_op(opFunc): try: res_or_list = opFunc(*args, **kwargs, loc=loc) - except TypeError as e: - # Provide a helpful error message when function doesn't accept 'loc' + verifier_error = True + # Verify the operation + if hasattr(res_or_list, "verify"): + res_or_list.verify() + + except DSLOperationBuildError as e: + # Nested DSLOperationError + raise DSLOperationBuildError( + message=e.message, cause=e, frameInfo=frameInfo + ) + except Exception as e: + # Check if it's a decorator config error first func_name = getattr(opFunc, "__name__", str(opFunc)) if "unexpected keyword argument 'loc'" in str(e): raise DSLRuntimeError( @@ -118,13 +139,20 @@ def dsl_user_op(opFunc): f"1. Add 'loc=None' as a keyword-only parameter to {func_name}:", f" def {func_name}(..., *, loc=None):", "", - f"2. Remove the @dsl_user_op decorator if location tracking is not needed", + "2. Remove the @dsl_user_op decorator if location tracking is not needed", ], cause=e, ) from e - else: - # Re-raise other TypeErrors as-is - raise + if verifier_error: + raise DSLOperationBuildError( + message="Operation verification failed", + cause=e, + frameInfo=frameInfo, + auto_translate=False, + ) + + raise e + return res_or_list return wrapper diff --git a/python/CuTeDSL/cutlass/base_dsl/arch.py b/python/CuTeDSL/cutlass/base_dsl/arch.py index 070a4d2fb..174048634 100644 --- a/python/CuTeDSL/cutlass/base_dsl/arch.py +++ b/python/CuTeDSL/cutlass/base_dsl/arch.py @@ -9,12 +9,60 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from enum import Enum +from collections.abc import Callable +from enum import Enum, EnumMeta import re -from typing import Callable, List, Tuple +from typing import Any -class Arch(Enum): +class ArchMeta(EnumMeta): + """ + Custom metaclass for Arch enum that supports dynamic aliases based on CUDA version. + + - If cuda_version >= 13.0: sm_101/sm_101a/sm_101f are aliases of sm_110/sm_110a/sm_110f, use sm_110 as the canonical name + - Otherwise: sm_110/sm_110a/sm_110f are aliases of sm_101/sm_101a/sm_101f, use sm_101 as the canonical name + """ + + _arch_aliases: dict[str, str] = {} + + def __new__( + mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any] + ) -> "ArchMeta": + cls = super().__new__(mcs, name, bases, namespace) # type: ignore[arg-type] + from .version_info import CUDA_VERSION + + if CUDA_VERSION.major >= 13: + # sm_101 -> sm_110, use sm_110 as the canonical name + mcs._arch_aliases = { + "sm_101": "sm_110", + "sm_101a": "sm_110a", + "sm_101f": "sm_110f", + } + else: + # sm_110 -> sm_101, use sm_101 as the canonical name + mcs._arch_aliases = { + "sm_110": "sm_101", + "sm_110a": "sm_101a", + "sm_110f": "sm_101f", + } + return cls + + def __getattribute__(cls, name: str) -> Any: + # Use type.__getattribute__ to avoid recursion when accessing _arch_aliases + aliases = type.__getattribute__(cls, "_arch_aliases") + if name in aliases: + # Redirect to the target member + return type.__getattribute__(cls, aliases[name]) + return super().__getattribute__(name) + + def __getitem__(cls, name: str) -> "Arch": # type: ignore[override] + # Support Arch["sm_101"] style access + if name in cls._arch_aliases: + return super().__getitem__(cls._arch_aliases[name]) + return super().__getitem__(name) + + +class Arch(Enum, metaclass=ArchMeta): # sm_arch = (major, minor, suffix) # Ampere sm_80 = (8, 0, "") @@ -44,36 +92,26 @@ class Arch(Enum): sm_121 = (12, 1, "") sm_121a = (12, 1, "a") sm_121f = (12, 1, "f") - def __init__(self, major, minor, suffix): + def __init__(self, major: int, minor: int, suffix: str) -> None: self.major = major self.minor = minor self.suffix = suffix - @classmethod - def _missing_(cls, value): - if isinstance(value, tuple) and len(value) == 2: - # Support creating Arch enum from (major, minor) tuple - # Arch(major, minor) is equivalent to Arch(major, minor, "") - major, minor, suffix = *value, "" - return cls((major, minor, suffix)) - else: - raise ValueError(f"invalid arguments for Arch: {value}") - # attributes to get arch list of specific families @classmethod - def AmpereArchs(cls) -> Tuple["Arch"]: + def AmpereArchs(cls) -> tuple["Arch", ...]: return (Arch.sm_80, Arch.sm_86, Arch.sm_87) @classmethod - def AdaArchs(cls) -> Tuple["Arch"]: + def AdaArchs(cls) -> tuple["Arch", ...]: return (Arch.sm_89,) @classmethod - def HopperArchs(cls) -> Tuple["Arch"]: + def HopperArchs(cls) -> tuple["Arch", ...]: return (Arch.sm_90, Arch.sm_90a) @classmethod - def BlackwellArchs(cls) -> Tuple["Arch"]: + def BlackwellArchs(cls) -> tuple["Arch", ...]: return ( Arch.sm_100, Arch.sm_100a, @@ -95,20 +133,21 @@ class Arch(Enum): Arch.sm_121f, ) - def __repr__(self): - return self.__str__() + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return f"Arch.{self.name}" @classmethod - def from_string(cls, arch_str): - pattern = r"^(?:sm_?|SM_?)?(\d+)(\d)([af]?)$" - match = re.match(pattern, arch_str) - if not match: - raise ValueError(f"Invalid architecture string format: {arch_str}") - major, minor, suffix = match.groups() - return cls((int(major), int(minor), suffix)) + def from_string(cls, arch_str: str) -> "Arch": + return cls[arch_str] + + def to_string(self) -> str: + return self.name @classmethod - def filter(cls, criterion: Callable[["Arch"], bool]) -> List["Arch"]: + def filter(cls, criterion: Callable[["Arch"], bool]) -> list["Arch"]: """ Filter the archs by the given criterion. """ @@ -129,7 +168,7 @@ class Arch(Enum): """ # sm_101 is renamed to sm_110, sm_101f is family of sm_110f, but is not family of sm_100f if self in [Arch.sm_101a, Arch.sm_101f]: - return arch.major == 11 and arch.minor == 0 + return arch.major == 11 and arch.minor >= 0 return ( self.major == arch.major @@ -137,22 +176,22 @@ class Arch(Enum): and self.suffix in ["a", "f"] ) - def __lt__(self, other): + def __lt__(self, other: object) -> bool: if not isinstance(other, Arch): return NotImplemented return (self.major, self.minor) < (other.major, other.minor) - def __le__(self, other): + def __le__(self, other: object) -> bool: if not isinstance(other, Arch): return NotImplemented return (self.major, self.minor) <= (other.major, other.minor) - def __gt__(self, other): + def __gt__(self, other: object) -> bool: if not isinstance(other, Arch): return NotImplemented return (self.major, self.minor) > (other.major, other.minor) - def __ge__(self, other): + def __ge__(self, other: object) -> bool: if not isinstance(other, Arch): return NotImplemented return (self.major, self.minor) >= (other.major, other.minor) diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py b/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py index 56a46004c..a8de32369 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_helpers.py @@ -14,14 +14,17 @@ This module provides helper functions that are generated by the preprocessor. The preprocessor read through python's ast and changes the input code. """ +from collections.abc import Callable, Iterator from dataclasses import dataclass -from typing import Callable, Iterator, Optional, overload, List +from typing import Any, overload from typing_extensions import deprecated import warnings import inspect +import types from types import BuiltinFunctionType from functools import lru_cache from inspect import getmembers +import builtins from .utils.logger import log from .common import * @@ -43,42 +46,41 @@ class Executor: if_execute: generate MLIR if OP """ - def __init__(self): - self._is_dynamic_expression = None - self._loop_execute_range_dynamic = None - self._if_dynamic = None - self._while_dynamic = None - self._compare_executor = None - self._any_executor = None - self._all_executor = None - self._builtin_redirector = None - self._ifexp_dynamic = None + def __init__(self) -> None: + self._is_dynamic_expression: Callable[..., Any] | None = None + self._loop_execute_range_dynamic: Callable[..., Any] | None = None + self._if_dynamic: Callable[..., Any] | None = None + self._while_dynamic: Callable[..., Any] | None = None + self._compare_executor: Callable[..., Any] | None = None + self._builtin_redirector: Callable[..., Any] | None = None + self._ifexp_dynamic: Callable[..., Any] | None = None + + @staticmethod + def _default_builtin_redirector(fcn: Callable[..., Any]) -> Callable[..., Any]: + # Default to no redirect + return fcn def set_functions( self, *, - is_dynamic_expression: Callable, - loop_execute_range_dynamic: Callable, - if_dynamic: Callable, - while_dynamic: Callable, - compare_executor: Callable, - any_executor: Callable = None, - all_executor: Callable = None, - builtin_redirector: Callable = None, - ifexp_dynamic: Callable = None, - ): + is_dynamic_expression: Callable[..., Any], + loop_execute_range_dynamic: Callable[..., Any], + if_dynamic: Callable[..., Any], + while_dynamic: Callable[..., Any], + compare_executor: Callable[..., Any], + builtin_redirector: Callable[..., Any] = _default_builtin_redirector, + ifexp_dynamic: Callable[..., Any] | None = None, + ) -> None: self._is_dynamic_expression = is_dynamic_expression self._loop_execute_range_dynamic = loop_execute_range_dynamic self._if_dynamic = if_dynamic self._while_dynamic = while_dynamic self._compare_executor = compare_executor - self._any_executor = any_executor - self._all_executor = all_executor self._builtin_redirector = builtin_redirector self._ifexp_dynamic = ifexp_dynamic @staticmethod - def convert_to_list(x): + def convert_to_list(x: Any) -> list[Any]: """This function is used to convert x to a list. If x is None, return an empty list. If x is not a list, return a list containing x. @@ -91,7 +93,7 @@ class Executor: return x @staticmethod - def converge_ret_val(res): + def converge_ret_val(res: Any) -> Any: """This function is used to converge res (the return value) of the function. If res is None, return None. If res is a list and has only one element, return the element. @@ -105,18 +107,19 @@ class Executor: def for_execute( self, - func, - start, - stop, - step, - write_args=[], - full_write_args_count=0, - write_args_names=[], - unroll=-1, - unroll_full=False, - prefetch_stages=None, - vectorize=None, - ): + func: Callable[..., Any], + start: Any, + stop: Any, + step: Any, + write_args: list[Any] = [], + full_write_args_count: int = 0, + write_args_names: list[str] = [], + unroll: int = -1, + unroll_full: bool = False, + prefetch_stages: int | None = None, + vectorize: int | None = None, + at_least_once: bool = False, + ) -> Any: assert self._loop_execute_range_dynamic, ( "Functions must be set before execution." ) @@ -127,24 +130,25 @@ class Executor: start, stop, step, - write_args, - full_write_args_count, - write_args_names, - unroll, - unroll_full, - prefetch_stages, - vectorize, + write_args=write_args, + full_write_args_count=full_write_args_count, + write_args_names=write_args_names, + unroll=unroll, + unroll_full=unroll_full, + prefetch_stages=prefetch_stages, + vectorize=vectorize, + at_least_once=at_least_once, ) def if_execute( self, - pred, - then_block: Callable, - else_block: Optional[Callable] = None, - write_args=[], - full_write_args_count=0, - write_args_names=[], - ): + pred: Any, + then_block: Callable[..., Any], + else_block: Callable[..., Any] | None = None, + write_args: list[Any] = [], + full_write_args_count: int = 0, + write_args_names: list[str] = [], + ) -> Any: assert self._if_dynamic, "Functions must be set before execution." # MLIR generation @@ -159,12 +163,12 @@ class Executor: def while_execute( self, - while_before_block: Callable, - while_after_block: Callable, - write_args=[], - full_write_args_count=0, - write_args_names=[], - ): + while_before_block: Callable[..., Any], + while_after_block: Callable[..., Any], + write_args: list[Any] = [], + full_write_args_count: int = 0, + write_args_names: list[str] = [], + ) -> Any: assert self._while_dynamic, "Functions must be set before execution." # MLIR generation @@ -178,11 +182,11 @@ class Executor: def ifexp_execute( self, - pred, - block_args: tuple, - then_block: Callable, - else_block: Callable, - ): + pred: Any, + block_args: tuple[Any, ...], + then_block: Callable[..., Any], + else_block: Callable[..., Any], + ) -> Any: assert self._ifexp_dynamic, "Functions must be set before execution." return self._ifexp_dynamic(pred, block_args, then_block, else_block) @@ -195,20 +199,21 @@ executor = Executor() def loop_selector( - start, - stop, - step, + start: Any, + stop: Any, + step: Any, *, - write_args=[], - full_write_args_count=0, - write_args_names=[], - unroll=-1, - unroll_full=False, - prefetch_stages=None, - vectorize=None, -): + write_args: list[Any] = [], + full_write_args_count: int = 0, + write_args_names: list[str] = [], + unroll: int = -1, + unroll_full: bool = False, + prefetch_stages: int | None = None, + vectorize: int | None = None, + at_least_once: bool = False, +) -> Callable[..., Any]: log().debug( - "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s] vectorize [%s]", + "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s] vectorize [%s] at_least_once [%s]", start, stop, step, @@ -219,10 +224,11 @@ def loop_selector( unroll_full, prefetch_stages, vectorize, + at_least_once, ) from .typing import Integer, Numeric - def _maybe_upcast(value): + def _maybe_upcast(value: Any) -> Any: if isinstance(value, Integer): value = value.ir_value() @@ -232,25 +238,26 @@ def loop_selector( stop = _maybe_upcast(stop) step = _maybe_upcast(step) - def ir_loop(func): + def ir_loop(func: Callable[..., Any]) -> Any: return executor.for_execute( func, start, stop, step, - write_args, - full_write_args_count, - write_args_names, - unroll, - unroll_full, - prefetch_stages, - vectorize, + write_args=write_args, + full_write_args_count=full_write_args_count, + write_args_names=write_args_names, + unroll=unroll, + unroll_full=unroll_full, + prefetch_stages=prefetch_stages, + vectorize=vectorize, + at_least_once=at_least_once, ) return ir_loop -def if_selector(pred, write_args=[]): +def if_selector(pred: Any, write_args: list[Any] = []) -> Callable[..., Any]: log().debug("pred [%s] write_args [%s]", pred, write_args) # Handle Numeric types here? @@ -259,26 +266,26 @@ def if_selector(pred, write_args=[]): if isinstance(pred, Numeric): pred = pred.value - def ir_loop(func): + def ir_loop(func: Callable[..., Any]) -> Any: return func(pred, *write_args) return ir_loop -def while_selector(*, write_args=[]): - def ir_while_loop(func): +def while_selector(*, write_args: list[Any] = []) -> Callable[..., Any]: + def ir_while_loop(func: Callable[..., Any]) -> Any: return func(*write_args) return ir_while_loop def while_executor( - while_before_block: Callable, - while_after_block: Callable, - write_args=[], - full_write_args_count=0, - write_args_names=[], -): + while_before_block: Callable[..., Any], + while_after_block: Callable[..., Any], + write_args: list[Any] = [], + full_write_args_count: int = 0, + write_args_names: list[str] = [], +) -> Any: return executor.while_execute( while_before_block, while_after_block, @@ -289,13 +296,13 @@ def while_executor( def if_executor( - pred, - then_block: Callable, - else_block: Optional[Callable] = None, - write_args=[], - full_write_args_count=0, - write_args_names=[], -): + pred: Any, + then_block: Callable[..., Any], + else_block: Callable[..., Any] | None = None, + write_args: list[Any] = [], + full_write_args_count: int = 0, + write_args_names: list[str] = [], +) -> Any: return executor.if_execute( pred, then_block, @@ -308,12 +315,12 @@ def if_executor( def ifExp_executor( *, - pred, - block_args: tuple, - then_block: Callable, - else_block: Callable, -): - if not executor._is_dynamic_expression(pred): + pred: Any, + block_args: tuple[Any, ...], + then_block: Callable[..., Any], + else_block: Callable[..., Any], +) -> Any: + if not executor._is_dynamic_expression(pred): # type: ignore[misc] return then_block(*block_args) if pred else else_block(*block_args) else: return executor.ifexp_execute(pred, block_args, then_block, else_block) @@ -340,28 +347,35 @@ class range: - unroll_full: Whether to fully unroll the loop - prefetch_stages: Number of prefetch stages to generate - vectorize: Whether to vectorize the loop (default: None) + - at_least_once: Annotate the loop as executing at least one iteration, + suppressing the zero-iteration bypass edge in analyses (default: False) """ @overload def __new__( - cls, stop, unroll=0, unroll_full=False, prefetch_stages=None, vectorize=None - ): - pass + cls, + stop: Any, + unroll: int = 0, + unroll_full: bool = False, + prefetch_stages: int | None = None, + vectorize: int | None = None, + at_least_once: bool = False, + ) -> "range": ... @overload def __new__( cls, - start, - stop, - step, - unroll=0, - unroll_full=False, - prefetch_stages=None, - vectorize=None, - ): - pass + start: Any, + stop: Any, + step: Any, + unroll: int = 0, + unroll_full: bool = False, + prefetch_stages: int | None = None, + vectorize: int | None = None, + at_least_once: bool = False, + ) -> "range": ... - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> "range": raise DSLRuntimeError("dynamic range should be always preprocessed to IR") def __iter__(self) -> Iterator[int]: @@ -371,11 +385,11 @@ class range: @deprecated( "range_dynamic is deprecated and will be removed in the future, please remove it." ) -def range_dynamic(*args, **kwargs): +def range_dynamic(*args: Any, **kwargs: Any) -> None: raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") -def range_constexpr(*args): +def range_constexpr(*args: Any) -> None: raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") @@ -384,7 +398,7 @@ def range_constexpr(*args): # ============================================================================= -def const_expr(expression): +def const_expr(expression: Any) -> Any: """ This function is used to check if the expression is a python value. If the expression is a python value, return the boolean value of the expression. @@ -399,7 +413,7 @@ def const_expr(expression): return expression.value else: failed = True - elif executor._is_dynamic_expression(expression): + elif executor._is_dynamic_expression(expression): # type: ignore[misc] failed = True if failed: @@ -415,7 +429,7 @@ def const_expr(expression): @deprecated( "dynamic_expr is deprecated and will be removed in the future, please remove it." ) -def dynamic_expr(expression): +def dynamic_expr(expression: Any) -> Any: return expression @@ -424,13 +438,13 @@ def dynamic_expr(expression): # ============================================================================= -def assert_executor(test, msg=None): +def assert_executor(test: Any, msg: str | None = None) -> None: from .typing import Numeric fail = False # Implicit convert dynamic expression to bool is not allowed # So here explicitly do a None check - if test is not None and executor._is_dynamic_expression(test): + if test is not None and executor._is_dynamic_expression(test): # type: ignore[misc] if isinstance(test, Numeric): try: test = test.to(bool) @@ -448,16 +462,16 @@ def assert_executor(test, msg=None): ) -def bool_cast(value): - if executor._is_dynamic_expression(value): +def bool_cast(value: Any) -> bool: + if executor._is_dynamic_expression(value): # type: ignore[misc] raise DSLRuntimeError( "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", - suggestion="Please explicitly convert to boolean with expressions like comparision.", + suggestion="Please explicitly convert to boolean with expressions like comparison.", ) return bool(value) -def compare_executor(left, comparators, ops): +def compare_executor(left: Any, comparators: list[Any], ops: list[Any]) -> Any: """ Executes comparison operations with a left operand and a list of comparators. @@ -478,35 +492,6 @@ def compare_executor(left, comparators, ops): return executor._compare_executor(left, comparators, ops) -def any_executor(iterable): - """Executes the 'any' operation on an iterable, handling both dynamic and static expressions. - - :param iterable: An iterable to check if any elements evaluate to True - :type iterable: Iterable - :return: boolean of Python value or IR value - :rtype: bool or cutlass.Boolean - - """ - if executor._any_executor and executor._is_dynamic_expression(iterable): - return executor._any_executor(iterable) - else: - return any(iterable) - - -def all_executor(iterable): - """Executes the 'all' operation on an iterable, handling both dynamic and static expressions. - - :param iterable: An iterable to check if all elements evaluate to True - :type iterable: Iterable - :return: boolean of Python value or IR value - :rtype: bool or cutlass.Boolean - """ - if executor._all_executor and executor._is_dynamic_expression(iterable): - return executor._all_executor(iterable) - else: - return all(iterable) - - # ============================================================================= # Control flow checks # ============================================================================= @@ -515,15 +500,15 @@ class DSLOptimizationWarning(Warning): This warning is used to warn the user about the optimization related issues in DSL. """ - def __init__(self, message): + def __init__(self, message: str) -> None: self.message = message super().__init__() - def __str__(self): + def __str__(self) -> str: return self.message -def range_value_check(*args): +def range_value_check(*args: Any) -> tuple[int, int, int]: """ Ensure all `range_constexpr` bounds are compile-time constants (Python ints). """ @@ -561,7 +546,7 @@ def range_value_check(*args): @lru_cache(maxsize=1) -def _get_self_module(): +def _get_self_module() -> types.ModuleType | None: """ This function is used to get the owning module of this function. """ @@ -569,7 +554,7 @@ def _get_self_module(): @lru_cache(maxsize=16) -def cf_symbol_check(symbol): +def cf_symbol_check(symbol: Any) -> None: """ Check if the symbol is control flow symbol from current module. """ @@ -578,12 +563,12 @@ def cf_symbol_check(symbol): name = symbol.__name__ self_module = _get_self_module() if inspect.ismodule(symbol): - if not self_module.__name__.startswith(name): + if not self_module.__name__.startswith(name): # type: ignore[union-attr] failed = True else: owning_module = inspect.getmodule(symbol) - root_module = owning_module.__name__.split(".")[0] - self_root_module = self_module.__name__.split(".")[0] + root_module = owning_module.__name__.split(".")[0] # type: ignore[union-attr] + self_root_module = self_module.__name__.split(".")[0] # type: ignore[union-attr] if root_module != self_root_module: failed = True @@ -594,18 +579,26 @@ def cf_symbol_check(symbol): ) -def redirect_builtin_function(fcn): +def redirect_builtin_function(fcn: Any) -> Any: """ This function is used to redirect built-in function call to the function defined in DSL package. """ # Only redirect if it's a built-in - if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector: - return executor._builtin_redirector(fcn) + if fcn is builtins.bool: + return bool_cast + + if isinstance(fcn, BuiltinFunctionType): + if fcn in [builtins.exec, builtins.eval]: + raise DSLRuntimeError( + f"Built-in function `{fcn.__name__}` is not supported in DSL.", + ) + if executor._builtin_redirector: + return executor._builtin_redirector(fcn) return fcn -def copy_members(dest, src): +def copy_members(dest: object, src: object) -> None: """ Copies all non-callable, non-dunder members from src to dest if they exist in src. Skips members that are callables or have names starting with double underscores. @@ -617,7 +610,7 @@ def copy_members(dest, src): for name, value in members: if ( name.startswith("__") - or isinstance(value, Callable) + or isinstance(value, Callable) # type: ignore[arg-type] or not hasattr(src, name) ): continue @@ -628,7 +621,7 @@ def copy_members(dest, src): pass -def get_locals_or_none(locals, symbols): +def get_locals_or_none(locals: dict[str, Any], symbols: list[str]) -> list[Any]: """ Given a locals() dictionary and a list of symbol names, return a list of their values in the same order as the symbols list. If a symbol is not present in locals, None is returned @@ -643,15 +636,33 @@ def get_locals_or_none(locals, symbols): return variables -def closure_check(closures): +def closure_check( + closures: list[Any], _visited: set[tuple[str, int]] | None = None +) -> None: """ - Check if the closures have any captures + Check if the closures have any unsupported capture """ + if _visited is None: + _visited = set() + for closure in closures: - if closure.__closure__: + # Use (function name, id) as identity to skip already-processed closures + # and prevent infinite recursion with mutually-recursive captured functions + closure_identity = (closure.__name__, id(closure)) + if closure_identity in _visited: + continue + _visited.add(closure_identity) + + closure_vars = inspect.getclosurevars(closure) + for name, value in closure_vars.nonlocals.items(): + if inspect.ismodule(value): + continue + if inspect.isfunction(value) or inspect.ismethod(value): + closure_check([value], _visited) + continue raise DSLRuntimeError( - f"Function `{closure.__name__}` is a closure that captures variables and is not supported in dynamic control flow", - suggestion="Please implicitly pass in captured variables as arguments", + f"Function `{closure.__name__}` is a closure that captures variable `{name}` and is not supported in dynamic control flow", + suggestion="Please explicitly pass in captured variables as arguments", ) @@ -675,9 +686,9 @@ class FormattedValue: value: Any conversion: int = -1 - format_spec: Optional[List[str]] = None + format_spec: list[str] | None = None - def to_str(self): + def to_str(self) -> tuple[str, Any | None]: """ Converts the FormattedValue into a format string component and (optional) dynamic argument. @@ -692,7 +703,7 @@ class FormattedValue: Raises: DSLRuntimeError: If the format spec is invalid or not supported for dynamic expressions. """ - if executor._is_dynamic_expression(self.value): + if executor._is_dynamic_expression(self.value): # type: ignore[misc] if self.conversion != -1: warnings.warn( "Conversion may not be honored for dynamic expressions", @@ -724,7 +735,9 @@ class FormattedValue: return f_str.format(self.value), None -def fstring_decompose(joinedStrComponent): +def fstring_decompose( + joinedStrComponent: list[str | FormattedValue], +) -> tuple[Any, ...]: """ Decomposes a joined f-string component list into a format string and dynamic arguments. diff --git a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py index b66744634..0ae0dc283 100644 --- a/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py +++ b/python/CuTeDSL/cutlass/base_dsl/ast_preprocessor.py @@ -39,11 +39,12 @@ import inspect import os import sys import textwrap +import types import warnings +from collections.abc import Callable, Generator, Iterable, Iterator from dataclasses import dataclass, field -from typing import List, Set, Dict, Any, Callable, Optional +from typing import Any, TypeVar from types import ModuleType -from collections import OrderedDict from copy import deepcopy from itertools import chain @@ -56,34 +57,37 @@ class OrderedSet: A deterministic set implementation for ordered operations. """ - def __init__(self, iterable=None): - self._dict = dict.fromkeys(iterable or []) + def __init__(self, iterable: Iterable[str] | None = None) -> None: + self._dict: dict[str, None] = dict.fromkeys(iterable or []) - def add(self, item): + def add(self, item: str) -> None: self._dict[item] = None - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self._dict) - def __and__(self, other): + def __contains__(self, item: object) -> bool: + return item in self._dict + + def __and__(self, other: "OrderedSet") -> "OrderedSet": return OrderedSet(key for key in self._dict if key in other) - def __or__(self, other): + def __or__(self, other: "OrderedSet") -> "OrderedSet": new_dict = self._dict.copy() new_dict.update(dict.fromkeys(other)) return OrderedSet(new_dict) - def __sub__(self, other): + def __sub__(self, other: "OrderedSet") -> "OrderedSet": return OrderedSet(key for key in self._dict if key not in other) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._dict) - def intersections(self, others): + def intersections(self, others: list[set[str]]) -> "OrderedSet": """Compute the intersection of this set with multiple other sets. :param others: A list of sets to compute intersections with - :type others: List[Set[str]] + :type others: list[set[str]] :return: A new ordered set containing elements that appear in this set and at least one of the other sets """ @@ -103,7 +107,7 @@ class ImportInfo: """ module_path: str - attr_name: Optional[str] + attr_name: str | None alias_name: str @@ -127,10 +131,10 @@ class TryImportInfo: executed depending on exception handling logic. """ - try_imports: list - except_imports: list - else_imports: list - finally_imports: list + try_imports: "list[ImportInfo | TryImportInfo]" + except_imports: "list[ImportInfo | TryImportInfo]" + else_imports: "list[ImportInfo | TryImportInfo]" + finally_imports: "list[ImportInfo | TryImportInfo]" @dataclass @@ -140,8 +144,8 @@ class ScopeManager: Manage nested scopes during transformations. """ - scopes: List[Set[str]] - callables: List[Set[str]] + scopes: list[set[str]] + callables: list[set[str]] @classmethod def create(cls) -> "ScopeManager": @@ -157,14 +161,14 @@ class ScopeManager: return self.callables[-1].add(name) - def get_active_symbols(self) -> List[Set[str]]: + def get_active_symbols(self) -> list[set[str]]: return self.scopes.copy() - def get_active_callables(self) -> List[Set[str]]: + def get_active_callables(self) -> list[set[str]]: return self.callables.copy() @contextlib.contextmanager - def enter_local_scope(self): + def enter_local_scope(self) -> Generator[None, None, None]: """ Context manager for entering a new local variable and callable scope. @@ -191,7 +195,7 @@ class ScopeManager: self.callables.pop() @contextlib.contextmanager - def enter_control_flow_scope(self): + def enter_control_flow_scope(self) -> Generator[None, None, None]: """ Context manager for entering a new dynamic control-flow scope. @@ -252,27 +256,32 @@ class Region: self, session_data: "SessionData", *, - owning_node: ast.stmt = None, - new_value: list[ast.stmt] = None, - ): + owning_node: ast.stmt | None = None, + new_value: list[ast.stmt] | None = None, + ) -> None: self.session_data = session_data self.owning_node = owning_node self.new_value = new_value - def __enter__(self): + def __enter__(self) -> "Region": if self.new_value is not None or isinstance(self.owning_node, ast.stmt): self.session_data.region_stack.append(self) if self.owning_node is not None: - self.owning_node._new_value = [] + self.owning_node._new_value = [] # type: ignore[attr-defined] return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: if self.new_value is not None or isinstance(self.owning_node, ast.stmt): self.session_data.region_stack.pop() if self.owning_node is not None: delattr(self.owning_node, "_new_value") - def append_new_stmts(self, stmts: list[ast.stmt]): + def append_new_stmts(self, stmts: list[ast.stmt]) -> None: """ Append a list of statements to the region's collection. @@ -282,8 +291,9 @@ class Region: The AST statements to append to this region. """ if self.owning_node is not None: - self.owning_node._new_value.extend(stmts) + self.owning_node._new_value.extend(stmts) # type: ignore[attr-defined] else: + assert self.new_value is not None self.new_value.extend(stmts) @@ -297,23 +307,25 @@ class SessionData: scope_manager: ScopeManager = field(default_factory=ScopeManager.create) function_counter: int = 0 function_name: str = "" - class_name: Optional[str] = None + class_name: str | None = None file_name: str = "" - function_globals: Optional[dict[str, Any]] = None + function_globals: dict[str, Any] | None = None import_top_module: bool = False region_stack: list[Region] = field(default_factory=list) generator_targets: list[str] = field(default_factory=list) lambda_args: list[str] = field(default_factory=list) @contextlib.contextmanager - def set_current_class_name(self, class_name: str): + def set_current_class_name(self, class_name: str) -> Generator[None, None, None]: old_class_name = self.class_name self.class_name = class_name yield self.class_name = old_class_name @contextlib.contextmanager - def set_current_function_name(self, function_name: str): + def set_current_function_name( + self, function_name: str + ) -> Generator[None, None, None]: old_function_name = self.function_name self.function_name = function_name yield @@ -321,13 +333,13 @@ class SessionData: def _create_module_attribute( - func_name, + func_name: str, *, - use_base_dsl=True, - submodule_name="ast_helpers", - lineno=None, - col_offset=None, -): + use_base_dsl: bool = True, + submodule_name: str | None = "ast_helpers", + lineno: int | None = None, + col_offset: int | None = None, +) -> ast.Attribute: """Creates an AST node representing a qualified attribute access to a function in a module or submodule. :param func_name: The attribute or function name to access @@ -347,7 +359,9 @@ def _create_module_attribute( """ # If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong. - def set_location(node, lineno, col_offset): + def set_location( + node: ast.expr, lineno: int | None, col_offset: int | None + ) -> None: if lineno is None or col_offset is None: return node.lineno = lineno @@ -355,16 +369,22 @@ def _create_module_attribute( node.col_offset = col_offset node.end_col_offset = col_offset - base = ast.Name( + base: ast.expr = ast.Name( id="__base_dsl__" if use_base_dsl else "__module_dsl__", ctx=ast.Load() ) set_location(base, lineno, col_offset) if submodule_name: base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load()) set_location(base, lineno, col_offset) - node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load()) - set_location(node, lineno, col_offset) - return node + result = ast.Attribute(value=base, attr=func_name, ctx=ast.Load()) + set_location(result, lineno, col_offset) + return result + + +_ComprehensionT = TypeVar( + "_ComprehensionT", ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp +) + class DSLPreprocessor(ast.NodeTransformer): """ @@ -382,15 +402,13 @@ class DSLPreprocessor(ast.NodeTransformer): IFEXP_EXECUTOR = "ifExp_executor" WHILE_EXECUTOR = "while_executor" ASSERT_EXECUTOR = "assert_executor" - BOOL_CAST = "bool_cast" IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType" SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"} CONST_EXPR_NAME = {"const_expr", "target_version"} COMPARE_EXECUTOR = "compare_executor" - ANY_EXECUTOR = "any_executor" - ALL_EXECUTOR = "all_executor" + BUILTIN_REDIRECTOR = "redirect_builtin_function" - def generic_visit(self, node): + def generic_visit(self, node: ast.AST) -> ast.AST: """ Copy of :meth:`ast.NodeTransformer.generic_visit` with support for inserting statements during expression visits. @@ -406,17 +424,17 @@ class DSLPreprocessor(ast.NodeTransformer): """ for field, old_value in ast.iter_fields(node): if isinstance(old_value, list): - with Region(self.session_data, owning_node=node): + with Region(self.session_data, owning_node=node): # type: ignore[arg-type] for value in old_value: if isinstance(value, ast.AST): value = self.visit(value) if value is None: continue elif not isinstance(value, ast.AST): - node._new_value.extend(value) + node._new_value.extend(value) # type: ignore[attr-defined] continue - node._new_value.append(value) - old_value[:] = node._new_value + node._new_value.append(value) # type: ignore[attr-defined] + old_value[:] = node._new_value # type: ignore[attr-defined] elif isinstance(old_value, ast.AST): new_node = self.visit(old_value) if new_node is None: @@ -425,17 +443,16 @@ class DSLPreprocessor(ast.NodeTransformer): setattr(node, field, new_node) return node - def __init__(self, client_module_name): + def __init__(self, client_module_name: list[str]) -> None: super().__init__() # Persistent state - self.processed_functions = set() + self.processed_functions: set[Callable[..., Any]] = set() self.client_module_name = client_module_name - self.module_cache = {} - self._session_data = None - + self.module_cache: dict[ModuleType, list[ImportInfo | TryImportInfo]] = {} + self._session_data: SessionData | None = None @contextlib.contextmanager - def get_session(self): + def get_session(self) -> Generator["DSLPreprocessor", None, None]: try: self._session_data = SessionData() yield self @@ -443,13 +460,15 @@ class DSLPreprocessor(ast.NodeTransformer): self._session_data = None @property - def session_data(self): + def session_data(self) -> SessionData: assert self._session_data is not None, ( "Please start a session before accessing session data" ) return self._session_data - def _get_imports_from_ast(self, node, module): + def _get_imports_from_ast( + self, node: ast.AST, module: ModuleType + ) -> list[ImportInfo | TryImportInfo]: """ Recursively extracts all import statements from the given AST node. @@ -466,8 +485,8 @@ class DSLPreprocessor(ast.NodeTransformer): Returns: A list of ImportInfo and TryImportInfo objects representing all discovered imports in the AST. """ - imports = [] - alias = lambda n: n.asname if n.asname else n.name + imports: list[ImportInfo | TryImportInfo] = [] + alias: Callable[[ast.alias], str] = lambda n: n.asname if n.asname else n.name for child_node in ast.iter_child_nodes(node): if isinstance(child_node, ast.Import): for name in child_node.names: @@ -500,7 +519,7 @@ class DSLPreprocessor(ast.NodeTransformer): for name in child_node.names: imports.append( ImportInfo( - module_path=module_name, + module_path=module_name or "", attr_name=name.name, alias_name=alias(name), ) @@ -509,25 +528,28 @@ class DSLPreprocessor(ast.NodeTransformer): elif isinstance(child_node, (ast.Try, getattr(ast, "TryStar", ast.Try))): # Handle try-catch try_imports = self._get_imports_from_ast( - ast.Module(body=child_node.body), module + ast.Module(body=child_node.body, type_ignores=[]), # type: ignore[attr-defined] + module, ) # search handler for ImportError or ModuleNotFoundError - except_imports = [] - for handler in child_node.handlers: + except_imports: list[ImportInfo | TryImportInfo] = [] + for handler in child_node.handlers: # type: ignore[attr-defined] if handler.type == None or handler.type.id in [ "ImportError", "ModuleNotFoundError", "Exception", ]: except_imports = self._get_imports_from_ast( - ast.Module(body=handler.body), module + ast.Module(body=handler.body, type_ignores=[]), module ) break else_imports = self._get_imports_from_ast( - ast.Module(body=child_node.orelse), module + ast.Module(body=child_node.orelse, type_ignores=[]), # type: ignore[attr-defined] + module, ) finally_imports = self._get_imports_from_ast( - ast.Module(body=child_node.finalbody), module + ast.Module(body=child_node.finalbody, type_ignores=[]), # type: ignore[attr-defined] + module, ) imports.append( TryImportInfo( @@ -536,9 +558,11 @@ class DSLPreprocessor(ast.NodeTransformer): ) return imports - def _get_module_imports(self, decorated_func): + def _get_module_imports( + self, decorated_func: Callable[..., Any] + ) -> list[ImportInfo | TryImportInfo]: """Extract imports from the module containing the decorated function""" - imports = [] + imports: list[ImportInfo | TryImportInfo] = [] # Get the module containing the decorated function if module := inspect.getmodule(decorated_func): @@ -556,9 +580,9 @@ class DSLPreprocessor(ast.NodeTransformer): return imports - def try_import_first_and_then_local_import(self, module_path): + def try_import_first_and_then_local_import(self, module_path: str) -> ModuleType: @contextlib.contextmanager - def local_import(module_path): + def local_import(module_path: str) -> Generator[ModuleType, None, None]: # Directory where some local import might happen: local_dir = os.path.dirname(self.session_data.file_name) # Momentarily insert the directory where the local import @@ -579,7 +603,9 @@ class DSLPreprocessor(ast.NodeTransformer): with local_import(module_path) as module: return module - def exec_import(self, import_info, exec_globals): + def exec_import( + self, import_info: ImportInfo, exec_globals: dict[str, Any] + ) -> None: module_path, attr_name, alias_name = ( import_info.module_path, import_info.attr_name, @@ -601,7 +627,11 @@ class DSLPreprocessor(ast.NodeTransformer): else: exec_globals[alias_name] = module - def exec_imports(self, import_infos, exec_globals): + def exec_imports( + self, + import_infos: list[ImportInfo | TryImportInfo], + exec_globals: dict[str, Any], + ) -> None: for import_info in import_infos: if isinstance(import_info, ImportInfo): try: @@ -620,7 +650,13 @@ class DSLPreprocessor(ast.NodeTransformer): finally: self.exec_imports(import_info.finally_imports, exec_globals) - def exec(self, function_name, original_function, code_object, exec_globals): + def exec( + self, + function_name: str, + original_function: Callable[..., Any], + code_object: types.CodeType, + exec_globals: dict[str, Any], + ) -> Callable[..., Any] | None: """Requires an active DSL preprocessor session.""" # Get imports from the original module module_imports = self._get_module_imports(original_function) @@ -637,13 +673,13 @@ class DSLPreprocessor(ast.NodeTransformer): return exec_globals.get(function_name) @staticmethod - def print_ast(transformed_tree=None): + def print_ast(transformed_tree: ast.AST | None = None) -> None: print("#", "-" * 40, "Transformed AST", "-" * 40) - unparsed_code = ast.unparse(transformed_tree) + unparsed_code = ast.unparse(transformed_tree) # type: ignore[arg-type] print(unparsed_code) print("#", "-" * 40, "End Transformed AST", "-" * 40) - def make_func_param_name(self, base_name, used_names): + def make_func_param_name(self, base_name: str, used_names: Iterable[str]) -> str: """Generate a unique parameter name that doesn't collide with existing names.""" if base_name not in used_names: return base_name @@ -653,7 +689,48 @@ class DSLPreprocessor(ast.NodeTransformer): i += 1 return f"{base_name}_{i}" - def transform_function(self, func_name, function_pointer): + def _inject_default_arg_values( + self, + function_pointer: Callable[..., Any], + func_ast: ast.FunctionDef, + ) -> None: + """Inject default-argument values whose source-level names are unresolvable. + + When a decorated function uses ``_param=name`` where ``name`` is a local + in an enclosing factory, ``exec()`` needs ``name`` in its namespace. + We use ``inspect.signature`` for runtime default values and the + already-parsed *func_ast* for the source-level name each default + references. + """ + exec_globals = self.session_data.function_globals + if exec_globals is None: + return + sig = inspect.signature(function_pointer) + params_with_defaults = { + name: param.default + for name, param in sig.parameters.items() + if param.default is not inspect.Parameter.empty + } + if not params_with_defaults: + return + # Build map: parameter name → AST default node + # (covers both positional and keyword-only parameters) + ast_defaults: dict[str, ast.expr] = {} + all_args = func_ast.args.posonlyargs + func_ast.args.args + offset = len(all_args) - len(func_ast.args.defaults) + for i, default_node in enumerate(func_ast.args.defaults): + ast_defaults[all_args[offset + i].arg] = default_node + for kwarg, kw_default in zip(func_ast.args.kwonlyargs, func_ast.args.kw_defaults): + if kw_default is not None: + ast_defaults[kwarg.arg] = kw_default + for param_name, default_val in params_with_defaults.items(): + ast_node = ast_defaults.get(param_name) + if isinstance(ast_node, ast.Name) and ast_node.id not in exec_globals: + exec_globals[ast_node.id] = default_val + + def transform_function( + self, func_name: str, function_pointer: Callable[..., Any] + ) -> list[ast.stmt]: """ Transforms a function. """ @@ -666,7 +743,7 @@ class DSLPreprocessor(ast.NodeTransformer): # Step 1. Parse the given function try: - file_name = inspect.getsourcefile(function_pointer) + file_name = inspect.getsourcefile(function_pointer) or "" lines, start_line = inspect.getsourcelines(function_pointer) dedented_source = textwrap.dedent("".join(lines)) tree = ast.parse(dedented_source, filename=file_name) @@ -690,6 +767,15 @@ class DSLPreprocessor(ast.NodeTransformer): self.processed_functions.add(function_pointer) log().info("ASTPreprocessor Transforming function [%s]", func_name) + # Step 1.3 Inject default-argument values from enclosing scopes. + # When a decorated function uses `_param=name` where `name` is a + # local in the enclosing factory, exec() needs `name` in its + # namespace. We use the already-parsed AST to find source-level + # names and inspect.signature to get runtime values. + func_def = tree.body[0] + assert isinstance(func_def, ast.FunctionDef) + self._inject_default_arg_values(function_pointer, func_def) + # Step 2. Transform the function transformed_tree = self.visit(tree) @@ -767,54 +853,53 @@ class DSLPreprocessor(ast.NodeTransformer): # Step 5. Return the transformed tree return combined_body - def check_early_exit(self, tree, kind): + def check_early_exit(self, tree: ast.AST, kind: str) -> None: """ Checks if a given region or scope in the provided Python code has early exits. """ class EarlyExitChecker(ast.NodeVisitor): - def __init__(self, kind): + def __init__(self, kind: str) -> None: self.has_early_exit = False - self.early_exit_node = None - self.early_exit_type = None + self.early_exit_node: ast.AST | None = None + self.early_exit_type: str | None = None self.kind = kind self.loop_nest_level = 0 - # Early exit is not allowed in any level of dynamic control flow - def visit_Return(self, node): + def visit_Return(self, node: ast.Return) -> None: self.has_early_exit = True self.early_exit_node = node self.early_exit_type = "return" - def visit_Raise(self, node): + def visit_Raise(self, node: ast.Raise) -> None: self.has_early_exit = True self.early_exit_node = node self.early_exit_type = "raise" - def visit_Break(self, node): + def visit_Break(self, node: ast.Break) -> None: # For break/continue in inner loops, we don't consider it as early exit if self.loop_nest_level == 0 and self.kind != "if": self.has_early_exit = True self.early_exit_node = node self.early_exit_type = "break" - def visit_Continue(self, node): + def visit_Continue(self, node: ast.Continue) -> None: if self.loop_nest_level == 0 and self.kind != "if": self.has_early_exit = True self.early_exit_node = node self.early_exit_type = "continue" - def visit_For(self, node): + def visit_For(self, node: ast.For) -> None: self.loop_nest_level += 1 self.generic_visit(node) self.loop_nest_level -= 1 - def visit_While(self, node): + def visit_While(self, node: ast.While) -> None: self.loop_nest_level += 1 self.generic_visit(node) self.loop_nest_level -= 1 - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # Stop at nested function def return @@ -839,7 +924,7 @@ class DSLPreprocessor(ast.NodeTransformer): ), ) - def is_node_constexpr(self, node) -> bool: + def is_node_constexpr(self, node: ast.If | ast.While) -> bool: """ Determines if the node is a constexpr. Supported nodes are if, while statements. @@ -858,7 +943,9 @@ class DSLPreprocessor(ast.NodeTransformer): return True return False - def _get_range_kind(self, iter_node): + def _get_range_kind( + self, iter_node: ast.expr + ) -> tuple[str | None, bool | None, bool | None]: """ Return "range", "range_dynamic", "range_constexpr" or None for the iterable """ @@ -876,12 +963,19 @@ class DSLPreprocessor(ast.NodeTransformer): return func.attr, False, len(iter_node.keywords) != 0 return None, None, None - def transform(self, original_function, exec_globals): + def transform( + self, + original_function: Callable[..., Any], + exec_globals: dict[str, Any], + callee_rewrite: bool = False, + ) -> ast.Module: """ Transforms the provided function using the preprocessor. Requires an active DSL preprocessor session. """ - self.session_data.file_name = inspect.getsourcefile(original_function) + self.session_data.file_name = ( + inspect.getsourcefile(original_function) or "" + ) self.session_data.function_globals = exec_globals transformed_tree = self.transform_function( original_function.__name__, original_function @@ -894,10 +988,10 @@ class DSLPreprocessor(ast.NodeTransformer): def analyze_region_variables( self, - node: Union[ast.For, ast.If, ast.While], - active_symbols: List[Set[str]], - active_callables: List[Set[str]], - ): + node: ast.For | ast.If | ast.While, + active_symbols: list[set[str]], + active_callables: list[set[str]], + ) -> tuple[list[str], int, list[str]]: """ Analyze variables in different code regions to identify read-only, write-only, and active variables for DSL constructs. @@ -911,14 +1005,14 @@ class DSLPreprocessor(ast.NodeTransformer): class RegionAnalyzer(ast.NodeVisitor): force_store = False - def visit_Name(self, node): + def visit_Name(self, node: ast.Name) -> None: """ Mark every store as write. """ if isinstance(node.ctx, ast.Store) or self.force_store: write_args.add(node.id) - def visit_Subscript(self, node): + def visit_Subscript(self, node: ast.Subscript) -> None: # When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`. # We need to force the store for the `Name` to be marked as write. if isinstance(node.ctx, ast.Store): @@ -929,22 +1023,22 @@ class DSLPreprocessor(ast.NodeTransformer): else: self.generic_visit(node) - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign) -> None: self.force_store = True [self.visit(target) for target in node.targets] self.force_store = False self.visit(node.value) - def visit_AugAssign(self, node): + def visit_AugAssign(self, node: ast.AugAssign) -> None: self.force_store = True self.visit(node.target) self.force_store = False self.visit(node.value) @staticmethod - def get_call_base(func_node): + def get_call_base(func_node: ast.expr) -> str | None: + # If the .value is another Attribute, keep digging if isinstance(func_node, ast.Attribute): - # If the .value is another Attribute, keep digging if isinstance(func_node.value, ast.Attribute): return RegionAnalyzer.get_call_base(func_node.value) # If the .value is a Name, that's our base @@ -958,7 +1052,7 @@ class DSLPreprocessor(ast.NodeTransformer): return None @staticmethod - def get_function_name(func_node: ast.Call): + def get_function_name(func_node: ast.Call) -> str | None: if isinstance(func_node.func, ast.Name): function_name = func_node.func.id # Check if it's a method or attribute call @@ -968,7 +1062,7 @@ class DSLPreprocessor(ast.NodeTransformer): function_name = None return function_name - def visit_Call(self, node): + def visit_Call(self, node: ast.Call) -> None: base_name = RegionAnalyzer.get_call_base(node.func) if isinstance(node.func, ast.Name): @@ -983,24 +1077,32 @@ class DSLPreprocessor(ast.NodeTransformer): self.generic_visit(node) analyzer = RegionAnalyzer() - analyzer.visit(ast.Module(body=node.body)) + analyzer.visit(ast.Module(body=node.body, type_ignores=[])) if node.orelse: - analyzer.visit(ast.Module(body=node.orelse)) + analyzer.visit(ast.Module(body=node.orelse, type_ignores=[])) # While's loop condition is executed n times, as loop body # So collect the variables used in the loop condition if isinstance(node, ast.While): - analyzer.visit(ast.Module(body=node.test)) + analyzer.visit(ast.Module(body=node.test, type_ignores=[])) # type: ignore[arg-type] # If arg is both write and invoke, remove from invoked_args invoked_args = invoked_args - write_args - write_args = list(write_args.intersections(active_symbols)) - invoked_args = list(invoked_args.intersections(active_symbols)) - called_functions = list(called_functions.intersections(active_callables)) - return write_args + invoked_args, len(write_args), called_functions + write_args_list: list[str] = list(write_args.intersections(active_symbols)) + invoked_args_list: list[str] = list(invoked_args.intersections(active_symbols)) + called_functions_list: list[str] = list( + called_functions.intersections(active_callables) + ) + return ( + write_args_list + invoked_args_list, + len(write_args_list), + called_functions_list, + ) - def extract_range_args(self, iter_node): + def extract_range_args( + self, iter_node: ast.Call + ) -> tuple[ast.expr, ast.expr, ast.expr, bool]: args = iter_node.args if len(args) == 1: return ( @@ -1024,21 +1126,23 @@ class DSLPreprocessor(ast.NodeTransformer): filename=self.session_data.file_name, ) - def extract_unroll_args(self, iter_node): + def extract_unroll_args(self, iter_node: ast.Call) -> tuple[ast.expr, ast.expr]: keywords = {kw.arg: kw.value for kw in iter_node.keywords} return ( keywords.get("unroll", ast.Constant(value=-1)), keywords.get("unroll_full", ast.Constant(value=False)), ) - def issue_deprecation_warning(self, *, message, category, filename, lineno): + def issue_deprecation_warning( + self, *, message: str, category: type[Warning], filename: str, lineno: int + ) -> None: warnings.simplefilter("always", category) # turn off filter warnings.warn_explicit( message, category=category, filename=filename, lineno=lineno ) warnings.simplefilter("default", category) # reset filter - def extract_prefetch_stages_args(self, iter_node): + def extract_prefetch_stages_args(self, iter_node: ast.Call) -> ast.expr: keywords = {kw.arg: kw.value for kw in iter_node.keywords} if "pipelining" in keywords: self.issue_deprecation_warning( @@ -1050,33 +1154,39 @@ class DSLPreprocessor(ast.NodeTransformer): return keywords.get("pipelining", ast.Constant(value=None)) return keywords.get("prefetch_stages", ast.Constant(value=None)) - def extract_vectorize_args(self, iter_node): + def extract_vectorize_args(self, iter_node: ast.Call) -> ast.expr: keywords = {kw.arg: kw.value for kw in iter_node.keywords} return keywords.get("vectorize", ast.Constant(value=None)) + def extract_at_least_once_args(self, iter_node: ast.Call) -> ast.expr: + keywords = {kw.arg: kw.value for kw in iter_node.keywords} + return keywords.get("at_least_once", ast.Constant(value=False)) + def create_loop_function( self, - func_name, - node, - start, - stop, - step, - unroll, - unroll_full, - prefetch_stages, - vectorize, - write_args, - full_write_args_count, - ): + func_name: str, + node: ast.For, + start: ast.expr, + stop: ast.expr, + step: ast.expr, + unroll: ast.expr, + unroll_full: ast.expr, + prefetch_stages: ast.expr, + vectorize: ast.expr, + at_least_once: ast.expr, + write_args: list[str], + full_write_args_count: int, + ) -> ast.FunctionDef: """ Creates a loop body function with the `loop_selector` decorator. """ + assert isinstance(node.target, ast.Name) func_args = [ast.arg(arg=node.target.id, annotation=None)] func_args += [ast.arg(arg=var, annotation=None) for var in write_args] # Create the loop body - transformed_body = [] + transformed_body: list[ast.stmt] = [] with Region(self.session_data, new_value=transformed_body): for stmt in node.body: transformed_stmt = self.visit( @@ -1114,6 +1224,7 @@ class DSLPreprocessor(ast.NodeTransformer): ast.keyword(arg="unroll_full", value=unroll_full), ast.keyword(arg="prefetch_stages", value=prefetch_stages), ast.keyword(arg="vectorize", value=vectorize), + ast.keyword(arg="at_least_once", value=at_least_once), ast.keyword( arg="write_args", value=self.generate_get_locals_or_none_call(write_args), @@ -1150,7 +1261,7 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) - def visit_BoolOp(self, node): + def visit_BoolOp(self, node: ast.BoolOp) -> ast.expr: # Visit child nodes first self.generic_visit(node) @@ -1197,44 +1308,52 @@ class DSLPreprocessor(ast.NodeTransformer): snippet=ast.unparse(node), ) - def short_circuit_eval(value, short_circuit_value): - return ast.BoolOp( - op=ast.And(), - values=[ - ast.Compare( - left=ast.Call( - func=ast.Name(id="type", ctx=ast.Load()), - args=[value], - keywords=[], + def short_circuit_eval( + value: ast.expr, short_circuit_value: ast.Constant + ) -> ast.BoolOp: + return ast.copy_location( + ast.BoolOp( + op=ast.And(), + values=[ + ast.Compare( + left=ast.Call( + func=ast.Name(id="type", ctx=ast.Load()), + args=[value], + keywords=[], + ), + ops=[ast.Eq()], + comparators=[ast.Name(id="bool", ctx=ast.Load())], ), - ops=[ast.Eq()], - comparators=[ast.Name(id="bool", ctx=ast.Load())], - ), - ast.Compare( - left=value, - ops=[ast.Eq()], - comparators=[short_circuit_value], - ), - ], + ast.Compare( + left=value, + ops=[ast.Eq()], + comparators=[short_circuit_value], + ), + ], + ), + node, ) lhs = node.values[0] for i in range(1, len(node.values)): test = short_circuit_eval(lhs, short_circuit_value) - lhs = ast.IfExp( - test=test, - body=lhs, - orelse=ast.Call( - func=helper_func, - args=[lhs, node.values[i]], - keywords=[], + lhs = ast.copy_location( + ast.IfExp( + test=test, + body=lhs, + orelse=ast.Call( + func=helper_func, + args=[lhs, node.values[i]], + keywords=[], + ), ), + node, ) - return ast.copy_location(lhs, node) + return lhs - def visit_UnaryOp(self, node): + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.expr: # Visit child nodes first self.generic_visit(node) @@ -1254,10 +1373,11 @@ class DSLPreprocessor(ast.NodeTransformer): return node - def _insert_range_value_check(self, node): + def _insert_range_value_check(self, node: ast.For) -> None: """ Insert a check for range arguments """ + assert isinstance(node.iter, ast.Call) range_inputs = node.iter.args check_call = ast.copy_location( ast.Call( @@ -1278,7 +1398,7 @@ class DSLPreprocessor(ast.NodeTransformer): node.iter, ) - def _insert_cf_symbol_check(self, func): + def _insert_cf_symbol_check(self, func: ast.expr) -> ast.Expr: """ Insert a check for range symbol """ @@ -1294,12 +1414,13 @@ class DSLPreprocessor(ast.NodeTransformer): ) return ast.Expr(check_call) - def visit_For(self, node): + def visit_For(self, node: ast.For) -> ast.For | list[ast.stmt]: # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) if range_kind == "range_constexpr" or range_kind == None: self.generic_visit(node) if range_kind == "range_constexpr": + assert isinstance(node.iter, ast.Call) check_call = self._insert_cf_symbol_check(node.iter.func) # Rewrite range_constexpr to range node.iter.func = ast.Name(id="range", ctx=ast.Load()) @@ -1324,13 +1445,13 @@ class DSLPreprocessor(ast.NodeTransformer): ) is_prefixed_range = range_kind == "range" and not is_builtin_range - check_call = None + check_call: ast.Expr | None = None # type: ignore[no-redef] if range_kind == "range_dynamic" or is_prefixed_range: - # Insert a check for range symbol + assert isinstance(node.iter, ast.Call) if not is_prefixed_range: check_call = self._insert_cf_symbol_check(node.iter.func) else: - # Get toplevel module + assert isinstance(node.iter.func, ast.Attribute) check_call = self._insert_cf_symbol_check(node.iter.func.value) new_for_node = self.transform_for_loop( @@ -1342,12 +1463,20 @@ class DSLPreprocessor(ast.NodeTransformer): return new_for_node @staticmethod - def _hoist_expr_to_assignments(expr, name): + def _hoist_expr_to_assignments(expr: ast.expr, name: str) -> ast.Assign: return ast.copy_location( ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr ) - def _build_select_and_assign(self, *, name, test, body, orelse, location): + def _build_select_and_assign( + self, + *, + name: str, + test: ast.expr, + body: ast.expr, + orelse: ast.expr, + location: ast.AST, + ) -> ast.Assign: node = ast.copy_location( ast.Assign( targets=[ast.Name(id=name, ctx=ast.Store())], @@ -1361,7 +1490,13 @@ class DSLPreprocessor(ast.NodeTransformer): ) return node - def _handle_negative_step(self, node, start_expr, stop_expr, step_expr): + def _handle_negative_step( + self, + node: ast.For, + start_expr: ast.expr, + stop_expr: ast.expr, + step_expr: ast.expr, + ) -> tuple[ast.Name, ast.Name, ast.Name, list[ast.stmt]]: # hoist start, stop, step to assignments start_ori_name = f"start_ori_{self.session_data.counter}" start = self._hoist_expr_to_assignments(start_expr, start_ori_name) @@ -1370,7 +1505,7 @@ class DSLPreprocessor(ast.NodeTransformer): step_ori_name = f"step_ori_{self.session_data.counter}" step = self._hoist_expr_to_assignments(step_expr, step_ori_name) - extra_exprs = [start, stop, step] + extra_exprs: list[ast.stmt] = [start, stop, step] # Handle possible negative step, generates the following code in Python: # isNegative = step < 0 @@ -1434,11 +1569,11 @@ class DSLPreprocessor(ast.NodeTransformer): ) with Region(self.session_data, new_value=extra_exprs): - extra_exprs.append(self.generic_visit(isNegative)) - extra_exprs.append(self.generic_visit(start)) - extra_exprs.append(self.generic_visit(stop)) - extra_exprs.append(self.generic_visit(step)) - extra_exprs.append(self.generic_visit(offset)) + extra_exprs.append(self.generic_visit(isNegative)) # type: ignore[arg-type] + extra_exprs.append(self.generic_visit(start)) # type: ignore[arg-type] + extra_exprs.append(self.generic_visit(stop)) # type: ignore[arg-type] + extra_exprs.append(self.generic_visit(step)) # type: ignore[arg-type] + extra_exprs.append(self.generic_visit(offset)) # type: ignore[arg-type] # Add this to begining of loop body # for i in range(start, stop, step): @@ -1446,19 +1581,22 @@ class DSLPreprocessor(ast.NodeTransformer): assert isinstance(node.target, ast.Name) target_name = node.target.id - target = self._build_select_and_assign( - name=target_name, - test=ast.Name(id=isNegative_name, ctx=ast.Load()), - body=ast.BinOp( - op=ast.Sub(), - left=ast.Name(id=offset_name, ctx=ast.Load()), - right=ast.Name(id=target_name, ctx=ast.Load()), - ), - orelse=ast.Name(id=target_name, ctx=ast.Load()), - location=node.target, - ) - node.body.insert(0, target) + if target_name != "_": + # if target_name is _, skip the assign back + target = self._build_select_and_assign( + name=target_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Sub(), + left=ast.Name(id=offset_name, ctx=ast.Load()), + right=ast.Name(id=target_name, ctx=ast.Load()), + ), + orelse=ast.Name(id=target_name, ctx=ast.Load()), + location=node.target, + ) + + node.body.insert(0, target) return ( ast.Name(id=start_name, ctx=ast.Load()), @@ -1467,7 +1605,9 @@ class DSLPreprocessor(ast.NodeTransformer): extra_exprs, ) - def _create_closure_check_call(self, called_closures, node): + def _create_closure_check_call( + self, called_closures: list[str], node: ast.stmt + ) -> ast.Expr: return ast.Expr( ast.Call( func=_create_module_attribute( @@ -1485,7 +1625,12 @@ class DSLPreprocessor(ast.NodeTransformer): ) ) - def transform_for_loop(self, node, active_symbols, active_callables): + def transform_for_loop( + self, + node: ast.For, + active_symbols: list[set[str]], + active_callables: list[set[str]], + ) -> list[ast.stmt]: # Check for early exit and raise exception self.check_early_exit(node, "for") if node.orelse: @@ -1496,7 +1641,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) # Get loop target variable name - target_var_name = None + target_var_name: str | None = None target_var_is_active_before_loop = False if isinstance(node.target, ast.Name): target_var_name = node.target.id @@ -1508,7 +1653,7 @@ class DSLPreprocessor(ast.NodeTransformer): # Add necessary exprs to handle this if target_var_is_active_before_loop: - # Initialize an extra loop carried variable + assert target_var_name is not None loop_carried_var_name = f"loop_carried_var_{self.session_data.counter}" pre_loop_expr = ast.copy_location( ast.Assign( @@ -1529,20 +1674,36 @@ class DSLPreprocessor(ast.NodeTransformer): ) active_symbols.append({loop_carried_var_name}) + assert isinstance(node.iter, ast.Call) start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) unroll, unroll_full = self.extract_unroll_args(node.iter) prefetch_stages = self.extract_prefetch_stages_args(node.iter) vectorize = self.extract_vectorize_args(node.iter) + at_least_once = self.extract_at_least_once_args(node.iter) write_args, full_write_args_count, called_closures = ( self.analyze_region_variables(node, active_symbols, active_callables) ) - if has_step and self.client_module_name[0] == "cutlass": - start, stop, step, exprs = self._handle_negative_step( + has_positive_step = ( + isinstance(step_expr, ast.Constant) + and isinstance(step_expr.value, (int, float)) + and step_expr.value > 0 + ) + + if ( + has_step + and self.client_module_name[0] == "cutlass" + and not has_positive_step + ): + start_n, stop_n, step_n, exprs = self._handle_negative_step( node, start_expr, stop_expr, step_expr ) + start: ast.expr = start_n + stop: ast.expr = stop_n + step: ast.expr = step_n else: - start, stop, step, exprs = start_expr, stop_expr, step_expr, [] + start, stop, step = start_expr, stop_expr, step_expr + exprs: list[ast.stmt] = [] # type: ignore[no-redef] if target_var_is_active_before_loop: exprs.append(pre_loop_expr) @@ -1563,6 +1724,7 @@ class DSLPreprocessor(ast.NodeTransformer): unroll_full, prefetch_stages, vectorize, + at_least_once, write_args, full_write_args_count, ) @@ -1573,7 +1735,7 @@ class DSLPreprocessor(ast.NodeTransformer): exprs = exprs + [func_def] + assign if target_var_is_active_before_loop: - # Create a new assignment to the target variable + assert target_var_name is not None exprs.append( ast.copy_location( ast.Assign( @@ -1586,7 +1748,7 @@ class DSLPreprocessor(ast.NodeTransformer): return exprs - def visit_Assert(self, node): + def visit_Assert(self, node: ast.Assert) -> ast.Expr: test = self.visit(node.test) args = [ast.keyword(arg="test", value=test)] @@ -1609,18 +1771,19 @@ class DSLPreprocessor(ast.NodeTransformer): ast.copy_location(new_node, node) return new_node - def processFormattedValue(self, node): + def processFormattedValue(self, node: ast.FormattedValue) -> ast.Call: """ Converts an ast.FormattedValue node into a runtime representation of an ast.FormattedValue. This function takes an ast.FormattedValue node and converts it into a runtime representation of ast.FormattedValue. """ - keywords = [] + keywords: list[ast.keyword] = [] if node.conversion != -1: keywords.append( ast.keyword(arg="conversion", value=ast.Constant(value=node.conversion)) ) if node.format_spec: + assert isinstance(node.format_spec, ast.JoinedStr) keywords.append( ast.keyword( arg="format_spec", @@ -1638,7 +1801,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) return ast.copy_location(call, node) - def processFString(self, node): + def processFString(self, node: ast.Call) -> ast.Call: """ Converts an f-string node into a runtime representation of an f-string. @@ -1646,8 +1809,9 @@ class DSLPreprocessor(ast.NodeTransformer): where each element is either a literal string or a FormattedValue. The FormattedValue is converted into a runtime representation of ast.FormattedValue. """ - elements = [] + elements: list[ast.expr] = [] joinedStr = node.args[0] + assert isinstance(joinedStr, ast.JoinedStr) for component in joinedStr.values: if isinstance(component, ast.Constant): elements.append(component) @@ -1670,58 +1834,33 @@ class DSLPreprocessor(ast.NodeTransformer): ) return ast.copy_location(call, node) - def visit_Call(self, node): - func = node.func + def visit_Call(self, node: ast.Call) -> ast.Call: + func = self.visit(node.func) # Visit args and kwargs node.args = [self.visit(arg) for arg in node.args] node.keywords = [self.visit(kwarg) for kwarg in node.keywords] # Rewrite call to some built-in functions if isinstance(func, ast.Name): - # Check if the function is 'bool' + # AST rewrite only redirect call to bool to bool_cast + # If `bool` escapes as a symbol, usually it means type check, do not rewrite it if func.id == "bool": return ast.copy_location( ast.Call( - func=_create_module_attribute( - self.BOOL_CAST, - lineno=node.lineno, - col_offset=node.col_offset, + func=ast.Call( + func=_create_module_attribute( + self.BUILTIN_REDIRECTOR, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[func], + keywords=[], ), args=[node.args[0]], keywords=[], ), node, ) - elif func.id in ["any", "all"]: - helper_func = ( - self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR - ) - return ast.copy_location( - ast.Call( - func=_create_module_attribute( - helper_func, lineno=node.lineno, col_offset=node.col_offset - ), - args=[node.args[0]], - keywords=[], - ), - node, - ) - elif func.id in ["min", "max"]: - self.session_data.import_top_module = True - return ast.copy_location( - ast.Call( - func=_create_module_attribute( - func.id, - use_base_dsl=False, - submodule_name=None, - lineno=node.lineno, - col_offset=node.col_offset, - ), - args=node.args, - keywords=[], - ), - node, - ) elif func.id == "super" and node.args == [] and node.keywords == []: # If it's a Python3 argument free super(), rewrite to old style super with args # So if this call is under dynamic control flow, it still works. @@ -1742,7 +1881,7 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) elif ( - func.id == "printf" + func.id in ("printf", "print_runtime") and len(node.args) > 0 and isinstance(node.args[0], ast.JoinedStr) ): @@ -1751,7 +1890,7 @@ class DSLPreprocessor(ast.NodeTransformer): ] elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): if ( - func.attr == "printf" + func.attr in ("printf", "print_runtime") and len(node.args) > 0 and isinstance(node.args[0], ast.JoinedStr) ): @@ -1760,7 +1899,7 @@ class DSLPreprocessor(ast.NodeTransformer): ] else: - def create_downcast_call(arg): + def create_downcast_call(arg: ast.expr) -> ast.Call: return ast.copy_location( ast.Call( func=_create_module_attribute( @@ -1775,15 +1914,16 @@ class DSLPreprocessor(ast.NodeTransformer): arg, ) - module = self.session_data.function_globals.get(func.value.id) - if isinstance(module, ModuleType) and module.__package__.endswith( - "._mlir.dialects" - ): + fn_globals = self.session_data.function_globals + module = fn_globals.get(func.value.id) if fn_globals else None + if isinstance(module, ModuleType) and ( + module.__package__ or "" + ).endswith("._mlir.dialects"): # Check if argument is Numeric, if so, call ir_value() - args = [] + args: list[ast.expr] = [] for arg in node.args: args.append(create_downcast_call(arg)) - kwargs = [] + kwargs: list[ast.keyword] = [] for kwarg in node.keywords: kwargs.append( ast.copy_location( @@ -1802,11 +1942,11 @@ class DSLPreprocessor(ast.NodeTransformer): return node - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST: with self.session_data.set_current_class_name(node.name): return self.generic_visit(node) - def _visit_target(self, target): + def _visit_target(self, target: ast.expr) -> None: if isinstance(target, ast.Name): self.session_data.scope_manager.add_to_scope(target.id) elif isinstance(target, ast.Tuple): @@ -1814,29 +1954,37 @@ class DSLPreprocessor(ast.NodeTransformer): if isinstance(t, ast.Name): self.session_data.scope_manager.add_to_scope(t.id) - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign) -> ast.stmt | list[ast.stmt]: for target in node.targets: self._visit_target(target) self.generic_visit(node) return node - def visit_AugAssign(self, node): + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AugAssign | list[ast.stmt]: self._visit_target(node.target) self.generic_visit(node) return node - def visit_AnnAssign(self, node): + def visit_Return(self, node: ast.Return) -> ast.stmt | list[ast.stmt]: + self.generic_visit(node) + return node + + def visit_Expr(self, node: ast.Expr) -> ast.stmt | list[ast.stmt]: + self.generic_visit(node) + return node + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: self._visit_target(node.target) self.generic_visit(node) return node - def visit_Name(self, node): + def visit_Name(self, node: ast.Name) -> ast.Name | ast.Call: isLoad = isinstance(node.ctx, ast.Load) - if node.id in ["max", "min", "any", "all"] and isLoad: + if node.id in ["max", "min", "any", "all", "exec", "eval"] and isLoad: return ast.copy_location( ast.Call( func=_create_module_attribute( - "redirect_builtin_function", + self.BUILTIN_REDIRECTOR, lineno=node.lineno, col_offset=node.col_offset, ), @@ -1851,13 +1999,15 @@ class DSLPreprocessor(ast.NodeTransformer): self.generic_visit(node) return node - def get_dsl_decorator_index(self, decorator_list): + def get_dsl_decorator_index(self, decorator_list: list[ast.expr]) -> Any: for i, d in enumerate(decorator_list): if isinstance(d, ast.Call): if isinstance(d.func, ast.Attribute): if d.func.attr in ["jit", "kernel"]: if d.keywords == []: return i + + # Keep existing preprocess behavior unchanged. for keyword in d.keywords: if keyword.arg == "preprocess": try: @@ -1868,6 +2018,18 @@ class DSLPreprocessor(ast.NodeTransformer): except: pass + keyword_names = { + keyword.arg + for keyword in d.keywords + if keyword.arg is not None + } + + # New behavior for kernel function attributes. + # Limit this expansion to kernel decorator with + # an explicit `attributes=` keyword. + if d.func.attr == "kernel" and "attributes" in keyword_names: + return i + elif isinstance(d, ast.Attribute): if d.attr in ["jit", "kernel"]: return i @@ -1897,7 +2059,7 @@ class DSLPreprocessor(ast.NodeTransformer): return dsl_decorator_index is not None - def remove_dsl_decorator(self, decorator_list): + def remove_dsl_decorator(self, decorator_list: list[ast.expr]) -> list[ast.expr]: """ Remove .jit and .kernel decorators The decorator can be in two forms: @@ -1920,12 +2082,35 @@ class DSLPreprocessor(ast.NodeTransformer): new_decorator_list.append(d) return new_decorator_list - def visit_FunctionDef(self, node): + def visit_Global(self, node: ast.Global) -> None: + raise DSLAstPreprocessorError( + "`global` is not supported in DSL", + suggestion="Please explicitly pass in global variables as arguments", + ) + + def visit_Nonlocal(self, node: ast.Nonlocal) -> ast.Nonlocal: + active_symbols = self.session_data.scope_manager.get_active_symbols() + nonlocal_names = OrderedSet(node.names) + intersect = nonlocal_names.intersections(active_symbols) + for name in node.names: + if name not in intersect: + raise DSLRuntimeError( + ( + f"`{ast.unparse(node)}` is referring to `{name}` which is not tracked by current JIT context, " + "this is not supported in DSL" + ), + suggestion="Please explicitly pass in nonlocal variables as arguments", + ) + self.generic_visit(node) + return node + + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: # Add self to active symbols of parent scope self.session_data.scope_manager.add_to_callables(node.name) - with self.session_data.scope_manager.enter_local_scope(), self.session_data.set_current_function_name( - node.name + with ( + self.session_data.scope_manager.enter_local_scope(), + self.session_data.set_current_function_name(node.name), ): self.session_data.function_counter += 1 @@ -1945,7 +2130,6 @@ class DSLPreprocessor(ast.NodeTransformer): self.generic_visit(node) - # Remove .jit and .kernel decorators node.decorator_list = self.remove_dsl_decorator(node.decorator_list) @@ -1953,16 +2137,17 @@ class DSLPreprocessor(ast.NodeTransformer): node.returns = None return node - def visit_With(self, node): + def visit_With(self, node: ast.With) -> ast.AST: for item in node.items: if isinstance(item.optional_vars, ast.Name): self.session_data.scope_manager.add_to_scope(item.optional_vars.id) return self.generic_visit(node) - def visit_While(self, node): + def visit_While(self, node: ast.While) -> ast.While | list[ast.stmt]: # Constexpr doesn't get preprocessed if self.is_node_constexpr(node): self.generic_visit(node) + assert isinstance(node.test, ast.Call) check = self._insert_cf_symbol_check(node.test.func) return [check, node] @@ -1970,7 +2155,6 @@ class DSLPreprocessor(ast.NodeTransformer): active_callables = self.session_data.scope_manager.get_active_callables() with self.session_data.scope_manager.enter_control_flow_scope(): - # Check for early exit and raise exception self.check_early_exit(node, "while") write_args, full_write_args_count, called_closures = ( @@ -1990,7 +2174,9 @@ class DSLPreprocessor(ast.NodeTransformer): return exprs + [func_def] + assign - def create_cf_call(self, func_name, yield_args, node): + def create_cf_call( + self, func_name: str, yield_args: list[str], node: ast.stmt + ) -> list[ast.stmt]: """Creates the assignment statement for the if function call""" if not yield_args: return [ @@ -2037,13 +2223,15 @@ class DSLPreprocessor(ast.NodeTransformer): else: return [ast.copy_location(assign, node)] - def _visit_Comprehension(self, node, ele_visitor): + def _visit_Comprehension( + self, node: _ComprehensionT, ele_visitor: Callable[..., Any] + ) -> _ComprehensionT: node.generators = [self.visit(generator) for generator in node.generators] - targets = [] + targets: list[str] = [] class NameCollector(ast.NodeVisitor): - def visit_Name(self, node): + def visit_Name(self, node: ast.Name) -> None: if isinstance(node.ctx, ast.Store): targets.append(node.id) @@ -2058,14 +2246,14 @@ class DSLPreprocessor(ast.NodeTransformer): self.session_data.generator_targets = [] return node - def visit_DictComp(self, node): - def key_value_visitor(n): + def visit_DictComp(self, node: ast.DictComp) -> ast.DictComp: + def key_value_visitor(n: ast.DictComp) -> None: n.key = self.visit(n.key) n.value = self.visit(n.value) return self._visit_Comprehension(node, key_value_visitor) - def visit_Lambda(self, node): + def visit_Lambda(self, node: ast.Lambda) -> ast.Lambda: current_lambda_args = len(self.session_data.lambda_args) for arg in node.args.args: self.session_data.lambda_args.append(arg.arg) @@ -2078,22 +2266,22 @@ class DSLPreprocessor(ast.NodeTransformer): return node - def visit_ListComp(self, node): + def visit_ListComp(self, node: ast.ListComp) -> ast.ListComp: return self._visit_Comprehension( node, lambda n: setattr(n, "elt", self.visit(n.elt)) ) - def visit_GeneratorExp(self, node): + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.GeneratorExp: return self._visit_Comprehension( node, lambda n: setattr(n, "elt", self.visit(n.elt)) ) - def visit_SetComp(self, node): + def visit_SetComp(self, node: ast.SetComp) -> ast.SetComp: return self._visit_Comprehension( node, lambda n: setattr(n, "elt", self.visit(n.elt)) ) - def visit_IfExp(self, node): + def visit_IfExp(self, node: ast.IfExp) -> ast.Call: """ Transforms an inline if-else (ternary) expression into runtime-dispatched control flow using synthesized function definitions for each branch. @@ -2209,13 +2397,13 @@ class DSLPreprocessor(ast.NodeTransformer): "NotIn": "not in", } - def compare_ops_to_str(self, node): - names = [ + def compare_ops_to_str(self, node: ast.Compare) -> ast.List: + names: list[ast.expr] = [ ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops ] return ast.List(elts=names, ctx=ast.Load()) - def visit_Compare(self, node): + def visit_Compare(self, node: ast.Compare) -> ast.Call: self.generic_visit(node) comparator_strs = self.compare_ops_to_str(node) @@ -2239,10 +2427,11 @@ class DSLPreprocessor(ast.NodeTransformer): return call - def visit_If(self, node): + def visit_If(self, node: ast.If) -> ast.If | list[ast.stmt]: # const_expr doesn't get preprocessed if self.is_node_constexpr(node): self.generic_visit(node) + assert isinstance(node.test, ast.Call) check = self._insert_cf_symbol_check(node.test.func) return [check, node] @@ -2250,7 +2439,6 @@ class DSLPreprocessor(ast.NodeTransformer): active_callables = self.session_data.scope_manager.get_active_callables() with self.session_data.scope_manager.enter_control_flow_scope(): - # Check for early exit and raise exception self.check_early_exit(node, "if") yield_args, full_write_args_count, called_closures = ( @@ -2270,7 +2458,7 @@ class DSLPreprocessor(ast.NodeTransformer): return exprs + [func_def] + assign - def generate_get_locals_or_none_call(self, write_args): + def generate_get_locals_or_none_call(self, write_args: list[str]) -> ast.Call: return ast.Call( func=_create_module_attribute("get_locals_or_none"), args=[ @@ -2285,14 +2473,20 @@ class DSLPreprocessor(ast.NodeTransformer): keywords=[], ) - def create_if_function(self, func_name, node, write_args, full_write_args_count): + def create_if_function( + self, + func_name: str, + node: ast.If, + write_args: list[str], + full_write_args_count: int, + ) -> ast.FunctionDef: test_expr = self.visit(node.test) pred_name = self.make_func_param_name("pred", write_args) func_args = [ast.arg(arg=pred_name, annotation=None)] func_args += [ast.arg(arg=var, annotation=None) for var in write_args] func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args] - then_body = [] + then_body: list[ast.stmt] = [] with ( Region(self.session_data, new_value=then_body), self.session_data.scope_manager.enter_control_flow_scope(), @@ -2420,6 +2614,7 @@ class DSLPreprocessor(ast.NodeTransformer): # And under both cases, the `pred` can be a const_expr, so we need to handle it here. if self.is_node_constexpr(elif_node): self.generic_visit(elif_node) + assert isinstance(elif_node.test, ast.Call) check = self._insert_cf_symbol_check(elif_node.test.func) else_block = ast.FunctionDef( name=else_block_name, @@ -2448,7 +2643,7 @@ class DSLPreprocessor(ast.NodeTransformer): decorator_list=[], ) else: - else_body = [] + else_body: list[ast.stmt] = [] with ( Region(self.session_data, new_value=else_body), self.session_data.scope_manager.enter_control_flow_scope(), @@ -2511,7 +2706,13 @@ class DSLPreprocessor(ast.NodeTransformer): node, ) - def create_while_function(self, func_name, node, write_args, full_write_args_count): + def create_while_function( + self, + func_name: str, + node: ast.While, + write_args: list[str], + full_write_args_count: int, + ) -> ast.FunctionDef: """Create a while function that looks like: @while_selector(pred, write_args=[]) @@ -2585,7 +2786,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) # Section: while_before_block FunctionDef, which contains condition - while_before_stmts = [] + while_before_stmts: list[ast.stmt] = [] with Region(self.session_data, new_value=while_before_stmts): test_expr = ast.copy_location(self.visit(node.test), node.test) @@ -2605,7 +2806,7 @@ class DSLPreprocessor(ast.NodeTransformer): ) # Section: while_after_block FunctionDef, which contains loop body - while_after_stmts = [] + while_after_stmts: list[ast.stmt] = [] with Region(self.session_data, new_value=while_after_stmts): for stmt in node.body: transformed_stmt = self.visit( diff --git a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py index 167c94b32..7c34f7749 100644 --- a/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py +++ b/python/CuTeDSL/cutlass/base_dsl/cache_helpers.py @@ -21,7 +21,8 @@ import uuid import random import tempfile import time -from typing import Any, Optional +from typing import Any +from collections.abc import Callable from pathlib import Path import hashlib from functools import lru_cache @@ -40,7 +41,7 @@ from .._mlir import ir -def get_current_user(): +def get_current_user() -> str: """ Get the current user. This is used to determine the path to the cache directory. """ @@ -56,34 +57,43 @@ def get_current_user(): raise -def normalize_path(path): +def normalize_path(path: str | Path) -> Path: """ Normalize a path to its full long form. """ return Path(path).resolve() -# default_generated_ir_path is the path to the cache directory. -# If `CUTE_DSL_CACHE_DIR` is set, it is used as the cache directory. -# Otherwise, it is set to a directory controled by TMPDIR defaulting -# to /tmp/${USER}/cutlass_python_cache. -if not (default_generated_ir_path := os.getenv("CUTE_DSL_CACHE_DIR", None)): +def get_default_generated_ir_path(dsl_name: str = "CUTE_DSL") -> str: + """ + Return the cache directory path. + """ + if path := os.getenv(f"{dsl_name}_CACHE_DIR", None): + return path tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir())) - def get_reusable_temp_dir(name): - path = tmp_dir / f"{get_current_user()}/{name}" - path.mkdir(parents=True, exist_ok=True) - return str(path) + def get_reusable_temp_dir(name: str) -> str: + p = tmp_dir / f"{get_current_user()}/{name}" + p.mkdir(parents=True, exist_ok=True) + return str(p) try: default_generated_ir_path = get_reusable_temp_dir("cutlass_python_cache") except Exception as e: - default_generated_ir_path = str(tmp_dir / "cutlass_python_cache") - print(f"Could not determine user, using default path. Error: {e}") + fallback = str(tmp_dir / "cutlass_python_cache") + log().warning( + f"Could not determine user or create cache directory, using fallback path {fallback}. Error: {e}" + ) + return fallback + + return default_generated_ir_path + + +default_generated_ir_path = get_default_generated_ir_path() @lru_cache(maxsize=1) -def get_default_file_dump_root(): +def get_default_file_dump_root() -> Path: """ Get the default file dump root. """ @@ -91,7 +101,7 @@ def get_default_file_dump_root(): return dump_root -def write_bytecode_with_crc32(f, module): +def write_bytecode_with_crc32(f: io.BufferedIOBase, module: ir.Module) -> None: """Write the bytecode to the file and calculate the crc32 checksum. :param f: The file to write the bytecode to. @@ -108,7 +118,7 @@ def write_bytecode_with_crc32(f, module): return -def read_bytecode_and_check_crc32(f): +def read_bytecode_and_check_crc32(f: io.BufferedReader) -> ir.Module: """ Read the bytecode from the file and check the crc32 checksum. @@ -134,7 +144,11 @@ def read_bytecode_and_check_crc32(f): return ir.Module.parse(bytecode) -def load_ir(file, asBytecode=False, bytecode_reader=None): +def load_ir( + file: str, + asBytecode: bool = False, + bytecode_reader: Callable[..., Any] | None = None, +) -> tuple[str, ir.Module]: """Load generated IR from a file. :param file: The path to the file to load. @@ -157,7 +171,7 @@ def load_ir(file, asBytecode=False, bytecode_reader=None): return func_name, module -def make_unique_filename(fpath: Path, new_ext: str = None) -> Path: +def make_unique_filename(fpath: Path, new_ext: str | None = None) -> Path: """ Generate a unique filename with an optional new extension. @@ -178,12 +192,13 @@ def make_unique_filename(fpath: Path, new_ext: str = None) -> Path: def save_ir( dsl_name: str, - module: object, + module: ir.Module, fname: str, output_dir: str | None = None, as_bytecode: bool = False, - bytecode_writer: callable = None, -) -> str: + bytecode_writer: Callable[..., Any] | None = None, + enable_debug_info: bool = True, +) -> Path: """Save generated IR to a file. :param dsl_name: The name of the DSL. @@ -198,6 +213,8 @@ def save_ir( :type as_bytecode: bool, optional :param bytecode_writer: The bytecode writer to use, defaults to None :type bytecode_writer: callable, optional + :param enable_debug_info: Whether to include location info in the IR, defaults to True + :type enable_debug_info: bool, optional :return: The path to the saved file :rtype: str """ @@ -221,9 +238,7 @@ def save_ir( module.operation.write_bytecode(f) else: with open(temp_fname, "w") as f: - # Always save with the locations in the MLIR assembly textual - # representation. - print(module.operation.get_asm(enable_debug_info=True), file=f) + print(module.operation.get_asm(enable_debug_info=enable_debug_info), file=f) # os.replace is guaranteed to be atomic on POSIX systems if it succeeds # so filepath cannot see a partial write os.replace(temp_fname, save_fname) @@ -233,8 +248,11 @@ def save_ir( def load_cache_from_path( - dsl_name, file, path=default_generated_ir_path, bytecode_reader=None -): + dsl_name: str, + file: str, + path: str | None = None, + bytecode_reader: Callable[..., Any] | None = None, +) -> JitCompiledFunction | None: """Load cache from a directory path. :param dsl_name: The name of the DSL. @@ -248,6 +266,8 @@ def load_cache_from_path( :return: The cache :rtype: dict """ + if path is None: + path = get_default_generated_ir_path(dsl_name) if not os.path.exists(path): return None ret = None @@ -259,7 +279,7 @@ def load_cache_from_path( asBytecode=True, bytecode_reader=bytecode_reader, ) - ret = JitCompiledFunction(module, None, None, None, None, [], False, None) + ret = JitCompiledFunction(module, None, None, None, None, [], False, None) # type: ignore[arg-type] except Exception as e: log().warning( f"{dsl_name} failed with loading generated IR cache for {file}.", e @@ -268,12 +288,12 @@ def load_cache_from_path( def dump_cache_to_path( - dsl_name, - jit_function, - file, - path=default_generated_ir_path, - bytecode_writer=None, -): + dsl_name: str, + jit_function: JitCompiledFunction, + file: str, + path: str | None = None, + bytecode_writer: Callable[..., Any] | None = None, +) -> None: """Dump the cache to a directory path. :param dsl_name: The name of the DSL. @@ -288,8 +308,8 @@ def dump_cache_to_path( :type bytecode_writer: callable, optional """ log().info("JIT cache : dumping [%s] file=[%s]", dsl_name, file) - if not path: - path = default_generated_ir_path + if path is None: + path = get_default_generated_ir_path(dsl_name) os.makedirs(path, exist_ok=True) try: save_ir( @@ -320,7 +340,9 @@ class JitCacheDict: If None, the cache is unlimited. Default is None. :type max_elems: int | None """ - self._dict = OrderedDict() if max_elems is not None else dict() + self._dict: OrderedDict[Any, tuple[Any, weakref.finalize | None]] = ( + OrderedDict() if max_elems is not None else dict() # type: ignore[assignment] + ) self.max_elems = max_elems def get(self, key: Any) -> Any | None: @@ -398,7 +420,9 @@ class JitCacheDict: if old_finalize is not None: old_finalize.detach() - def _remove_entry(k: Any, self_ref=weakref.ref(self)) -> None: + def _remove_entry( + k: Any, self_ref: weakref.ref[JitCacheDict] = weakref.ref(self) + ) -> None: # Called from GC/finalizer; be defensive and avoid raising. self_obj = self_ref() if self_obj is not None: diff --git a/python/CuTeDSL/cutlass/base_dsl/common.py b/python/CuTeDSL/cutlass/base_dsl/common.py index eb0685456..713bc7178 100644 --- a/python/CuTeDSL/cutlass/base_dsl/common.py +++ b/python/CuTeDSL/cutlass/base_dsl/common.py @@ -9,7 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +import inspect import os +import subprocess +import sys +import types from typing import Any, Dict, Optional, Union from functools import total_ordering from dataclasses import dataclass @@ -19,6 +23,55 @@ This module provides a Exception classes DSL class for any Dialect. """ +# Store the original exception hook +_original_excepthook = sys.excepthook + +# Store registered environment manager (set by DSL singleton) +_registered_env_manager = None + + +def register_env_manager(env_manager: Any) -> None: + """Register an EnvironmentVarManager instance for use by exception handling. + + Called by DSL singleton when it initializes. + """ + global _registered_env_manager + _registered_env_manager = env_manager + + +def _dsl_excepthook( + exc_type: type, + exc_value: BaseException, + exc_traceback: Optional[types.TracebackType], +) -> None: + """ + Custom exception hook that shows clean error messages for DSL exceptions. + For DSLOperationError, shows only the formatted message without traceback. + For other exceptions, uses the default Python traceback. + """ + # Check if show_stacktrace is enabled via registered env manager + show_stacktrace = False + if _registered_env_manager is not None: + show_stacktrace = getattr(_registered_env_manager, "show_stacktrace", False) + + # Check if it's a DSL operation error (by name to avoid circular import issues) + if exc_type.__name__ in ("DSLOperationError", "DSLOperationBuildError"): + if show_stacktrace: + # Show full traceback in verbose mode + _original_excepthook(exc_type, exc_value, exc_traceback) + else: + # Just print the formatted message (which is in __str__) + print(str(exc_value), file=sys.stderr) + sys.exit(1) + else: + # Use the original exception hook for other exceptions + _original_excepthook(exc_type, exc_value, exc_traceback) + + +# Install the custom exception hook +sys.excepthook = _dsl_excepthook + + # Add color codes at the top of the file after imports class Colors: """ANSI color codes for error messages""" @@ -50,7 +103,7 @@ class DSLBaseError(Exception): filename: Optional[str] = None, error_code: Optional[Union[str, int]] = None, context: Optional[Union[Dict[str, Any], str]] = None, - suggestion: Optional[str] = None, + suggestion: Union[str, list[str], tuple[str, ...], None] = None, cause: Optional[BaseException] = None, ) -> None: self.message = message @@ -64,7 +117,15 @@ class DSLBaseError(Exception): super().__init__(self._format_message()) - def _format_message(self): + def _generate_cause(self) -> str: + """ + Generates a string representation of the cause of the error, if available. + """ + if self.cause: + return f"Caused exception: {self.cause}" + return "" + + def _format_message(self) -> str: """ Formats the complete error message with available metadata. Override this in subclasses if you want to change formatting logic. @@ -84,8 +145,9 @@ class DSLBaseError(Exception): # Optionally truncate long snippets for readability parts.append(f" Snippet: \n {self.snippet}") - if self.cause: - parts.append(f" Caused exception: {self.cause}") + cause = self._generate_cause() + if cause: + parts.append(cause) if self.context: if isinstance(self.context, dict): @@ -108,6 +170,26 @@ class DSLBaseError(Exception): return "\n".join(parts) +class DSLSubprocessCallError(DSLBaseError): + """ + Raised when an error occurs during a subprocess call in the DSL. + """ + + def _generate_cause(self) -> str: + assert isinstance(self.cause, subprocess.CalledProcessError), ( + "cause must be a subprocess.CalledProcessError" + ) + cause = [] + cause.append(f" Caused exception: {self.cause}") + cause.append( + f" Command: \033[93m{' '.join(str(item) for item in self.cause.cmd)}\033[0m" + ) + cause.append(f" Return code: {self.cause.returncode}") + cause.append(f" stdout: {self.cause.stdout}") + cause.append(f" stderr: {Colors.BOLD}{self.cause.stderr}{Colors.RESET}") + return "\n".join(cause) + + class DSLRuntimeError(DSLBaseError): """ Raised when an error occurs during JIT-time code generation in the DSL. @@ -118,7 +200,9 @@ class DSLRuntimeError(DSLBaseError): pass -def _get_friendly_cuda_error_message(error_code, error_name): +def _get_friendly_cuda_error_message( + error_code: int, error_name: Union[str, bytes] +) -> tuple[str, str, Union[str, tuple[str, ...]]]: # Avoid circular dependency from .runtime.cuda import get_device_info @@ -166,40 +250,40 @@ def _get_friendly_cuda_error_message(error_code, error_name): error_suggestions = { "CUDA_ERROR_INVALID_CONTEXT": ( - f"1. Check if CUDA context is properly initialized under your environment", - f"2. Initialize CUDA context with `cuda.cuInit(0)` or `cutlass.cuda.initialize_cuda_context()`", + "1. Check if CUDA context is properly initialized under your environment", + "2. Initialize CUDA context with `cuda.cuInit(0)` or `cutlass.cuda.initialize_cuda_context()`", ), "CUDA_ERROR_INVALID_SOURCE": ( - f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture", - f"2. Clear the compilation cache and regenerate the kernel", - f"3. Check CUDA toolkit installation", + "1. Ensure env CUTE_DSL_ARCH matches your GPU architecture", + "2. Clear the compilation cache and regenerate the kernel", + "3. Check CUDA toolkit installation", ), "CUDA_ERROR_NO_BINARY_FOR_GPU": ( - f"Set env CUTE_DSL_ARCH to match your GPU architecture", + "Set env CUTE_DSL_ARCH to match your GPU architecture", ), "CUDA_ERROR_OUT_OF_MEMORY": ( - f"1. Reduce batch size", - f"2. Reduce model size", - f"3. Free unused GPU memory", + "1. Reduce batch size", + "2. Reduce model size", + "3. Free unused GPU memory", ), "CUDA_ERROR_INVALID_DEVICE": ( - f"1. Check if CUDA device is properly initialized", - f"2. Verify GPU is detected: nvidia-smi", - f"3. Check CUDA_VISIBLE_DEVICES environment variable", + "1. Check if CUDA device is properly initialized", + "2. Verify GPU is detected: nvidia-smi", + "3. Check CUDA_VISIBLE_DEVICES environment variable", ), "CUDA_ERROR_NOT_INITIALIZED": ( - f"1. Check CUDA driver installation", - f"2. call `cuda.cuInit(0)` before any other CUDA operation", - f"3. Run nvidia-smi to confirm GPU status", + "1. Check CUDA driver installation", + "2. call `cuda.cuInit(0)` before any other CUDA operation", + "3. Run nvidia-smi to confirm GPU status", ), "CUDA_ERROR_INVALID_VALUE": ( - f"1. Your GPU model", - f"2. SM ARCH setting", - f"3. Steps to reproduce", + "1. Your GPU model", + "2. SM ARCH setting", + "3. Steps to reproduce", ), "cudaErrorInsufficientDriver": ( - f"1. Run nvidia-smi to confirm CUDA driver version", - f"2. Ensure the CUDA driver version meets the requirement of the installed cuda-python package", + "1. Run nvidia-smi to confirm CUDA driver version", + "2. Ensure the CUDA driver version meets the requirement of the installed cuda-python package", ), } @@ -255,7 +339,7 @@ class DSLCudaRuntimeError(DSLBaseError): # Inherits all logic from DSLRuntimeError; override methods if you need # specialized behavior or formatting for runtime errors. - def __init__(self, error_code, error_name) -> None: + def __init__(self, error_code: int, error_name: Union[str, bytes]) -> None: self._error_code = error_code self._error_name = error_name message, debug_info, suggestion = _get_friendly_cuda_error_message( @@ -286,46 +370,298 @@ class DSLNotImplemented(DSLBaseError): pass -class CudaDriverDependencyError(DSLRuntimeError): - """Custom error class for CUDA driver dependency issues""" +def translate_mlir_nanobind_error(exc: BaseException) -> str: + """ + Translate nanobind/MLIR exceptions into user-friendly messages. + + Nanobind exceptions from MLIR C++ bindings: + - nb::value_error -> ValueError + - nb::type_error -> TypeError + - nb::cast_error -> RuntimeError (usually) + - nb::python_error -> Various Python exceptions + + Returns: + tuple of (translated_message, None, original_message) + Note: suggestions are None - only show if explicitly provided + """ + exc_type = type(exc).__name__ + error_msg = str(exc).lower() + original = str(exc) + + # Type casting errors (nb::cast_error, std::bad_cast) + if "std::bad_cast" in error_msg or "cast" in exc_type.lower(): + if "must be a type" in error_msg: + return "Type mismatch: The operation expected a different type than what was provided" + + return "Type casting failed: Cannot convert between incompatible types" + + # Value errors (nb::value_error) + if exc_type == "ValueError": + if "verification" in error_msg or "failed to verify" in error_msg: + return "MLIR operation verification failed: The operation constraints are not satisfied" + + if "result" in error_msg and "operation" in error_msg: + return "Invalid operation result type: The operation produced an unexpected type" + + if "attribute" in error_msg: + return "Invalid attribute: Attribute value or type is incorrect" + + # Type errors (nb::type_error) + if exc_type == "TypeError": + if "argument" in error_msg or "parameter" in error_msg: + return "Wrong argument type: Function received an incompatible type" + + # Runtime errors (often from nb::cast_error) + if exc_type == "RuntimeError": + if "operand" in error_msg: + return ( + "Invalid operand: Operation received wrong number or type of operands" + ) + + if "not registered" in error_msg or "unknown" in error_msg: + return "Operation or dialect not found" + + # Generic fallback + return f"{exc_type}: {original}" + + +class DSLUserCodeError(DSLBaseError): + """Raised when an error is detected in user DSL code. + + Covers mutation violations, scope errors, type mismatches, and similar + user-facing diagnostics. Takes explicit ``filename`` and ``lineno`` -- + no ``inspect.stack()`` magic inside the class. + + Usage:: + + raise DSLUserCodeError( + "Scope Error: variable `a` escapes its scope", + filename="/path/to/user.py", + lineno=42, + suggestion="Define the variable before the loop.", + ) + """ def __init__( self, message: str, - ): - # Create a detailed error message with instructions - detailed_message = f"""CUDA Driver Dependency Error + filename: Optional[str] = None, + lineno: Optional[int] = None, + col_offset: Optional[int] = None, + cause: Optional[BaseException] = None, + suggestion: Optional[Union[str, list]] = None, + context: Optional[Union[Dict[str, Any], str]] = None, + ) -> None: + snippet = None + if filename and lineno: + snippet = self._read_source_snippet(filename, lineno, col_offset) -{message} - -This error typically occurs when: -• NVIDIA GPU drivers are not installed on your system -• The installed drivers are incompatible with CUDA Toolkit 12.9 or latest version -• The libcuda.so.1 library is not accessible""" - - # Use DSLRuntimeError's structured approach super().__init__( - detailed_message, - suggestion=[ - "Install or update NVIDIA GPU drivers:", - " • Visit: https://www.nvidia.com/Download/index.aspx", - " • Download drivers compatible with CUDA Toolkit 12.9 or latest version", - " • Follow the installation instructions for your OS", - "", - "Verify driver installation:", - " • Run: nvidia-smi", - " • This should display GPU information without errors", - "", - "Check CUDA library availability:", - " • Run: ldconfig -p | grep libcuda", - " • This should show libcuda.so.1 in the output", - "", - "For more information, see:", - " • CUDA Toolkit documentation: https://docs.nvidia.com/cuda/", - " • CUTLASS DSL requirements: nvidia-cutlass-dsl documentation", - ], + message, + line=lineno, + filename=filename, + snippet=snippet, + cause=cause, + suggestion=suggestion, + context=context, ) + @staticmethod + def _read_source_snippet( + filename: str, + lineno: int, + col_offset: Optional[int] = None, + ) -> Optional[str]: + """Read a single source line and format it as a snippet.""" + try: + import linecache + + code_line = linecache.getline(filename, lineno).rstrip() + if not code_line: + return None + snippet = f" {lineno:4d} | {code_line}" + if col_offset is not None: + snippet += f"\n | {' ' * col_offset}^" + return snippet + except Exception: # noqa: BLE001 — best-effort snippet + return None + + def _format_message(self) -> str: + """Format a rich error message with code snippet and suggestions.""" + parts = [] + + parts.append( + f"\n{Colors.RED}{Colors.BOLD}[Error] {self.message}{Colors.RESET}\n" + ) + + if self.snippet and self.filename: + loc = f"{self.filename}:{self.line}" if self.line else self.filename + parts.append(f"{Colors.BLUE}Code:{Colors.RESET}") + parts.append(f"--> {Colors.BLUE}{loc}{Colors.RESET}") + parts.append(self.snippet) + parts.append("") + + if self.cause: + parts.append(f"{Colors.BLUE}Cause:{Colors.RESET} {self.cause}") + parts.append("") + + if self.context: + if isinstance(self.context, dict): + parts.append(f"{Colors.BLUE}Additional Context:{Colors.RESET}") + for key, value in self.context.items(): + parts.append(f" {key}: {value}") + else: + parts.append( + f"{Colors.BLUE}Additional Context:{Colors.RESET} {self.context}" + ) + parts.append("") + + if self.suggestion: + parts.append(f"{Colors.GREEN}Suggestion:{Colors.RESET}") + if isinstance(self.suggestion, (list, tuple)): + for s in self.suggestion: + parts.append(f" {Colors.GREEN}{s}{Colors.RESET}") + else: + parts.append(f" {self.suggestion}") + parts.append("") + + parts.append("=" * 100) + return "\n".join(parts) + + +class DSLOperationBuildError(DSLBaseError): + """ + Raised when an error occurs during a DSL operation with formatted source location. + This exception provides a nicely formatted error message showing the exact line + of user code that caused the error. + """ + + def __init__( + self, + message: str, + cause: Optional[BaseException] = None, + frameInfo: Optional[inspect.Traceback] = None, + auto_translate: bool = True, + ) -> None: + """ + Args: + message: The error message to display + cause: The underlying exception that caused this error + frameInfo: Optional frame info from inspect.getframeinfo() - if not provided, + automatically captures the caller's frame + auto_translate: If True, attempt to translate MLIR/nanobind errors + """ + import inspect + + # If frameInfo not provided, capture the caller's frame information + if frameInfo is None: + current_frame = inspect.currentframe() + frame = current_frame.f_back if current_frame else None + frameInfo = inspect.getframeinfo(frame) if frame else None + + # Try to translate MLIR/nanobind errors if no custom message provided + self.original_error = str(message) + if auto_translate and cause: + translated_msg = translate_mlir_nanobind_error(cause) + if translated_msg != str(cause): + message = translated_msg + + self.frameInfo = frameInfo + + # Extract line and filename from frameInfo + line = frameInfo.lineno if frameInfo else None + filename = frameInfo.filename if frameInfo else None + snippet = None + if frameInfo and frameInfo.code_context: + lineno = frameInfo.lineno + code_line = frameInfo.code_context[0].rstrip() + snippet = f" {lineno:4d} | {code_line}" + + # Add column pointer if available (Python 3.11+) + if ( + hasattr(frameInfo, "positions") + and frameInfo.positions.col_offset is not None # type: ignore[attr-defined] + ): + col = frameInfo.positions.col_offset # type: ignore[attr-defined] + snippet += f"\n | {' ' * col}^" + + super().__init__( + message, + line=line, + filename=filename, + snippet=snippet, + cause=cause, + ) + + def _collect_dsl_errors( + self, + ) -> tuple[ + list[tuple["DSLOperationBuildError", str, str]], Optional[BaseException] + ]: + """ + Recursively collect all DSLOperationErrors in the exception chain. + Returns a tuple of (list of unique errors with snippets, final non-DSLOperationError cause). + Deduplicates by (filename, lineno) to avoid redundant output. + """ + errors_with_snippets = [] + seen_locations = set() + current = self + + while current: + # Add error if it has a snippet to show and location is unique + if current.snippet and current.filename and current.line: + location_key = (current.filename, current.line) + if location_key not in seen_locations: + seen_locations.add(location_key) + errors_with_snippets.append( + (current, current.snippet, current.filename) + ) + # Check if cause is also a DSLOperationError + if current.cause and isinstance(current.cause, DSLOperationBuildError): + current = current.cause + else: + # Found the final cause (not a DSLOperationError) + break + + return errors_with_snippets, current.cause if current else None + + def _format_message(self) -> str: + """Formats the error message with nice visual presentation.""" + parts = [] + + # Collect all DSLOperationErrors in the chain recursively + dsl_errors, final_cause = self._collect_dsl_errors() + + # Show error header with the root cause message + error_msg = self.message + if final_cause: + error_msg = f"{type(final_cause).__name__}: {final_cause}" + parts.append(f"\n{Colors.RED}{Colors.BOLD}[Error] {error_msg}{Colors.RESET}\n") + + # Show the actual traceback first (where the error originated) + if final_cause: + import traceback + + tb = final_cause.__traceback__ + if tb: + parts.append(f"{Colors.BLUE}📍 Exception Origin:{Colors.RESET}") + # Format the traceback from the original exception + tb_lines = traceback.format_tb(tb) + for line in tb_lines: + parts.append(line.rstrip()) + parts.append("") + + # Show unique code snippets from DSL call chain (user code locations) + if dsl_errors: + parts.append(f"{Colors.BLUE}📋 DSL Call Stack:{Colors.RESET}") + for error, snippet, filename in dsl_errors: + parts.append(f"--> {Colors.BLUE}{filename}{Colors.RESET}") + parts.append(snippet) + parts.append("") + + parts.append("=" * 100) + return "\n".join(parts) + def _get_cuda_version() -> str: # Client of this module should implement this function @@ -356,10 +692,12 @@ class DSLCudaVersion: object.__setattr__(self, "major", int(parts[0])) object.__setattr__(self, "minor", int(parts[1])) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, DSLCudaVersion): + return NotImplemented return self.major == other.major and self.minor == other.minor - def __lt__(self, other): + def __lt__(self, other: "DSLCudaVersion") -> bool: return [self.major, self.minor] < [other.major, other.minor] diff --git a/python/CuTeDSL/cutlass/base_dsl/compiler.py b/python/CuTeDSL/cutlass/base_dsl/compiler.py index 41dbb957c..2edacfb34 100644 --- a/python/CuTeDSL/cutlass/base_dsl/compiler.py +++ b/python/CuTeDSL/cutlass/base_dsl/compiler.py @@ -15,11 +15,13 @@ and executes it using MLIR's ExecutionEngine. """ -from typing import Sequence, Optional, Tuple, Callable +from typing import Any +import collections.abc import os import sys import inspect -from .common import DSLRuntimeError, CudaDriverDependencyError +import types +from .common import DSLRuntimeError from .utils.logger import log from .env_manager import EnvironmentVarManager @@ -48,14 +50,12 @@ class CompilationError(RuntimeError): def __init__( self, message: str, - nvvm_error: Optional[str] = None, - ir_context: Optional[str] = None, - cuda_toolkit: Optional[str] = None, - arch: Optional[str] = None, - ): + nvvm_error: str | None = None, + ir_context: str | None = None, + arch: str | None = None, + ) -> None: self.nvvm_error = nvvm_error self.ir_context = ir_context - self.cuda_toolkit = cuda_toolkit self.arch = arch # Call parent with formatted error to avoid showing class name super().__init__("") # Empty string to avoid class name @@ -79,8 +79,7 @@ class CompilationError(RuntimeError): ---------------------- {self.BLUE}⚙️ Current Settings:{self.RESET} -{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"} -- Target Architecture: {self.arch}{self.RESET} +{self.BOLD}- Target Architecture: {self.arch}{self.RESET} IR Context (truncated): {self.ir_context} @@ -94,15 +93,12 @@ IR Context (truncated): class Compiler: """Compiler class for compiling and building MLIR modules.""" - def __init__(self, passmanager, execution_engine): + def __init__(self, passmanager: Any, execution_engine: Any) -> None: self.passmanager = passmanager self.execution_engine = execution_engine - # Flag to track if CUDA dependencies have been checked once in this process - self._cuda_dependencies_checked = False - # Post-compile hook to run on Module - self._post_compile_hook: Optional[Callable[[ir.Module], None]] = None + self._post_compile_hook: collections.abc.Callable[[Any], None] | None = None - def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]: + def _process_error(self, error_msg: str) -> tuple[str | None, str | None]: """Process error message to extract NVVM error and IR context""" nvvm_error = None ir_msg = "" @@ -135,16 +131,17 @@ class Compiler: def compile( self, - module, + module: ir.Module, pipeline: str, - cuda_toolkit: str = "", arch: str = "", - enable_verifier=False, - ): + enable_debug_info: bool = False, + enable_verifier: bool = False, + ) -> None: """Compiles the module by invoking the pipeline.""" try: pm = self.passmanager.PassManager.parse(pipeline) pm.enable_verifier(enable_verifier) + pm.run(module.operation) except Exception as e: error_msg = str(e) @@ -155,85 +152,69 @@ class Compiler: error_msg, nvvm_error=nvvm_error, ir_context=ir_msg, - cuda_toolkit=cuda_toolkit, arch=arch, ) from e raise e + finally: + pass if self._post_compile_hook: self._post_compile_hook(module) - def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()): + def jit( + self, + module: ir.Module, + opt_level: int = 2, + shared_libs: collections.abc.Sequence[str] = (), + ) -> Any: """Wraps the module in a JIT execution engine.""" - # Check CUDA driver and GPU dependencies before JIT execution (once per process) - self._check_cuda_dependencies_once(shared_libs) - # If pre-checks passed, attempt to create ExecutionEngine - # Any failures at this point are likely non-CUDA related return self.execution_engine.ExecutionEngine( module, opt_level=opt_level, shared_libs=shared_libs ) def compile_and_jit( self, - module, + module: ir.Module, pipeline: str, - shared_libs: Sequence[str] = (), + shared_libs: collections.abc.Sequence[str] = (), opt_level: int = 2, - cuda_toolkit: str = "", arch: str = "", - ): + enable_debug_info: bool = False, + ) -> Any: """Compiles and jits the module.""" self.compile( module, pipeline, - cuda_toolkit, arch, + enable_debug_info=enable_debug_info, ) return self.jit(module, opt_level, shared_libs) - def _check_cuda_dependencies_once(self, shared_libs: Sequence[str]) -> None: - """ - Check CUDA dependencies only once per process lifecycle. - After the first check (success or failure), skip all subsequent checks - as the runtime environment doesn't change during process execution. - """ - if self._cuda_dependencies_checked: - return # Already checked in this process, skip - - # Mark as checked to skip all future validations - self._cuda_dependencies_checked = True - - # Simple CUDA driver check - just call cuInit(0) - try: - import cuda.bindings.driver as cuda - - cuda.cuInit(0) - except Exception as e: - # Create a comprehensive error message for CUDA driver issues - error_message = ( - "CUDA runtime initialization failed during dependency check." - ) - - raise CudaDriverDependencyError( - message=error_message, - ) - class PostCompileHookContext: """Context manager for post-compile hook for a compiler.""" - def __init__(self, compiler: Compiler, hook: Callable[[ir.Module], None]): + def __init__( + self, + compiler: Compiler, + hook: collections.abc.Callable[[Any], None], + ) -> None: self.compiler = compiler self.hook = hook - self.prev_post_compile_hook = None + self.prev_post_compile_hook: collections.abc.Callable[[Any], None] | None = None - def __enter__(self): + def __enter__(self) -> "PostCompileHookContext": self.prev_post_compile_hook = self.compiler._post_compile_hook self.compiler._post_compile_hook = self.hook return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: self.compiler._post_compile_hook = self.prev_post_compile_hook @@ -242,33 +223,33 @@ class CompileOption: Base class for compile options. """ - option_name = "" # name of the compile option in the pipeline + option_name: str = "" - def __init__(self, val): - self._value = val + def __init__(self, val: Any) -> None: + self._value: Any = val - def serialize(self): + def serialize(self) -> str: return f"{self.__class__.option_name}={self._value}" @property - def value(self): + def value(self) -> Any: return self._value @value.setter - def value(self, value): + def value(self, value: Any) -> None: self._value = value class BooleanCompileOption(CompileOption): - def __init__(self, val: bool = True): + def __init__(self, val: bool = True) -> None: super().__init__(val) - def serialize(self): + def serialize(self) -> str: return f"{self.__class__.option_name}={'true' if self._value else 'false'}" class StringCompileOption(CompileOption): - def serialize(self): + def serialize(self) -> str: if self._value: self._value = self._value.strip("'") return f"{self.__class__.option_name}='{self._value}'" @@ -276,19 +257,19 @@ class StringCompileOption(CompileOption): class BooleanBasedFileDumpOption(CompileOption): - def __init__(self, val: bool = True): + def __init__(self, val: bool = True) -> None: super().__init__(val) - self._dump_path = "" + self._dump_path: str = "" @property - def dump_path(self): + def dump_path(self) -> str: return self._dump_path @dump_path.setter - def dump_path(self, path): + def dump_path(self, path: str) -> None: self._dump_path = path - def serialize(self): + def serialize(self) -> str: if self._value: assert self._dump_path, ( f"Dump path is not set for {self.__class__.__name__}" @@ -298,14 +279,14 @@ class BooleanBasedFileDumpOption(CompileOption): class EmptyCompileOption(CompileOption): - def serialize(self): + def serialize(self) -> str: return "" class OptLevel(CompileOption): option_name = "opt-level" - def __init__(self, val: int): + def __init__(self, val: int) -> None: if val < 0 or val > 3: raise DSLRuntimeError(f"Invalid OPT_LEVEL: {val}, valid range is [0, 3]") super().__init__(val) @@ -340,6 +321,29 @@ class LinkLibraries(StringCompileOption): class GPUArch(StringCompileOption): option_name = "cubin-chip" + def __init__(self, val: str) -> None: + if val == "": + super().__init__(val) + else: + # Avoid circular dependency + from .arch import Arch + + super().__init__(Arch.from_string(val).to_string()) + + @property + def value(self) -> str: + return self._value + + @value.setter + def value(self, value: str) -> None: + if value == "": + self._value = value + else: + # Avoid circular dependency + from .arch import Arch + + self._value = Arch.from_string(value).to_string() + class EnableTVMFFI(EmptyCompileOption): pass @@ -349,6 +353,7 @@ class DumpDir(EmptyCompileOption): option_name = "dump-dir" + class CompileOptions: """ This class encapsulates compilation options to configure the JIT compilation. @@ -357,8 +362,10 @@ class CompileOptions: compilation parameters such as optimization level, debugging control, etc. """ - def __init__(self, options=None): - self.options = { + def __init__( + self, options: "CompileOption | tuple[CompileOption, ...] | None" = None + ) -> None: + self.options: dict[type[CompileOption], CompileOption] = { # Compilation control options OptLevel: OptLevel(3), PtxasOptions: PtxasOptions(""), @@ -376,8 +383,8 @@ class CompileOptions: if options is not None: self._update(options) - def _update(self, options): - def _validate_and_update_option(option): + def _update(self, options: "CompileOption | tuple[CompileOption, ...]") -> None: + def _validate_and_update_option(option: CompileOption) -> None: if type(option) not in self.options: raise DSLRuntimeError(f"Invalid compile option: {option}") self.options[type(option)] = option @@ -388,7 +395,9 @@ class CompileOptions: else: _validate_and_update_option(options) - def apply_envar_settings(self, envar: EnvironmentVarManager, function_name: str): + def apply_envar_settings( + self, envar: EnvironmentVarManager, function_name: str + ) -> None: # Honor the settings from environment variables as well if envar.keep_ptx: self.options[KeepPTX].value = True @@ -413,16 +422,19 @@ class CompileOptions: else self.options[DumpDir].value ) if self.options[KeepPTX].value: - self.options[KeepPTX].dump_path = os.path.join(dump_dir, f"{function_name}") - self.options[KeepPTX].full_ptx_path = os.path.join( - dump_dir, f"{function_name}.{arch}.ptx" + self.options[KeepPTX].dump_path = os.path.join(dump_dir, f"{function_name}") # type: ignore[attr-defined, arg-type] + self.options[KeepPTX].full_ptx_path = os.path.join( # type: ignore[attr-defined] + dump_dir, # type: ignore[arg-type] + f"{function_name}.{arch}.ptx", ) if self.options[KeepCUBIN].value: - self.options[KeepCUBIN].dump_path = os.path.join( - dump_dir, f"{function_name}" + self.options[KeepCUBIN].dump_path = os.path.join( # type: ignore[attr-defined] + dump_dir, # type: ignore[arg-type] + f"{function_name}", ) - self.options[KeepCUBIN].full_cubin_path = os.path.join( - dump_dir, f"{function_name}.{arch}.cubin" + self.options[KeepCUBIN].full_cubin_path = os.path.join( # type: ignore[attr-defined] + dump_dir, # type: ignore[arg-type] + f"{function_name}.{arch}.cubin", ) @property def generate_line_info(self) -> bool: @@ -434,24 +446,28 @@ class CompileOptions: @property def dump_ptx_path(self) -> str | None: - return self.options[KeepPTX].dump_path if self.options[KeepPTX].value else None + return self.options[KeepPTX].dump_path if self.options[KeepPTX].value else None # type: ignore[attr-defined] @property def full_ptx_path(self) -> str | None: return ( - self.options[KeepPTX].full_ptx_path if self.options[KeepPTX].value else None + self.options[KeepPTX].full_ptx_path # type: ignore[attr-defined] + if self.options[KeepPTX].value + else None ) @property def dump_cubin_path(self) -> str | None: return ( - self.options[KeepCUBIN].dump_path if self.options[KeepCUBIN].value else None + self.options[KeepCUBIN].dump_path # type: ignore[attr-defined] + if self.options[KeepCUBIN].value + else None ) @property def full_cubin_path(self) -> str | None: return ( - self.options[KeepCUBIN].full_cubin_path + self.options[KeepCUBIN].full_cubin_path # type: ignore[attr-defined] if self.options[KeepCUBIN].value else None ) @@ -481,14 +497,16 @@ class CompileOptions: return flattend_options -def _parse_compile_options_from_str(options: str) -> CompileOptions: - """ - Parse the compile options from a string. - Deprecated and will be removed in the future. - """ - def _get_compile_option_from_str(option_str: str): - mapping = { +# This is a temp function to preserve backward compatibility. +# To be removed in the future. +def _parse_compile_options_from_str(options: str) -> CompileOptions: + """Parse the compile options from a string.""" + import shlex as _shlex + + _base_compile_options: "CompileOptions | None" = None + def _get_compile_option_from_str(option_str: str) -> type[CompileOption]: + mapping: dict[str, type[CompileOption]] = { "opt_level": OptLevel, "ptxas_options": PtxasOptions, "enable_assertions": EnableAssertions, @@ -516,18 +534,20 @@ def _parse_compile_options_from_str(options: str) -> CompileOptions: parser.add_argument("--gpu-arch", type=str, default="") parser.add_argument("--enable-tvm-ffi", action="store_true", default=False) parser.add_argument("--dump-dir", type=str, default="") - compile_options = CompileOptions() + compile_options = ( + _base_compile_options if _base_compile_options is not None else CompileOptions() + ) try: # Use shlex to properly handle options with spaces - parsed_options = shlex.split(options) if options else [] + parsed_options = _shlex.split(options) if options else [] # Avoid parsing the ptxas-options value as a hyphen key for i in range(1, len(parsed_options)): if parsed_options[i - 1] in ["--ptxas-options"]: parsed_options[i] = f"'{parsed_options[i]}'" option_dict = vars(parser.parse_args(parsed_options)) - for option, value in option_dict.items(): - option = _get_compile_option_from_str(option) - compile_options.options[option].value = value + for option_name, value in option_dict.items(): + option_cls = _get_compile_option_from_str(option_name) + compile_options.options[option_cls].value = value except SystemExit as e: # catch argparse error and raise as DSLRuntimeError raise DSLRuntimeError( @@ -538,8 +558,8 @@ def _parse_compile_options_from_str(options: str) -> CompileOptions: class CompileCallable: - def __init__(self, options=None): - def preprocess_options(option): + def __init__(self, options: Any = None) -> None: + def preprocess_options(option: Any) -> Any: if type(option) is type and issubclass( option, (BooleanCompileOption, BooleanBasedFileDumpOption, EnableTVMFFI) ): @@ -551,17 +571,17 @@ class CompileCallable: self._compile_options = CompileOptions(preprocess_options(options)) - def __getitem__(self, options): + def __getitem__(self, options: Any) -> "CompileCallable": """ Get a new CompileCallable object with the specified options. """ new_callable_with_options = CompileCallable(options) return new_callable_with_options - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self._compile(*args, **kwargs) - def _compile(self, func, *args, **kwargs): + def _compile(self, func: Any, *args: Any, **kwargs: Any) -> Any: """ This function is used to compile a `cute.jit` decorated function. It will process the compile options and input parameters, do explicit compilation and return the jit executor. @@ -589,7 +609,7 @@ class CompileCallable: pass elif inspect.ismethod(func): # if it's a method, add the instance to the first argument - args = [func.__self__] + list(args) + args = [func.__self__] + list(args) # type: ignore[assignment] func = func.__func__ elif ( inspect.isclass(type(func)) @@ -597,7 +617,7 @@ class CompileCallable: and hasattr(func.__call__, "__func__") ): # If it's a class instance, get the class's __call__ method - args = [func] + list(args) + args = [func] + list(args) # type: ignore[assignment] # Get the actual function from the class definition func = func.__call__.__func__ else: diff --git a/python/CuTeDSL/cutlass/base_dsl/dsl.py b/python/CuTeDSL/cutlass/base_dsl/dsl.py index 4b9c50a4a..6d6ec9e57 100644 --- a/python/CuTeDSL/cutlass/base_dsl/dsl.py +++ b/python/CuTeDSL/cutlass/base_dsl/dsl.py @@ -27,19 +27,24 @@ import re import inspect import argparse import hashlib -import weakref from functools import lru_cache, wraps from collections import namedtuple, OrderedDict from abc import ABC, abstractmethod -from typing import Any, Callable, List, ClassVar +from typing import Annotated, Any, ClassVar, TYPE_CHECKING, get_args, get_origin +from collections.abc import Callable from types import SimpleNamespace + +if TYPE_CHECKING: + import hashlib + from .arch import Arch import warnings import threading from . import typing as t -from .env_manager import EnvironmentVarManager -from .compiler import CompileOptions +from .env_manager import EnvironmentVarManager, is_cutlass_family_dsl_prefix +from .compiler import CompileOptions, LinkLibraries from .ast_helpers import DSLOptimizationWarning +from .common import register_env_manager # ============================================================================= # Local module imports @@ -52,14 +57,20 @@ from .utils.logger import log from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe from .runtime.jit_arg_adapters import ( is_argument_constexpr, - is_arg_spec_constexpr, + is_arg_annotation_constexpr, JitArgAdapterRegistry, ) from .ast_preprocessor import DSLPreprocessor from .common import * -from .typing import get_c_pointers, get_mlir_types, Integer -from .arch import Arch +from .typing import ( + get_c_pointers, + get_mlir_types, + Integer, + implements_dynamic_expression, + implements_jit_argument, +) +from ._mlir_helpers.op import _set_enable_frame_filtering # ============================================================================= # MLIR modules @@ -74,13 +85,12 @@ from .._mlir.dialects import func MLIR_DYNAMIC = -9223372036854775808 - # ============================================================================= # Main DSL Class # ============================================================================= -def is_dynamic_expression(value): +def is_dynamic_expression(value: object) -> bool: """ Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value """ @@ -95,35 +105,309 @@ def is_dynamic_expression(value): return False -def extract_mlir_values(obj): +def extract_mlir_values(obj: object, *, structured: bool = False) -> Any: """ - Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values + Given the `obj`, recursively go through it to extract all contained IR values. + + Args: + obj: The object to extract MLIR values from + structured: If False (default), returns a flat list of MLIR values. + If True, returns whatever __extract_mlir_values__ returns directly + (for tree-based debugging approach). + + Returns: + If structured=False: list[ir.Value] - flat list of MLIR values + If structured=True: the direct result of __extract_mlir_values__ (dict/list/ir.Value) """ - res = [] - if hasattr(obj, "__extract_mlir_values__"): - res = obj.__extract_mlir_values__() - elif isinstance(obj, (tuple, list)): - res = sum((extract_mlir_values(x) for x in obj), []) - elif isinstance(obj, SimpleNamespace): + import dataclasses + + if structured: + # Tree-structured mode: return __extract_mlir_values__ result directly + if hasattr(obj, "__extract_mlir_values__"): + return obj.__extract_mlir_values__() + elif dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return { + field.name: extract_mlir_values( + getattr(obj, field.name), structured=True + ) + for field in dataclasses.fields(obj) + } + elif isinstance(obj, (tuple, list)): + return [extract_mlir_values(x, structured=True) for x in obj] + elif isinstance(obj, SimpleNamespace): + return { + k: extract_mlir_values(v, structured=True) + for k, v in obj.__dict__.items() + } + elif isinstance(obj, ir.Value): + return obj + elif isinstance(obj, ir.BlockArgumentList): + return list(obj) + else: + return None + else: + # Flat list mode (original behavior) res = [] - for k, v in obj.__dict__.items(): - res.extend(extract_mlir_values(v)) - # Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in extract_mlir_values to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) + if hasattr(obj, "__extract_mlir_values__"): + # Flatten whatever __extract_mlir_values__ returns to ensure we always get a flat list + res = flatten_mlir_values(obj.__extract_mlir_values__()) + elif isinstance(obj, (tuple, list)): + res = sum((extract_mlir_values(x) for x in obj), []) + elif isinstance(obj, SimpleNamespace): + res = [] + for k, v in obj.__dict__.items(): + res.extend(extract_mlir_values(v)) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in extract_mlir_values to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + elif isinstance(obj, ir.Value): + res = [obj] + elif isinstance(obj, ir.BlockArgumentList): + res = list(obj) + + return res + + +def flatten_mlir_values(values: Any) -> list[ir.Value]: + """ + Flatten a nested dict/list structure of MLIR values into a flat list. + + This is used when we need to pass values to MLIR operations that expect + a flat list of values (e.g., function arguments, yield operands). + + Args: + values: A nested structure (dict, list, ir.Value, or None) + + Returns: + list[ir.Value]: A flat list of all MLIR values in depth-first order + """ + if values is None: + return [] + elif isinstance(values, ir.Value): + return [values] + elif isinstance(values, dict): + result = [] + for v in values.values(): + result.extend(flatten_mlir_values(v)) + return result + elif isinstance(values, list): + result = [] + for v in values: + result.extend(flatten_mlir_values(v)) + return result + else: + return [] + + +def unflatten_mlir_values(flat_values: Any, template: Any) -> Any: + """ + Reconstruct a nested dict/list structure from a flat list of MLIR values. + + This is the inverse of flatten_mlir_values. It uses a template structure + to know how to rebuild the nested structure. + + Args: + flat_values: Iterator or list of MLIR values + template: A nested structure (dict, list, ir.Value, or None) that + defines the shape to reconstruct + + Returns: + A nested structure matching the template shape, filled with values + from flat_values + """ + if not hasattr(flat_values, "__next__"): + flat_values = iter(flat_values) + + if template is None: + return None + elif isinstance(template, ir.Value): + return next(flat_values) + elif isinstance(template, dict): + return {k: unflatten_mlir_values(flat_values, v) for k, v in template.items()} + elif isinstance(template, list): + return [unflatten_mlir_values(flat_values, v) for v in template] + else: + return None + + +# ============================================================================= +# Dynamic Debug Control +# ============================================================================= + + +class _DynamicDebugState: + """ + Global state for controlling dynamic loop debug output. + """ + + def __init__(self) -> None: + self.enabled = False + self.max_depth: int | None = None + self.current_depth = 0 + + def should_print(self) -> bool: + if not self.enabled: + return False + if self.max_depth is None: + return True + return self.current_depth <= self.max_depth + + def enter_level(self) -> None: + self.current_depth += 1 + + def exit_level(self) -> None: + self.current_depth = max(0, self.current_depth - 1) + + def reset_depth(self) -> None: + self.current_depth = 0 + + +_dynamic_debug = _DynamicDebugState() + + +def set_dynamic_debug(enabled: bool, max_depth: int | None = None) -> None: + """ + Enable or disable dynamic loop debug output. + + Args: + enabled: Whether to enable debug output + max_depth: Maximum nesting depth to print. None means unlimited. + """ + _dynamic_debug.enabled = enabled + _dynamic_debug.max_depth = max_depth + _dynamic_debug.current_depth = 0 + + +def get_dynamic_debug() -> tuple[bool, int | None, int]: + """ + Get the current dynamic debug state. + + Returns: + Tuple of (enabled, max_depth, current_depth) + """ + return ( + _dynamic_debug.enabled, + _dynamic_debug.max_depth, + _dynamic_debug.current_depth, + ) + + +def should_print_dynamic_debug() -> bool: + """ + Check if dynamic debug output should be printed at the current level. + + Returns: + True if debug output is enabled and within max_depth limit. + """ + return _dynamic_debug.should_print() + + +def get_dynamic_debug_level() -> int: + """ + Get the current dynamic debug nesting level. + + Returns: + Current nesting depth. + """ + return _dynamic_debug.current_depth + + +class dynamic_debug_level: + """ + Context manager for tracking nesting depth in dynamic debug output. + + Usage: + with dynamic_debug_level(): + # Code at increased nesting level + if should_print_dynamic_debug(): + print(f"Level {get_dynamic_debug_level()}: ...") + """ + + def __enter__(self) -> None: + _dynamic_debug.enter_level() + + def __exit__(self, *args: object) -> None: + _dynamic_debug.exit_level() + + +def reset_dynamic_debug_depth() -> None: + """Reset the dynamic debug depth counter to 0.""" + _dynamic_debug.reset_depth() + + +def debug_print_mlir_values( + obj: object, indent: int = 0, name: str | None = None, types_only: bool = False +) -> str: + """ + Print a structured tree of MLIR values for debugging. + + Args: + obj: The object to print + indent: Current indentation level + name: Optional name to display for this node + types_only: If True, show MLIR types instead of full values + + Returns: + str: A formatted string representation of the MLIR values tree + """ + lines = [] + prefix = " " * indent + + if name: + type_name = name + elif hasattr(obj, "__class__"): + type_name = obj.__class__.__name__ + else: + type_name = str(type(obj).__name__) + + if obj is None: + lines.append(f"{prefix}{type_name}: (none)") elif isinstance(obj, ir.Value): - res = [obj] - elif isinstance(obj, ir.BlockArgumentList): - res = list(obj) # type: ignore + if types_only: + lines.append(f"{prefix}{type_name}: {obj.type}") + else: + lines.append(f"{prefix}{type_name}: {obj} : {obj.type}") + elif hasattr(obj, "__extract_mlir_values__"): + values = obj.__extract_mlir_values__() + lines.append(f"{prefix}{type_name}:") + if isinstance(values, dict): + for key, val in values.items(): + if val is None: + lines.append(f"{prefix} {key}: (static/none)") + elif isinstance(val, ir.Value): + if types_only: + lines.append(f"{prefix} {key}: {val.type}") + else: + lines.append(f"{prefix} {key}: {val} : {val.type}") + elif isinstance(val, (dict, list)): + lines.append( + f"{prefix} {key}: {type(val).__name__} with {len(val)} items" + ) + else: + lines.append(f"{prefix} {key}: {val}") + elif isinstance(values, ir.Value): + if types_only: + lines.append(f"{prefix} value: {values.type}") + else: + lines.append(f"{prefix} value: {values} : {values.type}") + elif isinstance(values, list): + lines.append(f"{prefix} [{len(values)} values]") + else: + lines.append(f"{prefix} {values}") + elif isinstance(obj, dict): + lines.append(f"{prefix}{type_name}: dict with {len(obj)} items") + elif isinstance(obj, (list, tuple)): + lines.append(f"{prefix}{type_name}: [{len(obj)} items]") + else: + lines.append(f"{prefix}{type_name}: {obj}") - return res + return "\n".join(filter(None, lines)) -def extract_mlir_attributes(obj): +def extract_mlir_attributes(obj: object) -> list[Any]: """ Given the `obj`, recursively go through it to extract all contained IR attributes as list of MLIR attributes. This is used for generating kernel function argument attributes. @@ -155,42 +439,89 @@ def extract_mlir_attributes(obj): return res -def new_from_mlir_values(obj, values): +def new_from_mlir_values(obj: Any, values: Any, *, structured: bool = False) -> Any: """ - Create a new python object by populating containing MLIR values with list of new values + Create a new python object by populating containing MLIR values with new values. + + Args: + obj: The original object to use as a template + values: Either a flat list of MLIR values (structured=False) or + a nested structure matching __extract_mlir_values__ output (structured=True) + structured: If False (default), values is a flat list sliced by type counts. + If True, values is passed directly to __new_from_mlir_values__. + + Returns: + A new object of the same type as obj, with MLIR values replaced """ + # Objects with __new_from_mlir_values__ always receive values directly if hasattr(obj, "__new_from_mlir_values__"): return obj.__new_from_mlir_values__(values) - elif isinstance(obj, (tuple, list)): - res = [] - for x in obj: - n_items = len(get_mlir_types(x)) - res.append(new_from_mlir_values(x, values[:n_items])) - values = values[n_items:] - obj_ty = type(obj) - return obj_ty(res) - elif isinstance(obj, SimpleNamespace): - res = SimpleNamespace() - for k, v in obj.__dict__.items(): - n_items = len(get_mlir_types(v)) - res.__dict__[k] = new_from_mlir_values(v, values[:n_items]) - values = values[n_items:] - return res - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in new_from_mlir_values to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - elif is_dynamic_expression(obj): - if len(values) == 0: - return obj - assert len(values) == 1 - return values[0] + import dataclasses + + if structured: + # Tree-structured mode + if dataclasses.is_dataclass(obj) and not isinstance(obj, type): + new_field_values = { + field.name: new_from_mlir_values( + getattr(obj, field.name), values[field.name], structured=True + ) + for field in dataclasses.fields(obj) + } + return type(obj)(**new_field_values) + elif isinstance(obj, (tuple, list)): + res = [ + new_from_mlir_values(x, v, structured=True) for x, v in zip(obj, values) + ] + obj_ty = type(obj) + if hasattr(obj_ty, '_make'): + return obj_ty._make(res) + return obj_ty(res) + elif isinstance(obj, SimpleNamespace): + ns = SimpleNamespace() + for k, v in obj.__dict__.items(): + ns.__dict__[k] = new_from_mlir_values(v, values[k], structured=True) + return ns + elif isinstance(obj, ir.Value): + return values + elif is_dynamic_expression(obj): + return values + else: + return obj else: - assert len(values) == 0, f"{obj} expects 0 values, but got {values}" - return obj + # Flat list mode (original behavior) + if isinstance(obj, (tuple, list)): + res = [] + for x in obj: + n_items = len(get_mlir_types(x)) + res.append(new_from_mlir_values(x, values[:n_items])) + values = values[n_items:] + obj_ty = type(obj) + if hasattr(obj_ty, '_make'): + return obj_ty._make(res) + return obj_ty(res) + elif isinstance(obj, SimpleNamespace): + ns = SimpleNamespace() + for k, v in obj.__dict__.items(): + n_items = len(get_mlir_types(v)) + ns.__dict__[k] = new_from_mlir_values(v, values[:n_items]) + values = values[n_items:] + return ns + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in new_from_mlir_values to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + elif is_dynamic_expression(obj): + if len(values) == 0: + return obj + + assert len(values) == 1 + return values[0] + else: + assert len(values) == 0, f"{obj} expects 0 values, but got {values}" + return obj class DSLSingletonMeta(type): @@ -220,7 +551,7 @@ class DSLSingletonMeta(type): _instances: ClassVar[dict] = {} _lock: ClassVar[threading.Lock] = threading.Lock() - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> Any: with cls._lock: log().info(f"DSLSingletonMeta __call__ for {cls}") if cls is BaseDSL: @@ -236,7 +567,7 @@ class DSLSingletonMeta(type): log().info(f"Active DSL singleton instances: {cls._instances}") return cls._instances[cls] - def clear_instances(cls): + def clear_instances(cls) -> None: log().info( f"Clearing DSL singleton instances for {cls}, current instances: {cls._instances}" ) @@ -267,19 +598,19 @@ class DSLLocation: class BaseDSL(metaclass=DSLSingletonMeta): - gpu_module = None - _env_class = EnvironmentVarManager + gpu_module: Any = None + _env_class: type[EnvironmentVarManager] = EnvironmentVarManager def __init__( self, *, name: str, - dsl_package_name: List[str], + dsl_package_name: list[str], compiler_provider: Any, pass_sm_arch_name: str, - device_compilation_only=False, - preprocess=False, - ): + device_compilation_only: bool = False, + preprocess: bool = False, + ) -> None: """ Constructor for initializing the class with required providers and environment settings. @@ -303,21 +634,24 @@ class BaseDSL(metaclass=DSLSingletonMeta): "All required parameters must be provided and non-empty" ) - self.name = name - self.compiler_provider = compiler_provider - self.pass_sm_arch_name = pass_sm_arch_name - self.decorator_location = None - self.no_cache = False - self.device_compilation_only = device_compilation_only - self.num_kernels = 0 + self.name: str = name + self.compiler_provider: Any = compiler_provider + self.pass_sm_arch_name: str = pass_sm_arch_name + self.decorator_location: DSLLocation | None = None + self.no_cache: bool = False + self.device_compilation_only: bool = device_compilation_only + self.num_kernels: int = 0 # Read environment variables - self.envar = self._env_class(self.name) - self.enable_preprocessor = preprocess + self.envar: EnvironmentVarManager = self._env_class(self.name) + register_env_manager(self.envar) + self.enable_preprocessor: bool = preprocess # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default - self.jit_cache = JitCacheDict(max_elems=self.envar.jit_cache_max_elems) + self.jit_cache: JitCacheDict = JitCacheDict( + max_elems=self.envar.jit_cache_max_elems + ) - self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}" - self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}" + self.host_jit_decorator_name: str = f"@{BaseDSL.jit.__name__}" + self.device_jit_decorator_name: str = f"@{BaseDSL.kernel.__name__}" # set warning if not self.envar.enable_optimization_warnings: @@ -331,69 +665,73 @@ class BaseDSL(metaclass=DSLSingletonMeta): # kernel info contains per kernel info including symbol string and CUfunction attributes to set # It's valid until the compilation is done. # {symbol_string: {CUfunction_attribute: value}} - self.kernel_info = OrderedDict() + self.kernel_info: OrderedDict[str, Any] = OrderedDict() # used to generate unique name for gpu.launch - self.launch_inner_count = 0 + self.launch_inner_count: int = 0 # initialize default compile options - self.compile_options = CompileOptions() + self.compile_options: CompileOptions = CompileOptions() if preprocess: - self.preprocessor = DSLPreprocessor(dsl_package_name) + self.preprocessor: DSLPreprocessor = DSLPreprocessor(dsl_package_name) log().info(f"Initializing {name} DSL") log().debug(f"Logger initialized for {self.name}") if self.envar.jit_time_profiling: - self.profiler = timer(enable=True) - self.cache_hits = 0 - self.cache_misses = 0 + self.profiler: Any = timer(enable=True) + self.cache_hits: int = 0 + self.cache_misses: int = 0 # Hook excepthook if self.envar.filter_stacktrace: origin_excepthook = sys.excepthook module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__))) - def excepthook(excep_type, value, traceback): - filter_exception(value, module_dir) + def excepthook( + excep_type: type, value: BaseException, traceback: Any + ) -> None: + filter_exception(value, module_dir) # type: ignore[arg-type] if hasattr(value, "__traceback__"): origin_excepthook(excep_type, value, value.__traceback__) else: origin_excepthook( - excep_type, value, filter_stackframe(traceback, module_dir) + excep_type, + value, + filter_stackframe(traceback, module_dir), # type: ignore[arg-type] ) sys.excepthook = excepthook # Restore original excepthook - def restore_excepthook(hook): + def restore_excepthook(hook: Any) -> None: sys.excepthook = hook atexit.register(restore_excepthook, origin_excepthook) @lru_cache(maxsize=1) - def print_warning_once(self, message): + def print_warning_once(self, message: str) -> None: log().warning(f"Warning: {message}") warnings.warn(message, UserWarning) - def print_warning(self, message): + def print_warning(self, message: str) -> None: log().warning(f"Warning: {message}") warnings.warn(message, UserWarning) @classmethod - def _get_dsl(cls): - # Instantiate the DSL Class once - main_dsl = cls() + def _get_dsl(cls) -> "BaseDSL": + # Instantiate the DSL Class once (singleton metaclass returns existing instance) + main_dsl = cls() # type: ignore[call-arg] return main_dsl @staticmethod - def _can_preprocess(**dkwargs): + def _can_preprocess(**dkwargs: Any) -> bool: """ Check if AST transformation is enabled or not for `jit` and `kernel` decorators. """ return dkwargs.pop("preprocess", True) @staticmethod - def _lazy_initialize_dsl(func): + def _lazy_initialize_dsl(func: Any) -> None: """ Lazy initialization of DSL object if has not been initialized """ @@ -402,7 +740,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): delattr(func, "_dsl_cls") @staticmethod - def _preprocess_and_replace_code(func): + def _preprocess_and_replace_code(func: Any) -> None: """ Run ast transformation and return the materialized function pointer """ @@ -429,13 +767,19 @@ class BaseDSL(metaclass=DSLSingletonMeta): ) @staticmethod - def jit_runner(cls, executor_name, frame, *dargs, **dkwargs): + def jit_runner( + cls: type["BaseDSL"], + executor_name: str, + frame: Any, + *dargs: Any, + **dkwargs: Any, + ) -> Any: """ Decorator to mark a function for JIT compilation. """ log().info("jit_runner") - def jit_runner_decorator(func): + def jit_runner_decorator(func: Any) -> Any: # Run preprocessor that alters AST func._dsl_cls = cls func._decorator_location = BaseDSL.get_location_from_frame(frame) @@ -445,7 +789,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): func._preprocessed = True @wraps(func) - def jit_wrapper(*args, **kwargs): + def jit_wrapper(*args: Any, **kwargs: Any) -> Any: BaseDSL._preprocess_and_replace_code(func) custom_name = getattr(jit_wrapper, "_name_prefix", None) @@ -458,10 +802,10 @@ class BaseDSL(metaclass=DSLSingletonMeta): func, *args, **kwargs ) - def set_name_prefix(name: str): - jit_wrapper._name_prefix = name + def set_name_prefix(name: str) -> None: + jit_wrapper._name_prefix = name # type: ignore[attr-defined] - jit_wrapper.set_name_prefix = set_name_prefix + jit_wrapper.set_name_prefix = set_name_prefix # type: ignore[attr-defined] return jit_wrapper @@ -471,30 +815,30 @@ class BaseDSL(metaclass=DSLSingletonMeta): return jit_runner_decorator @classmethod - def jit(cls, *dargs, **dkwargs): + def jit(cls, *dargs: Any, **dkwargs: Any) -> Any: """ Decorator to mark a function for JIT compilation for Host code. """ - frame = inspect.currentframe().f_back + frame = inspect.currentframe().f_back # type: ignore[union-attr] return BaseDSL.jit_runner(cls, "_func", frame, *dargs, **dkwargs) @classmethod - def kernel(cls, *dargs, **dkwargs): + def kernel(cls, *dargs: Any, **dkwargs: Any) -> Any: """ Decorator to mark a function for JIT compilation for GPU. """ - frame = inspect.currentframe().f_back + frame = inspect.currentframe().f_back # type: ignore[union-attr] return BaseDSL.jit_runner(cls, "_kernel_helper", frame, *dargs, **dkwargs) @abstractmethod - def _kernel_helper(self, func, *args, **kwargs): + def _kernel_helper(self, func: Any, *args: Any, **kwargs: Any) -> Any: """ Helper function to handle kernel generation logic """ pass @abstractmethod - def _build_gpu_module(self, attrs, loc=None): + def _build_gpu_module(self, attrs: dict[str, Any], loc: Any = None) -> None: """ Build the module op that contains the kernels. """ @@ -502,7 +846,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): pass @abstractmethod - def _get_pipeline(self, pipeline): + def _get_pipeline(self, pipeline: str | None) -> str | None: """ Get the pipeline from the other configuration options. """ @@ -511,7 +855,9 @@ class BaseDSL(metaclass=DSLSingletonMeta): return None @staticmethod - def log_additions(func_type, operands=None, types=None, arg_attrs=None): + def log_additions( + func_type: Any, operands: Any = None, types: Any = None, arg_attrs: Any = None + ) -> None: if operands is not None and operands != []: log().debug( f"Added {func_type} operands: [%s]", ", ".join(map(str, operands)) @@ -525,12 +871,16 @@ class BaseDSL(metaclass=DSLSingletonMeta): f"Added {func_type} arg_attrs: [%s]", ", ".join(map(str, arg_attrs)) ) - def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): + def mangle_name( + self, function_name: str, args: tuple[Any, ...], sig: inspect.Signature + ) -> str: """Does simple name mangling""" - for spec_arg, arg in zip(args_spec.args, args): - spec_ty = args_spec.annotations.get(spec_arg, None) - if spec_ty != None: + # sig.parameters maybe longer than args, but since canonicalized_args + # only contains positional arguments, we can rely on zip to truncate + for param, arg in zip(sig.parameters.values(), args): + spec_ty = param.annotation + if spec_ty != inspect.Parameter.empty: if issubclass(type(spec_ty), (t.IRValue, t.IRVariadic)): continue if isinstance(spec_ty, (ir.Type, ir.Value)): @@ -541,7 +891,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): continue if self._is_tensor_descriptor(arg): continue - if inspect.isclass(spec_ty): + if spec_ty != inspect.Parameter.empty and inspect.isclass(spec_ty): class_name = str(arg).replace("class", "") class_name = class_name.replace(" ", "") function_name = f"{function_name}_{class_name}" @@ -565,8 +915,14 @@ class BaseDSL(metaclass=DSLSingletonMeta): return function_name def _generate_execution_arguments_for_known_types( - self, arg, arg_spec, arg_name, i, fop_args, iv_block_args - ): + self, + arg: Any, + arg_spec: Any, + arg_name: str, + i: int, + fop_args: list[Any], + iv_block_args: int, + ) -> tuple[list[Any], int]: """ Generate MLIR arguments for known types. @@ -581,69 +937,68 @@ class BaseDSL(metaclass=DSLSingletonMeta): def generate_execution_arguments( self, - args, - kwargs, - fop, - args_spec: inspect.FullArgSpec, - ): + args: tuple[Any, ...], + kwonlyargs: dict[str, Any], + fop: Any, + sig: inspect.Signature, + ) -> tuple[list[Any], dict[str, Any]]: """Create list of arguments that will be passed to MLIR's func.func op""" - def gen_exec_args(input_args, arg_names, annotations, fop_args): - assert len(input_args) == len(arg_names) + def gen_exec_arg( + idx: int, + arg: Any, + parameter: inspect.Parameter, + fop_args: list[Any], + iv_block_args: int, + ) -> tuple[Any, int]: + arg_name = parameter.name + arg_spec = parameter.annotation + log().debug("Processing [%d] Argument [%s : %s]", idx, arg_name, arg_spec) - ir_args = [] - iv_block_args = 0 - for i, arg in enumerate(input_args): - arg_name = arg_names[i] - arg_spec = annotations.get(arg_name, None) - log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec) + # Implicit cast to NumericMeta + if isinstance(arg_spec, t.NumericMeta) and not isinstance(arg, arg_spec): + arg = t.cast(arg, arg_spec) # type: ignore[arg-type] - # Implicit cast to NumericMeta - if isinstance(arg_spec, t.NumericMeta) and not isinstance( - arg, arg_spec - ): - arg = t.cast(arg, arg_spec) + ir_arg, iv_block_args = self._generate_execution_arguments_for_known_types( + arg, arg_spec, arg_name, idx, fop_args, iv_block_args + ) - ir_arg, iv_block_args = ( - self._generate_execution_arguments_for_known_types( - arg, arg_spec, arg_name, i, fop_args, iv_block_args - ) - ) + if not ir_arg: + # If it's not a known type, try JIT argument adapter + # to convert the argument if possible + adapter = JitArgAdapterRegistry.get_registered_adapter(arg) + arg = adapter(arg) if adapter else arg - if not ir_arg: - # If it's not a known type, try JIT argument adapter - # to convert the argument if possible - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) - arg = adapter(arg) if adapter else arg + n_args = len(get_mlir_types(arg)) + blk_args = fop_args[iv_block_args : iv_block_args + n_args] + ir_arg = new_from_mlir_values(arg, blk_args) + iv_block_args += n_args + else: + ir_arg = ir_arg[0] - n_args = len(get_mlir_types(arg)) - blk_args = fop_args[iv_block_args : iv_block_args + n_args] - ir_arg.append(new_from_mlir_values(arg, blk_args)) - iv_block_args += n_args - - self.log_additions(ir_arg) - ir_args.extend(ir_arg) - - return ir_args, iv_block_args + self.log_additions(ir_arg) + return ir_arg, iv_block_args fop_args = list(fop.regions[0].blocks[0].arguments) - ir_args, iv_block_args = gen_exec_args( - args, args_spec.args, args_spec.annotations, fop_args - ) - ir_kwargs, _ = gen_exec_args( - [kwargs[arg] for arg in args_spec.kwonlyargs], - args_spec.kwonlyargs, - args_spec.annotations, - fop_args[iv_block_args:], - ) - ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)} + ir_args = [] + ir_kwargs = {} + iv_block_args = 0 + for i, (arg, param) in enumerate(zip(args, sig.parameters.values())): + ir_arg, iv_block_args = gen_exec_arg(i, arg, param, fop_args, iv_block_args) + ir_args.append(ir_arg) + + for i, (name, arg) in enumerate(kwonlyargs.items()): + ir_arg, iv_block_args = gen_exec_arg( + i, arg, sig.parameters[name], fop_args, iv_block_args + ) + ir_kwargs[name] = ir_arg log().debug("execution args: %s", ", ".join(map(str, ir_args))) log().debug("execution kwargs: %s", ", ".join(map(str, ir_kwargs))) return ir_args, ir_kwargs @abstractmethod - def _generate_mlir_type_for_tensor_descriptor(self, tensor): + def _generate_mlir_type_for_tensor_descriptor(self, tensor: Any) -> Any: """ Generate MLIR type for the tensor descriptor. """ @@ -651,24 +1006,52 @@ class BaseDSL(metaclass=DSLSingletonMeta): @abstractmethod def _generate_executable_arg_for_tensor_descriptor( - self, mlir_value=None, ptr_tensor_ty=None, tensor=None - ): + self, mlir_value: Any = None, ptr_tensor_ty: Any = None, tensor: Any = None + ) -> Any: """ Generates executable value for the given tensor descriptor. """ pass @abstractmethod - def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: + def _is_tensor_descriptor(self, maybe_tensor_descriptor: object) -> bool: pass @abstractmethod def _handle_tensor_descriptor( - self, maybe_tensor, arg_name: str, need_gpu_memory: bool + self, maybe_tensor: Any, arg_name: str, need_gpu_memory: bool ) -> Any: pass - def _validate_arg(self, arg, arg_index, arg_name, arg_spec): + def _should_remove_empty_gpu_modules(self) -> bool: + """ + Returns whether empty gpu.module instances should be removed from + the final generated IR. + """ + return True + + @staticmethod + def __remove_empty_gpu_modules(module: ir.Module) -> None: + """ + Removes empty gpu.module instances from the given module. + """ + + def delete_empty_gpu_module_op(op: Any) -> ir.WalkResult: + if op.name != "gpu.module": + return ir.WalkResult.ADVANCE + if ( + len(op.regions) == 0 + or len(op.regions[0].blocks) == 0 + or len(op.regions[0].blocks[0].operations) == 0 + ): + op.erase() + return ir.WalkResult.ADVANCE + + module.operation.walk(delete_empty_gpu_module_op) + + def _validate_arg( + self, arg: Any, arg_index: int, arg_name: str, arg_spec: Any + ) -> Any: """ Validates if the arg is really of the annotated type for type safety. @@ -679,14 +1062,14 @@ class BaseDSL(metaclass=DSLSingletonMeta): def _generate_jit_func_args_for_known_types( self, - func, - arg, - arg_name, - arg_spec, - arg_index, + func: Any, + arg: Any, + arg_name: str, + arg_spec: Any, + arg_index: int, *, - is_host=True, - ): + is_host: bool = True, + ) -> tuple[list[Any] | None, list[Any] | None, list[Any] | None]: """ Generate JIT function arguments for known types. @@ -694,7 +1077,9 @@ class BaseDSL(metaclass=DSLSingletonMeta): natively supported by the Base DSL. """ - jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] + jit_arg_type: list[Any] | None = [] + jit_arg_attr: list[Any] | None = [] + jit_exec_arg: list[Any] | None = [] if is_argument_constexpr(arg, arg_spec, arg_name, arg_index, func): jit_exec_arg = jit_arg_type = jit_arg_attr = None @@ -703,40 +1088,61 @@ class BaseDSL(metaclass=DSLSingletonMeta): def _generate_jit_func_args( self, - func, - function_name, - args, - kwargs, - args_spec: inspect.FullArgSpec, + func: Any, + function_name: str, + args: tuple[Any, ...] | list[Any], + kwonlyargs: dict[str, Any], + sig: inspect.Signature, *, - is_host=True, - compile_only=False, - ): + is_host: bool = True, + compile_only: bool = False, + ) -> tuple[list[Any], list[Any], list[Any], list[Any]]: """Generate JIT function arguments.""" - assert len(args) == len(args_spec.args) and len(kwargs) == len( - args_spec.kwonlyargs + positional_names = [] + kwonly_names = [] + for name, param in sig.parameters.items(): + if param.kind == inspect.Parameter.KEYWORD_ONLY: + kwonly_names.append(name) + else: + positional_names.append(name) + + assert len(args) == len(positional_names) and len(kwonlyargs) == len( + kwonly_names ), ( - f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args " - f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}" + f"Input args {len(args)=} and kwonlyargs {len(kwonlyargs)=} must match positional params " + f"{len(positional_names)=} and keyword-only params {len(kwonly_names)=}" ) - jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], [] + jit_arg_types: list[Any] = [] + jit_arg_attrs: list[Any] = [] + jit_exec_args: list[Any] = [] jit_adapted_args = [] default_attr = ir.DictAttr.get({}) - input_args = [*args, *kwargs.values()] - input_arg_names = [*args_spec.args, *args_spec.kwonlyargs] + input_args = [*args, *kwonlyargs.values()] + input_arg_names = [*positional_names, *kwonly_names] for i, (arg_name, arg) in enumerate(zip(input_arg_names, input_args)): - spec_ty = args_spec.annotations.get(arg_name, None) + spec_ty = sig.parameters[arg_name].annotation + + # Unwrap Annotated[T, marker1, ...] → base type T + markers. + annotation_markers = () + if ( + spec_ty is not inspect.Parameter.empty + and get_origin(spec_ty) is Annotated + ): + type_args = get_args(spec_ty) + spec_ty = type_args[0] + annotation_markers = type_args[1:] + log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty) # Implicitly convert into Numeric type if possible if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty): - arg = t.cast(arg, spec_ty) + arg = t.cast(arg, spec_ty) # type: ignore[arg-type] # Type safety check - if spec_ty is not None: + if spec_ty is not inspect.Parameter.empty: err = self._validate_arg(arg, i, arg_name, spec_ty) if err is not None: raise err @@ -755,32 +1161,28 @@ class BaseDSL(metaclass=DSLSingletonMeta): if jit_arg_type is not None and len(jit_arg_type) == 0: # If not any known type, try JIT argument adapter # to convert the argument - adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + adapter = JitArgAdapterRegistry.get_registered_adapter(arg) if adapter: arg = adapter(arg) jit_adapted_args.append(arg) if is_host: if self.envar.enable_tvm_ffi: - jit_exec_arg.extend([arg]) + jit_exec_arg.extend([arg]) # type: ignore[union-attr] else: - jit_exec_arg.extend(get_c_pointers(arg)) + jit_exec_arg.extend(get_c_pointers(arg)) # type: ignore[union-attr] jit_arg_type.extend(get_mlir_types(arg)) - jit_arg_attr.extend([default_attr] * len(get_mlir_types(arg))) + jit_arg_attr.extend([default_attr] * len(get_mlir_types(arg))) # type: ignore[union-attr] else: dyn_vals = extract_mlir_values(arg) - jit_exec_arg.extend(dyn_vals) + jit_exec_arg.extend(dyn_vals) # type: ignore[union-attr] jit_arg_type.extend([v.type for v in dyn_vals]) - jit_arg_attr.extend(extract_mlir_attributes(arg)) + jit_arg_attr.extend(extract_mlir_attributes(arg)) # type: ignore[union-attr] if not jit_arg_type or not jit_exec_arg: # when it is compile only, we don't have to prepare the executable arguments. - if ( - is_host and (compile_only or hasattr(arg, "__c_pointers__")) - ) or ( - not is_host - and hasattr(arg, "__extract_mlir_values__") - and hasattr(arg, "__new_from_mlir_values__") + if (is_host and (compile_only or implements_jit_argument(arg))) or ( + not is_host and implements_dynamic_expression(arg) ): pass else: @@ -788,8 +1190,8 @@ class BaseDSL(metaclass=DSLSingletonMeta): f"failed to generate argument #{i + 1} ({arg_name}) for JIT function '{function_name}'.", context={ f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.", - f"Call-site argument value": arg, - f"Call-site argument type": type(arg), + "Call-site argument value": arg, + "Call-site argument type": type(arg), }, suggestion=f"Consider annotating the argument with `{arg_name} : Constexpr` " "if it's a value known at compile-time. " @@ -799,29 +1201,29 @@ class BaseDSL(metaclass=DSLSingletonMeta): ) if jit_arg_type is not None: - jit_exec_args.extend(jit_exec_arg) + jit_exec_args.extend(jit_exec_arg) # type: ignore[arg-type] jit_arg_types.extend(jit_arg_type) - jit_arg_attrs.extend(jit_arg_attr) + jit_arg_attrs.extend(jit_arg_attr) # type: ignore[arg-type] return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args def generate_mlir_function_types( self, - func, - function_name, - input_args, - kwargs, - args_spec: inspect.FullArgSpec, - compile_only=False, - ): + func: Any, + function_name: str, + args: tuple[Any, ...] | list[Any], + kwonlyargs: dict[str, Any], + sig: inspect.Signature, + compile_only: bool = False, + ) -> tuple[list[Any], list[Any], list[Any]]: """Convert input arguments to MLIR function signature also convert numpy arrays to memref.""" exe_args, types, attrs, adapted_args = self._generate_jit_func_args( func, function_name, - input_args, - kwargs, - args_spec, + args, + kwonlyargs, + sig, is_host=True, compile_only=compile_only, ) @@ -837,13 +1239,13 @@ class BaseDSL(metaclass=DSLSingletonMeta): @dataclass class LaunchConfig: - cluster: list = None - fallback_cluster: list = None - grid: list = field(default_factory=lambda: [1, 1, 1]) - block: list = field(default_factory=lambda: [1, 1, 1]) - max_number_threads: list = field(default_factory=lambda: [0, 0, 0]) - smem: int = None - async_deps: list = field(default_factory=list) + cluster: list[Any] | None = None + fallback_cluster: list[Any] | None = None + grid: list[Any] = field(default_factory=lambda: [1, 1, 1]) + block: list[Any] = field(default_factory=lambda: [1, 1, 1]) + max_number_threads: list[Any] = field(default_factory=lambda: [0, 0, 0]) + smem: int | None = None + async_deps: list[Any] = field(default_factory=list) has_cluster: bool = False has_fallback_cluster: bool = False min_blocks_per_mp: int = 0 @@ -852,7 +1254,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): cooperative: bool = False @staticmethod - def _check_and_canonicalize_dim(dim, name): + def _check_and_canonicalize_dim(dim: Any, name: str) -> list[Any]: if not isinstance(dim, (list, tuple)): dim = [dim] @@ -869,7 +1271,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): # Pad with 1s to 3-dim vector for grid or block dimensions return list(dim) + [1] * (3 - len(dim)) - def __post_init__(self): + def __post_init__(self) -> None: self.grid = self._check_and_canonicalize_dim(self.grid, "grid") self.block = self._check_and_canonicalize_dim(self.block, "block") @@ -881,22 +1283,22 @@ class BaseDSL(metaclass=DSLSingletonMeta): if self.cluster is None: self.cluster = [None, None, None] elif len(self.cluster) != 3: - raise DSLRuntimeError(f"Expect 3d cluster!") + raise DSLRuntimeError("Expect 3d cluster!") self.has_fallback_cluster = self.fallback_cluster is not None if self.fallback_cluster is None: self.fallback_cluster = [None, None, None] elif len(self.fallback_cluster) != 3: - raise DSLRuntimeError(f"Expect 3d fallback_cluster!") + raise DSLRuntimeError("Expect 3d fallback_cluster!") - def has_max_number_threads(self): + def has_max_number_threads(self) -> bool: """Check if max_number_threads is given by user""" return all( value == 0 if not is_dynamic_expression(value) else False for value in self.max_number_threads ) - def diagnostic(self): + def diagnostic(self) -> None: """Check command line parameters and enables diagnostic""" # Check command line arguments "-diagnostic" parser = argparse.ArgumentParser(description="Process diagnostic status.") @@ -911,7 +1313,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): args, _ = parser.parse_known_args() ctx = ir.Context.current - def callback(d): + def callback(d: Any) -> None: print(f" [{self.name} Diagnostic] : {d.message}") ctx.attach_diagnostic_handler(callback) @@ -929,15 +1331,15 @@ class BaseDSL(metaclass=DSLSingletonMeta): ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}") @staticmethod - def get_location_from_frame(frame): + def get_location_from_frame(frame: Any) -> DSLLocation: return DSLLocation( - filename=inspect.getsourcefile(frame), + filename=inspect.getsourcefile(frame), # type: ignore[arg-type] lineno=frame.f_lineno, col_offset=0, function_name=frame.f_code.co_name, ) - def get_ir_location(self, location: DSLLocation = None): + def get_ir_location(self, location: DSLLocation | None = None) -> Any: """ Get python location information and generate MLIR location """ @@ -959,7 +1361,13 @@ class BaseDSL(metaclass=DSLSingletonMeta): ) return loc - def compile_and_jit(self, module, pipeline, shared_libs, function_name=""): + def compile_and_jit( + self, + module: ir.Module, + pipeline: str, + shared_libs: list[str], + function_name: str = "", + ) -> Any: """ Compile and JIT an MLIR module. """ @@ -973,12 +1381,13 @@ class BaseDSL(metaclass=DSLSingletonMeta): sys.stdout = redirect_stdout = io.StringIO() try: + enable_debug_info = self.envar.lineinfo kernel = self.compiler_provider.compile_and_jit( module, pipeline, shared_libs=shared_libs, - cuda_toolkit=self.envar.cuda_toolkit, arch=self.envar.arch, + enable_debug_info=enable_debug_info, ) finally: @@ -997,7 +1406,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): finally: pass - def preprocess_pipeline(self, pipeline, arch) -> str: + def preprocess_pipeline(self, pipeline: str, arch: str) -> str: options = { self.pass_sm_arch_name: arch, } @@ -1030,17 +1439,32 @@ class BaseDSL(metaclass=DSLSingletonMeta): ) shared_libs.append(lib) else: - self.print_warning(f"{self.name}_LIBS environment variable is not set") + if is_cutlass_family_dsl_prefix(self.name): + self.print_warning( + f"{self.name}_LIBS environment variable is not set and " + "CuTe-family auto-discovery failed. Set " + f"{self.name}_LIBS to the path of libcute_dsl_runtime.so " + "(e.g. /lib/libcute_dsl_runtime.so). For pip " + "editable installs, setting CUTE_DSL_LIBS or ensuring " + "PYTHONPATH includes /cutlass_ir/python_packages " + "also works." + ) + else: + self.print_warning( + f"{self.name}_LIBS environment variable is not set and " + f"auto-discovery failed. Set {self.name}_LIBS explicitly " + "for this DSL runtime." + ) return shared_libs @lru_cache(maxsize=1) - def get_version(self): + def get_version(self) -> "hashlib._Hash": version_hash = hashlib.sha256() return version_hash - def get_module_hash(self, module, function_name): + def get_module_hash(self, module: ir.Module, function_name: str) -> str: s = io.BytesIO() module.operation.write_bytecode(s) for attr, value in self.envar.__dict__.items(): @@ -1048,9 +1472,9 @@ class BaseDSL(metaclass=DSLSingletonMeta): s.write(str(value).encode()) # Add compile options to the hash s.write(self.compile_options.to_str().encode()) - module_hash = self.get_version().copy() - module_hash.update(s.getvalue()) - module_hash = module_hash.hexdigest() + hash_obj = self.get_version().copy() + hash_obj.update(s.getvalue()) + module_hash = hash_obj.hexdigest() log().debug("Bytecode=[%s]", s.getvalue().hex()) log().debug("Version=[%s]", self.get_version().hexdigest()) @@ -1059,18 +1483,34 @@ class BaseDSL(metaclass=DSLSingletonMeta): ) return module_hash - def build_module(self, module, function_name: str): + def build_module(self, module: ir.Module, function_name: str) -> ir.Module: """ Build the MLIR module, verify and return the module """ - # Save IR in a file + # Save IR in a file (raw, before any passes) — triggered by KEEP=ir-debug if self.envar.keep_ir: self.dump_mlir_path = save_ir( self.name, module, function_name, output_dir=self.envar.dump_dir, + enable_debug_info=self.envar.lineinfo, + ) + + # Save clean IR (after canonicalize+cse) — triggered by KEEP=ir + # Clone before compiling so the original module is not mutated. + if self.envar.keep_ir_clean: + module_clone = ir.Module.parse(str(module)) + self.compiler_provider.compile( + module_clone, "builtin.module(canonicalize,cse)" + ) + self.dump_mlir_path = save_ir( + self.name, + module_clone, + f"{function_name}_clean", + output_dir=self.envar.dump_dir, + enable_debug_info=self.envar.lineinfo, ) if self.envar.print_ir: @@ -1087,13 +1527,13 @@ class BaseDSL(metaclass=DSLSingletonMeta): return module - def get_return_types(self) -> List[ir.Type]: + def get_return_types(self) -> list[Any]: """ Get the return types of the host function. """ return [] - def generate_default_return_values(self, ip=None) -> List[ir.Value]: + def generate_default_return_values(self, ip: Any = None) -> list[Any]: """ Generate the default return values of the host function. """ @@ -1101,25 +1541,25 @@ class BaseDSL(metaclass=DSLSingletonMeta): def generate_original_ir( self, - ir, - func, - funcBody, - kwargs, - function_name, - func_types, - gpu_module_attrs, - args, - args_spec, - location=None, - ): - def build_ir_module(): + ir: Any, + func: Any, + funcBody: Callable[..., Any], + function_name: str, + func_types: list[Any], + gpu_module_attrs: dict[str, Any], + args: tuple[Any, ...], + kwonlyargs: dict[str, Any], + sig: inspect.Signature, + location: DSLLocation | None = None, + ) -> tuple[ir.Module, str, Any]: + def build_ir_module() -> tuple[ir.Module, Any]: loc = self.get_ir_location(location) module = ir.Module.create(loc=loc) unit_attr = ir.UnitAttr.get() module.operation.attributes["gpu.container_module"] = unit_attr with ir.InsertionPoint(module.body): - # Always generate gpu module. It's canonicalized by the compiler when it's not used. + # Always generate gpu module. We will remove it later if it is empty. self._build_gpu_module(gpu_module_attrs, loc=loc) ret_types = self.get_return_types() @@ -1131,7 +1571,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): entry_block = fop.add_entry_block(arg_locs=arg_locs) with ir.InsertionPoint(entry_block): ir_args, ir_kwargs = self.generate_execution_arguments( - args, kwargs, fop, args_spec + args, kwonlyargs, fop, sig ) # Call user function body try: @@ -1149,6 +1589,10 @@ class BaseDSL(metaclass=DSLSingletonMeta): except DSLRuntimeError as dsl_error: # Throw it's already a DSL error raise dsl_error + + if self._should_remove_empty_gpu_modules(): + BaseDSL.__remove_empty_gpu_modules(module) + return module, result # Build IR module @@ -1164,22 +1608,22 @@ class BaseDSL(metaclass=DSLSingletonMeta): def compile_and_cache( self, - module, - module_hash, - function_name, - pipeline, - args_spec, - no_cache, - no_jit_engine, - func_type=JitCompiledFunction, + module: ir.Module, + module_hash: str, + function_name: str, + pipeline: str | None, + sig: inspect.Signature, + no_cache: bool, + no_jit_engine: bool, + func_type: type[JitCompiledFunction] = JitCompiledFunction, *, - full_args=None, - full_kwargs=None, - dynamic_args=None, - dynamic_kwargs=None, - original_function_name=None, - funcBody=None, - ): + full_args: Any = None, + full_kwargs: Any = None, + dynamic_args: Any = None, + dynamic_kwargs: Any = None, + original_function_name: str | None = None, + funcBody: Callable[..., Any] | None = None, + ) -> JitCompiledFunction: # If `gpu-arch` is set by compile_options, use it. Otherwise, use the arch from the environment variable. compile_gpu_arch = ( self.envar.arch @@ -1192,7 +1636,8 @@ class BaseDSL(metaclass=DSLSingletonMeta): gen_jit_engine = False # Preprocess the pipeline. pipeline = self.preprocess_pipeline( - self._get_pipeline(pipeline), compile_gpu_arch + self._get_pipeline(pipeline), # type: ignore[arg-type] + compile_gpu_arch, # type: ignore[arg-type] ) shared_libs = self.get_shared_libs() # try load the file cache @@ -1265,14 +1710,19 @@ class BaseDSL(metaclass=DSLSingletonMeta): module, engine, capi_func, - args_spec, + sig, function_name, self.kernel_info, jit_time_profiling=self.envar.jit_time_profiling, + has_gpu_module=self.num_kernels > 0, jit_function_artifacts=JitFunctionArtifacts( PTX=self.compile_options.full_ptx_path, CUBIN=self.compile_options.full_cubin_path, - MLIR=(self.dump_mlir_path if self.envar.keep_ir else None), + MLIR=( + str(self.dump_mlir_path) + if (self.envar.keep_ir or self.envar.keep_ir_clean) + else None + ), ), # set dynamic arguments if the jit_function is a JitCompiledFunction for AOT generation. dynamic_args=dynamic_args, @@ -1295,10 +1745,10 @@ class BaseDSL(metaclass=DSLSingletonMeta): return fn - def post_compilation_cleanup(self): + def post_compilation_cleanup(self) -> None: """Clean up some internal state after one compilation is completed.""" # clear the kernel info after the compilation is done. - self.kernel_info = {} + self.kernel_info = OrderedDict() self.launch_inner_count = 0 # reset num_kernels to 0 for next compilation. self.num_kernels = 0 @@ -1307,52 +1757,80 @@ class BaseDSL(metaclass=DSLSingletonMeta): # reset decorator location after the compilation is done. self.decorator_location = None - def extract_dynamic_args(self, funcBody, args, kwargs, args_spec): + def extract_dynamic_args( + self, + funcBody: Callable[..., Any], + args: tuple[Any, ...], + kwonlyargs: dict[str, Any], + sig: inspect.Signature, + ) -> tuple[list[Any], OrderedDict[str, Any]]: """This function is used to extract the original dynamic arguments for AOT C header generation. The dynamic argument is the argument which is not marked as `Constexpr` in the function signature. """ dynamic_args = [] - dynamic_kwargs = OrderedDict() - for i, arg in enumerate(args): - if not is_arg_spec_constexpr( - args_spec.annotations.get(args_spec.args[i], None), - args_spec.args[i], + dynamic_kwonlyargs = OrderedDict() + for i, (arg, param) in enumerate(zip(args, sig.parameters.values())): + arg_name = param.name + arg_annotation = param.annotation + if not is_arg_annotation_constexpr( + arg_annotation, + arg_name, i, funcBody, ): dynamic_args.append(arg) - for i, (k, v) in enumerate(kwargs.items()): - if not is_arg_spec_constexpr(args_spec.kwonlyargs[i], k, i, funcBody): - dynamic_kwargs[k] = v - return dynamic_args, dynamic_kwargs + for i, (arg_name, v) in enumerate(kwonlyargs.items()): + arg_annotation = sig.parameters[arg_name].annotation + if not is_arg_annotation_constexpr( + arg_annotation, + arg_name, + i + len(args), + funcBody, + ): + dynamic_kwonlyargs[arg_name] = v + + return dynamic_args, dynamic_kwonlyargs def generate_mlir( self, - funcBody, - kwargs, - function_name, - gpu_module_attrs, - args, - args_spec, - pipeline, - no_cache, - no_jit_engine, - compile_only, - location=None, - ): + funcBody: Callable[..., Any], + function_name: str, + gpu_module_attrs: dict[str, Any], + args: tuple[Any, ...], + kwonlyargs: dict[str, Any], + sig: inspect.Signature, + pipeline: str | None, + no_cache: bool, + no_jit_engine: bool, + compile_only: bool, + location: DSLLocation | None = None, + ) -> Any: """Generate MLIR module and compile iself.T_provider.""" with ir.Context() as ctx, self.get_ir_location(location): # If threading is enabled, each MLIR context will keep alive a thread pool. # When we cache MLIR compilation results, we also cache its context thus accumulating #(compilations) * thread_pool_size threads. # Disable threading to avoid such excessive number of threads. ctx.enable_multithreading(False) + # Optional: capture full Python call stacks on every MLIR op. + # Enable via CUTE_DSL_LOC_TRACEBACKS=N (e.g. 128 for full stacks). + # Default OFF — deep tracebacks + LINEINFO causes segfault. + _loc_tb_depth = self.envar.loc_tracebacks + _loc_tb_ctx = None + if _loc_tb_depth: + try: + _depth = int(_loc_tb_depth) + _loc_tb_ctx = ir.loc_tracebacks(max_depth=_depth) + _loc_tb_ctx.__enter__() + except (ValueError, TypeError, AttributeError): + pass + try: # Convert input arguments to MLIR arguments exe_args, func_types, adapted_args = self.generate_mlir_function_types( - funcBody, function_name, args, kwargs, args_spec, compile_only + funcBody, function_name, args, kwonlyargs, sig, compile_only ) dynamic_args, dynamic_kwargs = self.extract_dynamic_args( - funcBody, args, kwargs, args_spec + funcBody, args, kwonlyargs, sig ) original_function_name = funcBody.__name__ @@ -1361,15 +1839,39 @@ class BaseDSL(metaclass=DSLSingletonMeta): ir, func, funcBody, - kwargs, function_name, func_types, gpu_module_attrs, args, - args_spec, + kwonlyargs, + sig, location=location, ) + # add ffi bitcode sources to link options + for gpu_module in module.body.operations: + if gpu_module.name != "gpu.module": + continue + if gpu_module is not None: + link_libraries = self.compile_options.options[ + LinkLibraries + ].value + try: + link_libraries_attributes = gpu_module.attributes[ + "link-libraries" + ] + except KeyError: + link_libraries_attributes = set() + sources = set(x.value for x in link_libraries_attributes) + link_libraries = ( + link_libraries + + ("," if len(link_libraries) > 0 else "") + + ",".join(sources) + ) + self.compile_options.options[LinkLibraries] = LinkLibraries( + link_libraries + ) + # dryrun is used to only generate IR if self.envar.dryrun: return result @@ -1388,11 +1890,11 @@ class BaseDSL(metaclass=DSLSingletonMeta): module_hash, function_name, pipeline, - args_spec, + sig, no_cache, no_jit_engine, full_args=args, - full_kwargs=kwargs, + full_kwargs=kwonlyargs, dynamic_args=dynamic_args, dynamic_kwargs=dynamic_kwargs, original_function_name=original_function_name, @@ -1408,6 +1910,11 @@ class BaseDSL(metaclass=DSLSingletonMeta): jit_function = cached_jit_func finally: + if _loc_tb_ctx is not None: + try: + _loc_tb_ctx.__exit__(None, None, None) + except Exception: + pass self.post_compilation_cleanup() # If compile_only is set, bypass execution return the jit_executor directly @@ -1419,13 +1926,42 @@ class BaseDSL(metaclass=DSLSingletonMeta): return result - def run_preprocessor(self, original_function): + @staticmethod + def _inject_closure_cells( + original_function: Any, exec_globals: dict[str, Any] + ) -> None: + """Inject closure cell values into *exec_globals*. + + When a decorated function captures variables from an enclosing scope, + those names are absent from ``__globals__``. The AST preprocessor + re-parses the source and ``exec()``s it, which requires those names + to be resolvable in *exec_globals*. + + This mirrors the injection already done by + ``function_compiler._rewrite_callee``. + """ + if original_function.__closure__: + for name, cell in zip( + original_function.__code__.co_freevars, + original_function.__closure__, + ): + try: + exec_globals[name] = cell.cell_contents + except ValueError: + # Cell may be empty if the variable was never assigned + # in the enclosing scope; safe to skip. + pass + + def run_preprocessor( + self, original_function: Any, callee_rewrite: bool = False + ) -> Any: function_name = original_function.__name__ self.funcBody = original_function log().info("Started preprocessing [%s]", function_name) exec_globals = {} if original_function.__globals__ is not None: exec_globals.update(original_function.__globals__) + self._inject_closure_cells(original_function, exec_globals) with self.preprocessor.get_session() as preprocessor_session: transformed_ast = preprocessor_session.transform( original_function, exec_globals @@ -1438,7 +1974,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): file_name = inspect.getsourcefile(original_function) code_object = compile( transformed_ast, - filename=file_name, + filename=file_name or "", mode="exec", ) @@ -1451,7 +1987,9 @@ class BaseDSL(metaclass=DSLSingletonMeta): exec_globals, ) - def _get_function_bound_args(self, sig, func_name, *args, **kwargs): + def _get_function_bound_args( + self, sig: inspect.Signature, func_name: str, *args: Any, **kwargs: Any + ) -> inspect.BoundArguments: """ Binds provided arguments to a function's signature and applies default values. @@ -1471,34 +2009,31 @@ class BaseDSL(metaclass=DSLSingletonMeta): ) return bound_args - def _canonicalize_args(self, sig, *args, **kwargs): + def _canonicalize_args( + self, bound_args: inspect.BoundArguments + ) -> tuple[tuple[Any, ...], dict[str, Any]]: """ Canonicalize the input arguments so that returned args only contain positional arguments and kwargs only contain keyword arguments. """ - function_name = self.funcBody.__name__ - bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) canonicalized_args = bound_args.args - canonicalized_kwargs = bound_args.kwargs - return canonicalized_args, canonicalized_kwargs - - def _check_arg_count(self, *args, **kwargs): - if not self.funcBody: - raise DSLRuntimeError("Function body is not set.") - - # Pass the actual function object to inspect.signature to get the signature. - sig = inspect.signature(self.funcBody) - - function_name = self.funcBody.__name__ - - bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) + canonicalized_kwonlyargs = bound_args.kwargs + return canonicalized_args, canonicalized_kwonlyargs + def _check_arg_count( + self, + sig: inspect.Signature, + bound_args: inspect.BoundArguments, + function_name: str, + ) -> bool: # Check if all non-default arguments are provided + has_varargs = False for param in sig.parameters.values(): if param.kind in ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, ): + has_varargs = True continue if ( param.default is inspect.Parameter.empty @@ -1507,99 +2042,104 @@ class BaseDSL(metaclass=DSLSingletonMeta): raise DSLRuntimeError( f"Missing required argument in `{function_name}`: '{param.name}'" ) + return has_varargs - return sig - - def _get_full_arg_spec(self, funcBody): + def _get_signature(self, funcBody: Callable[..., Any]) -> inspect.Signature: """ - Returns the full argument specification for a given function, handling PEP-563 - (postponed evaluation of type annotations) if necessary. - - If the function's annotations are provided as strings (which occurs when PEP-563 - is enabled), this method evaluates those annotations so they are returned as objects - instead of strings. + Returns the signature for a given function, handling PEP-563 + (postponed evaluation of type annotations) via eval_str=True. Parameters ---------- funcBody : function - The function whose argument specification is to be retrieved. + The function whose signature is to be retrieved. Returns ------- - inspect.FullArgSpec - The complete argument specification of the function, with its annotations - properly evaluated and resolved where relevant. + inspect.Signature + The signature of the function with annotations properly resolved. """ - args_spec = inspect.getfullargspec(funcBody) - # Set `eval_str = True` to make it work when PEP-563 is enabled - if args_spec.annotations and all( - type(arg_type) is str for arg_type in args_spec.annotations.values() - ): - eval_annotations = inspect.get_annotations(funcBody, eval_str=True) - args_spec = inspect.FullArgSpec( - args_spec.args, - args_spec.varargs, - args_spec.varkw, - args_spec.defaults, - args_spec.kwonlyargs, - args_spec.kwonlydefaults, - eval_annotations, - ) - return args_spec + return inspect.signature(funcBody, eval_str=True) @staticmethod def _expand_varargs_varkw( canonicalized_args: tuple, canonicalized_kwargs: dict, - args_spec: inspect.FullArgSpec, - ) -> inspect.FullArgSpec: - """Expand *args and **kwargs into concrete named parameters in the FullArgSpec. - - When a JIT function uses *args or **kwargs, the concrete call-site values - are known. This method synthesizes named parameters for them so the rest - of the pipeline (which expects fixed-arity signatures) works unchanged. - - For *args: extra positional arguments beyond ``args_spec.args`` get - synthetic names ``_vararg_0``, ``_vararg_1``, etc. - - For **kwargs: extra keyword arguments beyond ``args_spec.kwonlyargs`` - are appended as keyword-only parameters. + signature: inspect.Signature, + ) -> inspect.Signature: """ - if not args_spec.varargs and not args_spec.varkw: - return args_spec + Expands ``*args`` and ``**kwargs`` into concrete named parameters in the function's signature. - expanded_args = list(args_spec.args) - expanded_annotations = dict(args_spec.annotations) - expanded_defaults = list(args_spec.defaults) if args_spec.defaults else [] + This is used when a JIT function employs ``*args`` or ``**kwargs``. At the call site, + concrete values are provided; thus, this method generates explicit parameter names + so downstream components expecting fixed-arity signatures can function as usual. - if args_spec.varargs: - n_regular = len(args_spec.args) - n_extra = len(canonicalized_args) - n_regular - for i in range(n_extra): - expanded_args.append(f"varargs_{i}") + For ``*args``: additional positional arguments beyond those specified in the original + parameters are given synthetic parameter names such as ``_vararg_0``, ``_vararg_1``, etc. - expanded_kwonlyargs = list(args_spec.kwonlyargs) - expanded_kwonlydefaults = ( - dict(args_spec.kwonlydefaults) if args_spec.kwonlydefaults else {} - ) + For ``**kwargs``: extra keyword arguments beyond those already declared as keyword-only + in the signature will be appended as new keyword-only parameters. - if args_spec.varkw: - existing_kwonly = set(args_spec.kwonlyargs) - for key in canonicalized_kwargs: - if key not in existing_kwonly: - expanded_kwonlyargs.append(key) + Parameters + ---------- + canonicalized_args : tuple + The tuple of canonicalized positional arguments as passed to the function. + canonicalized_kwargs : dict + The dictionary of canonicalized keyword arguments as passed to the function. + signature : inspect.Signature + The function's signature object. - return inspect.FullArgSpec( - args=expanded_args, - varargs=None, - varkw=None, - defaults=tuple(expanded_defaults) if expanded_defaults else None, - kwonlyargs=expanded_kwonlyargs, - kwonlydefaults=expanded_kwonlydefaults if expanded_kwonlydefaults else None, - annotations=expanded_annotations, - ) + Returns + ------- + inspect.Signature + A new signature with expanded ``*args`` and ``**kwargs`` into named parameters. - def _func(self, funcBody, *args, **kwargs): + Notes + ----- + - Positional parameters are expanded first, followed by ``*args``, then keyword-only or default parameters, + and finally expansion of ``**kwargs``. + - This allows downstream function argument inspection and manipulation to treat all arguments as if they + were declared explicitly by name. + """ + + # Signature order + # positional → *args → keyword-only/default → **kwargs + new_params = [] + visited_kwonly_args = set() + for idx, (name, param) in enumerate(signature.parameters.items()): + if param.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + new_params.append(param) + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + # Expand *args into concrete named parameters + for vararg_idx in range(idx, len(canonicalized_args)): + new_params.append( + inspect.Parameter( + name=f"_vararg_{vararg_idx}", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ) + elif param.kind == inspect.Parameter.KEYWORD_ONLY: + new_params.append(param) + visited_kwonly_args.add(name) + elif param.kind == inspect.Parameter.VAR_KEYWORD: + # Expand **kwargs into concrete named parameters + for kwarg_name in canonicalized_kwargs: + if kwarg_name not in visited_kwonly_args: + new_params.append( + inspect.Parameter( + name=kwarg_name, + kind=inspect.Parameter.KEYWORD_ONLY, + ) + ) + else: + raise DSLRuntimeError(f"Invalid parameter kind: {param.kind}") + + return signature.replace(parameters=new_params) + + def _func(self, funcBody: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """Decorator for MLIR functions. It cuts the boilerplate code, does the following: 1. Generates `func.func` @@ -1642,37 +2182,45 @@ class BaseDSL(metaclass=DSLSingletonMeta): no_cache = True self.print_warning("Cache is disabled as user wants to compile only.") - # Check the number of arguments - sig = self._check_arg_count(*args, **kwargs) + # Get signature of the function + sig = self._get_signature(funcBody) - args_spec = inspect.getfullargspec(funcBody) + # Get bound arguments + bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) + + # Check the number of arguments + has_varargs = self._check_arg_count(sig, bound_args, function_name) # Canonicalize the input arguments - canonicalized_args, canonicalized_kwargs = self._canonicalize_args( - sig, *args, **kwargs + canonicalized_args, canonicalized_kwonly_args = self._canonicalize_args( + bound_args ) + # Expand *args/**kwargs into concrete named parameters - args_spec = self._expand_varargs_varkw( - canonicalized_args, canonicalized_kwargs, args_spec - ) + if has_varargs: + sig = self._expand_varargs_varkw( + canonicalized_args, canonicalized_kwonly_args, sig + ) # Simple name mangling - function_name = self.mangle_name(function_name, canonicalized_args, args_spec) + function_name = self.mangle_name(function_name, canonicalized_args, sig) if func_name_prefix: function_name = f"{func_name_prefix}_{function_name}" self.compile_options.apply_envar_settings(self.envar, function_name) if not self.compile_options.generate_line_info: self.decorator_location = None + # Enable frame filtering if line info is enabled + _set_enable_frame_filtering(self.compile_options.generate_line_info) # Generate MLIR Context and start generating IR log().debug(f"Generating MLIR for function '{function_name}'") result = self.generate_mlir( funcBody, - canonicalized_kwargs, function_name, gpu_module_attrs, canonicalized_args, - args_spec, + canonicalized_kwonly_args, + sig, pipeline, no_cache, no_jit_engine, @@ -1682,43 +2230,55 @@ class BaseDSL(metaclass=DSLSingletonMeta): return result class _KernelGenHelper(ABC): - def __init__(self): - self.func_op = None - self.func_type = None + def __init__(self) -> None: + self.func_op: Any = None + self.func_type: Any = None @abstractmethod - def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): + def generate_func_op( + self, + arg_types: list[Any], + arg_attrs: list[Any], + kernel_name: str, + loc: Any = None, + ) -> Any: assert arg_types is not None, "Invalid arg_types!" assert kernel_name is not None, "kernel name is empty" pass @abstractmethod - def generate_func_ret_op(self): + def generate_func_ret_op(self) -> None: pass @abstractmethod - def generate_launch_op(self, *args, **kwargs): + def generate_launch_op(self, *args: Any, **kwargs: Any) -> Any: pass @abstractmethod - def get_func_body_start(self): + def get_func_body_start(self) -> Any: pass @abstractmethod - def enter_gpu_module(module): + def enter_gpu_module(module: Any) -> Any: """Compute the insertion point into the given module.""" pass @lru_cache(maxsize=1) - def _get_default_stream(self): + def _get_default_stream(self) -> Any: """Returns the default stream 0""" from .runtime import cuda as cuda_helpers return cuda_helpers.stream_create() def _execute_cuda( - self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None - ): + self, + fname_cubin: str, + kernel_name: str, + grid_size: Any, + block_size: Any, + smem_size: int, + stream: Any = None, + ) -> None: """ Executes a specified CUDA kernel from a cubin file, handling module loading, kernel retrieval, stream creation, kernel launch, and synchronization. @@ -1751,13 +2311,13 @@ class BaseDSL(metaclass=DSLSingletonMeta): def _execute_by_cuda_driver( self, - kernel_generator, - generate_cubin, - grid_size, - block_size, - smem_size, - stream=None, - ): + kernel_generator: Callable[..., Any], + generate_cubin: Callable[..., str], + grid_size: Any, + block_size: Any, + smem_size: int, + stream: Any = None, + ) -> Any: """ This function builds IR and execute the module using cuda driver. It doesn't use mlir's cuda runtime @@ -1778,7 +2338,9 @@ class BaseDSL(metaclass=DSLSingletonMeta): return ret - def _generate_kernel_module(self, kernel_generator): + def _generate_kernel_module( + self, kernel_generator: Callable[..., Any] + ) -> tuple[Any, str, ir.Module]: """ Generates a module marked as GPU module which contains the kernel generated by :param kernel_generator:. @@ -1800,13 +2362,20 @@ class BaseDSL(metaclass=DSLSingletonMeta): return ret, kernel_name, self.build_module(module, kernel_name) def generate_kernel_operands_and_types( - self, kernel_func, kernel_name, args_spec, args, kwargs - ): + self, + kernel_func: Callable[..., Any], + kernel_name: str, + signature: inspect.Signature, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[list[Any], list[Any], list[Any]]: """ Generate the operands and types for the kernel function """ - kernel_operands, kernel_arg_types, kernel_arg_attrs = [], [], [] + kernel_operands: list[Any] = [] + kernel_arg_types: list[Any] = [] + kernel_arg_attrs: list[Any] = [] log().debug( "Processing GPU kernel call in [%s] mode", @@ -1822,7 +2391,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = ( self._generate_jit_func_args( - kernel_func, kernel_name, args, kwargs, args_spec, is_host=False + kernel_func, kernel_name, args, kwargs, signature, is_host=False ) ) @@ -1836,10 +2405,10 @@ class BaseDSL(metaclass=DSLSingletonMeta): return kernel_operands, kernel_arg_types, kernel_arg_attrs - def kernel_launcher(self, *dargs, **dkwargs): - def decorator(funcBody): + def kernel_launcher(self, *dargs: Any, **dkwargs: Any) -> Any: + def decorator(funcBody: Callable[..., Any]) -> Callable[..., Any]: @wraps(funcBody) - def kernel_wrapper(*args, **kwargs): + def kernel_wrapper(*args: Any, **kwargs: Any) -> Any: """ Base decorator for generating kernel function @@ -1868,20 +2437,22 @@ class BaseDSL(metaclass=DSLSingletonMeta): kernelGenHelper = dkwargs.get("kernelGenHelper", None) kernel_name = funcBody.__name__ - args_spec = inspect.getfullargspec(funcBody) + signature = self._get_signature(funcBody) self.funcBody = funcBody # Give each kernel a unique name. (The same kernel may be # called multiple times, resulting in multiple kernel traces.) # The mangled name of Python function is part of the name to # improve readability. - kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}" + kernel_name = f"kernel_{self.mangle_name(kernel_name, args, signature)}_{self.num_kernels}" if hasattr(self, "_name_prefix") and self._name_prefix: kernel_name = f"{self._name_prefix}_{kernel_name}" self.num_kernels += 1 # Step 0. Preprocess the arguments - def extract_args(argNames, assertIfNone=False) -> list: + def extract_args( + argNames: list[str], assertIfNone: bool = False + ) -> list[Any]: extracted = [] for name in argNames: value = kwargs.pop(name, None) @@ -1893,13 +2464,13 @@ class BaseDSL(metaclass=DSLSingletonMeta): return extracted - RequiredArgs = namedtuple("RequiredArgs", requiredArgs) + RequiredArgs = namedtuple("RequiredArgs", requiredArgs) # type: ignore[misc] req_args = ( RequiredArgs._make(extract_args(requiredArgs, assertIfNone=True)) if requiredArgs else None ) - OptionalArgs = namedtuple("OptionalArgs", optionalArgs) + OptionalArgs = namedtuple("OptionalArgs", optionalArgs) # type: ignore[misc] opt_args = ( OptionalArgs._make(extract_args(optionalArgs)) if optionalArgs @@ -1909,26 +2480,31 @@ class BaseDSL(metaclass=DSLSingletonMeta): "kernelGenHelper should be explicitly specified!" ) + # Get bound arguments + bound_args = self._get_function_bound_args( + signature, kernel_name, *args, **kwargs + ) + # check arguments - sig = self._check_arg_count(*args, **kwargs) + self._check_arg_count(signature, bound_args, kernel_name) # Canonicalize the input arguments canonicalized_args, canonicalized_kwargs = self._canonicalize_args( - sig, *args, **kwargs + bound_args ) kernel_operands, kernel_types, kernel_arg_attrs = ( self.generate_kernel_operands_and_types( funcBody, kernel_name, - args_spec, + signature, canonicalized_args, canonicalized_kwargs, ) ) loc = self.get_ir_location() - with self._enter_gpu_module(): + with self._enter_gpu_module(): # type: ignore[attr-defined] log().debug("Generating device kernel") if self.device_compilation_only: log().debug("Generating cuda-python arguments") @@ -1939,7 +2515,7 @@ class BaseDSL(metaclass=DSLSingletonMeta): kernel_name, canonicalized_args, canonicalized_kwargs, - args_spec, + signature, ) ) @@ -1956,17 +2532,21 @@ class BaseDSL(metaclass=DSLSingletonMeta): fop.sym_visibility = ir.StringAttr.get("public") with ir.InsertionPoint(helper.get_func_body_start()): ir_args, ir_kwargs = self.generate_execution_arguments( - canonicalized_args, canonicalized_kwargs, fop, args_spec + canonicalized_args, canonicalized_kwargs, fop, signature ) log().debug( f"IR arguments - args: {ir_args} ; kwargs: {ir_kwargs}" ) - # Call user function body + kernel_ret = funcBody(*ir_args, **ir_kwargs) + if hasattr(helper, "set_kernel_ret"): + helper.set_kernel_ret(kernel_ret) helper.generate_func_ret_op() # Step 3. Generate call site `launch_func` kernel_sym = ir.SymbolRefAttr.get(["kernels", kernel_name]) + setattr(funcBody, "_dsl_kernel_sym", kernel_sym) + setattr(funcBody, "_dsl_kernel_name", kernel_name) launch_ret = helper.generate_launch_op( kernelSym=kernel_sym, kernelOperands=kernel_operands, @@ -1991,17 +2571,22 @@ class BaseDSL(metaclass=DSLSingletonMeta): else: return decorator - def get_arch_enum(self) -> Arch: + def get_arch_enum(self) -> "Arch": """ Get the arch enum from the environment variable """ - arch_option = self.compile_options.gpu_arch - return Arch.from_string(arch_option if arch_option else self.envar.arch) + from .arch import Arch - def check_arch(self, criterion: Callable[[Arch], bool]) -> None: + arch_option: str | None = self.compile_options.gpu_arch + return Arch.from_string(arch_option if arch_option else self.envar.arch) # type: ignore[arg-type] + + def check_arch(self, criterion: Callable[["Arch"], bool]) -> None: """ Check the arch enum by criterion, raise DSLRuntimeError if the arch enum does not satisfy the criterion """ + # Avoid circular dependency + from .arch import Arch + arch = self.get_arch_enum() if not criterion(arch): raise DSLRuntimeError( diff --git a/python/CuTeDSL/cutlass/base_dsl/env_manager.py b/python/CuTeDSL/cutlass/base_dsl/env_manager.py index 414227731..49a53daa2 100644 --- a/python/CuTeDSL/cutlass/base_dsl/env_manager.py +++ b/python/CuTeDSL/cutlass/base_dsl/env_manager.py @@ -23,24 +23,83 @@ import os import sys import shutil import glob +import warnings from pathlib import Path from functools import lru_cache -from typing import Any from ..base_dsl.runtime.cuda import get_compute_capability_major_minor +from .common import DSLRuntimeError from .utils.logger import log from .cache_helpers import get_default_file_dump_root IS_WINDOWS = sys.platform == "win32" CLIB_EXT = ".dll" if IS_WINDOWS else ".so" +# ============================================================================= +# [DSL]_KEEP token definitions +# ============================================================================= + +# All individual artifact tokens accepted by [DSL]_KEEP. +_KEEP_ALL_TOKENS: frozenset[str] = frozenset( + { + "ir", + "ir-debug", + "ptx", + "cubin", + } +) +# "all" is a convenience alias that expands to every token above. +_KEEP_VALID_TOKENS: frozenset[str] = _KEEP_ALL_TOKENS | {"all"} + +CUTLASS_FAMILY_DSL_PREFIXES: frozenset[str] = frozenset( + { + "CUTE_DSL", + "CUTE_EXPERIMENTAL_DSL", + } +) + + +def is_cutlass_family_dsl_prefix(prefix: str) -> bool: + """Return whether the prefix uses the shared CuTe DSL runtime.""" + return prefix in CUTLASS_FAMILY_DSL_PREFIXES + + +def _parse_keep_tokens(raw: str, prefix: str = "") -> frozenset[str]: + """ + Parse the value of [DSL]_KEEP into a frozenset of canonical artifact tokens. + + Accepts a comma-separated list of tokens (case-insensitive). + The special value ``all`` expands to every token in ``_KEEP_ALL_TOKENS``. + Unknown tokens are logged as warnings and ignored. + + Token semantics: + ir — IR after canonicalize+cse (clean, human-readable) + ir-debug — Raw IR before any passes (old KEEP_IR=1 behaviour) + ptx — PTX assembly + cubin — CUBIN binary + """ + tokens = frozenset(t.strip().lower() for t in raw.split(",") if t.strip()) + if "all" in tokens: + return _KEEP_ALL_TOKENS + unknown = tokens - _KEEP_VALID_TOKENS + if unknown: + message = f"{prefix}_KEEP" if prefix else "[DSL]_KEEP" + log().warning( + "%s: unknown token(s) %s will be ignored. Valid tokens: %s", + message, + sorted(unknown), + sorted(_KEEP_VALID_TOKENS), + ) + return tokens - unknown + + # ============================================================================= # Environment Variable Helpers # ============================================================================= @lru_cache(maxsize=None) -def get_str_env_var(var_name, default_value=None): +def get_str_env_var(var_name: str, default_value: str | None = None) -> str | None: """ Get the string value of an environment variable. Note that the value is cached after the first call. @@ -50,7 +109,7 @@ def get_str_env_var(var_name, default_value=None): @lru_cache(maxsize=None) -def get_bool_env_var(var_name, default_value=False): +def get_bool_env_var(var_name: str, default_value: bool = False) -> bool: """ Get the value of a boolean environment variable. If the value it not in False, 0, or empty string, it is considered True. @@ -63,7 +122,7 @@ def get_bool_env_var(var_name, default_value=False): @lru_cache(maxsize=None) -def get_int_env_var(var_name, default_value=0): +def get_int_env_var(var_name: str, default_value: int = 0) -> int: """ Get the value of an integer environment variable. If the value is not a valid integer, the default value 0 is returned. @@ -74,7 +133,9 @@ def get_int_env_var(var_name, default_value=0): @lru_cache(maxsize=None) -def get_int_or_none_env_var(var_name, default_value=None): +def get_int_or_none_env_var( + var_name: str, default_value: int | None = None +) -> int | None: """ Get the value of an integer or None union environment variable. If the value is not a valid integer, the default value 0 is returned. @@ -95,7 +156,7 @@ def get_int_or_none_env_var(var_name, default_value=None): @lru_cache(maxsize=None) -def has_env_var(var_name): +def has_env_var(var_name: str) -> bool: """ Check if an environment variable is set. Note that the value is cached after the first call. @@ -103,7 +164,7 @@ def has_env_var(var_name): return os.getenv(var_name) is not None -def detect_gpu_arch(prefix): +def detect_gpu_arch(prefix: str) -> str: """ Attempts to detect the machine's GPU architecture. @@ -111,11 +172,11 @@ def detect_gpu_arch(prefix): A string representing the GPU architecture (e.g. "70" for compute capability 7.0), or a default value(e.g. "sm_100") if the GPU architecture cannot be determined. """ - arch = (None, None) + arch: tuple[int | None, int | None] = (None, None) try: arch = get_compute_capability_major_minor() except Exception as e: - log().info(f"Failed to get CUDA compute capability: {e}") + log().info("Failed to get CUDA compute capability: %s", e) if arch == (None, None): # default to sm_100 @@ -123,13 +184,15 @@ def detect_gpu_arch(prefix): major, minor = arch suffix = "" - if major >= 9: + if major >= 9: # type: ignore[operator] suffix = "a" return f"sm_{major}{minor}{suffix}" -def find_libs_in_ancestors(start, target_libs, lib_folder_guesses): +def find_libs_in_ancestors( + start: str | Path, target_libs: set[str], lib_folder_guesses: list[str] +) -> list[str] | None: """ Search ancestor directories for a candidate library folder containing all required libraries. @@ -181,7 +244,7 @@ def find_libs_in_ancestors(start, target_libs, lib_folder_guesses): return None -def _find_cuda_home(): +def _find_cuda_home() -> str | None: """Find the CUDA installation path using a series of heuristic methods. Methods below are checked in order, and the function returns on first match: 1. Checking the environment variables CUDA_HOME and CUDA_PATH. @@ -221,27 +284,7 @@ def _find_cuda_home(): return cuda_home -def get_cuda_toolkit_path(): - """ - Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if - set. Otherwise, attempts to discover a valid CUDA toolkit location and - return. If not found, return None. - """ - # Check if the environment variable is already set, if so, return it immediately. - try: - cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH") - if cuda_toolkit_path_existing: - return cuda_toolkit_path_existing - - found_cuda_home = _find_cuda_home() - if found_cuda_home: - return found_cuda_home - except Exception as e: - log().info("default_env: exception on get_cuda_toolkit_path", e) - return None - - -def get_prefix_dsl_libs(prefix: str): +def get_prefix_dsl_libs(prefix: str) -> str | None: """ Returns get_str_env_var('{prefix}_LIBS') if set. Otherwise, attempts to discover libs based on heuristics and return @@ -253,7 +296,7 @@ def get_prefix_dsl_libs(prefix: str): if prefix_libs_existing: return prefix_libs_existing - def get_libs_cand(start): + def get_libs_cand(start: str | Path) -> str | None: target_dsl_runtime_libs = { "cute_dsl_runtime", } @@ -279,15 +322,29 @@ def get_prefix_dsl_libs(prefix: str): # try to find from build folder structure dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve()) - return dsl_libs + if dsl_libs: + return dsl_libs + + # Known CuTe-family DSLs share libcute_dsl_runtime.so. With pip + # editable installs (`pip install -e`), the startup hook in + # cutlass/_pth_hook.py sets CUTE_DSL_LIBS but not the per-prefix + # variants, and the ancestor walk from the source tree cannot reach + # the build/vendored lib directory. Fall back to CUTE_DSL_LIBS for + # those aliases when their prefix-specific lookup fails. + if is_cutlass_family_dsl_prefix(prefix) and prefix != "CUTE_DSL": + fallback = os.getenv("CUTE_DSL_LIBS") + if fallback: + return fallback + + return None except Exception as e: - log().info(f"default_env: exception on get_prefix_dsl_libs", e) + log().info("default_env: exception on get_prefix_dsl_libs", e) return None class LogEnvironmentManager: - def __init__(self, prefix="DSL"): + def __init__(self, prefix: str = "DSL") -> None: self.prefix = prefix # Logging options @@ -302,7 +359,10 @@ class LogEnvironmentManager: and not self.log_to_file ): log().warning( - f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!" + "Log level was set, but neither logging to file (%s_LOG_TO_FILE) nor" + " logging to console (%s_LOG_TO_CONSOLE) is enabled!", + prefix, + prefix, ) self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1) @@ -318,11 +378,21 @@ class EnvironmentVarManager(LogEnvironmentManager): File options: - [DSL_NAME]_DUMP_DIR: Directory to dump the generated files (default: current working directory) - [DSL_NAME]_CACHE_DIR: Cache directory (default: /tmp/{dsl_name}_python_cache_{tmpfile_suffix}) - - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) - - [DSL_NAME]_KEEP_PTX: Save generated PTX in a file (default: False) - - [DSL_NAME]_KEEP_CUBIN: Save generated CUBIN in a file (default: False) - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False) + - [DSL_NAME]_KEEP: Comma-separated list of artifacts to save to DUMP_DIR (default: ""). + Tokens: + ir — IR after canonicalize+cse (clean, human-readable) + ir-debug — Raw IR before any passes + ptx — PTX assembly + cubin — CUBIN binary + all — all of the above + Example: CUTE_DSL_KEEP=ir,ptx + # Deprecated — use [DSL_NAME]_KEEP instead: + - [DSL_NAME]_KEEP_IR: (deprecated) use KEEP=ir-debug + - [DSL_NAME]_KEEP_PTX: (deprecated) use KEEP=ptx + - [DSL_NAME]_KEEP_CUBIN: (deprecated) use KEEP=cubin Other options: + - [DSL_NAME]_SHOW_STACKTRACE: Show full stack traces on failure (default: False) - [DSL_NAME]_LINEINFO: Compile with `--lineinfo` enabling developer tools such as the profiler and debugger (default: False) - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1). - [DSL_NAME]_DRYRUN: Generates IR only (default: False) @@ -336,9 +406,10 @@ class EnvironmentVarManager(LogEnvironmentManager): - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) - [DSL_NAME]_ENABLE_TVM_FFI: Enable TVM-FFI or not (default: False) + - [DSL_NAME]_LOC_TRACEBACKS: Maximum depth of location tracebacks (default: 0) """ - def __init__(self, prefix="DSL"): + def __init__(self, prefix: str = "DSL") -> None: super().__init__(prefix) # Printing options @@ -347,6 +418,7 @@ class EnvironmentVarManager(LogEnvironmentManager): ) self.print_ir = get_bool_env_var(f"{prefix}_PRINT_IR", False) self.filter_stacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) + self.show_stacktrace = get_bool_env_var(f"{prefix}_SHOW_STACKTRACE", False) self.lineinfo = get_bool_env_var(f"{prefix}_LINEINFO", False) self.no_cache = get_bool_env_var(f"{prefix}_NO_CACHE", False) self.jit_cache_max_elems = get_int_or_none_env_var( @@ -357,11 +429,48 @@ class EnvironmentVarManager(LogEnvironmentManager): self.dump_dir = get_str_env_var( f"{prefix}_DUMP_DIR", get_default_file_dump_root() ) - self.keep_ptx = get_bool_env_var(f"{prefix}_KEEP_PTX", False) - self.keep_cubin = get_bool_env_var(f"{prefix}_KEEP_CUBIN", False) # File options - self.keep_ir = get_bool_env_var(f"{prefix}_KEEP_IR", False) self.cache_dir = get_str_env_var(f"{prefix}_CACHE_DIR", None) + + # ------------------------------------------------------------------ # + # Artifact keep — [DSL]_KEEP= # + # ------------------------------------------------------------------ # + _keep_raw = get_str_env_var(f"{prefix}_KEEP", "") + _keep_tokens: set[str] = set( + _parse_keep_tokens(_keep_raw, prefix) if _keep_raw else frozenset() + ) + + if get_bool_env_var(f"{prefix}_KEEP_IR", False): + warnings.warn( + f"{prefix}_KEEP_IR is deprecated; use {prefix}_KEEP=ir-debug instead.", + DeprecationWarning, + stacklevel=2, + ) + _keep_tokens.add("ir-debug") + if get_bool_env_var(f"{prefix}_KEEP_PTX", False): + warnings.warn( + f"{prefix}_KEEP_PTX is deprecated; use {prefix}_KEEP=ptx instead.", + DeprecationWarning, + stacklevel=2, + ) + _keep_tokens.add("ptx") + if get_bool_env_var(f"{prefix}_KEEP_CUBIN", False): + warnings.warn( + f"{prefix}_KEEP_CUBIN is deprecated; use {prefix}_KEEP=cubin instead.", + DeprecationWarning, + stacklevel=2, + ) + _keep_tokens.add("cubin") + + self.keep_tokens: frozenset[str] = frozenset(_keep_tokens) + + # Derived boolean attributes — used by compiler.py and dsl.py. + # keep_ir_clean: save IR after canonicalize+cse (the readable form). + self.keep_ir_clean: bool = "ir" in self.keep_tokens + # keep_ir: save raw IR before any passes (old KEEP_IR=1 semantics). + self.keep_ir: bool = "ir-debug" in self.keep_tokens + self.keep_ptx: bool = "ptx" in self.keep_tokens + self.keep_cubin: bool = "cubin" in self.keep_tokens # Other options self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) @@ -375,9 +484,6 @@ class EnvironmentVarManager(LogEnvironmentManager): self.disable_file_caching = get_bool_env_var( f"{prefix}_DISABLE_FILE_CACHING", False ) - # set cuda - self.cuda_toolkit = get_cuda_toolkit_path() - # set mlir shared libraries self.shared_libs = get_prefix_dsl_libs(prefix) @@ -385,3 +491,5 @@ class EnvironmentVarManager(LogEnvironmentManager): self.enable_assertions = get_bool_env_var(f"{prefix}_ENABLE_ASSERTIONS", False) self.enable_tvm_ffi = get_bool_env_var(f"{prefix}_ENABLE_TVM_FFI", False) + + self.loc_tracebacks = get_int_env_var(f"{prefix}_LOC_TRACEBACKS", 0) diff --git a/python/CuTeDSL/cutlass/base_dsl/export/__init__.py b/python/CuTeDSL/cutlass/base_dsl/export/__init__.py index d06596efe..afa15ce7d 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/__init__.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/__init__.py @@ -16,7 +16,7 @@ from .export import ( decode_metadata_from_execution_engine, ) -from .export import ArgsSpecProcessor +from .export import SignatureProcessor from .external_binary_module import ExternalBinaryModule, LoadProvider __all__ = [ @@ -25,7 +25,7 @@ __all__ = [ "get_export_module", "encode_metadata_into_ir_module", "decode_metadata_from_execution_engine", - "ArgsSpecProcessor", + "SignatureProcessor", "ExternalBinaryModule", "LoadProvider", ] diff --git a/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py b/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py index 6a07c2b88..33cdc5963 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/c_header_generator.py @@ -31,7 +31,9 @@ from ..common import DSLRuntimeError from ..jit_executor import ExecutionArgs from ..._mlir import ir -from typing import Type, List, Any, Dict +from dataclasses import dataclass +from typing import Any, Union, get_origin, get_args +import inspect from inspect import isclass import cuda.bindings.driver as cuda @@ -41,34 +43,24 @@ import cuda.bindings.driver as cuda cubin_suffix = "cubin" +@dataclass class CHeaderArguments: """This class is used to store the arguments generation of the wrapper function. The arguments are generated when the JitCompiledFunction is created and used to avoid the long-term reference of the arguments by the JitCompiledFunction. """ - def __init__( - self, - dummy_prefix_name: str, - arguments: List[str], - packed_args: List[str], - declarations: str, - error_msg: str = None, - ): - self.dummy_prefix_name = dummy_prefix_name - self.arguments = arguments - self.packed_args = packed_args - self.declarations = declarations - self.error_msg = error_msg + dummy_prefix_name: str + arguments: list[str] + packed_args: list[str] + declarations: list[str] + error_msg: str | None = None - def __bool__(self): + def __bool__(self) -> bool: return self.error_msg is None - def __str__(self): - return self.error_msg - - def __repr__(self): - return self.error_msg + def __str__(self) -> str: + return self.error_msg or "" class CHeaderGenerator: @@ -116,7 +108,7 @@ class CHeaderGenerator: Float16: "__half_raw ", } - def _count_dynamic_expression(self, arg): + def _count_dynamic_expression(self, arg: Any) -> int: """ Count the number of dynamic values in the argument. """ @@ -126,7 +118,7 @@ class CHeaderGenerator: return 1 return 0 - def _generate_numeric_argument(self, arg_name: str, arg_type: Type[Numeric]): + def _generate_numeric_argument(self, arg_name: str, arg_type: type[Numeric]) -> str: """ Generate the argument of the wrapper function. """ @@ -136,7 +128,7 @@ class CHeaderGenerator: f"Unsupported argument type for c function argument generation: {arg_type}" ) - def _generate_check_cuda(self, dsl_name: str): + def _generate_check_cuda(self, dsl_name: str) -> str: check_cuda = ( f""" // Macro to check for cuda errors. @@ -151,8 +143,11 @@ class CHeaderGenerator: return check_cuda def _generate_kernel_module( - self, symbol_prefix: str, kernel_info: Dict[str, List], dsl_name: str - ): + self, + symbol_prefix: str, + kernel_info: dict[str, list[Any]], + dsl_name: str, + ) -> str: """ Generate the kernel module for the compiled function. """ @@ -208,28 +203,34 @@ static inline void {symbol_prefix}_Kernel_Module_Unload({symbol_prefix}_Kernel_M self, symbol_prefix: str, args_spec: ExecutionArgs, - args: List[Any], - kwargs: Dict[str, Any], - ): + args: tuple[Any], + kwargs: dict[str, Any], + ) -> tuple[list[str], list[str], list[str]]: """ Generate the arguments of the wrapper function. """ - arguments = [] - packed_args = [] - declarations = [] + arguments: list[str] = [] + packed_args: list[str] = [] + declarations: list[str] = [] # traverse the runtime args_spec and generate the arguments rectified_args = args_spec.get_rectified_args(args, kwargs) - input_arg_names = args_spec.args_spec.args + args_spec.args_spec.kwonlyargs - for arg_name, arg in zip(input_arg_names, rectified_args): - arg_type = args_spec.args_spec.annotations.get(arg_name, None) + for param, arg in zip(args_spec.signature.parameters.values(), rectified_args): + arg_type = param.annotation + arg_name = param.name # process optional argument if arg is None: continue + # Unwrap Optional[X] (i.e. Union[X, None]) to X when arg is not None + if get_origin(arg_type) is Union: + inner_types = [t for t in get_args(arg_type) if t is not type(None)] + if len(inner_types) == 1: + arg_type = inner_types[0] + # Generate basic numeric types if isinstance(arg_type, NumericMeta): - arguments.append(self._generate_numeric_argument(arg_name, arg_type)) + arguments.append(self._generate_numeric_argument(arg_name, arg_type)) # type: ignore[arg-type] packed_args.append("&" + arg_name) elif isclass(arg_type) and issubclass(arg_type, cuda.CUstream): arguments.append("CUstream " + arg_name) @@ -247,9 +248,9 @@ static inline void {symbol_prefix}_Kernel_Module_Unload({symbol_prefix}_Kernel_M symbol_prefix: str, args_spec: ExecutionArgs, function_name: str, - kernel_info: Dict[str, List], + kernel_info: dict[str, list[Any]], c_header_arguments: CHeaderArguments, - ): + ) -> str: """ Generate the wrapper function for the compiled function which is provided to users as the entry point. It uses the `symbol_prefix` as the function name for identification. The host/device symbols are hidden under the bytecode. @@ -270,17 +271,17 @@ static inline void {symbol_prefix}_Kernel_Module_Unload({symbol_prefix}_Kernel_M arg.replace(c_header_arguments.dummy_prefix_name, symbol_prefix) for arg in c_header_arguments.packed_args ] - declarations = [ + declaration_lines = [ declaration.replace(c_header_arguments.dummy_prefix_name, symbol_prefix) for declaration in c_header_arguments.declarations ] - declarations = "\n".join(declarations) + declarations_joined = "\n".join(declaration_lines) # 3. Generate the wrapper function kernel_symbols = tuple(kernel_info.keys()) kernel_symbols_str = ", ".join([f"&module->{sym}" for sym in kernel_symbols]) function = ( - declarations + declarations_joined + f""" #ifdef __cplusplus extern "C" @@ -297,7 +298,7 @@ static inline void {wrapper_function_name}({symbol_prefix}_Kernel_Module_t *modu ) return function - def _generate_binary_declaration(self, symbol_prefix: str): + def _generate_binary_declaration(self, symbol_prefix: str) -> str: """ Generate the binary of the compiled function. """ @@ -313,7 +314,7 @@ extern const unsigned char {varname}[]; export_module: ir.Module, args_spec: ExecutionArgs, function_name: str, - kernel_info: Dict[str, List], + kernel_info: dict[str, list[Any]], c_header_arguments: CHeaderArguments, dsl_name: str, ) -> str: diff --git a/python/CuTeDSL/cutlass/base_dsl/export/export.py b/python/CuTeDSL/cutlass/base_dsl/export/export.py index 9a4bd4b18..c99dc9ef7 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/export.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/export.py @@ -9,6 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +import inspect import io import os @@ -19,7 +20,6 @@ from ..._mlir.dialects import llvm import json import base64 import ctypes -from inspect import FullArgSpec args_spec_suffix = "args_spec" function_name_suffix = "function_name" @@ -29,8 +29,11 @@ c_string_suffix = "\0" def get_export_module( - ir_module: ir.Module, symbol_prefix: str, *, preserve_symbols=None -): + ir_module: ir.Module, + symbol_prefix: str, + *, + preserve_symbols: set[str] | None = None, +) -> ir.Module: """Get the export module which is cloned from the original compiled ir module, and add the prefix to avoid the symbol conflict. @@ -45,7 +48,7 @@ def get_export_module( if preserve_symbols is None: preserve_symbols = set() - def walk_llvm_func_op(op): + def walk_llvm_func_op(op: ir.Operation) -> ir.WalkResult: # not a declaration if ( op.name == "llvm.func" @@ -62,7 +65,7 @@ def get_export_module( ) return ir.WalkResult.ADVANCE - def walk_llvm_references(op): + def walk_llvm_references(op: ir.Operation) -> ir.WalkResult: # Rename function calls if op.name == "llvm.call" and op.attributes["callee"].value in defined_symbols: op.attributes["callee"] = ir.FlatSymbolRefAttr.get( @@ -113,26 +116,26 @@ def get_export_module( return export_module -class ArgsSpecProcessor: - """The args spec processor. The args_spec may contain the dsl specific types. The base processor - class is used to define an interface for dumping and loading the args_spec.""" +class SignatureProcessor: + """The signature processor. The signature may contain the dsl specific types. The base processor + class is used to define an interface for dumping and loading the signature.""" - def dumps(self, args_spec: FullArgSpec) -> bytes: - raise NotImplementedError("ArgsSpecProcessor does not support dumps") + def dumps(self, signature: inspect.Signature) -> bytes: + raise NotImplementedError("SignatureProcessor does not support dumps") - def loads(self, args_spec_bytes: bytes): - raise NotImplementedError("ArgsSpecProcessor does not support loads") + def loads(self, signature_bytes: bytes) -> inspect.Signature: + raise NotImplementedError("SignatureProcessor does not support loads") def encode_metadata_into_ir_module( prefix: str, ir_module: ir.Module, - args_spec: FullArgSpec, + signature: inspect.Signature, function_name: str, kernel_info: dict, - args_spec_processor: ArgsSpecProcessor, + signature_processor: SignatureProcessor, object_file_version: str, -): +) -> ir.Module: """Encode the executor metadata into the ir module. The metadata includes: 1. args_spec: The args_spec of the python function. 2. function_name: The name mangling function_name of the python host function. @@ -144,17 +147,17 @@ def encode_metadata_into_ir_module( @param args_spec: The args_spec of the python function. @param function_name: The name mangling function_name of the python host function. @param kernel_info: The kernel_info of the jit-compiled function including the kernel name and attributes. - @param args_spec_processor: The args spec processor. The args_spec may contain the dsl specific types. The processor will be used to dump and load the args_spec. + @param signature_processor: The signature processor. The signature may contain the dsl specific types. The processor will be used to dump and load the signature. @param object_file_version: The version of the object file. """ - if not args_spec: + if not signature: raise DSLRuntimeError( - "args_spec is empty, please set the args_spec for the python jit function." + "signature is empty, please set the signature for the python jit function." ) version = object_file_version + c_string_suffix - args_spec_bytes = args_spec_processor.dumps(args_spec) - args_spec_str = base64.b64encode(args_spec_bytes).decode("utf-8") + c_string_suffix + signature_bytes = signature_processor.dumps(signature) + signature_str = base64.b64encode(signature_bytes).decode("utf-8") + c_string_suffix packed_function_name = ( "_mlir_" + prefix + "__mlir_ciface_" + function_name + c_string_suffix ) @@ -162,9 +165,9 @@ def encode_metadata_into_ir_module( with ir.InsertionPoint(ir_module.body): args_spec_op = llvm.GlobalOp( sym_name="_".join([prefix, args_spec_suffix]), - global_type=ir.Type.parse(f"!llvm.array<{len(args_spec_str)} x i8>"), + global_type=ir.Type.parse(f"!llvm.array<{len(signature_str)} x i8>"), linkage=ir.Attribute.parse("#llvm.linkage"), - value=ir.StringAttr.get(args_spec_str), + value=ir.StringAttr.get(signature_str), ) function_name_op = llvm.GlobalOp( sym_name="_".join([prefix, function_name_suffix]), @@ -175,7 +178,7 @@ def encode_metadata_into_ir_module( value=ir.StringAttr.get(packed_function_name), ) # pack the kernel_info from a dict to a global op. - kernel_info = json.dumps(kernel_info) + c_string_suffix + kernel_info = json.dumps(kernel_info) + c_string_suffix # type: ignore[assignment] kernel_info_op = llvm.GlobalOp( sym_name="_".join([prefix, kernel_info_suffix]), global_type=ir.Type.parse(f"!llvm.array<{len(kernel_info)} x i8>"), @@ -194,19 +197,19 @@ def encode_metadata_into_ir_module( def decode_metadata_from_execution_engine( prefix: str, - execution_engine: "BinaryExecutionEngine", - args_spec_processor: ArgsSpecProcessor, -): + execution_engine: "BinaryExecutionEngine", # type: ignore[name-defined] + signature_processor: SignatureProcessor, +) -> tuple[inspect.Signature, str | None, dict, str | None]: """Decode the executor metadata from the execution engine. The metadata includes: - 1. args_spec: The args_spec of the python function. + 1. signature: The signature of the python function. 2. function_name: The name mangling function_name of the python host function. 3. kernel_info: The kernel_info of the jit-compiled function including the kernel name and attributes. 4. version: The version of the object file. @param prefix: The prefix name of the function. This is the unique identifier name of the function to avoid symbol conflict in the generated object file. @param execution_engine: The binary execution engine. This is the execution engine to load the cuda module. - @param args_spec_processor: The args spec processor. The args_spec may contain the dsl specific types. The processor will be used to dump and load the args_spec. - @return: The args_spec, function_name, and kernel_info. + @param signature_processor: The signature processor. The signature may contain the dsl specific types. The processor will be used to dump and load the signature. + @return: The signature, function_name, and kernel_info. """ args_spec_str_p = execution_engine.lookup("_".join([prefix, args_spec_suffix])) function_name_str_p = execution_engine.lookup( @@ -215,25 +218,25 @@ def decode_metadata_from_execution_engine( kernel_info_str_p = execution_engine.lookup("_".join([prefix, kernel_info_suffix])) version_str_p = execution_engine.lookup("_".join([prefix, version_suffix])) if args_spec_str_p: - args_spec_str = ctypes.c_char_p(args_spec_str_p).value.decode("utf-8") + args_spec_str = ctypes.c_char_p(args_spec_str_p).value.decode("utf-8") # type: ignore[union-attr] else: args_spec_str = None # The StringAttr encodes the string as utf-8 format. if function_name_str_p: - function_name_str = ctypes.c_char_p(function_name_str_p).value.decode("utf-8") + function_name_str = ctypes.c_char_p(function_name_str_p).value.decode("utf-8") # type: ignore[union-attr] else: function_name_str = None if kernel_info_str_p: - kernel_info_str = ctypes.c_char_p(kernel_info_str_p).value.decode("utf-8") + kernel_info_str = ctypes.c_char_p(kernel_info_str_p).value.decode("utf-8") # type: ignore[union-attr] else: kernel_info_str = None if version_str_p: - version_str = ctypes.c_char_p(version_str_p).value.decode("utf-8") + version_str = ctypes.c_char_p(version_str_p).value.decode("utf-8") # type: ignore[union-attr] else: version_str = None - args_spec_bytes = base64.b64decode(args_spec_str) - args_spec = args_spec_processor.loads(args_spec_bytes) + args_spec_bytes = base64.b64decode(args_spec_str) # type: ignore[arg-type] + args_spec = signature_processor.loads(args_spec_bytes) function_name = function_name_str - kernel_info = json.loads(kernel_info_str) + kernel_info = json.loads(kernel_info_str) # type: ignore[arg-type] return args_spec, function_name, kernel_info, version_str diff --git a/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py b/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py index 7119d5e0e..8f257fbf1 100644 --- a/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py +++ b/python/CuTeDSL/cutlass/base_dsl/export/external_binary_module.py @@ -13,19 +13,20 @@ import io import os import ctypes -from typing import Callable -from inspect import FullArgSpec +from collections.abc import Callable +from inspect import Signature, Parameter +from typing import Any, cast from ..common import DSLRuntimeError from ...base_dsl.dsl import BaseDSL from ...base_dsl.typing import Int32, Int64, Float32, Float64 -from .export import decode_metadata_from_execution_engine +from .export import SignatureProcessor, decode_metadata_from_execution_engine -def _get_ctypes_return_type(args_spec: FullArgSpec): - """Get the ctypes return type from the args_spec.""" - return_type = args_spec.annotations.get("return", None) - if return_type is None: +def _get_ctypes_return_type(signature: Signature) -> Any: + """Get the ctypes return type from the signature.""" + return_type = signature.return_annotation + if return_type is Parameter.empty: raise DSLRuntimeError("Return type is not specified for AOT compiled function.") type_to_ctype = { Int32: ctypes.c_int32, @@ -48,14 +49,14 @@ class LoadProvider: def __init__( self, - dsl: "Type[BaseDSL]", - args_spec_processor: "ArgsSpecProcessor", - version_checker: Callable, - execution_engine_constructor: Callable, - jit_function_constructor: Callable, - ): + dsl: type[BaseDSL], + signature_processor: SignatureProcessor, + version_checker: Callable[..., Any], + execution_engine_constructor: Callable[..., Any], + jit_function_constructor: Callable[..., Any], + ) -> None: self.dsl = dsl - self.args_spec_processor = args_spec_processor + self.signature_processor = signature_processor self.version_checker = version_checker self.execution_engine_constructor = execution_engine_constructor self.jit_function_constructor = jit_function_constructor @@ -63,12 +64,11 @@ class LoadProvider: class ExternalBinaryModule: """The exported binary module is a wrapper of the previous exported object files. It is used to load a object file - or a library in memory, allow function lookup and return the corresponding `JitCompiledFunction`. - """ + or a library in memory, allow function lookup and return the corresponding `JitCompiledFunction`.""" - load_provider: LoadProvider = None + load_provider: LoadProvider | None = None - def __init__(self, file_path: str, enable_tvm_ffi: bool = False): + def __init__(self, file_path: str, enable_tvm_ffi: bool = False) -> None: self.enable_tvm_ffi = enable_tvm_ffi assert self.load_provider is not None, ( "Load provider is not set for ExternalBinaryModule." @@ -83,14 +83,17 @@ class ExternalBinaryModule: object_file_content = f.read() except Exception as e: raise DSLRuntimeError(f"Failed to read object file {file_path}: {e}") - - useJitLink = not enable_tvm_ffi # Lifetime of the engine is same as the ExternalBinaryModule. + # Always use JITLink. MCJIT mishandles .o files with duplicate + # .text ELF sections (different permission flags), causing + # non-deterministic SIGSEGV in multi-process torchrun workloads. + # JITLink handles sections independently and correctly. + useJitLink = True self.engine = self.load_provider.execution_engine_constructor( object_file_content, shared_libs, useJitLink ) - def __getattr__(self, function_prefix: str) -> "JitCompiledFunction": + def __getattr__(self, function_prefix: str) -> Any: """Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error.""" if self.enable_tvm_ffi: try: @@ -105,42 +108,51 @@ class ExternalBinaryModule: f"Failed to load TVM FFI function {function_prefix}: {e}" ) + load_provider = cast(LoadProvider, self.load_provider) try: - args_spec, function_name, kernel_info, version_str = ( + signature, function_name, kernel_info, version_str = ( decode_metadata_from_execution_engine( - function_prefix, self.engine, self.load_provider.args_spec_processor + function_prefix, self.engine, load_provider.signature_processor ) ) except Exception as e: raise DSLRuntimeError( f"Function prefix {function_prefix} not found in the module.", cause=e ) - self.load_provider.version_checker(version_str) + load_provider.version_checker(version_str) capi_func_p = self.engine.lookup(function_name) if not capi_func_p: raise DSLRuntimeError( - "Unknown function: " + "Unknown function: " # type: ignore[operator] + "_mlir_" + function_prefix + "__mlir_ciface_" + function_name ) - return_type = _get_ctypes_return_type(args_spec) + return_type = _get_ctypes_return_type(signature) capi_func = ctypes.CFUNCTYPE(return_type, ctypes.c_void_p)(capi_func_p) - jit_function = self.load_provider.jit_function_constructor( + try: + has_gpu_module = ( + self.engine.lookup("_mlir_" + function_prefix + "_cuda_init") + is not None + ) + except Exception as e: + has_gpu_module = False + jit_function = load_provider.jit_function_constructor( ir_module=None, engine=self.engine, capi_func=capi_func, - args_spec=args_spec, + signature=signature, function_name=function_name, kernel_info=kernel_info, - jit_time_profiling=self.load_provider.dsl._get_dsl().envar.jit_time_profiling, + jit_time_profiling=load_provider.dsl._get_dsl().envar.jit_time_profiling, jit_function_artifacts=None, prefix=function_prefix, load_from_binary=True, + has_gpu_module=has_gpu_module, ) return jit_function - def __getitem__(self, function_prefix: str) -> "JitCompiledFunction": + def __getitem__(self, function_prefix: str) -> Any: """Get the jit_function from the `function_prefix`. The `function_prefix` is specified when users dump the object file. When there is no function_prefix found in the module, the function will raise an error.""" return self.__getattr__(function_prefix) diff --git a/python/CuTeDSL/cutlass/base_dsl/ffi.py b/python/CuTeDSL/cutlass/base_dsl/ffi.py new file mode 100644 index 000000000..ecab46590 --- /dev/null +++ b/python/CuTeDSL/cutlass/base_dsl/ffi.py @@ -0,0 +1,725 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +Foreign Function Interface (FFI) for Base DSL. + +This module provides infrastructure for calling external functions from DSL code, +with support for: +- Dynamic type dispatch via ExternCallHandler +- Multiple overload resolution (concrete types, TypeVars, Union types) +- Automatic MLIR prototype generation and caching +- Name mangling for overloaded functions +- Bitcode linking integration +- Implicit type conversions (e.g., ptr address space casts) + +Architecture: +- `extern`: Decorator creating ExternCallHandler instances for dynamic dispatch +- ExternCallHandler: Resolves concrete types at call time, creates FFI instances +- FFI: Core class managing MLIR prototypes and call emission +- ConstValue: Represents compile-time constants in type signatures +- BitCode: Specifies external bitcode sources for linking + +Usage: + `@extern`(source=BitCode("mylib.bc")) + def external_func(x: Int32) -> Float32: + ... +""" + +import typing +from types import UnionType +from typing import TypeVar, Any, Union +import inspect +from dataclasses import dataclass +import string + +from .._mlir import ir +from .._mlir.dialects import func, gpu, llvm + +from . import typing as t +from .typing import get_mlir_types, NumericMeta, as_numeric +from .dsl import extract_mlir_values +from .common import DSLRuntimeError + + +@dataclass(frozen=True) +class ConstValue: + """Represents a constant value and its MLIR types""" + + types: tuple[ir.Type] + value: Any + + +@dataclass(frozen=True) +class BitCode: + """ + Specifies an external bitcode file to link when compiling. + + Attributes: + path: Filesystem path to the .bc (LLVM bitcode) file. + """ + + path: str + + +ALLOWED = set(string.ascii_letters + string.digits + "_") + + +def mangle(name: str) -> str: + """Mangle a string to be a valid function symbol""" + return "".join((c if c in ALLOWED else f"_{ord(c):02X}") for c in name) + + +def to_types(t: Any) -> set[Any]: + """Convert a generic type to a set of possible types""" + if typing.get_origin(t) is None: + return {t} + if typing.get_origin(t) is UnionType or typing.get_origin(t) is Union: + return set(typing.get_args(t)) + return {t} + + +def default_name_mangler(self: "FFI") -> str: + """Given an ffi object, generate the mangled symbol name. Includes: + - Function name + - Types + - Values of constexprs + """ + parts = [] + for typ in self.params_types: + if isinstance(typ, ConstValue): + parts.append("_".join(mangle(str(x)) for x in typ.types)) + if not isinstance(typ.value.value, int): + raise DSLRuntimeError( + f"constexpr of type {type(typ.value)} not supported for ffi" + ) + parts.append(mangle(str(typ.value.value))) + else: + parts.append(mangle(str(typ))) + return f"{self.name}_{'_'.join(parts)}" + + +def type_is_concrete(typ: Any) -> bool: + """Determine if a type is concrete, i.e. equivalent to a single MLIR type""" + if isinstance(typ, UnionType): + return False + if isinstance(typ, TypeVar): + return False + if typing.get_origin(typ) is not None: + return False + if hasattr(typ, "mlir_type"): + return True + if hasattr(typ, "__get_mlir_types__"): + return True + raise DSLRuntimeError(f"cannot determine if type is concrete: {typ}") + + +def is_concrete(func: Any) -> bool: + signature = inspect.signature(func) + params_types = [ + param.annotation if param.annotation is not inspect.Parameter.empty else Any + for param in signature.parameters.values() + ] + + if Any in params_types: + return False + + for typ in params_types: + if not type_is_concrete(typ): + return False + + return True + + +def _arg_to_mlir_types(arg: Any) -> list[Any]: + """ + Helper method to convert an argument to its corresponding MLIR types. + + This method converts numeric meta types and types convertible via `get_mlir_types` + to their corresponding MLIR types. + :param arg: The argument to convert to MLIR types. + + :returns: + A list of MLIR types. + :rtype: list + """ + if isinstance(arg, ir.Type): + return [arg] + elif isinstance(arg, NumericMeta): + return [arg.mlir_type] + return get_mlir_types(arg) + + +def _args_to_mlir_types(args: list[Any]) -> list[Any]: + """ + Helper method to convert an arguments list to its corresponding MLIR types. + + This method converts numeric meta types and types convertible via `get_mlir_types` + to their corresponding MLIR types. + :param args: The arguments list to convert to MLIR types. + :type: list + + :returns: + A list of MLIR types. + :rtype: list + """ + result = [] + for arg in args: + result.extend(_arg_to_mlir_types(arg)) + return result + + +class ExternCallHandler: + """ + Dynamic dispatcher for FFI calls with runtime type resolution. + + Resolves the concrete FFI overload at call time based on argument types. + Supports: + - Multiple `@overload` variants + - TypeVar binding and unification + - Union type matching (tries each alternative) + - Constexpr parameters (compile-time constants in signatures) + - Implicit conversions via custom callback + + Caches FFI instances per concrete type signature to avoid re-creating + prototypes. + """ + + def __init__( + self, + func: Any, + name: str, + inline: bool, + source: BitCode | None, + name_mangler: Any, + overloaded: bool | None, + implicit_convert: Any, + ) -> None: + self.func = func + self.name = name + self.inline = inline + self.source = source + self.name_mangler = name_mangler + self.ffis: dict[Any, Any] = {} + + self.inited = False + self.overloads: list[Any] | None = None + self.overloaded = overloaded + self.implicit_convert = implicit_convert + + def _init(self) -> None: + if self.inited: + return + self.inited = True + + # Note: don't do this in the constructor as MLIR context doesn't exist yet + self.overloads = typing.get_overloads(self.func) # type: ignore[attr-defined] + assert isinstance(self.overloads, list) + if len(self.overloads) == 0: + self.overloads.append(self.func) + + if self.overloaded is None: + self.overloaded = False + if len(self.overloads) > 1: + self.overloaded = True + elif not is_concrete(self.overloads[0]): + self.overloaded = True + + def try_match( + self, args: tuple[Any, ...], overload: Any + ) -> tuple[bool, list[Any] | None, list[Any] | None]: + """ + Attempt to match runtime arguments against an overload signature. + + Args: + args: Runtime argument values + overload: Function overload to match against + + Returns: + Tuple (matched: bool, params_types: list | None, return_types: list | None) + + Matching rules: + - `Any` type always matches + - TypeVar creates binding on first occurrence, checks equality on subsequent + - Union types try each alternative + - Constexpr wraps parameter as ConstValue (compile-time constant) + - NumericMeta, types with .isinstance(), and Python types checked via isinstance + + Edge cases: + - Returns (False, None, None) if any parameter fails to match + - TypeVar bindings are per-call (not cached across calls) + """ + signature = inspect.signature(overload) + params_types = [ + param.annotation if param.annotation is not inspect.Parameter.empty else Any + for param in signature.parameters.values() + ] + return_type = ( + signature.return_annotation + if signature.return_annotation is not inspect.Parameter.empty + else None + ) + + type_var_mapping = {} + params_types_info = [] + + if len(args) != len(params_types): + return False, None, None + + for arg, typ in zip(args, params_types): + # no type always matches, same as a free type variable + if typ is Any: + params_types_info.append((_arg_to_mlir_types(arg), arg, False)) + continue + + match = False + type_set = to_types(typ) + for typ in type_set: + is_const_expr = False + if typing.get_origin(typ) == t.Constexpr: + typ = typing.get_args(typ)[0] + is_const_expr = True + + if isinstance(typ, NumericMeta): + if typ.isinstance(arg): + if isinstance(arg, (int, float)): + arg = typ(arg) + params_types_info.append( + (_arg_to_mlir_types(arg), arg, is_const_expr) + ) + match = True + elif isinstance(typ, TypeVar): + if typ not in type_var_mapping: + type_var_mapping[typ] = _arg_to_mlir_types(arg) + params_types_info.append( + (type_var_mapping[typ], arg, is_const_expr) + ) + match = True + else: + if type_var_mapping[typ] == _arg_to_mlir_types(arg): + params_types_info.append( + (type_var_mapping[typ], arg, is_const_expr) + ) + match = True + elif hasattr(typ, "isinstance"): + if typ.isinstance(arg): + params_types_info.append( + (_arg_to_mlir_types(arg), arg, is_const_expr) + ) + match = True + else: + if isinstance(arg, typ): + params_types_info.append( + (_arg_to_mlir_types(arg), arg, is_const_expr) + ) + match = True + if match: + break + + if not match: + return False, None, None + + concrete_params_types = [] + for types, arg, is_const_expr in params_types_info: + if not is_const_expr: + concrete_params_types.extend(types) + else: + concrete_params_types.append(ConstValue(tuple(types), arg)) + + if return_type in type_var_mapping: + concrete_return_types = type_var_mapping[return_type] + else: + concrete_return_types = _arg_to_mlir_types(return_type) + return True, concrete_params_types, concrete_return_types + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + self._init() + assert self.overloads is not None + + matched = False + params_types = None + return_types = None + for overload in self.overloads: + result, params_types, return_types = self.try_match(args, overload) + if result: + matched = True + break + + if not matched: + raise DSLRuntimeError("failed to find matching overload for call to ffi") + + assert params_types is not None + assert return_types is not None + + if len(return_types) == 1: + return_type = return_types[0] + elif len(return_types) == 0: + return_type = None + else: + raise DSLRuntimeError("multiple return types not supported") + + key = tuple(params_types) + if key not in self.ffis: + self.ffis[key] = FFI( + name=self.name, + overloaded=self.overloaded, # type: ignore[arg-type] + params_types=params_types, + return_type=return_type, + inline=self.inline, + source=self.source, + name_mangler=self.name_mangler, + implicit_convert=self.implicit_convert, + ) + + return self.ffis[key](*args, **kwargs) + + +def extern( + func: Any = None, + *, + name: str | None = None, + inline: bool = True, + source: BitCode | None = None, + name_mangler: Any = None, + overloaded: bool | None = None, + implicit_convert: Any = None, +) -> Any: + """ + Decorator to mark a function as an external FFI call. + + Calls to the function dynamically resolve to a concrete extern function based on runtime + argument types. + + Parameters + ---------- + name : str, optional + External symbol name. Defaults to Python function's name. + inline : bool, default=True + Whether to mark the function and call sites for inlining. + source : BitCode, optional + External bitcode file to link (e.g., BitCode("lib.bc")). + name_mangler : callable, optional + Custom name mangling function. Defaults to `default_name_mangler`. + overloaded : bool, optional + Whether to enable name mangling. Auto-detected if None (True if multiple + `@overload` variants or non-concrete signature). + implicit_convert : callable, optional + Custom callback for implicit type conversions (signature: (arg, typ) -> arg). + + Returns + ------- + A callable that dynamically dispatches to the correct FFI overload. + + Examples + -------- + Basic usage: + + >>> `@extern` + ... def my_func(x: Int32) -> Float32: + ... ... + + With bitcode linking: + + >>> `@extern`(source=BitCode("mylib.bc")) + ... def external_sqrt(x: Float32) -> Float32: + ... ... + + Multiple overloads: + + >>> `@extern` + ... `@overload` + ... def compute(x: Int32) -> Int32: + ... ... + >>> `@overload` + ... def compute(x: Float32) -> Float32: + ... ... + + TypeVar-based generic: + + >>> T = TypeVar('T') + >>> `@extern` + ... def identity(x: T) -> T: + ... ... + + """ + + def decorator(func: Any) -> ExternCallHandler: + return ExternCallHandler( + func, + name or func.__name__, + inline, + source, + name_mangler or default_name_mangler, + overloaded, + implicit_convert, + ) + + if func is None: + return decorator + return decorator(func) + + +class FFI: + """ + Foreign Function Interface (FFI) wrapper for external function invocation. + + This class enables calling external MLIR function prototypes from Python code, handling type conversion, + prototype registration, and dynamic insertion of function symbols into MLIR modules as needed. + + Parameters + ---------- + name : str + Name of the external function. This will be used as the symbol name when calling or registering a prototype in the MLIR module. + params_types : list, optional + List of argument types for the external function. These can be numeric types, numeric meta types, or types convertible via `get_mlir_types`. + return_type : optional + The return type of the external function. If not specified, the function is assumed to have no return value. + inline : bool + Whether the prototype in the MLIR module and all calls to it should be marked as inlined. Default is to inline. + source : optional + Optional source to link when compiling, that provides the implementation of the function. + + Methods + ------- + __call__(*args) + Calls the external function with the given arguments, ensuring argument and result types match the prototype. + """ + + def __init__( + self, + *, + name: str | None, + params_types: list[Any] | None = None, + return_type: Any = None, + inline: bool = True, + source: Any = None, + overloaded: bool = False, + name_mangler: Any = None, + implicit_convert: Any = None, + ) -> None: + self.name = name + self.params_types = params_types or [] + self.return_type = [return_type] if return_type else [] + self.inline = inline + self.source = source + self.overloaded = overloaded + self.name_mangler = name_mangler or default_name_mangler + self.implicit_convert = implicit_convert + + def _get_prototype_region(self, current_op: Any) -> tuple[Any, Any]: + """ + Helper method to determine the appropriate MLIR module and region for inserting a function prototype. + + This method recursively traverses the current operation's parent hierarchy to find the correct module + and region where the function prototype should be inserted. It supports both builtin.module and gpu.module. + :param current_op: The current operation to check. + :type current_op: Operation + + :returns: + A tuple containing the module operation and the insertion region. + :rtype: tuple + """ + if current_op is None: + raise DSLRuntimeError("current operation is unknown") + op_name = current_op.name + if op_name in ["builtin.module", "gpu.module"]: + return current_op, current_op.regions[0].blocks[0] + else: + return self._get_prototype_region(current_op.parent) + + @staticmethod + def _type_check( + callee: Any, exec_types: list[Any], returns_types: list[Any] + ) -> None: + """ + Helper method to check if the function prototype types match the expected types. + + This method compares the input and output types of the function prototype with the provided expected types. + :param callee: The function prototype operation to check. + :type callee: func.FuncOp + :param exec_types: The expected input types. + :type exec_types: list + :param returns_types: The expected output types. + :type returns_types: list + """ + if callee.type.inputs != exec_types or callee.type.results != returns_types: + raise DSLRuntimeError( + f"External prototype types mismatch, trying to call with ({exec_types}) -> ({returns_types}), got {callee.type}\n{callee}" + ) + + @property + def full_name(self) -> str | None: + if not self.overloaded: + return self.name + return self.name_mangler(self) + + @property + def dynamic_params_types(self) -> list[Any]: + return [x for x in self.params_types if not isinstance(x, ConstValue)] + + def _create_prototype_in_region( + self, op: Any, region: Any, exec_args: list[Any] + ) -> Any: + """ + Helper method to create or retrieve a function prototype in the current module. + + This method checks if a function prototype with the given name already exists in the symbol table of the current module. + If it does, it checks if the prototype's types match the expected types. If it does not, it raises an error. + If it does not exist, it creates a new function prototype and inserts it into the current region. + :param op: The module operation to check. + :type op: Operation + :param region: The region to insert the function prototype into. + :type region: Region + :param exec_args: The arguments to pass to the function prototype. + :type exec_args: list + """ + symbol_table = ir.SymbolTable(op.operation) + + if self.full_name in symbol_table: + callee = symbol_table[self.full_name] + else: + with ir.InsertionPoint(region): + if self.source is not None and not isinstance(self.source, BitCode): + raise DSLRuntimeError( + "Linking this kind of source is not supported", + ) + + if isinstance(self.source, BitCode): + # if extern function implementation is in bitcode file, + # add it to link-libraries list to be passed to compiler options later + sources = set() + if "link-libraries" in region.owner.attributes: + sources = set( + x.value for x in region.owner.attributes["link-libraries"] + ) + sources.add(self.source.path) + region.owner.attributes["link-libraries"] = ir.ArrayAttr.get( + [ir.StringAttr.get(x) for x in sorted(sources)] + ) + + callee = func.FuncOp( + self.full_name, + ( + _args_to_mlir_types(self.dynamic_params_types), + _args_to_mlir_types(self.return_type), + ), + ) + callee.sym_visibility = ir.StringAttr.get("private") + callee.no_inline = not self.inline + + # Sanity check the function prototype types match the expected types + self._type_check( + callee, + _args_to_mlir_types(exec_args), + _args_to_mlir_types(self.return_type), + ) + + return callee + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Calls the FFI function prototype with the provided arguments. + + This method ensures that an IR-level function prototype (external declaration) + with the given name and type signature exists in the current module. If it does not + exist, it will be created and inserted into the module. A call operation to this + function is then emitted using the arguments supplied by the caller. + + :param args: + The runtime arguments to pass to the FFI function. These will be converted to + their corresponding numeric types and lowered to MLIR values before being used as arguments. + :type args: tuple + + :returns: + The MLIR call operation created for this invocation. + :rtype: func.CallOp + + :raises DSLRuntimeError: + If there is no active MLIR insertion point or if the current operation + context cannot be determined. + """ + + if kwargs: + raise DSLRuntimeError( + "Keyword arguments are not supported for FFI calls", + suggestion="Use positional arguments only", + ) + + # Get the current insertion point and operation + try: + current_ip = ir.InsertionPoint.current + except Exception: + raise DSLRuntimeError( + "Failed to determine current insertion point", + suggestion="Make sure this is called under a jit context", + ) + current_op = current_ip.block.owner + module_op, insertion_region = self._get_prototype_region(current_op) + + if len(args) != len(self.params_types): + raise DSLRuntimeError( + f"Number of arguments mismatch, expected {len(self.params_types)}, got {len(args)}", + suggestion="Make sure the number of arguments matches the number of parameters", + ) + + # Extract the arguments to MLIR values + exec_args = [] + for arg, typ in zip(args, self.params_types): + if isinstance(typ, ConstValue): + continue + exec_arg = extract_mlir_values(arg) + if not exec_arg: + exec_arg = [as_numeric(arg).ir_value()] + exec_arg = self._do_implicit_conversion(exec_arg, _arg_to_mlir_types(typ)) + exec_args.extend(exec_arg) + + # Create the function prototype in module, so if it's under kernel function, prototype will be inserted into gpu.module + # If it's under gpu.module, prototype will be inserted into builtin.module + callee = self._create_prototype_in_region( + module_op, insertion_region, exec_args + ) + + # Emit the call operation + result = func.call( + callee.type.results, self.full_name, exec_args, no_inline=not self.inline + ) + + if self.return_type: + return result + + def _do_implicit_conversion(self, arg: list[Any], typ: list[Any]) -> list[Any]: + if self.implicit_convert is not None: + arg = self.implicit_convert(arg, typ) + + if len(arg) == 1 and len(typ) == 1: + arg_type = arg[0].type + typ_type = typ[0] + # implicitly cast !llvm.ptr -> !llvm.ptr<> + if ( + isinstance(typ_type, llvm.PointerType) + and isinstance(arg_type, llvm.PointerType) + and typ_type.address_space != arg_type.address_space + and typ_type.address_space == 0 + ): + llvm_ptr_ty = llvm.PointerType.get(0) + llvm_ptr = llvm.addrspacecast(llvm_ptr_ty, arg[0]) + arg = [llvm_ptr] + + return arg + + +__all__ = [ + "extern", + "FFI", + "BitCode", + "mangle", + "ConstValue", +] diff --git a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py index b702493f6..2c3d4ac2c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/jit_executor.py +++ b/python/CuTeDSL/cutlass/base_dsl/jit_executor.py @@ -13,11 +13,13 @@ This module provides jit executor related classes """ +import abc import array import ctypes import inspect import io -from typing import Union, Optional, NamedTuple, Any, Sequence +from typing import Any, NamedTuple, TYPE_CHECKING, ClassVar +from collections.abc import Callable, Sequence import weakref import threading import collections @@ -33,26 +35,31 @@ from .._mlir.dialects import llvm from . import typing as t from .common import DSLRuntimeError, DSLCudaRuntimeError from .runtime import cuda as cuda_helpers -from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr +from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_annotation_constexpr from .typing import get_c_pointers from .utils.logger import log from .utils.timer import timer +if TYPE_CHECKING: + from .dsl import BaseDSL + from .export.export import SignatureProcessor + from .export.c_header_generator import CHeaderGenerator + +@dataclass class CudaModuleAndKernel: """A loaded CUDA kernel and its metadata.""" - def __init__(self, sym, cuda_module, kernel, attrs): - self.sym = sym - self.cuda_module = cuda_module - self.kernel = kernel - self.attrs = attrs + sym: str + cuda_module: Any + kernel: Any + attrs: dict[Any, int] -def get_escaped_cubin_bytes(cubin_data): +def get_escaped_cubin_bytes(cubin_data: bytes) -> bytes: """This function escapes cubin data from mlir raw bytecode to executable binary bytes""" - def ishex(inp): + def ishex(inp: int) -> bool: return (0x30 <= inp < 0x3A) or (0x41 <= inp < 0x47) or (0x61 <= inp < 0x67) converted = bytearray() @@ -74,10 +81,12 @@ def get_escaped_cubin_bytes(cubin_data): return bytes(converted) -def walk_module_and_get_cubin_data(module, sym, callback): +def walk_module_and_get_cubin_data( + module: ir.Module, sym: str, callback: Callable[[str, str, bytes], None] +) -> None: """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback.""" - def walk_gpu_binary_op(op): + def walk_gpu_binary_op(op: ir.Operation) -> ir.WalkResult: if op.name != "gpu.binary": return ir.WalkResult.ADVANCE s = io.BytesIO() @@ -101,7 +110,9 @@ def walk_module_and_get_cubin_data(module, sym, callback): module.operation.walk(walk_gpu_binary_op) -def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel]: +def load_kernels_from_ir_module( + module: ir.Module, kernel_info: dict[str, Any] | None +) -> list[CudaModuleAndKernel]: """Loads all kernels from the IR module that match the given set of symbols.""" if not kernel_info: return [] # no modules @@ -114,7 +125,7 @@ def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel for sym in kernel_symbols: log().debug(f"Loading CUDA module for symbol: {sym}") - def walk_callback(sym, func_sym, cubin_data): + def walk_callback(sym: str, func_sym: str, cubin_data: bytes) -> None: if sym in kernel_modules: log().debug(f"Skipping already loaded symbol: {sym}") @@ -136,15 +147,6 @@ def load_kernels_from_ir_module(module, kernel_info) -> list[CudaModuleAndKernel return list(kernel_modules.values()) -class KwargsWrapperSpec(NamedTuple): - """A specification for keyword arguments wrapper.""" - - arg_names: list[str] - arg_defaults: tuple[Any, ...] - kwonly_names: list[str] - kwonly_defaults: dict[str, Any] - - @dataclass class ArgMeta: """Metadata for function arguments.""" @@ -160,123 +162,77 @@ class ArgMeta: arg_count: int +class KwargsWrapperSpec(NamedTuple): + """A specification for keyword arguments wrapper.""" + + arg_names: list[str] + arg_defaults: tuple[Any, ...] + kwonly_names: list[str] + kwonly_defaults: dict[str, Any] + + class ExecutionArgs: """Helper that wraps the function signature spec to filter execution and compile time arguments.""" - def __init__(self, spec, function_name): + def __init__(self, signature: inspect.Signature, function_name: str) -> None: self.function_name = function_name - self.args_spec = spec - if spec is not None: - self.args_spec = self.filter_runtime_arg_spec(spec) - self.original_args_spec = spec + self.signature = self.filter_runtime_signature(signature) + self.original_signature = signature self._missing = object() self._meta = self._build_meta() self._tls = threading.local() - def _build_meta(self): + def _build_meta(self) -> ArgMeta: """ Precompute metadata for the fast-path execution. This metadata is static per function signature. """ - spec = self.args_spec - if spec is None: - return ArgMeta( - pos_names=[], - kwonly_names=[], - all_names=[], - annotated_types=[], - numeric_flags=[], - name_to_index={}, - pos_defaults=[], - kwonly_defaults=[], - arg_count=0, + sig = self.signature + + pos_names = [] + kwonly_names = [] + annotated_types = [] + numeric_flags = [] + name_to_index = {} + pos_defaults = [] + kwonly_defaults = [] + + for name, param in sig.parameters.items(): + if param.kind == inspect.Parameter.KEYWORD_ONLY: + kwonly_names.append(name) + if param.default is not inspect.Parameter.empty: + kwonly_defaults.append(param.default) + else: + kwonly_defaults.append(self._missing) + else: + pos_names.append(name) + if param.default is not inspect.Parameter.empty: + pos_defaults.append(param.default) + else: + pos_defaults.append(self._missing) + annotated_types.append( + param.annotation + if param.annotation is not inspect.Parameter.empty + else None ) - - pos_names = list(spec.args) - kwonly_names = list(spec.kwonlyargs) - all_names = pos_names + kwonly_names - - annotated_types = [spec.annotations.get(n) for n in all_names] - numeric_flags = [isinstance(typ, t.NumericMeta) for typ in annotated_types] - - name_to_index = {n: i for i, n in enumerate(all_names)} - - pos_defaults = [self._missing] * len(pos_names) - if spec.defaults: - start = len(pos_names) - len(spec.defaults) - for i, d in enumerate(spec.defaults): - pos_defaults[start + i] = d - - kwonly_defaults = [self._missing] * len(kwonly_names) - if spec.kwonlydefaults: - for i, n in enumerate(kwonly_names): - if n in spec.kwonlydefaults: - kwonly_defaults[i] = spec.kwonlydefaults[n] + numeric_flags.append(isinstance(param.annotation, t.NumericMeta)) + name_to_index[name] = len(pos_names) + len(kwonly_names) - 1 return ArgMeta( pos_names=pos_names, kwonly_names=kwonly_names, - all_names=all_names, + all_names=pos_names + kwonly_names, annotated_types=annotated_types, numeric_flags=numeric_flags, name_to_index=name_to_index, pos_defaults=pos_defaults, kwonly_defaults=kwonly_defaults, - arg_count=len(all_names), + arg_count=len(pos_names) + len(kwonly_names), ) - def generate_execution_args_positional(self, *args): - """Fast execution for positional-only arguments with exact count match. - - This method is optimized for the common case where: - - All arguments are positional (no kwargs) - - Number of arguments matches the function signature exactly - - :param args: The positional arguments tuple - :return: (exe_args, adapted_args) tuple - """ - exe_arg_chunks = [None] * self._meta.arg_count - adapted_args = [] - - tls = self._tls - adapter_caches = getattr(tls, "adapter_caches", None) - if adapter_caches is None: - adapter_caches = [dict() for _ in range(self._meta.arg_count)] - tls.adapter_caches = adapter_caches - - annotated_types = self._meta.annotated_types - numeric_flags = self._meta.numeric_flags - - for index in range(self._meta.arg_count): - arg = args[index] - - cptr_method = getattr(arg, "__c_pointers__", None) - if cptr_method is not None: - exe_arg_chunks[index] = cptr_method() - continue - - if numeric_flags[index]: - arg = t.cast(arg, annotated_types[index]) - exe_arg_chunks[index] = get_c_pointers(arg) - else: - arg_type = type(arg) - cache = adapter_caches[index] - adapter = cache.get(arg_type) - if adapter is None: - adapter = JitArgAdapterRegistry.get_registered_adapter(arg_type) - if adapter is not None: - cache[arg_type] = adapter - - if adapter is not None: - arg = adapter(arg) - adapted_args.append(arg) - - exe_arg_chunks[index] = get_c_pointers(arg) - - exe_args = [p for chunk in exe_arg_chunks for p in chunk] - return exe_args, adapted_args - - def get_rectified_args(self, args, kwargs): + def get_rectified_args( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> list[Any]: """ This function is used to rectify the args and kwargs to a final runtime argument list according to the args_spec. """ @@ -315,7 +271,7 @@ class ExecutionArgs: # Fill keyword slots with the values from the caller for name, value in kwargs.items(): - idx = self._meta.name_to_index.get(name) + idx = self._meta.name_to_index.get(name) # type: ignore[assignment] if idx is None: raise DSLRuntimeError( "unexpected keyword argument", @@ -347,23 +303,31 @@ class ExecutionArgs: return rectified - def generate_execution_args(self, args, kwargs): + def generate_execution_args( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[list[Any], list[Any]]: """ This function is the prune version of `generate_mlir_function_types` which only generates execution args to get rid of mlir context. """ - if not kwargs and len(args) == self._meta.arg_count: - return self.generate_execution_args_positional(*args) + n = self._meta.arg_count + extra_args = args[n:] if len(args) > n else () + args = args[:n] if len(args) > n else args + + exe_arg_chunks: list[list[ctypes.c_void_p] | None] = [None] * n - exe_arg_chunks = [None] * self._meta.arg_count adapted_args = [] tls = self._tls adapter_caches = getattr(tls, "adapter_caches", None) if adapter_caches is None: - adapter_caches = [dict() for _ in range(self._meta.arg_count)] + adapter_caches = [dict() for _ in range(n)] tls.adapter_caches = adapter_caches - input_args = self.get_rectified_args(args, kwargs) + input_args: Sequence[Any] + if not kwargs and len(args) == n: + input_args = args + else: + input_args = self.get_rectified_args(args, kwargs) for index, arg in enumerate(input_args): cptr_method = getattr(arg, "__c_pointers__", None) @@ -371,17 +335,15 @@ class ExecutionArgs: exe_arg_chunks[index] = cptr_method() continue - arg_type_anno = self._meta.annotated_types[index] - if self._meta.numeric_flags[index]: - arg = t.cast(arg, arg_type_anno) + arg = t.cast(arg, self._meta.annotated_types[index]) # type: ignore[arg-type] exe_arg_chunks[index] = get_c_pointers(arg) else: arg_type = type(arg) cache = adapter_caches[index] adapter = cache.get(arg_type) if adapter is None: - adapter = JitArgAdapterRegistry.get_registered_adapter(arg_type) + adapter = JitArgAdapterRegistry.get_registered_adapter(arg) if adapter is not None: cache[arg_type] = adapter @@ -391,7 +353,15 @@ class ExecutionArgs: exe_arg_chunks[index] = get_c_pointers(arg) - exe_args = [p for chunk in exe_arg_chunks for p in chunk] + exe_args = [p for chunk in exe_arg_chunks for p in chunk] # type: ignore[union-attr] + + # Insert extra auxiliary arguments if any. + for arg in extra_args: + cptr_method = getattr(arg, "__c_pointers__", None) + if cptr_method is not None: + exe_args.extend(cptr_method()) + else: + exe_args.append(arg) return exe_args, adapted_args @@ -399,49 +369,30 @@ class ExecutionArgs: self, exclude_arg_names: Sequence[str] = () ) -> KwargsWrapperSpec: """ - This function is used to get the kwargs wrapper spec from the original args_spec. + This function is used to get the kwargs wrapper spec from the original signature. """ excluded_arg_names = set(exclude_arg_names) - arg_spec = self.original_args_spec - - if arg_spec.defaults: - defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) - else: - defaults_start_idx = len(arg_spec.args) + sig = self.original_signature arg_names = [] arg_defaults = [] kwonly_names = [] kwonly_defaults = {} - # Filter arguments and maintain their properties - for i, arg_name in enumerate(arg_spec.args): - arg_type = arg_spec.annotations.get(arg_name, None) - - # Skip compile-time arguments - if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): + for i, (name, param) in enumerate(sig.parameters.items()): + is_kwonly = param.kind == inspect.Parameter.KEYWORD_ONLY + annotation = param.annotation + if ( + is_arg_annotation_constexpr(annotation, name, i, None) + or name in excluded_arg_names + ): continue - if arg_name in excluded_arg_names: - continue - arg_names.append(arg_name) - - if i >= defaults_start_idx: - arg_defaults.append(arg_spec.defaults[i - defaults_start_idx]) - - if arg_spec.kwonlyargs: - for i, kwarg in enumerate(arg_spec.kwonlyargs): - arg_type = arg_spec.annotations.get(kwarg, None) - - # Skip compile-time arguments - if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): - continue - - if kwarg in excluded_arg_names: - continue - - kwonly_names.append(kwarg) - if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: - kwonly_defaults[kwarg] = arg_spec.kwonlydefaults[kwarg] + arg_names.append(name) if not is_kwonly else kwonly_names.append(name) + if param.default is not inspect.Parameter.empty: + if is_kwonly: + kwonly_defaults[name] = param.default + else: + arg_defaults.append(param.default) return KwargsWrapperSpec( arg_names=arg_names, @@ -450,165 +401,100 @@ class ExecutionArgs: kwonly_defaults=kwonly_defaults, ) - def get_rectified_args_from_original_args(self, full_args, full_kwargs): + def get_rectified_args_from_original_args( + self, full_args: Sequence[Any], full_kwargs: dict[str, Any] + ) -> tuple[Any]: """ This function is used to rectify the original arguments to the runtime - arguments that matched the original args_spec. + arguments that matched the original signature. :param full_args: The original full arguments to filter. :param full_kwargs: The original full keyword arguments to filter. :return: The filtered arguments and keyword arguments. """ - arg_spec = self.original_args_spec + sig = self.original_signature + try: + bound_args = sig.bind_partial(*full_args, **full_kwargs) + bound_args.apply_defaults() + except Exception as e: + raise DSLRuntimeError( + "failed to bind arguments to function signature", + cause=e, + ) - if arg_spec.defaults: - defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) - else: - defaults_start_idx = len(arg_spec.args) - - runtime_args = [] - - # Filter arguments and maintain their properties - for i, arg_name in enumerate(arg_spec.args): - arg_type = arg_spec.annotations.get(arg_name, None) - - # Skip compile-time arguments - if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): + # Filter out the constexpr arguments + for i, (name, param) in enumerate(sig.parameters.items()): + if is_arg_annotation_constexpr(param.annotation, name, i, None): + bound_args.arguments.pop(name) continue - # Check if argument was provided by user, otherwise use default - if i < len(full_args): - # User provided this argument - use it - runtime_args.append(full_args[i]) - elif i >= defaults_start_idx: - # Argument not provided, but has default - use default - default_idx = i - defaults_start_idx - runtime_args.append(arg_spec.defaults[default_idx]) - else: - # Required argument missing - raise DSLRuntimeError( - f"Missing required argument '{arg_name}' at position {i}", - context={ - "function_name": self.function_name, - "expected_args": len(arg_spec.args), - "provided_args": len(full_args), - }, - ) + # Once the constexpr arguments are filtered out, we need to convert the bound arguments to the signature's type. + bound_args = inspect.BoundArguments(self.signature, bound_args.arguments) + runtime_args = bound_args.args + runtime_kwargs = bound_args.kwargs - # Filter keyword-only arguments - runtime_kwargs = {} - if arg_spec.kwonlyargs: - for i, kwarg in enumerate(arg_spec.kwonlyargs): - arg_type = arg_spec.annotations.get(kwarg, None) + pos_count = sum( + 1 + for param in self.signature.parameters.values() + if param.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ) + kw_count = sum( + 1 + for param in self.signature.parameters.values() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ) - # Skip compile-time arguments - if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): - continue - - # Keep runtime keyword-only arguments - if kwarg in full_kwargs: - runtime_kwargs[kwarg] = full_kwargs[kwarg] - elif arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: - runtime_kwargs[kwarg] = arg_spec.kwonlydefaults[kwarg] - - if len(runtime_args) != len(self.args_spec.args) or len(runtime_kwargs) != len( - self.args_spec.kwonlyargs - ): + if len(runtime_args) != pos_count or len(runtime_kwargs) != kw_count: raise DSLRuntimeError( "input args/kwargs length does not match runtime function signature!", context={ "input args length": len(runtime_args), "input kwargs length": len(runtime_kwargs), - "function signature args length": len(self.args_spec.args), - "function signature kwonlyargs length": len( - self.args_spec.kwonlyargs - ), + "function signature args length": pos_count, + "function signature kwonlyargs length": kw_count, }, ) - return runtime_args + list(runtime_kwargs.values()) + return runtime_args + tuple(runtime_kwargs.values()) - def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec): - runtime_args = [] - runtime_annotations = {} - runtime_defaults = [] - - # Calculate the offset where defaults start in the original args - if arg_spec.defaults: - defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) - else: - defaults_start_idx = len(arg_spec.args) - - # Filter arguments and maintain their properties - for i, arg_name in enumerate(arg_spec.args): - arg_type = arg_spec.annotations.get(arg_name, None) - - # Skip compile-time arguments - if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): + def filter_runtime_signature(self, sig: inspect.Signature) -> inspect.Signature: + filtered_params = [] + for i, (name, param) in enumerate(sig.parameters.items()): + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + filtered_params.append(param) continue - # Keep runtime arguments - runtime_args.append(arg_name) - if arg_name in arg_spec.annotations: - runtime_annotations[arg_name] = arg_type + annotation = param.annotation - # Keep corresponding default if it exists - if i >= defaults_start_idx: - default_idx = i - defaults_start_idx - runtime_defaults.append(arg_spec.defaults[default_idx]) + if is_arg_annotation_constexpr(annotation, name, i, None): + continue - # Filter kwonlyargs and their defaults - runtime_kwonlyargs = [] - runtime_kwonlydefaults = {} + filtered_params.append(param) - if arg_spec.kwonlyargs: - for i, kwarg in enumerate(arg_spec.kwonlyargs): - arg_type = arg_spec.annotations.get(kwarg, None) + return sig.replace(parameters=filtered_params) - # Apply same filtering logic - if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): - continue - - runtime_kwonlyargs.append(kwarg) - if kwarg in arg_spec.annotations: - runtime_annotations[kwarg] = arg_type - if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: - runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg] - - # Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec) - runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None - - return inspect.FullArgSpec( - args=runtime_args, - varargs=arg_spec.varargs, # Keep original varargs - varkw=arg_spec.varkw, # Keep original varkw - defaults=runtime_defaults, - kwonlyargs=runtime_kwonlyargs, - kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None, - annotations=runtime_annotations, - ) - - def get_constexpr_args(self) -> list[dict[str, Union[int, str]]]: + def get_constexpr_args(self) -> list[dict[str, Any]]: """ This function returns the constexpr args that have been pruned from the original function signature. The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). - :rtype: list[dict[str, Union[int, str]]] + :rtype: list[dict[str, int | str | None]] """ - if self.original_args_spec is None: + if self.original_signature is None: return list() - constexpr_args = list() - for i, arg_name in enumerate(self.original_args_spec.args): - if arg_name not in self.args_spec.args: - constexpr_args.append({"argument_index": i, "argument_name": arg_name}) - if self.original_args_spec.kwonlyargs: - for kwarg in self.original_args_spec.kwonlyargs: - if kwarg not in self.args_spec.kwonlyargs: - constexpr_args.append( - {"argument_index": None, "argument_name": kwarg} - ) + constexpr_args = list() + for i, (name, param) in enumerate(self.original_signature.parameters.items()): + if name not in self.signature.parameters.keys(): + constexpr_args.append({"argument_index": i, "argument_name": name}) return constexpr_args @@ -618,13 +504,13 @@ class JitExecuteContext: def __init__( self, module: "JitModule", - kernel_fns=None, - context: Optional[cuda_helpers.DevicePrimaryContext] = None, - ): + kernel_fns: list[Any] | None = None, + context: cuda_helpers.DevicePrimaryContext | None = None, + ) -> None: if kernel_fns is None: kernel_fns = [] self.module = module - self.kernel_functions = kernel_fns + self.kernel_functions: list[Any] = kernel_fns self.kernel_functions_ptrs = [ctypes.c_void_p(k.getPtr()) for k in kernel_fns] self.context = context @@ -634,18 +520,18 @@ class JitModule: def __init__( self, - engine, - capi_func, - args_spec: ExecutionArgs, + engine: Any, + capi_func: Any, + execution_args: ExecutionArgs, modules: list[CudaModuleAndKernel], - ): + ) -> None: self.engine = engine self.capi_func = capi_func - self.args_spec = args_spec + self.execution_args = execution_args self.cuda_modules = modules self._unloaded = False - def get_device_execute_context(self, device=None) -> JitExecuteContext: + def get_device_execute_context(self, device: Any = None) -> JitExecuteContext: if self._unloaded: raise RuntimeError(f"Can not get executor for unloaded module.") @@ -682,7 +568,7 @@ class JitModule: # environment variable. return JitExecuteContext(self, kernel_fns, context) - def unload(self): + def unload(self) -> None: try: for m in set([m.cuda_module for m in self.cuda_modules]): cuda_helpers.unload_library(m) @@ -692,7 +578,7 @@ class JitModule: finally: self._unloaded = True - def __del__(self): + def __del__(self) -> None: self.unload() @@ -705,10 +591,10 @@ class JitExecutor: def __init__( self, - jit_module: Union[JitModule, "CudaDialectJitModule"], - exec_context: Optional[JitExecuteContext], + jit_module: JitModule | Any, + exec_context: JitExecuteContext | None, jit_time_profiling: bool, - ): + ) -> None: # JitExecutor will keep JitCompiledFunction alive so that the underlying # ExecutionEngine and module data is not discarded until runtime callables # are garbage collected. @@ -725,14 +611,14 @@ class JitExecutor: self._has_cuda_result = self.cuda_result is not None self._has_profiler = self.profiler is not None self._cuda_result_addr = ( - ctypes.addressof(self.cuda_result) if self._has_cuda_result else None + ctypes.addressof(self.cuda_result) if self._has_cuda_result else None # type: ignore[arg-type] ) if jit_time_profiling: - self._get_invoke_packed_args_func = self.profiler( + self._get_invoke_packed_args_func = self.profiler( # type: ignore[misc] self._get_invoke_packed_args ) - self.capi_func = self.profiler(self.jit_module.capi_func) + self.capi_func = self.profiler(self.jit_module.capi_func) # type: ignore[misc] else: self._get_invoke_packed_args_func = self._get_invoke_packed_args self.capi_func = self.jit_module.capi_func @@ -741,6 +627,7 @@ class JitExecutor: self._num_extra_args = 0 if self._has_cuda_result: self._num_extra_args += 1 + self._kernel_ptrs: list[ctypes.c_void_p] | None if self.exec_context is not None: self._kernel_ptrs = self.exec_context.kernel_functions_ptrs self._num_extra_args += len(self._kernel_ptrs) @@ -749,7 +636,9 @@ class JitExecutor: self._tls = threading.local() # Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`. - def _get_invoke_packed_args(self, exe_args): + def _get_invoke_packed_args( + self, exe_args: list[Any] + ) -> ctypes.Array[ctypes.c_void_p]: # Pre-calculate sizes once during init and cache num_base_args = len(exe_args) total_args = num_base_args + self._num_extra_args @@ -781,28 +670,34 @@ class JitExecutor: return packed_args - def generate_execution_args(self, *args, **kwargs): - return self.jit_module.args_spec.generate_execution_args(args, kwargs) + def generate_execution_args( + self, *args: Any, **kwargs: Any + ) -> tuple[list[Any], list[Any]]: + return self.jit_module.execution_args.generate_execution_args(args, kwargs) - def run_compiled_program(self, exe_args): + def run_compiled_program(self, exe_args: list[Any]) -> int | None: try: packed_args = self._get_invoke_packed_args_func(exe_args) self.capi_func(packed_args) if not self._has_cuda_result: return None - error_code = self.cuda_result.value + error_code = self.cuda_result.value # type: ignore[union-attr] if error_code == 0: return error_code - error_name = cuda_helpers._cudaGetErrorEnum( - cuda_helpers.cuda.CUresult(error_code) - ) + # Try to get the error name, but handle unknown error codes gracefully + try: + cu_result = cuda_helpers.cuda.CUresult(error_code) + error_name = cuda_helpers._cudaGetErrorEnum(cu_result) + except (ValueError, AttributeError): + # Error code not recognized by the enum or other error getting the name + error_name = f"" raise DSLCudaRuntimeError(error_code, error_name) except DSLCudaRuntimeError as e: raise e except Exception as e: raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> int | None: exe_args, adapted_args = self.generate_execution_args(*args, **kwargs) return self.run_compiled_program(exe_args) @@ -811,11 +706,11 @@ class JitExecutor: class JitFunctionArtifacts: """Holds artifacts for a JIT-compiled function.""" - PTX: str - CUBIN: str - MLIR: str + PTX: str | None + CUBIN: str | bytes | None + MLIR: str | None - def __post_init__(self): + def __post_init__(self) -> None: if self.PTX is not None and os.path.exists(self.PTX): try: with open(self.PTX, "r") as f: @@ -827,7 +722,7 @@ class JitFunctionArtifacts: with open(self.CUBIN, "rb") as f: self.CUBIN = f.read() except (IOError, OSError) as e: - raise DSLRuntimeError(f"Failed to read CUBIN file '{self.CUBIN}': {e}") + raise DSLRuntimeError(f"Failed to read CUBIN file '{self.CUBIN}': {e}") # type: ignore[str-bytes-safe] if self.MLIR is not None and os.path.exists(self.MLIR): try: with open(self.MLIR, "r") as f: @@ -836,56 +731,65 @@ class JitFunctionArtifacts: raise DSLRuntimeError(f"Failed to read MLIR file '{self.MLIR}': {e}") +@dataclass class ExportProvider: """Holds the dsl specific settings for the export of the jit-compiled function.""" - dsl: "Type[BaseDSL]" = None - arg_spec_processor: "ArgsSpecProcessor" = None - c_header_generator: "CHeaderGenerator" = None - object_file_version: str = None + dsl: "type[BaseDSL]" + signature_processor: "SignatureProcessor" + c_header_generator: "CHeaderGenerator" + object_file_version: str + mlirExecutionEngine: Any - def __init__( - self, - dsl: "Type[BaseDSL]", - arg_spec_processor: "ArgsSpecProcessor", - c_header_generator: "CHeaderGenerator", - object_file_version: str, - mlirExecutionEngine: "MlirExecutionEngine", - ): - self.dsl = dsl - self.arg_spec_processor = arg_spec_processor - self.c_header_generator = c_header_generator - self.object_file_version = object_file_version - self.mlirExecutionEngine = mlirExecutionEngine + +class AuxRuntimeFunc(abc.ABC): + """Abstract base class for auxiliary runtime host functions compiled by the DSL. + + Subclasses declare a ``name`` class attribute matching the host-function + symbol suffix (e.g. ``"queryDeviceWorkspace"``) and implement ``__call__`` + with the user-facing argument signature. + + Instances are created by :meth:`JitCompiledFunction.get_aux_func`. + """ + + name: ClassVar[str] # subclasses must set this + + def __init__(self, func_ptr: int, execution_args: ExecutionArgs) -> None: + """Declare the constructor contract for :meth:`JitCompiledFunction.get_aux_func`. + + Subclasses replace this with an implementation; they are not required to + call ``super().__init__``. + """ class JitCompiledFunction: """Holds a compiled function.""" - export_provider: ExportProvider = None + export_provider: ExportProvider | None = None def __init__( self, - ir_module, - engine, - capi_func, - args_spec, - function_name, - kernel_info, - jit_time_profiling, - jit_function_artifacts, - prefix=None, - load_from_binary=False, - dynamic_args=None, - dynamic_kwargs=None, - ): + ir_module: ir.Module, + engine: Any, + capi_func: Any, + signature: inspect.Signature | None, + function_name: str, + kernel_info: dict[str, Any] | None, + jit_time_profiling: bool, + jit_function_artifacts: JitFunctionArtifacts | None, + prefix: str | None = None, + load_from_binary: bool = False, + dynamic_args: tuple[Any] = tuple[Any](), + dynamic_kwargs: dict[str, Any] = dict[str, Any](), + has_gpu_module: bool = True, + ) -> None: self.ir_module = ir_module self.engine = engine self.capi_func = capi_func self.function_name = function_name - self.kernel_info = kernel_info - if args_spec is not None: - self.args_spec = ExecutionArgs(args_spec, self.function_name) + self.kernel_info = kernel_info if kernel_info is not None else dict[str, Any]() + if signature is not None: + self.execution_args = ExecutionArgs(signature, self.function_name) self.jit_time_profiling = jit_time_profiling assert ( @@ -899,29 +803,31 @@ class JitCompiledFunction: # This runtime state is stored here so that we can preserve the module # in the compiler cache. Callers can extend the lifetime of the module # by creating and retaining the executor. - self.jit_module = None + self.jit_module: JitModule | None = None self._executor_lock = threading.RLock() - self._default_executor = None + self._default_executor: JitExecutor | None = None # This is used to do early generation of the c header arguments to release the reference to the dynamic arguments. self._generate_c_header_arguments(dynamic_args, dynamic_kwargs) + self.has_gpu_module = has_gpu_module + @property - def __ptx__(self): + def __ptx__(self) -> str | None: """Returns the PTX code of the JIT-compiled function.""" return self.artifacts.PTX if self.artifacts is not None else None @property - def __cubin__(self): + def __cubin__(self) -> str | bytes | None: """Returns the CUBIN data of the JIT-compiled function.""" return self.artifacts.CUBIN if self.artifacts is not None else None @property - def __mlir__(self): + def __mlir__(self) -> str | None: """Returns the MLIR code of the JIT-compiled function.""" return self.artifacts.MLIR if self.artifacts is not None else None - def _deserializer(self): + def _deserializer(self) -> list[CudaModuleAndKernel]: """Load the cuda module from the binary execution engine. This function will be injected as the JitCompiledFunction method which will be called by the jit executor to load the cuda module by AOT flow. @param self: The JitCompiledFunction object. This is the JitCompiledFunction object to load the cuda module. @@ -940,7 +846,10 @@ class JitCompiledFunction: ) cubin_module = cuda_helpers.load_library_data(cubin_data) # load cuda module/get function pointer from module and cache - kernel_modules = collections.OrderedDict() + kernel_modules: collections.OrderedDict[str, CudaModuleAndKernel] = ( + collections.OrderedDict() + ) + assert self.kernel_info is not None for sym, attrs in self.kernel_info.items(): kernel = cuda_helpers.get_library_kernel(cubin_module, sym) if cuda_helpers.get_driver_version() >= 11080: @@ -950,14 +859,14 @@ class JitCompiledFunction: kernel_modules[sym] = CudaModuleAndKernel(sym, cubin_module, kernel, attrs) return list(kernel_modules.values()) - def _validate_engine(self): + def _validate_engine(self) -> None: if self.engine is None: raise DSLRuntimeError( "The compiled function does not have a valid execution engine.", suggestion="For cross-compilation, please use `JitCompiledFunction.export_to_c` to serialize the compiled function and load/execute it on target device.", ) - def to(self, device=None) -> JitExecutor: + def to(self, device: Any = None) -> JitExecutor: """Returns an executable function bound to the given device. For multi-device execution this method can be called for each device where @@ -977,31 +886,95 @@ class JitCompiledFunction: self.ir_module, self.kernel_info ) self.jit_module = JitModule( - self.engine, self.capi_func, self.args_spec, cuda_modules + self.engine, self.capi_func, self.execution_args, cuda_modules ) # Create a new executor that will be tied to a device context - # n.b. host only moduels do not load device specific modules or context. + # n.b. host only modules do not load device specific modules or context. context = self.jit_module.get_device_execute_context(device) return JitExecutor(self.jit_module, context, self.jit_time_profiling) - def generate_execution_args(self, *args, **kwargs): - return self.args_spec.generate_execution_args(args, kwargs) + def generate_execution_args( + self, *args: Any, **kwargs: Any + ) -> tuple[list[Any], list[Any]]: + return self.execution_args.generate_execution_args(args, kwargs) - def __call__(self, *args, **kwargs): + def get_aux_func( + self, func_class: type[AuxRuntimeFunc], kernel: Callable[..., Any] + ) -> AuxRuntimeFunc: + """Look up and return an auxiliary runtime function for a specific kernel. + + ``kernel`` must be a ``@dsl_name.kernel``-annotated callable that was called + inside the ``@dsl_name.jit`` function that produced this compiled object. + The lookup resolves the symbol ``{kernel_name}_{func_class.name}`` for + that specific kernel. + + :param func_class: A subclass of :class:`AuxRuntimeFunc` whose + ``name`` class attribute identifies the host function suffix. + :param kernel: A ``@dsl_name.kernel``-annotated callable. Must have been + called at least once inside a ``@dsl_name.jit`` function so that + ``_dsl_kernel_name`` is set. + :return: An instance of ``func_class`` initialised with the matched + function pointer and ready to call. + :raises TypeError: If ``func_class`` is not a subclass of + :class:`AuxRuntimeFunc`. + :raises ValueError: If ``kernel`` has no ``_dsl_kernel_name`` attribute. + :raises DSLRuntimeError: If no matching symbol is found in the JIT engine. + """ + if not ( + isinstance(func_class, type) and issubclass(func_class, AuxRuntimeFunc) + ): + raise TypeError( + f"func_class must be a subclass of AuxRuntimeFunc, got {func_class!r}" + ) + + # Unwrap bound methods then @wraps wrappers to reach the original funcBody. + func_body = getattr(kernel, "__func__", kernel) # bound method → function + func_body = getattr(func_body, "__wrapped__", func_body) # jit_wrapper → func + kernel_name = getattr(func_body, "_dsl_kernel_name", None) + if kernel_name is None: + raise ValueError( + f"kernel {kernel!r} has no '_dsl_kernel_name' attribute. " + "Make sure it has been called at least once inside a @cute.jit function." + ) + + self._validate_engine() + + sym_name = func_class.name + candidate = f"{kernel_name}_{sym_name}" + candidates = [candidate] + if self.prefix is not None: + candidates = [f"_mlir_{self.prefix}_{candidate}"] + candidates + + fn_ptr = None + for c in candidates: + fn_ptr = self.engine.raw_lookup(c) + if fn_ptr: + break + + if not fn_ptr: + raise DSLRuntimeError( + f"Host function '{sym_name}' not found in JIT engine. " + f"Tried: {candidates}" + ) + return func_class(fn_ptr, self.execution_args) + + def __call__(self, *args: Any, **kwargs: Any) -> int | None: """Executes the jit-compiled function under the currently active CUDA context. Calling this method multiple devices is not allowed and will result in unexpected CUDA errors. If you need to call the kernel on multiple devices use `to` to return a per-device function. """ - exe_args, adapted_args = self.args_spec.generate_execution_args(args, kwargs) + exe_args, adapted_args = self.execution_args.generate_execution_args( + args, kwargs + ) executor = self._default_executor if executor is not None: # Only lock on first call return executor.run_compiled_program(exe_args) return self.run_compiled_program(exe_args) - def run_compiled_program(self, exe_args): + def run_compiled_program(self, exe_args: list[Any]) -> int | None: """Executes the jit-compiled function under the currently active CUDA context. Calling this method multiple devices is not allowed and will result in unexpected @@ -1015,21 +988,28 @@ class JitCompiledFunction: # object alive as it hold a reference to self. proxy_self = weakref.proxy(self) self._default_executor = proxy_self.to(None) + assert self._default_executor is not None return self._default_executor.run_compiled_program(exe_args) - def _generate_c_header_arguments(self, dynamic_args, dynamic_kwargs): + def _generate_c_header_arguments( + self, + dynamic_args: tuple[Any], + dynamic_kwargs: dict[str, Any], + ) -> None: """Generates the c header arguments for the AOT C header generation.""" self.c_header_arguments = None from .export import CHeaderArguments - if dynamic_args is not None or dynamic_kwargs is not None: + if dynamic_args or dynamic_kwargs: self.dummy_prefix_name = "dummy_prefix_name" try: # This arguments may be generated failure due to not all the arguments (e.g. custom types) are supported by the AOT C header generator. + assert self.export_provider is not None + assert self.export_provider.c_header_generator is not None c_header_arguments, packed_args, declarations = ( self.export_provider.c_header_generator._generate_arguments( self.dummy_prefix_name, - self.args_spec, + self.execution_args, dynamic_args, dynamic_kwargs, ) @@ -1043,7 +1023,11 @@ class JitCompiledFunction: ) except Exception as e: self.c_header_arguments = CHeaderArguments( - self.dummy_prefix_name, [], [], [], str(e) + self.dummy_prefix_name, + [], + [], + [], + str(e), ) def dump_to_object( @@ -1066,16 +1050,16 @@ class JitCompiledFunction: export_module = encode_metadata_into_ir_module( function_prefix, export_module, - self.args_spec.args_spec, + self.execution_args.signature, self.function_name, self.kernel_info, - self.export_provider.arg_spec_processor, + self.export_provider.signature_processor, self.export_provider.object_file_version, ) cubin_data = None - def strip_gpu_binary_op(op): + def strip_gpu_binary_op(op: ir.Operation) -> ir.WalkResult: if op.name == "gpu.binary": s = io.BytesIO() op.operation.write_bytecode(s) @@ -1127,7 +1111,7 @@ class JitCompiledFunction: file_path: str, file_name: str, function_prefix: str = "", - ): + ) -> None: """Exports the jit-compiled function to a C compatible files(header/library). This is used for c/cpp AOT support. The `file_path` will be used as the directory to save the header and object files. @@ -1159,13 +1143,15 @@ class JitCompiledFunction: export_module = get_export_module(self.ir_module, function_prefix) # Generate the c header file + assert self.export_provider.c_header_generator is not None + assert self.export_provider.dsl is not None header_file_content = self.export_provider.c_header_generator( function_prefix, export_module, - self.args_spec, + self.execution_args, self.function_name, self.kernel_info, - self.c_header_arguments, + self.c_header_arguments, # type: ignore[arg-type] self.export_provider.dsl._get_dsl().name, ) try: diff --git a/python/CuTeDSL/cutlass/base_dsl/leaf_utils.py b/python/CuTeDSL/cutlass/base_dsl/leaf_utils.py new file mode 100644 index 000000000..9f1340123 --- /dev/null +++ b/python/CuTeDSL/cutlass/base_dsl/leaf_utils.py @@ -0,0 +1,601 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import dataclasses +from types import SimpleNamespace +from typing import Any + +from cutlass._mlir import ir + + +# ============================================================================= +# Helper functions +# ============================================================================= + + +def is_frozen_dataclass(obj_or_cls: Any) -> bool: + """Check if an object or class is a frozen dataclass.""" + cls = obj_or_cls if isinstance(obj_or_cls, type) else type(obj_or_cls) + if not dataclasses.is_dataclass(cls): + return False + params = getattr(cls, "__dataclass_params__", None) + if params is not None: + return getattr(params, "frozen", False) + return False + + +def _is_dynamic_expression(obj: Any) -> bool: + """Check if object implements the DynamicExpression protocol.""" + if isinstance(obj, type): + return False + return hasattr(obj, "__extract_mlir_values__") and hasattr( + obj, "__new_from_mlir_values__" + ) + + +def _is_assignable_leaf(obj: Any) -> bool: + """ + Check if object is an assignable leaf. + + Assignable leaves are things the language allows us to assign to: + - ir.Value: directly assignable + - DynamicExpression: use __extract_mlir_values__ / __new_from_mlir_values__ + - DSL types with .value that is ir.Value: can update .value + + Objects whose class sets __cls_traversable_dict__ = True are always + treated as containers (not leaves), even if they implement the + DynamicExpression protocol. This allows gather_leaves to recurse + into their __dict__ while the class still provides + __extract_mlir_values__ / __new_from_mlir_values__ to the framework. + """ + # If the class declares itself as dict-traversable, treat as container, not leaf + if getattr(type(obj), "__cls_traversable_dict__", False): + return False + + if isinstance(obj, ir.Value): + return True + + if _is_dynamic_expression(obj): + return True + + if hasattr(obj, "value") and isinstance(getattr(obj, "value", None), ir.Value): + return True + + return False + + +def _flatten_to_ir_values(values_dict: Any) -> list[ir.Value]: + """Flatten a values_dict from __extract_mlir_values__ to list of ir.Values.""" + result = [] + + if isinstance(values_dict, ir.Value): + result.append(values_dict) + elif isinstance(values_dict, dict): + for v in values_dict.values(): + result.extend(_flatten_to_ir_values(v)) + elif isinstance(values_dict, (list, tuple)): + for v in values_dict: + result.extend(_flatten_to_ir_values(v)) + + return result + + +def _unflatten_ir_values( + template: Any, values: list[ir.Value], idx: list[int] | None = None +) -> Any: + """Unflatten ir.Values back into a values_dict structure matching template.""" + if idx is None: + idx = [0] + + if isinstance(template, ir.Value): + result = values[idx[0]] if idx[0] < len(values) else template + idx[0] += 1 + return result + elif isinstance(template, dict): + return {k: _unflatten_ir_values(v, values, idx) for k, v in template.items()} + elif isinstance(template, list): + return [_unflatten_ir_values(v, values, idx) for v in template] + elif isinstance(template, tuple): + return tuple(_unflatten_ir_values(v, values, idx) for v in template) + else: + return template + + +def _get_all_attrs(obj: Any) -> dict[str, Any]: + """Get all attributes from an object via __dict__ and __slots__.""" + attrs = {} + if hasattr(obj, "__dict__"): + attrs.update(obj.__dict__) + for cls in type(obj).__mro__: + if hasattr(cls, "__slots__"): + for slot in cls.__slots__: + if hasattr(obj, slot) and slot not in attrs: + try: + attrs[slot] = getattr(obj, slot) + except AttributeError: + pass + return attrs + + +def _unwrap_ir_value(val: Any) -> ir.Value | None: + """ + Extract the ir.Value from a value, handling value casters. + + MLIR Python bindings auto-cast certain types (like PtrType -> _Pointer). + These wrapped types store the ir.Value in .value attribute. + """ + if isinstance(val, ir.Value): + return val + if hasattr(val, "value") and isinstance(val.value, ir.Value): + return val.value + return None + + +# ============================================================================= +# LeafInfo +# ============================================================================= + + +class LeafInfo: + """ + Information about an assignable leaf and its location. + + A "leaf" is something the language allows us to assign to: + - ir.Value: can be replaced in parent container + - DynamicExpression: use __extract_mlir_values__ / __new_from_mlir_values__ + - DSL types with .value (Int32, _Pointer, etc.): can update .value + + We track where each leaf lives so we can update it after a loop/branch. + + Attributes: + obj: The leaf object itself (DSL type or ir.Value) + parent: The parent object containing this leaf (can be None for DSL types) + key: The field name or index to access this leaf from parent + key_type: 'attr', 'list', 'dict', 'root' + path: Human-readable path from root (for debugging) + """ + + def __init__(self, obj: Any, parent: Any, key: Any, key_type: str, path: str): + self.obj = obj + self.parent = parent + self.key = key + self.key_type = key_type + self.path = path + self._extracted_values = None + + def get_ir_values(self) -> list[ir.Value]: + """Get all ir.Values from this leaf (may be multiple for DynamicExpression).""" + if isinstance(self.obj, ir.Value): + return [self.obj] + + if _is_dynamic_expression(self.obj): + values_dict = self.obj.__extract_mlir_values__() + return _flatten_to_ir_values(values_dict) + + if hasattr(self.obj, "value") and isinstance(self.obj.value, ir.Value): + return [self.obj.value] + + return [] + + def get_ir_value(self) -> ir.Value | None: + """Get single ir.Value (for backward compat - returns first value).""" + values = self.get_ir_values() + return values[0] if values else None + + def set_ir_values(self, new_vals: list[ir.Value]) -> None: + """Set ir.Values at this location (may be multiple for DynamicExpression). + + With the mutable proxy pattern, the parent is always mutable (the capture + list, a proxy list for tuples, or a proxy SimpleNamespace for frozen DCs). + So replacement in parent always succeeds for leaves. + """ + # Case 1: DynamicExpression -> use protocol to reconstruct and replace + if _is_dynamic_expression(self.obj): + old_values_dict = self.obj.__extract_mlir_values__() + new_values_dict = _unflatten_ir_values(old_values_dict, new_vals) + new_obj = self.obj.__new_from_mlir_values__(new_values_dict) + + replaced = False + if self.parent is not None: + if self.key_type == "attr": + setattr(self.parent, self.key, new_obj) + replaced = True + elif self.key_type == "list" or self.key_type == "dict": + self.parent[self.key] = new_obj + replaced = True + + if replaced: + self.obj = new_obj + return + + # Fallback: try in-place update + if hasattr(self.obj, "__dict__") and hasattr(new_obj, "__dict__"): + self.obj.__dict__.update(new_obj.__dict__) + return + if hasattr(self.obj, "value") and len(new_vals) >= 1: + self.obj.value = new_vals[0] + return + + raise RuntimeError( + f"Cannot update DynamicExpression at '{self.path}'.\n" + f" Object type: {type(self.obj).__name__}\n" + f" Parent type: {type(self.parent).__name__ if self.parent else 'None'}\n" + f" Key type: {self.key_type}" + ) + + # Case 2: Simple .value attribute (Int32, Float32, Pointer wrappers) + if hasattr(self.obj, "value") and isinstance( + getattr(self.obj, "value", None), ir.Value + ): + if len(new_vals) >= 1: + self.obj.value = new_vals[0] + return + + # Case 3: raw ir.Value (including subclasses like ArithValue, ctm.Pointer) + if isinstance(self.obj, ir.Value): + if len(new_vals) >= 1: + self._replace_in_parent(new_vals[0]) + return + + print( + f"WARNING: Cannot set ir.Values at {self.path}: " + f"got {type(self.obj).__name__}" + ) + + def set_ir_value(self, new_val: ir.Value) -> None: + """Set single ir.Value (for backward compat).""" + self.set_ir_values([new_val]) + + def _replace_in_parent(self, new_val: ir.Value) -> None: + """Replace this object in its parent container.""" + if self.parent is None: + print(f"WARNING: Cannot replace root-level ir.Value at {self.path}") + return + if self.key_type == "attr": + setattr(self.parent, self.key, new_val) + elif self.key_type == "list" or self.key_type == "dict": + self.parent[self.key] = new_val + + def __repr__(self) -> str: + return f"LeafInfo({self.path}, {type(self.obj).__name__})" + + +# ============================================================================= +# gather_leaves +# ============================================================================= + + +def gather_leaves( + objects: list[Any], +) -> tuple[list[ir.Value], list[LeafInfo], list[tuple[Any, ...]]]: + """ + Recursively traverse the object graph and gather assignable leaves. + + Leaves are atomic values (ir.Value, Int32, DynamicExpression, etc.) that + carry ir.Values. Containers (list, dict, class, tuple, frozen dataclass) + hold leaves and other containers -- we recurse into them. + + The `objects` list itself serves as the root mutable parent, so every leaf + is always addressable via its parent container. For immutable containers + (tuples, frozen dataclasses), we create mutable proxies during gather and + reconstruct the immutables during inject_leaves. + + Args: + objects: List of Python objects to traverse (the capture container) + + Returns: + Tuple of: + - ir_values: Flat list of ir.Values from leaves + - leaf_infos: List of LeafInfo describing each leaf's location + - immutable_proxies: List of (original_obj, proxy, parent, key, key_type) + for each immutable container encountered, in DFS order. + Used by inject_leaves for bottom-up reconstruction. + """ + ir_values: list[ir.Value] = [] + leaf_infos: list[LeafInfo] = [] + immutable_proxies: list[tuple[Any, ...]] = [] + + visited = set() + + def _gather_recursive( + obj: Any, parent: Any, key: Any, key_type: str, path: str + ) -> None: + """Recursively find assignable leaves.""" + if obj is None: + return + + obj_id = id(obj) + + # Check if this is an assignable leaf + if _is_assignable_leaf(obj): + info = LeafInfo(obj, parent, key, key_type, path) + leaf_infos.append(info) + leaf_values = info.get_ir_values() + ir_values.extend(leaf_values) + return + + # Not a leaf - recurse into containers + if obj_id in visited: + return + visited.add(obj_id) + + # List (mutable container) + if isinstance(obj, list): + for i, item in enumerate(obj): + item_path = f"{path}[{i}]" if path else f"[{i}]" + _gather_recursive(item, obj, i, "list", item_path) + return + + # Dict (mutable container) + if isinstance(obj, dict): + for k, v in obj.items(): + key_str = repr(k) if not isinstance(k, str) else k + item_path = f"{path}[{key_str}]" if path else f"[{key_str}]" + _gather_recursive(v, obj, k, "dict", item_path) + return + + # Tuple (IMMUTABLE container -- create mutable proxy list) + if isinstance(obj, tuple): + proxy_list = list(obj) + immutable_proxies.append((obj, proxy_list, parent, key, key_type)) + if parent is not None: + if key_type == "list" or key_type == "dict": + parent[key] = proxy_list + elif key_type == "attr": + setattr(parent, key, proxy_list) + for i, item in enumerate(proxy_list): + item_path = f"{path}({i})" if path else f"({i})" + _gather_recursive(item, proxy_list, i, "list", item_path) + return + + # Frozen dataclass (IMMUTABLE container -- create mutable proxy) + if is_frozen_dataclass(obj): + fields = dataclasses.fields(obj) + proxy = SimpleNamespace(**{f.name: getattr(obj, f.name) for f in fields}) + immutable_proxies.append((obj, proxy, parent, key, key_type)) + if parent is not None: + if key_type == "list" or key_type == "dict": + parent[key] = proxy + elif key_type == "attr": + setattr(parent, key, proxy) + for f in fields: + attr_val = getattr(proxy, f.name) + if f.name.startswith("__") or callable(attr_val): + continue + attr_path = f"{path}.{f.name}" if path else f.name + _gather_recursive(attr_val, proxy, f.name, "attr", attr_path) + return + + # Object with __dict__ or __slots__ (mutable container) + if hasattr(obj, "__dict__") or hasattr(type(obj), "__slots__"): + attrs = _get_all_attrs(obj) + for attr_name, attr_val in attrs.items(): + if attr_name.startswith("__") or callable(attr_val): + continue + attr_path = f"{path}.{attr_name}" if path else attr_name + _gather_recursive(attr_val, obj, attr_name, "attr", attr_path) + return + + # Process each top-level object + for i, obj in enumerate(objects): + root_path = f"arg{i}" + _gather_recursive(obj, objects, i, "list", root_path) + + return ir_values, leaf_infos, immutable_proxies + + +# ============================================================================= +# inject_leaves +# ============================================================================= + + +def inject_leaves( + leaf_infos: list[LeafInfo], + new_values: list[Any], + immutable_proxies: list[tuple[Any, ...]] | None = None, +) -> None: + """ + Inject new ir.Values into the leaves, then reconstruct immutable containers. + + Two-phase injection: + Phase 1: Direct mutations into leaf parents (which are always mutable -- + either a real mutable container or a proxy list/SimpleNamespace). + Phase 2: Bottom-up reconstruction of immutable containers (tuples, frozen + dataclasses) from their proxies, replacing the proxy in its parent. + + Args: + leaf_infos: List of LeafInfo from gather_leaves + new_values: New values to inject (ir.Value or wrapped types) + immutable_proxies: List of (original_obj, proxy, parent, key, key_type) + from gather_leaves. If None, skip Phase 2 (backward compat). + """ + # Calculate expected total values + expected_count = sum(len(info.get_ir_values()) for info in leaf_infos) + if len(new_values) != expected_count: + print( + f"WARNING: inject_leaves: value count mismatch - " + f"expected {expected_count}, got {len(new_values)}" + ) + + # Unwrap all values first + unwrapped = [] + for v in new_values: + ir_val = _unwrap_ir_value(v) + if ir_val is not None: + unwrapped.append(ir_val) + else: + print(f"WARNING: inject_leaves: cannot unwrap {type(v).__name__}") + unwrapped.append(v) + + # Phase 1: Inject values into each leaf + idx = 0 + for info in leaf_infos: + num_values = len(info.get_ir_values()) + if num_values == 0: + continue + + leaf_values = unwrapped[idx : idx + num_values] + idx += num_values + + if len(leaf_values) != num_values: + print( + f"ERROR: inject_leaves: not enough values for {info.path} " + f"(need {num_values}, have {len(leaf_values)})" + ) + continue + + info.set_ir_values(leaf_values) + + # Phase 2: Reconstruct immutable containers bottom-up + if immutable_proxies: + for original_obj, proxy, parent, key, key_type in reversed(immutable_proxies): + new_obj: Any + if isinstance(original_obj, tuple): + new_obj = type(original_obj)(proxy) + elif dataclasses.is_dataclass(original_obj) and not isinstance( + original_obj, type + ): + fields = dataclasses.fields(original_obj) + field_vals = {f.name: getattr(proxy, f.name) for f in fields} + new_obj = type(original_obj)(**field_vals) + else: + continue + + if parent is not None: + if key_type == "list" or key_type == "dict": + parent[key] = new_obj + elif key_type == "attr": + setattr(parent, key, new_obj) + + +# ============================================================================= +# Debug utility +# ============================================================================= + + +def print_leaves_debug( + leaf_infos: list[LeafInfo], label: str = "", prefix: str = "CUTE_DSL" +) -> None: + """ + Print debug info about gathered assignable leaves. + + Enable with environment variable: {prefix}_DEBUG_LEAVES=1 + (e.g. CUTE_DSL_DEBUG_LEAVES=1 for the CuTe DSL) + """ + import os + + if os.environ.get(f"{prefix}_DEBUG_LEAVES", "0") != "1": + return + + print(f"\n{'=' * 80}") + print(f"LEAVES: {label}") + print(f"{'=' * 80}") + print(f"{'#':<5} {'TYPE':<20} {'#V':<4} {'IR_TYPES':<25} {'PATH'}") + print(f"{'-' * 5} {'-' * 20} {'-' * 4} {'-' * 25} {'-' * 40}") + + total_values = 0 + for i, info in enumerate(leaf_infos): + obj_type = type(info.obj).__name__ + ir_vals = info.get_ir_values() + num_vals = len(ir_vals) + total_values += num_vals + + if num_vals == 0: + ir_types_str = "None" + elif num_vals == 1: + ir_types_str = str(ir_vals[0].type)[:23] + else: + ir_types_str = f"{str(ir_vals[0].type)[:15]}..({num_vals})" + + print(f"{i:<5} {obj_type:<20} {num_vals:<4} {ir_types_str:<25} {info.path}") + + print(f"{'=' * 80}") + print(f"Total: {len(leaf_infos)} leaves, {total_values} ir.Values") + print(f"{'=' * 80}\n") + + +# ============================================================================= +# TraversableLeafMixin – generic DynamicExpression via gather/inject leaves +# ============================================================================= + + +class TraversableLeafMixin: + """ + Mixin that auto-implements the DynamicExpression protocol + (__extract_mlir_values__ / __new_from_mlir_values__) using + gather_leaves / inject_leaves. + + When gather_leaves encounters an object whose class has + ``__cls_traversable_dict__ = True``, it treats the object as a + **container** (recurses into ``__dict__``) rather than as a + DynamicExpression leaf — even if ``__extract_mlir_values__`` is + present. This eliminates the need to manually implement the + extract/new protocol for classes that simply need all their + ``ir.Values`` gathered and injected. + + The gather state is stored on the instance under dunder keys + (``__gather_infos``, ``__gather_proxies``) which ``gather_leaves`` + automatically skips (it ignores ``__``-prefixed attributes). + + Usage — zero boilerplate:: + + @dataclass + class MyTask(TraversableLeafMixin): + src_resources: List[MemoryResource] + dst_resources: List[MemoryResource] + # No manual __extract_mlir_values__ needed! + + If extra fixups are needed after inject (e.g. re-aliasing):: + + class MyTask(TraversableLeafMixin): + ... + def __new_from_mlir_values__(self, values): + super().__new_from_mlir_values__(values) + # custom fixup + self.work_queue = self.src_resources[self.work_queue_idx] + return self + """ + + __cls_traversable_dict__ = True + + def __extract_mlir_values__(self) -> dict: + ir_vals, infos, proxies = gather_leaves([self]) + # Store with __ prefix so gather_leaves skips them on next call + self.__dict__["__gather_infos"] = infos + self.__dict__["__gather_proxies"] = proxies + + # Return a flat dict keyed by the full leaf path. + # This makes framework error messages immediately show which + # attribute changed structure, e.g.: + # "src_resources[0].buf" -> i32 + # "dst_resources[0].ptr" -> !llvm.ptr + # Multi-value leaves (DynamicExpression) get indexed sub-keys. + result = {} + val_idx = 0 + for info in infos: + n_vals = len(info.get_ir_values()) + if n_vals == 1: + result[info.path] = ir_vals[val_idx] + else: + for i in range(n_vals): + result[f"{info.path}#{i}"] = ir_vals[val_idx + i] + val_idx += n_vals + return result + + def __new_from_mlir_values__( + self, values: dict[str, Any] + ) -> "TraversableLeafMixin": + infos = self.__dict__.pop("__gather_infos") + proxies = self.__dict__.pop("__gather_proxies") + # values is a flat {path: ir.Value} dict; just take values in order + new_vals = list(values.values()) + inject_leaves(infos, new_vals, proxies) + return self diff --git a/python/CuTeDSL/cutlass/base_dsl/native_struct.py b/python/CuTeDSL/cutlass/base_dsl/native_struct.py new file mode 100644 index 000000000..b076e98ba --- /dev/null +++ b/python/CuTeDSL/cutlass/base_dsl/native_struct.py @@ -0,0 +1,435 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import builtins +from collections.abc import Iterator +from typing import Any, get_origin, get_type_hints + +from .dsl import extract_mlir_values +from .typing import DslType + +from .._mlir import ir +from .._mlir.dialects import llvm + +from ._mlir_helpers import dsl_user_op + + +def _is_constexpr_annotation(ann: type) -> bool: + """True if the annotation is Constexpr or Constexpr[T].""" + from .typing import Constexpr + + return ann is Constexpr or get_origin(ann) is Constexpr + + +def _annotation_to_mlir_type(ann: type) -> ir.Type: + """Resolve a type annotation to an MLIR type for struct fields. + + Supports DSL types with a class-level mlir_type (e.g. Int32, Float32), + ir.Type instances, and other native_struct classes (for nested structs). + Called at init/use time when MLIR context is available. + """ + if isinstance(ann, ir.Type): + return ann + if hasattr(ann, "mlir_type"): + mt = ann.mlir_type + return mt() if callable(mt) else mt + if hasattr(ann, "_struct_type"): + return ann._struct_type + raise TypeError( + f"Struct field type must be an ir.Type, a DSL type with mlir_type, " + f"or a native_struct class; got {ann!r}" + ) + + +class _StructTypeDescriptor: + """Descriptor that resolves struct type from annotations on each access. + + Also provides :meth:`resolve` for use outside the descriptor protocol + (e.g. in static methods like ``__get_mlir_types__`` and ``isinstance``). + """ + + _field_types_attr = "_field_types" + + def __init__( + self, field_names: list, field_annotations: dict, packed: bool = False + ): + self._field_names = field_names + self._field_annotations = field_annotations + self._packed = packed + + def _resolve(self) -> tuple[ir.Type, list[ir.Type]]: + """Resolve field annotations to MLIR types and build the struct type. + + Requires an active MLIR context. Not cached — MLIR types are tied to + the context they were created in, and each JIT compilation may use a + different context. + """ + field_types = [ + _annotation_to_mlir_type(self._field_annotations[n]) + for n in self._field_names + ] + struct_type = llvm.StructType.get_literal(field_types, packed=self._packed) + return struct_type, field_types + + def resolve(self) -> ir.Type: + """Return the LLVM struct type (resolving and caching if needed).""" + struct_type, _ = self._resolve() + return struct_type + + def __get__(self, obj: Any, owner: type | None) -> ir.Type: + if owner is None: + owner = type(obj) + struct_type, field_types = self._resolve() + setattr(owner, self._field_types_attr, field_types) + return struct_type + + +def native_struct( + cls: type | None = None, *, zero_init: bool = True, packed: bool = False +) -> Any: + """Decorator that mimics dataclass behavior but generates a native MLIR struct type. + + Can be used as ``@native_struct``, ``@native_struct(zero_init=False)``, or + ``@native_struct(packed=True)``. + + ``zero_init`` (default True): if True, the struct is initialized with + ``llvm.mlir.zero`` before inserting field values; if False, with ``llvm.mlir.undef``. + + ``packed`` (default False): if True, the LLVM struct type is created with the + packed attribute (no padding between fields). + + The decorated class must use type annotations for all fields; each annotation + must resolve to an DSL type (e.g. Int32, Float32, Pointer, etc.) or Constexpr. The decorator: + + - Builds an LLVM literal struct type ``!llvm.struct<(t1, t2, ...)>`` from the non-Constexpr + field types. + - Adds ``__init__(self, *, loc=None, ip=None, **kwargs)`` to construct from keyword + arguments (one per MLIR field). Use ``cls.__new_from_mlir_values__([value])`` to wrap + an existing ``ir.Value``. + - Implements ``__extract_mlir_values__``, ``__new_from_mlir_values__``, and + ``__get_mlir_types__`` so the class works as a DynamicExpression and JitArgument. + - For each field ``name``, adds a property ``name`` to the class. + Accessing ``instance.name`` gets the value of the field, optionally wrapping it in the annotated + DSL type (for example, ``Int32``). + Assigning to ``instance.name`` is supported and replaces the value of that field. + - Fields annotated with ``Constexpr`` or ``Constexpr[T]`` are excluded from the + native struct and from getters/setters; they can be passed as keyword arguments + to ``__init__`` and are stored as normal Python attributes on the instance. + - ``__setattr__`` is overridden so that only ``_value`` and Constexpr field names + can be set; assigning any other attribute raises ``AttributeError`` (no new + fields after init). + + Instance data is stored in ``_value`` (a single ir.Value of the struct type). + + Example:: + + @native_struct + class Vec2: + x: Int32 + y: Int32 + + @native_struct(zero_init=False) + class Vec2Undef: + x: Int32 + y: Int32 + + @native_struct(packed=True) + class PackedVec2: + x: Int32 + y: Int32 + + # From keyword arguments (loc= and ip= required) + v = Vec2(x=x_val, y=y_val, loc=loc, ip=ip) + + # Get field (returns Int32 when annotation is Int32) + x_val = v.x + + # Replace field + v.x = new_x + """ + + def decorate(cls: type) -> type: + hints = get_type_hints(cls) + if not hints: + raise TypeError( + f"{cls.__name__}: @native_struct requires at least one type-annotated field" + ) + # Split into MLIR fields (included in struct) and Constexpr fields (skipped). + # Do not resolve annotations to MLIR types here; MLIR context may not exist yet. + field_names = [] + field_annotations = {} + constexpr_field_names = [] + for name, ann in hints.items(): + if _is_constexpr_annotation(ann): + constexpr_field_names.append(name) + else: + field_names.append(name) + field_annotations[name] = ann + if not field_names: + raise TypeError( + f"{cls.__name__}: @native_struct requires at least one non-Constexpr field" + ) + + struct_type_descriptor = _StructTypeDescriptor( + field_names, field_annotations, packed=packed + ) + + def __extract_mlir_values__(self: Any) -> list[ir.Value]: + return [self._value] + + def __new_from_mlir_values__(self: Any, values: list[ir.Value]) -> Any: + self._value = values[0] + return self + + @dsl_user_op + def __init__( + self: Any, + *args: Any, + loc: ir.Location | None = None, + ip: ir.InsertionPoint | None = None, + **kwargs: Any, + ) -> None: + # Wrapping mode: single positional ir.Value + # Note: this is builtins.isinstance, not the struct's own + # isinstance() staticmethod (which hasn't been added to the + # class yet at this point in decorate()). + if len(args) == 1 and not kwargs and isinstance(args[0], ir.Value): + struct_type = type(self)._struct_type + if args[0].type != struct_type: + raise TypeError( + f"{cls.__name__}(): expected ir.Value of type " + f"{struct_type}, got {args[0].type}" + ) + self._value = args[0] + return + # Keyword-arg construction mode + if len(args) > 0: + raise TypeError( + f"{cls.__name__}() takes a single ir.Value or keyword " + f"arguments, got {len(args)} positional argument(s)" + ) + # Populate Constexpr fields + for name in constexpr_field_names: + if name in kwargs: + setattr(self, name, kwargs.pop(name)) + extra = set(kwargs.keys()) - set(field_names) - set(constexpr_field_names) + if extra: + raise TypeError( + f"{cls.__name__}() got unexpected keyword argument(s): {sorted(extra)}" + ) + struct_type = type(self)._struct_type + if type(self)._struct_zero_init: + val = llvm.mlir_zero(struct_type, loc=loc, ip=ip) + else: + val = llvm.mlir_undef(struct_type, loc=loc, ip=ip) + for i, name in enumerate(field_names): + v = kwargs.pop(name, None) + if v is not None: + # Coerce Python literals (int, float, bool) using + # the field's type annotation (e.g. Int32(10)). + ann = field_annotations[name] + if isinstance(ann, DslType) and not hasattr( + v, "__extract_mlir_values__" + ): + v = ann(v) + elem = extract_mlir_values(v) + if len(elem) != 1: + raise TypeError( + f"Expected single value for field {name!r}, got {len(elem)}" + ) + val = llvm.insertvalue(val, elem[0], position=[i], loc=loc, ip=ip) + self._value = val + + # Build getter/setter for each field; need to capture in closure per field + field_annotations_for_getter = hints + + def _make_getter(idx: int, name: str) -> Any: + dsl_type = field_annotations_for_getter.get(name) + + def getter( + self: Any, + *, + loc: ir.Location | None = None, + ip: ir.InsertionPoint | None = None, + ) -> Any: + # Resolve struct type (and thus _field_types) on first use + type(self)._struct_type + elem_type = type(self)._field_types[idx] + extracted = llvm.extractvalue( + res=elem_type, + container=self._value, + position=[idx], + loc=loc, + ip=ip, + ) + # Wrap in DSL type if annotation is a callable type (e.g. Int32) + if isinstance(dsl_type, DslType): + return dsl_type(extracted) + if hasattr(dsl_type, "__new_from_mlir_values__"): + instance = dsl_type() # type: ignore[misc] + return instance.__new_from_mlir_values__([extracted]) + return extracted + + getter.__name__ = name + getter.__doc__ = f"Get the {name!r} field." + return dsl_user_op(getter) + + def _make_setter(idx: int, name: str) -> Any: + dsl_type = field_annotations_for_getter.get(name) + + def setter( + self: Any, + value: Any, + *, + loc: ir.Location | None = None, + ip: ir.InsertionPoint | None = None, + ) -> None: + # Coerce Python literals using the field's type annotation. + if isinstance(dsl_type, DslType) and not hasattr( + value, "__extract_mlir_values__" + ): + value = dsl_type(value) + elem = extract_mlir_values(value) + if len(elem) != 1: + raise TypeError( + f"Expected single value for field {name!r}, got {len(elem)}" + ) + elem = elem[0] + new_value = llvm.insertvalue( + self._value, elem, position=[idx], loc=loc, ip=ip + ) + self._value = new_value + + setter.__name__ = f"set_{name}" + setter.__doc__ = f"Set the {name!r} field." + return dsl_user_op(setter) + + _allowed_attr_names = frozenset( + field_names + constexpr_field_names + ["_value"] # _value is internal + ) + + def __setattr__(self: Any, name: str, value: Any) -> None: + if name not in _allowed_attr_names: + raise AttributeError( + f"{type(self).__name__!r} does not allow setting attribute {name!r}; " + f"only fields {field_names + constexpr_field_names} are settable" + ) + object.__setattr__(self, name, value) + + def __iter__(self: Any) -> Iterator[Any]: + """Yield each field as its DSL-typed value (e.g. Int32, Boolean). + + Enables tuple unpacking: ``a, b = my_struct``. + """ + for name in field_names: + yield getattr(self, name) + + def __get_mlir_types__() -> list[ir.Type]: + """Return MLIR types list — compatible with FFI ``_to_mlir_types``. + + Works on both the class and instances, so ``get_mlir_types(MyStruct)`` + and ``get_mlir_types(my_instance)`` both return ``[struct_type]``. + """ + return [struct_type_descriptor.resolve()] + + def _isinstance(value: Any) -> bool: + """Check if an ``ir.Value`` matches this struct type.""" + if not builtins.isinstance(value, ir.Value): + return False + return value.type == struct_type_descriptor.resolve() + + attrs = { + "_field_names": field_names, + "_field_annotations": field_annotations, + "_constexpr_field_names": constexpr_field_names, + "_struct_type": struct_type_descriptor, + "mlir_type": struct_type_descriptor, + "_struct_zero_init": zero_init, + "_struct_packed": packed, + "__init__": __init__, + "__iter__": __iter__, + "__setattr__": __setattr__, + "__get_mlir_types__": staticmethod(__get_mlir_types__), + "isinstance": staticmethod(_isinstance), + "__extract_mlir_values__": __extract_mlir_values__, + "__new_from_mlir_values__": __new_from_mlir_values__, + } + for idx, name in enumerate(field_names): + attrs[name] = property(_make_getter(idx, name), _make_setter(idx, name)) + + # Preserve existing methods and attributes that don't conflict + for key, value in cls.__dict__.items(): + if key not in attrs and not key.startswith("__"): + attrs[key] = value + new_cls = type(cls.__name__, cls.__bases__, attrs) + new_cls.__module__ = cls.__module__ + new_cls.__qualname__ = cls.__qualname__ + new_cls.__annotations__ = cls.__annotations__ + if cls.__doc__ is not None: + new_cls.__doc__ = cls.__doc__ + return new_cls + + if cls is None: + return decorate + return decorate(cls) + + +def make_native_struct( + name: str, *, zero_init: bool = True, packed: bool = False, **fields: Any +) -> type: + """Create a native struct class dynamically from field name/type pairs. + + Unlike the ``@native_struct`` decorator which requires a class definition + with static type annotations, this factory builds a struct class at runtime. + This is useful when the struct layout is determined dynamically — for example, + NVVM ops whose result struct depends on matrix dimensions or element types. + + The returned class behaves identically to a ``@native_struct``-decorated + class: it supports ``ir.Value`` wrapping, keyword-arg construction, named + field access, tuple unpacking via iteration, and the full DSL protocol. + + Example:: + + ResultType = make_native_struct("WmmaLoadResult", + d0=Int32, d1=Int32, d2=Int32, d3=Int32) + ResultType.mlir_type # → !llvm.struct<(i32, i32, i32, i32)> + + result = ResultType(raw_ir_value) + result.d0 # → Int32 (via extractvalue) + d0, d1, d2, d3 = result # tuple unpacking + + Parameters + ---------- + name : str + Name for the generated class. + zero_init : bool + If True (default), keyword-arg construction zero-initializes before + inserting fields; if False, uses ``llvm.mlir.undef``. + packed : bool + If True, create a packed LLVM struct (no padding). + **fields + Field names mapped to DSL types (e.g. ``d0=Int32, d1=Int32``). + Order is preserved (Python 3.7+). + + Returns + ------- + type + A ``@native_struct`` class with the given fields. + """ + if not fields: + raise TypeError(f"make_native_struct({name!r}) requires at least one field") + + # Build a bare class with the right annotations, then delegate to native_struct + cls = type(name, (), {"__annotations__": dict(fields)}) + return native_struct(cls, zero_init=zero_init, packed=packed) + + +__all__ = ["native_struct", "make_native_struct"] diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py index 55a23d9f8..37285456a 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/cuda.py @@ -15,7 +15,8 @@ This module provides CUDA Python helper functions from functools import lru_cache from dataclasses import dataclass -from typing import List, Optional +from typing import Any +from enum import IntEnum import numpy as np import os import ctypes @@ -28,12 +29,18 @@ import cuda.bindings.nvrtc as nvrtc from ..utils.logger import log as _log from ..common import * +# ============================================================================= +# Enums +# ============================================================================= + + + # ============================================================================= # Utils # ============================================================================= -def _cudaGetErrorEnum(error): +def _cudaGetErrorEnum(error: Any) -> Any: """ Get the error name of a CUDA error. :param error: The CUDA error. @@ -53,7 +60,7 @@ def _cudaGetErrorEnum(error): raise DSLRuntimeError("Unknown error type: {}".format(error)) -def _get_gpu_arch_info(major, minor): +def _get_gpu_arch_info(major: int, minor: int) -> tuple[str, str, list[str]]: """ Get GPU architecture information and compatibility details. Return [Unknown, f"sm_{major}{minor}", [f"sm_{major}{minor}"]] if the major and minor version is not in the map. @@ -75,13 +82,21 @@ def _get_gpu_arch_info(major, minor): (8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40 (9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100 (10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200 + (10, 3): ("Blackwell", "sm_103a", ["sm_103a"]), + (12, 0): ( + "Blackwell", + "sm_120a", + ["sm_120a"], + ), # RTX PRO 6000 / RTX 50 Series } return gpu_arch_map.get( (major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"]) ) -def get_compute_capability_major_minor(device_id: int = 0): +def get_compute_capability_major_minor( + device_id: int = 0, +) -> tuple[int | None, int | None]: """ Get the compute capability of the CUDA device. :param device_id: The ID of the CUDA device. @@ -146,15 +161,15 @@ class DeviceInfo: device_count: int = 0 current_device: int = 0 - device_name: Optional[str] = None - major_version: Optional[int] = None - minor_version: Optional[int] = None - arch_name: Optional[str] = None - sm_arch: Optional[str] = None - compatible_archs: Optional[List[str]] = None - memory_gb: Optional[float] = None - target_arch: Optional[str] = None - error_message: Optional[str] = None + device_name: str | None = None + major_version: int | None = None + minor_version: int | None = None + arch_name: str | None = None + sm_arch: str | None = None + compatible_archs: list[str] | None = None + memory_gb: float | None = None + target_arch: str | None = None + error_message: str | None = None initialization_failed: bool = False def pretty_str(self) -> str: @@ -277,7 +292,7 @@ def get_device_info() -> DeviceInfo: return device_info -def checkCudaErrors(result): +def checkCudaErrors(result: Any) -> Any: """Check CUDA errors and provide detailed error messages. :param result: The result of the CUDA operation. :type result: tuple(CUresult, ...) @@ -304,7 +319,7 @@ def checkCudaErrors(result): # ============================================================================= -def get_current_device(): +def get_current_device() -> Any: """ Gets the current device on the active context. :return: The current device. @@ -317,7 +332,7 @@ def get_current_device(): return dev -def get_device(device_id: int): +def get_device(device_id: int) -> Any: """ Gets a device given its ordinal. :param device_id: The ID of the device. @@ -332,8 +347,25 @@ def get_device(device_id: int): return dev + @lru_cache(maxsize=1) -def initialize_cuda_context(device_id: int = 0, flags: int = 0): +def _create_cuda_context(device_id: int = 0, flags: int = 0) -> Any: + """Creates and caches a new CUDA context. Cached to prevent duplicate + context creation, which would cause CUDA_ERROR_OUT_OF_MEMORY.""" + cuDevice = get_device(device_id) + _log().info(f"cuCtxCreate {0} {cuDevice}") + if cuda.CUDA_VERSION >= 13000: + # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2 + # and v3 API has been removed from CTK 13. + # See https://github.com/NVIDIA/cuda-python/pull/792 + context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice)) + else: + context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) + _log().info(f"{context} <-- cuCtxCreate") + return context + + +def initialize_cuda_context(device_id: int = 0, flags: int = 0) -> Any: """ Initializes the CUDA context for a specified device. :param device_id: The ID of the device. @@ -347,7 +379,7 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0): # Initialize CUDA Driver API _log().info(f"cuInit {flags}") checkCudaErrors(cuda.cuInit(flags)) - + driver_version = get_driver_version() # Check the CUDA driver version works for the installed cuda-python package @@ -357,23 +389,33 @@ def initialize_cuda_context(device_id: int = 0, flags: int = 0): suggestion=f"Consider updating your NVIDIA driver to version 580 or above. Or install cuda-python package with version 12.9 or below.", ) - # Retrieve handle for device - cuDevice = get_device(device_id) - # Create context - _log().info(f"cuCtxCreate {0} {cuDevice}") - if cuda.CUDA_VERSION >= 13000: - # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2 - # and v3 API has been removed from CTK 13. - # See https://github.com/NVIDIA/cuda-python/pull/792 - context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice)) - else: - context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) - _log().info(f"{context} <-- cuCtxCreate") + # Check if a valid CUDA context already exists (e.g., created by PyTorch or + # another framework). Reusing it avoids creating redundant contexts, which can + # cause CUDA_ERROR_OUT_OF_MEMORY in multi-process setups (e.g., pytest-xdist + # with many workers sharing a single GPU). This check is intentionally not + # cached so that it always reflects the current state of the CUDA context + # stack — an external framework may destroy or replace its context at any time. + try: + result = cuda.cuCtxGetCurrent() + if not result[0].value and result[1] is not None: + # Validate that the context is usable by querying its device + dev_result = cuda.cuCtxGetDevice() + if not dev_result[0].value: + # Only reuse if the context's device matches the requested one + if int(dev_result[1]) == device_id: + _log().info( + f"Reusing existing CUDA context: {result[1]} " + f"(device: {dev_result[1]})" + ) + return result[1] + except Exception: + pass - return context + # No usable external context — create one (cached to prevent duplicates). + return _create_cuda_context(device_id, flags) -def device_primary_context_retain(device): +def device_primary_context_retain(device: Any) -> Any: """ Retains the primary context on the device. :param device: The device. @@ -386,7 +428,7 @@ def device_primary_context_retain(device): return checkCudaErrors(cuda.cuDevicePrimaryCtxRetain(device)) -def device_primary_context_release(device): +def device_primary_context_release(device: Any) -> None: """ Releases the primary context on the device. :param device: The device. @@ -403,15 +445,15 @@ class DevicePrimaryContext: the object is no longer alive. """ - def __init__(self, device): + def __init__(self, device: Any) -> None: self.device = device self.context = device_primary_context_retain(self.device) - def __del__(self): + def __del__(self) -> None: device_primary_context_release(self.device) -def load_cubin_module(cubin_file): +def load_cubin_module(cubin_file: str) -> Any: """ Loads a CUBIN file and returns the module. :param cubin_file: The path to the CUBIN file. @@ -432,7 +474,7 @@ def load_cubin_module(cubin_file): return module -def unload_cubin_module(module): +def unload_cubin_module(module: Any) -> None: """ Unloads a CUBIN module. :param module: The module. @@ -443,7 +485,7 @@ def unload_cubin_module(module): checkCudaErrors(cuda.cuModuleUnload(module)) -def load_cubin_module_data(cubin_data): +def load_cubin_module_data(cubin_data: bytes) -> Any: """ Loads a CUBIN from data and returns the module. :param cubin_data: The binary data of the CUBIN. @@ -460,7 +502,7 @@ def load_cubin_module_data(cubin_data): return module -def get_kernel_function(module, kernel_name): +def get_kernel_function(module: Any, kernel_name: str) -> Any: """ Retrieves the kernel function from the module. :param module: The module. @@ -479,7 +521,7 @@ def get_kernel_function(module, kernel_name): return kernel -def load_library(cubin_file): +def load_library(cubin_file: str) -> Any: """ Loads a CUBIN file and returns the library. :param cubin_file: The path to the CUBIN file. @@ -495,7 +537,7 @@ def load_library(cubin_file): return load_library_data(cubin_data) -def unload_library(library): +def unload_library(library: Any) -> None: """ Unloads a CUBIN library. :param library: The library. @@ -507,7 +549,7 @@ def unload_library(library): _log().info(f"cuLibraryUnload done {library}") -def load_library_data(cubin_data): +def load_library_data(cubin_data: bytes | int) -> Any: """ Loads a CUBIN from data and returns the library. :param cubin_data: The binary data of the CUBIN. @@ -527,7 +569,7 @@ def load_library_data(cubin_data): return library -def get_library_kernel(library, kernel_name): +def get_library_kernel(library: Any, kernel_name: str) -> Any: """ Retrieves the kernel from the library. :param library: The library. @@ -546,7 +588,7 @@ def get_library_kernel(library, kernel_name): return kernel -def get_function_from_kernel(kernel): +def get_function_from_kernel(kernel: Any) -> Any: """ Retrieves the kernel function from the kernel. :param kernel: The kernel. @@ -562,7 +604,39 @@ def get_function_from_kernel(kernel): return kernel_fn -def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None): +def load_library_from_file(file_path: str | os.PathLike[str]) -> Any: + """ + Loads a file, e.g., cubin, and returns the library + :param file_path: The path to the file. + :type file_path: str or Path + :return: The library. + :rtype: cuda.CUlibrary + :raise DSLRuntimeError: If the CUDA operation fails. + """ + _log().info(f"cuLibraryLoadFromFile {file_path}") + library = checkCudaErrors( + cuda.cuLibraryLoadFromFile( + fileName=str(file_path).encode(), + jitOptions=None, + jitOptionsValues=None, + numJitOptions=0, + libraryOptions=None, + libraryOptionValues=None, + numLibraryOptions=0, + ) + ) + _log().info(f"{library} <-- cuLibraryLoadFromFile") + return library + + +def launch_kernel( + kernel: Any, + grid_dims: tuple[int, int, int], + block_dims: tuple[int, int, int], + stream: Any, + smem_size: int, + kernel_args: Any | None = None, +) -> None: """ Launches the CUDA kernel. :param kernel: The kernel. @@ -603,7 +677,7 @@ def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args= ) -def stream_sync(stream): +def stream_sync(stream: Any) -> None: """ Synchronizes the CUDA stream. :param stream: The stream. @@ -614,7 +688,7 @@ def stream_sync(stream): checkCudaErrors(cuda.cuStreamSynchronize(stream)) -def stream_create(id=0): +def stream_create(id: int = 0) -> Any: """ Creates the CUDA stream. :param id: The ID of the stream. @@ -629,7 +703,7 @@ def stream_create(id=0): return stream -def stream_destroy(stream): +def stream_destroy(stream: Any) -> None: """ Destroys the CUDA stream. :param stream: The stream. @@ -640,7 +714,7 @@ def stream_destroy(stream): checkCudaErrors(cuda.cuStreamDestroy(stream)) -def context_destroy(context): +def context_destroy(context: Any) -> None: """ Destroys the CUDA context. :param context: The context. @@ -651,7 +725,7 @@ def context_destroy(context): checkCudaErrors(cuda.cuCtxDestroy(context)) -def allocate(size_in_bytes: int, stream=None): +def allocate(size_in_bytes: int, stream: Any | None = None) -> Any: """ Allocate device memory based on numpy host array size. :param size_in_bytes: The size of the memory to allocate. @@ -671,7 +745,7 @@ def allocate(size_in_bytes: int, stream=None): return device_memory -def deallocate(device_pointer, stream=None): +def deallocate(device_pointer: Any, stream: Any | None = None) -> None: """ Deallocate the specified device memory pointer. :param device_pointer: The device memory pointer. @@ -689,7 +763,12 @@ def deallocate(device_pointer, stream=None): checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream)) -def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): +def memcpy_h2d( + host_pointer: int, + device_pointer: Any, + size_in_bytes: int, + stream: Any | None = None, +) -> None: """ Copy data from host to device memory if stream is None, the copy is synchronous otherwise it is asynchronous. @@ -718,7 +797,12 @@ def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): ) -def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): +def memcpy_d2h( + host_pointer: int, + device_pointer: Any, + size_in_bytes: int, + stream: Any | None = None, +) -> None: """ Copy data from device to host memory if stream is None, the copy is synchronous otherwise it is asynchronous. @@ -747,7 +831,7 @@ def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): ) -def default_stream(): +def default_stream() -> Any: """ Returns the default stream. :return: The default stream. @@ -757,7 +841,7 @@ def default_stream(): @lru_cache(maxsize=1) -def get_driver_version(): +def get_driver_version() -> Any: """ Returns the CUDA driver version. Note: the value is cached after the first call. @@ -771,7 +855,12 @@ def get_driver_version(): return checkCudaErrors(cuda.cuDriverGetVersion()) -def set_kernel_attribute(kernel, attribute, value, device=None): +def set_kernel_attribute( + kernel: Any, + attribute: Any, + value: int, + device: Any | None = None, +) -> Any: """ Sets a CUDA kernel attribute. If the device is not provided, the attribute is set for the current device. @@ -797,7 +886,7 @@ def set_kernel_attribute(kernel, attribute, value, device=None): ) -def get_device_attribute(attribute, device_id: int = 0): +def get_device_attribute(attribute: Any, device_id: int = 0) -> Any: """ Gets a CUDA device attribute. :param attribute: The attribute. diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py b/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py index ef89dbbee..86a0d2fac 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/device_tensor.py @@ -10,13 +10,14 @@ # is strictly prohibited. import copy +from typing import Any from . import cuda as cuda_helpers from .tensor_descriptor import * from ..common import * -def allocate(tensor: TensorDescriptor, stream=None): +def allocate(tensor: TensorDescriptor, stream: Any = None) -> None: """ Allocates GPU memory """ @@ -29,10 +30,10 @@ def allocate(tensor: TensorDescriptor, stream=None): tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream) - log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) # type: ignore[union-attr] -def deallocate(tensor: TensorDescriptor, stream=None): +def deallocate(tensor: TensorDescriptor, stream: Any = None) -> None: """ Deallocates GPU memory """ @@ -43,7 +44,7 @@ def deallocate(tensor: TensorDescriptor, stream=None): if tensor.device_pointer is None: raise DSLRuntimeError("Tensor is not allocated on the device.") - log().info( + log().info( # type: ignore[union-attr] "Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer ) @@ -51,27 +52,31 @@ def deallocate(tensor: TensorDescriptor, stream=None): tensor.device_pointer = None -def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None): +def copy_to_gpu( + tensor: TensorDescriptor, do_allocate: bool = True, stream: Any = None +) -> TensorDescriptor: """ Copies data from host memory to the GPU memory. If do_allocate is True, it first calls allocate """ - log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) # type: ignore[union-attr] if do_allocate: allocate(tensor, stream) cuda_helpers.memcpy_h2d( tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream ) - log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) # type: ignore[union-attr] return tensor -def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None): +def copy_from_gpu( + tensor: TensorDescriptor, do_deallocate: bool = True, stream: Any = None +) -> None: """ Copies data from GPU memory back to the host. If do_deallocate is True, it calls deallocate """ - log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) # type: ignore[union-attr] if tensor._check_is_managed_by_framework(): raise DSLRuntimeError( "GPU tensors are managed by the framework and cannot be modified." @@ -84,10 +89,12 @@ def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None): ) if do_deallocate: deallocate(tensor, stream) - log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + log().info( # type: ignore[union-attr] + "copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer + ) -def to_gpu(tensor, stream=None) -> TensorDescriptor: +def to_gpu(tensor: Any, stream: Any = None) -> TensorDescriptor: """ Copies the tensor to the GPU memory from Host memory """ @@ -104,7 +111,7 @@ def to_gpu(tensor, stream=None) -> TensorDescriptor: raise DSLRuntimeError("Unsupported type") -def from_gpu(tensor, stream=None) -> TensorDescriptor: +def from_gpu(tensor: Any, stream: Any = None) -> TensorDescriptor: """ Copies the tensor to the GPU memory from Host memory """ diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py b/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py index 84952c549..d42dbb6b7 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/jit_arg_adapters.py @@ -14,7 +14,11 @@ This module provides runtime utilities for JIT argument conversion in DSL. """ from functools import wraps -from typing import get_origin +from typing import Callable, Any, Optional, get_origin +from inspect import Parameter +from dataclasses import is_dataclass, fields +from itertools import chain + # Local modules imports from ..common import DSLRuntimeError @@ -23,15 +27,30 @@ from ..typing import ( Int32, Float32, Boolean, + NumericMeta, + cast, + get_c_pointers, + get_mlir_types, + implements_jit_argument, + implements_dynamic_expression, ) +from ..utils.tree_utils import is_constexpr_field +from ..._mlir import ir -def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func): +def is_arg_annotation_constexpr( + arg_annotation: Any, + arg_name: str, + arg_index: int, + owning_func: Optional[Callable[..., Any]], +) -> bool: """ - Check if the argument spec is a constexpr. + Check if the argument annotation is a constexpr. """ - def _is_reserved_python_func_arg(arg_index, arg_name, func): + def _is_reserved_python_func_arg( + arg_index: int, arg_name: str, func: Optional[Callable[..., Any]] + ) -> bool: """ Check if the argument is a reserved python function argument. """ @@ -42,35 +61,43 @@ def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func): if arg_name == "self": return True - is_classmethod = isinstance(func, classmethod) or ( - hasattr(func, "__func__") and isinstance(func.__func__, classmethod) - ) - return arg_name == "cls" and is_classmethod + if func: + is_classmethod = isinstance(func, classmethod) or ( + hasattr(func, "__func__") and isinstance(func.__func__, classmethod) + ) + return arg_name == "cls" and is_classmethod + return False return ( _is_reserved_python_func_arg(arg_index, arg_name, owning_func) - or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr)) - or (get_origin(arg_spec) is Constexpr) + or (isinstance(arg_annotation, type) and issubclass(arg_annotation, Constexpr)) + or (get_origin(arg_annotation) is Constexpr) ) -def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func): +def is_argument_constexpr( + arg: Any, + arg_annotation: Any, + arg_name: str, + arg_index: int, + owning_func: Callable[..., Any], +) -> bool: """ Check if the argument is a constexpr. """ - def _is_type_argument(arg, arg_annotation): + def _is_type_argument(arg: Any, arg_annotation: Any) -> bool: """ Check if the argument is a type argument like Type[X] """ return isinstance(arg, type) and ( - arg_annotation is None or get_origin(arg_annotation) is type + arg_annotation is Parameter.empty or get_origin(arg_annotation) is type ) return ( - is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func) - or _is_type_argument(arg, arg_spec) + is_arg_annotation_constexpr(arg_annotation, arg_name, arg_index, owning_func) + or _is_type_argument(arg, arg_annotation) or arg is None ) @@ -86,10 +113,14 @@ class JitArgAdapterRegistry: """ # A dictionary with key=type and value=callable - jit_arg_adapter_registry = {} + jit_arg_adapter_registry: dict[type, Any] = {} + + # Default adapters for arguments we don't know type names beforehand + # Default dataclass adapter + default_dataclass_adapter: Callable[[object], Any] | None = None @classmethod - def register_jit_arg_adapter(cls, *dargs, **dkwargs): + def register_jit_arg_adapter(cls, *dargs: Any, **dkwargs: Any) -> Any: """ Register a JIT argument adapter callable @@ -106,11 +137,11 @@ class JitArgAdapterRegistry: The adapters are registered per type. If a type is already registerd, an error will be raised. """ - def decorator(*dargs, **dkwargs): + def decorator(*dargs: Any, **dkwargs: Any) -> Any: darg_python_ty = dargs[0] @wraps(darg_python_ty) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: if len(args) != 1 or not callable(args[0]): raise DSLRuntimeError( "a callable must be provided for registering JIT argument adapter" @@ -140,11 +171,81 @@ class JitArgAdapterRegistry: ) @classmethod - def get_registered_adapter(cls, ty): + def get_registered_adapter(cls, arg: object) -> Any: """ - Get the registered JIT argument adapter for the given type. + Get the registered JIT argument adapter for the given argument. """ - return cls.jit_arg_adapter_registry.get(ty, None) + adapter = cls.jit_arg_adapter_registry.get(type(arg), None) + if adapter is None: + if (cls.default_dataclass_adapter + and not implements_jit_argument(arg, partial=True) + and not implements_dynamic_expression(arg, partial=True) + and is_dataclass(arg) + and len(vars(arg)) == len(fields(arg))): # no extra/missing instance attrs + adapter = cls.default_dataclass_adapter + return adapter + + @classmethod + def set_default_dataclass_adapter(cls, adapter: Callable[[object], Any]) -> None: + """ + Set up a default dataclass adapter. If any user defined dataclass implements the JitArgument/DynamicExpression protocol, + those impls will be honored instead of this default adapter. + """ + cls.default_dataclass_adapter = adapter + + +class DefaultDataclassAdapter: + """ + Adapter for dataclass typed JIT arguments. + """ + def __init__(self, arg: object) -> None: + self._ir_fields: dict[str, object] = {} + self._ir_fields_len: dict[str, int] = {} + self._arg = arg + for f in fields(arg): # type: ignore[arg-type] + arg_field = getattr(arg, f.name) + if not is_constexpr_field(f): + if isinstance(f.type, NumericMeta) and not isinstance(arg_field, f.type): + self._ir_fields[f.name] = cast(arg_field, f.type) # type: ignore[arg-type] + else: + # Allow the nested fields to be adapted + arg_adapter = JitArgAdapterRegistry.get_registered_adapter(arg_field) + if arg_adapter is not None: + self._ir_fields[f.name] = arg_adapter(arg_field) + else: + self._ir_fields[f.name] = arg_field + + def __c_pointers__(self) -> list[Any]: + return list(chain.from_iterable(get_c_pointers(v) for v in self._ir_fields.values())) + + def __get_mlir_types__(self) -> list[Any]: + ir_types = [] + for f, v in self._ir_fields.items(): + types = get_mlir_types(v) + self._ir_fields_len[f] = len(types) + ir_types.extend(types) + return ir_types + + def __new_from_mlir_values__(self, values: list[Any]) -> Any: + from ..dsl import new_from_mlir_values # deferred to avoid circular import + + kwargs = {} + idx = 0 + for f in fields(self._arg): # type: ignore[arg-type] + if is_constexpr_field(f): + kwargs[f.name] = getattr(self._arg, f.name) + else: + kwargs[f.name] = new_from_mlir_values(self._ir_fields[f.name], values[idx : idx + self._ir_fields_len[f.name]]) + idx += self._ir_fields_len[f.name] + return type(self._arg)(**kwargs) + + def __extract_mlir_values__(self) -> list[ir.Value]: + from ..dsl import extract_mlir_values # deferred to avoid circular import + + return list(chain.from_iterable(extract_mlir_values(v) for v in self._ir_fields.values())) + + +JitArgAdapterRegistry.set_default_dataclass_adapter(DefaultDataclassAdapter) # ============================================================================= @@ -155,7 +256,7 @@ class JitArgAdapterRegistry: @JitArgAdapterRegistry.register_jit_arg_adapter(int) @JitArgAdapterRegistry.register_jit_arg_adapter(float) @JitArgAdapterRegistry.register_jit_arg_adapter(bool) -def _convert_python_scalar(arg): +def _convert_python_scalar(arg: Any) -> Any: """ Convert a Python scalar to a DSL type. """ @@ -164,19 +265,19 @@ def _convert_python_scalar(arg): float: Float32, bool: Boolean, } - return conversion_map.get(type(arg))(arg) + return conversion_map.get(type(arg))(arg) # type: ignore[misc] @JitArgAdapterRegistry.register_jit_arg_adapter(tuple) @JitArgAdapterRegistry.register_jit_arg_adapter(list) -def _convert_python_sequence(arg): +def _convert_python_sequence(arg: Any) -> Any: """ Go through each element in the sequence and convert it to a type that can be further processed by DSL to generate the corresponding JIT argument(s). """ adapted_arg = [] for elem in arg: - adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem)) + adapter = JitArgAdapterRegistry.get_registered_adapter(elem) if adapter is not None: converted_elem = adapter(elem) adapted_arg.append(converted_elem) diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py b/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py index 2d43fd3ce..da6d27ff1 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/stream_adapter.py @@ -14,9 +14,9 @@ This module provides CUDA Python helper functions """ import cuda.bindings.driver as cuda +from typing import Any # MLIR imports -from ..._mlir import ir from ..._mlir.dialects import gpu from .jit_arg_adapters import JitArgAdapterRegistry @@ -28,16 +28,16 @@ class StreamAdapter: Convert a CUDA stream to a stream representation for JIT arg generation. """ - def __init__(self, arg): + def __init__(self, arg: Any) -> None: self._arg = arg self._c_pointer = self._arg.getPtr() - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[Any]) -> Any: assert len(values) == 1 return values[0] - def __c_pointers__(self): + def __c_pointers__(self) -> list[Any]: return [self._c_pointer] - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> list[Any]: return [gpu.AsyncTokenType.get()] diff --git a/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py b/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py index 4d311dacc..b8cb45d38 100644 --- a/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py +++ b/python/CuTeDSL/cutlass/base_dsl/runtime/tensor_descriptor.py @@ -12,8 +12,10 @@ # Helpers import itertools, operator import ctypes +from typing import Any + from . import dlpack_types as _dpack -from .dlpack_runtime import ( +from .dlpack_runtime import ( # type: ignore[import-not-found] dlpack_to_tensor_desc, get_tensor_desc_data_ptr, get_tensor_desc_is_in_device, @@ -49,7 +51,7 @@ from ..typing import ( class TensorDescriptor: - def __init__(self, tensor): + def __init__(self, tensor: Any) -> None: """Initialize with a tensor that supports the DLPack protocol. Args: @@ -69,13 +71,13 @@ class TensorDescriptor: self.device_pointer = None else: raise DSLRuntimeError( - f"DLPack device type is not supported {self.dl_tensor.device.device_type}" + f"DLPack device type is not supported {self.dl_tensor.device.device_type}" # type: ignore[attr-defined] ) - log().info("TensorDescriptor is created = [%s]", self) + log().info("TensorDescriptor is created = [%s]", self) # type: ignore[union-attr] @staticmethod - def can_transformed_to_dlpack(dl_tensor): + def can_transformed_to_dlpack(dl_tensor: object) -> bool: if not hasattr(dl_tensor, "__dlpack__") or not hasattr( dl_tensor, "__dlpack_device__" ): @@ -83,19 +85,19 @@ class TensorDescriptor: return True @property - def is_in_device(self): + def is_in_device(self) -> bool: """Check if the tensor is stored on a device.""" return not self.device_pointer is None @property - def device_id(self): + def device_id(self) -> int: """Return device id where tensor resides.""" if self.is_in_device: return get_tensor_desc_device_id(self._capsule) return -1 @property - def pointer(self): + def pointer(self) -> Any: """ Returns the pointer to the tensor data. This is either the device pointer or the data pointer if the data is not in a device. @@ -103,7 +105,7 @@ class TensorDescriptor: return self.device_pointer if self.device_pointer is not None else self.data_ptr @property - def element_type(self): + def element_type(self) -> type: """Return the corresponding Python type based on DLPack dtype metadata.""" str_element_type = get_tensor_desc_element_type(self._capsule) dtype_map = { @@ -132,27 +134,27 @@ class TensorDescriptor: return dtype_map[str_element_type] @property - def shape(self): + def shape(self) -> tuple[int, ...]: """Return the shape of the tensor.""" return get_tensor_desc_shape(self._capsule) @property - def rank(self): + def rank(self) -> int: """Return the rank of the tensor.""" return get_tensor_desc_ndim(self._capsule) @property - def strides(self): + def strides(self) -> tuple[int, ...]: """Return the rank of the tensor.""" return get_tensor_desc_stride(self._capsule) @property - def element_size_in_bytes(self): + def element_size_in_bytes(self) -> int: """Calculate the element size in bytes of the DLPack tensor.""" return get_tensor_desc_element_size_in_bytes(self._capsule) @property - def size_in_bytes(self): + def size_in_bytes(self) -> int: """Calculate the total size in bytes of the DLPack tensor.""" # Calculate the number of elements using the shape ndim = get_tensor_desc_ndim(self._capsule) @@ -165,7 +167,7 @@ class TensorDescriptor: total_bytes = self.element_size_in_bytes * num_elements return total_bytes - def __str__(self): + def __str__(self) -> str: """Return a compact string representation of the device_tensor with a tensor prefix.""" # Extract shape shape = "x".join(map(str, self.shape)) @@ -184,7 +186,7 @@ class TensorDescriptor: return f"tensor<{shape}x{dtype}>_{device_type}" - def _check_is_managed_by_framework(self): + def _check_is_managed_by_framework(self) -> bool: """ Ensure the tensor is not managed by the framework (e.g., GPU tensor). Raises an exception if the tensor is framework-managed. @@ -192,18 +194,18 @@ class TensorDescriptor: return self.device_type == _dpack.DLDeviceType.kDLGPU @staticmethod - def is_compatible(maybe_tensor_descriptor) -> bool: + def is_compatible(maybe_tensor_descriptor: object) -> bool: """Check if the object is a TensorDescriptor or can be converted to one.""" return isinstance( maybe_tensor_descriptor, TensorDescriptor ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) -def from_tensor(tensor) -> TensorDescriptor: +def from_tensor(tensor: Any) -> TensorDescriptor: """Create a TensorDescriptor from a tensor object.""" return TensorDescriptor(tensor) -def to_tensor(tensor_descriptor: TensorDescriptor): +def to_tensor(tensor_descriptor: TensorDescriptor) -> Any: """Return tensor object from tensor descriptor.""" return tensor_descriptor.tensor diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py index 6b47e8de7..ddc6c7661 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/call_provider.py @@ -12,7 +12,7 @@ """Call provider that implements a specific calling convention.""" from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Optional, Union from . import spec from ..._mlir import ir @@ -125,7 +125,7 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder): return value def map_stride_for_tensor_dtype_f4x2_to_f4( - index, value: ir.Value + index: int, value: ir.Value ) -> ir.Value: if index != stride_one_index: with ir.InsertionPoint(current_block): @@ -262,7 +262,7 @@ class DynamicParamPackCallProvider(CallProvider, TVMFFIBuilder): call_operands += self.load_to_call_operands(struct_type, alloca) else: # pack the values to an alloca that we can pass as void** - all_values = [] + all_values: list[Any] = [] for _, value in packed_params: if isinstance(value, tuple): all_values.extend(value) diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py index 56f56cd41..b6a4c73d1 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/mlir_builder.py @@ -138,7 +138,7 @@ class MLIRBuilder(MLIRTypeBuilder): super().__init__() self.module: Optional[ir.Module] = None self.const_str_table: dict[str, ir.Value] = {} - self.const_func_ptr_table: dict[str, ir.Value] = {} + self.get_element_extra_kwargs: dict[str, Any] = {} # create constants @@ -307,7 +307,7 @@ class MLIRBuilder(MLIRTypeBuilder): true_block: ir.Block, false_block: ir.Block, *, - branch_weights=None, + branch_weights: Optional[tuple[int, int]] = None, true_dest_operands: Sequence[ir.Value] = (), false_dest_operands: Sequence[ir.Value] = (), ) -> None: @@ -371,65 +371,6 @@ class MLIRBuilder(MLIRTypeBuilder): self.const_str_table[content] = symbol return symbol - def get_or_load_global_func_ptr_from_text( - self, - current_block: ir.Block, - function_name: str, - ) -> ir.Value: - """Get or create a function pointer global in .text section and load it. - - This creates a constant global function pointer in the .text section - (for AArch64 ADRP range compatibility) and performs a volatile load - to prevent optimization. - - This forces the function pointer to be local to the code, bypassing GOT entry - ADRP lookup issues on AArch64 when GOT and .text section are more than 4GB - apart which can happen when ASLR is applied. - """ - # Check if we've already created this global - if function_name not in self.const_func_ptr_table: - symbol = f"__func_ptr_{function_name}" - - module_body = self.module.body - with ir.InsertionPoint(module_body): - # 1. Create the global constant - # We use 'private' linkage so it doesn't conflict across modules - global_ptr = llvm.GlobalOp( - self.ptr_type, - symbol, - ir.Attribute.parse("#llvm.linkage"), - # Initialization via block below - ) - - # 2. Set the necessary attributes for JIT safety and AArch64 range - # We use 'constant' to mark it as immutable - # We use 'section = ".text"' to force it into the code block - global_ptr.attributes["constant"] = ir.UnitAttr.get() - global_ptr.attributes["section"] = ir.StringAttr.get(".text") - - # 3. Add a constructor block to the GlobalOp to initialize it - # with the address of the target function - initializer_block = global_ptr.initializer.blocks.append() - with ir.InsertionPoint(initializer_block): - # Get the address of the external function - func_addr = llvm.AddressOfOp(self.ptr_type, function_name).res - # Return the address as the initial value of the global - llvm.return_(arg=func_addr) - - self.const_func_ptr_table[function_name] = symbol - else: - symbol = self.const_func_ptr_table[function_name] - - # Load it with volatile semantics in the current block - with ir.InsertionPoint(current_block): - symbol_addr = self.address_of(symbol, self.ptr_type) - # Perform a volatile load to prevent optimization - load_op = llvm.load(self.ptr_type, symbol_addr) - # Set volatile attribute to prevent optimization - load_op.owner.attributes["volatile_"] = ir.UnitAttr.get() - return load_op - - # function def function( self, @@ -478,7 +419,9 @@ class MLIRBuilder(MLIRTypeBuilder): ) func_op.attributes["llvm.linkage"] = ir.StringAttr.get("external") - def create_alloca(self, entry_block: ir.Block, alloca_type: ir.Type, array_size: int) -> ir.Value: + def create_alloca( + self, entry_block: ir.Block, alloca_type: ir.Type, array_size: int + ) -> ir.Value: """Create an alloca operation.""" with ir.InsertionPoint(entry_block.operations[0]): # declare the struct type @@ -534,7 +477,7 @@ class MLIRBuilder(MLIRTypeBuilder): ) -> list[ir.Operation]: """Find operations in the module by the operation name.""" operations = [] - for op in module.body: # type: ignore[union-attr] + for op in module.body: if op.name == name: operations.append(op) return operations @@ -543,7 +486,7 @@ class MLIRBuilder(MLIRTypeBuilder): self, module: ir.Module, name: str ) -> Optional[ir.Operation]: """Find a function in the module.""" - for op in module.body: # type: ignore[union-attr] + for op in module.body: if op.name == "llvm.func": # Get the function name from the sym_name attribute if "sym_name" in op.attributes: diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py index 1232c1780..53b0d89e5 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/spec.py @@ -448,7 +448,7 @@ def signature(name: str, params: list[Param]) -> str: continue param_type = format_param_type(param) - param_str = f"{param.name}: {param_type}" + param_str = f"{param.name}: {param_type}" # type: ignore[attr-defined] param_strs.append(param_str) return f"{name}({', '.join(param_strs)})" @@ -522,6 +522,6 @@ def create_map_tensor_dtype_f4x2_to_f4_spec(f4_tensor_spec: Tensor) -> Tensor: f4_tensor_spec.name, new_shape, dtype=tvm_ffi.dtype("float4_e2m1fnx2"), - strides=new_strides, + strides=new_strides, # type: ignore[arg-type] map_tensor_dtype_f4x2_to_f4=True, ) diff --git a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py index b8b0ebe49..6cfe2fc9c 100644 --- a/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py +++ b/python/CuTeDSL/cutlass/base_dsl/tvm_ffi_builder/tvm_ffi_builder.py @@ -1606,7 +1606,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): data_as_int = llvm.ptrtoint(self.i64_type, data) # Check if data pointer is divisible by alignment # (uses fast path for power-of-two alignments) - return self.i64_divisible_const(data_as_int, param.data_alignment) + return self.i64_divisible_const(data_as_int, param.data_alignment) # type: ignore[arg-type] current_block = self.check_condition( current_block, @@ -2016,7 +2016,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): self.current_fn_signature = spec.signature(fn_display_name, params_list) self._fn_call_context = f" when calling: `{self.current_fn_signature}`" - with ir.InsertionPoint(self.module.body): # type: ignore[union-attr] + with ir.InsertionPoint(self.module.body): # void TVMFFIErrorSetRaisedFromCStr( # const char* error_kind, const char* message); self.find_or_declare_extern_func( @@ -2072,7 +2072,7 @@ class TVMFFIFunctionBuilder(TVMFFIBuilder): continue arg_context = ArgContext( - param_name=param.name, + param_name=param.name, # type: ignore[attr-defined] arg_index=ffi_arg_index, tuple_indices=[], ) diff --git a/python/CuTeDSL/cutlass/base_dsl/typing.py b/python/CuTeDSL/cutlass/base_dsl/typing.py index b299b6dde..60108b5f4 100644 --- a/python/CuTeDSL/cutlass/base_dsl/typing.py +++ b/python/CuTeDSL/cutlass/base_dsl/typing.py @@ -13,33 +13,28 @@ import ctypes from itertools import chain import numpy as np import operator -from typing_extensions import deprecated -from functools import reduce from typing import ( + Callable, Generic, + Optional, Protocol, Union, Any, - List, Type, TypeVar, overload, runtime_checkable, - get_origin, ) -from types import FunctionType -from dataclasses import dataclass -from abc import ABC, abstractmethod from .common import * -from .ast_helpers import const_expr -from ._mlir_helpers import arith as arith_helper, lru_cache_ir +from .common import DSLRuntimeError as DSLRuntimeError +from ._mlir_helpers import arith as arith_helper from ._mlir_helpers.arith import ArithValue from ._mlir_helpers.op import dsl_user_op from .._mlir import ir from .._mlir.extras import types as T -from .._mlir.dialects import arith, math +from .._mlir.dialects import arith # ============================================================================= # Dynamic Expression Protocol @@ -103,7 +98,7 @@ class DynamicExpression(Protocol): } """ - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: """Extract MLIR values from this object. :return: List of MLIR values representing this object's data @@ -111,7 +106,7 @@ class DynamicExpression(Protocol): """ raise NotImplementedError - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "DynamicExpression": """Create a new instance from MLIR values. :param values: List of MLIR values to construct the object from @@ -194,7 +189,7 @@ class JitArgument(Protocol): jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...])) """ - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: """ Generate a list of ctypes pointers for the current object. @@ -203,7 +198,7 @@ class JitArgument(Protocol): """ raise NotImplementedError - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> list[ir.Type]: """ Generate a list of MLIR types for the current object. @@ -212,7 +207,7 @@ class JitArgument(Protocol): """ raise NotImplementedError - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "JitArgument": """ Create a new object from MLIR values. @@ -224,7 +219,7 @@ class JitArgument(Protocol): raise NotImplementedError -def get_c_pointers(obj): +def get_c_pointers(obj: Any) -> list[ctypes.c_void_p]: """ Given the `obj`, recursively go through it to extract all contained C pointers """ @@ -241,7 +236,7 @@ def get_c_pointers(obj): return [] -def get_mlir_types(obj): +def get_mlir_types(obj: Any) -> list[ir.Type]: """ Given the `obj`, recursively go through it to extract all contained MLIR types """ @@ -262,6 +257,30 @@ def get_mlir_types(obj): return [] +def implements_jit_argument(obj: Any, *, partial: bool = False) -> bool: + """ + Check if the object implements the JitArgument protocol. + When partial=True, returns True if any protocol method is present. + """ + check = any if partial else all + return check( + hasattr(obj, attr) + for attr in ("__c_pointers__", "__get_mlir_types__", "__new_from_mlir_values__") + ) + + +def implements_dynamic_expression(obj: Any, *, partial: bool = False) -> bool: + """ + Check if the object implements the DynamicExpression protocol. + When partial=True, returns True if any protocol method is present. + """ + check = any if partial else all + return check( + hasattr(obj, attr) + for attr in ("__extract_mlir_values__", "__new_from_mlir_values__") + ) + + class DslType(type): """Metaclass for all DSL types in the system. @@ -299,7 +318,14 @@ class DslType(type): _is_abstract: bool - def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs): + def __new__( + cls, + name: str, + bases: tuple, + attrs: dict, + is_abstract: bool = False, + **kwargs: Any, + ) -> Any: new_cls = super().__new__(cls, name, bases, attrs) new_cls._is_abstract = is_abstract @@ -307,7 +333,7 @@ class DslType(type): return new_cls @property - def is_abstract(cls): + def is_abstract(cls) -> bool: return cls._is_abstract @@ -333,26 +359,27 @@ class NumericMeta(DslType): """ width: int + bytes: int # Placeholder type _mlir_type = Any - _np_dtype: Union[np.dtype, None] + _np_dtype: Optional[type] def __new__( cls, - name, - bases, - attrs, - width=8, - np_dtype=None, - mlir_type=None, - is_abstract=False, - **kwargs, - ): - def _extract_mlir_values(self): + name: str, + bases: tuple, + attrs: dict, + width: int = 8, + np_dtype: Optional[type] = None, + mlir_type: Optional[Callable[[], ir.Type]] = None, + is_abstract: bool = False, + **kwargs: Any, + ) -> Any: + def _extract_mlir_values(self: "Numeric") -> list[ir.Value]: return [self.ir_value()] - def _new_from_mlir_values(self, values: list) -> "Numeric": + def _new_from_mlir_values(self: "Numeric", values: list[ir.Value]) -> "Numeric": res_ty = type(self) return res_ty(values[0]) @@ -373,24 +400,48 @@ class NumericMeta(DslType): new_cls._mlir_type = staticmethod(mlir_type) new_cls.width = width + new_cls.bytes = max(1, (width + 7) // 8) new_cls._np_dtype = np_dtype return new_cls @property - def numpy_dtype(cls): + def numpy_dtype(cls) -> Optional[type]: return cls._np_dtype @property - def is_integer(cls) -> bool: ... + def is_integer(cls) -> bool: ... # type: ignore[empty-body] @property - def is_float(cls) -> bool: ... + def is_float(cls) -> bool: ... # type: ignore[empty-body] def is_same_kind(cls, other: Type) -> bool: return cls.is_integer == other.is_integer or cls.is_float == other.is_float + def isinstance(cls, value: Any) -> bool: + """ + Check if the value is an compatible type with the numeric type. + + :param value: The value to check + :type value: Any + :return: True if the value is a compatible type with the numeric type, False otherwise + :rtype: bool + """ + if isinstance(value, Numeric): + return value.dtype is cls + elif isinstance(value, arith_helper.ArithValue): + elem_ty = arith_helper.element_type(value.type) + return Numeric.from_mlir_type(elem_ty) is cls + elif isinstance(value, int): + return cls.is_integer + elif isinstance(value, float): + return cls.is_float + elif isinstance(value, bool): + return cls.is_integer + else: + return False + @staticmethod - def from_python(value: Any) -> Type["Numeric"]: + def from_python(value: Union[bool, int, float]) -> Type["Numeric"]: """ Deduce the DSL type from a Python value. """ @@ -405,18 +456,20 @@ class NumericMeta(DslType): ) @property - def mlir_type(cls): - return cls._mlir_type() # type: ignore + def mlir_type(cls) -> ir.Type: + return cls._mlir_type() Value = TypeVar("Value") -def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric": +def cast( + obj: Union[bool, int, float, Value, "Numeric"], type_: Type["Numeric"] +) -> "Numeric": """Cast an object to the specified numeric type. :param obj: Object to be cast - :type obj: Union[bool, int, float, Value] + :type obj: Union[bool, int, float, Value, Numeric] :param type_: Target numeric type :type type_: Type[Numeric] :raises TypeError: If casting to an abstract type or unsupported type conversion @@ -427,6 +480,7 @@ def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeri >>> x = cast(5, Int32) # Cast integer to Int32 >>> y = cast(3.14, Float32) # Cast float to Float32 """ + res: "Numeric" if type_.is_abstract: if not isinstance(obj, type_): raise TypeError( @@ -435,10 +489,11 @@ def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeri ) # If target_type is abstract, and value is instance of target_type, # then we can return value as is + res = obj else: # Implicit cast based on using annotation type - obj = type_(obj) - return obj + res = type_(obj) # type: ignore[arg-type] + return res # Option 1: use ir.Value as base @@ -463,14 +518,14 @@ class IntegerMeta(NumericMeta): def __new__( cls, - name, - bases, - attrs, - width=32, - signed=True, - mlir_type=None, - is_abstract=False, - ): + name: str, + bases: tuple, + attrs: dict, + width: int = 32, + signed: bool = True, + mlir_type: Optional[Callable[[], ir.Type]] = None, + is_abstract: bool = False, + ) -> Any: if width == 1: np_dtype = np.bool_ elif width == 128: @@ -482,9 +537,9 @@ class IntegerMeta(NumericMeta): else: np_dtype = getattr(np, f"uint{width}") - def _c_pointers(self): + def _c_pointers(self: "Integer") -> list[ctypes.c_void_p]: if width == 1: - c_value = ctypes.c_bool(self.value) + c_value = ctypes.c_bool(self.value) # type: ignore[arg-type] elif signed: c_value = getattr(ctypes, f"c_int{width}")(self.value) else: @@ -501,7 +556,7 @@ class IntegerMeta(NumericMeta): new_cls.signed = signed return new_cls - def __str__(cls): + def __str__(cls) -> str: return f"{cls.__name__}" @property @@ -530,7 +585,7 @@ class IntegerMeta(NumericMeta): else: return 2**cls.width - 1 - def recast_width(cls, width): + def recast_width(cls, width: int) -> Type["Integer"]: type_map = { 8: Int8, 16: Int16, @@ -563,7 +618,15 @@ class FloatMeta(NumericMeta): _exponent_width: int _mantissa_width: int - def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False): + def __new__( + cls, + name: str, + bases: tuple, + attrs: dict, + width: int = 32, + mlir_type: Optional[Callable[[], ir.Type]] = None, + is_abstract: bool = False, + ) -> Any: np_dtype = getattr(np, name.lower(), None) new_cls = super().__new__( cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract @@ -584,7 +647,7 @@ class FloatMeta(NumericMeta): # Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc. return new_cls - def __str__(cls): + def __str__(cls) -> str: return f"{cls.__name__}" @property @@ -615,7 +678,7 @@ class FloatMeta(NumericMeta): def mantissa_width(cls) -> int: return cls._mantissa_width - def recast_width(cls, width): + def recast_width(cls, width: int) -> Type["Float"]: type_map = { 16: Float16, 32: Float32, @@ -626,7 +689,7 @@ class FloatMeta(NumericMeta): return type_map[width] -def _arith_signless_to_int(a, target_type): +def _arith_signless_to_int(a: ir.Value, target_type: "IntegerMeta") -> ir.Value: # is_signed: sign of result type if target_type.width > a.type.width: # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL @@ -640,7 +703,9 @@ def _arith_signless_to_int(a, target_type): return a -def _binary_op_type_promote(a, b, promote_bool: bool = False): +def _binary_op_type_promote( + a: "Numeric", b: "Numeric", promote_bool: bool = False +) -> tuple["Numeric", "Numeric", Type["Numeric"]]: """Promote two numeric operands following type promotion rules. :param a: First numeric operand @@ -688,9 +753,9 @@ def _binary_op_type_promote(a, b, promote_bool: bool = False): # If one type is integer, convert it to the float type if a_type.is_float and not b_type.is_float: - b_type = a_type.recast_width(max(a_width, b_width)) + b_type = a_type.recast_width(max(a_width, b_width)) # type: ignore[attr-defined] elif b_type.is_float and not a_type.is_float: - a_type = b_type.recast_width(max(a_width, b_width)) + a_type = b_type.recast_width(max(a_width, b_width)) # type: ignore[attr-defined] # Both are float types - handle precision promotion if a_width > b_width and a_width >= 16: @@ -735,8 +800,8 @@ def _binary_op_type_promote(a, b, promote_bool: bool = False): if a_type == b_type: return a, b, a_type - a_signed = a_type.signed - b_signed = b_type.signed + a_signed = a_type.signed # type: ignore[attr-defined] + b_signed = b_type.signed # type: ignore[attr-defined] a_width = a_type.width b_width = b_type.width @@ -764,7 +829,12 @@ def _binary_op_type_promote(a, b, promote_bool: bool = False): return a.to(b.dtype), b, b.dtype -def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): +def _binary_op( + op: Callable[..., Any], + promote_operand: bool = True, + promote_bool: bool = False, + flip: bool = False, +) -> Callable[..., Any]: """Wrapper for binary operations on Numeric types. This wrapper handles type promotion, operation execution, and result type determination @@ -791,7 +861,13 @@ def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): - Division (truediv) with integer types is not fully supported and converts to Float32 """ - def wrapper(lhs, rhs, *, loc=None, ip=None): + def wrapper( + lhs: "Numeric", + rhs: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: orig_lhs_type = type(lhs) orig_rhs_type = type(rhs) @@ -815,7 +891,7 @@ def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): if promote_operand: lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool) else: - rhs = ty(rhs) + rhs = ty(rhs) # type: ignore[arg-type] if op in ( operator.lt, @@ -831,13 +907,15 @@ def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean: res_type = Boolean + lhs_val: Union[bool, int, float, ir.Value, ArithValue] if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer): - lhs_val = lhs.value.with_signedness(lhs.signed) + lhs_val = lhs.value.with_signedness(lhs.signed) # type: ignore[attr-defined] else: lhs_val = lhs.value + rhs_val: Union[bool, int, float, ir.Value, ArithValue] if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer): - rhs_val = rhs.value.with_signedness(rhs.signed) + rhs_val = rhs.value.with_signedness(rhs.signed) # type: ignore[attr-defined] else: rhs_val = rhs.value @@ -863,7 +941,13 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): :vartype value: Union[bool, int, float, Value] """ - def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None): + def __init__( + self, + value: Union[bool, int, float, Value], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self.value = value def __str__(self) -> str: @@ -878,7 +962,7 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self.value)})" - def __hash__(self): + def __hash__(self) -> int: return hash(type(self).__class__) ^ hash(self.value) @property @@ -886,21 +970,57 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): return type(self) @overload - def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ... + def to( + self, + dtype: Type["Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": ... @overload - def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ... + def to( + self, + dtype: Type[int], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> int: ... @overload - def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ... + def to( + self, + dtype: Type[float], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> float: ... @overload - def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ... + def to( # type: ignore[overload-cannot-match] + self, + dtype: Type[bool], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: ... @overload - def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ... + def to( + self, + dtype: Type[ir.Value], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: ... - def to(self, dtype: Type, *, loc=None, ip=None): + def to( + self, + dtype: Type, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: """Convert this numeric value to another numeric type. If the target type is the same as the current type, returns self. @@ -958,7 +1078,7 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): else: raise ValueError( f"cannot convert {type(self)} to {dtype}, " - f"self.value is {self.value.type}" + f"self.value is {self.value.type}" # type: ignore[attr-defined] ) if not isinstance(res, ArithValue): @@ -974,13 +1094,45 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): else: raise ValueError(f"unable to convert {type(self)} to {dtype}") - def ir_value(self, *, loc=None, ip=None) -> ir.Value: + def ir_value( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: return self.to(ir.Value, loc=loc, ip=ip) - @property - def zero(self) -> "Numeric": ... + def bitcast( + self, + dtype: "Type[Numeric]", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + """Reinterpret the bits of this value as a different numeric type. - def __dsl_not__(self, *, loc=None, ip=None): + The source and target types must have the same bit width. + + :param dtype: Target DSL type (e.g., ``Float32`` when self is ``Int32``). + :return: A new instance of ``dtype`` with the same bit pattern. + """ + if not isinstance(dtype, NumericMeta): + raise TypeError(f"dtype must be a Numeric type, but got {dtype}") + if dtype is type(self): + return self + ir_val = self.ir_value(loc=loc, ip=ip) + result = arith.bitcast(dtype.mlir_type, ir_val, loc=loc, ip=ip) + return dtype(result) + + @property + def zero(self) -> "Numeric": ... # type: ignore[empty-body] + + def __dsl_not__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Union[bool, "Boolean"]: """DSL implementation of Python's `not` operator. Returns True if the value is equal to zero, False otherwise. @@ -1000,7 +1152,13 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): zero_val = arith.constant(ty.mlir_type, ty.zero) return self.__eq__(ty(zero_val), loc=loc, ip=ip) - def __dsl_and__(self, other, *, loc=None, ip=None): + def __dsl_and__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": """DSL implementation of Python's `and` operator. Returns the second operand if the first is truthy, otherwise returns the first operand. @@ -1027,11 +1185,14 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): # 6 unnecessary MLIR operations. For Boolean inputs the semantics of # `and` are identical to bitwise AND, so delegate directly to __and__. if isinstance(self, Boolean) and isinstance(other, Boolean): - return self.__and__(other, loc=loc, ip=ip) + return self.__and__(other, loc=loc, ip=ip) # type: ignore[call-arg] is_true = self.__dsl_bool__(loc=loc, ip=ip) - def and_op(lhs, rhs): + def and_op( + lhs: Union[bool, int, float, ir.Value], + rhs: Union[bool, int, float, ir.Value], + ) -> Union[bool, int, float, ir.Value]: if isinstance(lhs, (int, float, bool)): if isinstance(rhs, (int, float, bool)): return lhs and rhs @@ -1047,7 +1208,13 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip) - def __dsl_or__(self, other, *, loc=None, ip=None): + def __dsl_or__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": """DSL implementation of Python's `or` operator. Returns the first operand if it is truthy, otherwise returns the second operand. @@ -1070,7 +1237,10 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): """ is_true = self.__dsl_bool__(loc=loc, ip=ip) - def or_op(lhs, rhs): + def or_op( + lhs: Union[bool, int, float, ir.Value], + rhs: Union[bool, int, float, ir.Value], + ) -> Union[bool, int, float, ir.Value]: if isinstance(lhs, (int, float, bool)): if isinstance(rhs, (int, float, bool)): return lhs or rhs @@ -1086,7 +1256,12 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip) - def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean": + def __dsl_bool__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": """DSL implementation of Python's __bool__ method. Returns a Boolean indicating whether this value is considered truthy. @@ -1102,7 +1277,7 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): zero = type(self).zero return self.__ne__(zero, loc=loc, ip=ip) - def __bool__(self): + def __bool__(self) -> bool: if isinstance(self.value, (int, float, bool)): return bool(self.value) else: @@ -1115,30 +1290,47 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): ], ) - def __index__(self): + def __index__(self) -> int: if isinstance(self.value, (int, float, bool)): - return self.value + return self.value # type: ignore[return-value] else: raise DSLRuntimeError( f"'{type(self.value)}' object cannot be interpreted as an integer", suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator", ) - def __neg__(self, *, loc=None, ip=None): - if isinstance(self, (bool, int, float)): - return type(self)(-self.value) # type: ignore + def __neg__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + if isinstance(self.value, (bool, int, float)): + return type(self)(-self.value) else: - return type(self)(-self.value, loc=loc, ip=ip) # type: ignore + return type(self)(-self.value, loc=loc, ip=ip) # type: ignore[operator] + + def __abs__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + if isinstance(self.value, (bool, int, float)): + return type(self)(abs(self.value)) + else: + return type(self)(abs(self.value), loc=loc, ip=ip) # type: ignore[arg-type] @staticmethod - def _from_python_value(value): + def _from_python_value( + value: Union[bool, int, float, ArithValue, "Numeric"], + ) -> "Numeric": if isinstance(value, Numeric): return value if isinstance(value, bool): - res_type = Boolean + res_type: Type["Numeric"] = Boolean elif isinstance(value, int): - # Choose Int32 if it can represent the value, Int64 otherwise res_type = ( Int32 if (value <= 2147483647) and (value >= -2147483648) else Int64 ) @@ -1153,103 +1345,217 @@ class Numeric(metaclass=NumericMeta, is_abstract=True): return res_type(value) @dsl_user_op - def __add__(self, other, *, loc=None, ip=None) -> "Numeric": + def __add__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip) @dsl_user_op - def __sub__(self, other, *, loc=None, ip=None) -> "Numeric": + def __sub__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip) @dsl_user_op - def __mul__(self, other, *, loc=None, ip=None) -> "Numeric": + def __mul__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip) @dsl_user_op - def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric": + def __floordiv__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.floordiv, promote_bool=True)( self, other, loc=loc, ip=ip ) @dsl_user_op - def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric": + def __truediv__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.truediv, promote_bool=True)( self, other, loc=loc, ip=ip ) @dsl_user_op - def __mod__(self, other, *, loc=None, ip=None) -> "Numeric": + def __mod__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip) @dsl_user_op - def __radd__(self, other, *, loc=None, ip=None) -> "Numeric": + def __radd__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return self.__add__(other, loc=loc, ip=ip) @dsl_user_op - def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric": + def __rsub__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.sub, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) @dsl_user_op - def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric": + def __rmul__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return self.__mul__(other, loc=loc, ip=ip) @dsl_user_op - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric": + def __rfloordiv__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.floordiv, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) @dsl_user_op - def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric": + def __rtruediv__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.truediv, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) @dsl_user_op - def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric": + def __rmod__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.mod, promote_bool=True, flip=True)( self, other, loc=loc, ip=ip ) @dsl_user_op - def __eq__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore + def __eq__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": + return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) @dsl_user_op - def __ne__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore + def __ne__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": + return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) @dsl_user_op - def __lt__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore + def __lt__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": + return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) @dsl_user_op - def __le__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore + def __le__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": + return _binary_op(operator.le)(self, other, loc=loc, ip=ip) @dsl_user_op - def __gt__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore + def __gt__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": + return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) @dsl_user_op - def __ge__(self, other, *, loc=None, ip=None) -> "Boolean": - return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore + def __ge__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Boolean": + return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) @dsl_user_op - def __pow__(self, other, *, loc=None, ip=None) -> "Numeric": - return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore + def __pow__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: raise ValueError( f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}" ) - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> list[ir.Type]: return [type(self).mlir_type] @staticmethod - def from_mlir_type(mlir_type): + def from_mlir_type(mlir_type: ir.Type) -> Type["Numeric"]: type_map = { T.bool(): Boolean, T.f64(): Float64, @@ -1363,7 +1669,13 @@ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True) a = Int32(c5) # Treat c5 as int32 bitwise """ - def __init__(self, x, *, loc=None, ip=None): + def __init__( + self, + x: Union[bool, int, float, ir.Value, "Integer", "Float"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: ty = type(self) if isinstance(x, (bool, int, float)): @@ -1378,14 +1690,14 @@ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True) assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" x_val = int(np.array(x).astype(np_dtype)) elif type(x) == ty: - x_val = x.value - elif isinstance(x, ir.Value): # type: ignore + x_val = x.value # type: ignore[assignment] + elif isinstance(x, ir.Value): x_val = x - if isinstance(x.type, ir.IntegerType): # type: ignore + if isinstance(x.type, ir.IntegerType): if x.type.width != ty.width: # signless -> (u)int x_val = _arith_signless_to_int(x, ty) - elif isinstance(x.type, ir.FloatType): # type: ignore + elif isinstance(x.type, ir.FloatType): # float -> (u)int x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip) elif isinstance(x, Integer): @@ -1404,47 +1716,112 @@ class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True) super().__init__(x_val) - def __invert__(self, *, loc=None, ip=None): + def __invert__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Integer": res_type = type(self) return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip)) - def __lshift__(self, other, *, loc=None, ip=None): + def __lshift__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip) - def __rlshift__(self, other, *, loc=None, ip=None): + def __rlshift__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": other_ = as_numeric(other) if not isinstance(other_, Integer): raise ValueError(f"Cannot left shift {other_} with {self}") - return other_.__lshift__(self, loc=loc, ip=ip) + return other_.__lshift__(self, loc=loc, ip=ip) # type: ignore[call-arg] - def __rshift__(self, other, *, loc=None, ip=None): + def __rshift__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip) - def __rrshift__(self, other, *, loc=None, ip=None): + def __rrshift__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": other_ = as_numeric(other) if not isinstance(other_, Integer): raise ValueError(f"Cannot right shift {other_} with {self}") - return other_.__rshift__(self, loc=loc, ip=ip) + return other_.__rshift__(self, loc=loc, ip=ip) # type: ignore[call-arg] - def __and__(self, other, *, loc=None, ip=None): + def __and__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.and_)(self, other, loc=loc, ip=ip) - def __rand__(self, other, *, loc=None, ip=None): - return self.__and__(other, loc=loc, ip=ip) + def __rand__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + return self.__and__(other, loc=loc, ip=ip) # type: ignore[call-arg] - def __or__(self, other, *, loc=None, ip=None): + def __or__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.or_)(self, other, loc=loc, ip=ip) - def __ror__(self, other, *, loc=None, ip=None): - return self.__or__(other, loc=loc, ip=ip) + def __ror__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + return self.__or__(other, loc=loc, ip=ip) # type: ignore[call-arg] - def __xor__(self, other, *, loc=None, ip=None): + def __xor__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": return _binary_op(operator.xor)(self, other, loc=loc, ip=ip) - def __rxor__(self, other, *, loc=None, ip=None): - return self.__xor__(other, loc=loc, ip=ip) + def __rxor__( + self, + other: Union[int, float, bool, "Numeric"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": + return self.__xor__(other, loc=loc, ip=ip) # type: ignore[call-arg] - def __tvm_ffi_int__(self): + def __tvm_ffi_int__(self) -> Union[int, ir.Value]: return self.value @@ -1498,36 +1875,42 @@ class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True): :raises ValueError: If conversion from the input type is not supported """ - def __init__(self, x, *, loc=None, ip=None): + def __init__( + self, + x: Union[bool, int, float, ir.Value, "Integer", "Float"], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: ty = type(self) - if isinstance(x, (bool, int, float)): # type: ignore + if isinstance(x, (bool, int, float)): # Why we need to convert x to with numpy? # np_dtype = ty.numpy_dtype # assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" # x = float(np.array(x).astype(np_dtype)) super().__init__(float(x)) - elif isinstance(x, ir.Value): # type: ignore - if isinstance(x.type, ir.IntegerType): # type: ignore + elif isinstance(x, ir.Value): + if isinstance(x.type, ir.IntegerType): raise DSLRuntimeError("signless to float conversion is not implemented") - elif isinstance(x.type, ir.FloatType): # type: ignore + elif isinstance(x.type, ir.FloatType): if x.type != ty.mlir_type: x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip) super().__init__(x) elif isinstance(x, Integer): - if isinstance(x.value, ir.Value): # type: ignore + if isinstance(x.value, ir.Value): x = arith_helper.itofp( x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip ) else: - x = float(x.value) + x = float(x.value) # type: ignore[arg-type] super().__init__(x) elif isinstance(x, Float): Float.__init__(self, x.value) else: raise DSLRuntimeError(f"{x} to Float conversion is not supported") - def __tvm_ffi_float__(self): + def __tvm_ffi_float__(self) -> Union[float, ir.Value]: return self.value @@ -1564,8 +1947,12 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. """ def __init__( - self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None - ): + self, + a: Union[bool, int, float, ir.Value, Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: value = None if isinstance(a, (bool, int, float)): value = bool(a) @@ -1582,7 +1969,12 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. super().__init__(value, loc=loc, ip=ip) self._value_int8 = None - def ir_value_int8(self, *, loc=None, ip=None): + def ir_value_int8( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: """ Returns int8 ir value of Boolean. When we need to store Boolean tensor element, use ir_value_int8(). @@ -1599,7 +1991,12 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value() return self._value_int8 - def __neg__(self, *, loc=None, ip=None): + def __neg__( # type: ignore[override] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Numeric": """Negation operator is not supported for boolean type. :param loc: Source location information, defaults to None @@ -1611,6 +2008,7 @@ class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T. raise TypeError("Negation, the operator `-` is not supported for boolean type") + class Int4( Integer, metaclass=IntegerMeta, @@ -1661,7 +2059,7 @@ class Uint128( class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64): - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: if not isinstance(self.value, float): raise ValueError("only float is supported") @@ -1672,10 +2070,10 @@ class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64): class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32): @staticmethod - def _get_c_pointer(value: float): + def _get_c_pointer(value: float) -> ctypes.c_void_p: return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p) - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: if not isinstance(self.value, float): raise ValueError("only float is supported") @@ -1683,7 +2081,7 @@ class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32): class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32): - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: if not isinstance(self.value, float): raise ValueError("only float is supported") return [Float32._get_c_pointer(self.value)] @@ -1691,35 +2089,35 @@ class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32): class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16): @staticmethod - def _get_c_pointer(value: float): + def _get_c_pointer(value: float) -> ctypes.c_void_p: # Convert float to float16 binary representation # First convert to numpy float16 to handle the conversion f16_val = np.float16(value) # Get the raw bits as a 16-bit integer - bits = f16_val.view(np.uint16) + bits: np.uint16 = f16_val.view(np.uint16) # Create a short (16-bit int) with those bits - c_val = ctypes.c_short(bits) + c_val = ctypes.c_short(int(bits)) return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p) - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: if not isinstance(self.value, float): raise ValueError("only float is supported") return [Float16._get_c_pointer(self.value)] class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16): - def __c_pointers__(self): + def __c_pointers__(self) -> list[ctypes.c_void_p]: if not isinstance(self.value, float): raise ValueError("only float is supported") # Convert float32 to bfloat16 representation # First convert the value to float32 bit representation f32_val = np.float32(self.value) # Get the 32-bit integer representation - bits = f32_val.view(np.uint32) + bits: np.uint32 = f32_val.view(np.uint32) # Truncate to 16 bits, keeping the high 16 bits bf16_bits = np.uint16(bits >> 16) # Create a short (16-bit int) with those bits - c_val = ctypes.c_short(bf16_bits) + c_val = ctypes.c_short(bf16_bits) # type: ignore[arg-type] c_pointer = ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p) return [c_pointer] @@ -1790,7 +2188,7 @@ ALL_DTYPES = { __STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES} -def dtype(dtype_) -> Type[Numeric]: +def dtype(dtype_: str) -> Type[Numeric]: t = None if isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__: t = __STR_TO_DTYPE__[dtype_] @@ -1817,7 +2215,14 @@ class TensorMeta(DslType): >>> Tensor[T, (3, 4, 5)] """ - def __new__(cls, name, bases, attrs, element_type=Any, shape=Any): + def __new__( + cls, + name: str, + bases: tuple, + attrs: dict, + element_type: Any = Any, + shape: Any = Any, + ) -> Any: new_cls = super().__new__(cls, name, bases, attrs) new_cls._element_type = element_type new_cls._shape = shape @@ -1834,18 +2239,28 @@ class Constexpr(Generic[TY]): pass -class align: - def __init__(self, value: int): +class align(int): + def __new__(cls, value: int) -> "align": if value <= 0 or (value & (value - 1)) != 0: raise DSLRuntimeError("expects align be power of 2 as positive value") - self._value = value + return super().__new__(cls, value) - def __str__(self): - return f"align({self._value})" + def __str__(self) -> str: + return f"align({super().__str__()})" class PointerMeta(DslType): - def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)): + _value_type: Any + _align: Any + + def __new__( + cls, + name: str, + bases: tuple, + attrs: dict, + value_type: Any = Int32, + align_: Any = align(1), + ) -> Any: new_cls = super().__new__( cls, name, @@ -1859,7 +2274,7 @@ class PointerMeta(DslType): new_cls._align = align_ return new_cls - def __eq__(cls, other): + def __eq__(cls, other: Any) -> bool: if not isinstance(other, PointerMeta): return False return ( @@ -1867,10 +2282,10 @@ class PointerMeta(DslType): and cls._align._value == other._align._value ) # Compare alignment values - def __hash__(cls): + def __hash__(cls) -> int: return hash((cls._value_type, cls._align._value)) # Hash alignment value - def __getitem__(cls, params) -> Type["Pointer"]: + def __getitem__(cls, params: Any) -> Type["Pointer"]: value_type, align_ = params if not isinstance(align_, align): @@ -1886,7 +2301,7 @@ class PointerMeta(DslType): ) return new_cls - def __str__(cls): + def __str__(cls) -> str: return f"ptr<{cls._value_type}, {cls._align}>" @@ -1901,10 +2316,10 @@ class Pointer(metaclass=PointerMeta): """ - def __init__(self, value): + def __init__(self, value: Any) -> None: self.value = value - def __str__(self): + def __str__(self) -> str: return f"{self.value} : {type(self)}" @@ -1927,19 +2342,19 @@ class IRVariadic: A helper class to pass a variadic number of arguments to a function. """ - def __init__(self, operands): + def __init__(self, operands: list[ir.Value]) -> None: """ Create a list of variadic operands. `operands` must be dynamic values. """ self.operands = operands - def block_arg_types(self): + def block_arg_types(self) -> list[ir.Type]: """ Return the list of block args types. """ return [operand.type for operand in self.operands] - def set_func_args(self, block_args): + def set_func_args(self, block_args: list[ir.Value]) -> None: """ This function is called after entering a function. `block_args` are the block arguments that correspond to the passed operands. Derived classes @@ -1948,7 +2363,7 @@ class IRVariadic: """ pass - def __len__(self): + def __len__(self) -> int: """ Return the length of variadic operands. """ @@ -1960,7 +2375,9 @@ class FuncArgWithAttr(IRValue): This derived class is specifically for func op arg with attr """ - def __init__(self, ty, attr_name, attr_ty, attr_value=None): + def __init__( + self, ty: Any, attr_name: str, attr_ty: Any, attr_value: Any = None + ) -> None: super().__init__(ty) assert attr_name is not None and ( attr_ty is not None or attr_value is not None @@ -1970,8 +2387,9 @@ class FuncArgWithAttr(IRValue): self.attr_value = attr_value - -def implicitDowncastNumericType(value): +def implicitDowncastNumericType( + value: Union[bool, int, float, "Numeric"], +) -> Union[bool, int, float, ir.Value]: if isinstance(value, Numeric): return value.ir_value() return value diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/logger.py b/python/CuTeDSL/cutlass/base_dsl/utils/logger.py index 2ce47c7cb..32576a99a 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/logger.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/logger.py @@ -15,16 +15,20 @@ This module provides logging helper functions import logging -logger = None +logger: logging.Logger -def log(): +def log() -> logging.Logger: return logger def setup_log( - name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1 -): + name: str, + log_to_console: bool = False, + log_to_file: bool = False, + log_file_path: str | None = None, + log_level: int = 1, +) -> logging.Logger: """Set up and configure a logger with console and/or file handlers. :param name: Name of the logger to create @@ -78,7 +82,7 @@ def setup_log( return logger -def _init_logger_with_client_name(prefix): +def _init_logger_with_client_name(prefix: str) -> None: from ..env_manager import LogEnvironmentManager log_env = LogEnvironmentManager(prefix) diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/numpy.py b/python/CuTeDSL/cutlass/base_dsl/utils/numpy.py index 61a181e3e..82b5b04cb 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/numpy.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/numpy.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -15,6 +15,8 @@ the DSL. """ import numpy as np +from typing import Any + from ..._mlir.extras import types as T # ============================================================================= @@ -22,7 +24,7 @@ from ..._mlir.extras import types as T # ============================================================================= -def _numpy_type_to_mlir_type(dtype): +def _numpy_type_to_mlir_type(dtype: type[np.generic] | np.dtype[Any]) -> Any: if dtype == np.float64: return T.f64() if dtype == np.float16: @@ -47,22 +49,10 @@ def _numpy_type_to_mlir_type(dtype): return T.ui8() if dtype == np.bool_: return T.bool() - if dtype == f8E5M2: - return T.f8E5M2() - if dtype == f8E4M3FN: - return T.f8E4M3FN() - if dtype == f8E8M0FNU: - return T.f8E8M0FNU() - if dtype == f6E3M2FN: - return T.f6E3M2FN() - if dtype == f6E2M3FN: - return T.f6E2M3FN() - if dtype == f4E2M1FN: - return T.f4E2M1FN() raise TypeError(f"Unknown NumPy dtype for MLIR conversion: {dtype!r}") -def _mlir_type_to_numpy_type(mlir_type): +def _mlir_type_to_numpy_type(mlir_type: Any) -> type[np.generic]: if mlir_type == T.f64(): return np.float64 if mlir_type == T.f16(): diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py b/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py index af899faa9..a97aa73c7 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/stacktrace.py @@ -15,9 +15,10 @@ This module provides stacktrace helper functions import os import re +import types -def walk_to_top_module(start_path): +def walk_to_top_module(start_path: str) -> str | None: """ Walk up from the start_path to find the top-level Python module. @@ -50,7 +51,9 @@ def walk_to_top_module(start_path): return current_path -def _filter_internal_frames(traceback, internal_path): +def _filter_internal_frames( + traceback: types.TracebackType | None, internal_path: str +) -> types.TracebackType | None: """ Filter out stack frames from the traceback that belong to the specified module path. @@ -75,12 +78,14 @@ def _filter_internal_frames(traceback, internal_path): return traceback -_generated_function_names = re.compile( +_generated_function_names: re.Pattern[str] = re.compile( r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$" ) -def _filter_duplicated_frames(traceback): +def _filter_duplicated_frames( + traceback: types.TracebackType | None, +) -> types.TracebackType | None: """ Filter out duplicated stack frames from the traceback. The function filters out consecutive frames that are in the same file and have the same line number. @@ -115,6 +120,7 @@ def _filter_duplicated_frames(traceback): else: traceback = iter_tb.tb_next elif skip_next: + assert iter_tb.tb_next is not None # if next is last frame, don't skip if iter_tb.tb_next.tb_next: iter_tb.tb_next = iter_tb.tb_next.tb_next @@ -126,7 +132,9 @@ def _filter_duplicated_frames(traceback): return traceback -def filter_stackframe(traceback, prefix_path): +def filter_stackframe( + traceback: types.TracebackType | None, prefix_path: str +) -> types.TracebackType | None: """ Filter out stack frames from the traceback that belong to the specified module path. @@ -145,7 +153,7 @@ def filter_stackframe(traceback, prefix_path): return _filter_duplicated_frames(traceback) -def filter_exception(value, module_dir): +def filter_exception(value: BaseException, module_dir: str) -> None: """ Filter out internal implementation details from exception traceback. diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/timer.py b/python/CuTeDSL/cutlass/base_dsl/utils/timer.py index bf4a2ce0b..5a5c95a62 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/timer.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/timer.py @@ -14,17 +14,18 @@ This module provides a timing helper functions """ from functools import wraps +from typing import Any from .logger import log # TODO: revisit this part when mlir timing manager is ready for pybind. -def timer(*dargs, **kwargs): +def timer(*dargs: Any, **kwargs: Any) -> Any: enable = kwargs.get("enable", True) - def decorator(func): + def decorator(func: Any) -> Any: @wraps(func) - def func_wrapper(*args, **kwargs): + def func_wrapper(*args: Any, **kwargs: Any) -> Any: if not enable: return func(*args, **kwargs) from time import time diff --git a/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py b/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py index 86fb2cd46..2c3f8c6b1 100644 --- a/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py +++ b/python/CuTeDSL/cutlass/base_dsl/utils/tree_utils.py @@ -9,16 +9,62 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin +from collections.abc import Callable, Iterable, Iterator +from typing import Any, NamedTuple, get_origin + import dataclasses import itertools as it from types import SimpleNamespace -from ..typing import as_numeric, Numeric, Constexpr +from ..typing import as_numeric, Numeric, Constexpr, implements_dynamic_expression from .._mlir_helpers.arith import ArithValue from ..common import DSLBaseError from ..._mlir import ir + +def _flatten_mlir_values(values: Any) -> list[ir.Value]: + """ + Flatten a nested dict/list structure of MLIR values into a flat list. + Local copy to avoid circular imports with dsl.py. + """ + if values is None: + return [] + elif isinstance(values, ir.Value): + return [values] + elif isinstance(values, dict): + result = [] + for v in values.values(): + result.extend(_flatten_mlir_values(v)) + return result + elif isinstance(values, list): + result = [] + for v in values: + result.extend(_flatten_mlir_values(v)) + return result + else: + return [] + + +def _unflatten_mlir_values(flat_values: Any, template: Any) -> Any: + """ + Reconstruct a nested dict/list structure from a flat list of MLIR values. + Local copy to avoid circular imports with dsl.py. + """ + if not hasattr(flat_values, "__next__"): + flat_values = iter(flat_values) + + if template is None: + return None + elif isinstance(template, ir.Value): + return next(flat_values) + elif isinstance(template, dict): + return {k: _unflatten_mlir_values(flat_values, v) for k, v in template.items()} + elif isinstance(template, list): + return [_unflatten_mlir_values(flat_values, v) for v in template] + else: + return None + + NoneType = type(None) # ============================================================================= @@ -99,27 +145,25 @@ def is_frozen_dataclass(obj_or_cls: Any) -> bool: return ( dataclasses.is_dataclass(cls) and getattr(cls, "__dataclass_params__", None) is not None - and cls.__dataclass_params__.frozen + and cls.__dataclass_params__.frozen # type: ignore[attr-defined] ) -def is_dynamic_expression(x: Any) -> bool: +def is_namedtuple_instance(x: Any) -> bool: """ - Check if an object implements the DynamicExpression protocol. - - Objects implementing this protocol must have both `__extract_mlir_values__` - and `__new_from_mlir_values__` methods. + Check if an object is an instance of a :class:`typing.NamedTuple` subclass. Args: x: Any object to check Returns: - bool: True if the object implements the DynamicExpression protocol, - False otherwise + bool: True if *x* is a NamedTuple instance, False otherwise """ - return all( - hasattr(x, attr) - for attr in ("__extract_mlir_values__", "__new_from_mlir_values__") + t = type(x) + return ( + issubclass(t, tuple) + and hasattr(t, "_fields") + and isinstance(t._fields, tuple) ) @@ -183,8 +227,8 @@ class Leaf: is_numeric: bool = False is_none: bool = False - node_metadata: SimpleNamespace = None - ir_type_str: str = None + node_metadata: SimpleNamespace | None = None + ir_type_str: str | None = None # ============================================================================= @@ -192,7 +236,7 @@ class Leaf: # ============================================================================= -def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any], list[Any]]: +def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any], list[str]]: """ Extract non-method, non-function attributes from a dataclass instance. @@ -223,7 +267,7 @@ def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any], list[Any]]: constexpr_fields.append(field.name) fields.remove(field.name) v = getattr(x, field.name) - if is_dynamic_expression(v): + if implements_dynamic_expression(v): raise DSLTreeFlattenError( f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`", type_str=get_fully_qualified_class_name(x), @@ -316,7 +360,43 @@ def default_dataclass_from_iterable( ) -def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: +def namedtuple_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a :class:`typing.NamedTuple` instance to iterable form. + + Args: + x: A NamedTuple instance + + Returns: + tuple: (metadata, field_values) where metadata stores the field names + and original object for later reconstruction + """ + fields = list(type(x)._fields) + return ( + SimpleNamespace( + type_str=get_fully_qualified_class_name(x), + fields=fields, + original_obj=x, + ), + list(x), + ) + + +def namedtuple_from_iterable(metadata: SimpleNamespace, children: Iterable[Any]) -> Any: + """ + Reconstruct a :class:`typing.NamedTuple` instance from iterable form. + + Args: + metadata: Metadata produced by :func:`namedtuple_to_iterable` + children: Iterable of reconstructed field values + + Returns: + A new NamedTuple instance of the original type + """ + return type(metadata.original_obj)(*children) + + +def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any] | None]: """ Convert a dynamic expression to iterable form. @@ -329,9 +409,31 @@ def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression and mlir_values are the extracted MLIR values """ + extracted = x.__extract_mlir_values__() + if extracted is None: + # Preserve None so the caller's "children is None" check still triggers + # DSLTreeFlattenError for types whose __extract_mlir_values__ returns None + # (e.g. runtime._Pointer inheriting an unimplemented ABC stub). + return ( + SimpleNamespace(is_dynamic_expression=1, original_obj=x, template=None), + None, + ) + + flattened = _flatten_mlir_values(extracted) + if not flattened and extracted: + # extracted is non-empty but flatten produced nothing -- this means + # __extract_mlir_values__ returned non-ir.Value items (e.g. Python ints + # before they are promoted to MLIR values). Fall back to passing the raw + # values as children, which preserves the old (pre-flatten) behavior and + # lets _tree_flatten raise DSLTreeFlattenError for unsupported types. + return ( + SimpleNamespace(is_dynamic_expression=1, original_obj=x, template=None), + extracted, + ) + return ( - SimpleNamespace(is_dynamic_expression=1, original_obj=x), - x.__extract_mlir_values__(), + SimpleNamespace(is_dynamic_expression=1, original_obj=x, template=extracted), + flattened, ) @@ -350,7 +452,13 @@ def dynamic_expression_from_iterable( Returns: The reconstructed dynamic expression object """ - return metadata.original_obj.__new_from_mlir_values__(list(children)) + children_list = list(children) + # If we have a template, unflatten the values back to the original structure + if hasattr(metadata, "template") and metadata.template is not None: + values = _unflatten_mlir_values(children_list, metadata.template) + else: + values = children_list + return metadata.original_obj.__new_from_mlir_values__(values) def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: @@ -508,7 +616,9 @@ unflattened_a should be structurally identical to a, and unflattened_b should be """ -def tree_flatten(x: Any) -> tuple[list[Any], list[ir.Attribute], PyTreeDef]: +def tree_flatten( + x: Any, return_ir_values: bool = True +) -> tuple[list[Any], list[ir.Attribute], PyTreeDef | Leaf]: """ Flatten a nested structure into a flat list of values and a tree definition. @@ -518,9 +628,10 @@ def tree_flatten(x: Any) -> tuple[list[Any], list[ir.Attribute], PyTreeDef]: Args: x: The nested structure to flatten - + return_ir_values: Whether to return ir.Values instead of original values Returns: - tuple: (flat_values, treedef) where flat_values is a list of leaf values + tuple: (flat_values, flat_attributes, treedef) where flat_values is a list of leaf values + and flat_attributes is a list of attributes for the leaf values and treedef is the tree structure definition Raises: @@ -530,11 +641,11 @@ def tree_flatten(x: Any) -> tuple[list[Any], list[ir.Attribute], PyTreeDef]: >>> tree_flatten([1, [2, 3], 4]) ([1, 2, 3, 4], PyTreeDef(...)) """ - children_iter, child_attrs_iter, treedef = _tree_flatten(x) + children_iter, child_attrs_iter, treedef = _tree_flatten(x, return_ir_values) return list(children_iter), list[ir.Attribute](child_attrs_iter), treedef -def get_registered_node_types_or_insert(x: Any) -> Union[NodeType, None]: +def get_registered_node_types_or_insert(x: Any) -> NodeType | None: """ Get the registered node type for an object, registering it if necessary. @@ -553,11 +664,17 @@ def get_registered_node_types_or_insert(x: Any) -> Union[NodeType, None]: node_type = _node_types.get(type(x)) if node_type: return node_type - elif is_dynamic_expression(x): + elif implements_dynamic_expression(x): # If a class implements DynamicExpression protocol, register it before default dataclass one return register_pytree_node( type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable ) + elif is_namedtuple_instance(x): + # NamedTuples are pytree containers: flatten to field values, rebuild via constructor. + # Checked before dataclass because NamedTuples are tuples, not dataclasses. + return register_pytree_node( + type(x), namedtuple_to_iterable, namedtuple_from_iterable + ) elif dataclasses.is_dataclass(x): return register_pytree_node( type(x), default_dataclass_to_iterable, default_dataclass_from_iterable @@ -567,11 +684,11 @@ def get_registered_node_types_or_insert(x: Any) -> Union[NodeType, None]: def create_leaf_for_value( - x: Any, + x: Any = None, is_numeric: bool = False, is_none: bool = False, - node_metadata: SimpleNamespace = None, - ir_type_str: str = None, + node_metadata: SimpleNamespace | None = None, + ir_type_str: str | None = None, ) -> Leaf: """ Create a Leaf node for a given value. @@ -596,7 +713,8 @@ def create_leaf_for_value( def _tree_flatten( x: Any, -) -> tuple[Iterable[Any], Iterable[ir.Attribute], Union[PyTreeDef, Leaf]]: + return_ir_values: bool = True, +) -> tuple[Iterable[Any], Iterable[ir.Attribute], PyTreeDef | Leaf]: """ Internal function to flatten a tree structure. @@ -606,10 +724,12 @@ def _tree_flatten( Args: x: The object to flatten + return_ir_values: Whether to return ir.Value instead of original values Returns: - tuple: (flattened_values, treedef) where flattened_values is an iterable - of leaf values and treedef is the tree structure + tuple: (flattened_values, flattened_attributes, treedef) where flattened_values is an iterable + of leaf values, flattened_attributes is an iterable of leaf attributes + and treedef is the tree structure Raises: DSLTreeFlattenError: If the object type is not supported @@ -617,8 +737,12 @@ def _tree_flatten( if x is None: return [], [], create_leaf_for_value(x, is_none=True) - elif isinstance(x, ArithValue) and is_dynamic_expression(x): - v = x.__extract_mlir_values__() + elif isinstance(x, ArithValue) and implements_dynamic_expression(x): + v = ( + _flatten_mlir_values(x.__extract_mlir_values__()) + if return_ir_values + else [x] + ) a = ( [ir.DictAttr.get({})] if not hasattr(x, "__extract_mlir_attributes__") @@ -630,18 +754,18 @@ def _tree_flatten( create_leaf_for_value( x, node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), - ir_type_str=str(v[0].type), + ir_type_str=str(x.type), ), ) elif isinstance(x, ArithValue): return [x], [ir.DictAttr.get({})], create_leaf_for_value(x, is_numeric=True) - elif isinstance(x, ir.Value): - return [x], [ir.DictAttr.get({})], create_leaf_for_value(x) - - elif isinstance(x, Numeric): - v = x.__extract_mlir_values__() + elif implements_dynamic_expression(x) and isinstance(x, ir.Value): + # Only for ir.Value subclasses (e.g. ctm.Pointer). Check before plain ir.Value + # so they are unflattened via __new_from_mlir_values__. Other dynamic + # expressions (e.g. TmemAllocator with 2 values) use the registered/node path. + v = _flatten_mlir_values(x.__extract_mlir_values__()) a = ( [ir.DictAttr.get({})] if not hasattr(x, "__extract_mlir_attributes__") @@ -653,7 +777,31 @@ def _tree_flatten( create_leaf_for_value( x, node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), - ir_type_str=str(v[0].type), + ir_type_str=str(v[0].type) if v else "unknown", + ), + ) + + elif isinstance(x, ir.Value): + return [x], [ir.DictAttr.get({})], create_leaf_for_value(x) + + elif isinstance(x, Numeric): + v = ( + _flatten_mlir_values(x.__extract_mlir_values__()) # type: ignore[attr-defined] + if return_ir_values + else [x] + ) + a = ( + [ir.DictAttr.get({})] + if not hasattr(x, "__extract_mlir_attributes__") + else x.__extract_mlir_attributes__() + ) + return ( + v, + a, + create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(type(x).mlir_type), ), ) @@ -667,7 +815,7 @@ def _tree_flatten( "Flatten Error: children is None", get_fully_qualified_class_name(x) ) children_flat, child_attrs_flat, child_trees = unzip3( - map(_tree_flatten, children) + map(lambda child: _tree_flatten(child, return_ir_values), children) ) flattened = it.chain.from_iterable(children_flat) @@ -685,11 +833,13 @@ def _tree_flatten( # Try to convert to numeric try: - nval = as_numeric(x).ir_value() + numeric = as_numeric(x) return ( - [nval], + [numeric.ir_value() if return_ir_values else x], [ir.DictAttr.get({})], - create_leaf_for_value(nval, is_numeric=True), + create_leaf_for_value( + is_numeric=True, ir_type_str=str(type(numeric).mlir_type) + ), ) except Exception: raise DSLTreeFlattenError( @@ -720,7 +870,7 @@ def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: return _tree_unflatten(treedef, iter(xs)) -def _tree_unflatten(treedef: Union[PyTreeDef, Leaf], xs: Iterator[Any]) -> Any: +def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: """ Internal function to reconstruct a tree structure. @@ -748,7 +898,7 @@ def _tree_unflatten(treedef: Union[PyTreeDef, Leaf], xs: Iterator[Any]) -> Any: return treedef.node_type.from_iterable(treedef.node_metadata, children) -def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool: +def _check_tree_equal(lhs: PyTreeDef | Leaf, rhs: PyTreeDef | Leaf) -> bool: """ Check if two tree definitions are structurally equal. diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index 94ac506f8..aaa7787e6 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -9,6 +9,9 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from collections.abc import Callable +from typing import Any + # Use the auto-generated enum AddressSpace from cutlass._mlir.dialects.cute import AddressSpace, CacheEvictionPriority @@ -75,6 +78,7 @@ from .core import ( slice_and_offset, crd2idx, idx2crd, + increment_coord, filter_zeros, filter, tile_to_shape, @@ -122,6 +126,7 @@ from .core import ( fast_divmod_create_divisor, basis_value, basis_get, + nullspace, ) from .tuple import ( @@ -137,6 +142,8 @@ from .tuple import ( tuple_cat, transform_apply, filter_tuple, + unwrap, + wrap, ) from .tensor import ( TensorSSA, @@ -189,6 +196,7 @@ from .algorithm import gemm, copy, basic_copy, basic_copy_if, autovec_copy, pref from . import typing as typing_module from . import core from . import arch + from . import export from . import nvgpu @@ -196,18 +204,17 @@ from . import testing from . import runtime from . import math - # Export all math ops without "math." from .math import * # Used as internal symbol from .. import cutlass_dsl as _dsl -from .ffi import ffi +from .ffi import ffi, extern, BitCode, ConstValue, mangle # Aliases -jit = _dsl.CuTeDSL.jit -kernel = _dsl.CuTeDSL.kernel +jit: Callable[..., Any] = _dsl.CuTeDSL.jit +kernel: Callable[..., Any] = _dsl.CuTeDSL.kernel register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter compile = _dsl.CompileCallable() OptLevel = _dsl.OptLevel @@ -217,8 +224,12 @@ GenerateLineInfo = _dsl.GenerateLineInfo KeepCUBIN = _dsl.KeepCUBIN KeepPTX = _dsl.KeepPTX GPUArch = _dsl.GPUArch +LinkLibraries = _dsl.LinkLibraries EnableTVMFFI = _dsl.EnableTVMFFI +native_struct = _dsl.native_struct +make_native_struct = _dsl.make_native_struct # factory for dynamic struct types + # attach the TVM FFI ABI interface postprocessor to the DSL from . import _tvm_ffi_args_spec_converter @@ -226,52 +237,16 @@ _tvm_ffi_args_spec_converter.attach_args_spec_converter(_dsl.CuTeDSL._get_dsl()) # Explicitly export all symbols for documentation generation __all__ = [ - # ==================== cutlass._mlir.dialects.cute ==================== + # Core types + *core.__all__, "AddressSpace", "CacheEvictionPriority", - # ==================== .typing ==================== "Tensor", "Layout", "ComposedLayout", - "SymInt", - "is_integer", - "is_int_tuple", - # ==================== .core ==================== - *core.__all__, - # ==================== .tuple ==================== - "transform_leaf", - "find_if", - "find", - "flatten_to_tuple", - "unflatten", - "product", - "product_like", - "product_each", - "elem_less", - "tuple_cat", - "transform_apply", - "filter_tuple", - # ==================== .tensor ==================== - "TensorSSA", - "ReductionOp", - "make_tensor", - "make_identity_tensor", - "make_fragment", - "make_fragment_like", - "make_rmem_tensor_like", - "make_rmem_tensor", - "recast_tensor", - "domain_offset", - "print_tensor", - "full", - "full_like", - "empty_like", - "ones_like", - "zeros_like", - "where", - "any_", - "all_", - # ==================== .atom ==================== + "Swizzle", + "E", + "ScaledBasis", "Atom", "MmaAtom", "CopyAtom", @@ -279,6 +254,114 @@ __all__ = [ "TiledMma", "ThrMma", "ThrCopy", + "TensorSSA", + "ReductionOp", + "SymInt", + # Basic utility functions + "assume", + "is_integer", + "is_int_tuple", + "is_static", + "size", + "has_underscore", + "slice_", + "depth", + "rank", + "shape", + "printf", + "print_tensor", + "pretty_str", + # Layout functions + "make_layout", + "recast_layout", + "make_identity_layout", + "make_ordered_layout", + "make_layout_like", + "make_composed_layout", + "make_layout_tv", + "make_layout_image_mask", + "get_nonswizzle_portion", + "get_swizzle_portion", + "nullspace", + # Tensor functions + "make_ptr", + "make_tensor", + "make_identity_tensor", + "make_fragment", + "make_fragment_like", + "make_rmem_tensor", + "make_rmem_tensor_like", + "recast_ptr", + "recast_tensor", + # Tensor manipulation + "get", + "select", + "front", + "is_major", + "leading_dim", + "find", + "find_if", + "transform_leaf", + "basis_value", + "basis_get", + "coalesce", + "group_modes", + "cosize", + "size_in_bytes", + # Tuple operations + "flatten_to_tuple", + "flatten", + "unflatten", + "product", + "product_like", + "product_each", + "prepend", + "append", + "prepend_ones", + "append_ones", + "elem_less", + "tuple_cat", + "transform_apply", + "filter_tuple", + "unwrap", + "wrap", + # Math operations + "ceil_div", + "round_up", + # Layout operations + "slice_and_offset", + "crd2idx", + "increment_coord", + "domain_offset", + "filter_zeros", + "filter", + "tile_to_shape", + "shape_div", + "dice", + # Layout algebra + "composition", + "complement", + "right_inverse", + "left_inverse", + "max_common_layout", + "max_common_vector", + "is_congruent", + "is_weakly_congruent", + # Product operations + "logical_product", + "zipped_product", + "tiled_product", + "flat_product", + "raked_product", + "blocked_product", + # Division operations + "flat_divide", + "logical_divide", + "zipped_divide", + "tiled_divide", + "local_partition", + "local_tile", + # MMA and Copy atom operations "make_atom", "make_mma_atom", "make_tiled_mma", @@ -294,26 +377,48 @@ __all__ = [ "make_cotiled_copy", "copy_atom_call", "mma_atom_call", - # ==================== .algorithm ==================== - "gemm", - "copy", + # Algorithm operations "basic_copy", "basic_copy_if", "autovec_copy", + "copy", "prefetch", - # ==================== .extension ==================== - # ==================== .math ==================== - *math.__all__, - # ==================== submodules ==================== + "gemm", + # Tensor SSA + "full", + "full_like", + "empty_like", + "ones_like", + "zeros_like", + "where", + "any_", + "all_", + "repeat_as_tuple", + "repeat", + "repeat_like", + # User defined struct + "struct", + "union", + # FastDivmod operations + "FastDivmodDivisor", + "fast_divmod_create_divisor", + # Modules "arch", "export", "nvgpu", "testing", "runtime", - # ==================== DSL (cutlass_dsl) ==================== + # Math utils + *math.__all__, + # Decorators and code generation "jit", "kernel", "register_jit_arg_adapter", "compile", + # Foreign function interface "ffi", + "extern", + "BitCode", + "ConstValue", + "mangle", ] diff --git a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py index e878d6e4a..579773cf1 100644 --- a/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py +++ b/python/CuTeDSL/cutlass/cute/_tvm_ffi_args_spec_converter.py @@ -9,9 +9,12 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. + +from dataclasses import is_dataclass, fields as dataclass_fields from cutlass.base_dsl.tvm_ffi_builder import spec from cutlass.base_dsl.jit_executor import ExecutionArgs from cutlass.base_dsl.common import DSLRuntimeError +from cutlass.base_dsl.utils.tree_utils import is_constexpr_field from cutlass.cutlass_dsl import is_cute_algebra_type from cutlass._mlir.dialects import cute as _cute_ir from .runtime import _FakeStream @@ -19,7 +22,6 @@ from .typing import Tensor, Pointer, SymInt from .typing import ( Numeric, Boolean, - Integer, Int4, Int8, Uint8, @@ -43,17 +45,17 @@ from .typing import ( ) import cuda.bindings.driver as cuda +from types import UnionType from typing import ( List, Dict, Any, Optional, - Tuple, + Union, get_origin, get_args, get_type_hints, ) -from types import UnionType import inspect NumericToTVMFFIDtype = { @@ -113,22 +115,24 @@ class SymIntId: def __init__(self, sym_int: SymInt): self.sym_int = sym_int - def __hash__(self): + def __hash__(self) -> int: return id(self.sym_int) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, SymIntId): + return NotImplemented return self.sym_int is other.sym_int class ConverterContext: """Context for managing variable allocation during TVM FFI args conversion.""" - def __init__(self): - self.num_dyn_shape_vars = 0 - self.num_dyn_stride_vars = 0 - self.num_device_id_vars = 0 - self.sym_int_id_mapping = {} - self.vdevice_to_device_id_mapping = {} + def __init__(self) -> None: + self.num_dyn_shape_vars: int = 0 + self.num_dyn_stride_vars: int = 0 + self.num_device_id_vars: int = 0 + self.sym_int_id_mapping: Dict[SymIntId, spec.Var] = {} + self.vdevice_to_device_id_mapping: Dict[tuple, spec.Var] = {} def alloc_shape_name(self) -> str: """Allocate a new dynamic shape variable name.""" @@ -142,7 +146,9 @@ class ConverterContext: self.num_dyn_stride_vars += 1 return name - def alloc_or_reuse_symint_var(self, value: SymInt, name_alloc_func): + def alloc_or_reuse_symint_var( + self, value: SymInt, name_alloc_func: Any + ) -> spec.Var: """Allocate or reuse a symbolic integer variable.""" sym_int_id = SymIntId(value) if sym_int_id in self.sym_int_id_mapping: @@ -179,7 +185,7 @@ class ConverterContext: def _convert_single_arg( - arg, arg_name: str, arg_type, ctx: ConverterContext + arg: Any, arg_name: str, arg_type: Any, ctx: ConverterContext ) -> spec.Param: """Convert a single argument to a spec.Param. @@ -219,14 +225,20 @@ def _convert_single_arg( spec.Var(ctx.alloc_shape_name(), NumericToTVMFFIDtype[arg[i].dtype]) ) return spec.Shape(arg_name, shape) + elif isinstance(arg, SymInt): + if arg.width == 32: + dtype = NumericToTVMFFIDtype[Int32] + else: + dtype = NumericToTVMFFIDtype[Int64] + return spec.Var(arg_name, dtype, divisibility=arg.divisibility) elif isinstance(arg, Tensor): shapes = [] - for i, dyn_mask in enumerate(arg.dynamic_shapes_mask): + for i, dyn_mask in enumerate(arg.dynamic_shapes_mask): # type: ignore[attr-defined] if not dyn_mask: - shapes.append(arg.shape[i]) - elif isinstance(arg.shape[i], SymInt): + shapes.append(arg.shape[i]) # type: ignore[index] + elif isinstance(arg.shape[i], SymInt): # type: ignore[index] shapes.append( - ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name) + ctx.alloc_or_reuse_symint_var(arg.shape[i], ctx.alloc_shape_name) # type: ignore[arg-type, index] ) else: shapes.append( @@ -234,12 +246,12 @@ def _convert_single_arg( ) strides = [] - for i, dyn_mask in enumerate(arg.dynamic_strides_mask): + for i, dyn_mask in enumerate(arg.dynamic_strides_mask): # type: ignore[attr-defined] if not dyn_mask: - strides.append(arg.stride[i]) - elif isinstance(arg.stride[i], SymInt): + strides.append(arg.stride[i]) # type: ignore[index] + elif isinstance(arg.stride[i], SymInt): # type: ignore[index] strides.append( - ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name) + ctx.alloc_or_reuse_symint_var(arg.stride[i], ctx.alloc_stride_name) # type: ignore[arg-type, index] ) else: if hasattr(arg, "_use_32bit_stride") and arg._use_32bit_stride: @@ -258,10 +270,10 @@ def _convert_single_arg( tvm_ffi_cute_tensor = spec.Tensor( arg_name, - shapes, + shapes, # type: ignore[arg-type] arg._tvm_ffi_tensor.dtype, - strides=strides, - data_alignment=arg._assumed_align, + strides=strides, # type: ignore[arg-type] + data_alignment=arg._assumed_align, # type: ignore[attr-defined] device_type=device_type, device_id=device_id, ) @@ -274,10 +286,10 @@ def _convert_single_arg( tvm_ffi_cute_tensor = spec.Tensor( arg_name, - shapes, - NumericToTVMFFIDtype[arg.element_type], - strides=strides, - data_alignment=arg._assumed_align, + shapes, # type: ignore[arg-type] + NumericToTVMFFIDtype[arg.element_type], # type: ignore[index] + strides=strides, # type: ignore[arg-type] + data_alignment=arg._assumed_align, # type: ignore[attr-defined] device_type=device_type, device_id=device_id, ) @@ -298,7 +310,7 @@ def _convert_single_arg( return spec.Stream(arg_name) elif isinstance(arg, cuda.CUstream): return spec.Stream(arg_name) - elif arg_type is not None and hasattr(arg_type, "_fields"): + elif arg_type is not inspect.Parameter.empty and hasattr(arg_type, "_fields"): # Handle NamedTuple - normalize to Tuple by order of fields, ignoring defaults # Get field types from annotations type_hints = get_type_hints(arg_type) @@ -323,9 +335,9 @@ def _convert_single_arg( tuple_params.append(elem_param) return spec.TupleParam(arg_name, tuple_params) - elif arg_type is not None and get_origin(arg_type) is tuple: + elif arg_type is not inspect.Parameter.empty and get_origin(arg_type) is tuple: # Handle Tuple[X, Y, ...] type annotations - tuple_element_types = get_args(arg_type) + tuple_element_types = get_args(arg_type) # type: ignore[assignment] if not isinstance(arg, (tuple, list)): raise DSLRuntimeError( f"Expected tuple for argument {arg_name}, got {type(arg)}" @@ -360,6 +372,49 @@ def _convert_single_arg( return spec.Var(arg_name, NumericToTVMFFIDtype[Int32]) elif isinstance(arg, float): return spec.Var(arg_name, NumericToTVMFFIDtype[Float32]) + elif ( + is_dataclass(arg_type) + if (arg_type is not None and arg_type is not inspect.Parameter.empty) + else is_dataclass(type(arg)) + ): + dc_type = ( + arg_type + if ( + arg_type is not None + and arg_type is not inspect.Parameter.empty + and is_dataclass(arg_type) + ) + else type(arg) + ) + if not isinstance(arg, dc_type): + raise DSLRuntimeError( + f"Expected {dc_type.__name__} for argument {arg_name}, got {type(arg)}" + ) + dc_fields = dataclass_fields(dc_type) + tuple_params = [] + for f in dc_fields: + if is_constexpr_field(f): + continue + field_value = getattr(arg, f.name) + field_name = f"{arg_name}.{f.name}" + field_type = f.type + tuple_params.append( + _convert_single_arg(field_value, field_name, field_type, ctx) + ) + return spec.TupleParam(arg_name, tuple_params) + elif arg_type is not None and ( + get_origin(arg_type) is UnionType or get_origin(arg_type) is Union + ): + member_types = get_args(arg_type) + for member_type in member_types: + try: + return _convert_single_arg(arg, arg_name, member_type, ctx) + except DSLRuntimeError: + continue + raise DSLRuntimeError( + f"Unsupported argument type: {type(arg)} for union type: {arg_type}. " + f"None of the union members matched: {member_types}" + ) else: raise DSLRuntimeError( f"Unsupported argument type: {type(arg)} for annotated type: {get_origin(arg_type)}" @@ -368,25 +423,25 @@ def _convert_single_arg( def _tvm_ffi_args_spec_converter( function_name: str, - args_spec: inspect.FullArgSpec, + signature: inspect.Signature, full_args: List[Any], full_kwargs: Dict[str, Any], -): +) -> tuple[List[spec.Param], Any]: """Convert cute algebra args to tvm ffi spec params. This function converts the cute arguments specs to tvm ffi spec params. """ - exec_args = ExecutionArgs(args_spec, function_name) + exec_args = ExecutionArgs(signature, function_name) rectified_args = exec_args.get_rectified_args_from_original_args( full_args, full_kwargs ) - arg_names = exec_args.args_spec.args + exec_args.args_spec.kwonlyargs params = [] ctx = ConverterContext() wrapper_extra_exclude_arg_names = [] - for arg, arg_name in zip(rectified_args, arg_names): - arg_type = args_spec.annotations.get(arg_name, None) + for arg, parameter in zip(rectified_args, exec_args.signature.parameters.values()): + arg_type = parameter.annotation + arg_name = parameter.name param = _convert_single_arg(arg, arg_name, arg_type, ctx) params.append(param) if isinstance(param, spec.EnvStream): @@ -397,6 +452,6 @@ def _tvm_ffi_args_spec_converter( return params, kwargs_wrapper_spec -def attach_args_spec_converter(dsl): +def attach_args_spec_converter(dsl: Any) -> None: """Attach TVM FFI ABI interface postprocessor to the DSL instance.""" dsl._tvm_ffi_args_spec_converter = _tvm_ffi_args_spec_converter diff --git a/python/CuTeDSL/cutlass/cute/algorithm.py b/python/CuTeDSL/cutlass/cute/algorithm.py index f982c7b44..a6a99ea0c 100644 --- a/python/CuTeDSL/cutlass/cute/algorithm.py +++ b/python/CuTeDSL/cutlass/cute/algorithm.py @@ -9,15 +9,22 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. + import math -from typing import Optional, Dict, Any, List, Tuple, Union +from typing import Optional, Dict, Any, List, Tuple, Type, Union from cutlass._mlir import ir -from cutlass.cutlass_dsl import for_generate, yield_out, if_generate, dsl_user_op +from cutlass.cutlass_dsl import ( + for_generate, + yield_out, + if_generate, + dsl_user_op, + LoopUnroll, +) import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from .typing import Tensor, Int64, Int16, AddressSpace +from .typing import Numeric, Tensor, Int64, Int16, AddressSpace from .core import ( rank, is_static, @@ -28,6 +35,7 @@ from .core import ( logical_divide, append_ones, group_modes, + slice_, ) from .atom import ( MmaAtom, @@ -36,20 +44,13 @@ from .atom import ( _normalize_variadic_tensor_operand, copy_atom_call, ) -from .nvgpu.common import CacheEvictionPriority - -def _normalize_gemm_operand_list( - x: Union["Tensor", List["Tensor"], Tuple["Tensor", ...]], name: str -) -> List["Tensor"]: - if isinstance(x, Tensor): - return [x] - if isinstance(x, (list, tuple)): - if len(x) == 0: - raise ValueError(f"`{name}` must contain at least one Tensor") - if not all(isinstance(t, Tensor) for t in x): - raise TypeError(f"All elements of `{name}` must be Tensor") - return list(x) # type: ignore - raise TypeError(f"`{name}` must be a Tensor or a sequence of Tensors") +from .nvgpu.common import ( + CacheEvictionPriority, + CopyG2ROp, + CopyR2GOp, + CopyS2ROp, + CopyR2SOp, +) @dsl_user_op @@ -60,9 +61,9 @@ def gemm( b: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], c: Tensor, *, - loc=None, - ip=None, - **kwargs, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> None: """The GEMM algorithm. @@ -82,18 +83,23 @@ def gemm( - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) - Operand flexibility: - - `a` and `b` can be a single Tensor (regular GEMM) or a sequence `[operand, scale_factor]` for block-scaled GEMM. + The operands `a` and `b` are variadic, each containing a variable number of tensors: + + - For regular GEMM, `a` and `b` contain the GEMM A and B tensors respectively. + - For GEMM with auxiliary operands, `a` and `b` contain the GEMM A and B tensors followed by + their respective auxiliary tensors. For example: + + - For BlockScaledGemm, `a` = [A, SFA] and `b` = [B, SFB]. :param atom: MMA atom :type atom: MmaAtom - :param d: Destination tensor + :param d: Destination tensor (output accumulator) :type d: Tensor - :param a: First source tensor or sequence for advanced modes (e.g., `[a, sfa]`) + :param a: A tensor or list of tensors containing the GEMM A tensor and optional auxiliary tensors :type a: Union[Tensor, List[Tensor], Tuple[Tensor, ...]] - :param b: Second source tensor or sequence for advanced modes (e.g., `[b, sfb]`) + :param b: B tensor or list of tensors containing the GEMM B tensor and optional auxiliary tensors :type b: Union[Tensor, List[Tensor], Tuple[Tensor, ...]] - :param c: Third source tensor + :param c: Input accumulator tensor :type c: Tensor :param loc: Source location for MLIR, defaults to None :type loc: Optional[Location], optional @@ -106,8 +112,8 @@ def gemm( """ # Normalize A/B to lists for variadic IR operands, while keeping old API working. - a_list = _normalize_gemm_operand_list(a, "a") - b_list = _normalize_gemm_operand_list(b, "b") + a_list = _normalize_variadic_tensor_operand(a, "a") + b_list = _normalize_variadic_tensor_operand(b, "b") # Rank validations based on the primary A/B tensors (guaranteed non-empty) a_rank = rank(a_list[0].shape) @@ -137,8 +143,74 @@ def gemm( return _cute_ir.gemm(value, d.value, a_vals, b_vals, c.value, loc=loc, ip=ip) +def _make_copy_atom( + copy_internal_type: Type[Numeric], + num_bits_per_copy: int, + src_memspace: AddressSpace, + dst_memspace: AddressSpace, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **mem_attrs: Any, +) -> ir.Value: + """Create a copy atom, using the universal copy by default. + + When no ``mem_attrs`` are provided, the universal copy atom is used. + Otherwise, the function dispatches to a specialized copy op based on the + source/destination memory spaces and forwards the memory attributes: + - gmem -> rmem: ``CopyG2ROp`` + - rmem -> gmem: ``CopyR2GOp`` + - smem -> rmem: ``CopyS2ROp`` + - rmem -> smem: ``CopyR2SOp`` + + A ``ValueError`` is raised if no specialized op matches the memory-space pair. + """ + + if not mem_attrs: + atom_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + ) + return make_atom(atom_type, loc=loc, ip=ip) + + # Specialized path: dispatch based on memory spaces. + op: Union[ + CopyG2ROp, + CopyR2GOp, + CopyS2ROp, + CopyR2SOp, + ] + if src_memspace == AddressSpace.gmem and dst_memspace == AddressSpace.rmem: + op = CopyG2ROp() + elif src_memspace == AddressSpace.rmem and dst_memspace == AddressSpace.gmem: + op = CopyR2GOp() + elif src_memspace == AddressSpace.smem and dst_memspace == AddressSpace.rmem: + op = CopyS2ROp() + elif src_memspace == AddressSpace.rmem and dst_memspace == AddressSpace.smem: + op = CopyR2SOp() + else: + raise ValueError( + f"Memory attributes {set(mem_attrs)} are not supported for " + f"{src_memspace} -> {dst_memspace} copies (no specialized op available)." + ) + trait = op._make_trait( + copy_internal_type, + num_bits_per_copy=num_bits_per_copy, + loc=loc, + ip=ip, + **mem_attrs, + ) + return trait.value + + @dsl_user_op -def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: +def basic_copy( + src: Tensor, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """Performs a basic element-wise copy. This functions **assumes** the following pre-conditions: @@ -159,7 +231,8 @@ def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: if is_static(src.shape) and is_static(dst.shape): simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - src.element_type.mlir_type, src.element_type.width + src.element_type.mlir_type, # type: ignore[union-attr] + src.element_type.width, # type: ignore[union-attr] ) simt_copy = make_atom(simt_copy_ty, loc=loc, ip=ip) return _cute_ir.copy(simt_copy, [src.value], [dst.value], loc=loc, ip=ip) @@ -172,7 +245,14 @@ def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: @dsl_user_op -def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: +def basic_copy_if( + pred: Tensor, + src: Tensor, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """Performs a basic predicated element-wise copy. This functions **assumes** the following pre-conditions: @@ -183,7 +263,7 @@ def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) is fully unrolled. """ - if src.element_type.width != dst.element_type.width: + if src.element_type.width != dst.element_type.width: # type: ignore[union-attr] raise NotImplementedError( "basic_copy_if currently only supports equal source and destination " "element type bit width" @@ -195,7 +275,7 @@ def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) s = size(dst, loc=loc, ip=ip) # Always generate an scf.for Op when one of the tensors is dynamic for i in for_generate(0, s, loc=loc, ip=ip): - if_generate(pred[i], lambda: dst.__setitem__(i, src[i]), loc=loc, ip=ip) # type: ignore + if_generate(pred[i], lambda: dst.__setitem__(i, src[i]), loc=loc, ip=ip) yield_out() @@ -203,7 +283,12 @@ def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) # - verify size(src) == size(dst) == size(prd) # - fully unroll the loop for now def _basic_copy_if_static( - pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None + pred: Tensor, + src: Tensor, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: assert is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape) if size(src, loc=loc, ip=ip) != size(dst, loc=loc, ip=ip): @@ -212,7 +297,7 @@ def _basic_copy_if_static( ) # Fully unrolled loop in the static case for now for i in range(size(dst, loc=loc, ip=ip)): - if_generate(pred[i], lambda: dst.__setitem__(i, src[i]), loc=loc, ip=ip) # type: ignore + if_generate(pred[i], lambda: dst.__setitem__(i, src[i]), loc=loc, ip=ip) @dsl_user_op @@ -221,19 +306,22 @@ def autovec_copy( dst: Tensor, *, l1c_evict_priority: CacheEvictionPriority = CacheEvictionPriority.EVICT_NORMAL, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Auto-vectorization SIMT copy policy. - Given a source and destination tensors that are statically shaped, this policy figures out the - largest safe vector width that the copy instruction can take and performs the copy. + Given a source and destination tensors that are statically shaped, this policy + figures out the largest safe vector width that the copy instruction can take + and performs the copy. Any extra memory attributes are forwarded to the specialized + copy op. + """ - if src.element_type.width != dst.element_type.width: + if src.element_type.width != dst.element_type.width: # type: ignore[union-attr] raise NotImplementedError( - "autovec_copy currently only supports equal source and destination " - "element type bit width" + "autovec_copy only supports equal source and destination " + f"element type bit widths, got {src.element_type} and {dst.element_type}" ) # We are going to dispatch to copy-with-atom which requires shapes to be static @@ -249,20 +337,20 @@ def autovec_copy( # - the maximum alignment of the layouts # - the maximum alignment of the pointers - upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) + upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) # type: ignore[union-attr] upper_bound = math.gcd(upper_bound, num_common_elements) - upper_bound *= src.element_type.width + upper_bound *= src.element_type.width # type: ignore[union-attr] # For our instructions, the alignment of the pointer is an upper bound to the vector width # max_alignment, as opposed to alignment, takes into account possible address swizzling - upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) - upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) + upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) # type: ignore[union-attr] + upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) # type: ignore[union-attr] # Finally, we put a cap at 256b num_bits_per_copy = math.gcd(upper_bound, 256) if (num_common_elements > 1) and (num_bits_per_copy % 8 == 0): - num_common_elements = num_bits_per_copy // src.element_type.width + num_common_elements = num_bits_per_copy // src.element_type.width # type: ignore[union-attr] # 2 step logical divides ensuring that the divides are valid at every step vec_src = logical_divide(src, vec_layout, loc=loc, ip=ip) @@ -274,72 +362,88 @@ def autovec_copy( vec_dst, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip ) - # Dispatch to copy with atom - simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( - src.element_type.mlir_type, + # Forward memory attributes that differ from their defaults so _make_copy_atom + # falls back to the universal copy when no specialisation is needed. + mem_attrs = {} + if l1c_evict_priority != CacheEvictionPriority.EVICT_NORMAL: + mem_attrs["l1c_evict_priority"] = l1c_evict_priority + + simt_copy_atom = _make_copy_atom( + src.element_type, # type: ignore[arg-type] num_bits_per_copy, - 0, - 0, - l1c_evict_priority._to_ir(), + src.iterator.memspace, # type: ignore[union-attr] + dst.iterator.memspace, # type: ignore[union-attr] + loc=loc, + ip=ip, + **mem_attrs, ) - simt_copy = make_atom(simt_type, loc=loc, ip=ip) return _cute_ir.copy( - simt_copy, [tiled_src.value], [tiled_dst.value], loc=loc, ip=ip + simt_copy_atom, + [tiled_src.value], + [tiled_dst.value], + loc=loc, + ip=ip, ) # Failed to vectorize, use a basic copy basic_copy(src, dst, loc=loc, ip=ip) -def _parse_auto_multicast_args( +def _parse_tma_multicast_args( kwargs: Dict[str, Any], ) -> List[Tuple[str, ir.Attribute]]: """ Parse multicast-related kwargs and return a list of (attr_name, attr) pairs. This function consumes the following key from kwargs if present: - - 'auto_multicast': dict - dict: { 'multicast_layout': str, 'use_2cta': bool, 'from_block_api': bool } + - 'tma_multicast': dict with keys: + { 'cluster_shape': (m, n), 'multicast_dim': 'M' or 'N', + 'use_2cta_mma_inst': bool, 'from_block_api': bool } Returns: List of (attr_name, ir.Attribute) pairs to be attached to the op. Recognized attributes: - - ('multicast_layout', #cute.layout<...>) when a layout string is provided - - ('use_2cta', unit) when use_2cta is True + - ('multicast_layout', #cute.layout<...>) + - ('use_2cta', unit) when use_2cta_mma_inst is True - ('from_block_api', unit) when from_block_api is True """ attr_pairs: List[Tuple[str, ir.Attribute]] = [] # Pop known keys to avoid leaking to trait unpack - auto_multicast = kwargs.pop("auto_multicast", None) + tma_multicast = kwargs.pop("tma_multicast", None) - from_block_api: bool = False - use_2cta: bool = False - layout_str: Optional[str] = None + if tma_multicast is None: + return attr_pairs - if auto_multicast is not None: - if not isinstance(auto_multicast, dict): - raise TypeError( - "auto_multicast must be a dict with keys 'multicast_layout' and optional 'use_2cta'" - ) - layout_str = auto_multicast.get("multicast_layout", None) - use_2cta = bool(auto_multicast.get("use_2cta", False)) - from_block_api = bool(auto_multicast.get("from_block_api", False)) + if not isinstance(tma_multicast, dict): + raise TypeError("tma_multicast must be a dict") - if layout_str is not None: - if not isinstance(layout_str, str): - raise TypeError( - "multicast_layout must be a string representing a CuTe layout, e.g. '(4,2):(1,0)'" - ) - attr_pairs.append( - ( - "multicast_layout", - ir.Attribute.parse(f'#cute.layout<"{layout_str}">'), + # Validate required keys + required_keys = ["cluster_shape", "multicast_dim"] + for key in required_keys: + if key not in tma_multicast: + raise KeyError( + f"tma_multicast is missing required key '{key}'. " + f"Expected keys: {required_keys}" ) + + multicast_dim = tma_multicast["multicast_dim"] + if multicast_dim not in ("M", "N"): + raise ValueError(f"multicast_dim must be 'M' or 'N', got '{multicast_dim}'") + + cluster_m, cluster_n = tma_multicast["cluster_shape"] + direction = "(1,0)" if multicast_dim == "M" else "(0,1)" + layout_str = f"({cluster_m},{cluster_n}):{direction}" + + attr_pairs.append( + ( + "multicast_layout", + ir.Attribute.parse(f'#cute.layout<"{layout_str}">'), ) - if from_block_api: + ) + if tma_multicast.get("from_block_api", False): attr_pairs.append(("from_block_api", ir.UnitAttr.get())) - if use_2cta: + if tma_multicast.get("use_2cta_mma_inst", False): attr_pairs.append(("use_2cta", ir.UnitAttr.get())) return attr_pairs @@ -352,9 +456,10 @@ def copy( dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], *, pred: Optional[Tensor] = None, - loc=None, - ip=None, - **kwargs, + unroll_factor: Optional[int] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> None: """Facilitates data transfer between two tensors conforming to layout profile ``(V, Rest...)``. @@ -366,6 +471,8 @@ def copy( :type dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]] :param pred: Optional predication tensor for conditional transfers, defaults to None :type pred: Optional[Tensor], optional + :param unroll_factor: Optional unroll count for loop over Rest... modes, defaults to None for fully unroll when Rest... modes are static + :type unroll_factor: Optional[int], optional :param loc: Source location information, defaults to None :type loc: Any, optional :param ip: Insertion point, defaults to None @@ -430,15 +537,28 @@ def copy( src_primary = src_list[0] dst_primary = dst_list[0] - if isinstance(src_primary.type, _cute_ir.MemRefType) and isinstance( - dst_primary.type, _cute_ir.MemRefType + if isinstance( + src_primary.type, # type: ignore[attr-defined] + _cute_ir.MemRefType, + ) and isinstance( + dst_primary.type, # type: ignore[attr-defined] + _cute_ir.MemRefType, ): - if src_primary.element_type.width != dst_primary.element_type.width: + if ( + len(dst_list) == 1 + and src_primary.element_type.width != dst_primary.element_type.width # type: ignore[union-attr] + ): raise TypeError( "`copy` currently only supports equal source and destination " "element type bit width" ) + if unroll_factor is not None: + if not isinstance(unroll_factor, int) or unroll_factor < 1: + raise ValueError( + f"unroll_factor must be a positive integer, but got {unroll_factor}" + ) + if rank(src_primary) != rank(dst_primary): raise ValueError( "Expected source and destination tensors to have the same rank, " @@ -446,39 +566,63 @@ def copy( ) # Canonicalize all tensors to at least rank-2 - src_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in src_list] - dst_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in dst_list] + src_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in src_list] # type: ignore[call-overload] + dst_list = [group_modes(append_ones(t, up_to_rank=2), 1) for t in dst_list] # type: ignore[call-overload] if pred is not None: - pred = group_modes(append_ones(pred, up_to_rank=2), 1) + pred = group_modes(append_ones(pred, up_to_rank=2), 1) # type: ignore[call-overload] # Recompute primary references after canonicalization src_primary = src_list[0] dst_primary = dst_list[0] - if is_static(src_primary.shape[1]) and is_static(dst_primary.shape[1]): + if is_static(src_primary.shape[1]) and is_static(dst_primary.shape[1]): # type: ignore[index] if size(src_primary, mode=[1]) != size(dst_primary, mode=[1]): raise ValueError( "Expected source and destination tensors to have the same size in mode-1, " f"but got {size(src_primary, mode=[1])} and {size(dst_primary, mode=[1])}" ) - multicast_attr_pairs = _parse_auto_multicast_args(kwargs) + multicast_attr_pairs = _parse_tma_multicast_args(kwargs) - value = atom._unpack(loc=loc, ip=ip, **kwargs) - pred_value = pred.value if isinstance(pred, Tensor) else pred + # Unroll the loop per specified unroll_factor for static RestM case + if is_static(src_primary.shape[1]) and unroll_factor is not None: # type: ignore[index] + unroll_factor = LoopUnroll(count=unroll_factor) + for i in for_generate( + 0, + stop=size(src_primary, mode=[1], loc=loc, ip=ip), + unroll=unroll_factor, + loc=loc, + ip=ip, + ): + src_atom = [slice_(src, (None, i), loc=loc, ip=ip) for src in src_list] + dst_atom = [slice_(dst, (None, i), loc=loc, ip=ip) for dst in dst_list] + pred_atom = ( + slice_(pred, (None, i), loc=loc, ip=ip) if pred is not None else None + ) + copy_atom_call( + atom, src_atom, dst_atom, pred=pred_atom, loc=loc, ip=ip, **kwargs + ) + yield_out() + else: + value = atom._unpack(loc=loc, ip=ip, **kwargs) + pred_value = pred.value if isinstance(pred, Tensor) else pred - src_vals = [t.value for t in src_list] - dst_vals = [t.value for t in dst_list] - op = _cute_ir.copy(value, src_vals, dst_vals, pred=pred_value, loc=loc, ip=ip) + src_vals = [t.value for t in src_list] + dst_vals = [t.value for t in dst_list] + op = _cute_ir.copy(value, src_vals, dst_vals, pred=pred_value, loc=loc, ip=ip) - for name, attr in multicast_attr_pairs: - op.attributes[name] = attr - - return op + for name, attr in multicast_attr_pairs: + op.attributes[name] = attr @dsl_user_op -def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: +def prefetch( + atom: CopyAtom, + src: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ The Prefetch algorithm. diff --git a/python/CuTeDSL/cutlass/cute/arch/__init__.py b/python/CuTeDSL/cutlass/cute/arch/__init__.py index 776398787..16a24da08 100644 --- a/python/CuTeDSL/cutlass/cute/arch/__init__.py +++ b/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -17,6 +17,15 @@ from .tmem import * from .numeric_conversion import * from .clc import * +import cutlass.cutlass_dsl as cutlass_dsl + +# Forward from auto-generated nvvm python: only export on 12.9 wheel +_nvvm_forward_exports_12_9 = ( + ["ProxyKind", "SharedSpace", "RoundingModeKind", "ReduxKind", "AtomicOpKind"] + if cutlass_dsl.target_version(exact_version="12.9") + else [] +) + # __all__ is required here for documentation generation __all__ = [ # @@ -40,6 +49,7 @@ __all__ = [ # "lane_idx", "warp_idx", + "physical_warp_id", "thread_idx", "block_dim", "block_idx", @@ -75,7 +85,7 @@ __all__ = [ "vote_all_sync", "vote_uni_sync", "warp_redux_sync", - "atomic_max_float32", + "atomic_max_float32", # Deprecated: use atomic_fmax "atomic_add", "atomic_and", "atomic_or", @@ -86,6 +96,7 @@ __all__ = [ "atomic_cas", "store", "load", + "red", "popc", "fence_proxy", "fence_view_async_tmem_load", @@ -97,7 +108,9 @@ __all__ = [ "fma_packed_f32x2", "mul_packed_f32x2", "add_packed_f32x2", + "sub_packed_f32x2", "fmax", + "fmin", "rcp_approx", "exp2", "cvt_i8x4_to_f32x4", @@ -106,9 +119,39 @@ __all__ = [ "cvt_i8x2_to_bf16x2", "cvt_i8x4_to_bf16x4", "cvt_f32x2_bf16x2", - "warp_redux_sync", + "smid", + "nsmid", + "clock", + "clock64", + "match_sync", + "clz", + "bfind", + "brev", + "bfe", + "bfi", + "mul_hi", + "mul_wide", + "mul24", + "mad24", + "add_cc", + "addc", + "sub_cc", + "subc", + "mad_cc", + "madc", + "activemask", + "lanemask_lt", + "lanemask_le", + "lanemask_eq", + "lanemask_ge", + "lanemask_gt", + "add_sat_int", + "sub_sat_int", + "lop3", + "shf", # Constants "WARP_SIZE", + *_nvvm_forward_exports_12_9, # # smem.py # diff --git a/python/CuTeDSL/cutlass/cute/arch/clc.py b/python/CuTeDSL/cutlass/cute/arch/clc.py index e99c8d54d..88caa9d73 100644 --- a/python/CuTeDSL/cutlass/cute/arch/clc.py +++ b/python/CuTeDSL/cutlass/cute/arch/clc.py @@ -9,12 +9,12 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Tuple +from typing import Optional, Tuple -from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass.cutlass_dsl import dsl_user_op from cutlass._mlir import ir -from cutlass._mlir.dialects import nvvm, llvm, vector, arith +from cutlass._mlir.dialects import nvvm as _nvvm, vector from ..typing import Int32, Pointer, Int128 @@ -23,8 +23,9 @@ from ..typing import Int32, Pointer, Int128 def issue_clc_query( mbar_ptr: Pointer, clc_response_ptr: Pointer, - loc=None, - ip=None, + multicast: bool = True, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ The clusterlaunchcontrol.try_cancel instruction requests atomically cancelling the launch @@ -39,17 +40,27 @@ def issue_clc_query( """ mbar_llvm_ptr = mbar_ptr.llvm_ptr clc_response_llvm_ptr = clc_response_ptr.llvm_ptr - nvvm.clusterlaunchcontrol_try_cancel_multicast( - clc_response_llvm_ptr, - mbar_llvm_ptr, - loc=loc, - ip=ip, - ) + if multicast: + _nvvm.clusterlaunchcontrol_try_cancel_multicast( + clc_response_llvm_ptr, + mbar_llvm_ptr, + loc=loc, + ip=ip, + ) + else: + _nvvm.clusterlaunchcontrol_try_cancel( + clc_response_llvm_ptr, + mbar_llvm_ptr, + loc=loc, + ip=ip, + ) @dsl_user_op def clc_response( - result_addr: Pointer, loc=None, ip=None + result_addr: Pointer, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Int32, Int32, Int32, Int32]: """ After loading response from clusterlaunchcontrol.try_cancel instruction into 16-byte @@ -80,7 +91,8 @@ def clc_response( [0], ) # Query if the cluster was canceled - pred = nvvm.clusterlaunchcontrol_query_cancel_is_canceled( + # res parameter expects an MLIR Type, and returns the actual OpResult value + pred = _nvvm.clusterlaunchcontrol_query_cancel_is_canceled( clc_result_i128, loc=loc, ip=ip, @@ -88,21 +100,21 @@ def clc_response( is_valid = Int32(pred) # Get first CTA ID x component - m_idx_i32 = nvvm.clusterlaunchcontrol_query_cancel_get_first_ctaid_x( + m_idx_i32 = _nvvm.clusterlaunchcontrol_query_cancel_get_first_ctaid_x( clc_result_i128, loc=loc, ip=ip, ) # Get first CTA ID y component - n_idx_i32 = nvvm.clusterlaunchcontrol_query_cancel_get_first_ctaid_y( + n_idx_i32 = _nvvm.clusterlaunchcontrol_query_cancel_get_first_ctaid_y( clc_result_i128, loc=loc, ip=ip, ) # Get first CTA ID z component - l_idx_i32 = nvvm.clusterlaunchcontrol_query_cancel_get_first_ctaid_z( + l_idx_i32 = _nvvm.clusterlaunchcontrol_query_cancel_get_first_ctaid_z( clc_result_i128, loc=loc, ip=ip, diff --git a/python/CuTeDSL/cutlass/cute/arch/elect.py b/python/CuTeDSL/cutlass/cute/arch/elect.py index d9a26db84..ba4947271 100644 --- a/python/CuTeDSL/cutlass/cute/arch/elect.py +++ b/python/CuTeDSL/cutlass/cute/arch/elect.py @@ -9,7 +9,9 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from cutlass.cutlass_dsl import BaseDSL, T, dsl_user_op +from typing import Optional + +from cutlass.cutlass_dsl import BaseDSL, dsl_user_op import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir.dialects import nvvm, scf @@ -19,7 +21,12 @@ from ..typing import Int, Int32 @dsl_user_op -def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32: +def make_warp_uniform( + value: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: """ Provides a compiler hint indicating that the specified value is invariant across all threads in the warp, which may enable performance optimizations. @@ -42,31 +49,97 @@ class IfOpRegion: Automatically inserts `scf.yield([])` when exiting the context. """ - def __init__(self, block, *, loc=None, ip=None): + def __init__( + self, + block: ir.Block, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self.block = block self.insert_point = ir.InsertionPoint(self.block) self.loc = loc self.ip = ip - def __enter__(self): + def __enter__(self) -> ir.BlockArgumentList: self.insert_point.__enter__() return self.block.arguments - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: object, + ) -> None: scf.yield_([], loc=self.loc, ip=self.ip) self.insert_point.__exit__(exc_type, exc_value, traceback) @dsl_user_op -def elect_one(*, loc=None, ip=None) -> IfOpRegion: +def elect_one( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> IfOpRegion: """ - Elects one thread within a warp. + Elects one thread within a warp to execute single-threaded operations. + + This function uses the PTX ``elect.sync`` instruction to select exactly one thread + per warp to execute the code within its context. All other threads in the warp skip + the block and reconverge after it. + + See the PTX ISA documentation on `elect.sync `__. + + **When to Use elect_one:** + + ``elect_one()`` is **required** for operations that must be executed by a single thread + for correctness, including: + + - **Barrier initialization and transaction setup** (``mbarrier_init``, ``mbarrier_expect_tx``, + ``mbarrier_arrive_and_expect_tx``) + - **tcgen05 commit operations** (``tcgen05.commit``) - DSL does NOT + automatically guard these, unlike C++ which uses ``elect_one_sync()`` internally + - **Single-thread state setup** + + **When NOT to Use elect_one:** + + Do NOT use ``elect_one()`` for operations that already handle single-threaded execution internally: + + - **TMA copy operations** (``cute.copy`` with TMA atoms) - TMA partitioning ensures only one + thread within a warp issues the operation automatically. Wrapping in ``elect_one()`` can cause GPU deadlock. .. code-block:: python + # CORRECT: Initialize barrier with elect_one with elect_one(): - # Only one thread in the warp executes the code in this context - pass + cute.arch.mbarrier_init(barrier_ptr, arrival_count) + cute.arch.mbarrier_expect_tx(barrier_ptr, num_bytes) + + # CORRECT: tcgen05.commit requires elect_one in DSL + with elect_one(): + tcgen05.commit(barrier_ptr, None, cta_group) + + # CORRECT: TMA copy does not need elect_one + cute.copy( + tma_atom, + gmem_tensor, # TMA handles single-thread internally + smem_tensor, + tma_bar_ptr=barrier_ptr + ) + + **PTX Programming Model:** + + In the PTX programming model, certain cluster-scoped and CTA-scoped operations must be + issued by a single thread to maintain correctness. The ``elect.sync`` instruction provides + a warp-uniform way to select this thread with proper synchronization. + + :return: A context manager that executes its block on exactly one thread per warp + :rtype: IfOpRegion + + .. seealso:: + - :func:`cute.arch.mbarrier_init` - Requires elect_one + - :func:`cute.arch.mbarrier_expect_tx` - Requires elect_one + - :func:`cute.arch.mbarrier_arrive_and_expect_tx` - Requires elect_one + - PTX ISA documentation on ``elect.sync`` + - Tutorial example: ``examples/blackwell/tutorial_tma/tma_v0.py`` """ from cutlass.base_dsl.arch import Arch diff --git a/python/CuTeDSL/cutlass/cute/arch/mbar.py b/python/CuTeDSL/cutlass/cute/arch/mbar.py index ca04dcc34..5a1c6de8e 100644 --- a/python/CuTeDSL/cutlass/cute/arch/mbar.py +++ b/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -8,14 +8,13 @@ # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Optional - from cutlass.base_dsl.arch import Arch -from cutlass.cutlass_dsl import BaseDSL, T, if_generate, dsl_user_op +from cutlass.cutlass_dsl import BaseDSL, if_generate, dsl_user_op +from cutlass._mlir import ir from cutlass._mlir.dialects import nvvm, llvm -from ..typing import Pointer, Int, Boolean, Int32, AddressSpace +from ..typing import Optional, Pointer, Int, Boolean, Int32, AddressSpace #################################################################################################### # @@ -25,17 +24,39 @@ from ..typing import Pointer, Int, Boolean, Int32, AddressSpace @dsl_user_op -def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: +def mbarrier_init( + mbar_ptr: Pointer, + cnt: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Initializes a mbarrier with the specified thread arrival count. + **Single-Thread Execution Required**: This operation **must** be executed by only one thread + per CTA. Use :func:`cute.arch.elect_one` to ensure proper synchronization: + + .. code-block:: python + + with cute.arch.elect_one(): + cute.arch.mbarrier_init(barrier_ptr, arrival_count) + + **PTX Mapping**: This operation maps to the PTX ``mbarrier.init.shared.b64`` instruction, + which must be issued by a single thread for correctness. + :param mbar_ptr: A pointer to the mbarrier in SMEM :type mbar_ptr: Pointer :param cnt: The arrival count of the mbarrier :type cnt: Int + + .. seealso:: + - :func:`cute.arch.elect_one` - Required wrapper for single-thread execution + - :func:`cute.arch.mbarrier_expect_tx` - Also requires elect_one + - PTX ISA documentation on ``mbarrier.init`` """ nvvm.mbarrier_init_shared( - mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), + mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), # type: ignore[attr-defined] Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip, @@ -43,21 +64,45 @@ def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: @dsl_user_op -def mbarrier_init_fence(*, loc=None, ip=None) -> None: +def mbarrier_init_fence( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ A fence operation that applies to the mbarrier initializations. """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) + nvvm.fence_mbarrier_init(loc=loc, ip=ip) @dsl_user_op def mbarrier_arrive_and_expect_tx( - mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None + mbar_ptr: Pointer, + bytes: Int, + peer_cta_rank_in_cluster: Optional[Int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Arrives on a mbarrier and expects a specified number of transaction bytes. + Each thread that executes this operation will increment the arrival count by 1 and + increment the expected transaction bytes by the specified number. + + To ensure proper synchronization, most calls to this function should be wrapped in :func:`cute.arch.elect_one`. + + .. code-block:: python + + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(barrier_ptr, num_transaction_bytes) + + This is a combined operation that both arrives at the barrier (incrementing the arrival count) + and sets the expected transaction bytes. It is commonly used with TMA operations in pipelined + kernels. + + See the PTX ISA documentation on `mbarrier.arrive.expect_tx `__. + :param mbar_ptr: A pointer to the mbarrier in SMEM :type mbar_ptr: Pointer :param bytes: The number of transaction bytes @@ -65,10 +110,15 @@ def mbarrier_arrive_and_expect_tx( :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to the mbarrier is converted to a remote address in the peer CTA's SMEM. + + .. seealso:: + - :func:`cute.arch.elect_one` - Required wrapper for single-thread execution + - :func:`cute.arch.mbarrier_init` - Also requires elect_one + - :func:`cute.arch.mbarrier_expect_tx` - Expect_tx without arrive """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) # type: ignore[attr-defined] if peer_cta_rank_in_cluster is not None: mbar_cluster_type = llvm.PointerType.get(AddressSpace.dsmem) mbar_llvm_ptr = nvvm.mapa( @@ -96,11 +146,30 @@ def mbarrier_arrive_and_expect_tx( @dsl_user_op def mbarrier_expect_tx( - mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None + mbar_ptr: Pointer, + bytes: Int, + peer_cta_rank_in_cluster: Optional[Int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Expects a specified number of transaction bytes without an arrive. + Each thread that executes this operation will increment the expected transaction bytes by the specified number. + + To ensure proper synchronization, most calls to this function should be wrapped in :func:`cute.arch.elect_one`. + + .. code-block:: python + + with cute.arch.elect_one(): + cute.arch.mbarrier_expect_tx(barrier_ptr, num_transaction_bytes) + + This is commonly used with TMA operations to set the expected transaction size before + issuing a TMA load. + + See the PTX ISA documentation on `mbarrier.expect_tx `__. + :param mbar_ptr: A pointer to the mbarrier in SMEM :type mbar_ptr: Pointer :param bytes: The number of transaction bytes @@ -108,10 +177,15 @@ def mbarrier_expect_tx( :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to the mbarrier is converted to a remote address in the peer CTA's SMEM. + + .. seealso:: + - :func:`cute.arch.elect_one` - Recommended wrapper for single-thread execution + - :func:`cute.arch.mbarrier_init` - initialize mbarrier + - :func:`cute.arch.mbarrier_arrive_and_expect_tx` - Combined arrive and expect_tx """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) # type: ignore[attr-defined] if peer_cta_rank_in_cluster is not None: mbar_cluster_type = llvm.PointerType.get(AddressSpace.dsmem) mbar_llvm_ptr = nvvm.mapa( @@ -138,7 +212,13 @@ def mbarrier_expect_tx( @dsl_user_op -def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: +def mbarrier_wait( + mbar_ptr: Pointer, + phase: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Waits on a mbarrier with a specified phase. @@ -150,10 +230,11 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) timeout_ns = 10000000 + # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX # The timeout in ns only applies to the latter and this call is truly blocking nvvm.mbarrier_try_wait_parity_shared( - mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), + mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), # type: ignore[attr-defined] Int32(phase).ir_value(loc=loc, ip=ip), Int32(timeout_ns).ir_value(loc=loc, ip=ip), loc=loc, @@ -162,7 +243,13 @@ def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: @dsl_user_op -def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean: +def mbarrier_try_wait( + mbar_ptr: Pointer, + phase: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Boolean: """ Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. @@ -177,7 +264,7 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo return Boolean( nvvm.mbarrier_wait_parity( - mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), + mbar_ptr.to_llvm_ptr(loc=loc, ip=ip), # type: ignore[attr-defined] Int32(phase).ir_value(loc=loc, ip=ip), nvvm.MBarrierWaitKind.TRY, loc=loc, @@ -188,7 +275,12 @@ def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Bo @dsl_user_op def mbarrier_conditional_try_wait( - cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None + cond: Boolean, + mbar_ptr: Pointer, + phase: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Boolean: """ Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. @@ -202,7 +294,7 @@ def mbarrier_conditional_try_wait( :rtype: Boolean """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - return if_generate( + return if_generate( # type: ignore[return-value] cond, lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip), lambda: Boolean(True).ir_value(loc=loc, ip=ip), @@ -219,8 +311,8 @@ def mbarrier_arrive( peer_cta_rank_in_cluster: Optional[Int] = None, arrive_count: Int = 1, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Arrives on an mbarrier. @@ -231,7 +323,7 @@ def mbarrier_arrive( the mbarrier is converted to a remote address in the peer CTA's SMEM. """ - mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) # type: ignore[attr-defined] if peer_cta_rank_in_cluster is not None: BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) @@ -260,7 +352,12 @@ def mbarrier_arrive( @dsl_user_op -def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: +def cp_async_mbarrier_arrive_noinc( + mbar_ptr: Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Arrives on an mbarrier for async load **without incrementing** the arrival count (`cp.async.mbarrier.arrive.shared ..., noinc=1`). @@ -272,5 +369,5 @@ def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> N """ BaseDSL._get_dsl().check_arch(lambda arch: arch >= Arch.sm_90) - mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) + mbar_llvm_ptr = mbar_ptr.to_llvm_ptr(loc=loc, ip=ip) # type: ignore[attr-defined] nvvm.cp_async_mbarrier_arrive_shared(mbar_llvm_ptr, noinc=True, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py b/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py index 3d1a8cc83..ed8f4d540 100644 --- a/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py +++ b/python/CuTeDSL/cutlass/cute/arch/numeric_conversion.py @@ -9,12 +9,14 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from typing import Optional + from cutlass.base_dsl.arch import Arch from cutlass.base_dsl.common import DSLRuntimeError from cutlass.cutlass_dsl import BaseDSL, dsl_user_op from cutlass._mlir import ir -from cutlass._mlir.dialects import arith, llvm, vector +from cutlass._mlir.dialects import arith, vector from .nvvm_wrappers import ( cvt_i8_bf16, @@ -34,7 +36,13 @@ from ..typing import Int4, Int8, Float32, BFloat16, Int32 @dsl_user_op -def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): +def cvt_i8_bf16_intrinsic( + vec_i8: ir.Value, + length: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Fast conversion from int8 to bfloat16. It converts a vector of int8 to a vector of bfloat16. @@ -46,14 +54,20 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): :rtype: 1D vector of bfloat16 """ arch = BaseDSL._get_dsl().get_arch_enum() - if not arch in cvt_i8_bf16_intrinsic.supported_archs: + if arch not in cvt_i8_bf16_intrinsic.supported_archs: # type: ignore[attr-defined] raise DSLRuntimeError(f"cvt_i8_bf16_intrinsic is not supported on {arch}") src_pos = 0 vec_i8x4_type = ir.VectorType.get([4], Int8.mlir_type, loc=loc) vec_i8x2_type = ir.VectorType.get([2], Int8.mlir_type, loc=loc) vec_f32x2_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) vec_dst_type = ir.VectorType.get([length], BFloat16.mlir_type, loc=loc) - vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip) + zero_attr = ir.FloatAttr.get(BFloat16.mlir_type, 0.0) + vec_dst = arith.ConstantOp( + vec_dst_type, + ir.DenseElementsAttr.get_splat(vec_dst_type, zero_attr), + loc=loc, + ip=ip, + ).result arch = BaseDSL._get_dsl().get_arch_enum() # try to use vectorized version if length >= 4: @@ -62,7 +76,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): vec_i8x4 = vector.extract_strided_slice( vec_i8x4_type, vec_i8, [src_pos], [4], [1], loc=loc, ip=ip ) - if arch in cvt_i8_bf16_intrinsic.s26_bf16_supported_archs: + if arch in cvt_i8_bf16_intrinsic.s26_bf16_supported_archs: # type: ignore[attr-defined] vec_bf16x4 = cvt_i8x4_to_bf16x4(vec_i8x4, loc=loc, ip=ip) vec_dst = vector.insert_strided_slice( vec_bf16x4, vec_dst, [src_pos], [1], loc=loc, ip=ip @@ -90,7 +104,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): vec_i8x2 = vector.extract_strided_slice( vec_i8x2_type, vec_i8, [src_pos], [2], [1], loc=loc, ip=ip ) - if arch in cvt_i8_bf16_intrinsic.s26_bf16_supported_archs: + if arch in cvt_i8_bf16_intrinsic.s26_bf16_supported_archs: # type: ignore[attr-defined] vec_bf16x2 = cvt_i8x2_to_bf16x2(vec_i8x2, loc=loc, ip=ip) else: vec_f32x2 = cvt_i8x2_to_f32x2(vec_i8x2, loc=loc, ip=ip) @@ -101,7 +115,7 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): src_pos += 2 length -= 2 if length >= 1: - if arch in cvt_i8_bf16_intrinsic.s26_bf16_supported_archs: + if arch in cvt_i8_bf16_intrinsic.s26_bf16_supported_archs: # type: ignore[attr-defined] val_bf16 = cvt_i8_bf16( vector.extractelement( vec_i8, @@ -119,8 +133,8 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): loc=loc, ip=ip, ) - src_i32 = llvm.sext(Int32.mlir_type, src_i8, loc=loc, ip=ip) - src_f32 = llvm.sitofp(Float32.mlir_type, src_i32, loc=loc, ip=ip) + src_i32 = arith.ExtSIOp(Int32.mlir_type, src_i8, loc=loc, ip=ip) + src_f32 = arith.SIToFPOp(Float32.mlir_type, src_i32, loc=loc, ip=ip) val_bf16 = cvt_f32_bf16(src_f32, loc=loc, ip=ip) vec_dst = vector.insertelement( val_bf16, @@ -133,7 +147,14 @@ def cvt_i8_bf16_intrinsic(vec_i8, length, *, loc=None, ip=None): @dsl_user_op -def cvt_i4_bf16_intrinsic(vec_i4, length, *, with_shuffle=False, loc=None, ip=None): +def cvt_i4_bf16_intrinsic( + vec_i4: ir.Value, + length: int, + *, + with_shuffle: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Fast conversion from int4 to bfloat16. It converts a vector of int4 to a vector of bfloat16. @@ -152,14 +173,20 @@ def cvt_i4_bf16_intrinsic(vec_i4, length, *, with_shuffle=False, loc=None, ip=No :rtype: 1D vector of bfloat16 """ arch = BaseDSL._get_dsl().get_arch_enum() - if not arch in cvt_i4_bf16_intrinsic.supported_archs: + if arch not in cvt_i4_bf16_intrinsic.supported_archs: # type: ignore[attr-defined] raise DSLRuntimeError(f"cvt_i4_bf16_intrinsic is not supported on {arch}") src_pos = 0 vec_i4x8_type = ir.VectorType.get([8], Int4.mlir_type, loc=loc) vec_i4x4_type = ir.VectorType.get([4], Int4.mlir_type, loc=loc) vec_i4x2_type = ir.VectorType.get([2], Int4.mlir_type, loc=loc) vec_dst_type = ir.VectorType.get([length], BFloat16.mlir_type, loc=loc) - vec_dst = llvm.mlir_zero(vec_dst_type, loc=loc, ip=ip) + zero_attr = ir.FloatAttr.get(BFloat16.mlir_type, 0.0) + vec_dst = arith.ConstantOp( + vec_dst_type, + ir.DenseElementsAttr.get_splat(vec_dst_type, zero_attr), + loc=loc, + ip=ip, + ).result # try to use vectorized version if length >= 8: @@ -222,7 +249,13 @@ def cvt_i4_bf16_intrinsic(vec_i4, length, *, with_shuffle=False, loc=None, ip=No @dsl_user_op -def sext_unpacked_i4_i8_intrinsic(vec_unpacked_i4, length, *, loc=None, ip=None): +def sext_unpacked_i4_i8_intrinsic( + vec_unpacked_i4: ir.Value, + length: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Sign extend vector of int4 unpacked in 8b containers to packed int8 @@ -237,7 +270,13 @@ def sext_unpacked_i4_i8_intrinsic(vec_unpacked_i4, length, *, loc=None, ip=None) vec_i8x4_type = ir.VectorType.get([4], Int8.mlir_type, loc=loc) vec_i8_type = ir.VectorType.get([length], Int8.mlir_type, loc=loc) - vec_i8 = llvm.mlir_zero(vec_i8_type, loc=loc, ip=ip) + zero_attr = ir.IntegerAttr.get(Int8.mlir_type, 0) + vec_i8 = arith.ConstantOp( + vec_i8_type, + ir.DenseElementsAttr.get_splat(vec_i8_type, zero_attr), + loc=loc, + ip=ip, + ).result for pos in range(0, length, 4): vec_unpacked_i4x4 = vector.extract_strided_slice( @@ -252,19 +291,19 @@ def sext_unpacked_i4_i8_intrinsic(vec_unpacked_i4, length, *, loc=None, ip=None) # Expose supported architectures via the intrinsic symbol -cvt_i8_bf16_intrinsic.supported_archs = ( +cvt_i8_bf16_intrinsic.supported_archs = ( # type: ignore[attr-defined] *Arch.AmpereArchs(), *Arch.AdaArchs(), *Arch.HopperArchs(), *Arch.BlackwellArchs(), ) -cvt_i8_bf16_intrinsic.s26_bf16_supported_archs = ( +cvt_i8_bf16_intrinsic.s26_bf16_supported_archs = ( # type: ignore[attr-defined] Arch.sm_100a, Arch.sm_110a, Arch.sm_120a, Arch.sm_121a, ) -cvt_i4_bf16_intrinsic.supported_archs = ( +cvt_i4_bf16_intrinsic.supported_archs = ( # type: ignore[attr-defined] Arch.sm_100a, Arch.sm_110a, Arch.sm_120a, diff --git a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py index 7835a7ca9..8a9cba11e 100644 --- a/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py +++ b/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -10,21 +10,23 @@ # is strictly prohibited. from functools import partial -from typing import Any, Optional, Tuple, Union, Callable, Literal +from typing import Any, Optional, Tuple, Union, Callable, Literal, Type, overload from typing_extensions import deprecated -from cutlass.cutlass_dsl import T, dsl_user_op, target_version +from cutlass.cutlass_dsl import T, dsl_user_op import cutlass.cutlass_dsl as cutlass_dsl from cutlass._mlir import ir -from cutlass._mlir.dialects import arith, llvm, nvvm, vector +from cutlass._mlir.dialects import arith, builtin, llvm, math, nvvm as _nvvm_raw, vector +from cutlass.base_dsl._mlir_helpers.dialect_proxy import DialectAutoConvertProxy from ..core import size from ..typing import ( Int, Boolean, + Integer, Int8, Int16, Uint16, @@ -35,27 +37,24 @@ from ..typing import ( Float32, BFloat16, Numeric, + Uint64, + Pointer, as_numeric, ) WARP_SIZE = 32 FULL_MASK = 0xFFFFFFFF +# Create the proxy instance to replace the raw nvvm module +nvvm = DialectAutoConvertProxy(_nvvm_raw) + # ============================================================================ -# Enum String Mapping Helper +# Helper # ============================================================================ -# This section provides a helper to convert string literals to NVVM enum types -# by introspecting the enum's __str__() method. Each function imports and -# enhances only the enums it needs, avoiding namespace pollution. -# -# Usage within functions: -# MemOrderKind = _enhance_enum_with_str_mapping(MemOrderKind) -# sem = MemOrderKind.from_str("relaxed") -## ============================================================================ -def _enhance_enum_with_str_mapping(enum_class): +def _enhance_enum_with_str_mapping(enum_class: Any) -> Any: """ Enhance an IntEnum class with automatic string-to-enum conversion. @@ -81,8 +80,8 @@ def _enhance_enum_with_str_mapping(enum_class): str_to_enum_map[str_repr] = member # Add from_str class method - @classmethod - def from_str(cls, s): + @classmethod # type: ignore[misc] + def from_str(cls: Any, s: Any) -> Any: """ Convert a string literal to the corresponding enum member. @@ -91,18 +90,32 @@ def _enhance_enum_with_str_mapping(enum_class): :raises ValueError: If the string is not a valid enum member :raises TypeError: If an enum is passed instead of a string """ + if s is None: return None + # Check if user passed an enum (should be a string literal instead) - # This catches cases where user passes e.g., RoundingModeKind.RN instead of "rn" + # This catches cases where user passes e.g., FPRoundingMode.RN instead of "rn" from enum import Enum if isinstance(s, Enum): - raise TypeError( - f"Expected a string literal for {cls.__name__}, but got enum '{type(s).__name__}.{s.name}'. " - f"Please pass a string instead (e.g., '{str(s)}' instead of {type(s).__name__}.{s.name}). " - f"Valid string options are: {sorted(str_to_enum_map.keys())}" - ) + if cutlass_dsl.target_version(exact_version="12.9"): + import warnings + + warnings.warn( + f"Passing enum member directly to {cls.__name__}.from_str() is deprecated. " + f"Please use string literals instead (e.g., '{str(s)}' instead of {cls.__name__}.{s.name}).", + DeprecationWarning, + stacklevel=2, + ) + return s + else: + raise TypeError( + f"Expected a string literal for {cls.__name__}, but got enum '{type(s).__name__}.{s.name}'. " + f"Please pass a string instead (e.g., '{str(s)}' instead of {type(s).__name__}.{s.name}). " + f"Valid string options are: {sorted(str_to_enum_map.keys())}" + ) + if s not in str_to_enum_map: valid_options = sorted(str_to_enum_map.keys()) raise ValueError( @@ -115,8 +128,110 @@ def _enhance_enum_with_str_mapping(enum_class): return enum_class +def _cutlass_dtype_to_reduction_type_str(cutlass_dtype: type[Numeric]) -> str: + """ + Convert cutlass data type to ReductionType string literal. + + :param cutlass_dtype: Cutlass data type (e.g., cutlass.Uint32, cutlass.Float32) + :type cutlass_dtype: type[Numeric] + :return: ReductionType string literal (e.g., "u32", "f32") + :rtype: str + :raises ValueError: If cutlass_dtype is not supported for reduction operations + + Supported conversions: + cutlass.Uint32 -> "u32" + cutlass.Uint64 -> "u64" + cutlass.Int32 -> "s32" + cutlass.Int64 -> "s64" + cutlass.Float32 -> "f32" + cutlass.Float64 -> "f64" + cutlass.Float16 -> "f16" + cutlass.BFloat16 -> "bf16" + """ + import cutlass + + # Mapping from cutlass types to ReductionType string literals + dtype_map = { + cutlass.Uint32: "u32", + cutlass.Uint64: "u64", + cutlass.Int32: "s32", + cutlass.Int64: "s64", + cutlass.Float32: "f32", + cutlass.Float64: "f64", + cutlass.Float16: "f16", + cutlass.BFloat16: "bf16", + } + + if cutlass_dtype not in dtype_map: + valid_types = ", ".join([t.__name__ for t in dtype_map.keys()]) + raise ValueError( + f"Invalid cutlass dtype for reduction: '{cutlass_dtype.__name__}'. " + f"Valid types are: {valid_types}" + ) + + return dtype_map[cutlass_dtype] + + +def _reduction_type_str_to_cutlass_dtype(dtype_str: str) -> type[Numeric]: + """ + Convert ReductionType string literal to cutlass data type. + + :param dtype_str: ReductionType string literal (e.g., "u32", "f32") + :type dtype_str: str + :return: Cutlass data type (e.g., cutlass.Uint32, cutlass.Float32) + :rtype: type[Numeric] + :raises ValueError: If dtype_str is not supported for reduction operations + + Supported conversions: + "b32" -> cutlass.Uint32 (bitwise operations use unsigned) + "b64" -> cutlass.Uint64 (bitwise operations use unsigned) + "u32" -> cutlass.Uint32 + "u64" -> cutlass.Uint64 + "s32" -> cutlass.Int32 + "s64" -> cutlass.Int64 + "f32" -> cutlass.Float32 + "f64" -> cutlass.Float64 + "f16" -> cutlass.Float16 + "bf16" -> cutlass.BFloat16 + """ + import cutlass + + # Mapping from ReductionType string literals to cutlass types + str_to_dtype_map = { + "b32": cutlass.Uint32, # Bitwise operations use unsigned + "b64": cutlass.Uint64, + "u32": cutlass.Uint32, + "u64": cutlass.Uint64, + "s32": cutlass.Int32, + "s64": cutlass.Int64, + "f32": cutlass.Float32, + "f64": cutlass.Float64, + "f16": cutlass.Float16, + "bf16": cutlass.BFloat16, + # Vector types - map to base scalar type + "f16x2": cutlass.Float16, + "bf16x2": cutlass.BFloat16, + } + + if dtype_str not in str_to_dtype_map: + valid_options = sorted(str_to_dtype_map.keys()) + raise ValueError( + f"Invalid ReductionType string: '{dtype_str}'. " + f"Valid options are: {valid_options}" + ) + + return str_to_dtype_map[dtype_str] + + +# ============================================================================ +# Function +# ============================================================================ + + @dsl_user_op -def lane_idx(*, loc=None, ip=None) -> Int32: +def lane_idx( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: """ Returns the lane index of the current thread within the warp. """ @@ -124,7 +239,9 @@ def lane_idx(*, loc=None, ip=None) -> Int32: @dsl_user_op -def warp_idx(*, loc=None, ip=None) -> Int32: +def warp_idx( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: """ Returns the warp index within a CTA. """ @@ -139,7 +256,21 @@ def warp_idx(*, loc=None, ip=None) -> Int32: @dsl_user_op -def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def physical_warp_id( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: + """ + Returns the warp identifier. + + See the `PTX documentation `__. + """ + return Int32(nvvm.read_ptx_sreg_warpid(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def thread_idx( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the thread index within a CTA. """ @@ -151,7 +282,9 @@ def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def block_dim( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the number of threads in each dimension of the CTA. """ @@ -163,7 +296,9 @@ def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def block_idx( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the CTA identifier within a grid. """ @@ -175,7 +310,9 @@ def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def grid_dim( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the number of CTAs in each dimension of the grid. """ @@ -187,7 +324,9 @@ def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def cluster_idx( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the cluster identifier within a grid. """ @@ -199,7 +338,9 @@ def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def cluster_dim( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the number of clusters in each dimension of the grid. """ @@ -211,7 +352,9 @@ def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def block_in_cluster_idx( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the CTA index within a cluster across all dimensions. """ @@ -223,7 +366,9 @@ def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: +def block_in_cluster_dim( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Tuple[Int32, Int32, Int32]: """ Returns the dimensions of the cluster. """ @@ -235,7 +380,9 @@ def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: @dsl_user_op -def cluster_size(*, loc=None, ip=None) -> Int32: +def cluster_size( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: """ Returns the number of CTA within the cluster. """ @@ -243,7 +390,9 @@ def cluster_size(*, loc=None, ip=None) -> Int32: @dsl_user_op -def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: +def block_idx_in_cluster( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: """ Returns the linearized identifier of the CTA within the cluster. """ @@ -251,7 +400,9 @@ def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: @dsl_user_op -def dynamic_smem_size(*, loc=None, ip=None) -> Int32: +def dynamic_smem_size( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: """ Returns the launch dynamic smem size. """ @@ -272,15 +423,15 @@ def dynamic_smem_size(*, loc=None, ip=None) -> Int32: @dsl_user_op def shuffle_sync_op( - value: Union[Numeric, "TensorSSA"], + value: Union[Numeric, "TensorSSA"], # type: ignore[name-defined] # noqa: F821 offset: Int, mask: Int = FULL_MASK, mask_and_clamp: Int = WARP_SIZE - 1, - kind: nvvm.ShflKind = nvvm.ShflKind.idx, + kind: nvvm.ShflKind = nvvm.ShflKind.idx, # type: ignore[name-defined] *, - loc=None, - ip=None, -) -> Union[Numeric, "TensorSSA"]: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Numeric, "TensorSSA"]: # type: ignore[name-defined] # noqa: F821 """ Shuffles a value within the threads of a warp. @@ -325,16 +476,16 @@ def shuffle_sync_op( if not isinstance(value, Numeric): value = as_numeric(value) - if value.width > 64: + if value.width > 64: # type: ignore[attr-defined] raise ValueError("shuffle_sync only supports values up to 64 bits") orig_type = type(value) - if value.width < 32: + if value.width < 32: # type: ignore[attr-defined] if value.dtype.is_float: value = value.to(Float32) else: - if value.signed: + if value.signed: # type: ignore[attr-defined] value = value.to(Int32) else: value = value.to(Uint32) @@ -350,7 +501,7 @@ def shuffle_sync_op( ip=ip, ) ) - elif value.width == 32: + elif value.width == 32: # type: ignore[attr-defined] return orig_type( nvvm.shfl_sync( type(value).mlir_type, @@ -364,7 +515,7 @@ def shuffle_sync_op( ) ) else: - if value.width != 64: + if value.width != 64: # type: ignore[attr-defined] raise ValueError( "shuffle_sync only supports 64 bits values when the bit width is larger than 32" ) @@ -372,16 +523,12 @@ def shuffle_sync_op( T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip ) # extract low 32 bits - low_32_bits = llvm.trunc( - T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip - ) + low_32_bits = arith.trunci(T.i32(), value, loc=loc, ip=ip) # extract high 32 bits - high_32_bits = llvm.lshr( + high_32_bits = arith.shrui( value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip ) - high_32_bits = llvm.trunc( - T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip - ) + high_32_bits = arith.trunci(T.i32(), high_32_bits, loc=loc, ip=ip) low_32_bits_shfl = nvvm.shfl_sync( T.i32(), @@ -405,16 +552,12 @@ def shuffle_sync_op( ) # combine low and high 32 bits - low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip) - high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip) - shlf_res = llvm.shl( - high_64_bit, - Int64(32).ir_value(loc=loc, ip=ip), - llvm.IntegerOverflowFlags.none, - loc=loc, - ip=ip, + low_64_bit = arith.extui(T.i64(), low_32_bits_shfl, loc=loc, ip=ip) + high_64_bit = arith.extui(T.i64(), high_32_bits_shfl, loc=loc, ip=ip) + shlf_res = arith.shli( + high_64_bit, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip ) - shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip) + shlf_res = arith.ori(shlf_res, low_64_bit, loc=loc, ip=ip) shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip) return orig_type(shlf_res) @@ -427,7 +570,12 @@ shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly) @dsl_user_op def warp_reduction( - val: Numeric, op: Callable, *, threads_in_group: int = 32, loc=None, ip=None + val: Numeric, + op: Callable, + *, + threads_in_group: int = 32, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """warp reduction of a Numeric value(e.g.Float32) by shuffle_sync_bfly, accepts custom binary operator. The threads_in_group is the number of threads reduction group in a warp. @@ -464,7 +612,13 @@ warp_reduction_sum = partial(warp_reduction, op=lambda x, y: x + y) @dsl_user_op -def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None: +def barrier( + *, + barrier_id: Optional[Int] = None, + number_of_threads: Optional[Int] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Creates a barrier, optionally named. """ @@ -474,17 +628,44 @@ def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> No if number_of_threads is not None: number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) - nvvm.barrier( - barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip - ) + if cutlass_dsl.target_version(exact_version="12.9"): + if barrier_id is None: + barrier_id = Int32(0).ir_value(loc=loc, ip=ip) + has_count = number_of_threads is not None + operands = [barrier_id, number_of_threads] if has_count else [barrier_id] + llvm.inline_asm( + None, + operands, + f"bar.sync {'$0, $1' if has_count else '$0'};", + "r,r" if has_count else "r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + # TODO: support barrier with reduction result + nvvm.barrier( + barrier_id=barrier_id, + number_of_threads=number_of_threads, + loc=loc, + ip=ip, + ) @dsl_user_op def barrier_arrive( - *, barrier_id=None, number_of_threads=None, loc=None, ip=None + *, + barrier_id: Optional[Int] = None, + number_of_threads: Optional[Int] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: if barrier_id is not None: barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) + else: + barrier_id = Int32(0).ir_value(loc=loc, ip=ip) if number_of_threads is None: raise ValueError( @@ -492,13 +673,31 @@ def barrier_arrive( ) number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) - nvvm.barrier_arrive( - barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip - ) + if cutlass_dsl.target_version(exact_version="12.9"): + llvm.inline_asm( + None, + [barrier_id, number_of_threads], + "bar.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + nvvm.barrier_arrive( + barrier_id=barrier_id, + number_of_threads=number_of_threads, + loc=loc, + ip=ip, + ) @dsl_user_op -def sync_threads(*, loc=None, ip=None) -> None: +def sync_threads( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ Synchronizes all threads within a CTA. """ @@ -506,7 +705,12 @@ def sync_threads(*, loc=None, ip=None) -> None: @dsl_user_op -def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None: +def sync_warp( + mask: Int = FULL_MASK, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Performs a warp-wide sync with an optional mask. """ @@ -514,9 +718,11 @@ def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None: @dsl_user_op -def fence_acq_rel_cta(*, loc=None, ip=None) -> None: +def fence_acq_rel_cta( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ - Fence operation with acquire-release semantics. + Fence operation with acquire-release semantics at CTA (block) scope. See the `PTX documentation `__. """ @@ -524,9 +730,11 @@ def fence_acq_rel_cta(*, loc=None, ip=None) -> None: @dsl_user_op -def fence_acq_rel_cluster(*, loc=None, ip=None) -> None: +def fence_acq_rel_cluster( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ - Fence operation with acquire-release semantics. + Fence operation with acquire-release semantics at cluster scope. See the `PTX documentation `__. """ @@ -534,9 +742,11 @@ def fence_acq_rel_cluster(*, loc=None, ip=None) -> None: @dsl_user_op -def fence_acq_rel_gpu(*, loc=None, ip=None) -> None: +def fence_acq_rel_gpu( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ - Fence operation with acquire-release semantics. + Fence operation with acquire-release semantics at GPU (device) scope. See the `PTX documentation `__. """ @@ -544,9 +754,11 @@ def fence_acq_rel_gpu(*, loc=None, ip=None) -> None: @dsl_user_op -def fence_acq_rel_sys(*, loc=None, ip=None) -> None: +def fence_acq_rel_sys( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ - Fence operation with acquire-release semantics. + Fence operation with acquire-release semantics at system scope. See the `PTX documentation `__. """ @@ -554,7 +766,9 @@ def fence_acq_rel_sys(*, loc=None, ip=None) -> None: @dsl_user_op -def cp_async_commit_group(*, loc=None, ip=None) -> None: +def cp_async_commit_group( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ Commits all prior initiated but uncommitted cp.async instructions. @@ -564,7 +778,9 @@ def cp_async_commit_group(*, loc=None, ip=None) -> None: @dsl_user_op -def cp_async_wait_group(n, *, loc=None, ip=None) -> None: +def cp_async_wait_group( + n: Int, *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ Waits till only a specified numbers of cp.async groups are pending. @@ -574,7 +790,9 @@ def cp_async_wait_group(n, *, loc=None, ip=None) -> None: @dsl_user_op -def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None: +def cp_async_bulk_commit_group( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ Commits all prior initiated but uncommitted cp.async.bulk instructions. @@ -584,7 +802,13 @@ def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None: @dsl_user_op -def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None: +def cp_async_bulk_wait_group( + group: Int, + *, + read: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Waits till only a specified numbers of cp.async.bulk groups are pending. @@ -594,7 +818,9 @@ def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None: @dsl_user_op -def cluster_wait(*, loc=None, ip=None) -> None: +def cluster_wait( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ A cluster-wide wait operation. """ @@ -602,7 +828,12 @@ def cluster_wait(*, loc=None, ip=None) -> None: @dsl_user_op -def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None: +def cluster_arrive( + *, + aligned: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ A cluster-wide arrive operation. """ @@ -610,7 +841,12 @@ def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None: @dsl_user_op -def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None: +def cluster_arrive_relaxed( + *, + aligned: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ A cluster-wide arrive operation with relaxed semantics. """ @@ -624,8 +860,9 @@ def fence_proxy( ], *, space: Optional[Literal["cta", "cluster"]] = None, - loc=None, - ip=None, + use_intrinsic: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Fence operation to ensure memory consistency between proxies. @@ -642,6 +879,7 @@ def fence_proxy( - "cta" : CTA (Cooperative Thread Array) scope - "cluster" : Cluster scope :type space: Optional[Literal["cta", "cluster"]] + :param use_intrinsic: Whether to use intrinsic version """ from cutlass._mlir.dialects.nvvm import ( SharedSpace, @@ -665,18 +903,61 @@ def fence_proxy( @dsl_user_op def vote_sync_op( - pred: Boolean, kind: nvvm.VoteSyncKind, mask: Int = FULL_MASK, *, loc=None, ip=None + pred: Boolean, + kind: Literal["any", "all", "uni", "ballot"], + mask: Int = FULL_MASK, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Int32, Boolean]: """ Performs a vote operation across the warp. """ - return_type = Int32 if kind == nvvm.VoteSyncKind.ballot else Boolean + return_type = Int32 if kind == "ballot" else Boolean + if cutlass_dsl.target_version(exact_version="12.9"): + if kind == "ballot": + return return_type( + nvvm.vote_ballot_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + Boolean(pred).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + else: + return return_type( + llvm.inline_asm( + T.bool(), + [ + Boolean(pred).ir_value(loc=loc, ip=ip), + Int32(mask).ir_value(loc=loc, ip=ip), + ], + f"""{{\n\t + .reg .pred ps;\n\t + .reg .pred pd;\n\t + setp.ne.b32 ps, $1, 0;\n\t + vote.sync.{kind}.pred pd, ps, $2;\n\t + selp.b32 $0, 1, 0, pd;\n\t + }}""", + "=r,r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + from cutlass._mlir.dialects.nvvm import VoteSyncKind + + VoteSyncKind = _enhance_enum_with_str_mapping(VoteSyncKind) + return return_type( nvvm.vote_sync( - T.i32() if kind == nvvm.VoteSyncKind.ballot else T.bool(), + T.i32() if kind == "ballot" else T.bool(), Int32(mask).ir_value(loc=loc, ip=ip), Boolean(pred).ir_value(loc=loc, ip=ip), - kind, + VoteSyncKind.from_str(kind), loc=loc, ip=ip, ) @@ -684,7 +965,11 @@ def vote_sync_op( def vote_ballot_sync( - pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None + pred: Boolean, + mask: Int = FULL_MASK, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Int32: """Performs a ballot operation across the warp. @@ -700,12 +985,16 @@ def vote_ballot_sync( See the `PTX documentation `__. """ - return vote_sync_op(pred, nvvm.VoteSyncKind.ballot, mask, loc=loc, ip=ip) + return vote_sync_op(pred, "ballot", mask, loc=loc, ip=ip) @dsl_user_op def vote_any_sync( - pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None + pred: Boolean, + mask: Int = FULL_MASK, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Boolean: """True if source predicate is True for any non-exited threads in mask. Negate the source predicate to compute .none. @@ -721,12 +1010,16 @@ def vote_any_sync( See the `PTX documentation `__. """ - return vote_sync_op(pred, nvvm.VoteSyncKind.any, mask, loc=loc, ip=ip) + return vote_sync_op(pred, "any", mask, loc=loc, ip=ip) @dsl_user_op def vote_all_sync( - pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None + pred: Boolean, + mask: Int = FULL_MASK, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Boolean: """True if source predicate is True for all non-exited threads in mask. Negate the source predicate to compute .none. @@ -742,12 +1035,16 @@ def vote_all_sync( See the `PTX documentation `__. """ - return vote_sync_op(pred, nvvm.VoteSyncKind.all, mask, loc=loc, ip=ip) + return vote_sync_op(pred, "all", mask, loc=loc, ip=ip) @dsl_user_op def vote_uni_sync( - pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None + pred: Boolean, + mask: Int = FULL_MASK, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Boolean: """True f source predicate has the same value in all non-exited threads in mask. Negating the source predicate also computes .uni @@ -761,25 +1058,30 @@ def vote_uni_sync( threads in mask :rtype: Boolean """ - return vote_sync_op(pred, nvvm.VoteSyncKind.uni, mask, loc=loc, ip=ip) + return vote_sync_op(pred, "uni", mask, loc=loc, ip=ip) @dsl_user_op -def popc(value: Numeric, *, loc=None, ip=None) -> Numeric: +def popc( + value: Numeric, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Numeric: """ Performs a population count operation. """ if not isinstance(value, Numeric): value = as_numeric(value) - return type(value)(llvm.intr_ctpop(value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return type(value)(math.ctpop(value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @dsl_user_op def fence_view_async_tmem_op( kind: Literal["load", "store"], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Perform a fence operation on the async TMEM load or store. @@ -828,8 +1130,8 @@ fence_view_async_tmem_store = partial(fence_view_async_tmem_op, kind="store") @dsl_user_op def fence_view_async_shared( *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Perform a fence operation on the async shared memory load or store. @@ -849,10 +1151,11 @@ def fence_view_async_shared( def setmaxregister_increase( reg_count: int, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: from cutlass._mlir.dialects.nvvm import SetMaxRegisterAction + return nvvm.setmaxregister(reg_count, SetMaxRegisterAction.increase, loc=loc, ip=ip) @@ -860,10 +1163,11 @@ def setmaxregister_increase( def setmaxregister_decrease( reg_count: int, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: from cutlass._mlir.dialects.nvvm import SetMaxRegisterAction + return nvvm.setmaxregister(reg_count, SetMaxRegisterAction.decrease, loc=loc, ip=ip) @@ -872,10 +1176,11 @@ def setmaxregister_decrease( def warpgroup_reg_alloc( reg_count: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: from cutlass._mlir.dialects.nvvm import SetMaxRegisterAction + nvvm.setmaxregister(reg_count, SetMaxRegisterAction.increase, loc=loc, ip=ip) @@ -884,10 +1189,11 @@ def warpgroup_reg_alloc( def warpgroup_reg_dealloc( reg_count: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: from cutlass._mlir.dialects.nvvm import SetMaxRegisterAction + nvvm.setmaxregister(reg_count, SetMaxRegisterAction.decrease, loc=loc, ip=ip) @@ -899,15 +1205,15 @@ def calc_packed_f32x2_op( calc_func: Callable, *, rnd: Optional[Literal["rn", "rz", "rm", "rp", "none"]] = "rn", - ftz=None, - loc=None, - ip=None, + ftz: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Float32, Float32]: - from cutlass._mlir.dialects.nvvm import RoundingModeKind + from cutlass._mlir.dialects.nvvm import FPRoundingMode # Enhance enum and convert string literal to enum type - RoundingModeKind = _enhance_enum_with_str_mapping(RoundingModeKind) - rnd = RoundingModeKind.from_str(rnd) + FPRoundingMode = _enhance_enum_with_str_mapping(FPRoundingMode) + rnd = FPRoundingMode.from_str(rnd) vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) vec_src_a = vector.from_elements( @@ -957,24 +1263,76 @@ mul_packed_f32x2 = partial( add_packed_f32x2 = partial( calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2 ) +sub_packed_f32x2 = partial( + calc_packed_f32x2_op, src_c=None, calc_func=nvvm.sub_packed_f32x2 +) @dsl_user_op def fmax( - a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None + a: Union[float, Float32], + b: Union[float, Float32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Float32: - return Float32( - nvvm.fmax( - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, + if cutlass_dsl.target_version(exact_version="12.9"): + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + else: + return Float32( + nvvm.fmax( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) ) - ) @dsl_user_op -def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None): +def fmin( + a: Union[float, Float32], + b: Union[float, Float32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Float32: + if cutlass_dsl.target_version(exact_version="12.9"): + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + else: + return Float32( + nvvm.fmin( + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def rcp_approx( + a: Union[float, Float32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Float32: return Float32( nvvm.rcp_approx_ftz_f(Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) ) @@ -984,7 +1342,12 @@ def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None): @deprecated( "cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead" ) -def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: +def exp2( + a: Union[float, Float32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Float32: return Float32( llvm.inline_asm( T.f32(), @@ -1000,8 +1363,13 @@ def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: # Convert 1 int8 value to 1 bfloat16 value @dsl_user_op -def cvt_i8_bf16(src_i8, *, loc=None, ip=None): - src_i16 = llvm.zext(Int16.mlir_type, src_i8, loc=loc, ip=ip) +def cvt_i8_bf16( + src_i8: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: + src_i16 = arith.extui(Int16.mlir_type, src_i8, loc=loc, ip=ip) val_i16 = llvm.inline_asm( Uint16.mlir_type, [ @@ -1021,7 +1389,12 @@ def cvt_i8_bf16(src_i8, *, loc=None, ip=None): @dsl_user_op -def cvt_i8x2_to_bf16x2(src_vec2, *, loc=None, ip=None): +def cvt_i8x2_to_bf16x2( + src_vec2: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # pack 2 int8 into 1 int16 value src_i16 = llvm.bitcast(Int16.mlir_type, src_vec2, loc=loc, ip=ip) val_i32 = llvm.inline_asm( @@ -1043,7 +1416,12 @@ def cvt_i8x2_to_bf16x2(src_vec2, *, loc=None, ip=None): @dsl_user_op -def cvt_i8x4_to_bf16x4(src_vec4, *, loc=None, ip=None): +def cvt_i8x4_to_bf16x4( + src_vec4: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # pack 4 int8 into 1 int32 value src_i32 = llvm.bitcast(Int32.mlir_type, src_vec4, loc=loc, ip=ip) rst01 = llvm.inline_asm( @@ -1084,7 +1462,12 @@ def cvt_i8x4_to_bf16x4(src_vec4, *, loc=None, ip=None): # Convert vector of 2 float values to vector of 2 bfloat16 values with satfinite rounding @dsl_user_op -def cvt_f32x2_bf16x2(src_vec2, *, loc=None, ip=None): +def cvt_f32x2_bf16x2( + src_vec2: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: src0 = vector.extractelement( src_vec2, position=arith.constant(Int32.mlir_type, 0, loc=loc, ip=ip) ) @@ -1110,7 +1493,12 @@ def cvt_f32x2_bf16x2(src_vec2, *, loc=None, ip=None): # Convert 1 float32 value to 1 bfloat16 value @dsl_user_op -def cvt_f32_bf16(src_f32, *, loc=None, ip=None): +def cvt_f32_bf16( + src_f32: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: bf16_val = llvm.inline_asm( BFloat16.mlir_type, [ @@ -1124,7 +1512,12 @@ def cvt_f32_bf16(src_f32, *, loc=None, ip=None): # Convert vector of 4 int8 values to vector of 4 float32 values @dsl_user_op -def cvt_i8x4_to_f32x4(src_vec4, *, loc=None, ip=None): +def cvt_i8x4_to_f32x4( + src_vec4: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: zero = arith.constant(Int32.mlir_type, 0, loc=loc, ip=ip) mask4 = ( arith.constant(Int32.mlir_type, 0x00000001, loc=loc, ip=ip), @@ -1238,14 +1631,19 @@ def cvt_i8x4_to_f32x4(src_vec4, *, loc=None, ip=None): # Convert vector of 2 int8 values to vector of 2 float32 values @dsl_user_op -def cvt_i8x2_to_f32x2(src_vec2, *, loc=None, ip=None): +def cvt_i8x2_to_f32x2( + src_vec2: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: zero = arith.constant(Int32.mlir_type, 0, loc=loc, ip=ip) mask2 = ( arith.constant(Int32.mlir_type, 0x00000001, loc=loc, ip=ip), arith.constant(Int32.mlir_type, 0x00000100, loc=loc, ip=ip), ) src_i16 = llvm.bitcast(Int16.mlir_type, src_vec2, loc=loc, ip=ip) - src_i32_pad16b = llvm.zext(Int32.mlir_type, src_i16, loc=loc, ip=ip) + src_i32_pad16b = arith.extui(Int32.mlir_type, src_i16, loc=loc, ip=ip) rst0 = llvm.inline_asm( Int32.mlir_type, [ @@ -1301,7 +1699,14 @@ def cvt_i8x2_to_f32x2(src_vec2, *, loc=None, ip=None): # Permute bytes from register pair. @dsl_user_op -def prmt(src, src_reg_shifted, prmt_indices, *, loc=None, ip=None): +def prmt( + src: Int, + src_reg_shifted: Int, + prmt_indices: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: return llvm.inline_asm( T.i32(), [ @@ -1319,10 +1724,15 @@ def prmt(src, src_reg_shifted, prmt_indices, *, loc=None, ip=None): # Convert 1 int4 value to 1 bfloat16 value @dsl_user_op -def cvt_i4_bf16(src_i4, *, loc=None, ip=None): +def cvt_i4_bf16( + src_i4: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # i4 -> i32 -> f32 -> bf - src_i32 = llvm.sext(Int32.mlir_type, src_i4, loc=loc, ip=ip) - src_f32 = llvm.sitofp(Float32.mlir_type, src_i32, loc=loc, ip=ip) + src_i32 = arith.extsi(Int32.mlir_type, src_i4, loc=loc, ip=ip) + src_f32 = arith.sitofp(Float32.mlir_type, src_i32, loc=loc, ip=ip) bf16_val = cvt_f32_bf16(src_f32, loc=loc, ip=ip) return bf16_val @@ -1338,9 +1748,14 @@ def cvt_i4_bf16(src_i4, *, loc=None, ip=None): # Int4 values are packed into int32 values with upper bits filled with 0 if there are less than 4 int4 values. # Results bfloat16 values are also packed into int32 values. @dsl_user_op -def cvt_i4_to_bf16_with_shuffle_impl(src_i32, num_elts, *, loc=None, ip=None): - from cutlass import CUDA_VERSION - if CUDA_VERSION.major < 13: +def cvt_i4_to_bf16_with_shuffle_impl( + src_i32: ir.Value, + num_elts: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: + if not cutlass_dsl.target_version(min_version="13.1"): raise cutlass_dsl.DSLCudaVerNotImplemented( feature="cvt_i4_to_bf16_with_shuffle_impl", required_version="13.1" ) @@ -1436,9 +1851,15 @@ def cvt_i4_to_bf16_with_shuffle_impl(src_i32, num_elts, *, loc=None, ip=None): # Int4 values are packed into int32 values with upper bits filled with 0 if there are less than 4 int4 values. # Results bfloat16 values are also packed into int32 values. @dsl_user_op -def cvt_i4_to_bf16_impl(src_i32, num_elts, *, loc=None, ip=None): +def cvt_i4_to_bf16_impl( + src_i32: ir.Value, + num_elts: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: c4 = arith.constant(Int32.mlir_type, 4, loc=loc, ip=ip) - src_shr4 = llvm.lshr(src_i32, c4, loc=loc, ip=ip) + src_shr4 = arith.shrui(src_i32, c4, loc=loc, ip=ip) xor_mask0 = arith.constant(Int32.mlir_type, 0x08080808, loc=loc, ip=ip) and_mask = arith.constant(Int32.mlir_type, 0x0F0F0F0F, loc=loc, ip=ip) imm_lut = arith.constant(Int32.mlir_type, 0x0000006A, loc=loc, ip=ip) @@ -1519,11 +1940,17 @@ def cvt_i4_to_bf16_impl(src_i32, num_elts, *, loc=None, ip=None): # Convert 2 int4 values to 2 bfloat16 values @dsl_user_op -def cvt_i4x2_to_bf16x2(src_vec2, *, with_shuffle=False, loc=None, ip=None): +def cvt_i4x2_to_bf16x2( + src_vec2: ir.Value, + *, + with_shuffle: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: cvt_func = cvt_i4_to_bf16_with_shuffle_impl if with_shuffle else cvt_i4_to_bf16_impl # pack 2 int4 into 1 int32 value and fill upper bits with 0 src_i8 = llvm.bitcast(Int8.mlir_type, src_vec2, loc=loc, ip=ip) - src_i32 = llvm.zext(Int32.mlir_type, src_i8, loc=loc, ip=ip) + src_i32 = arith.extui(Int32.mlir_type, src_i8, loc=loc, ip=ip) rst_i32 = cvt_func(src_i32, 2, loc=loc, ip=ip) vec_bf16x2_type = ir.VectorType.get([2], BFloat16.mlir_type, loc=loc) vec_bf16x2 = llvm.bitcast(vec_bf16x2_type, rst_i32, loc=loc, ip=ip) @@ -1532,11 +1959,17 @@ def cvt_i4x2_to_bf16x2(src_vec2, *, with_shuffle=False, loc=None, ip=None): # Convert 4 int4 values to 4 bfloat16 values @dsl_user_op -def cvt_i4x4_to_bf16x4(src_vec4, *, with_shuffle=False, loc=None, ip=None): +def cvt_i4x4_to_bf16x4( + src_vec4: ir.Value, + *, + with_shuffle: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: cvt_func = cvt_i4_to_bf16_with_shuffle_impl if with_shuffle else cvt_i4_to_bf16_impl # pack 4 int4 into 1 int32 value and fill upper bits with 0 src_i16 = llvm.bitcast(Int16.mlir_type, src_vec4, loc=loc, ip=ip) - src_i32 = llvm.zext(Int32.mlir_type, src_i16, loc=loc, ip=ip) + src_i32 = arith.extui(Int32.mlir_type, src_i16, loc=loc, ip=ip) rst_i32 = cvt_func(src_i32, 4, loc=loc, ip=ip) vec_bf16x4_type = ir.VectorType.get([4], BFloat16.mlir_type, loc=loc) vec_bf16x4 = llvm.bitcast(vec_bf16x4_type, rst_i32, loc=loc, ip=ip) @@ -1545,7 +1978,13 @@ def cvt_i4x4_to_bf16x4(src_vec4, *, with_shuffle=False, loc=None, ip=None): # Convert 8 int4 values to 8 bfloat16 values @dsl_user_op -def cvt_i4x8_to_bf16x8(src_vec8, *, with_shuffle=False, loc=None, ip=None): +def cvt_i4x8_to_bf16x8( + src_vec8: ir.Value, + *, + with_shuffle: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: cvt_func = cvt_i4_to_bf16_with_shuffle_impl if with_shuffle else cvt_i4_to_bf16_impl # pack 8 int4 into 1 int32 value and fill upper bits with 0 src_i32 = llvm.bitcast(Int32.mlir_type, src_vec8, loc=loc, ip=ip) @@ -1557,7 +1996,12 @@ def cvt_i4x8_to_bf16x8(src_vec8, *, with_shuffle=False, loc=None, ip=None): # Sign extend 4 int4 unpacked in 8b containers @dsl_user_op -def sext_unpacked_i4x4_to_i8x4(src_vec4, *, loc=None, ip=None): +def sext_unpacked_i4x4_to_i8x4( + src_vec4: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: imm_u32 = arith.constant(Uint32.mlir_type, 0x78787878, loc=loc, ip=ip) src_u32 = llvm.bitcast(Uint32.mlir_type, src_vec4, loc=loc, ip=ip) dst_u32 = arith.addi(src_u32, imm_u32, loc=loc, ip=ip) @@ -1566,7 +2010,12 @@ def sext_unpacked_i4x4_to_i8x4(src_vec4, *, loc=None, ip=None): @dsl_user_op -def log2_of_pow2_int(a: Int32, *, loc=None, ip=None) -> Int32: +def log2_of_pow2_int( + a: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: tmp = llvm.inline_asm( Int32.mlir_type, [a.ir_value(loc=loc, ip=ip)], @@ -1593,7 +2042,12 @@ def log2_of_pow2_int(a: Int32, *, loc=None, ip=None) -> Int32: @deprecated( "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead" ) -def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: +def exp( + a: Union[float, Float32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Float32: LOG2_E = 1.4426950408889634 return exp2(a * LOG2_E, loc=loc, ip=ip) @@ -1603,7 +2057,10 @@ def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: "cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead" ) def exp_packed_f32x2( - a: Tuple[Float32, Float32], *, loc=None, ip=None + a: Tuple[Float32, Float32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Float32, Float32]: LOG2_E = Float32(1.4426950408889634) b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip) @@ -1611,7 +2068,9 @@ def exp_packed_f32x2( @dsl_user_op -def griddepcontrol_wait(*, loc=None, ip=None) -> None: +def griddepcontrol_wait( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ This instruction is used to wait for the previous kernel's grid ending (all blocks of the previous kernel have finished and memflushed), i.e., @@ -1631,7 +2090,9 @@ def griddepcontrol_wait(*, loc=None, ip=None) -> None: @dsl_user_op -def griddepcontrol_launch_dependents(*, loc=None, ip=None) -> None: +def griddepcontrol_launch_dependents( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ Issuing the launch_dependents instruction hints a dependent kernel to launch earlier. launch_dependents doesn't impact the functionality but the performance: @@ -1659,6 +2120,8 @@ def _warp_redux_sync_nvvm( "fmin", "max", "min", + "umax", + "umin", "add", "xor", "or", @@ -1666,10 +2129,10 @@ def _warp_redux_sync_nvvm( ], mask_and_clamp: Int = FULL_MASK, abs: bool = False, - nan: bool = None, + nan: Optional[bool] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: from cutlass._mlir.dialects.nvvm import ReduxKind @@ -1678,6 +2141,12 @@ def _warp_redux_sync_nvvm( kind = ReduxKind.from_str(kind) value_type = type(value) + if value_type.is_integer and not value_type.signed: # type: ignore[attr-defined] + if kind == ReduxKind.MAX: + kind = ReduxKind.UMAX + elif kind == ReduxKind.MIN: + kind = ReduxKind.UMIN + value_ir = value.ir_value(loc=loc, ip=ip) return value_type( @@ -1704,12 +2173,15 @@ def _warp_redux_sync_ptx( "min", ], mask_and_clamp: Int = FULL_MASK, - abs: bool = None, - nan: bool = None, + abs: Optional[bool] = None, + nan: Optional[bool] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: + """ + **ONLY** support f32 as nvvm compatability + """ value_type = type(value) value_ir = value.ir_value(loc=loc, ip=ip) mlir_type = value_type.mlir_type @@ -1735,7 +2207,7 @@ def _warp_redux_sync_ptx( mlir_type, [value_ir, mask_ir], f"{ptx_instr}", - f"=f,f,i", + "=f,f,i", has_side_effects=True, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, @@ -1751,6 +2223,8 @@ def warp_redux_sync( "fmin", "max", "min", + "umax", + "umin", "add", "xor", "or", @@ -1758,10 +2232,10 @@ def warp_redux_sync( ], mask_and_clamp: Int = FULL_MASK, *, - abs: bool = None, - nan: bool = None, - loc=None, - ip=None, + abs: Optional[bool] = None, + nan: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Perform warp-level reduction operation across threads. @@ -1772,9 +2246,11 @@ def warp_redux_sync( :param value: Input value to reduce :type value: Numeric :param kind: Reduction operation. Supported operations: - - Integer types (Int32/Uint32): "add", "and", "max", "min", "or", "xor" + + - Integer types (Int32/Uint32): "add", "and", "max", "min", "umax", "umin", "or", "xor" + "max"/"min" auto-promote to "umax"/"umin" for unsigned types (Uint32/Uint64). - Float types (Float32): "fmax", "fmin" (or "max"/"min" which auto-convert to "fmax"/"fmin") - :type kind: Literal["add", "and", "max", "min", "or", "xor", "fmin", "fmax"] + :type kind: Literal["add", "and", "max", "min", "umax", "umin", "or", "xor", "fmin", "fmax"] :param mask_and_clamp: Warp participation mask (default: FULL_MASK = 0xFFFFFFFF) :type mask_and_clamp: Int :param abs: Apply absolute value before reduction (float types only) @@ -1803,45 +2279,12 @@ def warp_redux_sync( ) -@dsl_user_op -def atomic_max_float32( - ptr, - value: Float32, +def _normalize_ptr( + addr: Union[ir.Value, Pointer], *, - positive_only: bool = True, - loc=None, - ip=None, -) -> Float32: - """ - Performs an atomic max operation on a float32 value in global memory. - - This implementation works correctly for non-negative values (>= 0) using direct bitcast. - - :param ptr: Pointer to the memory location - :param value: The float32 value to compare and potentially store (should be >= 0 for correct results) - :type value: Float32 - :param positive_only: If True (default), assumes input values are non-negative. - This parameter is provided for API compatibility and future extensions. - :type positive_only: bool - :return: The old value at the memory location - :rtype: Float32 - """ - from cutlass._mlir.dialects.nvvm import AtomicOpKind - - value_int = llvm.bitcast(T.i32(), value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - - old_value_int = nvvm.atomicrmw( - AtomicOpKind.MAX, - ptr, - value_int, - loc=loc, - ip=ip, - ) - - return Float32(llvm.bitcast(T.f32(), old_value_int, loc=loc, ip=ip)) - - -def _normalize_ptr(addr, *, loc=None, ip=None) -> ir.Value: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Helper function to normalize pointer types to MLIR ir.Value. @@ -1867,7 +2310,7 @@ def _normalize_ptr(addr, *, loc=None, ip=None) -> ir.Value: def _atomic( - ptr, + ptr: Union[ir.Value, Pointer], val: Union[Numeric, ir.Value], *, op: Literal[ @@ -1875,15 +2318,17 @@ def _atomic( "fadd", "max", "min", + "umax", + "umin", "and", "or", "xor", "exch", ], sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Numeric, ir.Value]: """ General atomic operation function. @@ -1896,16 +2341,16 @@ def _atomic( :param val: Value to add (scalar Numeric or vector ir.Value) :type val: Union[Numeric, ir.Value] :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") - :param op: Atomic operation ("add", "fadd", "max", "min", "and", "or", "xor", "exch") - :type op: Literal["add", "fadd", "max", "min", "and", "or", "xor", "exch"] + :param op: Atomic operation ("add", "fadd", "max", "min", "umax", "umin", "and", "or", "xor", "exch"). + "max"/"min" auto-promote to "umax"/"umin" for unsigned types (Uint32/Uint64). + :type op: Literal["add", "fadd", "max", "min", "umax", "umin", "and", "or", "xor", "exch"] :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Union[Numeric, ir.Value] """ from cutlass._mlir.dialects.nvvm import AtomicOpKind, MemOrderKind, MemScopeKind - from cutlass import CUDA_VERSION # Enhance enums and convert string literals to enum types AtomicOpKind = _enhance_enum_with_str_mapping(AtomicOpKind) @@ -1925,9 +2370,9 @@ def _atomic( if is_vector: # Vector type atomic - val is already an ir.Value val_ir = val - val_type = val.type + val_type = val.type # type: ignore[union-attr] # Check if it's a floating-point vector type - elem_type = val.type.element_type + elem_type = val.type.element_type # type: ignore[union-attr] is_float_vector = ( elem_type == Float16.mlir_type or elem_type == BFloat16.mlir_type @@ -1948,13 +2393,16 @@ def _atomic( # For .f32, .f64, .f16, .bf16, .f16x2, .bf16x2, only .add (FADD) is supported # For .u32 .u64, .s32, .s64, .add .and .or .xor .cas .exch .min .max are supported if val_type.is_float: - # For floating-point types, only ADD is supported if op == AtomicOpKind.ADD: - # Convert ADD to FADD for floating-point types op = AtomicOpKind.FADD + elif val_type.is_integer and not val_type.signed: # type: ignore[attr-defined] + if op == AtomicOpKind.MAX: + op = AtomicOpKind.UMAX + elif op == AtomicOpKind.MIN: + op = AtomicOpKind.UMIN # * NVVM call based on nvvm version - if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + if cutlass_dsl.target_version(exact_version="12.9"): # Old API: requires explicit result type as first positional argument # For vectors: pass val_type (ir.VectorType), for scalars: pass val_type.mlir_type result_type = val_type if is_vector else val_type.mlir_type @@ -1984,13 +2432,13 @@ def _atomic( def atomic_add( - ptr, + ptr: Union[ir.Value, Pointer], val: Union[Numeric, ir.Value], *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Numeric, ir.Value]: """ Performs an atomic addition operation. @@ -2002,8 +2450,8 @@ def atomic_add( :type val: Union[Numeric, ir.Value] :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Union[Numeric, ir.Value] """ @@ -2011,13 +2459,13 @@ def atomic_add( def atomic_and( - ptr, + ptr: Union[ir.Value, Pointer], val: Numeric, *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic bitwise AND operation. @@ -2029,8 +2477,8 @@ def atomic_and( :type val: Numeric :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Numeric """ @@ -2038,13 +2486,13 @@ def atomic_and( def atomic_or( - ptr, + ptr: Union[ir.Value, Pointer], val: Numeric, *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic bitwise OR operation. @@ -2056,8 +2504,8 @@ def atomic_or( :type val: Numeric :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Numeric """ @@ -2065,13 +2513,13 @@ def atomic_or( def atomic_xor( - ptr, + ptr: Union[ir.Value, Pointer], val: Numeric, *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic bitwise XOR operation. @@ -2083,8 +2531,8 @@ def atomic_xor( :type val: Numeric :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Numeric """ @@ -2092,13 +2540,13 @@ def atomic_xor( def atomic_max( - ptr, + ptr: Union[ir.Value, Pointer], val: Numeric, *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic maximum operation. @@ -2110,8 +2558,8 @@ def atomic_max( :type val: Numeric :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Numeric """ @@ -2119,13 +2567,13 @@ def atomic_max( def atomic_min( - ptr, + ptr: Union[ir.Value, Pointer], val: Numeric, *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic minimum operation. @@ -2146,13 +2594,13 @@ def atomic_min( def atomic_exch( - ptr, + ptr: Union[ir.Value, Pointer], val: Numeric, *, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic exchange operation. @@ -2164,24 +2612,109 @@ def atomic_exch( :type val: Numeric :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Numeric """ - return _atomic(ptr, val, op="exch", sem=sem, scope=scope, loc=loc, ip=ip) + return _atomic(ptr, val, op="exch", sem=sem, scope=scope, loc=loc, ip=ip) # type: ignore[arg-type] + + +@dsl_user_op +def atomic_fmax( + ptr: Union[ir.Value, Pointer], + val: Float32, + *, + sign_bit: Optional[bool] = None, + sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, + scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Float32: + """ + Implementation of atomic fmax using integer bitcast. + + Works for +inf, -inf, and signbit-0 nans including canonical nan. + Atomically maxes `val` to the value at memory location `ptr` and returns the old value. + + :param ptr: Pointer to memory location. Supports: + - ir.Value (LLVM pointer) + - cute.ptr (_Pointer instance) + :param val: value to max + :type val: Float32 + :param sign_bit: Indicates the sign bit of `val` if known beforehand, e.g. abs vals + :type sign_bit: Optional[bool] + :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") + :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] + :param scope: Memory scope ("gpu", "cta", "cluster", "sys") + :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :return: Old value at memory location + :rtype: Float32 + """ + intval = llvm.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + then_body = lambda: atomic_min( + ptr, + Uint32(intval), + sem=sem, + scope=scope, # type: ignore[arg-type] + loc=loc, + ip=ip, + ) + else_body = lambda: atomic_max( + ptr, + Int32(intval), + sem=sem, + scope=scope, # type: ignore[arg-type] + loc=loc, + ip=ip, + ) + + if sign_bit is None: + old_intval = cutlass_dsl.if_generate( + Int32(intval) < 0, + then_body, + else_body, + [], + [Int32], + loc=loc, + ip=ip, + ) + elif sign_bit: + old_intval = then_body() + else: + old_intval = else_body() + + assert not isinstance(old_intval, list) + return Float32( + llvm.bitcast(T.f32(), old_intval.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + + +@deprecated("atomic_max_float32 is deprecated, use atomic_fmax instead") +def atomic_max_float32( + ptr: Union[ir.Value, Pointer], + value: Float32, + *, + positive_only: bool = True, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Float32: + """Deprecated: use atomic_fmax instead.""" + return atomic_fmax( + ptr, value, sign_bit=False if positive_only else None, loc=loc, ip=ip + ) @dsl_user_op def atomic_cas( - ptr, + ptr: Union[ir.Value, Pointer], *, cmp: Numeric, val: Numeric, sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] = None, - scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + scope: Optional[Literal["gpu", "cta", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Numeric: """ Performs an atomic compare-and-swap (CAS) operation. @@ -2198,13 +2731,12 @@ def atomic_cas( :type val: Numeric :param sem: Memory semantic ("relaxed", "release", "acquire", "acq_rel") :type sem: Optional[Literal["relaxed", "release", "acquire", "acq_rel"]] - :param scope: Memory scope ("gpu", "cta", "cluster", "sys") - :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :param scope: Memory scope ("gpu", "cta", "sys") + :type scope: Optional[Literal["gpu", "cta", "sys"]] :return: Old value at memory location :rtype: Numeric """ from cutlass._mlir.dialects.nvvm import AtomicOpKind, MemOrderKind, MemScopeKind - from cutlass import CUDA_VERSION # Enhance enums and convert string literals to enum types MemOrderKind = _enhance_enum_with_str_mapping(MemOrderKind) @@ -2226,7 +2758,7 @@ def atomic_cas( val_ir = val.ir_value(loc=loc, ip=ip) # * NVVM call based on nvvm version - if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: + if cutlass_dsl.target_version(exact_version="12.9"): result = nvvm.atomicrmw( cmp_type.mlir_type, op=AtomicOpKind.CAS, @@ -2238,23 +2770,12 @@ def atomic_cas( loc=loc, ip=ip, ) - elif CUDA_VERSION.major == 13 and CUDA_VERSION.minor == 1: - result = nvvm.atomicrmw( - op=AtomicOpKind.CAS, - ptr=ptr, - a=val_ir, - b=cmp_ir, - mem_order=sem, - syncscope=scope, - loc=loc, - ip=ip, - ) else: result = nvvm.atomicrmw( op=AtomicOpKind.CAS, ptr=ptr, - a=cmp_ir, - b=val_ir, + a=val_ir, + b=cmp_ir, mem_order=sem, syncscope=scope, loc=loc, @@ -2265,7 +2786,7 @@ def atomic_cas( @dsl_user_op def store( - ptr, + ptr: Union[ir.Value, Pointer], val: Union[Numeric, ir.Value], *, level1_eviction_priority: Optional[ @@ -2281,8 +2802,8 @@ def store( ss: Optional[Literal["cta", "cluster"]] = None, sem: Optional[Literal["relaxed", "release"]] = None, scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Store a value to a memory location. @@ -2358,7 +2879,7 @@ def store( @dsl_user_op def load( - ptr, + ptr: Union[ir.Value, Pointer], dtype: Union[type[Numeric], ir.VectorType], *, sem: Optional[Literal["relaxed", "acquire"]] = None, @@ -2375,8 +2896,8 @@ def load( cop: Optional[Literal["ca", "cg", "cs", "lu", "cv"]] = None, ss: Optional[Literal["cta", "cluster"]] = None, level_prefetch_size: Optional[Literal["size_64b", "size_128b", "size_256b"]] = None, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Numeric, ir.Value]: """ Load a value from a memory location. @@ -2464,11 +2985,167 @@ def load( if is_vector: return result else: + assert scalar_dtype is not None return scalar_dtype(result) @dsl_user_op -def cvt_f4e2m1_f16(src, *, loc=None, ip=None): +def red( + ptr: Union[ir.Value, Pointer], + val: Union[Numeric, ir.Value], + *, + op: Literal["add", "min", "max", "umin", "umax", "and", "or", "xor"], + dtype: Union[ + Literal[ + "b32", + "b64", + "u32", + "u64", + "s32", + "s64", + "f32", + "f64", + "f16", + "f16x2", + "bf16", + "bf16x2", + ], + type[Numeric], + ], + sem: Optional[Literal["relaxed", "release"]] = None, + scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """ + Perform an atomic reduction operation on a memory location. + + Atomically computes: ptr = ptr x val, where x is the reduction operation. + + :param ptr: Pointer to memory location (global or shared). Supports: + - ir.Value (LLVM pointer) + - cute.ptr (_Pointer instance) + :param val: Value to reduce with the memory location (scalar Numeric or vector ir.Value) + :type val: Union[Numeric, ir.Value] + :param op: Reduction operation string literal: + "add" : Addition + "min" : Minimum (signedness determined by dtype) + "max" : Maximum (signedness determined by dtype) + "umin" : Unsigned minimum (alias for "min", forces dtype to unsigned) + "umax" : Unsigned maximum (alias for "max", forces dtype to unsigned) + "and" : Bitwise AND + "or" : Bitwise OR + "xor" : Bitwise XOR + :type op: Literal["add", "min", "max", "umin", "umax", "and", "or", "xor"] + :param dtype: Data type. Supports string literals ("b32", "b64", "u32", "u64", "s32", "s64", + "f32", "f64", "f16", "f16x2", "bf16", "bf16x2") or cutlass types (Uint32, Uint64, + Int32, Int64, Float32, Float64, Float16, BFloat16) + :type dtype: Union[str, type[Numeric]] + :param sem: Memory ordering semantics string literal: + "relaxed" : Relaxed memory ordering + "release" : Release memory ordering + None : No memory ordering specified + :type sem: Optional[Literal["relaxed", "release"]] + :param scope: Memory scope string literal: + "gpu" : GPU scope + "cta" : CTA/block scope + "cluster" : Cluster scope + "sys" : System scope + None : No scope specified + :type scope: Optional[Literal["gpu", "cta", "cluster", "sys"]] + :return: None (operation modifies memory in-place) + :rtype: None + + .. note:: + This operation modifies memory in-place and returns None. + The old value is NOT returned (unlike atomic_add, atomic_max, etc.). + For operations that need the old value, use the atomic_* functions instead. + """ + from cutlass._mlir.dialects.nvvm import ( + ReductionOp, + ReductionType, + MemOrderKind, + MemScopeKind, + ) + + # Enhance enums and convert string literals to enum types + ReductionOp = _enhance_enum_with_str_mapping(ReductionOp) + ReductionType = _enhance_enum_with_str_mapping(ReductionType) + MemOrderKind = _enhance_enum_with_str_mapping(MemOrderKind) + MemScopeKind = _enhance_enum_with_str_mapping(MemScopeKind) + + # Handle "umax"/"umin" aliases: map to "max"/"min" and ensure unsigned dtype + _unsigned_op_aliases: dict[str, Literal["max", "min"]] = { + "umax": "max", + "umin": "min", + } + if op in _unsigned_op_aliases: + op = _unsigned_op_aliases[op] + # If dtype is a cutlass signed type, promote to its unsigned counterpart + _signed_to_unsigned: dict[type[Numeric], type[Numeric]] = { + Int32: Uint32, + Int64: Uint64, + } + if not isinstance(dtype, str) and dtype in _signed_to_unsigned: + dtype = _signed_to_unsigned[dtype] + # If dtype is a signed string literal, promote to unsigned + _signed_str_to_unsigned: dict[str, Literal["u32", "u64"]] = { + "s32": "u32", + "s64": "u64", + } + if isinstance(dtype, str) and dtype in _signed_str_to_unsigned: + dtype = _signed_str_to_unsigned[dtype] + + # Process dtype parameter: normalize to string literal and numeric type + dtype_str: str + if isinstance(dtype, str): + # dtype is already a string literal (e.g., "u32") + dtype_str = dtype + dtype_numeric = _reduction_type_str_to_cutlass_dtype(dtype_str) + else: + # dtype is a cutlass numeric type (e.g., cutlass.Uint32) + dtype_numeric = dtype + dtype_str = _cutlass_dtype_to_reduction_type_str(dtype_numeric) + + # Convert string literals to enum types + op_enum = ReductionOp.from_str(op) + dtype_enum = ReductionType.from_str(dtype_str) + sem = MemOrderKind.from_str(sem) + scope = MemScopeKind.from_str(scope) + + # Normalize pointer type to MLIR ir.Value + ptr = _normalize_ptr(ptr, loc=loc, ip=ip) + + # Determine if val is a vector type or scalar type + is_vector = isinstance(val, ir.Value) and isinstance(val.type, ir.VectorType) + + if is_vector: + val_ir = val + else: + if not isinstance(val, Numeric): + # Use dtype_numeric to convert the value to the specified type + val = dtype_numeric(val) + val_ir = val.ir_value(loc=loc, ip=ip) + + nvvm.red( + op=op_enum, + type_=dtype_enum, + a=ptr, + b=val_ir, + mem_order=sem, + mem_scope=scope, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def cvt_f4e2m1_f16( + src: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # 0 padding for upper 4 bits zero = arith.constant(src.type, 0, loc=loc, ip=ip) vec2 = vector.from_elements( @@ -2484,10 +3161,15 @@ def cvt_f4e2m1_f16(src, *, loc=None, ip=None): # Convert 2 float4e2m1 values to 2 float16 values @dsl_user_op -def cvt_f4e2m1x2_to_f16x2(src_vec2, *, loc=None, ip=None): +def cvt_f4e2m1x2_to_f16x2( + src_vec2: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # pack 2 float4e2m1 into 1 int8 value and fill upper bits with 0 src_i8 = llvm.bitcast(Int8.mlir_type, src_vec2, loc=loc, ip=ip) - src_i16 = llvm.zext(Int16.mlir_type, src_i8, loc=loc, ip=ip) + src_i16 = arith.extui(Int16.mlir_type, src_i8, loc=loc, ip=ip) rst_i32 = llvm.inline_asm( Int32.mlir_type, [src_i16], @@ -2505,7 +3187,12 @@ def cvt_f4e2m1x2_to_f16x2(src_vec2, *, loc=None, ip=None): # Convert 4 float4e2m1 values to 4 float16 values @dsl_user_op -def cvt_f4e2m1x4_to_f16x4(src_vec4, *, loc=None, ip=None): +def cvt_f4e2m1x4_to_f16x4( + src_vec4: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # pack 4 float4e2m1 into 1 int16 value src_i16 = llvm.bitcast(Int16.mlir_type, src_vec4, loc=loc, ip=ip) rst_i32x2 = llvm.inline_asm( @@ -2530,7 +3217,12 @@ def cvt_f4e2m1x4_to_f16x4(src_vec4, *, loc=None, ip=None): # Convert 8 float4e2m1 values to 8 float16 values @dsl_user_op -def cvt_f4e2m1x8_to_f16x8(src_vec8, *, loc=None, ip=None): +def cvt_f4e2m1x8_to_f16x8( + src_vec8: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: # pack 8 float4e2m1 into 1 int32 value and fill upper bits with 0 src_i32 = llvm.bitcast(Int32.mlir_type, src_vec8, loc=loc, ip=ip) rst_i32x4 = llvm.inline_asm( @@ -2559,40 +3251,1896 @@ def cvt_f4e2m1x8_to_f16x8(src_vec8, *, loc=None, ip=None): return vec_f16x8 + @dsl_user_op -def mapa(ptr, cta_rank_in_cluster=0, *, loc=None, ip=None): +def smid( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: """ - Map a pointer to distributed shared memory across cluster. + Returns the SM (Streaming Multiprocessor) ID of the current thread. - Portable wrapper that uses the appropriate NVVM API based on CUDA version: - - CUDA 13.1+: Uses nvvm.mapa with dsmem address space - - CUDA 12.9: Uses nvvm.mapa_shared_cluster + The SM ID is a unique identifier for the streaming multiprocessor executing + the current thread. Valid range is 0 to nsmid() - 1. - Args: - ptr: Pointer to shared memory (llvm_ptr attribute will be used) - cta_rank_in_cluster: CTA rank within the cluster (default 0) + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-smid - Returns: - Mapped LLVM pointer to shared memory + :return: SM ID of the current thread + :rtype: Int32 """ - if target_version(min_version="13.1"): - dsmem_ptr_ty = llvm.PointerType.get(7) # dsmem - smem_ptr_ty = llvm.PointerType.get(3) # smem + return Int32(nvvm.read_ptx_sreg_smid(T.i32(), loc=loc, ip=ip)) - llvm_ptr = nvvm.mapa( - dsmem_ptr_ty, - ptr.llvm_ptr, - Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - return llvm.addrspacecast(smem_ptr_ty, llvm_ptr, loc=loc, ip=ip) + +@dsl_user_op +def nsmid( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: + """ + Returns the number of SMs (Streaming Multiprocessors) on the device. + + This returns the total count of SMs available on the GPU, which defines + the valid range for smid() as [0, nsmid() - 1]. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-nsmid + + :return: Total number of SMs on the device + :rtype: Int32 + """ + return Int32(nvvm.read_ptx_sreg_nsmid(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def clock( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int32: + """ + Returns a 32-bit clock counter value. + + Reads the per-SM clock counter, which can be used for timing and profiling. + The counter wraps around on overflow. For extended range, use clock64(). + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-clock + + :return: 32-bit clock counter value + :rtype: Int32 + """ + return Int32(nvvm.read_ptx_sreg_clock(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def clock64( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Int64: + """ + Returns a 64-bit clock counter value. + + Reads the per-SM 64-bit clock counter, providing extended range compared + to the 32-bit clock(). Useful for timing longer operations without overflow. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-clock64 + + :return: 64-bit clock counter value + :rtype: Int64 + """ + return Int64(nvvm.read_ptx_sreg_clock64(T.i64(), loc=loc, ip=ip)) + + +@dsl_user_op +def match_sync( + mask: Union[int, Int32, Uint32, Int64, Uint64], + value: Union[int, Int32, Uint32, Int64, Uint64], + kind: Literal["any", "all"] = "any", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: + """ + Finds threads in a warp with matching values using warp-synchronous matching. + + Performs a broadcast and compare of the operand value across threads specified + by the mask. Returns a mask indicating which threads have matching values. + + - "any" mode: Returns mask of threads that have the same value as any other thread + - "all" mode: Returns mask of threads where all active threads have the same value + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-match-sync + + :param mask: Mask of participating threads (typically 0xFFFFFFFF for full warp) + :type mask: Union[int, Int32, Uint32] + :param value: Value to match across threads + :type value: Union[int, Int32, Uint32] + :param kind: Match mode - "any" or "all" + :type kind: Literal["any", "all"] + :return: Mask of threads with matching values + :rtype: Uint32 + """ + # Convert kind string to MatchSyncKind enum + if kind == "any": + kind_enum = nvvm.MatchSyncKind.any + elif kind == "all": + kind_enum = nvvm.MatchSyncKind.all else: - llvm_ptr = ptr.llvm_ptr - return nvvm.mapa_shared_cluster( - llvm_ptr.type, - llvm_ptr, - Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + raise ValueError(f"Invalid kind '{kind}', must be 'any' or 'all'") + + mask_ir = Int32(mask).ir_value(loc=loc, ip=ip) + if isinstance(value, (Int64, Uint64)): + value_ir = value.ir_value(loc=loc, ip=ip) + else: + value_ir = Int32(value).ir_value(loc=loc, ip=ip) + + if kind_enum == nvvm.MatchSyncKind.all: + result = nvvm.match_sync( + llvm.StructType.get_literal([T.i32(), Boolean.mlir_type]), + mask_ir, + value_ir, + kind_enum, loc=loc, ip=ip, ) + return Uint32(llvm.extractvalue(T.i32(), result, [0], loc=loc, ip=ip)) + else: + result = nvvm.match_sync( + T.i32(), + mask_ir, + value_ir, + kind_enum, + loc=loc, + ip=ip, + ) + return Uint32(result) + + +@dsl_user_op +def clz( + value: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: + """ + Counts the number of leading zero bits (count leading zeros). + + https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-clz + + Returns the number of consecutive zero bits starting from the most significant bit. + For a 32-bit value, returns a value in range [0, 32]. For 64-bit, range is [0, 64]. + + :param value: Input value (32-bit or 64-bit integer) + :type value: Union[Int32, Uint32, Int64, Uint64] + :return: Count of leading zero bits (same bit width as input) + :rtype: Union[Int32, Int64] + """ + + # Determine instruction and result type based on input type + if isinstance(value, (Int32, Uint32)): + asm_str = "clz.b32 $0, $1;" + constraints = "=r,r" + elif isinstance(value, (Int64, Uint64)): # Int64 or Uint64 + asm_str = "clz.b64 $0, $1;" + constraints = "=r,l" + else: + raise TypeError(f"Invalid value type for clz: {type(value)}") + + value_ir = value.ir_value(loc=loc, ip=ip) + + result = llvm.inline_asm( + T.i32(), + [value_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def bfind( + value: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: + """ + Finds the bit position of the most significant non-sign bit. + + For unsigned, finds the most significant 1 bit. For signed, finds the most + significant bit that differs from the sign bit. Returns 0xFFFFFFFF if not found. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-bfind + + :param value: Input value (32-bit or 64-bit integer) + :type value: Union[Int32, Uint32, Int64, Uint64] + :return: Bit position (0-31 or 0-63) or 0xFFFFFFFF if not found + :rtype: Union[Int32, Int64] + """ + + if not isinstance(value, (Int32, Uint32, Int64, Uint64)): + raise TypeError(f"Invalid value type for bfind: {type(value)}") + + value_ir = value.ir_value(loc=loc, ip=ip) + + if isinstance(value, Int32): + asm_str = "bfind.s32 $0, $1;" + constraints = "=r,r" + result_type = T.i32() + return_type = Uint32 + elif isinstance(value, Int64): + asm_str = "bfind.s64 $0, $1;" + constraints = "=r,l" + result_type = T.i32() + return_type = Uint32 + elif isinstance(value, Uint32): + asm_str = "bfind.u32 $0, $1;" + constraints = "=r,r" + result_type = T.i32() + return_type = Int32 # type: ignore[assignment] + elif isinstance(value, Uint64): # Uint64 + asm_str = "bfind.u64 $0, $1;" + constraints = "=r,l" + result_type = T.i32() + return_type = Uint32 + + result = llvm.inline_asm( + result_type, + [value_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return return_type(result) + + +@overload +def brev( + value: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def brev( + value: Uint32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@overload +def brev( + value: Int64, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int64: ... + + +@overload +def brev( + value: Uint64, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint64: ... + + +@dsl_user_op +def brev( + value: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Reverses the bits in the value. + + Returns the input value with bits reversed. Bit 0 becomes bit 31 (or 63), + bit 1 becomes bit 30 (or 62), etc. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-brev + + :param value: Input value (32-bit or 64-bit integer) + :type value: Union[Int32, Uint32, Int64, Uint64] + :return: Bit-reversed value (same type as input) + :rtype: Union[Int32, Uint32, Int64, Uint64] + + """ + if not isinstance(value, (Int32, Uint32, Int64, Uint64)): + raise TypeError(f"Invalid value type for brev: {type(value)}") + + value_ir = value.ir_value(loc=loc, ip=ip) + value_type = type(value) + + # Determine instruction based on input type + if isinstance(value, (Int32, Uint32)): + asm_str = "brev.b32 $0, $1;" + constraints = "=r,r" + result_type = T.i32() + else: # Int64 or Uint64 + asm_str = "brev.b64 $0, $1;" + constraints = "=l,l" + result_type = T.i64() + + result = llvm.inline_asm( + result_type, + [value_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@overload +def bfe( + value: Int32, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def bfe( + value: Uint32, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@overload +def bfe( + value: Int64, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int64: ... + + +@overload +def bfe( + value: Uint64, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint64: ... + + +@dsl_user_op +def bfe( + value: Union[Int32, Uint32, Int64, Uint64], + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + + Extract bit field from value and place the zero or sign-extended result. + Source start gives the bit field starting bit position, and source length gives the + bit field length in bits. + + The result and value must have the same type. + + Start and length are 32 bits, but are restricted to the 8-bit value range 0..255. + + The sign bit of the extracted field is defined as: + + Uint32 or Uint64 value: zero + + Int32 or Int64 value: + Most significant bit (msb) of input value if the extracted field extends beyond the + msb of the input value, otherwise if the bit field length is zero, the result is zero. + + The result is padded with the sign bit of the extracted field. + If the start position is beyond the msb of the input, the result is filled with the + replicated sign bit of the extracted field. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-bfe + + :param value: Source value to extract from + :type value: Union[Int32, Uint32] + :param start: Starting bit position (0-31) + :type start: Union[int, Int32, Uint32] + :param length: Number of bits to extract (0-32) + :type length: Union[int, Int32, Uint32] + :return: Extracted bit field (right-justified) + :rtype: Union[Int32, Uint32] + """ + + if not isinstance(value, (Int32, Uint32, Int64, Uint64)): + raise TypeError(f"Invalid value type for bfe: {type(value)}") + + value_ir = value.ir_value(loc=loc, ip=ip) + start_ir = Int32(start).ir_value(loc=loc, ip=ip) + length_ir = Int32(length).ir_value(loc=loc, ip=ip) + + return_type: Type[Integer] + if isinstance(value, (Int32, Uint32)): + result_type = T.i32() + constraints = "=r,r,r,r" + if isinstance(value, Int32): + asm_str = "bfe.s32 $0, $1, $2, $3;" + return_type = Int32 + else: + asm_str = "bfe.u32 $0, $1, $2, $3;" + return_type = Uint32 + else: # Int64 or Uint64 + result_type = T.i64() + constraints = "=l,l,r,r" + if isinstance(value, Int64): + asm_str = "bfe.s64 $0, $1, $2, $3;" + return_type = Int64 + else: + asm_str = "bfe.u64 $0, $1, $2, $3;" + return_type = Uint64 + + result = llvm.inline_asm( + result_type, + [value_ir, start_ir, length_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return return_type(result) + + +@overload +def bfi( + replacement: Int32, + value: Int32, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def bfi( + replacement: Uint32, + value: Uint32, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@overload +def bfi( + replacement: Int64, + value: Int64, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int64: ... + + +@overload +def bfi( + replacement: Uint64, + value: Uint64, + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint64: ... + + +@dsl_user_op +def bfi( + replacement: Union[Int32, Uint32, Int64, Uint64], + value: Union[Int32, Uint32, Int64, Uint64], + start: Union[int, Int32, Uint32], + length: Union[int, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Inserts a bit field into a value (bit field insert). + + Replaces a contiguous sequence of bits in the value with bits from the + replacement operand. Bits outside the specified field are preserved from + the original value. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-bfi + + :param value: Original value to insert into + :type value: Union[Int32, Uint32] + :param replacement: Value containing bits to insert + :type replacement: Union[Int32, Uint32] + :param start: Starting bit position (0-31) + :type start: Union[int, Int32, Uint32] + :param length: Number of bits to insert (0-32) + :type length: Union[int, Int32, Uint32] + :return: Value with bit field replaced + :rtype: Union[Int32, Uint32] + + **Architecture**: SM 20+ + + **Example**:: + + # Insert 0xF into bits [11:8] of 0x12345678 + result = bfi(Uint32(0x12345678), Uint32(0xF), start=8, length=4) + # Returns 0x12345F78 + """ + if type(value) is not type(replacement): + raise TypeError( + "bfi requires value and replacement to have the same integer type" + ) + + if isinstance(value, (Int32, Uint32)): + result_type = T.i32() + asm_str = "bfi.b32 $0, $1, $2, $3, $4;" + constraints = "=r,r,r,r,r" + elif isinstance(value, (Int64, Uint64)): + result_type = T.i64() + asm_str = "bfi.b64 $0, $1, $2, $3, $4;" + constraints = "=l,l,l,r,r" + else: + raise TypeError(f"Invalid value type for bfi: {type(value)}") + + replacement_ir = replacement.ir_value(loc=loc, ip=ip) + value_ir = value.ir_value(loc=loc, ip=ip) + start_ir = Int32(start).ir_value(loc=loc, ip=ip) + length_ir = Int32(length).ir_value(loc=loc, ip=ip) + + result = llvm.inline_asm( + result_type, + [replacement_ir, value_ir, start_ir, length_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return type(value)(result) + + +@overload +def mul_hi( + a: Int32, + b: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def mul_hi( + a: Uint32, + b: Uint32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@overload +def mul_hi( + a: Int64, + b: Int64, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int64: ... + + +@overload +def mul_hi( + a: Uint64, + b: Uint64, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint64: ... + + +@dsl_user_op +def mul_hi( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Multiplies two values and returns the high-order bits of the result. + + Performs a full-width multiplication and returns the upper half of the result. + For 32-bit inputs, returns bits [63:32]. For 64-bit inputs, returns bits [127:64]. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-mul-hi + + :param a: First multiplicand + :type a: Union[Int32, Uint32, Int64, Uint64] + :param b: Second multiplicand + :type b: Union[Int32, Uint32, Int64, Uint64] + :return: High-order bits of the product (same type as inputs) + :rtype: Union[Int32, Uint32, Int64, Uint64] + """ + + if type(a) is not type(b) or not isinstance(a, (Int32, Uint32, Int64, Uint64)): + raise TypeError( + "Invalid value types for mul_hi: a and b must be the same type " + f"(both Int32, Uint32, Int64, or Uint64), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + value_type = type(a) + + # Determine instruction based on type + if isinstance(a, Int32): + asm_str = "mul.hi.s32 $0, $1, $2;" + constraints = "=r,r,r" + result_type = T.i32() + elif isinstance(a, Uint32): + asm_str = "mul.hi.u32 $0, $1, $2;" + constraints = "=r,r,r" + result_type = T.i32() + elif isinstance(a, Int64): + asm_str = "mul.hi.s64 $0, $1, $2;" + constraints = "=l,l,l" + result_type = T.i64() + else: # Uint64 + asm_str = "mul.hi.u64 $0, $1, $2;" + constraints = "=l,l,l" + result_type = T.i64() + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@overload +def mul_wide( + a: Int16, + b: Int16, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def mul_wide( + a: Uint16, + b: Uint16, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@overload +def mul_wide( + a: Int32, + b: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int64: ... + + +@overload +def mul_wide( + a: Uint32, + b: Uint32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint64: ... + + +@dsl_user_op +def mul_wide( + a: Union[Int16, Uint16, Int32, Uint32], + b: Union[Int16, Uint16, Int32, Uint32], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Multiplies two narrow values and returns a wide result. + + Performs multiplication with automatic widening of the result type. + 16-bit inputs produce 32-bit result. 32-bit inputs produce 64-bit result. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-mul + + :param a: First multiplicand (16-bit or 32-bit) + :type a: Union[Int16, Uint16, Int32, Uint32] + :param b: Second multiplicand (must match signedness of a) + :type b: Union[Int16, Uint16, Int32, Uint32] + :return: Wide product (32-bit for 16-bit inputs, 64-bit for 32-bit inputs) + :rtype: Union[Int32, Uint32, Int64, Uint64] + """ + if type(a) is not type(b) or not isinstance(a, (Int16, Uint16, Int32, Uint32)): + raise TypeError( + "Invalid value types for mul_wide: a and b must be the same type " + f"(both Int16, Uint16, Int32, or Uint32), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + + # Determine instruction and return type based on input type + return_type: Type[Integer] + if isinstance(a, Int16): + asm_str = "mul.wide.s16 $0, $1, $2;" + constraints = "=r,h,h" + result_type = T.i32() + return_type = Int32 + elif isinstance(a, Uint16): + asm_str = "mul.wide.u16 $0, $1, $2;" + constraints = "=r,h,h" + result_type = T.i32() + return_type = Uint32 + elif isinstance(a, Int32): + asm_str = "mul.wide.s32 $0, $1, $2;" + constraints = "=l,r,r" + result_type = T.i64() + return_type = Int64 + else: # Uint32 + asm_str = "mul.wide.u32 $0, $1, $2;" + constraints = "=l,r,r" + result_type = T.i64() + return_type = Uint64 + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir], + asm_str, + constraints, + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return return_type(result) + + +@overload +def mul24( + a: Int32, + b: Int32, + hi: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def mul24( + a: Uint32, + b: Uint32, + hi: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@dsl_user_op +def mul24( + a: Union[Int32, Uint32], + b: Union[Int32, Uint32], + hi: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32]: + """ + Fast 24-bit integer multiplication. + + Multiplies the low 24 bits of each operand. Bits [31:24] are ignored. + Result can be either low 32 bits (hi=False) or high 32 bits (hi=True). + + t = a * b; + d = t<47..16> # for .hi variant (if hi is True) + d = t<31..0> # for .lo variant (if hi is False) + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-mul24 + + :param a: First operand (only low 24 bits used) + :type a: Union[Int32, Uint32] + :param b: Second operand (only low 24 bits used) + :type b: Union[Int32, Uint32] + :param hi: If True, return high 32 bits; if False, return low 32 bits + :type hi: bool + :return: Product of low 24 bits + :rtype: Union[Int32, Uint32] + """ + + if type(a) is not type(b) or not isinstance(a, (Int32, Uint32)): + raise TypeError( + "Invalid value types for mul24: a and b must be the same type " + f"(both Int32, Uint32), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + value_type = type(a) + + # Build instruction string + lohi = "hi" if hi else "lo" + if isinstance(a, Int32): + asm_str = f"mul24.{lohi}.s32 $0, $1, $2;" + else: # Uint32 + asm_str = f"mul24.{lohi}.u32 $0, $1, $2;" + + result = llvm.inline_asm( + T.i32(), + [a_ir, b_ir], + asm_str, + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@overload +def mad24( + a: Int32, + b: Int32, + c: Int32, + hi: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def mad24( + a: Uint32, + b: Uint32, + c: Uint32, + hi: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@dsl_user_op +def mad24( + a: Union[Int32, Uint32], + b: Union[Int32, Uint32], + c: Union[Int32, Uint32], + hi: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32]: + """ + Fast 24-bit integer multiply-add. + + Computes (a * b) + c using only the low 24 bits of a and b. + Result can be either low 32 bits (hi=False) or high 32 bits (hi=True). + + t = a * b + d = t<47..16> + c # for .hi variant (if hi is True) + d = t<31..0> + c # for .lo variant (if hi is False) + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-mad24 + + :param a: First multiplicand (only low 24 bits used) + :type a: Union[Int32, Uint32] + :param b: Second multiplicand (only low 24 bits used) + :type b: Union[Int32, Uint32] + :param c: Addend (all 32 bits used) + :type c: Union[Int32, Uint32] + :param hi: If True, return high 32 bits; if False, return low 32 bits + :type hi: bool + :return: (a * b) + c + :rtype: Union[Int32, Uint32] + """ + if not isinstance(a, (Int32, Uint32)): + raise TypeError("mad24 requires a to be an Int32 or Uint32") + + if type(a) is not type(b) or type(a) is not type(c): + raise TypeError("mad24 requires a, b, and c to have the same integer type") + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + c_ir = c.ir_value(loc=loc, ip=ip) + value_type = type(a) + + # Build instruction string + lohi = "hi" if hi else "lo" + if isinstance(a, Int32): + asm_str = f"mad24.{lohi}.s32 $0, $1, $2, $3;" + else: # Uint32 + asm_str = f"mad24.{lohi}.u32 $0, $1, $2, $3;" + + result = llvm.inline_asm( + T.i32(), + [a_ir, b_ir, c_ir], + asm_str, + "=r,r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def add_cc( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Addition with carry-out (sets carry flag). + + Performs addition and sets the carry flag for use by subsequent addc() operations. + This is the first operation in a multi-precision addition chain. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#extended-precision-arithmetic-instructions-add-cc + + :param a: First operand + :type a: Union[Int32, Uint32] + :param b: Second operand + :type b: Union[Int32, Uint32] + :return: Sum (a + b) + :rtype: Union[Int32, Uint32] + """ + if type(a) is not type(b) or not isinstance(a, (Int32, Uint32, Int64, Uint64)): + raise TypeError( + "Invalid value types for add_cc: a and b must be the same type " + f"(both Int32, Uint32, Int64, or Uint64), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + value_type = type(a) + + if isinstance(a, (Int32, Uint32)): + result_type = T.i32() + asm_str = "add.cc.u32 $0, $1, $2;" + constraints = "=r,r,r" + else: # Int64 or Uint64 + result_type = T.i64() + asm_str = "add.cc.u64 $0, $1, $2;" + constraints = "=l,l,l" + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir], + asm_str, + constraints, + has_side_effects=True, # Modifies carry flag + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def addc( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Addition with carry-in (reads carry flag). + + Performs addition including the carry flag set by add_cc() or previous addc(). + This continues a multi-precision addition chain. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#extended-precision-arithmetic-instructions-addc + + :param a: First operand + :type a: Union[Int32, Uint32] + :param b: Second operand + :type b: Union[Int32, Uint32] + :return: Sum (a + b + carry_flag) + :rtype: Union[Int32, Uint32] + """ + if type(a) is not type(b) or not isinstance(a, (Int32, Uint32, Int64, Uint64)): + raise TypeError( + "Invalid value types for addc: a and b must be the same type " + f"(both Int32, Uint32, Int64, or Uint64), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + value_type = type(a) + + if isinstance(a, (Int32, Uint32)): + result_type = T.i32() + asm_str = "addc.u32 $0, $1, $2;" + constraints = "=r,r,r" + else: # Int64 or Uint64 + result_type = T.i64() + asm_str = "addc.u64 $0, $1, $2;" + constraints = "=l,l,l" + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir], + asm_str, + constraints, + has_side_effects=True, # Reads and may modify carry flag + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def sub_cc( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Subtraction with carry-out (sets carry/borrow flag). + + Performs subtraction and sets the carry flag for use by subsequent subc() operations. + This is the first operation in a multi-precision subtraction chain. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#extended-precision-arithmetic-instructions-sub-cc + + :param a: Value to subtract from + :type a: Union[Int32, Uint32] + :param b: Value to subtract + :type b: Union[Int32, Uint32] + :return: Difference (a - b) + :rtype: Union[Int32, Uint32] + """ + if type(a) is not type(b) or not isinstance(a, (Int32, Uint32, Int64, Uint64)): + raise TypeError( + "Invalid value types for sub_cc: a and b must be the same type " + f"(both Int32, Uint32, Int64, or Uint64), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + value_type = type(a) + + if isinstance(a, (Int32, Uint32)): + result_type = T.i32() + asm_str = "sub.cc.u32 $0, $1, $2;" + constraints = "=r,r,r" + else: # Int64 or Uint64 + result_type = T.i64() + asm_str = "sub.cc.u64 $0, $1, $2;" + constraints = "=l,l,l" + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir], + asm_str, + constraints, + has_side_effects=True, # Modifies carry flag + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def subc( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Subtraction with carry-in (reads carry/borrow flag). + + Performs subtraction including the carry flag set by sub_cc() or previous subc(). + This continues a multi-precision subtraction chain. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#extended-precision-arithmetic-instructions-subc + + :param a: Value to subtract from + :type a: Union[Int32, Uint32] + :param b: Value to subtract + :type b: Union[Int32, Uint32] + :return: Difference (a - b - carry_flag) + :rtype: Union[Int32, Uint32] + """ + if type(a) is not type(b) or not isinstance(a, (Int32, Uint32, Int64, Uint64)): + raise TypeError( + "Invalid value types for subc: a and b must be the same type " + f"(both Int32, Uint32, Int64, or Uint64), got {type(a)} and {type(b)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + value_type = type(a) + + if isinstance(a, (Int32, Uint32)): + result_type = T.i32() + asm_str = "subc.u32 $0, $1, $2;" + constraints = "=r,r,r" + else: # Int64 or Uint64 + result_type = T.i64() + asm_str = "subc.u64 $0, $1, $2;" + constraints = "=l,l,l" + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir], + asm_str, + constraints, + has_side_effects=True, # Reads and may modify carry flag + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def mad_cc( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + c: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Multiply-add with carry-out (sets carry flag). + + Performs (a * b) + c and sets the carry flag for use by subsequent madc() operations. + This starts a multi-precision multiply-add chain. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#extended-precision-arithmetic-instructions-mad-cc + + :param a: First multiplicand + :type a: Union[Int32, Uint32] + :param b: Second multiplicand + :type b: Union[Int32, Uint32] + :param c: Addend + :type c: Union[Int32, Uint32] + :return: Low 32 bits of (a * b) + c + :rtype: Union[Int32, Uint32] + """ + if ( + not isinstance(a, (Int32, Uint32, Int64, Uint64)) + or not isinstance(b, (Int32, Uint32, Int64, Uint64)) + or not isinstance(c, (Int32, Uint32, Int64, Uint64)) + ): + raise TypeError( + "mad_cc requires Int32/Uint32/Int64/Uint64 operands for a, b, and c" + ) + + if type(a) is not type(b) or type(a) is not type(c): + raise TypeError("mad_cc requires a, b, and c to have the same integer type") + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + c_ir = c.ir_value(loc=loc, ip=ip) + value_type = type(a) + + # Use .lo variant for low half of the product. + if isinstance(a, Int32): + asm_str = "mad.lo.cc.s32 $0, $1, $2, $3;" + constraints = "=r,r,r,r" + result_type = T.i32() + elif isinstance(a, Uint32): + asm_str = "mad.lo.cc.u32 $0, $1, $2, $3;" + constraints = "=r,r,r,r" + result_type = T.i32() + elif isinstance(a, Int64): + asm_str = "mad.lo.cc.s64 $0, $1, $2, $3;" + constraints = "=l,l,l,l" + result_type = T.i64() + else: # Uint64 + asm_str = "mad.lo.cc.u64 $0, $1, $2, $3;" + constraints = "=l,l,l,l" + result_type = T.i64() + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir, c_ir], + asm_str, + constraints, + has_side_effects=True, # Modifies carry flag + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def madc( + a: Union[Int32, Uint32, Int64, Uint64], + b: Union[Int32, Uint32, Int64, Uint64], + c: Union[Int32, Uint32, Int64, Uint64], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32, Int64, Uint64]: + """ + Multiply-add with carry-in (reads carry flag). + + Performs (a * b) + c + carry_flag. This continues a multi-precision multiply-add chain. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#extended-precision-arithmetic-instructions-madc + :param a: First multiplicand + :type a: Union[Int32, Uint32] + :param b: Second multiplicand + :type b: Union[Int32, Uint32] + :param c: Addend + :type c: Union[Int32, Uint32] + :return: Low 32 bits of (a * b) + c + carry_flag + :rtype: Union[Int32, Uint32] + """ + if ( + type(a) is not type(b) + or type(a) is not type(c) + or not isinstance(a, (Int32, Uint32, Int64, Uint64)) + ): + raise TypeError( + "Invalid value types for madc: a, b, and c must be the same type " + f"(both Int32, Uint32, Int64, or Uint64), got {type(a)}, {type(b)}, and {type(c)}" + ) + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + c_ir = c.ir_value(loc=loc, ip=ip) + value_type = type(a) + + # Use .lo variant for low half of the product. + if isinstance(a, Int32): + asm_str = "madc.lo.s32 $0, $1, $2, $3;" + constraints = "=r,r,r,r" + result_type = T.i32() + elif isinstance(a, Uint32): + asm_str = "madc.lo.u32 $0, $1, $2, $3;" + constraints = "=r,r,r,r" + result_type = T.i32() + elif isinstance(a, Int64): + asm_str = "madc.lo.s64 $0, $1, $2, $3;" + constraints = "=l,l,l,l" + result_type = T.i64() + else: # Uint64 + asm_str = "madc.lo.u64 $0, $1, $2, $3;" + constraints = "=l,l,l,l" + result_type = T.i64() + + result = llvm.inline_asm( + result_type, + [a_ir, b_ir, c_ir], + asm_str, + constraints, + has_side_effects=True, # Reads and may modify carry flag + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def activemask( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Uint32: + """ + Returns the mask of currently active threads in the warp. + + Returns a 32-bit mask where bit N is set if thread N in the warp is active + (not exited or diverged away). This reflects the current execution state. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-activemask + + :return: Mask of active threads in warp + :rtype: Uint32 + """ + result = llvm.inline_asm( + T.i32(), + [], + "activemask.b32 $0;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def lanemask_lt( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Uint32: + """ + Returns mask of lanes with ID less than current lane. + + Returns a 32-bit mask where bit N is set if N < current_lane_id. + For lane 0, returns 0x00000000. For lane 31, returns 0x7FFFFFFF. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-lanemask-lt + + :return: Mask of lanes with index < current lane + :rtype: Uint32 + + """ + result = llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %lanemask_lt;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def lanemask_le( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Uint32: + """ + Returns mask of lanes with ID less than or equal to current lane. + + Returns a 32-bit mask where bit N is set if N <= current_lane_id. + For lane 0, returns 0x00000001. For lane 31, returns 0xFFFFFFFF. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-lanemask-le + + :return: Mask of lanes with index <= current lane + :rtype: Uint32 + + """ + result = llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %lanemask_le;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def lanemask_eq( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Uint32: + """ + Returns mask with only the current lane's bit set. + + Returns a 32-bit mask where only bit current_lane_id is set. + Equivalent to (1 << lane_idx()). + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-lanemask-eq + + :return: Mask with only current lane bit set + :rtype: Uint32 + + """ + result = llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %lanemask_eq;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def lanemask_ge( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Uint32: + """ + Returns mask of lanes with ID greater than or equal to current lane. + + Returns a 32-bit mask where bit N is set if N >= current_lane_id. + For lane 0, returns 0xFFFFFFFF. For lane 31, returns 0x80000000. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-lanemask-ge + + :return: Mask of lanes with index >= current lane + :rtype: Uint32 + + """ + result = llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %lanemask_ge;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def lanemask_gt( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> Uint32: + """ + Returns mask of lanes with ID greater than current lane. + + Returns a 32-bit mask where bit N is set if N > current_lane_id. + For lane 0, returns 0xFFFFFFFE. For lane 31, returns 0x00000000. + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#special-registers-lanemask-gt + + :return: Mask of lanes with index > current lane + :rtype: Uint32 + + """ + result = llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %lanemask_gt;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Uint32(result) + + +@dsl_user_op +def add_sat_int( + a: Int32, + b: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: + """ + Saturating signed 32-bit addition. + + Performs addition with saturation. If the result overflows, it saturates to + INT32_MAX (0x7FFFFFFF). If it underflows, saturates to INT32_MIN (0x80000000). + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add + + :param a: First operand + :type a: Int32 + :param b: Second operand + :type b: Int32 + :return: Saturated sum + :rtype: Int32 + """ + if not isinstance(a, Int32) or not isinstance(b, Int32): + raise TypeError("add_sat expects Int32 operands") + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + + result = llvm.inline_asm( + T.i32(), + [a_ir, b_ir], + "add.sat.s32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Int32(result) + + +@dsl_user_op +def sub_sat_int( + a: Int32, + b: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: + """ + Saturating signed 32-bit subtraction. + + Performs subtraction with saturation. If the result overflows, it saturates to + INT32_MAX (0x7FFFFFFF). If it underflows, saturates to INT32_MIN (0x80000000). + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-sub + + :param a: Minuend + :type a: Int32 + :param b: Subtrahend + :type b: Int32 + :return: Saturated difference + :rtype: Int32 + """ + if not isinstance(a, Int32) or not isinstance(b, Int32): + raise TypeError("sub_sat expects Int32 operands") + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + + result = llvm.inline_asm( + T.i32(), + [a_ir, b_ir], + "sub.sat.s32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return Int32(result) + + +@overload +def lop3( + a: Int32, + b: Int32, + c: Int32, + lut: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def lop3( + a: Uint32, + b: Uint32, + c: Uint32, + lut: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@dsl_user_op +def lop3( + a: Union[Int32, Uint32], + b: Union[Int32, Uint32], + c: Union[Int32, Uint32], + lut: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32]: + """ + Three-input logic operation with lookup table. + + Performs an arbitrary 3-input boolean function defined by an 8-bit lookup table. + Each bit of the LUT corresponds to one combination of input bits (a, b, c). + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-lop3 + + :param a: First input + :type a: Union[Int32, Uint32] + :param b: Second input + :type b: Union[Int32, Uint32] + :param c: Third input + :type c: Union[Int32, Uint32] + :param lut: 8-bit lookup table defining the boolean function + :type lut: Union[int, Int32, Uint32] + :return: Result of the 3-input logic operation + :rtype: Union[Int32, Uint32] + + """ + if not isinstance(a, (Int32, Uint32)): + raise TypeError("lop3 expects Int32/Uint32 operands") + if type(a) is not type(b) or type(a) is not type(c): + raise TypeError("lop3 requires a, b, and c to have the same integer type") + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + c_ir = c.ir_value(loc=loc, ip=ip) + value_type = type(a) + + # LUT must be a constant + if not isinstance(lut, int): + raise TypeError("lut parameter must be an integer constant") + + lut = lut & 0xFF + + result = llvm.inline_asm( + T.i32(), + [a_ir, b_ir, c_ir], + f"lop3.b32 $0, $1, $2, $3, {lut};", + "=r,r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@overload +def shf( + a: Int32, + b: Int32, + shift: Union[int, Int32, Uint32], + kind: Literal["l", "r", "clamp"] = "l", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: ... + + +@overload +def shf( + a: Uint32, + b: Uint32, + shift: Union[int, Int32, Uint32], + kind: Literal["l", "r", "clamp"] = "l", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Uint32: ... + + +@dsl_user_op +def shf( + a: Union[Int32, Uint32], + b: Union[Int32, Uint32], + shift: Union[int, Int32, Uint32], + kind: Literal["l", "r", "clamp_left", "clamp_right"] = "l", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Int32, Uint32]: + """ + Funnel shift operation. + + Concatenates two 32-bit values into a 64-bit value and shifts/extracts a 32-bit result. + + - "l" (left): Shift left, extract high 32 bits + - "r" (right): Shift right, extract low 32 bits + - "clamp_left": Clamp shift left amount to [0, 32] + - "clamp_right": Clamp shift right amount to [0, 32] + + See https://docs.nvidia.com/cuda/parallel-thread-execution/#logic-and-shift-instructions-shf + + :param a: First 32-bit value (high part of concatenation) + :type a: Union[Int32, Uint32] + :param b: Second 32-bit value (low part of concatenation) + :type b: Union[Int32, Uint32] + :param shift: Shift amount + :type shift: Union[int, Int32, Uint32] + :param kind: Shift direction - "l" (left), "r" (right), or "clamp" + :type kind: Literal["l", "r", "clamp"] + :return: 32-bit result after funnel shift + :rtype: Union[Int32, Uint32] + """ + if not isinstance(a, (Int32, Uint32)) or not isinstance(b, (Int32, Uint32)): + raise TypeError("shf expects Int32/Uint32 inputs for a and b") + if type(a) is not type(b): + raise TypeError("shf requires a and b to have the same integer type") + + a_ir = a.ir_value(loc=loc, ip=ip) + b_ir = b.ir_value(loc=loc, ip=ip) + shift_ir = Int32(shift).ir_value(loc=loc, ip=ip) + value_type = type(a) + + # Build instruction string + if kind == "l": + direction = "l" + wrap_mode = "wrap" + elif kind == "r": + direction = "r" + wrap_mode = "wrap" + elif kind == "clamp_left": + direction = "l" + wrap_mode = "clamp" + elif kind == "clamp_right": + direction = "r" + wrap_mode = "clamp" + else: + raise ValueError( + f"Invalid kind '{kind}', must be 'l', 'r', 'clamp_left', or 'clamp_right'" + ) + + asm_str = f"shf.{direction}.{wrap_mode}.b32 $0, $1, $2, $3;" + + result = llvm.inline_asm( + T.i32(), + [a_ir, b_ir, shift_ir], + asm_str, + "=r,r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + return value_type(result) + + +@dsl_user_op +def prefetch( + addr: Any, + *, + cache_level: Optional[Literal["L1", "L2"]] = None, + evict_priority: Optional[Any] = None, + predicate: Optional[Any] = None, + tensormap: Optional[bool] = None, + uniform: Optional[bool] = None, + in_param_space: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """ + Prefetch data or TMA descriptor to cache. + + :param addr: LLVM pointer to prefetch. + :param cache_level: Prefetch cache level string ("L1", "L2"). Mutually exclusive with tensormap. + :param evict_priority: Cache eviction priority. + :param predicate: Optional predicate for conditional execution. + :param tensormap: If True, prefetch a tensormap descriptor. + :param uniform: If True, use uniform prefetch. + :param in_param_space: If True, address is in parameter space. + """ + if cache_level is not None and tensormap is not None: + raise ValueError("prefetch: cache_level and tensormap are mutually exclusive") + + # Default to L1 when neither cache_level nor tensormap is specified + if cache_level is None and tensormap is None: + cache_level = "L1" + + if cache_level is not None: + PrefetchCacheLevel = _enhance_enum_with_str_mapping(nvvm.PrefetchCacheLevel) + cache_level = PrefetchCacheLevel.from_str(cache_level) + + if cutlass_dsl.target_version(min_version="13.2"): + nvvm.prefetch( + addr, + cache_level=cache_level, + evict_priority=evict_priority, + predicate=predicate, + tensormap=tensormap, + uniform=uniform, + in_param_space=in_param_space, + loc=loc, + ip=ip, + ) + else: + # Fallback: inline PTX for builds without nvvm.prefetch op + if tensormap: + ptr_as_i64 = llvm.ptrtoint(T.i64(), addr, loc=loc, ip=ip) + llvm.inline_asm( + None, + [ptr_as_i64], + "prefetch.tensormap [$0];", + "l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + else: + level = "L1" + if cache_level is not None: + level = str(cache_level) + ptr_as_i64 = llvm.ptrtoint(T.i64(), addr, loc=loc, ip=ip) + llvm.inline_asm( + None, + [ptr_as_i64], + f"prefetch.global.{level} [$0];", + "l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) diff --git a/python/CuTeDSL/cutlass/cute/arch/smem.py b/python/CuTeDSL/cutlass/cute/arch/smem.py index ea5886a34..77ba50e2d 100644 --- a/python/CuTeDSL/cutlass/cute/arch/smem.py +++ b/python/CuTeDSL/cutlass/cute/arch/smem.py @@ -16,8 +16,9 @@ from cutlass.cutlass_dsl import T, dsl_user_op import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir +from cutlass._mlir.dialects import nvvm, llvm -from ..typing import Pointer, Numeric, NumericMeta, Layout +from ..typing import Int, Int32, Pointer, Numeric, NumericMeta @dsl_user_op @@ -26,8 +27,8 @@ def alloc_smem( size_in_elems: int, alignment: Optional[int] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: """ Statically allocates SMEM. @@ -66,8 +67,8 @@ def get_dyn_smem( element_type: Type[Numeric], alignment: Optional[int] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: """ Retrieves a pointer to a dynamic SMEM allocation. @@ -97,7 +98,9 @@ def get_dyn_smem( @dsl_user_op -def get_dyn_smem_size(*, loc=None, ip=None) -> int: +def get_dyn_smem_size( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> int: """ Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time. This can be used for bounds checking during shared memory allocation. @@ -106,3 +109,38 @@ def get_dyn_smem_size(*, loc=None, ip=None) -> int: :rtype: int """ return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip) + + +@dsl_user_op +def map_dsmem_ptr( + smem_ptr: Pointer, + cta_rank_in_cluster: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Pointer: + """ + Maps a shared memory pointer to a remote CTA's distributed shared memory. + + :param smem_ptr: A pointer in SMEM + :type smem_ptr: Pointer + :param cta_rank_in_cluster: The CTA in cluster to map to + :type cta_rank_in_cluster: Int + + :return: The remote shared memory CuTe pointer + :rtype: Pointer + + """ + dsmem_llvm_ptr = nvvm.mapa( + llvm.PointerType.get(_cute_ir.AddressSpace.dsmem), + smem_ptr.to_llvm_ptr(loc=loc, ip=ip), # type: ignore[attr-defined] + Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + intptr = llvm.ptrtoint(T.i32(), dsmem_llvm_ptr, loc=loc, ip=ip) + aligned_ty = _cute_ir.ConstrainedIntType.get(smem_ptr.alignment, 32) # type: ignore[attr-defined] + aligned_intptr = _cute_ir.assume(aligned_ty, intptr, loc=loc, ip=ip) + + return _cute_ir.inttoptr(smem_ptr.type, aligned_intptr, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/arch/tmem.py b/python/CuTeDSL/cutlass/cute/arch/tmem.py index 2b100aae9..28e37e6fc 100644 --- a/python/CuTeDSL/cutlass/cute/arch/tmem.py +++ b/python/CuTeDSL/cutlass/cute/arch/tmem.py @@ -9,14 +9,16 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Type +from typing import Optional, Type from cutlass.cutlass_dsl import dsl_user_op +from cutlass.base_dsl.arch import Arch +from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from ..typing import Pointer, Int, Int32, Numeric, NumericMeta, Tensor +from ..typing import Pointer, Int, Int32, Numeric, NumericMeta SM100_TMEM_CAPACITY_COLUMNS = ( 512 # deprecated; use get_max_tmem_alloc_cols(arch="sm_100") instead @@ -72,14 +74,15 @@ def get_min_tmem_alloc_cols(compute_capability: str) -> int: return TMEM_MIN_ALLOC_COLUMNS_MAP[compute_capability] + @dsl_user_op def retrieve_tmem_ptr( element_type: Type[Numeric], alignment: int, ptr_to_buffer_holding_addr: Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: """ Retrieves a pointer to TMEM with the provided element type and alignment. @@ -103,7 +106,10 @@ def retrieve_tmem_ptr( element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment ) return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr( - res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip + res_ty, + ptr_to_buffer_holding_addr.value, + loc=loc, + ip=ip, ) @@ -111,11 +117,11 @@ def retrieve_tmem_ptr( def alloc_tmem( num_columns: Int, smem_ptr_to_write_address: Pointer, - is_two_cta=None, + is_two_cta: Optional[bool] = None, *, arch: str = "sm_100", - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Allocates TMEM. @@ -135,11 +141,15 @@ def alloc_tmem( if ( num_columns < tmem_min_alloc_cols or num_columns > tmem_max_alloc_cols - or not (num_columns & (num_columns - 1) == 0) - ): - raise ValueError( - f"num_columns must be between {tmem_min_alloc_cols} and {tmem_max_alloc_cols}, and must be pow of 2, but got {num_columns}" + or not ( + (num_columns & (num_columns - 1) == 0) ) + ): + err_msg = f"num_columns must be between {tmem_min_alloc_cols} and {tmem_max_alloc_cols}, " + err_msg += "and must be pow of 2" + err_msg += f", but got {num_columns}." + raise ValueError(err_msg) + _cute_nvgpu_ir.arch_sm100_alloc_tmem( Int32(num_columns).ir_value(loc=loc, ip=ip), smem_ptr_to_write_address.value, @@ -150,7 +160,12 @@ def alloc_tmem( @dsl_user_op -def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None: +def relinquish_tmem_alloc_permit( + is_two_cta: Optional[bool] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can allocate. @@ -164,11 +179,11 @@ def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None: def dealloc_tmem( tmem_ptr: Pointer, num_columns: Int, - is_two_cta=None, + is_two_cta: Optional[bool] = None, *, arch: str = "sm_100", - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Deallocates TMEM using the provided pointer and number of columns. @@ -178,6 +193,8 @@ def dealloc_tmem( :param num_columns: The number of columns in the TMEM allocation :type num_columns: Int :param is_two_cta: Optional boolean parameter for 2-CTA MMAs + :param arch: The architecture of the GPU. + :type arch: str """ tmem_min_alloc_cols = get_min_tmem_alloc_cols(arch) tmem_max_alloc_cols = get_max_tmem_alloc_cols(arch) @@ -185,11 +202,15 @@ def dealloc_tmem( if ( num_columns < tmem_min_alloc_cols or num_columns > tmem_max_alloc_cols - or not (num_columns & (num_columns - 1) == 0) - ): - raise ValueError( - f"num_columns must be between {tmem_min_alloc_cols} and {tmem_max_alloc_cols}, and must be pow of 2, but got {num_columns}" + or not ( + (num_columns & (num_columns - 1) == 0) ) + ): + err_msg = f"num_columns must be between {tmem_min_alloc_cols} and {tmem_max_alloc_cols}, " + err_msg += "and must be pow of 2" + err_msg += f", but got {num_columns}." + raise ValueError(err_msg) + _cute_nvgpu_ir.arch_sm100_dealloc_tmem( tmem_ptr.value, Int32(num_columns).ir_value(loc=loc, ip=ip), diff --git a/python/CuTeDSL/cutlass/cute/atom.py b/python/CuTeDSL/cutlass/cute/atom.py index 0cda9aacf..6bfea228b 100644 --- a/python/CuTeDSL/cutlass/cute/atom.py +++ b/python/CuTeDSL/cutlass/cute/atom.py @@ -10,7 +10,7 @@ # is strictly prohibited. from abc import ABC, ABCMeta, abstractmethod -from typing import Type, Union, Optional, Any, List, Tuple, overload +from typing import Type, Union, Optional, Any, overload, List, Tuple from .typing import Shape, Layout, Tile, Tensor, Numeric, Int32 from .core import ( @@ -53,7 +53,13 @@ class MmaOp(Op, metaclass=ABCMeta): """ @abstractmethod - def _make_trait(self, *, loc=None, ip=None, **kwargs): + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "Trait": pass @@ -64,8 +70,13 @@ class CopyOp(Op, metaclass=ABCMeta): @abstractmethod def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs - ): + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "Trait": pass @@ -80,30 +91,61 @@ class Trait(ABC): def __init__(self, value: ir.Value) -> None: self.value = value - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> List[ir.Value]: return [self.value] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "Trait": return self.__class__(values[0]) - def set(self, field, value, *, loc=None, ip=None) -> None: + def set( + self, + field: Any, + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: raise NotImplementedError( "set not implemented, the requesting Atom has likely no runtime state" ) - def get(self, field, *, loc=None, ip=None) -> Any: + def get( + self, + field: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: raise NotImplementedError( "get not implemented, the requesting Atom has likely no runtime state" ) - def unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + def unpack( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> ir.Value: return self.value - def with_(self, *, loc=None, ip=None, **kwargs) -> "Trait": + def with_( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "Trait": return self.__class__(self.unpack(loc=loc, ip=ip, **kwargs)) -def make_atom(ty, values=None, *, loc=None, ip=None): +def make_atom( + ty: ir.Type, + values: Optional[List[ir.Value]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.OpResult: """ This is a wrapper around the _cute_ir.make_atom operation, providing default value for the values argument. """ @@ -135,10 +177,10 @@ class Atom(ABC): self._op = op self._trait = trait - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> List[ir.Value]: return extract_mlir_values(self._trait) + extract_mlir_values(self._op) - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "Atom": traits_value = values[: len(extract_mlir_values(self._trait))] op_value = values[len(extract_mlir_values(self._trait)) :] @@ -151,11 +193,18 @@ class Atom(ABC): return self._op @property - def type(self): + def type(self) -> ir.Type: return self._trait.value.type @dsl_user_op - def set(self, modifier, value, *, loc=None, ip=None) -> None: + def set( + self, + modifier: Any, + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Sets runtime fields of the Atom. @@ -175,7 +224,13 @@ class Atom(ABC): self._trait.set(modifier, value, loc=loc, ip=ip) @dsl_user_op - def get(self, field, *, loc=None, ip=None) -> Any: + def get( + self, + field: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: """ Gets runtime fields of the Atom. @@ -193,7 +248,13 @@ class Atom(ABC): """ return self._trait.get(field, loc=loc, ip=ip) - def with_(self, *, loc=None, ip=None, **kwargs) -> "Atom": + def with_( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "Atom": """ Returns a new Atom with the new Operation and Trait with the given runtime state. The runtime state is provided as keyword arguments and it is Atom-specific. @@ -208,7 +269,13 @@ class Atom(ABC): """ return self.__class__(self.op, self._trait.with_(loc=loc, ip=ip, **kwargs)) - def _unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + def _unpack( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> ir.Value: return self._trait.unpack(loc=loc, ip=ip, **kwargs) @@ -239,27 +306,52 @@ class MmaAtom(Atom): @property @dsl_user_op - def thr_id(self, *, loc=None, ip=None) -> Layout: + def thr_id( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.thr_id, loc=loc, ip=ip) @property @dsl_user_op - def shape_mnk(self, *, loc=None, ip=None) -> Shape: + def shape_mnk( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: return _unpack_x_tuple(self._trait.value.type.shape_mnk, loc=loc, ip=ip) @property @dsl_user_op - def tv_layout_A(self, *, loc=None, ip=None) -> Layout: + def tv_layout_A( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.layout_a_tv, loc=loc, ip=ip) @property @dsl_user_op - def tv_layout_B(self, *, loc=None, ip=None) -> Layout: + def tv_layout_B( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.layout_b_tv, loc=loc, ip=ip) @property @dsl_user_op - def tv_layout_C(self, *, loc=None, ip=None) -> Layout: + def tv_layout_C( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.layout_c_tv) # @@ -267,11 +359,17 @@ class MmaAtom(Atom): # @dsl_user_op - def make_fragment_A(self, input, *, loc=None, ip=None): + def make_fragment_A( + self, + input: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.OpResult: # input could be memref/shape/layout for tmem based fragment if isinstance(input, _Tensor): if self.op is not None: - self.op._verify_fragment_A(input, loc=loc, ip=ip) + self.op._verify_fragment_A(input, loc=loc, ip=ip) # type: ignore[attr-defined] input = input.value if isinstance(input, tuple): input = _pack_shape(input, loc=loc, ip=ip) @@ -280,10 +378,16 @@ class MmaAtom(Atom): ) @dsl_user_op - def make_fragment_B(self, input, *, loc=None, ip=None): + def make_fragment_B( + self, + input: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.OpResult: if isinstance(input, _Tensor): if self.op is not None: - self.op._verify_fragment_B(input, loc=loc, ip=ip) + self.op._verify_fragment_B(input, loc=loc, ip=ip) # type: ignore[attr-defined] input = input.value if isinstance(input, tuple): input = _pack_shape(input, loc=loc, ip=ip) @@ -292,7 +396,13 @@ class MmaAtom(Atom): ) @dsl_user_op - def make_fragment_C(self, input, *, loc=None, ip=None): + def make_fragment_C( + self, + input: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.OpResult: # input could be memref/shape/layout for tmem based fragment if isinstance(input, _Tensor): input = input.value @@ -326,27 +436,52 @@ class TiledMma(MmaAtom): @property @dsl_user_op - def tv_layout_A_tiled(self, *, loc=None, ip=None) -> Layout: + def tv_layout_A_tiled( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.layout_a_tv_tiled, loc=loc, ip=ip) @property @dsl_user_op - def tv_layout_B_tiled(self, *, loc=None, ip=None) -> Layout: + def tv_layout_B_tiled( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.layout_b_tv_tiled, loc=loc, ip=ip) @property @dsl_user_op - def tv_layout_C_tiled(self, *, loc=None, ip=None) -> Layout: + def tv_layout_C_tiled( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.layout_c_tv_tiled, loc=loc, ip=ip) @property @dsl_user_op - def permutation_mnk(self, *, loc=None, ip=None) -> Tile: + def permutation_mnk( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tile: return _unpack_x_tuple(self._trait.value.type.permutation_mnk, loc=loc, ip=ip) @property @dsl_user_op - def thr_layout_vmnk(self, *, loc=None, ip=None) -> Layout: + def thr_layout_vmnk( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return static(self._trait.value.type.thr_layout_vmnk, loc=loc, ip=ip) @property @@ -380,7 +515,14 @@ class TiledMma(MmaAtom): # partition_shape # - def _partition_shape(self, operand_id, shape, *, loc=None, ip=None): + def _partition_shape( + self, + operand_id: Any, + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: shape = _pack_shape(shape, loc=loc, ip=ip) return _unpack_x_tuple( _cute_ir.tiled_mma_partition_shape( @@ -391,15 +533,33 @@ class TiledMma(MmaAtom): ) @dsl_user_op - def partition_shape_A(self, shape_mk, *, loc=None, ip=None): + def partition_shape_A( + self, + shape_mk: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: return self._partition_shape(_cute_ir.MmaOperand.A, shape_mk, loc=loc, ip=ip) @dsl_user_op - def partition_shape_B(self, shape_nk, *, loc=None, ip=None): + def partition_shape_B( + self, + shape_nk: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: return self._partition_shape(_cute_ir.MmaOperand.B, shape_nk, loc=loc, ip=ip) @dsl_user_op - def partition_shape_C(self, shape_mn, *, loc=None, ip=None): + def partition_shape_C( + self, + shape_mn: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: return self._partition_shape(_cute_ir.MmaOperand.C, shape_mn, loc=loc, ip=ip) # @@ -407,16 +567,37 @@ class TiledMma(MmaAtom): # @overload - def _thrfrg(self, operand_id, input: Layout, *, loc=None, ip=None) -> Layout: ... + def _thrfrg( + self, + operand_id: Any, + input: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: ... @overload - def _thrfrg(self, operand_id, input: Tensor, *, loc=None, ip=None) -> Tensor: ... + def _thrfrg( + self, + operand_id: Any, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: ... - def _thrfrg(self, operand_id, input, *, loc=None, ip=None) -> Union[Tensor, Layout]: + def _thrfrg( + self, + operand_id: Any, + input: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Union[Tensor, Layout]: if isinstance(input, Tensor): return make_tensor( input.iterator, - self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), + self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), # type: ignore[arg-type] loc=loc, ip=ip, ) @@ -432,17 +613,29 @@ class TiledMma(MmaAtom): ) def _thrfrg_A( - self, input: Union[Layout, Tensor], *, loc=None, ip=None + self, + input: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Layout, Tensor]: return self._thrfrg(_cute_ir.MmaOperand.A, input, loc=loc, ip=ip) def _thrfrg_B( - self, input: Union[Layout, Tensor], *, loc=None, ip=None + self, + input: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Layout, Tensor]: return self._thrfrg(_cute_ir.MmaOperand.B, input, loc=loc, ip=ip) def _thrfrg_C( - self, input: Union[Layout, Tensor], *, loc=None, ip=None + self, + input: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Layout, Tensor]: return self._thrfrg(_cute_ir.MmaOperand.C, input, loc=loc, ip=ip) @@ -456,17 +649,23 @@ class ThrMma(TiledMma): super().__init__(op, trait) self._thr_idx = thr_idx - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "ThrMma": return self.__class__( self.op, new_from_mlir_values(self._trait, values), self.thr_idx ) @property - def thr_idx(self): + def thr_idx(self) -> Union[int, Int32]: return self._thr_idx @dsl_user_op - def partition_A(self, input_mk: Tensor, *, loc=None, ip=None) -> Tensor: + def partition_A( + self, + input_mk: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) return _cute_ir.tiled_mma_partition( _cute_ir.MmaOperand.A, @@ -478,7 +677,13 @@ class ThrMma(TiledMma): ) @dsl_user_op - def partition_B(self, input_nk: Tensor, *, loc=None, ip=None) -> Tensor: + def partition_B( + self, + input_nk: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) return _cute_ir.tiled_mma_partition( _cute_ir.MmaOperand.B, @@ -490,7 +695,13 @@ class ThrMma(TiledMma): ) @dsl_user_op - def partition_C(self, input_mn: Tensor, *, loc=None, ip=None) -> Tensor: + def partition_C( + self, + input_mn: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) return _cute_ir.tiled_mma_partition( _cute_ir.MmaOperand.C, @@ -503,7 +714,13 @@ class ThrMma(TiledMma): @dsl_user_op -def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: +def make_mma_atom( + op: MmaOp, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, +) -> MmaAtom: """ Makes an MMA Atom from an MMA Operation. @@ -522,12 +739,12 @@ def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: @dsl_user_op def make_tiled_mma( op_or_atom: Union[Op, MmaAtom], - atom_layout_mnk=(1, 1, 1), - permutation_mnk=None, + atom_layout_mnk: Any = (1, 1, 1), + permutation_mnk: Any = None, *, - loc=None, - ip=None, - **kwargs, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> TiledMma: """ Makes a tiled MMA from an MMA Operation or an MMA Atom. @@ -610,33 +827,6 @@ class CopyAtom(Atom): def layout_dst_tv(self) -> Layout: return static(self._trait.value.type.layout_dst_tv) - @property - def smem_layout(self): - """ - Convenience property to access the SMEM layout for TMA copy atoms. - - This is a shortcut for ``atom.op.smem_layout`` that checks if the operation - is a TMA operation and provides a clearer error message if not. - - :return: The SMEM layout - :rtype: Layout or ComposedLayout - :raises TypeError: If the operation is not a TMA operation - :raises ValueError: If the SMEM layout is not set - - Example: - >>> layout = tma_atom.smem_layout # Instead of tma_atom.op.smem_layout - """ - # Import here to avoid circular dependency - from .nvgpu.cpasync.copy import TmaCopyOp - - if not isinstance(self.op, TmaCopyOp): - raise TypeError( - f"smem_layout is only available for TMA copy operations, " - f"but this atom uses {type(self.op).__name__}" - ) - - return self.op.smem_layout - class TiledCopy(CopyAtom): """ @@ -686,9 +876,18 @@ class TiledCopy(CopyAtom): return ThrCopy(self.op, self._trait, thr_idx) @dsl_user_op - def retile(self, src, *, loc=None, ip=None): + def retile( + self, + src: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: return _cute_ir.tiled_copy_retile( - tiled_copy=self._trait.value, input=src.value, loc=loc, ip=ip + tiled_copy=self._trait.value, + input=src.value, + loc=loc, + ip=ip, ) @@ -701,33 +900,58 @@ class ThrCopy(TiledCopy): super().__init__(op, trait) self._thr_idx = thr_idx - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "ThrCopy": return self.__class__( self.op, new_from_mlir_values(self._trait, values), self.thr_idx ) @property - def thr_idx(self): + def thr_idx(self) -> Union[int, Int32]: return self._thr_idx @dsl_user_op - def partition_S(self, src: Tensor, *, loc=None, ip=None) -> Tensor: + def partition_S( + self, + src: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) return _cute_ir.tiled_copy_partition_S( - self._trait.value, src.value, thr_idx, loc=loc, ip=ip + self._trait.value, + src.value, + thr_idx, + loc=loc, + ip=ip, ) @dsl_user_op - def partition_D(self, dst: Tensor, *, loc=None, ip=None) -> Tensor: + def partition_D( + self, + dst: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) return _cute_ir.tiled_copy_partition_D( - self._trait.value, dst.value, thr_idx, loc=loc, ip=ip + self._trait.value, + dst.value, + thr_idx, + loc=loc, + ip=ip, ) @dsl_user_op def make_copy_atom( - op: CopyOp, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + op: CopyOp, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> CopyAtom: """ Makes a Copy Atom from a Copy Operation. @@ -753,7 +977,14 @@ def make_copy_atom( return CopyAtom(op, trait) -def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): +def _make_tiled_copy( + atom: Any, + layout_tv: Any, + tiler_mn: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": if type(tiler_mn) is tuple: tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) @@ -774,7 +1005,14 @@ def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): return TiledCopy(atom.op, trait) -def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): +def make_tiled_copy( + atom: Any, + layout_tv: Any, + tiler_mn: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create a tiled type given a TV partitioner and tiler. :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. @@ -796,7 +1034,12 @@ def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): @dsl_user_op def make_tiled_copy_tv( - atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None + atom: CopyAtom, + thr_layout: Layout, + val_layout: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> TiledCopy: """Create a tiled copy given separate thread and value layouts. @@ -825,7 +1068,12 @@ def make_tiled_copy_tv( @dsl_user_op def make_cotiled_copy( - atom: CopyAtom, atom_layout_tv: Layout, data_layout: Layout, *, loc=None, ip=None + atom: CopyAtom, + atom_layout_tv: Layout, + data_layout: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> TiledCopy: """ Produce a TiledCopy from thread and value offset maps. @@ -861,7 +1109,10 @@ def make_cotiled_copy( # check validity atom_layout_v_to_check = coalesce( make_layout( - atom_layout_tv.shape[1], stride=atom_layout_tv.stride[1], loc=loc, ip=ip + atom_layout_tv.shape[1], # type: ignore[index] + stride=atom_layout_tv.stride[1], # type: ignore[index] + loc=loc, + ip=ip, ), loc=loc, ip=ip, @@ -915,7 +1166,13 @@ def make_cotiled_copy( @dsl_user_op -def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): +def make_tiled_copy_A( + atom: Any, + tiled_mma: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. :param atom: Copy atom @@ -941,7 +1198,13 @@ def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): @dsl_user_op -def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): +def make_tiled_copy_B( + atom: Any, + tiled_mma: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. :param atom: Copy atom @@ -967,7 +1230,13 @@ def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): @dsl_user_op -def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): +def make_tiled_copy_C( + atom: Any, + tiled_mma: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. :param atom: Copy atom @@ -993,7 +1262,13 @@ def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): @dsl_user_op -def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): +def make_tiled_copy_S( + atom: Any, + tiled_copy: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. :param atom: Copy atom @@ -1015,7 +1290,13 @@ def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): @dsl_user_op -def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): +def make_tiled_copy_D( + atom: Any, + tiled_copy: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. :param atom: Copy atom @@ -1037,7 +1318,13 @@ def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): @dsl_user_op -def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): +def make_tiled_copy_C_atom( + atom: CopyAtom, + mma: TiledMma, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "TiledCopy": """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. :param atom: Copy atom @@ -1127,7 +1414,7 @@ def _normalize_variadic_tensor_operand( raise ValueError(f"`{name}` must contain at least one Tensor") if not all(isinstance(t, Tensor) for t in x): raise TypeError(f"All elements of `{name}` must be Tensor") - return list(x) # type: ignore + return list(x) raise TypeError(f"`{name}` must be a Tensor or a sequence of Tensors") @@ -1138,9 +1425,9 @@ def copy_atom_call( dst: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], *, pred: Optional[Tensor] = None, - loc=None, - ip=None, - **kwargs, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> None: """ Execute a single copy atom operation. @@ -1165,6 +1452,10 @@ def copy_atom_call( - For copy with auxiliary operands, they contain the main tensor followed by auxiliary tensors. For example: + - For static load from tensor memory, ``dst`` = [data, stat]. + - For TMA gather4, ``src`` = [coord0, coord1, coord2, coord3] (four 2D coordinate tensors). + - For TMA scatter4, ``dst`` = [coord0, coord1, coord2, coord3] (four 2D coordinate tensors). + :param atom: Copy atom specifying the transfer operation :type atom: CopyAtom :param src: Source tensor(s) with layout profile ``(V)``. Can be a single Tensor @@ -1195,16 +1486,26 @@ def copy_atom_call( # Predicated copy atom operation cute.copy_atom_call(copy_atom, src, dst, pred=pred) + # Static load from tensor memory: load with row-wise reduction (MAX, MIN, MAXABS, MINABS) + cute.copy_atom_call(loadtm_stat_atom, src, [data, stat]) + + # TMA gather4: combine four 2D coordinate tensors into single destination + cute.copy_atom_call(tma_gather4_atom, [coord0, coord1, coord2, coord3], dst) + """ # Normalize src/dst to lists for variadic IR operands, while keeping old API working. src_list = _normalize_variadic_tensor_operand(src, "src") dst_list = _normalize_variadic_tensor_operand(dst, "dst") - # Validate first src/dst for element type width check - if isinstance(src_list[0].type, _cute_ir.MemRefType) and isinstance( - dst_list[0].type, _cute_ir.MemRefType + # Validate first src/dst for element type width check. + if isinstance(src_list[0].type, _cute_ir.MemRefType) and isinstance( # type: ignore[attr-defined] + dst_list[0].type, # type: ignore[attr-defined] + _cute_ir.MemRefType, ): - if src_list[0].element_type.width != dst_list[0].element_type.width: + if ( + len(dst_list) == 1 + and src_list[0].element_type.width != dst_list[0].element_type.width # type: ignore[union-attr] + ): raise TypeError( "`copy_atom_call` currently only supports equal source and destination " "element type bit width" @@ -1228,13 +1529,13 @@ def copy_atom_call( def mma_atom_call( atom: MmaAtom, d: Tensor, - a: Tensor, - b: Tensor, + a: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], + b: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], c: Tensor, *, - loc=None, - ip=None, - **kwargs, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> None: """ Execute a single MMA atom operation. @@ -1245,15 +1546,23 @@ def mma_atom_call( Note: The tensors 'd', 'a', 'b', and 'c' must only have a single fragment. + The operands `a` and `b` are variadic, each containing a variable number of tensors: + + - For regular MMA, `a` and `b` contain the MMA A and B tensors respectively. + - For MMA with auxiliary operands, `a` and `b` contain the MMA A and B tensors followed by + their respective auxiliary tensors. For example: + + - For BlockScaledMMA, `a` = [A, SFA] and `b` = [B, SFB]. + :param atom: The MMA atom to execute :type atom: MmaAtom :param d: Destination tensor (output accumulator) :type d: Tensor - :param a: First source tensor (matrix A) - :type a: Tensor - :param b: Second source tensor (matrix B) - :type b: Tensor - :param c: Third source tensor (input accumulator C) + :param a: A tensor or list of tensors containing the MMA A tensor and optional auxiliary tensors + :type a: Union[Tensor, List[Tensor], Tuple[Tensor, ...]] + :param b: B tensor or list of tensors containing the MMA B tensor and optional auxiliary tensors + :type b: Union[Tensor, List[Tensor], Tuple[Tensor, ...]] + :param c: Input accumulator tensor :type c: Tensor :param loc: Source location for MLIR, defaults to None :type loc: Optional[Location], optional @@ -1264,10 +1573,26 @@ def mma_atom_call( .. code-block:: python - # Call an MMA atom operation + # Regular MMA atom call cute.mma_atom_call(mma_atom, d_tensor, a_tensor, b_tensor, c_tensor) + + # Block-scaled MMA atom call + cute.mma_atom_call(mma_atom, d_tensor, [a_tensor, sfa_tensor], + [b_tensor, sfb_tensor], c_tensor) """ + # Normalize A/B to lists for variadic IR operands, while keeping old API working. + a_list = _normalize_variadic_tensor_operand(a, "a") + b_list = _normalize_variadic_tensor_operand(b, "b") + value = atom._unpack(loc=loc, ip=ip, **kwargs) + a_vals = [t.value for t in a_list] + b_vals = [t.value for t in b_list] return _cute_ir.mma_atom_call( - value, d.value, a.value, b.value, c.value, loc=loc, ip=ip + value, + d.value, + a_vals, + b_vals, + c.value, + loc=loc, + ip=ip, ) diff --git a/python/CuTeDSL/cutlass/cute/core.py b/python/CuTeDSL/cutlass/cute/core.py index e37152321..84387e62d 100644 --- a/python/CuTeDSL/cutlass/cute/core.py +++ b/python/CuTeDSL/cutlass/cute/core.py @@ -12,24 +12,21 @@ from functools import partial, reduce import inspect from inspect import isclass -from typing import Any, Dict, List, Optional, Tuple, Type, Union, overload -from types import MethodType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, overload -from cutlass import const_expr from typing_extensions import deprecated from cutlass._mlir import ir -from cutlass._mlir.dialects import builtin, llvm, vector, arith, nvvm -from cutlass._mlir.dialects import cute as _cute_ir -from cutlass._mlir.dialects.cute import ( - Ratio as _Ratio, -) +from cutlass._mlir.dialects import builtin, llvm, vector, arith +from cutlass._mlir.dialects import cute as _cute_ir, cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir.dialects.cute import ( ReductionOp as ReductionOp, ) from cutlass._mlir.dialects.cute import ( + Ratio as _Ratio, ScaledBasis as _ScaledBasis, ) +from cutlass._mlir.extras.types import MemRefType as BuiltinMemRefType from cutlass.cutlass_dsl import ( T, const, @@ -43,7 +40,7 @@ from cutlass.cutlass_dsl import ( not_, ) -from .tuple import find_if, flatten_to_tuple, product_each, transform_leaf, wrap +from .tuple import find_if, flatten_to_tuple, product_each, transform_leaf, unwrap, wrap from .typing import ( AddressSpace, Boolean, @@ -124,12 +121,12 @@ __all__ = [ "coalesce", "crd2idx", "idx2crd", + "increment_coord", "recast_layout", "slice_and_offset", "shape", "recast_ptr", "make_ptr", - "get_remote_smem_ptr_in_cluster", "composition", "complement", "right_inverse", @@ -154,6 +151,7 @@ __all__ = [ "make_layout_tv", "get_nonswizzle_portion", "get_swizzle_portion", + "nullspace", ] #################################################################################################### @@ -163,7 +161,7 @@ __all__ = [ #################################################################################################### -def _get_typed_value(x): +def _get_typed_value(x: Any) -> Any: if isinstance(x, Integer): x = x.ir_value() @@ -173,7 +171,14 @@ def _get_typed_value(x): return x -def _pack_x(x, packer, op, *, loc=None, ip=None) -> ir.Value: +def _pack_x( + x: Any, + packer: Callable[..., Any], + op: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: x = transform_leaf(_get_typed_value, x) res_ty, dyn_elems = packer(x) # <"0"> is deduced from type inference which should be removed for make_... operations @@ -181,12 +186,22 @@ def _pack_x(x, packer, op, *, loc=None, ip=None) -> ir.Value: return op(res_ty, dyn_elems, loc=loc, ip=ip).result -def _pack_shape(shape: Shape, *, loc=None, ip=None) -> ir.Value: +def _pack_shape( + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: _check_shape(shape) return _pack_x(shape, _cute_ir.pack_shape, _cute_ir.MakeShapeOp, loc=loc, ip=ip) -def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: +def _pack_stride( + stride: Stride, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: _check_stride(stride) dyn_elems = map(_get_typed_value, extract_mlir_values(stride)) # Convert basis elements to the base class before _pack_x @@ -202,22 +217,37 @@ def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: return _cute_ir.MakeStrideOp(res_ty, dyn_elems, loc=loc, ip=ip).result -def _pack_coord(coord: Coord, *, loc=None, ip=None) -> ir.Value: +def _pack_coord( + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: _check_coord(coord) return _pack_x(coord, _cute_ir.pack_coord, _cute_ir.MakeCoordOp, loc=loc, ip=ip) -def _pack_int_tuple(int_tuple: IntTuple, *, loc=None, ip=None) -> ir.Value: +def _pack_int_tuple( + int_tuple: IntTuple, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: _check_int_tuple(int_tuple) return _pack_x( int_tuple, _cute_ir.pack_int_tuple, _cute_ir.MakeIntTupleOp, loc=loc, ip=ip ) -def _pack_tile(tile: Tile, *, loc=None, ip=None) -> ir.Value: +def _pack_tile( + tile: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: _check_tile(tile) - def expand_leaves(tile) -> list: + def expand_leaves(tile: Any) -> list: leaves = [] for e in tile: if isinstance(e, _Layout): @@ -238,7 +268,12 @@ def _pack_tile(tile: Tile, *, loc=None, ip=None) -> ir.Value: return _cute_ir.make_tile(res_ty, dyn_elems, loc=loc, ip=ip) -def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple: +def _unpack_x_tuple( + t: Union[ir.Type, ir.Value], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> XTuple: # If t is an MLIR type, make sure it's static and make a Value if isinstance(t, ir.Type): if not _cute_ir.is_static(t): @@ -260,7 +295,7 @@ def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple # CuTe IR only supports Int32 for now. Need to support detection of other types res = _cute_ir.unpack_x_tuple(input_ty, vals, loc=loc) - def post_process(x): + def post_process(x: Any) -> Any: if isinstance(x, _cute_ir.ScaledBasis): return ScaledBasis(post_process(x.get_value()), x.get_mode()) elif isinstance(x, _cute_ir.Ratio): @@ -361,7 +396,14 @@ class IntValue(cutlass_arith.ArithValue): """ @dsl_user_op - def __init__(self, v, signed=True, *, loc=None, ip=None): + def __init__( + self, + v: Any, + signed: bool = True, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: # Cute Constrained Int Type is always signed if isinstance(v, int): v = _pack_int_tuple(v, loc=loc, ip=ip) @@ -373,7 +415,12 @@ class IntValue(cutlass_arith.ArithValue): super().__init__(v, True, loc=loc, ip=ip) @dsl_user_op - def get_typed_value(self, *, loc=None, ip=None): + def get_typed_value( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: if isinstance(self.type, ir.IntegerType): def_op = self.owner.operation if def_op.name == "cute.get_scalars": @@ -387,28 +434,29 @@ class IntValue(cutlass_arith.ArithValue): return _cute_ir.MakeIntTupleOp(res_ty, [self], loc=loc, ip=ip).result @property - def divisibility(self): - assert isinstance(self.get_typed_value().type, _cute_ir.IntTupleType), ( - f"expected self.get_typed_value() to be int_tuple type, but got {self.get_typed_value().type}" + def divisibility(self) -> int: + typed_value = self.get_typed_value() + assert isinstance(typed_value.type, _cute_ir.IntTupleType), ( + f"expected self.get_typed_value() to be int_tuple type, but got {typed_value.type}" ) - return self.get_typed_value().type.get_divisibility([0]) + return typed_value.type.get_divisibility([0]) - def __str__(self): + def __str__(self) -> str: if self.divisibility == 1: return "?" elif self.type.width == 32: return f"?{{div={self.divisibility}}}" else: return f"?{{i{self.type.width} div={self.divisibility}}}" - def __repr__(self): + def __repr__(self) -> str: parent_name = cutlass_arith.ArithValue.__name__ return super().__str__().replace(parent_name, IntValue.__name__) - def pretty_str(self): + def pretty_str(self) -> str: return self.__str__() - def _binary_op(op): - def wrapper(self, other, **kwargs): + def _binary_op(op: Any) -> Any: + def wrapper(self: "IntValue", other: Any, **kwargs: Any) -> "IntValue": if isinstance(other, IntValue): other_val = other.get_typed_value() elif isinstance(other, ir.Value) and isinstance( @@ -433,70 +481,130 @@ class IntValue(cutlass_arith.ArithValue): @dsl_user_op @_binary_op - def __add__(self, other, *, loc=None, ip=None): + def __add__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_add( self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __sub__(self, other, *, loc=None, ip=None): + def __sub__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_sub( self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __mul__(self, other, *, loc=None, ip=None): + def __mul__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_mul( self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __floordiv__(self, other, *, loc=None, ip=None) -> "IntValue": + def __floordiv__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_div( self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __mod__(self, other, *, loc=None, ip=None) -> cutlass_arith.ArithValue: + def __mod__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_mod( self.get_typed_value(loc=loc, ip=ip), other, loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __radd__(self, other, *, loc=None, ip=None) -> "IntValue": + def __radd__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_add( other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __rsub__(self, other, *, loc=None, ip=None) -> "IntValue": + def __rsub__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_sub( other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __rmul__(self, other, *, loc=None, ip=None): + def __rmul__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_mul( other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "IntValue": + def __rfloordiv__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_div( other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip ) @dsl_user_op @_binary_op - def __rmod__(self, other, *, loc=None, ip=None) -> "IntValue": + def __rmod__( + self, + other: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "IntValue": return _cute_ir.tuple_mod( other, self.get_typed_value(loc=loc, ip=ip), loc=loc, ip=ip ) @@ -540,7 +648,7 @@ class Ratio(_Ratio): res = super().reduced() return Ratio(res.numerator, res.denominator) - def __mul__(self, other): + def __mul__(self, other: Union["Ratio", int]) -> "Ratio": """Multiply this ratio by another ratio or an integer. :param other: The value to multiply by @@ -559,7 +667,7 @@ class Ratio(_Ratio): else: raise TypeError(f"Cannot multiply Ratio with {type(other)}") - def __rmul__(self, other): + def __rmul__(self, other: Union["Ratio", int]) -> "Ratio": """Right multiplication operation. :param other: The value to multiply by @@ -569,7 +677,7 @@ class Ratio(_Ratio): """ return self.__mul__(other) - def __str__(self): + def __str__(self) -> str: """String representation of the ratio. :return: String in the format "numerator/denominator" @@ -577,7 +685,7 @@ class Ratio(_Ratio): """ return super().__str__() - def to(self, dtype): + def to(self, dtype: type) -> ir.Value: """Convert the ratio to another type. :param dtype: The target type for conversion @@ -634,7 +742,7 @@ class ScaledBasis: idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3) """ - def __init__(self, value, mode) -> None: + def __init__(self, value: Any, mode: Union[int, List[int]]) -> None: if isinstance(mode, int): self._mode = [mode] else: @@ -653,7 +761,13 @@ class ScaledBasis: return not is_dynamic_expression(self._value) @dsl_user_op - def to(self, dtype, *, loc=None, ip=None): + def to( + self, + dtype: type, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: """Convert to another type. :param dtype: The target type for conversion @@ -683,14 +797,14 @@ class ScaledBasis: else: raise TypeError(f"Cannot convert ScaledBasis to {dtype}") - def __str__(self): + def __str__(self) -> str: return f"{self.to(_ScaledBasis).__str__()}" - def __hash__(self): + def __hash__(self) -> int: return hash((self.value, tuple(self.mode))) @property - def value(self): + def value(self) -> Any: """Get the scale value. :return: The scale value @@ -706,14 +820,18 @@ class ScaledBasis: """ return self._mode - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, ScaledBasis): - return and_(self.mode == other.mode, self.value == other.value) + return and_(self.mode == other.mode, self.value == other.value) # type: ignore[return-value] else: return False def __rmul__( - self, scale: Union[Int, ir.Value, Ratio], *, loc=None, ip=None + self, + scale: Union[Int, ir.Value, Ratio], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> "ScaledBasis": """Right multiplication by a scale factor. @@ -735,13 +853,14 @@ class ScaledBasis: raise TypeError( f"scale must be an integer or a ratio, but got {type(scale)}" ) - if isinstance(self.value, Ratio): + + value = self.value + + if isinstance(value, Ratio): raise NotImplementedError( "scaling a basis element having a ratio is not supported" ) - value = self.value - if not isinstance(value, (Integer, Ratio, int, cutlass_arith.ArithValue)): raise TypeError(f"Don't support {type(value)} for ScaledBasis") @@ -754,10 +873,14 @@ class ScaledBasis: elif isinstance(value, Integer): value = value.ir_value(loc=loc, ip=ip) - return ScaledBasis(scale * value, self.mode) # type: ignore + return ScaledBasis(scale * value, self.mode) def __mul__( - self, scale: Union[Int, ir.Value, Ratio], *, loc=None, ip=None + self, + scale: Union[Int, ir.Value, Ratio], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> "ScaledBasis": """Multiplication by a scale factor. This operation is used in layout algebra to scale basis elements, @@ -775,9 +898,9 @@ class ScaledBasis: :raises NotImplementedError: If scaling a basis element with a ratio value """ - return self.__rmul__(scale, loc=loc, ip=ip) + return self.__rmul__(scale, loc=loc, ip=ip) # type: ignore[call-arg] - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> List[ir.Value]: if isinstance(self.value, Ratio): # Ratio is always static return [] @@ -785,7 +908,7 @@ class ScaledBasis: return extract_mlir_values(self.value) -def E(mode: Union[int, List[int]]) -> ScaledBasis: +def E(mode: Union[int, List[int]]) -> Union[ScaledBasis, int]: """Create a unit ScaledBasis element with the specified mode. This function creates a ScaledBasis with value 1 and the given mode. @@ -822,12 +945,12 @@ def E(mode: Union[int, List[int]]) -> ScaledBasis: return ScaledBasis(1, mode) -def get_divisibility(x: Union[int, Integer]) -> int: +def get_divisibility(x: Int) -> int: if isinstance(x, int): return x if isinstance(x, Integer): - x = x.value + x = x.value # type: ignore[assignment] if isinstance(x, IntValue): return x.divisibility @@ -866,8 +989,8 @@ def basis_get( basis: Union[ScaledBasis, Numeric, int], t: Union[XTuple, Layout, ComposedLayout], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[XTuple, Layout, ComposedLayout]: """Apply the mode indices from a ScaledBasis to get an element from a tuple, layout, or composed layout. @@ -939,11 +1062,11 @@ class Swizzle(ir.Value): """ - def __str__(self): + def __str__(self) -> str: # Cut off the MLIR type's string for making pretty_str more concise return self.type.__str__()[15 : 15 + 8] - def __eq__(self, other) -> Union[bool, Boolean]: + def __eq__(self, other: object) -> Union[bool, Boolean]: # type: ignore[override] """Check if this Swizzle is equal to another Swizzle. Since num_bits, num_base, and num_shift are static, this is a constant expression. @@ -1014,17 +1137,27 @@ class _Layout(Layout): idx = cute.crd2idx((2, 3), layout) """ - def __init__(self, op_result) -> None: + def __init__(self, op_result: ir.Value) -> None: """Initialize a Layout object. :param op_result: The operation result value to wrap. """ super().__init__(op_result) - def __repr__(self, *, loc=None, ip=None) -> str: - return self.__str__(loc=loc, ip=ip) + def __repr__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> str: + return self.__str__(loc=loc, ip=ip) # type: ignore[call-arg] - def __str__(self, *, loc=None, ip=None) -> str: + def __str__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> str: """Return a string representation of the layout. :return: A string in the format "shape:stride". @@ -1033,11 +1166,21 @@ class _Layout(Layout): return type_str[type_str.find("<") + 2 : type_str.rfind(">") - 1] @lru_cache_ir() - def shape_method(self, *, loc=None, ip=None) -> Shape: + def shape_method( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) @lru_cache_ir() - def stride_method(self, *, loc=None, ip=None) -> Stride: + def stride_method( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Stride: return _unpack_x_tuple( _cute_ir.get_stride(self, loc=loc, ip=ip), loc=loc, ip=ip ) @@ -1045,7 +1188,12 @@ class _Layout(Layout): @property @dsl_user_op @lru_cache_ir() - def shape(self, *, loc=None, ip=None) -> Shape: + def shape( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: """Get the shape of the layout. The shape defines the dimensions and structure of the layout's @@ -1058,7 +1206,12 @@ class _Layout(Layout): @property @dsl_user_op @lru_cache_ir() - def stride(self, *, loc=None, ip=None) -> Stride: + def stride( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Stride: """Get the stride of the layout. The stride defines how coordinates map to linear indices in memory. @@ -1075,7 +1228,7 @@ class _Layout(Layout): """ return self.type.max_alignment - def __eq__(self, other) -> Union[bool, Boolean]: + def __eq__(self, other: object) -> Union[bool, Boolean]: # type: ignore[override] """Check if this layout is equal to another layout. Two layouts are equal if they have the same shape and stride. @@ -1091,7 +1244,7 @@ class _Layout(Layout): else: return False - def __req__(self, other) -> Union[bool, Boolean]: + def __req__(self, other: object) -> Union[bool, Boolean]: """Reflected equality check. :param other: The layout to compare with. @@ -1101,7 +1254,7 @@ class _Layout(Layout): return other.__eq__(self) return False - def __ne__(self, other) -> Union[bool, Boolean]: + def __ne__(self, other: object) -> Union[bool, Boolean]: # type: ignore[override] """Check if this layout is not equal to another layout. :param other: The layout to compare with. @@ -1114,7 +1267,7 @@ class _Layout(Layout): else: return True - def __rne__(self, other) -> Union[bool, Boolean]: + def __rne__(self, other: object) -> Union[bool, Boolean]: """Reflected inequality check. :param other: The layout to compare with. @@ -1131,7 +1284,12 @@ class _Layout(Layout): return get(self, mode=[idx]) @dsl_user_op - def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: + def __call__( + self, + coord: Coord, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> IntTuple: if has_underscore(coord): crd_val = _pack_coord(coord, loc=loc, ip=ip) return _cute_ir.slice(self, crd_val, loc=loc, ip=ip) @@ -1139,7 +1297,13 @@ class _Layout(Layout): return crd2idx(coord, self, loc=loc, ip=ip) @dsl_user_op - def get_hier_coord(self, idx, *, loc=None, ip=None) -> Coord: + def get_hier_coord( + self, + idx: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Coord: """Get the hierarchical coordinate corresponding to a linear index. This method maps from a linear index back to the logical coordinate @@ -1162,7 +1326,13 @@ class _Layout(Layout): return _unpack_x_tuple(crd, loc=loc, ip=ip) @dsl_user_op - def get_flat_coord(self, idx, *, loc=None, ip=None) -> Coord: + def get_flat_coord( + self, + idx: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Coord: idx_val = Int32(idx).ir_value(loc=loc, ip=ip) res = _cute_ir.get_flat_coord(idx_val, self, loc=loc, ip=ip) return _unpack_x_tuple(res, loc=loc, ip=ip) @@ -1178,7 +1348,7 @@ class _ComposedLayout(ComposedLayout): to coordinate as inner layout. """ - def __init__(self, value) -> None: + def __init__(self, value: ir.Value) -> None: """Initialize a ComposedLayout object. :param value: The operation result value to wrap. @@ -1198,28 +1368,53 @@ class _ComposedLayout(ComposedLayout): @property @dsl_user_op - def inner(self, *, loc=None, ip=None) -> Union[Swizzle, Layout]: + def inner( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Union[Swizzle, Layout]: return _cute_ir.composed_get_inner(self.value, loc=loc, ip=ip) @property @dsl_user_op - def offset(self, *, loc=None, ip=None) -> IntTuple: + def offset( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> IntTuple: return _unpack_x_tuple( _cute_ir.composed_get_offset(self.value, loc=loc, ip=ip), loc=loc, ip=ip ) @property @dsl_user_op - def outer(self, *, loc=None, ip=None) -> Layout: + def outer( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return _cute_ir.composed_get_outer(self.value, loc=loc, ip=ip) @property @dsl_user_op - def shape(self, *, loc=None, ip=None) -> Shape: + def shape( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: return self.shape_method(loc=loc, ip=ip) @dsl_user_op - def shape_method(self, *, loc=None, ip=None) -> Shape: + def shape_method( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: return _unpack_x_tuple( _cute_ir.get_shape(self.value, loc=loc, ip=ip), loc=loc, ip=ip ) @@ -1228,7 +1423,7 @@ class _ComposedLayout(ComposedLayout): def max_alignment(self) -> int: return self.type.max_alignment - def __eq__(self, other) -> Union[bool, Boolean]: + def __eq__(self, other: object) -> Union[bool, Boolean]: # type: ignore[override] if isinstance(other, _ComposedLayout): if is_static(self.type) and is_static(other.type): return self.type == other.type @@ -1239,34 +1434,45 @@ class _ComposedLayout(ComposedLayout): else: return False - def __req__(self, other) -> Union[bool, Boolean]: + def __req__(self, other: object) -> Union[bool, Boolean]: if isinstance(other, _ComposedLayout): return Boolean(other.__eq__(self)) return False - def __ne__(self, other) -> Union[bool, Boolean]: + def __ne__(self, other: object) -> Union[bool, Boolean]: # type: ignore[override] return not self.__eq__(other) - def __rne__(self, other) -> Union[bool, Boolean]: + def __rne__(self, other: object) -> Union[bool, Boolean]: if isinstance(other, _ComposedLayout): return other.__ne__(self) return True @dsl_user_op - def __getitem__(self, idx: int, *, loc=None, ip=None) -> "_ComposedLayout": + def __getitem__( + self, + idx: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "_ComposedLayout": """ Top-level `get` to provide a syntax similar to `tuple`. """ - return get(self, mode=[idx], loc=loc, ip=ip) + return get(self, mode=[idx], loc=loc, ip=ip) # type: ignore[return-value] @dsl_user_op - def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: + def __call__( + self, + coord: Coord, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> IntTuple: return crd2idx(coord, self, loc=loc, ip=ip) - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> List[ir.Value]: return [self.value] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "_ComposedLayout": # Only expecting single value of _ComposedLayout or ir.Value # In this context, a _ComposedLayout instance is an encapsulated ir.Value which is automatically created # by value caster for ComposedLayout typed values @@ -1303,21 +1509,25 @@ class _Pointer(Pointer): T(c) = (E ∘ L)(c) = *(E + L(c)) """ - def __init__(self, value) -> None: - assert isinstance(value, ir.Value) - self.value = ir.Value(value) + def __init__(self, value: ir.Value, dtype: Optional[Type[Numeric]] = None) -> None: + assert isinstance(value, ir.Value), f"Expected ir.Value, but got {type(value)}" + self.value = value + + if isinstance(value.type.value_type, _cute_nvgpu_ir.TmaDescriptorTiledType): + dtype = value.type.value_type + self._dtype = dtype or Numeric.from_mlir_type(value.type.value_type) def __str__(self) -> str: # Cut off the MLIR type's string for making pretty_str more concise return self.type.__str__()[6:] - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> List[ir.Type]: return [self.value.type] - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> List[ir.Value]: return [self.value] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "_Pointer": # Only expecting single value of _Pointer instance or ir.Value # In this context, a _Pointer instance is an encapsulated ir.Value which is automatically created # by value caster for cute.ptr typed values @@ -1330,16 +1540,12 @@ class _Pointer(Pointer): ) @property - @lru_cache_ir() def dtype( self, ) -> Union[ Type[Numeric], ]: - ret_type = None - if ret_type is None: - ret_type = Numeric.from_mlir_type(self.value.type.value_type) - return ret_type + return self._dtype @property def alignment(self) -> int: @@ -1361,7 +1567,12 @@ class _Pointer(Pointer): return self.value.type @dsl_user_op - def load(self, *, loc=None, ip=None) -> Numeric: + def load( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Numeric: # LLVM doesn't support load/store narrow precision per element tmp_ty = self.dtype.mlir_type if self.dtype is Boolean or self.dtype.width == 8: @@ -1383,9 +1594,9 @@ class _Pointer(Pointer): self, value: Union[Numeric, cutlass_arith.ArithValue, int, float, bool], *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: if isinstance(value, (int, float, bool, cutlass_arith.ArithValue)): value = self.dtype(value, loc=loc, ip=ip) elif isinstance(value, Numeric): @@ -1406,19 +1617,37 @@ class _Pointer(Pointer): return llvm.store(tmp_val, llvm_ptr, loc=loc, ip=ip) @dsl_user_op - def __getitem__(self, idx: Int, *, loc=None, ip=None) -> Pointer: + def __getitem__( + self, + idx: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: return (self + idx).load() @dsl_user_op - def __setitem__(self, idx: Int, value: Numeric, *, loc=None, ip=None) -> Pointer: + def __setitem__( + self, + idx: Int, + value: Numeric, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: (self + idx).store(value, loc=loc, ip=ip) - return value + return value # type: ignore[return-value] # Only use if you absolutely need to get the LLVM pointer Value @property @dsl_user_op @lru_cache_ir() - def llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + def llvm_ptr( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: """ Get the LLVM pointer representation of this pointer. @@ -1433,7 +1662,12 @@ class _Pointer(Pointer): @dsl_user_op @lru_cache_ir() - def to_llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + def to_llvm_ptr( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: """ Get the LLVM pointer representation of this pointer. (Used by internal API to propagate loc and ip) @@ -1452,7 +1686,55 @@ class _Pointer(Pointer): ) @dsl_user_op - def __add__(self, offset: Int, *, loc=None, ip=None) -> Pointer: + @lru_cache_ir() + def _to_builtin_memref( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: + """ + Convert this pointer to a builtin memref (without any layout information). + + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint] + :return: The builtin memref representation + :rtype: ir.Value + """ + + memref_ty = BuiltinMemRefType.get( + shape=[], + element_type=self.type.value_type, + layout=None, + memory_space=ir.Attribute.parse( + str(self.memspace.value if self.memspace != AddressSpace.rmem else 0) + ), + loc=loc, + ) + idx_ty = Int64.mlir_type + offset = Int64(0).ir_value(loc=loc, ip=ip) + + memref_desc_ty = llvm.StructType.get_literal( + [self.llvm_ptr.type, self.llvm_ptr.type, idx_ty] + ) + memref_desc = llvm.mlir_undef(memref_desc_ty, loc=loc, ip=ip) + memref_desc = llvm.insertvalue(memref_desc, self.llvm_ptr, [0], loc=loc, ip=ip) + memref_desc = llvm.insertvalue(memref_desc, self.llvm_ptr, [1], loc=loc, ip=ip) + memref_desc = llvm.insertvalue(memref_desc, offset, [2], loc=loc, ip=ip) + return builtin.unrealized_conversion_cast( + [memref_ty], [memref_desc], loc=loc, ip=ip + ) + + @dsl_user_op + def __add__( # type: ignore[override] + self, + offset: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: """ Offset the pointer by elements of a layout's codomain. @@ -1461,20 +1743,38 @@ class _Pointer(Pointer): :return: A new pointer offset by the specified amount :rtype: ir.Value """ - offset = _pack_int_tuple(offset, loc=loc, ip=ip) # type: ignore + offset = _pack_int_tuple(offset, loc=loc, ip=ip) return _cute_ir.add_offset(self.value, offset=offset, loc=loc, ip=ip) @dsl_user_op - def __radd__(self, offset: Int, *, loc=None, ip=None) -> Pointer: + def __radd__( + self, + offset: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: return self.__add__(offset, loc=loc, ip=ip) @dsl_user_op - def __sub__(self, offset: Int, *, loc=None, ip=None) -> Pointer: - return self.__add__(-offset, loc=loc, ip=ip) # type: ignore + def __sub__( + self, + offset: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: + return self.__add__(-offset, loc=loc, ip=ip) @dsl_user_op @lru_cache_ir() - def toint(self, *, loc=None, ip=None): + def toint( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Numeric: + res_type: Type[Integer] if self.memspace in (AddressSpace.gmem, AddressSpace.generic): res_type = Int64 else: @@ -1485,7 +1785,13 @@ class _Pointer(Pointer): ) @dsl_user_op - def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + def align( + self, + min_align: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: """ Align a pointer to a specified byte alignment. @@ -1534,7 +1840,13 @@ class _Pointer(Pointer): #################################################################################################### -def _op_wrapper(op_fn, input, *, loc=None, ip=None): +def _op_wrapper( + op_fn: Any, + input: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: from .tensor import _Tensor if isinstance(input, Tensor): @@ -1546,12 +1858,69 @@ def _op_wrapper(op_fn, input, *, loc=None, ip=None): return op_fn(input, loc=loc, ip=ip) +def ModeOpDecorator(func: Any) -> Any: + class ModeOp: + """ + A generic class for operations that support mode indexing. + + This enables syntax like: + op(obj) <==> op(obj, mode=[]) # Apply op to obj with no mode filtering + op[0](obj) <==> op(obj, mode=[0]) # Apply op to obj after getting mode 0 + op[0,1](obj) <==> op(obj, mode=[0,1]) # Apply op to obj after getting modes (0,1) + """ + + def __init__(self, func: Any, mode: Union[Tuple[int, ...], int] = ()) -> None: + """ + Initialize ModeOp. + """ + self.func = func + # Functions like cute.size are written to take Lists. + # ModeOp works better with tuples. + # For now, handle the conversion internally. + self.mode = ( + tuple(mode) + if isinstance(mode, list) + else wrap(mode) + if mode is not None + else () + ) + + def __call__( + self, + obj: Any, + mode: Union[Tuple[int, ...], List[int], int, None] = (), + **kwargs: Any, + ) -> Any: + """Apply the function with optional mode specification.""" + mode = ( + tuple(mode) + if isinstance(mode, list) + else wrap(mode) + if mode is not None + else () + ) + return self.func(obj, mode=list(self.mode + mode), **kwargs) + + def __getitem__(self, mode: Union[Tuple[int, ...], int]) -> "ModeOp": + """Return a new instance with new modes appended to existing modes.""" + mode = ( + tuple(mode) + if isinstance(mode, list) + else wrap(mode) + if mode is not None + else () + ) + return ModeOp(self.func, self.mode + mode) + + return ModeOp(func) + + # # Utilities # -def is_valid_leaf(a) -> bool: +def is_valid_leaf(a: object) -> bool: """ Returns whether `a` has a type that is valid for a CuTe tuple's leaf. """ @@ -1562,7 +1931,7 @@ def is_valid_leaf(a) -> bool: ) -def is_static(x: Any) -> bool: +def is_static(x: object) -> bool: """Check if a value is statically known at compile time. In CuTe, static values are those whose values are known at compile time, @@ -1601,7 +1970,7 @@ def is_static(x: Any) -> bool: elif isinstance(x, _ComposedLayout): return _cute_ir.is_static(x.type) elif is_dynamic_expression(x): - return _cute_ir.is_static(x.type) + return _cute_ir.is_static(x.type) # type: ignore[attr-defined] elif isinstance(x, (bool, int, float)) or x is None: return True else: @@ -1637,7 +2006,7 @@ def _tuple_str(t: Tuple[Any, ...]) -> str: Constructs a string representation of a python tuple without calling __repr__ on its elements. """ - def construct_inner_str(t) -> str: + def construct_inner_str(t: Any) -> str: if not isinstance(t, tuple): return pretty_str(t) res = "" @@ -1652,7 +2021,7 @@ def _tuple_str(t: Tuple[Any, ...]) -> str: return res -def pretty_str(arg) -> str: +def pretty_str(arg: object) -> str: """ Constructs a concise readable pretty string. """ @@ -1668,7 +2037,12 @@ def pretty_str(arg) -> str: @dsl_user_op -def printf(*args, loc=None, ip=None, end="\n") -> None: +def printf( + *args: Any, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + end: str = "\n", +) -> None: """ Print one or more values with optional formatting. @@ -1724,7 +2098,7 @@ def printf(*args, loc=None, ip=None, end="\n") -> None: else: fmt = "{}" + ", {}" * (len(args) - 1) + end - def process_arg(arg): + def process_arg(arg: Any) -> Any: arg0 = arg.value if isinstance(arg, Numeric) else arg if isinstance(arg0, ir.Value): @@ -1737,12 +2111,12 @@ def printf(*args, loc=None, ip=None, end="\n") -> None: return const(arg0, Float32) elif has_underscore(arg0): # Assume it's a coordinate - return _pack_coord(arg0) + return _pack_coord(arg0) # type: ignore[arg-type] elif has_scaled_basis(arg0): # Assume it's a stride - return _pack_stride(arg0) + return _pack_stride(arg0) # type: ignore[arg-type] elif is_int_tuple(arg0): - return _pack_int_tuple(arg0) + return _pack_int_tuple(arg0) # type: ignore[arg-type] elif isinstance(arg0, tuple): # Assume it's a tile return _pack_tile(arg0) @@ -1760,12 +2134,17 @@ def printf(*args, loc=None, ip=None, end="\n") -> None: else: raise TypeError(f"unsupported argument type in printf, got {type(arg)}") - args = [process_arg(a) for a in args] - _cute_ir.print_(args, fmt=fmt, loc=loc, ip=ip) + processed_args = [process_arg(a) for a in args] + _cute_ir.print_(processed_args, fmt=fmt, loc=loc, ip=ip) @dsl_user_op -def front(input, *, loc=None, ip=None): +def front( + input: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: """Recursively get the first element of input. This function traverses a hierarchical structure (like a layout or tensor) @@ -1789,7 +2168,13 @@ def front(input, *, loc=None, ip=None): @dsl_user_op -def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: +def is_major( + mode: Union[int, List[int]], + stride: Stride, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> bool: """ Check whether a mode in stride is the major mode. """ @@ -1800,10 +2185,26 @@ def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: @dsl_user_op -def assume(src, divby=None, *, loc=None, ip=None): +def assume( + src: Any, + divby: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: if divby is None: return src + if not isinstance(divby, int) or divby <= 0: + raise ValueError(f"Expected `divby` to be a positive integer, got {divby}") + + if isinstance(src, int): + if src % divby != 0: + raise ValueError( + f"Expected {src} to be divisible by {divby}, got {src % divby}" + ) + return src + if isinstance(src, Integer): width = type(src).width src_val = src.ir_value(loc=loc, ip=ip) @@ -1817,7 +2218,14 @@ def assume(src, divby=None, *, loc=None, ip=None): @dsl_user_op -def make_swizzle(b, m, s, *, loc=None, ip=None): +def make_swizzle( + b: int, + m: int, + s: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Swizzle: # canonicalize to <0, 4, 3> for identity swizzle (as compiler assumes <0, 4, 3>) if not isinstance(b, int) or not isinstance(m, int) or not isinstance(s, int): raise ValueError("b, m, and s must be int") @@ -1829,12 +2237,22 @@ def make_swizzle(b, m, s, *, loc=None, ip=None): @dsl_user_op -def static(value, *, loc=None, ip=None): +def static( + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: return _cute_ir.static(value, loc=loc, ip=ip) @dsl_user_op -def get_leaves(value, *, loc=None, ip=None): +def get_leaves( + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: return _cute_ir.get_leaves(value, loc=loc, ip=ip) @@ -1860,12 +2278,9 @@ def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: .. code-block:: python - >>> depth(1) - 0 - >>> depth((1, 2)) - 1 - >>> depth(((1, 2), (3, 4))) - 2 + depth(1) # 0 + depth((1, 2)) # 1 + depth(((1, 2), (3, 4))) # 2 """ if type(a) is tuple: if not a: @@ -1877,8 +2292,9 @@ def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: return 0 +@ModeOpDecorator @lru_cache_ir() -def rank(a: Union[XTuple, Layout, "ComposedLayout"], mode: List[int] = []) -> int: # type: ignore +def rank(a: Union[XTuple, Layout, "ComposedLayout"], mode: List[int] = []) -> int: """Returns the rank (dimensionality) of a tuple, layout, or tensor. The rank of a tuple is its length. For layouts and tensors, the rank is @@ -1896,7 +2312,9 @@ def rank(a: Union[XTuple, Layout, "ComposedLayout"], mode: List[int] = []) -> in if isinstance(a, (Layout, ComposedLayout, Tensor)): return rank(a.shape, mode) - if (not isinstance(mode, list)) or any(not isinstance(m, int) for m in mode): + # Guaranteed by ModeOpDecorator + assert isinstance(mode, list) + if any(not isinstance(m, int) for m in mode): raise ValueError(f"Expected 'mode' to be a list of int, but got {mode}") if mode: @@ -1984,14 +2402,39 @@ def is_weakly_congruent( @overload -def get(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... +def get( + input: Layout, + mode: Any = ..., + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def get(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... +def get( + input: ComposedLayout, + mode: Any = ..., + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ComposedLayout: ... @overload -def get(input: XTuple, mode, *, loc=None, ip=None) -> XTuple: ... +def get( + input: XTuple, + mode: Any = ..., + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> XTuple: ... -def get(input, mode: List[int], *, loc=None, ip=None): +@ModeOpDecorator +def get( + input: Any, + mode: List[int] = [], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: """Extract a specific element or sub-layout from a layout or tuple. This function recursively traverses the input according to the mode indices, @@ -2001,7 +2444,7 @@ def get(input, mode: List[int], *, loc=None, ip=None): :param input: The input layout or tuple to extract from :type input: Layout, ComposedLayout, tuple :param mode: Indices specifying the path to traverse for extraction - :type mode: List[int] + :type mode: int or list of ints :param loc: Source location for MLIR, defaults to None :type loc: optional :param ip: Insertion point, defaults to None @@ -2044,26 +2487,51 @@ def get(input, mode: List[int], *, loc=None, ip=None): if isinstance(input, _ComposedLayout): input = input.value - res_ty = input.type.get_op_res_type(mode=mode) # type: ignore + res_ty = input.type.get_op_res_type(mode=mode) return _cute_ir.get(res_ty, input, mode=mode, loc=loc, ip=ip) @overload -def select(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... +def select( + input: Layout, + mode: Any = ..., + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def select(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... +def select( + input: ComposedLayout, + mode: Any = ..., + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ComposedLayout: ... @overload -def select(input: XTuple, mode, *, loc=None, ip=None) -> XTuple: ... +def select( + input: XTuple, + mode: Any = ..., + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> XTuple: ... +@ModeOpDecorator @dsl_user_op -def select(input, mode: List[int], *, loc=None, ip=None): +def select( + input: Any, + mode: List[int] = [], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: """Select modes from input. :param input: Input to select from :type input: Layout, ComposedLayout, tuple :param mode: Indices specifying which dimensions or elements to select - :type mode: List[int] + :type mode: int or list of ints :param loc: Source location for MLIR, defaults to None :type loc: optional :param ip: Insertion point, defaults to None @@ -2106,24 +2574,51 @@ def select(input, mode: List[int], *, loc=None, ip=None): @overload def group_modes( - input: Layout, begin: int, end: int, *, loc=None, ip=None + input: Layout, + begin: int, + end: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def group_modes( - input: ComposedLayout, begin: int, end: int, *, loc=None, ip=None + input: ComposedLayout, + begin: int, + end: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @overload def group_modes( - input: Tensor, begin: int, end: int, *, loc=None, ip=None + input: Tensor, + begin: int, + end: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: ... @overload def group_modes( - input: XTuple, begin: int, end: int, *, loc=None, ip=None + input: XTuple, + begin: int, + end: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> XTuple: ... @dsl_user_op -def group_modes(input, begin: int, end: Optional[int] = None, *, loc=None, ip=None): +def group_modes( + input: Union[Layout, ComposedLayout, Tensor, XTuple], + begin: int, + end: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, Tensor, XTuple]: """Group modes of a hierarchical tuple or layout into a single mode. This function groups a range of modes from the input object into a single mode, @@ -2182,19 +2677,47 @@ def group_modes(input, begin: int, end: Optional[int] = None, *, loc=None, ip=No @overload -def slice_(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... +def slice_( + src: Layout, + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def slice_( - src: _ComposedLayout, coord: Coord, *, loc=None, ip=None + src: _ComposedLayout, + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> _ComposedLayout: ... @overload -def slice_(src: Tensor, coord: Coord, *, loc=None, ip=None) -> Tensor: ... +def slice_( + src: Tensor, + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: ... @overload -def slice_(src: XTuple, coord: Coord, *, loc=None, ip=None) -> XTuple: ... +def slice_( + src: XTuple, + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> XTuple: ... @dsl_user_op -def slice_(src, coord: Coord, *, loc=None, ip=None): +def slice_( + src: Union[Layout, _ComposedLayout, Tensor, XTuple], + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, _ComposedLayout, Tensor, XTuple]: """Perform a slice operation on a source object using the given coordinate. This function implements CuTe's slicing operation which extracts a subset of elements @@ -2245,7 +2768,7 @@ def slice_(src, coord: Coord, *, loc=None, ip=None): * Selecting specific patterns of elements """ - def lift_slice(a, b): + def lift_slice(a: Any, b: Any) -> tuple: if isinstance(a, tuple): if (not isinstance(b, tuple)) or (len(a) != len(b)): raise ValueError("coord must be weakly congruent to src in slice_") @@ -2274,16 +2797,40 @@ def slice_(src, coord: Coord, *, loc=None, ip=None): @overload -def dice(src: Layout, dicer: Coord, *, loc=None, ip=None) -> Layout: ... +def dice( + src: Layout, + dicer: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def dice(src: ComposedLayout, dicer: Coord, *, loc=None, ip=None) -> ComposedLayout: ... +def dice( + src: ComposedLayout, + dicer: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ComposedLayout: ... @overload -def dice(src: XTuple, dicer: Coord, *, loc=None, ip=None) -> XTuple: ... +def dice( + src: XTuple, + dicer: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> XTuple: ... @dsl_user_op @lru_cache_ir() -def dice(src, dicer, *, loc=None, ip=None): +def dice( + src: Union[Layout, ComposedLayout, XTuple], + dicer: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, XTuple]: """Keep modes in input when it is paired with an integer in dicer. This function performs dicing operation on the input based on the dicer coordinate. @@ -2320,7 +2867,7 @@ def dice(src, dicer, *, loc=None, ip=None): if not is_static(dicer): raise ValueError(f"expects dicer to be static, but got {dicer}") - def lift_dice(a, b): + def lift_dice(a: Any, b: Any) -> tuple: if isinstance(a, tuple): if (not isinstance(b, tuple)) or (len(a) != len(b)): raise ValueError("dicer must be weakly congruent to input in dice") @@ -2350,7 +2897,14 @@ def dice(src, dicer, *, loc=None, ip=None): ) -def _extend(func, input, elem, up_to_rank, loc, ip): +def _extend( + func: Any, + input: Any, + elem: Any, + up_to_rank: Optional[int], + loc: Optional[ir.Location], + ip: Optional[ir.InsertionPoint], +) -> Any: if input is None: raise ValueError("No input provided for input") @@ -2386,20 +2940,42 @@ def _extend(func, input, elem, up_to_rank, loc, ip): @overload def prepend( - input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None + input: Layout, + elem: Layout, + up_to_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def prepend( - input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None + input: ComposedLayout, + elem: Layout, + up_to_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @overload def prepend( - input: XTuple, elem: XTuple, up_to_rank=None, *, loc=None, ip=None + input: XTuple, + elem: XTuple, + up_to_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> XTuple: ... @dsl_user_op -def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): +def prepend( + input: Union[Layout, ComposedLayout, XTuple], + elem: Any, + up_to_rank: Union[None, int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, XTuple]: """Extend input to rank up_to_rank by prepending elem in front of input. This function extends the input object by prepending elements to reach a desired rank. @@ -2441,20 +3017,42 @@ def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=No @overload def append( - input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None + input: Layout, + elem: Layout, + up_to_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def append( - input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None + input: ComposedLayout, + elem: Layout, + up_to_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @overload def append( - input: XTuple, elem: XTuple, up_to_rank=None, *, loc=None, ip=None + input: XTuple, + elem: XTuple, + up_to_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> XTuple: ... @dsl_user_op -def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): +def append( + input: Union[Layout, ComposedLayout, XTuple], + elem: Any, + up_to_rank: Union[None, int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, XTuple]: """Extend input to rank up_to_rank by appending elem to the end of input. This function extends the input object by appending elements to reach a desired rank. @@ -2502,7 +3100,11 @@ def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=Non @dsl_user_op def prepend_ones( - t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None + t: Tensor, + up_to_rank: Union[None, int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: from .tensor import make_tensor @@ -2513,18 +3115,32 @@ def prepend_ones( @overload def append_ones( - t: Layout, up_to_rank: Union[None, int] = None, *, loc=None, ip=None + t: Layout, + up_to_rank: Union[None, int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def append_ones( - t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None + t: Tensor, + up_to_rank: Union[None, int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: ... @dsl_user_op -def append_ones(t, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): +def append_ones( + t: Union[Layout, Tensor], + up_to_rank: Union[None, int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: from .tensor import make_tensor if isinstance(t, Tensor): @@ -2537,7 +3153,7 @@ def append_ones(t, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): raise TypeError(f"expects Tensor or Layout, but got {type(t)}") -def repeat_as_tuple(x, n) -> tuple: +def repeat_as_tuple(x: Any, n: int) -> tuple: """Creates a tuple with x repeated n times. This function creates a tuple by repeating the input value x n times. @@ -2563,7 +3179,7 @@ def repeat_as_tuple(x, n) -> tuple: return (x,) * n -def repeat(x, n): +def repeat(x: Any, n: int) -> Any: """Creates an object by repeating x n times. This function creates an object by repeating the input value x n times. @@ -2591,7 +3207,7 @@ def repeat(x, n): return x if n == 1 else (x,) * n -def repeat_like(x, target): +def repeat_like(x: Any, target: Any) -> Any: """Creates an object congruent to target and filled with x. This function recursively creates a nested tuple structure that matches the structure @@ -2631,7 +3247,7 @@ def flatten(a: Tensor) -> Tensor: ... def flatten(a: XTuple) -> XTuple: ... -def flatten(a): +def flatten(a: Union[Layout, Tensor, XTuple]) -> Union[Layout, Tensor, XTuple]: """Flattens a CuTe data structure into a simpler form. For tuples, this function flattens the structure into a single-level tuple. @@ -2669,16 +3285,30 @@ def flatten(a): @overload def filter_zeros( - input: Layout, *, target_profile=None, loc=None, ip=None + input: Layout, + *, + target_profile: Optional[Stride] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def filter_zeros( - input: Tensor, *, target_profile=None, loc=None, ip=None + input: Tensor, + *, + target_profile: Optional[Stride] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: ... @dsl_user_op -def filter_zeros(input, *, target_profile=None, loc=None, ip=None): +def filter_zeros( + input: Union[Layout, Tensor], + *, + target_profile: Optional[Stride] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: """Filter out zeros from a layout or tensor. This function removes zero-stride dimensions from a layout or tensor. @@ -2705,15 +3335,35 @@ def filter_zeros(input, *, target_profile=None, loc=None, ip=None): @overload -def filter(input: Layout, *, loc=None, ip=None) -> Layout: ... +def filter( + input: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def filter(input: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... +def filter( + input: ComposedLayout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ComposedLayout: ... @overload -def filter(input: Tensor, *, loc=None, ip=None) -> Tensor: ... +def filter( + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: ... @dsl_user_op -def filter(input, *, loc=None, ip=None): +def filter( + input: Union[Layout, ComposedLayout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, Tensor]: """Filter a layout or tensor. This function filters a layout or tensor according to CuTe's filtering rules. @@ -2743,13 +3393,14 @@ def filter(input, *, loc=None, ip=None): return _cute_ir.filter(input, loc=loc, ip=ip) +@ModeOpDecorator @dsl_user_op def size( a: Union[IntTuple, Shape, Layout, ComposedLayout, Tensor], mode: List[int] = [], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Int: """Return size of domain of layout or tensor. @@ -2761,7 +3412,7 @@ def size( :param a: The input object whose size to compute :type a: IntTuple, Shape, Layout, ComposedLayout or Tensor :param mode: List of mode(s) for size calculation. If empty, computes total size, defaults to [] - :type mode: list of int, optional + :type mode: int or list of ints, optional :param loc: Source location for MLIR, defaults to None :type loc: optional :param ip: Insertion point, defaults to None @@ -2781,16 +3432,22 @@ def size( if not isinstance(a, (Layout, ComposedLayout, Tensor)): a_val = _pack_int_tuple(a, loc=loc, ip=ip) elif isinstance(a, (ComposedLayout, Tensor)): - a_val = a.value + a_val = a.value # type: ignore[union-attr] else: a_val = a res = _cute_ir.size(a_val, mode=mode, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore + return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore[return-value] @dsl_user_op -def shape_div(lhs: Shape, rhs: Shape, *, loc=None, ip=None) -> Shape: +def shape_div( + lhs: Shape, + rhs: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Shape: """Perform element-wise division of shapes. This function performs element-wise division between two shapes. @@ -2813,7 +3470,13 @@ def shape_div(lhs: Shape, rhs: Shape, *, loc=None, ip=None) -> Shape: @dsl_user_op -def ceil_div(input: Shape, tiler: Tiler, *, loc=None, ip=None) -> Shape: +def ceil_div( + input: Shape, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Shape: """ Compute the ceiling division of a target shape by a tiling specification. @@ -2883,7 +3546,11 @@ def round_up(a: IntTuple, b: IntTuple) -> IntTuple: @dsl_user_op def make_layout( - shape: Shape, *, stride: Union[Stride, None] = None, loc=None, ip=None + shape: Shape, + *, + stride: Union[Stride, None] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: """Create a CuTe Layout object from shape and optional stride information. @@ -2936,10 +3603,10 @@ def make_layout( shape_val = _pack_shape(shape, loc=loc, ip=ip) if stride is not None: stride_val = _pack_stride(stride, loc=loc, ip=ip) - layout_ty = _cute_ir.LayoutType.get(shape_val, stride_val) + layout_ty = _cute_ir.LayoutType.get(shape_val.type, stride_val.type) else: stride_val = None - layout_ty = _cute_ir.LayoutType.get(shape_val) + layout_ty = _cute_ir.LayoutType.get(shape_val.type) return _cute_ir.make_layout( layout_ty, shape=shape_val, stride=stride_val, loc=loc, ip=ip @@ -2947,7 +3614,12 @@ def make_layout( @dsl_user_op -def make_identity_layout(shape: Shape, *, loc=None, ip=None) -> Layout: +def make_identity_layout( + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: """Create an identity layout with the given shape. An identity layout maps logical coordinates directly to themselves without any transformation. @@ -2983,7 +3655,13 @@ def make_identity_layout(shape: Shape, *, loc=None, ip=None) -> Layout: @dsl_user_op -def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Layout: +def make_ordered_layout( + shape: Shape, + order: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: """Create a layout with a specific ordering of dimensions. This function creates a layout where the dimensions are ordered according to the @@ -3026,7 +3704,12 @@ def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Lay @dsl_user_op -def make_layout_like(input: Union[Layout, Tensor], *, loc=None, ip=None) -> Layout: +def make_layout_like( + input: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: if isinstance(input, Tensor): layout = input.layout else: @@ -3036,7 +3719,15 @@ def make_layout_like(input: Union[Layout, Tensor], *, loc=None, ip=None) -> Layo class _ComposedLayoutWithInnerFunc(ComposedLayout): @dsl_user_op - def __init__(self, inner, offset, outer, *, loc=None, ip=None): + def __init__( + self, + inner: Any, + offset: IntTuple, + outer: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self._inner = inner self._offset = offset self._outer = outer @@ -3044,7 +3735,13 @@ class _ComposedLayoutWithInnerFunc(ComposedLayout): self._offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) @dsl_user_op - def __call__(self, coord, *, loc=None, ip=None): + def __call__( + self, + coord: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: delta = self._outer(coord) delta_val = _pack_int_tuple(delta, loc=loc, ip=ip) @@ -3053,37 +3750,62 @@ class _ComposedLayoutWithInnerFunc(ComposedLayout): return self._inner(offset_new) - def __str__(self): + def __str__(self) -> str: return f"({self._inner} o {self._offset} o {self._outer})" @property - def type(self): + def type(self) -> Any: raise ValueError("type is not supported for customized composed layouts") @property - def is_normal(self): + def is_normal(self) -> bool: return False @property - def inner(self, *, loc=None, ip=None): + def inner( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: return self._inner @property - def offset(self, *, loc=None, ip=None): + def offset( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> IntTuple: return self._offset @property - def outer(self, *, loc=None, ip=None): + def outer( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return self._outer @property - def shape(self, *, loc=None, ip=None): + def shape( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: return self._outer.shape @dsl_user_op def make_composed_layout( - inner, offset: IntTuple, outer: Layout, *, loc=None, ip=None + inner: Any, + offset: IntTuple, + outer: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: """Create a composed layout by composing an inner transformation with an outer layout. @@ -3135,10 +3857,15 @@ def make_composed_layout( return _ComposedLayoutWithInnerFunc(inner, offset, outer, loc=loc, ip=ip) +@ModeOpDecorator @dsl_user_op def cosize( - a: Union[Layout, ComposedLayout, Tensor], mode: List[int] = [], *, loc=None, ip=None -): + a: Union[Layout, ComposedLayout, Tensor], + mode: List[int] = [], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int: """Return size of codomain of layout or tensor. Return static value if type is static. For a layout ``L = S:D`` where ``S`` is the shape and ``D`` is the stride, the codomain size is the @@ -3162,7 +3889,7 @@ def cosize( :type a: Union[Layout, ComposedLayout, Tensor] :param mode: List of mode(s) for cosize calculation. If empty, calculates over all modes. If specified, calculates cosize only for the given modes. - :type mode: List[int], optional + :type mode: int or list of ints, optional :param loc: Location information for diagnostics, defaults to None :type loc: optional :param ip: Instruction pointer for diagnostics, defaults to None @@ -3179,21 +3906,24 @@ def cosize( res = _cute_ir.cosize(a.value, mode=mode, loc=loc, ip=ip) else: res = _cute_ir.cosize(a, mode=mode, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore[return-value] @dsl_user_op def size_in_bytes( - dtype: Type[Numeric], + dtype: Union[ + Type[Numeric], + ], layout: Union[Layout, ComposedLayout, None], *, - loc=None, - ip=None, -) -> Union[int, Integer]: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int: """Calculate the size in bytes based on its data type and layout. The result is rounded up to the nearest byte. + Supports both regular Numeric types. :param dtype: The DSL numeric data type - :type dtype: Type[Numeric] + :type dtype: Union[Type[Numeric]] :param layout: The layout of the elements. If None, the function returns 0 :type layout: Layout, optional :param loc: Location information for diagnostics, defaults to None @@ -3203,7 +3933,12 @@ def size_in_bytes( :return: The total size in bytes. Returns 0 if the layout is None :rtype: int """ - if not isinstance(dtype, NumericMeta): + if not isinstance( + dtype, + ( + NumericMeta, + ), + ): raise TypeError(f"dtype must be a Numeric, but got {dtype}") size_in_elem = 0 @@ -3224,11 +3959,17 @@ def size_in_bytes( else: size_in_elem = cosize(layout, loc=loc, ip=ip) - return ceil_div(size_in_elem * dtype.width, 8, loc=loc, ip=ip) # type: ignore + return ceil_div(size_in_elem * dtype.width, 8, loc=loc, ip=ip) @dsl_user_op -def coalesce(input, *, target_profile: Coord = None, loc=None, ip=None): +def coalesce( + input: Union[Layout, ComposedLayout, Tensor], + *, + target_profile: Optional[Coord] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, Tensor]: if target_profile: profile_val = _pack_coord(target_profile, loc=loc, ip=ip) else: @@ -3240,7 +3981,13 @@ def coalesce(input, *, target_profile: Coord = None, loc=None, ip=None): @dsl_user_op -def crd2idx(coord: Coord, layout, *, loc=None, ip=None): +def crd2idx( + coord: Coord, + layout: Union[Layout, ComposedLayout, tuple, int], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int: """ Convert a multi-dimensional coordinate into a value using the specified layout. @@ -3281,23 +4028,41 @@ def crd2idx(coord: Coord, layout, *, loc=None, ip=None): layout = layout.value res = _cute_ir.crd2idx(crd_val, layout, loc=loc, ip=ip) - return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore + return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore[return-value] @overload -def idx2crd(idx: Int, shape: Int, *, loc=None, ip=None) -> Int: ... +def idx2crd( + idx: Int, + shape: Int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int: ... @overload -def idx2crd(idx: IntTuple, shape: Tuple, *, loc=None, ip=None) -> Tuple: ... +def idx2crd( + idx: IntTuple, + shape: Tuple, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple: ... @dsl_user_op -def idx2crd(idx, shape, *, loc=None, ip=None): +def idx2crd( + idx: IntTuple, + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> IntTuple: """ - Convert a linear index back into a multi-dimensional coordinate using the specified layout. + Convert a linear index back into a nested coordinate using the specified layout. - Mapping from a linear index to the corresponding multi-dimensional coordinate in the layout's coordinate space. + Mapping from a linear index to the corresponding nested coordinate in the layout's coordinate space. It essentially "unfolds" a linear index into its constituent coordinate components. :param idx: The linear index to convert back to coordinates. @@ -3319,10 +4084,10 @@ def idx2crd(idx, shape, *, loc=None, ip=None): @cute.jit def foo(): coord = cute.idx2crd(11, (5, 4)) - # idx2crd is always col-major + # idx2crd is always lexicographical ordering (left-to-right) # For shape (m, n, l, ...), coord = (idx % m, idx // m % n, idx // m // n % l, ... # Computed as: (11 % 5, 11 // 5 % 4) = (1, 2) - print(coord) + cute.printf("coord: {}", coord) foo() # Expected output: (1, 2) """ @@ -3334,15 +4099,64 @@ def idx2crd(idx, shape, *, loc=None, ip=None): return _unpack_x_tuple(res, loc=loc, ip=ip) +@dsl_user_op +def increment_coord( + coord: Coord, + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Coord: + """ + Colexicographically increment a coordinate within a coordinate space defined by a shape. + + Increments the leftmost mode first. When a mode reaches its + shape limit, it wraps to 0 and carries to the next mode. + + :param coord: The coordinate to increment. + :type coord: Coord + :param shape: The shape defining the coordinate space bounds. + :type shape: Shape + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :returns: The incremented coordinate. + :rtype: Coord + :raises ValueError: If the coordinate and shape are not congruent or if the coordinate contains an underscore. + + **Example:** + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + coord = cute.increment_coord((2, 0, 0), (3, 3, 3)) + # Increments colexicographically: (2,0,0) -> (0,1,0) + cute.printf("coord: {}", coord) + foo() # Expected output: coord: (0, 1, 0) + """ + if has_underscore(coord): + raise ValueError("coord cannot contain underscores") + if not is_congruent(coord, shape): + raise ValueError("coord and shape must be congruent") + + coord_val = _pack_coord(coord, loc=loc, ip=ip) + shape_val = _pack_shape(shape, loc=loc, ip=ip) + res = _cute_ir.increment_coord(coord_val, shape_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + @dsl_user_op def recast_layout( new_type_bits: int, old_type_bits: int, src_layout: Union[Layout, ComposedLayout], *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: """ Recast a layout from one data type to another. @@ -3389,11 +4203,17 @@ def recast_layout( src_layout = src_layout.value return _cute_ir.recast_layout( new_type_bits, old_type_bits, src_layout, loc=loc, ip=ip - ) # type: ignore + ) @dsl_user_op -def slice_and_offset(coord, src, *, loc=None, ip=None): +def slice_and_offset( + coord: Coord, + src: Union[Layout, ComposedLayout], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple: layout = slice_(src, coord, loc=loc, ip=ip) offset = crd2idx(coord, src, loc=loc, ip=ip) return layout, offset @@ -3402,7 +4222,11 @@ def slice_and_offset(coord, src, *, loc=None, ip=None): @dsl_user_op @lru_cache_ir() def shape( - input: Union[Shape, Tensor, Layout, Tile], *, mode=None, loc=None, ip=None + input: Union[Shape, Tensor, Layout, Tile], + *, + mode: Optional[int] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Shape: """Returns the shape of a tensor, layout or tiler. @@ -3457,117 +4281,84 @@ def shape( @dsl_user_op def recast_ptr( ptr: Pointer, - swizzle_=None, + swizzle_: Optional[Swizzle] = None, dtype: Optional[Type[Numeric]] = None, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: - cvt_type = None + cvt_ty = None if dtype is not None: - if cvt_type is None: + if cvt_ty is None: if not isclass(dtype) or not issubclass(dtype, Numeric): raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - cvt_type = T.i8() if dtype is Boolean else dtype.mlir_type + cvt_ty = T.i8() if dtype is Boolean else dtype.mlir_type - dtype = cvt_type - value_type = ptr.type.value_type if dtype is None else dtype - swizzle = swizzle_.type.attribute if swizzle_ is not None else None - res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle) + value_ty = cvt_ty or ptr.type.value_type + swizzle_attr = swizzle_.type.attribute if swizzle_ is not None else None + res_ty = _cute_ir.PtrType.get(value_ty, ptr.memspace, ptr.alignment, swizzle_attr) # type: ignore[attr-defined] return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip) @dsl_user_op def make_ptr( - dtype: Union[Type[Numeric], None], - value, - mem_space: AddressSpace = AddressSpace.generic, + dtype: Union[ + Type[Numeric], + None, + ], + value: Union[int, Integer, ir.Value], + mem_space: Optional[AddressSpace] = None, *, - assumed_align=None, - loc=None, - ip=None, + assumed_align: Optional[int] = None, + swizzle_: Optional[Swizzle] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Pointer: - # Perform checks - if dtype is None or not isinstance(dtype, NumericMeta): - raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}") - if not isinstance(mem_space, AddressSpace): - raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") + cvt_type = None + if dtype is not None: + if cvt_type is None: + if not isinstance(dtype, NumericMeta): + raise TypeError("expects dtype to be a type of Numeric") + cvt_type = dtype.mlir_type if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): + llvm_ptr_ty = llvm.PointerType(value.type) + mem_space = AddressSpace(llvm_ptr_ty.address_space) value = llvm.ptrtoint(T.i64(), value) + if not is_integer(value): raise TypeError(f"expects integer value, but got {type(value)}") + if mem_space is None: + mem_space = AddressSpace.generic + if not isinstance(mem_space, AddressSpace): + raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") + # TMEM addresses are 32b wide is_tmem = mem_space == AddressSpace.tmem - value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) + value = Int32(value) if is_tmem else Int64(value) # Set the alignment of the pointer - bytes_per_elt = max(1, dtype.width // 8) + bytes_per_elt = max(1, dtype.width // 8) # type: ignore[union-attr] if assumed_align is None: assumed_align = bytes_per_elt + if bytes_per_elt % assumed_align != 0 and assumed_align % bytes_per_elt != 0: raise ValueError( f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa." ) + aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width) aligned_intptr = _cute_ir.assume( aligned_ty, value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip ) # Construct the pointer Type - data_ty = T.i8() if dtype is None else dtype.mlir_type - ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align) - return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) + data_ty = T.i8() if dtype is None else cvt_type + swizzle = swizzle_.type.attribute if swizzle_ is not None else None - -@dsl_user_op -def get_remote_smem_ptr_in_cluster( - smem_ptr: Pointer, - cta_rank_in_cluster: Int, - *, - loc=None, - ip=None, -) -> Pointer: - """ - Get the remote shared memory CuTe pointer in a cluster. - - :param smem_ptr: The current shared memory pointer - :type smem_ptr: Pointer - :param cta_rank_in_cluster: The peer CTA rank in cluster to get the remote pointer for - :type cta_rank_in_cluster: Int - :param loc: Source location for MLIR, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point, defaults to None - :type ip: Optional[InsertionPoint] - - :return: The remote shared memory CuTe pointer - :rtype: Pointer - - """ - cur_llvm_ptr = smem_ptr.llvm_ptr - remote_llvm_ptr = nvvm.mapa( - llvm.PointerType.get(7), # LLVM dsmem address space - cur_llvm_ptr, - Int32(cta_rank_in_cluster).ir_value(loc=loc, ip=ip), - loc=loc, - ip=ip, - ) - remote_llvm_ptr_cast = llvm.addrspacecast( - llvm.PointerType.get(AddressSpace.smem), remote_llvm_ptr, loc=loc, ip=ip - ) - remote_ptr = make_ptr( - smem_ptr.dtype, - remote_llvm_ptr_cast, - AddressSpace.smem, - assumed_align=smem_ptr.alignment, - loc=loc, - ip=ip, - ) - if const_expr(smem_ptr.value.type.is_swizzled): - sw = Swizzle(static(smem_ptr.value.type.swizzle_type)) - remote_ptr = recast_ptr( - remote_ptr, swizzle_=sw, dtype=smem_ptr.dtype, loc=loc, ip=ip - ) - return remote_ptr + ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align, swizzle) + ptr = _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) + ptr._dtype = dtype + return ptr # @@ -3577,20 +4368,38 @@ def get_remote_smem_ptr_in_cluster( @overload def composition( - lhs: Layout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None + lhs: Layout, + rhs: Union[Layout, Shape, Tile], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def composition( - lhs: ComposedLayout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None + lhs: ComposedLayout, + rhs: Union[Layout, Shape, Tile], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @overload def composition( - lhs: Tensor, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None + lhs: Tensor, + rhs: Union[Layout, Shape, Tile], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: ... @dsl_user_op -def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): +def composition( + lhs: Union[Layout, ComposedLayout, Tensor], + rhs: Union[Layout, Shape, Tile], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout, Tensor]: """ Compose two layout representations using the CuTe layout algebra. @@ -3646,7 +4455,11 @@ def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): @dsl_user_op def complement( - input: Layout, cotarget: Union[Layout, Shape], *, loc=None, ip=None + input: Layout, + cotarget: Union[Layout, Shape], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: """ Compute the complement layout of the input layout with respect to the cotarget. @@ -3690,7 +4503,12 @@ def complement( @dsl_user_op -def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: +def right_inverse( + input: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: if not isinstance(input, Layout): raise TypeError(f"Expected input of type Layout, but got {type(input)}") @@ -3698,7 +4516,12 @@ def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: @dsl_user_op -def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: +def left_inverse( + input: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: if not isinstance(input, Layout): raise TypeError(f"Expected input of type Layout, but got {type(input)}") @@ -3706,15 +4529,31 @@ def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: @overload -def logical_product(block: Layout, tiler: Tile, *, loc=None, ip=None) -> Layout: ... +def logical_product( + block: Layout, + tiler: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def logical_product( - block: ComposedLayout, tiler: Tile, *, loc=None, ip=None + block: ComposedLayout, + tiler: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def logical_product(block, tiler: Tile, *, loc=None, ip=None): +def logical_product( + block: Union[Layout, ComposedLayout], + tiler: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: if isinstance(block, _ComposedLayout): block = block.value @@ -3731,7 +4570,7 @@ def logical_product(block, tiler: Tile, *, loc=None, ip=None): tiler_rank = rank(tiler_val) block_rank = rank(block) res = tuple( - logical_product(block[i], tiler_val[i]) if i < tiler_rank else block[i] + logical_product(block[i], tiler_val[i]) if i < tiler_rank else block[i] # type: ignore[index] for i in range(block_rank) ) @@ -3741,15 +4580,31 @@ def logical_product(block, tiler: Tile, *, loc=None, ip=None): @overload -def zipped_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +def zipped_product( + block: Layout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def zipped_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None + block: ComposedLayout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def zipped_product(block, tiler: Layout, *, loc=None, ip=None): +def zipped_product( + block: Union[Layout, ComposedLayout], + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: if isinstance(block, _ComposedLayout): return _cute_ir.zipped_product(input=block.value, tiler=tiler, loc=loc, ip=ip) else: @@ -3757,15 +4612,31 @@ def zipped_product(block, tiler: Layout, *, loc=None, ip=None): @overload -def tiled_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +def tiled_product( + block: Layout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def tiled_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None + block: ComposedLayout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def tiled_product(block, tiler: Layout, *, loc=None, ip=None): +def tiled_product( + block: Union[Layout, ComposedLayout], + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: if isinstance(block, _ComposedLayout): return _cute_ir.tiled_product(input=block.value, tiler=tiler, loc=loc, ip=ip) else: @@ -3773,15 +4644,31 @@ def tiled_product(block, tiler: Layout, *, loc=None, ip=None): @overload -def flat_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +def flat_product( + block: Layout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def flat_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None + block: ComposedLayout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def flat_product(block, tiler: Layout, *, loc=None, ip=None): +def flat_product( + block: Union[Layout, ComposedLayout], + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: if isinstance(block, _ComposedLayout): return _cute_ir.flat_product(input=block.value, tiler=tiler, loc=loc, ip=ip) else: @@ -3789,15 +4676,31 @@ def flat_product(block, tiler: Layout, *, loc=None, ip=None): @overload -def raked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +def raked_product( + block: Layout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def raked_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None + block: ComposedLayout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def raked_product(block, tiler: Layout, *, loc=None, ip=None): +def raked_product( + block: Union[Layout, ComposedLayout], + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: if isinstance(block, _ComposedLayout): return _cute_ir.raked_product(input=block.value, tiler=tiler, loc=loc, ip=ip) else: @@ -3805,15 +4708,31 @@ def raked_product(block, tiler: Layout, *, loc=None, ip=None): @overload -def blocked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +def blocked_product( + block: Layout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload def blocked_product( - block: ComposedLayout, tiler: Layout, *, loc=None, ip=None + block: ComposedLayout, + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def blocked_product(block, tiler: Layout, *, loc=None, ip=None): +def blocked_product( + block: Union[Layout, ComposedLayout], + tiler: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: if isinstance(block, _ComposedLayout): return _cute_ir.blocked_product(input=block.value, tiler=tiler, loc=loc, ip=ip) else: @@ -3821,28 +4740,64 @@ def blocked_product(block, tiler: Layout, *, loc=None, ip=None): @overload -def logical_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +def logical_divide( + target: Layout, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... +def logical_divide( + target: Tensor, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: ... @dsl_user_op -def logical_divide(target, tiler: Tiler, *, loc=None, ip=None): - if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore +def logical_divide( + target: Union[Layout, Tensor], + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: + if isinstance(tiler, (int, tuple)): + tiler = _pack_tile(tiler, loc=loc, ip=ip) return _op_wrapper( partial(_cute_ir.logical_divide, tiler=tiler), target, loc=loc, ip=ip ) @overload -def zipped_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +def zipped_divide( + target: Layout, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... +def zipped_divide( + target: Tensor, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: ... @dsl_user_op -def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): +def zipped_divide( + target: Union[Layout, Tensor], + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: """ ``zipped_divide`` is ``logical_divide`` with Tiler modes and Rest modes gathered together: ``(Tiler,Rest)`` @@ -3872,21 +4827,44 @@ def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): tiler = (8, 8) result = cute.zipped_divide(layout, tiler) # result shape: ((8, 8), (16, 8)) """ + if not isinstance(tiler, Layout) and rank(target) < rank(tiler): + raise ValueError( + f"Expected rank(target) >= rank(tiler), but got rank(target)={rank(target)} and rank(tiler)={rank(tiler)}" + ) + if isinstance(tiler, tuple): - tiler = _pack_tile(tiler, loc=loc, ip=ip) # type: ignore + tiler = _pack_tile(tiler, loc=loc, ip=ip) return _op_wrapper( partial(_cute_ir.zipped_divide, tiler=tiler), target, loc=loc, ip=ip ) @overload -def tiled_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +def tiled_divide( + target: Layout, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... +def tiled_divide( + target: Tensor, + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: ... @dsl_user_op -def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): +def tiled_divide( + target: Union[Layout, Tensor], + tiler: Tiler, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) return _op_wrapper( @@ -3895,13 +4873,31 @@ def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): @overload -def flat_divide(target: Layout, tiler: Tile, *, loc=None, ip=None) -> Layout: ... +def flat_divide( + target: Layout, + tiler: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def flat_divide(target: Tensor, tiler: Tile, *, loc=None, ip=None) -> Tensor: ... +def flat_divide( + target: Tensor, + tiler: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: ... @dsl_user_op -def flat_divide(target, tiler: Tile, *, loc=None, ip=None): +def flat_divide( + target: Union[Layout, Tensor], + tiler: Tile, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: if isinstance(tiler, tuple): tiler = _pack_tile(tiler, loc=loc, ip=ip) return _op_wrapper( @@ -3910,13 +4906,17 @@ def flat_divide(target, tiler: Tile, *, loc=None, ip=None): # -# Higher-level utilties +# Higher-level utilities # @dsl_user_op def max_common_layout( - a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None + a: Union[Layout, Tensor], + b: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: from .tensor import _Tensor @@ -3939,7 +4939,11 @@ def max_common_layout( @dsl_user_op def max_common_vector( - a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None + a: Union[Layout, Tensor], + b: Union[Layout, Tensor], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> int: from .tensor import _Tensor @@ -3962,18 +4966,35 @@ def max_common_vector( @overload def tile_to_shape( - atom: Layout, trg_shape: Shape, order: Shape, *, loc=None, ip=None + atom: Layout, + trg_shape: Shape, + order: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @overload def tile_to_shape( - atom: ComposedLayout, trg_shape: Shape, order: Shape, *, loc=None, ip=None + atom: ComposedLayout, + trg_shape: Shape, + order: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op -def tile_to_shape(atom, trg_shape: Shape, order: Shape, *, loc=None, ip=None): - trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) # type: ignore - order = _pack_int_tuple(order, loc=loc, ip=ip) # type: ignore +def tile_to_shape( + atom: Union[Layout, ComposedLayout], + trg_shape: Shape, + order: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: + trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) + order = _pack_int_tuple(order, loc=loc, ip=ip) if isinstance(atom, _ComposedLayout): return _cute_ir.tile_to_shape(atom.value, trg_shape, order, loc=loc, ip=ip) @@ -3988,19 +5009,23 @@ def local_partition( index: Union[int, Numeric], proj: XTuple = 1, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: if isinstance(index, cutlass_arith.ArithValue): index_val = index else: - index_val = index.ir_value(loc=loc, ip=ip) + index_val = index.ir_value(loc=loc, ip=ip) # type: ignore[union-attr] if index_val.type.width > 32: raise NotImplementedError( f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" ) return _cute_ir.local_partition( - input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip + input=target.value, + tiler=dice(tiler, proj), + index=index_val, + loc=loc, + ip=ip, ) @@ -4009,10 +5034,10 @@ def local_tile( input: Tensor, tiler: Tiler, coord: Coord, - proj: XTuple = None, # type: ignore + proj: XTuple = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: """ Partition a tensor into tiles using a tiler and extract a single tile at the provided coordinate. @@ -4101,7 +5126,12 @@ def local_tile( @dsl_user_op def make_layout_image_mask( - lay: Layout, coord: Coord, mode: int, *, loc=None, ip=None + lay: Layout, + coord: Coord, + mode: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Int16: """ Makes a 16-bit integer mask of the image of a layout sliced at a given mode @@ -4124,7 +5154,7 @@ def make_layout_image_mask( raise ValueError("the mask may not fit into a 16-bit integer") # Replace the mode to keep with _ in the coordinate - slicer = tuple(None if idx == mode else x for idx, x in enumerate(coord)) + slicer = tuple(None if idx == mode else x for idx, x in enumerate(coord)) # type: ignore[arg-type] # Slice the layout with the slicer above and keep track of the offset sliced_lay, offset = slice_and_offset(slicer, lay, loc=loc, ip=ip) # Given that we replace only one mode with _, the rank of the slice should be 1 @@ -4135,13 +5165,13 @@ def make_layout_image_mask( # Create the mask of the image mcast_mask = Int16(0) - for i in range(size(sliced_lay)): # type: ignore + for i in range(size(sliced_lay)): mcast_mask = mcast_mask | (1 << sliced_lay(i)) mcast_mask <<= offset return Int16(mcast_mask) -def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: # type: ignore +def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: """ Find the leading dimension of a shape and stride. @@ -4159,7 +5189,7 @@ def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], Non * If no leading dimension is found, returns None """ - def pred_fn(val, pos): + def pred_fn(val: object, pos: Union[int, tuple]) -> bool: # skip dynamic values which can't be compared # find the candidate target val, stride at this position is 1 if (not is_dynamic_expression(val)) and (val == 1): @@ -4177,7 +5207,11 @@ def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], Non @dsl_user_op def make_layout_tv( - thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None + thr_layout: Layout, + val_layout: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Shape, Layout]: """Create a thread-value layout by repeating the val_layout over the thr_layout. @@ -4259,7 +5293,10 @@ def make_layout_tv( @dsl_user_op def get_nonswizzle_portion( - layout: Union[Layout, ComposedLayout], *, loc=None, ip=None + layout: Union[Layout, ComposedLayout], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Layout, ComposedLayout]: """ Extract the non-swizzle portion from a layout. @@ -4289,7 +5326,10 @@ def get_nonswizzle_portion( @dsl_user_op def get_swizzle_portion( - layout: Union[Layout, ComposedLayout], *, loc=None, ip=None + layout: Union[Layout, ComposedLayout], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Swizzle: """ Extract or create the swizzle portion from a layout. @@ -4322,6 +5362,61 @@ def get_swizzle_portion( raise TypeError(f"expects a Layout or ComposedLayout, but got {type(layout)}") +@dsl_user_op +def nullspace( + layout: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: + """ + Computes the nullspace (kernel) of a layout. + + Returns a layout l such that layout(l(i)) == 0 for all i < size(l), + nullspace(l) == make_layout(1, stride=0), + and size(l) == size(layout) / size(filter_zeros(layout)) + + :param layout: The layout to compute the nullspace of. + :type layout: Layout + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional + :type ip: optional + :returns: The nullspace of the layout + :rtype: Layout + :raises TypeError: If the layout is not a Layout. + """ + + if not isinstance(layout, Layout): + raise TypeError(f"expects a Layout, but got {type(layout)}") + + # Select all indices corresponds to stride 0 + flat_stride = wrap(flatten(layout.stride)) + + # Transform to get tuple of zeros and get the indices that are non zero + nullspace_indices = [] + for i in range(len(flat_stride)): + if is_static(flat_stride[i]) and flat_stride[i] == 0: + nullspace_indices.append(i) + + if len(nullspace_indices) == 0: + return make_layout(1, stride=0, loc=loc, ip=ip) + else: + flat_shape = flatten(shape(layout)) + # create a compact major left stride based on the flat shape + rstride = [1] * len(flat_shape) + for i in range(1, len(flat_shape)): + rstride[i] = flat_shape[i - 1] * rstride[i - 1] + + # Select all indices that map to 0 + return make_layout( + unwrap(tuple(flat_shape[i] for i in nullspace_indices)), + stride=unwrap(tuple(rstride[i] for i in nullspace_indices)), + loc=loc, + ip=ip, + ) + + ############################################################################## # User defined struct ############################################################################## @@ -4392,11 +5487,13 @@ class struct: _dtype: Optional[Numeric] = None _size: Optional[int] = None - def __new__(cls, name, bases, dct): + def __new__( + cls, name: str, bases: tuple[type, ...], dct: Dict[str, Any] + ) -> "struct._MemRangeMeta": new_cls = super().__new__(cls, name, bases, dct) return new_cls - def __getitem__(cls, params) -> Type["struct.MemRange"]: + def __getitem__(cls, params: tuple[Any, ...]) -> "Type[struct.MemRange]": # get params from syntax: struct.MemRange[dtype, size] if len(params) == 2: dtype, size = params @@ -4415,16 +5512,16 @@ class struct: return new_cls @property - def size(cls): + def size(cls) -> Optional[int]: return cls._size @property - def elem_width(cls): - return cls._dtype.width if cls._dtype is not Boolean else 8 + def elem_width(cls) -> int: + return cls._dtype.width if cls._dtype is not Boolean else 8 # type: ignore[union-attr] @property - def size_in_bytes(cls): - return cls.size * cls.elem_width // 8 + def size_in_bytes(cls) -> int: + return cls.size * cls.elem_width // 8 # type: ignore[operator] class MemRange(metaclass=_MemRangeMeta): """ @@ -4442,7 +5539,9 @@ class struct: :param base: The base address of the memory range. """ - def __init__(self, dtype, size, base): + def __init__( + self, dtype: Optional[Numeric], size: Optional[int], base: Optional[Pointer] + ) -> None: """ Initializes a new memory range. @@ -4455,23 +5554,34 @@ class struct: self._size: Optional[int] = size self._base: Optional[Pointer] = base - def __repr__(self): + def __repr__(self) -> str: return f"{object.__repr__(self)} " @dsl_user_op - def data_ptr(self, *, loc=None, ip=None) -> Pointer: + def data_ptr( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: """ Returns start pointer to the data in this memory range. :return: A pointer to the start of the memory range. :raises AssertionError: If the size of the memory range is negative. """ - assert self._size >= 0 + assert self._size is not None and self._size >= 0 return recast_ptr(self._base, dtype=self._dtype, loc=loc, ip=ip) @dsl_user_op def get_tensor( - self, layout, swizzle=None, dtype=None, *, loc=None, ip=None + self, + layout: Union[Layout, ComposedLayout], + swizzle: Optional[Swizzle] = None, + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: """ Creates a tensor from the memory range. @@ -4485,7 +5595,7 @@ class struct: """ from .tensor import make_tensor - assert self._size > 0 + assert self._size is not None and self._size > 0 # make tensor if isinstance(layout, ComposedLayout) and (swizzle is not None): raise TypeError("incompatible layout with swizzle") @@ -4495,7 +5605,13 @@ class struct: return type(res)(res, dtype=elem_type, loc=loc, ip=ip) @dsl_user_op - def __getitem__(self, index: int, *, loc=None, ip=None) -> Any: + def __getitem__( + self, + index: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: """ Returns the element at the specified index in the memory range. @@ -4503,12 +5619,19 @@ class struct: :return: The element at the specified index. :raises AssertionError: If the index is out of range. """ - assert (index >= 0) and (index < self._size) + assert self._size is not None and (index >= 0) and (index < self._size) ptr = self.data_ptr() + index return ptr.load(loc=loc, ip=ip) @dsl_user_op - def __setitem__(self, index: int, val, *, loc=None, ip=None): + def __setitem__( + self, + index: int, + val: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Set element value at the specified index in the memory range. @@ -4516,9 +5639,13 @@ class struct: :val: The element value at the specified index. :raises AssertionError: If the index is out of range. """ - assert (index >= 0) and (index < self._size) + assert self._size is not None and (index >= 0) and (index < self._size) ptr = self.data_ptr() + index - ptr.store(as_numeric(val).to(self._dtype), loc=loc, ip=ip) + ptr.store( + as_numeric(val).to(self._dtype), # type: ignore[call-overload] + loc=loc, + ip=ip, + ) # inner class for aligning a member type class _AlignMeta(type): @@ -4536,10 +5663,12 @@ class struct: _dtype: Optional[Any] = None _align: Optional[int] = None - def __new__(cls, name, bases, dct): + def __new__( + cls, name: str, bases: tuple[type, ...], dct: Dict[str, Any] + ) -> "struct._AlignMeta": return super().__new__(cls, name, bases, dct) - def __getitem__(cls, params) -> Any: + def __getitem__(cls, params: tuple[Any, ...]) -> Any: if len(params) == 2: dtype, align = params assert align > 0 @@ -4562,11 +5691,11 @@ class struct: return new_cls @property - def dtype(cls): + def dtype(cls) -> Optional[Any]: return cls._dtype @property - def align(cls): + def align(cls) -> Optional[int]: return cls._align class Align(metaclass=_AlignMeta): @@ -4586,10 +5715,10 @@ class struct: :ivar _ptr: The underlying pointer to the scalar value. """ - def __init__(self, ptr): - self._ptr: Optional[_Pointer] = ptr + def __init__(self, ptr: _Pointer) -> None: + self._ptr: _Pointer = ptr - def __repr__(self): + def __repr__(self) -> str: return f"{object.__repr__(self)} <{self.dtype}> " def __get_mlir_types__(self) -> List[ir.Type]: @@ -4598,14 +5727,19 @@ class struct: def __extract_mlir_values__(self) -> List[ir.Value]: return [self.value] - def __new_from_mlir_values__(self, values) -> Pointer: + def __new_from_mlir_values__(self, values: List[ir.Value]) -> Pointer: # type: ignore[override] ptr = _Pointer( values[0] if isinstance(values[0], ir.Value) else values[0].value ) return self.__class__(ptr) @dsl_user_op - def to_llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + def to_llvm_ptr( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: llvm_ptr_ty = llvm.PointerType.get( self._ptr.memspace.value if self._ptr.memspace != AddressSpace.rmem @@ -4626,18 +5760,18 @@ class struct: return self._ptr @property - def dtype(self) -> Numeric: + def dtype(self) -> Type[Numeric]: """ Get the data type of the scalar value. :return: The numeric data type of the underlying pointer. - :rtype: Numeric + :rtype: Type[Numeric] """ return self._ptr.dtype @property @deprecated("Using `struct.scalar` as pointer is deprecated.") - def value(self): + def value(self) -> ir.Value: """ Get the raw MLIR value of the underlying pointer. @@ -4660,7 +5794,7 @@ class struct: # util func for base dsl scalar types @staticmethod - def _is_scalar_type(dtype): + def _is_scalar_type(dtype: Any) -> bool: """ Checks if the given type is a scalar numeric type. @@ -4670,7 +5804,7 @@ class struct: return isinstance(dtype, type) and issubclass(dtype, Numeric) # calculate size and alignment - def __init__(self, cls): + def __init__(self, cls: type) -> None: """ Initializes a new struct decorator instance. @@ -4685,21 +5819,21 @@ class struct: self._offsets: Dict[str, int] = {} # Override `setattr` function for struct to assign scalar properly - def struct_setattr(self, name, value): + def struct_setattr(self: Any, name: str, value: Any) -> None: attr = getattr(self, name, None) if isinstance(attr, struct._ScalarData): value = as_numeric(value).to(attr.dtype) - attr.ptr.store(value) + attr.ptr.store(value) # type: ignore[attr-defined] else: raise ValueError(f"cannot assign value to `{name}` in {self.__name__}") type.__setattr__(self._cls, "__setattr__", struct_setattr) # Override `__repr__` function for struct info - def struct_repr(self): + def struct_repr(self: Any) -> str: return f"{object.__repr__(self)} <{self.__name__}> " - self._cls.__repr__ = struct_repr + type.__setattr__(self._cls, "__repr__", struct_repr) # Calculate the offsets and alignment offset = 0 @@ -4710,11 +5844,11 @@ class struct: # get alignment of member sub_align = 1 if isinstance(member, struct._AlignMeta): - sub_align = member.align + sub_align = member.align # type: ignore[assignment] member = member.dtype # switch addition order to support dynamic size - def add_offset(val): + def add_offset(val: Any) -> Any: return val + offset if isinstance(val, ir.Value) else offset + val # size of scalar @@ -4753,7 +5887,13 @@ class struct: # create the __init__ method for decorated struct @dsl_user_op - def __call__(self, base: Any, *, loc=None, ip=None) -> None: + def __call__( + self, + base: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: """ Creates a new instance of the decorated struct. @@ -4774,7 +5914,7 @@ class struct: obj = obj.dtype if struct._is_scalar_type(obj): ptr = recast_ptr(base + off, dtype=obj, loc=loc, ip=ip) - new_obj = struct._ScalarData(ptr) + new_obj: Any = struct._ScalarData(ptr) object.__setattr__(cls, name, new_obj) elif isinstance(obj, struct._MemRangeMeta): new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) @@ -4808,7 +5948,7 @@ class struct: # util func for aligning offset @staticmethod - def align_offset(offset, align): + def align_offset(offset: Any, align: int) -> Any: """ Return the round-up offset up to the next multiple of align. """ @@ -4832,7 +5972,7 @@ class union(struct): - The alignment is the maximum alignment of all objects - The size is the maximum size of all objects - **Usage:**Expand commentComment on line R4131 + **Usage:** .. code-block:: python @@ -4874,7 +6014,7 @@ class union(struct): :return: The decorated union class. """ - def __init__(self, cls): + def __init__(self, cls: type) -> None: """ Initializes a new cute.union decorator instance. @@ -4889,18 +6029,18 @@ class union(struct): object.__setattr__(self, "_offsets", {}) # Override `setattr` function for struct to assign scalar properly - def union_setattr(self, name, value): + def union_setattr(self: Any, name: str, value: Any) -> None: attr = getattr(self, name, None) if isinstance(attr, struct._ScalarData): value = as_numeric(value).to(attr.dtype) - attr.ptr.store(value) + attr.ptr.store(value) # type: ignore[attr-defined] else: raise ValueError(f"cannot assign value to `{name}` in {self.__name__}") type.__setattr__(self._cls, "__setattr__", union_setattr) # Override `__repr__` function for struct info - def union_repr(self): + def union_repr(self: Any) -> str: return f"{object.__repr__(self)} <{self.__name__}> " type.__setattr__(self._cls, "__repr__", union_repr) @@ -4917,7 +6057,7 @@ class union(struct): # Get alignment of object sub_align = 1 if isinstance(item, struct._AlignMeta): - sub_align = item.align + sub_align = item.align # type: ignore[assignment] item = item.dtype # Calculate size and alignment based on object type @@ -4946,7 +6086,13 @@ class union(struct): ) @dsl_user_op - def __call__(self, base: Any, *, loc=None, ip=None) -> None: + def __call__( + self, + base: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: """ Creates a new instance of the decorated union. @@ -4966,7 +6112,7 @@ class union(struct): obj = obj.dtype if struct._is_scalar_type(obj): ptr = recast_ptr(base + off, dtype=obj, loc=loc, ip=ip) - new_obj = struct._ScalarData(ptr) + new_obj: Any = struct._ScalarData(ptr) object.__setattr__(cls, name, new_obj) elif isinstance(obj, struct._MemRangeMeta): new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) @@ -4981,7 +6127,7 @@ class union(struct): ) return cls - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: raise TypeError("Cannot add a new field after initialization") def size_in_bytes(self) -> int: """ @@ -5031,23 +6177,29 @@ class FastDivmodDivisor: """ First-class FastDivmod divisor with operator overloading support. - This class wraps a FastDivmod divisor and enables natural Python operator syntax: + This class wraps a FastDivmod divisor and enables natural Python operator syntax. + + :ivar divisor: The original divisor value (publicly accessible) + :ivar _divisor_mlir: The FastDivmod divisor MLIR value (internal) + + **Example:** + + .. code-block:: python + quotient, remainder = divmod(dividend, divisor) quotient = dividend // divisor remainder = dividend % divisor - - :ivar _divisor: The FastDivmod divisor MLIR value """ @dsl_user_op def __init__( self, divisor: Integer, - is_power_of_2: bool = None, + is_power_of_2: Optional[bool] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Create a FastDivmod divisor for optimized division operations. @@ -5055,6 +6207,9 @@ class FastDivmodDivisor: :param is_power_of_2: Whether divisor is known to be a power of 2. Defaults to False. """ + # Store the original divisor value for public access + self._original_divisor = divisor + # Convert divisor to ir.Value for MLIR operation if isinstance(divisor, ir.Value): divisor_val = divisor @@ -5068,13 +6223,17 @@ class FastDivmodDivisor: # Create FastDivmod divisor fast_divmod_divisor_type = _cute_ir.FastDivmodDivisorType.get(32, is_power_of_2) - self._divisor = _cute_ir.fast_divmod_create_divisor( + self._divisor_mlir = _cute_ir.fast_divmod_create_divisor( fast_divmod_divisor_type, divisor_val, loc=loc, ip=ip ) @dsl_user_op def __rdivmod__( - self, dividend: Integer, *, loc=None, ip=None + self, + dividend: Integer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Integer, Integer]: """ Overload for: divmod(dividend, self) @@ -5098,14 +6257,20 @@ class FastDivmodDivisor: quotient_type, remainder_type, dividend_val, - self._divisor, + self._divisor_mlir, loc=loc, ip=ip, ) return (IntValue(results[0]), IntValue(results[1])) @dsl_user_op - def __rfloordiv__(self, dividend: Integer, *, loc=None, ip=None) -> Integer: + def __rfloordiv__( + self, + dividend: Integer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Integer: """ Overload for: dividend // self Returns quotient only. @@ -5119,7 +6284,13 @@ class FastDivmodDivisor: return quotient @dsl_user_op - def __rmod__(self, dividend: Integer, *, loc=None, ip=None) -> Integer: + def __rmod__( + self, + dividend: Integer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Integer: """ Overload for: dividend % self Returns remainder only. @@ -5132,22 +6303,79 @@ class FastDivmodDivisor: _, remainder = self.__rdivmod__(dividend, loc=loc, ip=ip) return remainder - def __extract_mlir_values__(self): - """Extract MLIR values for Host->Device transfer.""" - return [self._divisor] + @property + def divisor(self) -> Integer: + """ + Get the original divisor value. - def __new_from_mlir_values__(self, values): + This allows users to access the divisor value that was used to create + this FastDivmodDivisor object. This is useful for passing the divisor + value to other functions or for storing it in data structures without + needing to manually track the divisor value separately. + + :return: The original divisor value + :rtype: Integer + + **Example:** + + .. code-block:: python + + batch_size = 32 + batch_fdd = cute.fast_divmod_create_divisor(batch_size) + print(f"Divisor: {batch_fdd.divisor}") # Access the divisor value + some_function(divisor=batch_fdd.divisor) # Pass to other functions + """ + return self._original_divisor + + @divisor.setter + def divisor(self, value: Integer) -> None: + self._original_divisor = value + + # Backward compatibility: _divisor was renamed to _divisor_mlir in 4.5 + @property + def _divisor(self) -> ir.Value: + return self._divisor_mlir + + @_divisor.setter + def _divisor(self, value: ir.Value) -> None: + self._divisor_mlir = value + + def __extract_mlir_values__(self) -> List[ir.Value]: + """Extract MLIR values for Host->Device transfer.""" + # CRITICAL: Extract the FastDivmodDivisor MLIR value directly. + # + # This allows GridInvariantCodeMotionPass to: + # 1. Recognize FastDivmodCreateDivisorOp in the IR + # 2. Hoist it to the host side before kernel launch + # 3. Pass the pre-computed divisor as a kernel argument + # + # We only extract the _divisor_mlir to maintain compatibility with + # other code that assumes each FastDivmodDivisor has exactly 1 MLIR value. + # The _original_divisor is preserved in the object structure. + return [self._divisor_mlir] + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "FastDivmodDivisor": """Reconstruct FastDivmodDivisor from MLIR values.""" + # Directly use the passed FastDivmodDivisor value without recreating it. + # This is critical to avoid generating new create_divisor ops on device side, + # which would bypass GridInvariantCodeMotionPass optimization. new_obj = object.__new__(FastDivmodDivisor) - new_obj._divisor = values[0] + new_obj._divisor_mlir = values[0] + + # Preserve the original divisor to support the public divisor property. + # Note: After host-device transfer, _original_divisor will reference + # the same value as before transfer for constants, or the reconstructed + # value for dynamic expressions. + new_obj._original_divisor = self._original_divisor + return new_obj - def __repr__(self): - return f"FastDivmodDivisor({self._divisor.type})" + def __repr__(self) -> str: + return f"FastDivmodDivisor(divisor={self._original_divisor}, type={self._divisor_mlir.type})" # Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator -FastDivmodDivisor.__init__.__signature__ = inspect.Signature( +FastDivmodDivisor.__init__.__signature__ = inspect.Signature( # type: ignore[attr-defined] [ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter( @@ -5165,22 +6393,30 @@ FastDivmodDivisor.__init__.__signature__ = inspect.Signature( @dsl_user_op def fast_divmod_create_divisor( - divisor: Integer, *, loc=None, ip=None + divisor: Integer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> FastDivmodDivisor: """Create a FastDivmod divisor for optimized division operations. This function creates a FastDivmod divisor that precomputes auxiliary values to enable fast division and modulus operations without using division instructions. - The returned FastDivmodDivisor object supports natural Python operator syntax: - divisor = fast_divmod_create_divisor(batch_size) - quotient, remainder = divmod(linear_idx, divisor) - quotient = linear_idx // divisor - remainder = linear_idx % divisor + The returned FastDivmodDivisor object supports natural Python operator syntax. :param divisor: The divisor value (should be runtime-dynamic value) :type divisor: Integer :return: FastDivmodDivisor object with operator overloading support :rtype: FastDivmodDivisor + + **Example:** + + .. code-block:: python + + divisor = fast_divmod_create_divisor(batch_size) + quotient, remainder = divmod(linear_idx, divisor) + quotient = linear_idx // divisor + remainder = linear_idx % divisor """ return FastDivmodDivisor(divisor, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/experimental/README.md b/python/CuTeDSL/cutlass/cute/experimental/README.md deleted file mode 100644 index 914330ab7..000000000 --- a/python/CuTeDSL/cutlass/cute/experimental/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# CuTe Experimental APIs - -> **Note:** APIs in this module are experimental and subject to change. -> -> This module serves as a staging area for new CuTe functionality that is still under active development. Performance, compile time, and interoperability with CuTe are works in progress. API signatures, behavior, and naming conventions may change without notice between releases. -> -> Once these APIs are stabilized, they will be migrated to the main `cute` submodules. -> -> Users are encouraged to experiment with these APIs but should be prepared to update their code as the interfaces evolve. - -## Core APIs (`core.py`) - -- `elect_sync` — Elects one thread within a warp -- `get_mbarrier` — Returns the mbarrier pointer for a given stage token -- `create_pipeline` — Creates a circular buffer of synchronization primitives indexed by stage count -- `create_pipeline_with_mask` — Creates a pipeline with an arrival mask for cluster-scoped synchronization -- `pipeline_advance_iterator` — Advances a pipeline iterator to the next stage -- `producer_acquire` / `producer_commit` — Producer-side pipeline synchronization -- `consumer_wait` / `consumer_release` / `consumer_tail` — Consumer-side pipeline synchronization -- `get_pipeline_produce_stage` / `get_pipeline_consume_stage` — Gets pipeline stage tokens - -## Memory APIs (`memory.py`) - -- `allocate` — Allocate a buffer with given type, layout, and address space -- `tma_load` — Copy tensor from global memory to shared memory using TMA -- `tma_load_multicast` — Copy tensor from global memory to shared memory using TMA with multicast -- `tma_store` — Copy tensor from shared memory to global memory using TMA -- `copy` — Copy tensor from src to dst using a given copy atom - -## Algorithm APIs (`algorithm.py`) - -- `simt_auto_vec_copy` — Copies a tensor between buffers with single thread (auto-vectorized) -- `partition` — Partition a buffer into a given layout and tiler -- `partition_and_copy` — Combines partitioning and copying in a single operation - -## Math APIs (`math.py`) - -- `dot` — Computes a dot product of two tensors using an MMA atom -- `dot_block_scaled` — Computes a block-scaled dot product with scale factors - -## Pipeline Classes (`pipeline.py`) - -- `GenericPipeline` — Generic pipeline for any producer/consumer combination -- `TMAToUMMAPipeline` — Pipeline for TMA load to UMMA consumption -- `TMAToAsyncPipeline` — Pipeline for TMA load to async consumer -- `AsyncToUMMAPipeline` — Pipeline for async producer to UMMA consumption -- `UMMAtoAsyncPipeline` — Pipeline for UMMA producer to async consumer -- `TMAStorePipeline` — Pipeline for SMEM producer to TMA store consumer - -## Utilities (`utils.py`) - -- `get_cta_v_map_ab` — Compute CTA-V map for A/B operands -- `get_cta_v_map_c` — Compute CTA-V map for C operand -- `make_tmem_layout_acc` — Derive TMEM accumulator buffer layout from a tiled MMA -- `make_tmem_layout_a` — Derive TMEM A-operand buffer layout from a tiled MMA -- `make_t2r_rmem_layout` — Derive per-thread RMEM buffer layout for the T2R epilogue copy - diff --git a/python/CuTeDSL/cutlass/cute/experimental/__init__.py b/python/CuTeDSL/cutlass/cute/experimental/__init__.py index 571629656..2e3b59a87 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/__init__.py +++ b/python/CuTeDSL/cutlass/cute/experimental/__init__.py @@ -17,6 +17,7 @@ compile = _dsl.CompileCallable() from .algorithm import * from .core import * +from .host_runtime import * from .math import * from .memory import * from .pipeline import * diff --git a/python/CuTeDSL/cutlass/cute/experimental/algorithm.py b/python/CuTeDSL/cutlass/cute/experimental/algorithm.py index 90766ca98..8977a0271 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/algorithm.py +++ b/python/CuTeDSL/cutlass/cute/experimental/algorithm.py @@ -9,8 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from typing import Optional + from cutlass import cute from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir import ir from cutlass._mlir.dialects import lir as cutlass_lir from .memory import copy @@ -18,17 +21,14 @@ from .memory import copy @dsl_user_op def simt_auto_vec_copy( - src: cute.Tensor, dst: cute.Tensor, async_op=False, loc=None, ip=None -): + src: cute.Tensor, + dst: cute.Tensor, + async_op: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ - Copies a tensor between two cute.memref buffers with single thread. - - :param src: Source tensor - :type src: cute.Tensor - :param dst: Destination tensor - :type dst: cute.Tensor - :param async_op: Whether to use asynchronous operation, defaults to False - :type async_op: bool, optional + Copies a tensor between two cute.memref buffers with single thread """ if async_op: cutlass_lir.SimtAutoVecCopyOp( @@ -40,19 +40,16 @@ def simt_auto_vec_copy( @dsl_user_op def partition( - buffer: cute.Tensor, agent_id: cute.Int32, *, layout_tv, tiler, loc=None, ip=None + buffer: cute.Tensor, + agent_id: cute.Int32, + *, + layout_tv: cute.Layout, + tiler: cute.Layout, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Tensor: """ Partition a buffer into a given layout and tiler. - - :param buffer: Buffer to partition - :type buffer: cute.Tensor - :param agent_id: Agent ID - :type agent_id: cute.Int32 - :param layout_tv: Layout tensor - :type layout_tv: cute.Tensor - :param tiler: Tiler - :type tiler: cute.Tensor """ assert isinstance(agent_id, cute.Int32), ( f"Expected agent_id to be cute.Int32, got {type(agent_id)}" @@ -74,18 +71,11 @@ def partition_and_copy( src: cute.Tensor, dst: cute.Tensor, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Copies a tensor between two cute.memref buffer - - :param tiled_copy: Tiled copy - :type tiled_copy: cute.core.ThrCopy - :param src: Source tensor - :type src: cute.Tensor - :param dst: Destination tensor - :type dst: cute.Tensor """ src_partitioned = src dst_partitioned = dst @@ -144,7 +134,7 @@ def partition_and_copy( copy( src_partitioned, dst_partitioned, - copy_atom=tiled_copy, + copy_atom=cute.make_copy_atom(tiled_copy.op, src.element_type), loc=loc, ip=ip, ) diff --git a/python/CuTeDSL/cutlass/cute/experimental/core.py b/python/CuTeDSL/cutlass/cute/experimental/core.py index 42159d015..bedf22a2a 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/core.py +++ b/python/CuTeDSL/cutlass/cute/experimental/core.py @@ -1,12 +1,15 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Protocol, TypeAlias from cutlass.cutlass_dsl import dsl_user_op from cutlass._mlir.dialects import lir as cutlass_lir_ir, nvvm as _nvvm @@ -14,18 +17,34 @@ from cutlass._mlir import ir from cutlass.cutlass_dsl import lru_cache_ir from cutlass._mlir.dialects.core import OperationTypeEnum from cutlass import cute +from cutlass.cute.typing import Boolean + + +class _SupportsIrValue(Protocol): + def ir_value( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: ... + + +SkipWaitToken: TypeAlias = bool | ir.Value | _SupportsIrValue @dsl_user_op -def elect_sync(loc=None, ip=None): - """ - Elects one predicated thread within a warp. - """ +def elect_sync( + loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> ir.Value: return _nvvm.elect_sync(loc=loc, ip=ip) @dsl_user_op -def get_mbarrier(stage_token, loc=None, ip=None): +def get_mbarrier( + stage_token: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Returns the mbarrier pointer for a given stage token. """ @@ -34,7 +53,7 @@ def get_mbarrier(stage_token, loc=None, ip=None): @ir.register_value_caster(cutlass_lir_ir.PipelineStateType.get_static_typeid()) class PipelineState(ir.Value): - def __init__(self, value): + def __init__(self, value: ir.Value) -> None: if isinstance(value, ir.Value): self.value = value else: @@ -47,48 +66,90 @@ class PipelineState(ir.Value): return self.value.type @classmethod - def __new_from_mlir_values__(cls, values): + def __new_from_mlir_values__(cls, values: list[ir.Value]) -> "PipelineState": assert len(values) == 1, f"Expected 1 value, but got {len(values)}" return PipelineState(values[0]) -@dsl_user_op -def create_pipeline( +def _normalize_create_pipeline_arrival_mask( + arrival_mask: Optional[cute.Int16], + compat_kwargs: dict[str, object], +) -> Optional[cute.Int16]: + # Legacy source compatibility: older callers used `multicast` as the sixth + # argument. Keep `False` working, but force `True` callers onto the explicit + # mask APIs because the legacy path produced incorrect IR. + # Remove this shim once the team is comfortable breaking low-level + # create_pipeline() callers and dropping the legacy multicast spelling. + multicast = compat_kwargs.pop("multicast", None) + if compat_kwargs: + unexpected_arg = next(iter(compat_kwargs)) + raise TypeError( + f"create_pipeline() got an unexpected keyword argument '{unexpected_arg}'" + ) + + if multicast is not None: + if not isinstance(multicast, bool): + raise TypeError(f"Expected `multicast` to be a bool, got {type(multicast)}") + if arrival_mask is not None: + raise ValueError( + "create_pipeline() does not accept both `arrival_mask` and legacy `multicast`." + ) + if multicast: + raise ValueError( + "create_pipeline(multicast=True) is no longer supported; " + "use create_pipeline(..., arrival_mask=...) or " + "create_pipeline_with_mask(...)." + ) + return None + + if isinstance(arrival_mask, bool): + if arrival_mask: + raise ValueError( + "create_pipeline(True) no longer supports the legacy multicast " + "form; use create_pipeline(..., arrival_mask=...) or " + "create_pipeline_with_mask(...)." + ) + return None + + return arrival_mask + + +def _build_pipeline( stage: cute.Int32, producer: OperationTypeEnum, consumer: OperationTypeEnum, producer_arv_count: cute.Int32, consumer_arv_count: cute.Int32, - loc=None, - ip=None, + arrival_mask: Optional[cute.Int16], + loc: Optional[ir.Location], + ip: Optional[ir.InsertionPoint], ) -> tuple[PipelineState, PipelineState, PipelineState]: - """ - Creates an abstraction for a circular buffer of synchronizatoin primitives - indexed by stage count. - - :param stage: Stage count - :type stage: cute.Int32 - :param producer: Producer operation type - :type producer: OperationTypeEnum - :param consumer: Consumer operation type - :type consumer: OperationTypeEnum - :param producer_arv_count: Producer arrival count - :type producer_arv_count: cute.Int32 - :param consumer_arv_count: Consumer arrival count - :type consumer_arv_count: cute.Int32 - """ if isinstance(producer_arv_count, int): producer_arv_count = cute.Int32(producer_arv_count) if isinstance(consumer_arv_count, int): consumer_arv_count = cute.Int32(consumer_arv_count) - result = ir.Type.parse(f"!lir.pipeline<{stage}, {producer} -> {consumer}>") - op = cutlass_lir_ir.CreatePipelineOp( - result, - producer_arv_count.ir_value(), - consumer_arv_count.ir_value(), - loc=loc, - ip=ip, - ) + + if arrival_mask is not None: + if isinstance(arrival_mask, int): + arrival_mask = cute.Int16(arrival_mask) + result = ir.Type.parse(f"!lir.pipeline<{stage}, {producer} -> {consumer}>") + op = cutlass_lir_ir.CreatePipelineWithMaskOp( + result, + producer_arv_count.ir_value(), + consumer_arv_count.ir_value(), + arrival_mask.ir_value(), + loc=loc, + ip=ip, + ) + else: + result = ir.Type.parse(f"!lir.pipeline<{stage}, {producer} -> {consumer}>") + op = cutlass_lir_ir.CreatePipelineOp( + result, + producer_arv_count.ir_value(), + consumer_arv_count.ir_value(), + loc=loc, + ip=ip, + ) pipeline = op.result result = ir.Type.parse(f"!lir.pipeline_state<{stage}>") @@ -102,6 +163,45 @@ def create_pipeline( return pipeline, producer_state, consumer_state +@dsl_user_op +def create_pipeline( + stage: cute.Int32, + producer: OperationTypeEnum, + consumer: OperationTypeEnum, + producer_arv_count: cute.Int32, + consumer_arv_count: cute.Int32, + arrival_mask: Optional[cute.Int16] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **compat_kwargs: object, +) -> tuple[PipelineState, PipelineState, PipelineState]: + """ + Creates an abstraction for a circular buffer of synchronization primitives + indexed by stage count. + + Args: + stage: Number of pipeline stages. + producer: Producer operation type. + consumer: Consumer operation type. + producer_arv_count: Number of producer arrivals. + consumer_arv_count: Number of consumer arrivals. + arrival_mask: Optional arrival mask for multi-CTA synchronization + (2SM or multicast). When provided, creates the pipeline with + explicit mask-based barrier configuration. + """ + arrival_mask = _normalize_create_pipeline_arrival_mask(arrival_mask, compat_kwargs) + return _build_pipeline( + stage, + producer, + consumer, + producer_arv_count, + consumer_arv_count, + arrival_mask, + loc, + ip, + ) + + @dsl_user_op def create_pipeline_with_mask( stage: cute.Int32, @@ -110,53 +210,29 @@ def create_pipeline_with_mask( producer_arv_count: cute.Int32, consumer_arv_count: cute.Int32, arrival_mask: cute.Int16, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> tuple[PipelineState, PipelineState, PipelineState]: - """ - Creates a pipeline with an arrival mask for cluster-scoped synchronization. - - :param stage: Pipeline stage count. - :param producer: Producer operation type (e.g. SM90_TMA_LOAD_MULTICAST). - :param consumer: Consumer operation type (e.g. SM100_MMA_2SM_SS). - :param producer_arv_count: Producer arrival count for the pipeline barriers. - :param consumer_arv_count: Consumer arrival count for the pipeline barriers. - :param arrival_mask: Bitmask that selects participating peers (e.g. CTAs in a - cluster). This is attached to the pipeline value and is consulted by some - pipeline lowerings to generate cluster-scoped synchronization - """ - if isinstance(producer_arv_count, int): - producer_arv_count = cute.Int32(producer_arv_count) - if isinstance(consumer_arv_count, int): - consumer_arv_count = cute.Int32(consumer_arv_count) - if isinstance(arrival_mask, int): - arrival_mask = cute.Int16(arrival_mask) - - result = ir.Type.parse(f"!lir.pipeline<{stage}, {producer} -> {consumer}>") - op = cutlass_lir_ir.CreatePipelineWithMaskOp( - result, - producer_arv_count.ir_value(), - consumer_arv_count.ir_value(), - arrival_mask.ir_value(), - loc=loc, - ip=ip, + """Backward-compatible wrapper. Prefer create_pipeline(..., arrival_mask=...).""" + return _build_pipeline( + stage, + producer, + consumer, + producer_arv_count, + consumer_arv_count, + arrival_mask, + loc, + ip, ) - pipeline = op.result - - result = ir.Type.parse(f"!lir.pipeline_state<{stage}>") - op = cutlass_lir_ir.CreatePipelineStateOp(result, pipeline, loc=loc, ip=ip) - producer_state = op.result - - result = ir.Type.parse(f"!lir.pipeline_state<{stage}>") - op = cutlass_lir_ir.CreatePipelineStateOp(result, pipeline, loc=loc, ip=ip) - consumer_state = op.result - - return pipeline, producer_state, consumer_state - @dsl_user_op -def pipeline_advance_iterator(pipe, state, loc=None, ip=None): +def pipeline_advance_iterator( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Advances a pipeline iterator to the next stage. """ @@ -165,7 +241,12 @@ def pipeline_advance_iterator(pipe, state, loc=None, ip=None): @dsl_user_op -def producer_acquire(pipe, state, loc=None, ip=None): +def producer_acquire( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Acquires exclusive access to a pipeline. """ @@ -174,7 +255,12 @@ def producer_acquire(pipe, state, loc=None, ip=None): @dsl_user_op -def producer_commit(pipe, state, loc=None, ip=None): +def producer_commit( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Commits results to a pipeline. """ @@ -183,7 +269,12 @@ def producer_commit(pipe, state, loc=None, ip=None): @dsl_user_op -def consumer_wait(pipe, state, loc=None, ip=None): +def consumer_wait( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Waits for a pipeline to transition to `full`. """ @@ -192,7 +283,12 @@ def consumer_wait(pipe, state, loc=None, ip=None): @dsl_user_op -def consumer_release(pipe, state, loc=None, ip=None): +def consumer_release( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Releases a pipeline that has been consumed. """ @@ -201,7 +297,28 @@ def consumer_release(pipe, state, loc=None, ip=None): @dsl_user_op -def consumer_tail(pipe, state, loc=None, ip=None): +def consumer_release_elect_one_sync( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: + """ + Releases a pipeline that has been consumed. + """ + op = cutlass_lir_ir.ConsumerReleaseOp( + pipe, state, elect_one_sync=True, loc=loc, ip=ip + ) + return op.result + + +@dsl_user_op +def consumer_tail( + pipe: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: """ Called by the consumer to block until asynchronous tasks have completed. """ @@ -210,7 +327,12 @@ def consumer_tail(pipe, state, loc=None, ip=None): @dsl_user_op -def get_pipeline_produce_stage(pipeline, state, loc=None, ip=None): +def get_pipeline_produce_stage( + pipeline: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple[ir.Value, ir.Value]: """ Gets a pipeline produce stage. """ @@ -228,7 +350,12 @@ def get_pipeline_produce_stage(pipeline, state, loc=None, ip=None): @dsl_user_op -def get_pipeline_consume_stage(pipeline, state, loc=None, ip=None): +def get_pipeline_consume_stage( + pipeline: ir.Value, + state: ir.Value, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> tuple[ir.Value, ir.Value]: """ Creates a pipeline consume stage. """ @@ -243,3 +370,244 @@ def get_pipeline_consume_stage(pipeline, state, loc=None, ip=None): ip=ip, ) return op.stage_token, op.stage_index + + +@ir.register_value_caster( + cutlass_lir_ir.CircularBufferPipelineStateType.get_static_typeid() +) +class CircularBufferPipelineState(ir.Value): + def __init__(self, value: ir.Value) -> None: + if isinstance(value, ir.Value): + self.value = value + else: + raise TypeError(f"Expected ir.Value, got {type(value)}") + super().__init__(value) + + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @classmethod + def __new_from_mlir_values__( + cls, values: list[ir.Value] + ) -> "CircularBufferPipelineState": + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + return CircularBufferPipelineState(values[0]) + + +@dsl_user_op +def create_circular_buffer_pipeline( + pipeline: ir.Value, + pipeline_state: PipelineState, + stages: int, + count_per_stage: int, + count_per_iteration: int, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> CircularBufferPipelineState: + """ + Creates a circular buffer abstraction layered on top of a lir.pipeline. + + Each pipeline stage is subdivided into `count_per_stage` units. + Operations can advance the circular buffer position by `count_per_iteration` units + at a time in a FIFO manner. The abstraction provides synchronized access to + pipeline stages given the circular buffer position. + + Args: + pipeline: The underlying pipeline object + pipeline_state: Initial pipeline state + stages: Number of pipeline stages + count_per_stage: Number of units per pipeline stage + count_per_iteration: Number of units per iteration (chunk size) + loc: Source location + ip: Insertion point + + Returns: + CircularBufferPipelineState: The circular buffer pipeline state + """ + result_type = ir.Type.parse( + f"!lir.circular_buffer_pipeline_state<{stages}, {count_per_stage}, {count_per_iteration}>" + ) + op = cutlass_lir_ir.CreateCircularBufferPipelineOp( + result_type, + pipeline, + pipeline_state, + loc=loc, + ip=ip, + ) + return op.result + + +@dsl_user_op +def circular_buffer_pipeline_consume( + pipeline: ir.Value, + circular_buffer_pipeline_state: CircularBufferPipelineState, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """ + Synchronize pipeline stages needed for circular buffer consumption. + + This operation performs synchronization for the circular buffer consumer. + Based on the current circular buffer position and `count_per_iteration`, it + determines which pipeline stages need to be synchronized and waits for them + to transition to full before consumption can proceed. + + Args: + pipeline: The underlying pipeline object + circular_buffer_pipeline_state: Current circular buffer pipeline state + loc: Source location + ip: Insertion point + """ + cutlass_lir_ir.CircularBufferPipelineConsumeOp( + pipeline, + circular_buffer_pipeline_state, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def circular_buffer_pipeline_consumer_release( + pipeline: ir.Value, + circular_buffer_pipeline_state: CircularBufferPipelineState, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """ + Release pipeline stages after circular buffer consumption. + + This operation releases pipeline stages after circular buffer consumption. + Based on the current circular buffer position and `count_per_iteration`, it + determines which pipeline stages have been fully consumed and transitions them + to empty. + + Args: + pipeline: The underlying pipeline object + circular_buffer_pipeline_state: Current circular buffer pipeline state + loc: Source location + ip: Insertion point + """ + cutlass_lir_ir.CircularBufferPipelineConsumerReleaseOp( + pipeline, + circular_buffer_pipeline_state, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def circular_buffer_pipeline_advance_iterator( + pipeline: ir.Value, + circular_buffer_pipeline_state: CircularBufferPipelineState, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> CircularBufferPipelineState: + """ + Advance the circular buffer position. + + This operation advances the circular buffer position by `count_per_iteration` + units. + + Args: + pipeline: The underlying pipeline + circular_buffer_pipeline_state: Current circular buffer pipeline state + loc: Source location + ip: Insertion point + + Returns: + CircularBufferPipelineState: Updated circular buffer pipeline state with advanced offset + """ + op = cutlass_lir_ir.CircularBufferPipelineAdvanceIteratorOp( + pipeline, + circular_buffer_pipeline_state, + loc=loc, + ip=ip, + ) + return op.result + + +@dsl_user_op +def mbarrier_expect_tx( + mbarPtr: ir.Value, + txBytes: cute.Int32, + ctaId: Optional[ir.Value] = None, + elect_one_sync: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """ + Called by the consumer to block until asynchronous tasks have completed. Supports optional broadcast. + """ + if isinstance(txBytes, int): + txBytes = cute.Int32(txBytes) + if ctaId != None: + ctaId = ctaId.value + _op = cutlass_lir_ir.MBarrierExpectTxOp( + mbarPtr.value, + txBytes.ir_value(), + ctaId=ctaId, + elect_one_sync=elect_one_sync, + loc=loc, + ip=ip, + ) + return + + +def normalize_skip_wait_token( + token: Optional[SkipWaitToken], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Optional[ir.Value]: + """ + Normalizes a skip wait token to an ir.Value. + """ + if token is None: + return None + if isinstance(token, bool): + return Boolean(token).ir_value(loc=loc, ip=ip) + if isinstance(token, ir.Value): + return token + if hasattr(token, "ir_value"): + return token.ir_value(loc=loc, ip=ip) + raise TypeError(f"skipWait token must lower to ir.Value, got {type(token)}") + + +@dsl_user_op +def producer_try_acquire( + pipe: ir.Value, + state: ir.Value, + *, + token: Optional[SkipWaitToken] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Boolean: + """ + Tries to acquire a producer stage, non-blocking. + """ + skip_wait = normalize_skip_wait_token(token, loc=loc, ip=ip) + token_value = cutlass_lir_ir.ProducerTryAcquireOp( + pipe, state, skipWait=skip_wait, loc=loc, ip=ip + ).token + return Boolean(token_value, loc=loc, ip=ip) + + +@dsl_user_op +def consumer_try_wait( + pipe: ir.Value, + state: ir.Value, + *, + token: Optional[SkipWaitToken] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Boolean: + """ + Tries to wait for a consumer stage, non-blocking. + """ + skip_wait = normalize_skip_wait_token(token, loc=loc, ip=ip) + token_value = cutlass_lir_ir.ConsumerTryWaitOp( + pipe, state, skipWait=skip_wait, loc=loc, ip=ip + ).token + return Boolean(token_value, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/cute/experimental/host_runtime.py b/python/CuTeDSL/cutlass/cute/experimental/host_runtime.py new file mode 100644 index 000000000..373095fe4 --- /dev/null +++ b/python/CuTeDSL/cutlass/cute/experimental/host_runtime.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +"""Host-side runtime helpers for querying and driving JIT-compiled kernels.""" + +import ctypes +from typing import Any + +from cutlass.base_dsl.jit_executor import AuxRuntimeFunc, DSLRuntimeError, ExecutionArgs + +__all__ = [ + "AllocationRequirement", + "QueryDeviceWorkspaceFunc", +] + + +class AllocationRequirement(ctypes.Structure): + """Mirrors the ``AllocationRequirement`` struct produced by ``queryDeviceWorkspace``. + + .. code-block:: c + + struct AllocationRequirement { int64_t sizeInBytes; int64_t alignment; }; + """ + + _fields_ = [ + ("size_in_bytes", ctypes.c_int64), + ("alignment", ctypes.c_int64), + ] + + def __repr__(self) -> str: + return ( + f"AllocationRequirement(size_in_bytes={self.size_in_bytes}, " + f"alignment={self.alignment})" + ) + + +class QueryDeviceWorkspaceFunc(AuxRuntimeFunc): + """Callable wrapper for the ``queryDeviceWorkspace`` symbol of a single kernel. + + Usage:: + + query = compiled_fn.get_aux_func(QueryDeviceWorkspaceFunc, kernel=my_kernel) + req = query(*kernel_args) + workspace = torch.empty(req.size_in_bytes, dtype=torch.uint8, device="cuda") + compiled_fn(*kernel_args, from_dlpack(workspace)) + + :param func_ptr: Raw integer address of the packed-args wrapper + returned by ``engine.raw_lookup``. + :param args_spec: The :class:`~cutlass.base_dsl.jit_executor.ExecutionArgs` + instance used to build logical-arg arrays from Python tensors. + """ + + name = "queryDeviceWorkspace" + + def __init__(self, func_ptr: int, args_spec: ExecutionArgs) -> None: + self._raw_fn = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(func_ptr) + self._args_spec = args_spec + + def __call__(self, *args: Any, **kwargs: Any) -> AllocationRequirement: + """Query workspace requirements for the given kernel arguments. + + Accepts the same positional/keyword arguments as the kernel call. + Returns the :class:`AllocationRequirement` reported by the kernel. + + :raises DSLRuntimeError: If the host function returns a non-zero error code. + """ + exe_args, adapted_args = self._args_spec.generate_execution_args(args, kwargs) + logical_args = (ctypes.c_void_p * len(exe_args))(*exe_args) + arg0 = ctypes.c_void_p(ctypes.addressof(logical_args)) + + # MLIR packed-args calling convention: void f(void **all_args) where + # all_args[0] = &arg0 (logical_args_ptr value) + # all_args[1] = &arg1 (AllocationRequirement* value) + # all_args[2] = &retval (i32 return value storage) + req = AllocationRequirement() + arg1 = ctypes.c_void_p(ctypes.addressof(req)) + retval = ctypes.c_int32(0) + all_args = (ctypes.c_void_p * 3)( + ctypes.addressof(arg0), + ctypes.addressof(arg1), + ctypes.addressof(retval), + ) + self._raw_fn(all_args) + ret = retval.value + if ret != 0: + raise DSLRuntimeError(f"queryDeviceWorkspace failed with return code {ret}") + + del adapted_args # keep alive until after the call + return req diff --git a/python/CuTeDSL/cutlass/cute/experimental/math.py b/python/CuTeDSL/cutlass/cute/experimental/math.py index 5eedee9c5..99b5a785d 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/math.py +++ b/python/CuTeDSL/cutlass/cute/experimental/math.py @@ -9,8 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from typing import Optional + from cutlass import cute from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir import ir from cutlass._mlir.dialects import lir as cutlass_lir @@ -22,25 +25,9 @@ def dot_block_scaled( b: cute.Tensor, sfb: cute.Tensor, c: cute.Tensor, - loc=None, - ip=None, -): - """ - Computes the dot product of two tensors with block scaling and accumulates the result into a third tensor. - - :param mma_atom: MMA atom - :type mma_atom: cute.MmaAtom - :param a: First tensor - :type a: cute.Tensor - :param sfa: First scale factor tensor - :type sfa: cute.Tensor - :param b: Second tensor - :type b: cute.Tensor - :param sfb: Second scale factor tensor - :type sfb: cute.Tensor - :param c: Result tensor - :type c: cute.Tensor - """ + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: cutlass_lir.DotBlockScaledOp( mma_atom._unpack(), a.value, @@ -59,21 +46,9 @@ def dot( a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, - loc=None, - ip=None, -): - """ - Computes the dot product of two tensors and accumulates the result into a third tensor. - - :param mma_atom: MMA atom - :type mma_atom: cute.MmaAtom - :param a: First tensor - :type a: cute.Tensor - :param b: Second tensor - :type b: cute.Tensor - :param c: Result tensor - :type c: cute.Tensor - """ + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: cutlass_lir.DotOp( mma_atom._unpack(), a.value, @@ -82,3 +57,5 @@ def dot( loc=loc, ip=ip, ) + + diff --git a/python/CuTeDSL/cutlass/cute/experimental/memory.py b/python/CuTeDSL/cutlass/cute/experimental/memory.py index 9d36437d1..a6f84d0db 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/memory.py +++ b/python/CuTeDSL/cutlass/cute/experimental/memory.py @@ -21,7 +21,9 @@ from cutlass._mlir.dialects.core import OperationTypeEnum from cutlass import cute -def _get_tma_load_kind(tma_operation_type: OperationTypeEnum): +def _get_tma_load_kind( + tma_operation_type: OperationTypeEnum, +) -> _cute_ir.TiledTmaLoadEnum: """Convert OperationTypeEnum to TiledTmaLoadEnum.""" if tma_operation_type == OperationTypeEnum.SM100_TMA_LOAD_2SM_MULTICAST: return _cute_ir.TiledTmaLoadEnum.sm_100_2sm_multicast @@ -41,8 +43,8 @@ def allocate( layout: cute.Layout | cute.ComposedLayout, alignment: cute.Int32, is2cta: bool = False, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Tensor: """ Allocate a buffer of the given type and layout. @@ -63,10 +65,10 @@ def allocate( swizzle = layout.inner layout = layout.outer - # Handle SparseElemType (pass through) vs regular types (get mlir_type) - if isinstance(type, _cute_ir.SparseElemType): - pass - else: + bit_layout = None + + _is_passthrough_type = False + if not _is_passthrough_type: type = type.mlir_type ptr_ty = _cute_ir.PtrType.get( @@ -74,6 +76,8 @@ def allocate( address_space, alignment, swizzle.type.attribute if swizzle else None, + None, + bit_layout.type.attribute if bit_layout else None, ) buffer_type = _cute_ir.MemRefType.get(ptr_ty, layout.type) @@ -90,15 +94,15 @@ def allocate( def tma_load( src: cute.Tensor, dst: cute.Tensor, - mbar, + mbar: ir.Value, *, - cta_v_map, + cta_v_map: Optional[cute.Layout] = None, tma_operation_type: Optional[OperationTypeEnum] = None, - internal_type=None, + internal_type: Optional[Type[cute.Numeric]] = None, update_expect_tx: bool = True, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Copies a tensor pointed by a !cute.memref into a Buffer using TMA. @@ -106,23 +110,6 @@ def tma_load( When used with Cute DSL pipelines, it must be set to False as the pipeline already initializes the mbarrier's transaction bytes. tma_operation_type (optional): specifies the TMA operation type (SM90_TMA_LOAD, SM100_TMA_LOAD_2SM, etc.) internal_type (optional): selects the TMA transfer's internal element encoding used by hardware. - Does not change src/dst memref types. For structured sparsity, use base storage types: - Float16 for 2:4 FP16 sparse element type, Uint8 for 8:1 uint8 sparse element type. - - :param src: Source tensor in global memory - :type src: cute.Tensor - :param dst: Destination tensor in shared memory - :type dst: cute.Tensor - :param mbar: Memory barrier for synchronization - :type mbar: cute.core.Mbarrier - :param cta_v_map: CTA V-map for the tensor - :type cta_v_map: cute.core.CtaVMap - :param tma_operation_type: TMA operation type (e.g., SM90_TMA_LOAD, SM100_TMA_LOAD_2SM, etc.) - :type tma_operation_type: OperationTypeEnum - :param internal_type: Internal type of the TMA transfer - :type internal_type: cute.core.InternalType - :param update_expect_tx: Whether to update expected transaction bytes - :type update_expect_tx: bool """ if tma_operation_type is not None: kind = _get_tma_load_kind(tma_operation_type) @@ -130,11 +117,12 @@ def tma_load( kind = _cute_ir.TiledTmaLoadEnum.sm_90 kwargs = { - "cta_v_map": cta_v_map.type.attribute, "kind": kind, "loc": loc, "ip": ip, } + if cta_v_map is not None: + kwargs["cta_v_map"] = cta_v_map.type.attribute # Map internal_type to tma_format per updated API if internal_type is not None: internal_mlir_ty = ( @@ -148,7 +136,6 @@ def tma_load( if update_expect_tx: kwargs["update_expect_tx"] = True - cutlass_lir.TmaLoadOp(src.value, dst.value, mbar, **kwargs) @@ -156,16 +143,16 @@ def tma_load( def tma_load_multicast( src: cute.Tensor, dst: cute.Tensor, - mbar, + mbar: ir.Value, *, vmnk_layout: cute.Layout, - cta_v_map, + cta_v_map: Optional[cute.Layout] = None, tma_operation_type: OperationTypeEnum, multicast_mode: int, update_expect_tx: bool = True, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Copies a tensor pointed by a !cute.memref into a Buffer using TMA with multicast. @@ -180,17 +167,17 @@ def tma_load_multicast( """ kind = _get_tma_load_kind(tma_operation_type) kwargs = { - "cta_v_map": cta_v_map.type.attribute, "kind": kind, "vmnk_layout": vmnk_layout, "multicast_mode": multicast_mode, "loc": loc, "ip": ip, } + if cta_v_map is not None: + kwargs["cta_v_map"] = cta_v_map.type.attribute if update_expect_tx: kwargs["update_expect_tx"] = True - cutlass_lir.TmaLoadMulticastOp( src.value, dst.value, @@ -204,34 +191,23 @@ def tma_store( src: cute.Tensor, dst: cute.Tensor, *, - cta_v_map, - internal_type=None, - loc=None, - ip=None, -): + cta_v_map: Optional[cute.Layout] = None, + internal_type: Optional[Type[cute.Numeric]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Copies a tensor from a Buffer to a tensor pointed to by a !cute.memref. internal_type (optional): selects the TMA transfer's internal element encoding used by hardware. - Does not change src/dst memref types. For structured sparsity, use base storage types: - Float16 for 2:4 FP16 sparse element type, Uint8 for 8:1 uint8 sparse element type. - - - :param src: Source tensor in shared memory - :type src: cute.Tensor - :param dst: Destination tensor in global memory - :type dst: cute.Tensor - :param cta_v_map: CTA V-map for the tensor - :type cta_v_map: cute.core.CtaVMap - :param internal_type: Internal type of the TMA transfer - :type internal_type: cute.core.InternalType """ kwargs = { - "cta_v_map": cta_v_map.type.attribute, "loc": loc, "ip": ip, } + if cta_v_map is not None: + kwargs["cta_v_map"] = cta_v_map.type.attribute # Map internal_type to tma_format per updated API if internal_type is not None: @@ -248,7 +224,14 @@ def tma_store( @dsl_user_op -def copy(src: cute.Tensor, dst: cute.Tensor, *, copy_atom, loc=None, ip=None): +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + copy_atom: cute.CopyAtom, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Copy a tensor from src to dst using a given copy atom. """ diff --git a/python/CuTeDSL/cutlass/cute/experimental/pipeline.py b/python/CuTeDSL/cutlass/cute/experimental/pipeline.py index 20c1f3eec..0850823e2 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/pipeline.py +++ b/python/CuTeDSL/cutlass/cute/experimental/pipeline.py @@ -1,59 +1,73 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. """ Convenience pipeline classes that hide elect_one synchronization complexity """ from dataclasses import dataclass -from typing import Optional +from typing import NoReturn, Optional import cutlass import cutlass.cute as cute -from cutlass._mlir.dialects import lir as cutlass_lir_ir from cutlass.base_dsl.typing import Int32 +from cutlass._mlir import ir +from cutlass._mlir.dialects import lir as cutlass_lir_ir from cutlass._mlir.dialects.core import OperationTypeEnum +from cutlass.cute.typing import Boolean from cutlass.cute.experimental.core import ( create_pipeline, - create_pipeline_with_mask, producer_acquire, get_pipeline_produce_stage, get_pipeline_consume_stage, producer_commit, consumer_release, pipeline_advance_iterator, + PipelineState, consumer_wait, consumer_tail, + create_circular_buffer_pipeline, + circular_buffer_pipeline_consume, + circular_buffer_pipeline_consumer_release, + circular_buffer_pipeline_advance_iterator, + mbarrier_expect_tx, + normalize_skip_wait_token, + producer_try_acquire, + consumer_try_wait, + SkipWaitToken, ) from cutlass.cutlass_dsl import CuteExperimentalDSL +from ..typing import Pointer + class GenericPipelineBase: """Base class for pipeline convenience wrappers""" def __init__( self, - raw_pipeline, - num_stages, - producer_state, - consumer_state, - ): + raw_pipeline: ir.Value, + num_stages: cute.Int32, + producer_state: ir.Value, + consumer_state: ir.Value, + ) -> None: self.raw_pipeline = raw_pipeline self.num_stages = num_stages # For convenience class, we always manage state internally self.producer_state = producer_state self.consumer_state = consumer_state - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list: """Extract MLIR values for DynamicExpression protocol.""" # raw_pipeline is always ir.OpResult from create_pipeline (no __extract_mlir_values__) pipeline_values = [self.raw_pipeline] @@ -68,14 +82,14 @@ class GenericPipelineBase: return ( pipeline_values + [ - num_stages_dsl.__extract_mlir_values__()[0], + num_stages_dsl.__extract_mlir_values__()[0], # type: ignore[attr-defined] ] + producer_state_values + consumer_state_values ) @classmethod - def __new_from_mlir_values__(cls, values): + def __new_from_mlir_values__(cls, values: list) -> "GenericPipelineBase": """Reconstruct object from MLIR values.""" # Parse the known structure: [pipeline] + [num_stages, producer_flag, consumer_flag] + [producer_state] + [consumer_state] # All lir_* objects are single MLIR values @@ -84,53 +98,73 @@ class GenericPipelineBase: producer_state = values[2] # Always single PipelineState consumer_state = values[3] # Always single PipelineState - # Create temporary DSL objects and extract Python values - temp_num_stages = Int32(0) - - num_stages_dsl = temp_num_stages.__new_from_mlir_values__([num_stages_val]) + num_stages_dsl = Int32(0).__new_from_mlir_values__([num_stages_val]) # type: ignore[attr-defined] return cls( raw_pipeline, - ( - num_stages_dsl.value - if hasattr(num_stages_dsl, "value") - else int(num_stages_dsl) - ), + num_stages_dsl, producer_state, consumer_state, ) - def producer_acquire(self): + def producer_acquire(self) -> "GenericPipelineBase": """Acquire producer state.""" producer_acquire(self.raw_pipeline, self.producer_state) return self - def get_producer_stage(self): + def producer_try_acquire(self, *, token: Optional[SkipWaitToken] = None) -> Boolean: + """Try to acquire the next producer stage without blocking.""" + return producer_try_acquire(self.raw_pipeline, self.producer_state, token=token) + + def get_producer_stage(self) -> ir.Value: """Get producer stage.""" return get_pipeline_produce_stage(self.raw_pipeline, self.producer_state) - def get_consumer_stage(self): + def get_consumer_stage(self) -> ir.Value: """Get consumer stage.""" return get_pipeline_consume_stage(self.raw_pipeline, self.consumer_state) # Instance methods that can now be used directly in kernel context - def producer_acquire_and_get_stage(self): - """Combined producer acquire + get_stage with automatic elect_one using internal state.""" + def producer_acquire_and_get_stage( + self, *, token: Optional[SkipWaitToken] = None + ) -> tuple[ir.Value, ir.Value]: + """Acquire a producer stage and return its stage token/index. - self.producer_acquire() - return get_pipeline_produce_stage(self.raw_pipeline, self.producer_state) + When `token` is provided, reuse the preceding `producer_try_acquire()` + result and keep the internal state at the acquired stage so a following + `producer_commit_and_advance()` retires the same stage. + """ + if token is None: + self.producer_acquire() + return get_pipeline_produce_stage(self.raw_pipeline, self.producer_state) - def producer_commit(self): + skip_wait = normalize_skip_wait_token(token) + stage_state = cutlass_lir_ir.ProducerAcquireOp( + self.raw_pipeline, + self.producer_state, + skipWait=skip_wait, + ).outState + self.producer_state = stage_state + stage_token, stage_idx = get_pipeline_produce_stage( + self.raw_pipeline, stage_state + ) + return stage_token, stage_idx + + def producer_commit(self) -> "GenericPipelineBase": """Commit producer state.""" producer_commit(self.raw_pipeline, self.producer_state) return self - def consumer_release(self): + def consumer_try_wait(self, *, token: Optional[SkipWaitToken] = None) -> Boolean: + """Try to wait for the next consumer stage without blocking.""" + return consumer_try_wait(self.raw_pipeline, self.consumer_state, token=token) + + def consumer_release(self) -> "GenericPipelineBase": """Release consumer state.""" consumer_release(self.raw_pipeline, self.consumer_state) return self - def producer_commit_and_advance(self): + def producer_commit_and_advance(self) -> "GenericPipelineBase": """Combined producer commit + advance with automatic elect_one using internal state.""" self.producer_commit() # Update internal state in-place for better performance @@ -139,17 +173,42 @@ class GenericPipelineBase: ) return self - def consumer_wait_and_get_stage(self): - """Combined consumer wait + get_stage with automatic elect_one using internal state.""" - self.consumer_wait() - return get_pipeline_consume_stage(self.raw_pipeline, self.consumer_state) + def consumer_wait_and_get_stage( + self, *, token: Optional[SkipWaitToken] = None + ) -> tuple[ir.Value, ir.Value]: + """Wait for a consumer stage and return its stage token/index. - def consumer_wait(self): + When `token` is provided, reuse the preceding `consumer_try_wait()` + result and keep the internal state at the consumed stage so a following + `consumer_release_and_advance()` retires the same stage. + """ + if token is None: + self.consumer_wait() + return get_pipeline_consume_stage(self.raw_pipeline, self.consumer_state) + + skip_wait = normalize_skip_wait_token(token) + stage_state = cutlass_lir_ir.ConsumerWaitOp( + self.raw_pipeline, + self.consumer_state, + skipWait=skip_wait, + ).outState + self.consumer_state = stage_state + stage_token, stage_idx = get_pipeline_consume_stage( + self.raw_pipeline, stage_state + ) + return stage_token, stage_idx + + def consumer_wait( + self, state: Optional[PipelineState] = None + ) -> "GenericPipelineBase": """Wait for consumer to be ready.""" - consumer_wait(self.raw_pipeline, self.consumer_state) + if state: + consumer_wait(self.raw_pipeline, state) + else: + consumer_wait(self.raw_pipeline, self.consumer_state) return self - def consumer_release_and_advance(self): + def consumer_release_and_advance(self) -> "GenericPipelineBase": """Combined consumer release + advance with automatic elect_one using internal state.""" self.consumer_release() # Update internal state in-place for better performance @@ -158,11 +217,15 @@ class GenericPipelineBase: ) return self - def consumer_tail(self): + def consumer_tail(self) -> "GenericPipelineBase": """Combined consumer tail with automatic elect_one using internal state.""" consumer_tail(self.raw_pipeline, self.consumer_state) return self + def increment_state(self, state: PipelineState) -> ir.Value: + """Advance the input state w/o modifying current pipeline""" + return pipeline_advance_iterator(self.raw_pipeline, state) + class GenericPipeline(GenericPipelineBase): """ @@ -177,7 +240,7 @@ class GenericPipeline(GenericPipelineBase): producer_arv_count: cute.Int32, consumer_arv_count: cute.Int32, num_stages: cute.Int32, - ): + ) -> "GenericPipeline": """ Create a generic pipeline with parameterized producer and consumer. @@ -204,7 +267,7 @@ class GenericPipeline(GenericPipelineBase): ) -def _validate_umma_operation_type(operation_type: OperationTypeEnum): +def _validate_umma_operation_type(operation_type: OperationTypeEnum) -> None: if operation_type not in [ OperationTypeEnum.SM100_MMA_1SM_SS, OperationTypeEnum.SM100_MMA_1SM_TS, @@ -228,6 +291,14 @@ def _is_2sm_umma_operation_type(operation_type: OperationTypeEnum) -> bool: ] +def _is_multicast_tma_operation_type(operation_type: OperationTypeEnum) -> bool: + """Check if the operation type is a multicast TMA load.""" + return operation_type in [ + OperationTypeEnum.SM90_TMA_LOAD_MULTICAST, + OperationTypeEnum.SM100_TMA_LOAD_2SM_MULTICAST, + ] + + class TMAToUMMAPipeline(GenericPipelineBase): """ Pipeline for TMA to UMMA. @@ -240,27 +311,45 @@ class TMAToUMMAPipeline(GenericPipelineBase): mma_operation_type: OperationTypeEnum, tma_operation_type: Optional[OperationTypeEnum] = None, cluster_layout_vmnk: Optional[cute.Layout] = None, - ): + ) -> "TMAToUMMAPipeline": """ Create a TMA to UMMA pipeline. - For 2SM MMA with TMA_LOAD_2SM, provide cluster_layout_vmnk for proper mask computation. + Args: + num_stages: Number of pipeline stages. + mma_operation_type: UMMA operation type (e.g., SM100_MMA_1SM_SS). + tma_operation_type: TMA operation type. Defaults to SM90_TMA_LOAD. + cluster_layout_vmnk: Cluster layout in (v, m, n, k) order. Required + whenever the selected TMA load spans more than one CTA, i.e. + for any 2SM or multicast `tma_operation_type`. This layout is + the source of truth for CTA identity and v-pair membership. The + m/n/k dimensions can be dynamic. """ _validate_umma_operation_type( mma_operation_type, ) - # Default to SM90_TMA_LOAD if not specified if tma_operation_type is None: tma_operation_type = OperationTypeEnum.SM90_TMA_LOAD + if _is_multicast_tma_operation_type(tma_operation_type): + if cluster_layout_vmnk is None: + raise ValueError( + "cluster_layout_vmnk is required when using multicast TMA loads" + ) + return TMAToUMMAPipeline._create_with_multicast_mask( + num_stages=num_stages, + tma_operation_type=tma_operation_type, + mma_operation_type=mma_operation_type, + cluster_layout_vmnk=cluster_layout_vmnk, + ) + if tma_operation_type == OperationTypeEnum.SM100_TMA_LOAD_2SM: if cluster_layout_vmnk is None: raise ValueError( "cluster_layout_vmnk is required if using 2CTA MMA with TMA" ) - # If using 2CTA MMA, need consumer_mask == local_cta | peer_cta cta_rank_in_cluster = cute.arch.make_warp_uniform( cute.arch.block_idx_in_cluster() ) @@ -271,7 +360,7 @@ class TMAToUMMAPipeline(GenericPipelineBase): cluster_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0 ) - raw_pipeline, producer_state, consumer_state = create_pipeline_with_mask( + raw_pipeline, producer_state, consumer_state = create_pipeline( num_stages, tma_operation_type, mma_operation_type, @@ -279,7 +368,7 @@ class TMAToUMMAPipeline(GenericPipelineBase): consumer_arv_count=1, arrival_mask=arrival_mask, ) - else: + elif tma_operation_type == OperationTypeEnum.SM90_TMA_LOAD: raw_pipeline, producer_state, consumer_state = create_pipeline( num_stages, tma_operation_type, @@ -287,6 +376,8 @@ class TMAToUMMAPipeline(GenericPipelineBase): producer_arv_count=1, consumer_arv_count=1, ) + else: + raise ValueError(f"Invalid tma_operation_type: {tma_operation_type}") return TMAToUMMAPipeline( raw_pipeline, num_stages, @@ -301,15 +392,24 @@ class TMAToUMMAPipeline(GenericPipelineBase): tma_operation_type: OperationTypeEnum, mma_operation_type: OperationTypeEnum, cluster_layout_vmnk: cute.Layout, - ): - """ - Create a TMA to UMMA pipeline with multicast mask for 2CTA operations. - """ - _validate_umma_operation_type( - mma_operation_type, + ) -> "TMAToUMMAPipeline": + """Backward-compatible alias. Prefer create(tma_operation_type=...MULTICAST).""" + return TMAToUMMAPipeline._create_with_multicast_mask( + num_stages=num_stages, + tma_operation_type=tma_operation_type, + mma_operation_type=mma_operation_type, + cluster_layout_vmnk=cluster_layout_vmnk, ) - # Calculate TMA multicasting masks + @staticmethod + def _create_with_multicast_mask( + *, + num_stages: cute.Int32, + tma_operation_type: OperationTypeEnum, + mma_operation_type: OperationTypeEnum, + cluster_layout_vmnk: cute.Layout, + ) -> "TMAToUMMAPipeline": + """Internal: compute TMA multicast masks from cluster layout.""" tma_mcast_proj_A = 2 # multicast across CTAs in same row tma_mcast_proj_B = 1 # multicast across CTAs in same column @@ -322,7 +422,7 @@ class TMAToUMMAPipeline(GenericPipelineBase): # For 2CTA MMA (v-size==2), the peer CTA is the other v-slice (xor 1). # For 1CTA MMA (v-size==1), the peer is the local CTA (no flip). - v_size = cute.size(cluster_layout_vmnk.shape[0]) + v_size = cute.size(cluster_layout_vmnk.shape[0]) # type: ignore[index] peer_v = ( (cta_in_cluster_coord_vmnk[0] ^ 1) if cutlass.const_expr(v_size > 1) @@ -356,11 +456,11 @@ class TMAToUMMAPipeline(GenericPipelineBase): arrival_mask_a | arrival_mask_a_peer | arrival_mask_b | arrival_mask_b_peer ) - num_mcast_ctas_a = cute.size(cluster_layout_vmnk.shape[2]) - num_mcast_ctas_b = cute.size(cluster_layout_vmnk.shape[1]) + num_mcast_ctas_a = cute.size(cluster_layout_vmnk.shape[2]) # type: ignore[index] + num_mcast_ctas_b = cute.size(cluster_layout_vmnk.shape[1]) # type: ignore[index] num_mcast_participants = num_mcast_ctas_a + num_mcast_ctas_b - 1 - raw_pipeline, producer_state, consumer_state = create_pipeline_with_mask( + raw_pipeline, producer_state, consumer_state = create_pipeline( num_stages, tma_operation_type, mma_operation_type, @@ -372,19 +472,189 @@ class TMAToUMMAPipeline(GenericPipelineBase): raw_pipeline, num_stages, producer_state, consumer_state ) - def producer_commit(self): - """Commit producer state.""" + def producer_commit(self) -> "TMAToUMMAPipeline": + """ + Commit producer state. + + For 2SM MMA, only leader CTA commits during production as MMA + is issued by leader. Compiler generates the if-leader-cta-branch + internally to preserve a symmetric acquire-commit pattern. + """ with cute.arch.elect_one(): super().producer_commit() return self - def consumer_release(self): + def consumer_release(self) -> "TMAToUMMAPipeline": """Release consumer state.""" with cute.arch.elect_one(): super().consumer_release() return self +class TMAToUMMACircularPipeline(TMAToUMMAPipeline): + """ + Circular Buffer Pipeline for TMA to UMMA. + + This class wraps a TMAToUMMAPipeline and adds circular buffer semantics, + allowing fine-grained control over chunk-wise consumption within pipeline stages. + """ + + def __init__( + self, + raw_pipeline: ir.Value, + num_stages: cute.Int32, + producer_state: ir.Value, + consumer_state: ir.Value, + circular_buffer_state: ir.Value, + count_per_stage: int, + count_per_iteration: int, + ) -> None: + super().__init__(raw_pipeline, num_stages, producer_state, consumer_state) + self.circular_buffer_state = circular_buffer_state + self.count_per_stage = count_per_stage + self.count_per_iteration = count_per_iteration + + @staticmethod + def create( # type: ignore[override] + *, + num_stages: cute.Int32, + mma_operation_type: OperationTypeEnum, + count_per_stage: int, + count_per_iteration: int, + tma_operation_type: Optional[OperationTypeEnum] = None, + cluster_layout_vmnk: Optional[cute.Layout] = None, + ) -> "TMAToUMMACircularPipeline": + """ + Create a TMA to UMMA circular buffer pipeline. + + Args: + num_stages: Number of pipeline stages + mma_operation_type: MMA operation type + count_per_stage: Number of units (chunks) per pipeline stage + count_per_iteration: Number of units (chunks) consumed per iteration + tma_operation_type: TMA operation type (optional, defaults to SM90_TMA_LOAD) + cluster_layout_vmnk: Cluster layout in (v, m, n, k) order. Required + whenever the selected TMA load spans more than one CTA (2SM or + multicast). + + Returns: + TMAToUMMACircularPipeline: A circular buffer pipeline instance + """ + base_pipeline = TMAToUMMAPipeline.create( + num_stages=num_stages, + mma_operation_type=mma_operation_type, + tma_operation_type=tma_operation_type, + cluster_layout_vmnk=cluster_layout_vmnk, + ) + + # Create the circular buffer pipeline state on top + circular_buffer_state = create_circular_buffer_pipeline( + base_pipeline.raw_pipeline, + base_pipeline.consumer_state, + stages=num_stages, + count_per_stage=count_per_stage, + count_per_iteration=count_per_iteration, + ) + + return TMAToUMMACircularPipeline( + base_pipeline.raw_pipeline, + num_stages, + base_pipeline.producer_state, + base_pipeline.consumer_state, + circular_buffer_state, + count_per_stage, + count_per_iteration, + ) + + def __extract_mlir_values__(self) -> list: + """Extract MLIR values for DynamicExpression protocol.""" + # Get base values from parent: [pipeline, num_stages, producer_state, consumer_state] + base_values = super().__extract_mlir_values__() + + # Add circular buffer specific values + count_per_stage_dsl = Int32(self.count_per_stage) + count_per_iteration_dsl = Int32(self.count_per_iteration) + + return ( + base_values + + [count_per_stage_dsl.__extract_mlir_values__()[0]] # type: ignore[attr-defined] + + [count_per_iteration_dsl.__extract_mlir_values__()[0]] # type: ignore[attr-defined] + + [self.circular_buffer_state] + ) + + @classmethod + def __new_from_mlir_values__(cls, values: list) -> "TMAToUMMACircularPipeline": + """Reconstruct object from MLIR values.""" + # Parse: [pipeline, num_stages, producer_state, consumer_state, count_per_stage, count_per_iteration, circular_buffer_state] + raw_pipeline = values[0] + num_stages_val = values[1] + producer_state = values[2] + consumer_state = values[3] + count_per_stage_val = values[4] + count_per_iteration_val = values[5] + circular_buffer_state = values[6] + + # Extract Python values from DSL objects + temp_int = Int32(0) + num_stages = temp_int.__new_from_mlir_values__([num_stages_val]).value # type: ignore[attr-defined] + count_per_stage = temp_int.__new_from_mlir_values__([count_per_stage_val]).value # type: ignore[attr-defined] + count_per_iteration = temp_int.__new_from_mlir_values__( # type: ignore[attr-defined] + [count_per_iteration_val] + ).value + + return cls( + raw_pipeline, + num_stages, + producer_state, + consumer_state, + circular_buffer_state, + count_per_stage, + count_per_iteration, + ) + + def consumer_wait(self) -> "TMAToUMMACircularPipeline": # type: ignore[override] + """Wait for consumer to be ready (uses circular buffer consume).""" + circular_buffer_pipeline_consume(self.raw_pipeline, self.circular_buffer_state) + return self + + def consumer_release(self) -> "TMAToUMMACircularPipeline": + """Release consumer state (uses circular buffer consumer release).""" + with cute.arch.elect_one(): + circular_buffer_pipeline_consumer_release( + self.raw_pipeline, self.circular_buffer_state + ) + return self + + def consumer_release_and_advance(self) -> "TMAToUMMACircularPipeline": + """Combined consumer release + advance using circular buffer semantics.""" + self.consumer_release() + # Update circular buffer state + self.circular_buffer_state = circular_buffer_pipeline_advance_iterator( + self.raw_pipeline, self.circular_buffer_state + ) + return self + + def get_consumer_stage(self) -> None: + """Get consumer stage - unsupported for circular buffer pipeline.""" + raise NotImplementedError( + "get_consumer_stage() is not supported for TMAToUMMACircularPipeline." + ) + + def consumer_wait_and_get_stage( + self, *, token: Optional[SkipWaitToken] = None + ) -> NoReturn: + """Combined consumer wait + get_stage - unsupported for circular buffer pipeline.""" + raise NotImplementedError( + "consumer_wait_and_get_stage() is not supported for TMAToUMMACircularPipeline." + ) + + def consumer_tail(self) -> NoReturn: + """Consumer tail - unsupported for circular buffer pipeline.""" + raise NotImplementedError( + "consumer_tail() is not supported for TMAToUMMACircularPipeline." + ) + + class TMAToAsyncPipeline(GenericPipelineBase): """ Pipeline for TMA to * (except UMMA). @@ -396,7 +666,7 @@ class TMAToAsyncPipeline(GenericPipelineBase): num_stages: cute.Int32, consumer: OperationTypeEnum, consumer_arv_count: cute.Int32, - ): + ) -> "TMAToAsyncPipeline": """ Create a TMA to * (except UMMA) pipeline. """ @@ -415,7 +685,7 @@ class TMAToAsyncPipeline(GenericPipelineBase): consumer_state, ) - def producer_commit(self): + def producer_commit(self) -> "TMAToAsyncPipeline": """Commit producer state.""" with cute.arch.elect_one(): super().producer_commit() @@ -434,7 +704,7 @@ class AsyncToUMMAPipeline(GenericPipelineBase): producer: OperationTypeEnum, producer_arv_count: cute.Int32, mma_operation_type: OperationTypeEnum, - ): + ) -> "AsyncToUMMAPipeline": """ Create a * (except TMA) to UMMA pipeline. """ @@ -459,7 +729,7 @@ class AsyncToUMMAPipeline(GenericPipelineBase): consumer_state, ) - def consumer_release(self): + def consumer_release(self) -> "AsyncToUMMAPipeline": """Release consumer state.""" with cute.arch.elect_one(): super().consumer_release() @@ -479,7 +749,7 @@ class UMMAtoAsyncPipeline(GenericPipelineBase): consumer_arv_count: cute.Int32, mma_operation_type: OperationTypeEnum, cluster_layout_vmnk: Optional[cute.Layout] = None, - ): + ) -> "UMMAtoAsyncPipeline": """ Create a UMMA to * (except TMA) pipeline. @@ -525,14 +795,14 @@ class UMMAtoAsyncPipeline(GenericPipelineBase): consumer_arv_count: cute.Int32, mma_operation_type: OperationTypeEnum, cluster_layout_vmnk: cute.Layout, - ): + ) -> "UMMAtoAsyncPipeline": """ Create a UMMA to * pipeline with arrival mask for 2CTA operations. """ tmem_sync_mask = cutlass.pipeline.PipelineUmmaAsync._compute_tmem_sync_mask( cta_layout_vmnk=cluster_layout_vmnk ) - raw_pipeline, producer_state, consumer_state = create_pipeline_with_mask( + raw_pipeline, producer_state, consumer_state = create_pipeline( num_stages, mma_operation_type, consumer_type, @@ -547,7 +817,7 @@ class UMMAtoAsyncPipeline(GenericPipelineBase): consumer_state, ) - def producer_commit(self): + def producer_commit(self) -> "UMMAtoAsyncPipeline": """Commit producer state.""" with cute.arch.elect_one(): super().producer_commit() @@ -557,7 +827,7 @@ class UMMAtoAsyncPipeline(GenericPipelineBase): @dataclass class TMAStorePipeline: """ - TMA Store Pipeline modeling SMEM producer to TMA consumer pipeline. + TMA Store Pipeline modeling SMEM producer (store to smem operations) to TMA consumer (TMA store to global) pipeline. A number of epilogue warps participate in the pipeline as producers, and one of them is designated as the consumer to perform TMA store. Named barrier is used to synchronize all warps so that producers write SMEM after the pipeline stage is available, and the consumer waits for all producers before issuing TMA store. The canonical pipeline flow is: @@ -581,10 +851,10 @@ class TMAStorePipeline: tma_warp_id: int index: int = 0 - def get_num_stages(self): - return self.stages + def get_num_stages(self) -> int: + return self.stages # type: ignore[return-value] - def acquire_sync(self): + def acquire_sync(self) -> "TMAStorePipeline": """ Acquire pipeline stage and synchronize all warps. @@ -593,7 +863,7 @@ class TMAStorePipeline: """ @CuteExperimentalDSL.jit - def acquire_sync_impl(): + def acquire_sync_impl() -> "TMAStorePipeline": # Only TMA warp needs to wait for bulk async operations warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) @@ -612,7 +882,7 @@ class TMAStorePipeline: return acquire_sync_impl() - def commit_sync(self): + def commit_sync(self) -> "TMAStorePipeline": """ Fence SMEM writes and synchronize all warps. @@ -626,7 +896,7 @@ class TMAStorePipeline: self._barrier() return self - def release_advance(self): + def release_advance(self) -> "TMAStorePipeline": """ Release current stage and advance to next stage. @@ -635,7 +905,7 @@ class TMAStorePipeline: """ @CuteExperimentalDSL.jit - def release_advance_impl(): + def release_advance_impl() -> "TMAStorePipeline": # Only TMA warp commits the TMA operations warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) @@ -650,11 +920,11 @@ class TMAStorePipeline: return release_advance_impl() - def get_index(self): + def get_index(self) -> int: """Get current pipeline stage index.""" return self.index - def tail(self): + def tail(self) -> "TMAStorePipeline": """ Wait for all remaining TMA operations to complete. @@ -662,7 +932,7 @@ class TMAStorePipeline: """ @CuteExperimentalDSL.jit - def tail_impl(): + def tail_impl() -> "TMAStorePipeline": warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) @@ -676,9 +946,105 @@ class TMAStorePipeline: return tail_impl() - def _barrier(self): + def _barrier(self) -> None: """Internal barrier synchronization.""" cute.arch.barrier( barrier_id=self.barrier_id, number_of_threads=self.arv_count, ) + + +class GroupedGemmSchedulerPipeline(GenericPipelineBase): + """ + Pipeline for a dedicated scheduler warp producing tile info into SMEM, + consumed by all other warps. + """ + + @staticmethod + def create( + *, + num_stages: cute.Int32, + producer_arv_count: cute.Int32, + consumer_arv_count: cute.Int32, + ) -> "GroupedGemmSchedulerPipeline": + """ + Create a grouped gemm scheduler pipeline. + """ + raw_pipeline, producer_state, consumer_state = create_pipeline( + num_stages, + OperationTypeEnum.SW_STATIC_PERSISTENT_TILE_SCHEDULER, + OperationTypeEnum.LDS, + producer_arv_count=producer_arv_count, + consumer_arv_count=consumer_arv_count, + ) + return GroupedGemmSchedulerPipeline( + raw_pipeline, + num_stages, + producer_state, + consumer_state, + ) + + def consumer_wait(self) -> "GroupedGemmSchedulerPipeline": # type: ignore[override] + """Wait for consumer to be ready.""" + consumer_wait(self.raw_pipeline, self.consumer_state) + return self + + def consumer_release(self) -> "GroupedGemmSchedulerPipeline": + """Release consumer state.""" + consumer_release(self.raw_pipeline, self.consumer_state) + return self + + def producer_commit_and_advance(self) -> "GroupedGemmSchedulerPipeline": + """Commit producer state and advance to next stage.""" + super().producer_commit_and_advance() + return self + + +class CLCPipeline(GenericPipelineBase): + """ + Pipeline for tile scheduling (using CLC) to all warps. + """ + + @staticmethod + def create( + *, + num_stages: cute.Int32, + consumer_arv_count: cute.Int32, + ) -> "CLCPipeline": + """ + Create a CLC to consumer pipeline. + + The consumer includes mma, tma, epilogue, and scheduler. + """ + + raw_pipeline, producer_state, consumer_state = create_pipeline( + num_stages, + OperationTypeEnum.SM100_LAUNCH_CONTROL, + OperationTypeEnum.LDS, + producer_arv_count=1, + consumer_arv_count=consumer_arv_count, + ) + return CLCPipeline( + raw_pipeline, + num_stages, + producer_state, + consumer_state, + ) + + def producer_commit(self) -> "CLCPipeline": + """Commit producer state.""" + producer_commit(self.raw_pipeline, self.producer_state) + return self + + @staticmethod + def get_response_size() -> int: + """ + Returns the size in bytes of a CLC response. + """ + return 16 + + def expect_response(self, mbar_ptr: Pointer) -> None: + """ + Increments the expected transaction count of a CLC response. + """ + mbarrier_expect_tx(mbar_ptr, self.get_response_size(), cute.arch.lane_idx()) diff --git a/python/CuTeDSL/cutlass/cute/experimental/utils.py b/python/CuTeDSL/cutlass/cute/experimental/utils.py index cfcce7288..c2c28420f 100644 --- a/python/CuTeDSL/cutlass/cute/experimental/utils.py +++ b/python/CuTeDSL/cutlass/cute/experimental/utils.py @@ -9,123 +9,93 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from typing import Callable, Optional, Tuple, Union + +import cutlass from cutlass import cute +from cutlass._mlir import ir + +from ... import cutlass_dsl as _dsl +from .pipeline import TMAStorePipeline, TMAToUMMAPipeline def get_cta_v_map_ab( - gmem_tensor, - mma_tiler_mnk, - tiled_mma, - input_operand, + gmem_tensor: cute.Tensor, + mma_tiler_mnk: cute.Shape, + tiled_mma: cute.TiledMma, + input_operand: str, *, - loc=None, - ip=None, -): - """ - Build the **CTA-to-value map** (aka **CTA V-map**) layout for a TMA load of A/B - (and scale-factor variants SFA/SFB). - - In practice, `cta_v_map` is a `cute.Layout` that tells TMA how this CTA’s - portion of a global tensor tile maps onto the values being transferred into - shared memory. - - :param gmem_tensor: Global-memory tensor being loaded by TMA. - :type gmem_tensor: cute.Tensor - :param mma_tiler_mnk: The (M,N,K,...) tiler describing the CTA tile shape. - :type mma_tiler_mnk: tuple - :param tiled_mma: The tiled MMA object used to derive the per-operand thread/value mapping. - :type tiled_mma: cute.core.TiledMma - :param input_operand: One of {"A","B","SFA","SFB"} selecting which operand mapping to use. - :type input_operand: str - :returns: A layout suitable to pass as `cta_v_map=...` to `tma_load` / `tma_load_multicast`. - :rtype: cute.Layout - """ + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[cute.Layout, cute.ComposedLayout]: ident = cute.core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) mode = 0 if (input_operand in ("A", "SFA")) else 1 - mma_tiler_mk = (mma_tiler_mnk[mode], *mma_tiler_mnk[2:]) + mma_tiler_mk = (mma_tiler_mnk[mode], *mma_tiler_mnk[2:]) # type: ignore[index] g_tile = cute.core.composition(ident, mma_tiler_mk, loc=loc, ip=ip) if input_operand in ("A", "SFA"): cta_v_map = tiled_mma._thrfrg_A(g_tile) if input_operand in ("B", "SFB"): cta_v_map = tiled_mma._thrfrg_B(g_tile) - cta_v_map = cute.core.get(cta_v_map, mode=[1]) - cta_v_map = cute.core.dice(cta_v_map, (1, (1,) * cute.core.rank(g_tile))) - return cta_v_map + cta_v_map = cute.core.get(cta_v_map, mode=[1]) # type: ignore[assignment] + cta_v_map = cute.core.dice(cta_v_map, (1, (1,) * cute.core.rank(g_tile))) # type: ignore[assignment] + return cta_v_map # type: ignore[return-value] + def get_cta_v_map_c( - gmem_tensor, - epi_tile, - *, - loc=None, - ip=None, -): - """ - Build the **CTA-to-value map** (aka **CTA V-map**) layout for a TMA store/load - of the output tensor C/D. - - This returns an identity layout over the global tensor composed with the - epilogue tile, yielding a `cute.Layout` that describes which global indices - this CTA is responsible for. - - :param gmem_tensor: Global-memory tensor being stored/loaded by TMA. - :type gmem_tensor: cute.Tensor - :param epi_tile: Epilogue tile layout describing the CTA's output tile shape. - :type epi_tile: cute.Layout - :returns: A layout suitable to pass as `cta_v_map=...` to `tma_store` / `tma_load`. - :rtype: cute.Layout - """ - ident = cute.core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) - return cute.core.composition(ident, epi_tile, loc=loc, ip=ip) + gmem_tensor: cute.Tensor, + epi_tile: Union[cute.Tile, cute.Shape], +) -> cute.Layout: + return cute.composition(cute.make_identity_layout(gmem_tensor.shape), epi_tile) def make_tmem_layout_acc( - tiled_mma, - mnk_tiler, - acc_stage, + tiled_mma: cute.TiledMma, + mnk_tiler: cute.Shape, + acc_stage: int, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: """Return TMEM accumulator buffer layout for a tiled MMA. This is a small helper around ``tiled_mma.make_fragment_C(...).layout`` to keep example code fragment-free at the call site. - :param tiled_mma: The MMA tiler (``cute.TiledMma``). - :type tiled_mma: cute.TiledMma - :param mnk_tiler: Full MNK tiler; only the MN components are used for C. - :type mnk_tiler: tuple - :param acc_stage: Accumulator pipeline stages. - :param loc: Optional location for DSL ops. - :param ip: Optional insertion point for DSL ops. - :return: Layout for the accumulator TMEM buffer. - :rtype: cute.Layout + Args: + tiled_mma: The MMA tiler (``cute.TiledMma``). + mnk_tiler: Full MNK tiler; only the MN components are used for C. + acc_stage: Accumulator pipeline stages. + loc: Optional location for DSL ops. + ip: Optional insertion point for DSL ops. + + Returns: + ``cute.Layout`` for the accumulator TMEM buffer. """ - acc_shape = tiled_mma.partition_shape_C(mnk_tiler[:2], loc=loc, ip=ip) + acc_shape = tiled_mma.partition_shape_C(mnk_tiler[:2], loc=loc, ip=ip) # type: ignore[index] acc_shape_staged = cute.append(acc_shape, acc_stage, loc=loc, ip=ip) return tiled_mma.make_fragment_C(acc_shape_staged, loc=loc, ip=ip).layout def make_tmem_layout_a( - tiled_mma, - mk_tiler, - stage, + tiled_mma: cute.TiledMma, + mk_tiler: cute.Shape, + stage: int, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: """Return TMEM A operand buffer layout for a tiled MMA. - :param tiled_mma: The MMA tiler (``cute.TiledMma``). - :type tiled_mma: cute.TiledMma - :param mk_tiler: MK tiler used to shape the A operand. - :type mk_tiler: tuple - :param stage: Pipeline stages for the A operand buffer. - :param loc: Optional location for DSL ops. - :param ip: Optional insertion point for DSL ops. - :return: Layout for the A operand TMEM buffer. - :rtype: cute.Layout + Args: + tiled_mma: The MMA tiler (``cute.TiledMma``). + mk_tiler: MK tiler used to shape the A operand. + stage: Pipeline stages for the A operand buffer. + loc: Optional location for DSL ops. + ip: Optional insertion point for DSL ops. + + Returns: + ``cute.Layout`` for the A operand TMEM buffer. """ a_shape = tiled_mma.partition_shape_A(mk_tiler, loc=loc, ip=ip) a_shape_staged = cute.append(a_shape, stage, loc=loc, ip=ip) @@ -133,30 +103,270 @@ def make_tmem_layout_a( def make_t2r_rmem_layout( - tiled_copy_t2r, - gC_mnl_epi, - tidx, + tiled_copy_t2r: cute.TiledCopy, + gC_mnl_epi: cute.Tensor, + tidx: cute.Int32, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[cute.Layout, cute.ComposedLayout]: """Return RMEM buffer layout for the T2R epilogue destination. Computes the per-thread RMEM buffer layout produced by a TMEM->RMEM copy for a single epilogue iteration. - :param tiled_copy_t2r: The TMEM->RMEM tiled copy op (``cute.TiledCopy``). - :type tiled_copy_t2r: cute.TiledCopy - :param gC_mnl_epi: Global C tensor partitioned by epilogue tile. - :type gC_mnl_epi: cute.Tensor - :param tidx: Thread index for the copy slice. - :param loc: Optional location for DSL ops. - :param ip: Optional insertion point for DSL ops. - :return: Layout for the RMEM buffer. - :rtype: cute.Layout + Args: + tiled_copy_t2r: The TMEM->RMEM tiled copy op (``cute.TiledCopy``). + gC_mnl_epi: Global C tensor partitioned by epilogue tile. + tidx: Thread index for the copy slice. + loc: Optional location for DSL ops. + ip: Optional insertion point for DSL ops. + + Returns: + ``cute.Layout`` for the RMEM buffer. """ thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi, loc=loc, ip=ip) return cute.make_fragment_like( tTR_gC[(None, None, None, 0, 0)].layout, loc=loc, ip=ip ) + + +@_dsl.CuteExperimentalDSL.jit +def epilogue_tma_store( + cta_tile_shape_mnk: cute.Shape, + use_2cta_instrs: bool, + tmem_acc_buffer_staged: cute.Tensor, + gmem_d: cute.Tensor, + cta_d_tile_coord: cute.Coord, + tma_store_pipeline: TMAStorePipeline, + tma_store_warp_id: int, + epilogue_op: Callable[[cute.Tensor], cute.Tensor], + d_major_mode: Optional["LayoutEnum"] = None, # type: ignore[name-defined] + tid_x_in_group: Optional[int] = None, +) -> TMAStorePipeline: + """ + Epilogue phase: copy accumulator from TMEM to GMEM via RMEM and TMA store. + + This function implements the epilogue for GEMM on Blackwell (SM100): it consumes + the accumulator produced by the MMA warp and writes the output tile to global + memory. The data flow is: + + TMEM --copy--> RMEM --epilogue op--> RMEM --copy--> SMEM --TMA--> GMEM + + The TMA store pipeline coordinates multiple warps writing to SMEM before a single + warp (tma_store_warp_id) issues the TMA store. Pipeline protocol per sub-tile: + acquire_sync() -> RMEM->SMEM copy -> commit_sync() -> TMA store (TMA warp only) + -> release_advance(). tail() is called at the end to wait for in-flight TMA stores. + + Args: + cta_tile_shape_mnk: Effective (M, N, K) tile shape per CTA for epilogue tiling + use_2cta_instrs: True if using 2-CTA MMA instructions (affects epilogue tile shape) + tmem_acc_buffer_staged: One stage slice from the full accumulator pipeline for + this CTA's tile, should have shape + (cta_tile_shape_mnk[0], cta_tile_shape_mnk[1], 1, 1) + gmem_d: Global output tensor D + cta_d_tile_coord: Coordinate of this CTA's output tile, e.g. (cta_m, cta_n, cta_l) + tma_store_pipeline: TMAStorePipeline instance + tma_store_warp_id: Warp index that issues TMA stores + epilogue_op: Callable applied in registers to accumulator values before store + d_major_mode: LayoutEnum for d_tensor, the function will automatically detect + the d_major_mode from gmem_d if not provided + tid_x_in_group: Thread index in the group of warps that issue TMA stores. For + example, if warps 4-7 are in the same group and calling this function, + tid_x_in_group should be 0-127 instead of 128-255. If not provided, the + function will use cute.arch.thread_idx(). + + Returns: + tma_store_pipeline: The updated TMAStorePipeline instance + """ + from .algorithm import partition_and_copy + from .memory import allocate, tma_store + import cutlass.utils.blackwell_helpers as blackwell_helpers + from cutlass import utils as utils + + if cutlass.const_expr(tid_x_in_group is None): + tid_x_in_group, _, _ = cute.arch.thread_idx() + tid_x_in_group = tid_x_in_group % 128 + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + acc_dtype = tmem_acc_buffer_staged.element_type + d_dtype = gmem_d.element_type + if cutlass.const_expr(d_major_mode is None): + d_major_mode = utils.LayoutEnum.from_tensor(gmem_d) + + epi_tile_shape = blackwell_helpers.compute_epilogue_tile_shape( + cta_tile_shape_mnk, + use_2cta_instrs, + d_major_mode, + d_dtype, + ) + + copy_atom_t2r = blackwell_helpers.get_tmem_load_op( + cta_tile_shape_mnk, + d_major_mode, + d_dtype, + acc_dtype, + epi_tile_shape, + use_2cta_instrs, + ) + + acc_epi_div_tiled = cute.flat_divide(tmem_acc_buffer_staged, epi_tile_shape) + acc_epi_div_slice = acc_epi_div_tiled[None, None, 0, 0] + tiled_copy_t2r = cute.nvgpu.tcgen05.make_tmem_copy(copy_atom_t2r, acc_epi_div_slice) + tiled_copy_r2s = cute.make_tiled_copy_D( + cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), d_dtype), + tiled_copy_t2r, + ) + + tiler_mn = (cta_tile_shape_mnk[0], cta_tile_shape_mnk[1]) # type: ignore[index] + gmem_d_mn_tiled = cute.zipped_divide(gmem_d, tiler_mn) + gmem_d_tile = gmem_d_mn_tiled[(None, None), cta_d_tile_coord] + gmem_d_epi_tma = cute.flat_divide(gmem_d_tile, epi_tile_shape) # type: ignore[arg-type] + epi_subtile_cnt = gmem_d_epi_tma.shape[3] # type: ignore[index] + + acc_d_rmem_layout = make_t2r_rmem_layout( + tiled_copy_t2r, + gmem_d_epi_tma, + tid_x_in_group, # type: ignore[arg-type] + ) + rmem_acc_buffer = allocate( + acc_dtype, + cute.AddressSpace.rmem, + acc_d_rmem_layout, + alignment=32, + ) + rmem_d_buffer = allocate( + d_dtype, + cute.AddressSpace.rmem, + acc_d_rmem_layout, + alignment=32, + ) + + d_smem_layout_staged = blackwell_helpers.make_smem_layout_epi( + d_dtype, + d_major_mode, + epi_tile_shape, + tma_store_pipeline.get_num_stages(), + ) + smem_d_buffer = allocate( + d_dtype, + cute.AddressSpace.smem, + d_smem_layout_staged, + alignment=1024, + ) + + for epi_subtile_idx in range(epi_subtile_cnt): # type: ignore[arg-type] + # TMEM -> RMEM + partition_and_copy( + tiled_copy_t2r.get_slice(tid_x_in_group), + acc_epi_div_tiled[None, None, 0, epi_subtile_idx], + rmem_acc_buffer, + ) + + # RMEM -> RMEM and epilogue Op + acc_vec = rmem_acc_buffer.load() + epilogue_out = epilogue_op(acc_vec.to(d_dtype)) + rmem_d_buffer.store(epilogue_out) + + # RMEM -> SMEM + # The TMA store pipeline coordinates multiple warps writing to SMEM + # before a single warp issues the TMA store. + # acquire_sync(): + # - TMA warp waits for any in-flight TMA ops to complete + # - All warps synchronize via a named barrier + tma_store_pipeline.acquire_sync() + store_idx = tma_store_pipeline.get_index() + partition_and_copy( + tiled_copy_r2s.get_slice(tid_x_in_group), + rmem_d_buffer, + smem_d_buffer[None, None, store_idx], + ) + + # commit_sync(): + # - Fences SMEM writes to ensure visibility for TMA + # - All warps synchronize before TMA store + # This is CRITICAL: TMA must see committed SMEM writes + tma_store_pipeline.commit_sync() + + # SMEM -> GMEM + if warp_idx == tma_store_warp_id: + tma_store( + smem_d_buffer[None, None, store_idx], + gmem_d_epi_tma[None, None, 0, epi_subtile_idx], + ) + + # release_advance(): + # - TMA warp commits TMA ops to bulk group + # - All warps advance to the next pipeline stage + tma_store_pipeline.release_advance() + + tma_store_pipeline.tail() + return tma_store_pipeline + + +@_dsl.CuteExperimentalDSL.jit +def mainloop_mma( + tiled_mma: cute.TiledMma, + a_buffer: cute.Tensor, + b_buffer: cute.Tensor, + acc_buffer: cute.Tensor, + k_tile_start: cute.Int32, + k_tile_end: cute.Int32, + mma_inst_tile_k: cute.Int32, + a_buffer_pipeline: TMAToUMMAPipeline, + b_buffer_pipeline: TMAToUMMAPipeline, + ab_buffer_same_pipeline: bool = True, + accumulate_to_acc: bool = False, +) -> Tuple[TMAToUMMAPipeline, TMAToUMMAPipeline]: + """ + Mainloop MMA phase: consume A/B tiles from the pipeline and compute into TMEM accumulator. + + This function is the consumer side of the TMA load -> MMA pipeline. It waits + for the TMA load warp to fill a pipeline stage, then runs multiple MMA + instructions over the K-tile (inner loop over mma_inst_tile_k), and releases the stage. + + Args: + tiled_mma: Tiled MMA descriptor (e.g. from blackwell_helpers.make_trivial_tiled_mma) + a_buffer: A operand buffer, shape (..., mma_inst_tile_k, num_a_buffer_stages) + b_buffer: B operand buffer, shape (..., mma_inst_tile_k, num_b_buffer_stages) + acc_buffer: Accumulator buffer for this CTA's tile + k_tile_start: Start index of the K-tile to iterate over (outer loop) + k_tile_end: End index of the K-tile to iterate over (outer loop) + mma_inst_tile_k: Number of MMA instructions per K-tile (inner loop) + a_buffer_pipeline: TMAToUMMAPipeline to sync with TMA load producer for A buffer + b_buffer_pipeline: TMAToUMMAPipeline to sync with TMA load producer for B buffer + ab_buffer_same_pipeline: If the TMA load producers for A and B are the same pipeline + accumulate_to_acc: If the first K-tile should accumulate to the accumulator, + otherwise the result will be overwritten. + + Returns: + a_buffer_pipeline: The updated TMAToUMMAPipeline for A buffer + b_buffer_pipeline: The updated TMAToUMMAPipeline for B buffer + """ + from .math import dot + + mma_atom = cute.make_mma_atom(tiled_mma.op) + mma_atom.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, accumulate_to_acc) + for _k_tile in cutlass.range(k_tile_start, k_tile_end, 1, unroll=1): + _, a_buffer_stage_idx = a_buffer_pipeline.consumer_wait_and_get_stage() + if cutlass.const_expr(ab_buffer_same_pipeline): + b_buffer_stage_idx = a_buffer_stage_idx + else: + _, b_buffer_stage_idx = b_buffer_pipeline.consumer_wait_and_get_stage() + for k_instr_tile in cutlass.range(mma_inst_tile_k, unroll_full=True): + a_buffer_sliced = a_buffer[None, None, k_instr_tile, a_buffer_stage_idx] + b_buffer_sliced = b_buffer[None, None, k_instr_tile, b_buffer_stage_idx] + dot( + mma_atom, + cute.append_ones(a_buffer_sliced, up_to_rank=3), # type: ignore[arg-type] + cute.append_ones(b_buffer_sliced, up_to_rank=3), # type: ignore[arg-type] + acc_buffer, + ) + mma_atom.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) + a_buffer_pipeline.consumer_release_and_advance() + if not cutlass.const_expr(ab_buffer_same_pipeline): + b_buffer_pipeline.consumer_release_and_advance() + + return a_buffer_pipeline, b_buffer_pipeline diff --git a/python/CuTeDSL/cutlass/cute/export/__init__.py b/python/CuTeDSL/cutlass/cute/export/__init__.py index 8b2fe7aea..7b1a6806c 100644 --- a/python/CuTeDSL/cutlass/cute/export/__init__.py +++ b/python/CuTeDSL/cutlass/cute/export/__init__.py @@ -12,7 +12,7 @@ from .c_header_generator import CuteCHeaderGenerator from .export import object_file_version as _object_file_version -from .export import CuteArgsSpecProcessor as _CuteArgsSpecProcessor +from .export import CuteSignatureProcessor as _CuteSignatureProcessor from ...base_dsl.jit_executor import ExportProvider as _ExportProvider from ...cutlass_dsl import CuTeDSL as _CuTeDSL @@ -22,8 +22,8 @@ from ...cutlass_dsl.cuda_jit_executor import ( from ..._mlir._mlir_libs._cutlass_ir import _mlirExecutionEngine _CudaDialectJitCompiledFunction.export_provider = _ExportProvider( - dsl=_CuTeDSL, - arg_spec_processor=_CuteArgsSpecProcessor(), + dsl=_CuTeDSL, # type: ignore[type-abstract] + signature_processor=_CuteSignatureProcessor(), c_header_generator=CuteCHeaderGenerator(), object_file_version=_object_file_version, mlirExecutionEngine=_mlirExecutionEngine, @@ -37,8 +37,8 @@ from ..._mlir._mlir_libs._cutlass_ir._execution_engine import ( ) _ExternalBinaryModule.load_provider = _LoadProvider( - dsl=_CuTeDSL, - args_spec_processor=_CuteArgsSpecProcessor(), + dsl=_CuTeDSL, # type: ignore[type-abstract] + signature_processor=_CuteSignatureProcessor(), version_checker=_version_checker, execution_engine_constructor=_BinaryExecutionEngine, jit_function_constructor=_CudaDialectJitCompiledFunction, diff --git a/python/CuTeDSL/cutlass/cute/export/aot_config.py b/python/CuTeDSL/cutlass/cute/export/aot_config.py index 3c6c62025..bd5d63f39 100644 --- a/python/CuTeDSL/cutlass/cute/export/aot_config.py +++ b/python/CuTeDSL/cutlass/cute/export/aot_config.py @@ -106,7 +106,7 @@ def get_ldflags() -> str: return "" -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="AOT configuration helper for CuTe DSL (similar to tvm-ffi-config)", formatter_class=argparse.RawDescriptionHelpFormatter, diff --git a/python/CuTeDSL/cutlass/cute/export/c_header_generator.py b/python/CuTeDSL/cutlass/cute/export/c_header_generator.py index 762ce1222..bb79a376b 100644 --- a/python/CuTeDSL/cutlass/cute/export/c_header_generator.py +++ b/python/CuTeDSL/cutlass/cute/export/c_header_generator.py @@ -11,14 +11,13 @@ from cutlass.cute.typing import NumericMeta, Integer from cutlass.base_dsl.export import CHeaderGenerator, CHeaderArguments -from cutlass.base_dsl.dsl import is_dynamic_expression from cutlass.base_dsl.common import DSLRuntimeError from cutlass.base_dsl.jit_executor import ExecutionArgs from cutlass.cutlass_dsl.cutlass import is_cute_algebra_type from ..runtime import Tensor, Pointer -from typing import List, Any, Dict -from inspect import isclass +from typing import Any, Union, get_origin, get_args +from inspect import isclass, Parameter import cuda.bindings.driver as cuda # ============================================================================= @@ -75,15 +74,15 @@ class CuteCHeaderGenerator(CHeaderGenerator): return "int32_t " return self.numeric_to_c_type[dyn_type] - def _generate_binary_declaration(self, symbol_prefix: str): + def _generate_binary_declaration(self, symbol_prefix: str) -> str: """ Generate the binary of the compiled function. """ return "" def _generate_kernel_module( - self, symbol_prefix: str, kernel_info: Dict[str, List], dsl_name: str - ): + self, symbol_prefix: str, kernel_info: dict[str, list[Any]], dsl_name: str + ) -> str: """ Generate the kernel module for the compiled function. """ @@ -137,10 +136,10 @@ static inline void {symbol_prefix}_Kernel_Module_Unload({symbol_prefix}_Kernel_M def _generate_arguments( self, symbol_prefix: str, - args_spec: ExecutionArgs, - args: List[Any], - kwargs: Dict[str, Any], - ): + execution_args: ExecutionArgs, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[list[str], list[str], list[str]]: """ Generate the arguments of the wrapper function. """ @@ -148,28 +147,37 @@ static inline void {symbol_prefix}_Kernel_Module_Unload({symbol_prefix}_Kernel_M packed_args = [] declarations = [] # traverse the runtime args_spec and generate the arguments - rectified_args = args_spec.get_rectified_args(args, kwargs) - input_arg_names = args_spec.args_spec.args + args_spec.args_spec.kwonlyargs - for arg_name, arg in zip(input_arg_names, rectified_args): - arg_type = args_spec.args_spec.annotations.get(arg_name, None) + rectified_args = execution_args.get_rectified_args(args, kwargs) + + for param, arg in zip( + execution_args.signature.parameters.values(), rectified_args + ): + arg_type = param.annotation + arg_name = param.name # process optional argument if arg is None: continue + # Unwrap Optional[X] (i.e. Union[X, None]) to X when arg is not None + if get_origin(arg_type) is Union: + inner_types = [t for t in get_args(arg_type) if t is not type(None)] + if len(inner_types) == 1: + arg_type = inner_types[0] + if isinstance(arg, Pointer): arguments.append(f"void *{arg_name}") packed_args.append("&" + arg_name) elif isinstance(arg, Tensor): dynamic_shapes = ( - f"\n int32_t dynamic_shapes[{sum(arg.dynamic_shapes_mask)}];" - if sum(arg.dynamic_shapes_mask) > 0 + f"\n int32_t dynamic_shapes[{sum(arg.dynamic_shapes_mask)}];" # type: ignore[attr-defined] + if sum(arg.dynamic_shapes_mask) > 0 # type: ignore[attr-defined] else "" ) - stride_type = "int32_t" if arg._use_32bit_stride else "int64_t" + stride_type = "int32_t" if arg._use_32bit_stride else "int64_t" # type: ignore[attr-defined] dynamic_strides = ( - f"\n {stride_type} dynamic_strides[{sum(arg.dynamic_strides_mask)}];" - if sum(arg.dynamic_strides_mask) > 0 + f"\n {stride_type} dynamic_strides[{sum(arg.dynamic_strides_mask)}];" # type: ignore[attr-defined] + if sum(arg.dynamic_strides_mask) > 0 # type: ignore[attr-defined] else "" ) declarations.append( @@ -183,7 +191,7 @@ typedef struct {{ packed_args.append(arg_name) # Generate basic numeric types elif isinstance(arg_type, NumericMeta): - arguments.append(self._generate_numeric_argument(arg_name, arg_type)) + arguments.append(self._generate_numeric_argument(arg_name, arg_type)) # type: ignore[arg-type] packed_args.append("&" + arg_name) elif is_cute_algebra_type(arg_type) or isinstance(arg, (tuple, list)): c_type = self._get_cute_algebra_type(arg_type, arg) @@ -204,11 +212,11 @@ typedef struct {{ self, dsl_name: str, symbol_prefix: str, - args_spec: ExecutionArgs, + execution_args: ExecutionArgs, function_name: str, - kernel_info: Dict[str, List], + kernel_info: dict[str, list[Any]], c_header_arguments: CHeaderArguments, - ): + ) -> str: """ Generate the wrapper function for the compiled function which is provided to users as the entry point. It uses the `symbol_prefix` as the function name for identification. The host/device symbols are hidden under the bytecode. @@ -236,16 +244,16 @@ typedef struct {{ # 3. Get the return type of the wrapper function. # Note that this requires the return type to be properly annotated in python. - return_type = args_spec.args_spec.annotations.get("return", None) - if return_type is None: + return_type = execution_args.signature.return_annotation + if return_type is Parameter.empty: return_type = "void" else: return_type = self.numeric_to_c_type[return_type][:-1] - declarations = "\n".join(declarations) + declarations_str = "\n".join(declarations) # 4. Generate the wrapper function function = ( - declarations + declarations_str + f""" #ifdef __cplusplus extern "C" diff --git a/python/CuTeDSL/cutlass/cute/export/export.py b/python/CuTeDSL/cutlass/cute/export/export.py index 9375c954e..d2e71c89d 100644 --- a/python/CuTeDSL/cutlass/cute/export/export.py +++ b/python/CuTeDSL/cutlass/cute/export/export.py @@ -9,14 +9,12 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -import os import pickle -import copy from ..typing import IntTuple, Shape, Stride, Coord, Tile -from inspect import FullArgSpec +from inspect import Signature from cutlass.base_dsl.export import ( - ArgsSpecProcessor, + SignatureProcessor, ) cute_algebra_types_dump = { @@ -35,20 +33,31 @@ cute_algebra_types_load = { } -class CuteArgsSpecProcessor(ArgsSpecProcessor): - def dumps(self, args_spec: FullArgSpec) -> bytes: - new_args_spec = copy.deepcopy(args_spec) - for arg, arg_type in new_args_spec.annotations.items(): +class CuteSignatureProcessor(SignatureProcessor): + def dumps(self, signature: Signature) -> bytes: + params = [] + for param in signature.parameters.values(): + arg_type = param.annotation if arg_type in cute_algebra_types_dump.keys(): - new_args_spec.annotations[arg] = cute_algebra_types_dump[arg_type] - return pickle.dumps(new_args_spec) + params.append( + param.replace(annotation=cute_algebra_types_dump[arg_type]) + ) + else: + params.append(param) + return pickle.dumps(signature.replace(parameters=params)) - def loads(self, args_spec_bytes: bytes) -> FullArgSpec: - args_spec = pickle.loads(args_spec_bytes) - for arg, arg_type in args_spec.annotations.items(): + def loads(self, signature_bytes: bytes) -> Signature: + signature = pickle.loads(signature_bytes) + params = [] + for param in signature.parameters.values(): + arg_type = param.annotation if arg_type in cute_algebra_types_load.keys(): - args_spec.annotations[arg] = cute_algebra_types_load[arg_type] - return args_spec + params.append( + param.replace(annotation=cute_algebra_types_load[arg_type]) + ) + else: + params.append(param) + return signature.replace(parameters=params) # This is the version of the object file. It is used to check the version of the object file is compatible with the current dsl version or not. diff --git a/python/CuTeDSL/cutlass/cute/ffi.py b/python/CuTeDSL/cutlass/cute/ffi.py index 727d1a296..3e045c77b 100644 --- a/python/CuTeDSL/cutlass/cute/ffi.py +++ b/python/CuTeDSL/cutlass/cute/ffi.py @@ -9,198 +9,57 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +import functools +from typing import List, Optional + +from cutlass.base_dsl.ffi import extern as base_extern +from cutlass.base_dsl.ffi import FFI, BitCode, mangle, ConstValue from cutlass._mlir import ir -from cutlass._mlir.dialects import func -from cutlass.base_dsl.typing import get_mlir_types, NumericMeta, Numeric, as_numeric -from cutlass.base_dsl.dsl import extract_mlir_values - -from cutlass import DSLRuntimeError +from cutlass._mlir.dialects import cute, llvm -class ffi: - """ - Foreign Function Interface (FFI) wrapper for external function invocation in the CUTLASS Python DSL. +def _implicit_convert(arg: List[ir.Value], typ: List[ir.Type]) -> List[ir.Value]: + if len(arg) == 1 and len(typ) == 1: + arg_type = arg[0].type + typ_type = typ[0] + # implicitly cast !cute.ptr -> !llvm.ptr + if isinstance(typ_type, llvm.PointerType) and isinstance( + arg_type, cute.PtrType + ): + ptr_value = arg[0] + ptr_as_int = cute.ptrtoint(ir.IntegerType.get_signless(64), ptr_value) + addr_space = cute.PtrType(ptr_value.type).address_space + llvm_ptr_ty = llvm.PointerType.get(addr_space) + llvm_ptr = llvm.inttoptr(llvm_ptr_ty, ptr_as_int) + return [llvm_ptr] - This class enables calling external MLIR function prototypes from Python code, handling type conversion, - prototype registration, and dynamic insertion of function symbols into MLIR modules as needed. + return arg - Parameters - ---------- - name : str - Name of the external function. This will be used as the symbol name when calling or registering a prototype in the MLIR module. - params_types : list, optional - List of argument types for the external function. These can be CUTLASS numeric types, numeric meta types, or types convertible via `get_mlir_types`. - return_type : optional - The return type of the external function. If not specified, the function is assumed to have no return value. - Methods - ------- - __call__(*args) - Calls the external function with the given arguments, ensuring argument and result types match the prototype. - """ +def ffi( + *, + name: str | None = None, + params_types: list | None = None, + return_type: Optional[ir.Type] = None, + inline: bool = True, + source: str | None = None, +) -> FFI: + return FFI( + name=name, + params_types=params_types, + return_type=return_type, + inline=inline, + source=source, + implicit_convert=_implicit_convert, + ) - def __init__(self, *, name: str, params_types: list = [], return_type=None): - self.name = name - self.params_types = params_types - self.return_type = [return_type] if return_type else [] - def _get_prototype_region(self, current_op): - """ - Helper method to determine the appropriate MLIR module and region for inserting a function prototype. +extern = functools.partial(base_extern, implicit_convert=_implicit_convert) - This method recursively traverses the current operation's parent hierarchy to find the correct module - and region where the function prototype should be inserted. It supports both builtin.module and gpu.module. - :param current_op: The current operation to check. - :type current_op: Operation - - :returns: - A tuple containing the module operation and the insertion region. - :rtype: tuple - """ - if current_op is None: - raise DSLRuntimeError("current operation is unknown") - op_name = current_op.name - if op_name in ["builtin.module", "gpu.module"]: - return current_op, current_op.regions[0].blocks[0] - else: - return self._get_prototype_region(current_op.parent) - - @staticmethod - def _to_mlir_types(args): - """ - Helper method to convert a list of arguments to their corresponding MLIR types. - - This method converts CUTLASS numeric types, numeric meta types, and types convertible via `get_mlir_types` - to their corresponding MLIR types. - :param args: The list of arguments to convert to MLIR types. - :type args: list - - :returns: - A list of MLIR types. - :rtype: list - """ - types = [] - for param in args: - if isinstance(param, NumericMeta): - types.append(param.mlir_type) - elif isinstance(param, Numeric): - types.append(param.mlir_type) - else: - types.extend(get_mlir_types(param)) - return types - - @staticmethod - def _type_check(callee, exec_types, returns_types): - """ - Helper method to check if the function prototype types match the expected types. - - This method compares the input and output types of the function prototype with the provided expected types. - :param callee: The function prototype operation to check. - :type callee: func.FuncOp - :param exec_types: The expected input types. - :type exec_types: list - :param returns_types: The expected output types. - :type returns_types: list - """ - if callee.type.inputs != exec_types or callee.type.results != returns_types: - raise DSLRuntimeError( - f"External prototype types mismatch, trying to call with ({exec_types}) -> ({returns_types}), got {callee.type}" - ) - - def _create_prototype_in_region(self, op, region, exec_args): - """ - Helper method to create or retrieve a function prototype in the current module. - - This method checks if a function prototype with the given name already exists in the symbol table of the current module. - If it does, it checks if the prototype's types match the expected types. If it does not, it raises an error. - If it does not exist, it creates a new function prototype and inserts it into the current region. - :param op: The module operation to check. - :type op: Operation - :param region: The region to insert the function prototype into. - :type region: Region - :param exec_args: The arguments to pass to the function prototype. - :type exec_args: list - """ - symbol_table = ir.SymbolTable(op.operation) - - if self.name in symbol_table: - callee = symbol_table[self.name] - else: - with ir.InsertionPoint(region): - callee = func.FuncOp( - self.name, - ( - ffi._to_mlir_types(self.params_types), - ffi._to_mlir_types(self.return_type), - ), - ) - callee.sym_visibility = ir.StringAttr.get("private") - - # Sanity check the function prototype types match the expected types - self._type_check( - callee, - ffi._to_mlir_types(exec_args), - ffi._to_mlir_types(self.return_type), - ) - - return callee - - def __call__(self, *args, **kwargs): - """ - Calls the FFI function prototype with the provided arguments. - - This method ensures that an IR-level function prototype (external declaration) - with the given name and type signature exists in the current module. If it does not - exist, it will be created and inserted into the module. A call operation to this - function is then emitted using the arguments supplied by the caller. - - :param args: - The runtime arguments to pass to the FFI function. These will be converted to - their corresponding numeric types and lowered to MLIR values before being used as arguments. - :type args: tuple - - :returns: - The MLIR call operation created for this invocation. - :rtype: func.CallOp - - :raises DSLRuntimeError: - If there is no active MLIR insertion point or if the current operation - context cannot be determined. - """ - - if kwargs: - raise DSLRuntimeError( - "Keyword arguments are not supported for FFI calls", - suggestion="Use positional arguments only", - ) - - # Get the current insertion point and operation - try: - current_ip = ir.InsertionPoint.current - except Exception: - raise DSLRuntimeError( - "Failed to determine current insertion point", - suggestion="Make sure this is called under a jit context", - ) - current_op = current_ip.block.owner - module_op, insertion_region = self._get_prototype_region(current_op) - - # Extract the arguments to MLIR values - exec_args = [] - for arg in args: - exec_arg = extract_mlir_values(arg) - if not exec_arg: - exec_arg = [as_numeric(arg).ir_value()] - exec_args.extend(exec_arg) - - # Create the function prototype in module, so if it's under kernel function, prototype will be inserted into gpu.module - # If it's under gpu.module, prototype will be inserted into builtin.module - callee = self._create_prototype_in_region( - module_op, insertion_region, exec_args - ) - - # Emit the call operation - result = func.call(callee.type.results, self.name, exec_args) - - if self.return_type: - return result +__all__ = [ + "ffi", + "extern", + "BitCode", + "mangle", + "ConstValue", +] diff --git a/python/CuTeDSL/cutlass/cute/math.py b/python/CuTeDSL/cutlass/cute/math.py index 89acc70ab..ef3674d20 100644 --- a/python/CuTeDSL/cutlass/cute/math.py +++ b/python/CuTeDSL/cutlass/cute/math.py @@ -9,16 +9,22 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Callable, Union +from typing import Callable, Optional, Union from .typing import Numeric from .tensor import TensorSSA +from cutlass._mlir import ir from cutlass._mlir.dialects import math, arith from cutlass.cutlass_dsl import dsl_user_op -def _math_op(func: Callable, fastmath: bool, *args, **kwargs): +def _math_op( + func: Callable[..., ir.Value], + fastmath: bool, + *args: Union[TensorSSA, Numeric], + **kwargs: object, +) -> Union[TensorSSA, ir.Value]: """Dispatch the function to either a TensorSSA or a Numeric(Float). :param func: The function to dispatch @@ -44,40 +50,17 @@ def _math_op(func: Callable, fastmath: bool, *args, **kwargs): func(*args, fastmath=fastmath_flag, **kwargs), args[0].shape, args[0].dtype ) else: - args = [a.ir_value() for a in args] - return func(*args, fastmath=fastmath_flag, **kwargs) + ir_args = [a.ir_value() for a in args] + return func(*ir_args, fastmath=fastmath_flag, **kwargs) @dsl_user_op -def absf( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None -) -> Union[TensorSSA, Numeric]: - """Compute element-wise absolute value of the input tensor. - - :param a: Input tensor - :type a: Union[TensorSSA, Numeric] - :param fastmath: Enable fast math optimizations, defaults to False - :type fastmath: bool, optional - :param loc: Source location information, defaults to None - :type loc: Optional[Location] - :param ip: Insertion point for IR generation, defaults to None - :type ip: Optional[InsertionPoint] - :return: Tensor containing the absolute value of each element in input tensor - :rtype: Union[TensorSSA, Numeric] - - Example: - - .. code-block:: - - x = cute.make_rmem_tensor(layout) # Create tensor - y = x.load() # Load values - z = absf(y) # Compute absolute value - """ - return _math_op(math.absf, fastmath, a, loc=loc, ip=ip) - - def acos( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc cosine of the input tensor. @@ -105,7 +88,11 @@ def acos( @dsl_user_op def asin( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc sine of the input tensor. @@ -133,7 +120,11 @@ def asin( @dsl_user_op def atan( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc tangent of the input tensor. @@ -161,8 +152,12 @@ def atan( @dsl_user_op def atan2( - a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False, - *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + b: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise arc tangent of two tensors. @@ -193,14 +188,46 @@ def atan2( return _math_op(math.atan2, fastmath, a, b, loc=loc, ip=ip) +@dsl_user_op +def absf( + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[TensorSSA, Numeric]: + """Compute element-wise absolute value of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor containing the absolute value of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_rmem_tensor(layout) # Create tensor + y = x.load() # Load values + z = absf(y) # Compute absolute value + """ + return _math_op(math.absf, fastmath, a, loc=loc, ip=ip) + + @dsl_user_op def copysign( a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise copysign of two tensors. @@ -232,7 +259,11 @@ def copysign( @dsl_user_op def cos( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise cosine of the input tensor. @@ -260,7 +291,11 @@ def cos( @dsl_user_op def erf( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise error function of the input tensor. @@ -291,7 +326,11 @@ def erf( @dsl_user_op def exp( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise exponential of the input tensor. @@ -319,7 +358,11 @@ def exp( @dsl_user_op def exp2( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise base-2 exponential of the input tensor. @@ -347,7 +390,11 @@ def exp2( @dsl_user_op def floor( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise floor of the input tensor. @@ -373,8 +420,13 @@ def floor( return _math_op(math.floor, fastmath, a, loc=loc, ip=ip) +@dsl_user_op def log( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise natural logarithm of the input tensor. @@ -402,7 +454,11 @@ def log( @dsl_user_op def log2( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise base-2 logarithm of the input tensor. @@ -430,7 +486,11 @@ def log2( @dsl_user_op def log10( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise base-10 logarithm of the input tensor. @@ -458,7 +518,11 @@ def log10( @dsl_user_op def rsqrt( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise reciprocal square root of the input tensor. @@ -488,7 +552,11 @@ def rsqrt( @dsl_user_op def sin( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise sine of the input tensor. @@ -516,7 +584,11 @@ def sin( @dsl_user_op def sqrt( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise square root of the input tensor. @@ -544,7 +616,11 @@ def sqrt( @dsl_user_op def tan( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise tangent of the input tensor. @@ -572,7 +648,11 @@ def tan( @dsl_user_op def tanh( - a: Union[TensorSSA, Numeric], fastmath: bool = False, *, loc=None, ip=None + a: Union[TensorSSA, Numeric], + fastmath: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[TensorSSA, Numeric]: """Compute element-wise hyperbolic tangent of the input tensor. @@ -604,8 +684,8 @@ __all__ = [ "asin", "atan", "atan2", - "copysign", "cos", + "copysign", "erf", "exp", "exp2", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/python/CuTeDSL/cutlass/cute/nvgpu/common.py index 5cbe508d5..735badbde 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/common.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -10,32 +10,105 @@ # is strictly prohibited. import enum from dataclasses import dataclass -from typing import Type, Optional +from typing import Any, Mapping, Optional, Type, Union +import warnings -from cutlass.cutlass_dsl import DSLBaseError +from cutlass.cutlass_dsl import DSLBaseError, DSLRuntimeError import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from .. import atom -from ..typing import Float16, Float32, Float64, Numeric - +from ..typing import Float16, Float32, Float64, Numeric, Tensor, Int64 +from abc import ABC, abstractmethod __all__ = [ + "OperandMajorMode", + "OutputMajorMode", "OpError", "normalize_field_to_ir_name", "MmaUniversalOp", "MmaUniversalTrait", "CopyUniversalOp", "CopyUniversalTrait", + "CopyG2ROp", + "CopyG2RTrait", + "CopyR2GOp", + "CopyR2GTrait", + "CopyS2ROp", + "CopyS2RTrait", + "CopyR2SOp", + "CopyR2STrait", "MemoryOrder", "MemoryScope", + "L2PrefetchSize", "CacheEvictionPriority", + "LoadCacheMode", + "StoreCacheMode", + "SharedSpace", ] -def normalize_field_to_ir_name(field, admissible_fields) -> str: +class OperandMajorMode(enum.Enum): + """ + An enumeration for the majorness of the input operands of the MMA. + """ + + MN = _cute_ir.MajorMode.mn + K = _cute_ir.MajorMode.k + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def __eq__(self, other: object) -> bool: + if hasattr(other, "_to_ir") and type(other._to_ir()) is type(self._to_ir()): + return self._to_ir() == other._to_ir() + raise DSLRuntimeError( + f"{self.__module__}.{self.__class__.__qualname__} cannot be compared with {other.__module__}.{other.__class__.__qualname__}" + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.value) + + @classmethod + def _missing_(cls, value: Any) -> Optional["OperandMajorMode"]: + if isinstance(value, str): + value = value.upper() + if value == "MN": + return OperandMajorMode.MN + elif value == "K": + return OperandMajorMode.K + return None + + def _to_ir(self) -> _cute_ir.MajorMode: + return self.value + + +class OutputMajorMode(enum.Enum): + """Major mode for the output operand D(M, N). + + M = M-major (column-major): stride=(1, M), contiguous along M. + N = N-major (row-major): stride=(N, 1), contiguous along N. + """ + + M = "m" + N = "n" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + +def normalize_field_to_ir_name(field: Any, admissible_fields: Any) -> str: """ Normalize a field specifier to its IR logical field name. @@ -67,7 +140,10 @@ class OpError(DSLBaseError): """ def __init__( - self, op: atom.Op, message: str, suggestion: Optional[str] = None + self, + op: Union[atom.Op, atom.Trait], + message: str, + suggestion: Optional[str] = None, ) -> None: if suggestion is None: # Default suggestion @@ -113,7 +189,13 @@ class MmaUniversalOp(atom.MmaOp): f"\n A/B/Accumulator data type = {self.abacc_dtype}" ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaUniversalTrait": shape_mnk_attr = ir.Attribute.parse('#cute.shape<"(1,1,1)">') atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get( shape_mnk_attr, @@ -123,10 +205,22 @@ class MmaUniversalOp(atom.MmaOp): ) return MmaUniversalTrait(atom.make_atom(atom_ty, loc=loc, ip=ip)) - def _verify_fragment_A(self, input, *, loc=None, ip=None): + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: pass - def _verify_fragment_B(self, input, *, loc=None, ip=None): + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: pass @@ -178,6 +272,23 @@ class MemoryScope(enum.Enum): return self.value +class L2PrefetchSize(enum.Enum): + NONE = _cute_ir.L2PrefetchSize.NONE + RESERVED = _cute_ir.L2PrefetchSize.RESERVED + SIZE_64B = _cute_ir.L2PrefetchSize.SIZE_64B + SIZE_128B = _cute_ir.L2PrefetchSize.SIZE_128B + SIZE_256B = _cute_ir.L2PrefetchSize.SIZE_256B + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.L2PrefetchSize: + return self.value + + class CacheEvictionPriority(enum.Enum): EVICT_NORMAL = _cute_ir.CacheEvictionPriority.EVICT_NORMAL EVICT_FIRST = _cute_ir.CacheEvictionPriority.EVICT_FIRST @@ -195,11 +306,72 @@ class CacheEvictionPriority(enum.Enum): return self.value +class LoadCacheMode(enum.Enum): + ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always + GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_ + STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming + LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use + NONE = _cute_nvgpu_ir.LoadCacheMode.none + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode: + return self.value + + +class StoreCacheMode(enum.Enum): + WRITE_BACK = _cute_nvgpu_ir.StoreCacheMode.write_back + GLOBAL = _cute_nvgpu_ir.StoreCacheMode.global_ + STREAMING = _cute_nvgpu_ir.StoreCacheMode.streaming + WRITE_THROUGH = _cute_nvgpu_ir.StoreCacheMode.write_through + NONE = _cute_nvgpu_ir.StoreCacheMode.none + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_nvgpu_ir.StoreCacheMode: + return self.value + + +class SharedSpace(enum.Enum): + CTA = _cute_nvgpu_ir.SharedSpace.CTA + CLUSTER = _cute_nvgpu_ir.SharedSpace.CLUSTER + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_nvgpu_ir.SharedSpace: + return self.value + + +COPY_CACHE_POLICY_FIELD_NAME = "cache_policy" + + @dataclass(frozen=True) class CopyUniversalOp(atom.CopyOp): """ The universal Copy Operation. + This operation is equivalent to the ``a = b`` assignment without any extra + memory attributes. For advanced memory features (memory order, memory scope, + cache eviction priority, invariant loads, etc.) please use the specialized copy + operations instead: + + - :class:`CopyG2ROp` -- global memory to register + - :class:`CopyR2GOp` -- register to global memory + - :class:`CopyS2ROp` -- shared memory to register + - :class:`CopyR2SOp` -- register to shared memory + When creating a Copy Atom out of this operation, the expected usage pattern is .. code-block:: python @@ -237,13 +409,31 @@ class CopyUniversalOp(atom.CopyOp): memory_scope: MemoryScope = MemoryScope.CTA, l1c_evict_priority: CacheEvictionPriority = CacheEvictionPriority.EVICT_NORMAL, invariant: bool = False, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyUniversalTrait": if not isinstance(num_bits_per_copy, int) or num_bits_per_copy < 0: raise ValueError( f"'num_bits_per_copy' must be a non-negative int when creating a copy Atom for {self.__class__.__name__!r}" ) + + # CopyUniversalOp is designed to be a universal copy operation that is + # equivalent to the "a = b" assignment without any extra attributes. + # For advanced memory features, such as memory order, please use the + # specialized copy operations (e.g., CopyG2ROp) or their combinations instead. + if ( + memory_order != MemoryOrder.WEAK + or memory_scope != MemoryScope.CTA + or l1c_evict_priority != CacheEvictionPriority.EVICT_NORMAL + or invariant + ): + warnings.warn( + "Using CopyUniversalOp with extra attributes is deprecated. Please use specialized copy ops " + "(e.g., CopyG2ROp) for advanced memory features.", + DeprecationWarning, + ) + atom_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( copy_internal_type.mlir_type, num_bits_per_copy, @@ -257,3 +447,301 @@ class CopyUniversalOp(atom.CopyOp): class CopyUniversalTrait(atom.Trait): pass + + +@dataclass(frozen=True) +class CopyG2ROp(atom.CopyOp): + """ + The G2R copy operation. + + When creating a Copy Atom out of this operation, the expected usage pattern is + + .. code-block:: python + + op = cute.nvgpu.CopyG2ROp() + atom = cute.make_copy_atom( + op, + tensor_dtype, + num_bits_per_copy=64, + memory_order=cute.nvgpu.MemoryOrder.VOLATILE, + memory_scope=cute.nvgpu.MemoryScope.SYS, + l2_prefetch_size=cute.nvgpu.L2PrefetchSize.NONE, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.EVICT_NORMAL, + load_cache_mode=cute.nvgpu.LoadCacheMode.ALWAYS, + shared_space=cute.nvgpu.SharedSpace.CTA, + invariant=False, + ) + """ + + def __str__(self) -> str: + return "G2R copy operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + num_bits_per_copy: int = 0, + memory_order: MemoryOrder = MemoryOrder.WEAK, + memory_scope: MemoryScope = MemoryScope.CTA, + l2_prefetch_size: L2PrefetchSize = L2PrefetchSize.NONE, + l1c_evict_priority: CacheEvictionPriority = CacheEvictionPriority.EVICT_NORMAL, + load_cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS, + shared_space: SharedSpace = SharedSpace.CTA, + invariant: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyG2RTrait": + if not isinstance(num_bits_per_copy, int) or num_bits_per_copy < 0: + raise ValueError( + f"'num_bits_per_copy' must be a non-negative int when creating a copy Atom for {self.__class__.__name__!r}" + ) + atom_type = _cute_nvgpu_ir.CopyAtomG2RType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), + l2_prefetch_size._to_ir(), + l1c_evict_priority._to_ir(), + load_cache_mode._to_ir(), + shared_space._to_ir(), + invariant, + ) + return CopyG2RTrait(atom.make_atom(atom_type, loc=loc, ip=ip)) + + +class CopyG2RTrait(atom.Trait): + def unpack( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + cache_policy: Optional[Int64] = None, + **kwargs: Any, + ) -> ir.Value: + if cache_policy is None: + return self.value + cache_policy_attr_str = ( + f"#cute_nvgpu.atom_copy_field_g2r<{COPY_CACHE_POLICY_FIELD_NAME}>" + ) + cache_policy_attr = ir.Attribute.parse(cache_policy_attr_str) + val = _cute_nvgpu_ir.atom_set_value( + self.value, + cache_policy_attr, + cache_policy.ir_value(), + loc=loc, + ip=ip, + ) + return val + + +@dataclass(frozen=True) +class CopyR2GOp(atom.CopyOp): + """ + The R2G copy operation. + + When creating a Copy Atom out of this operation, the expected usage pattern is + + .. code-block:: python + + op = cute.nvgpu.CopyR2GOp() + atom = cute.make_copy_atom( + op, + tensor_dtype, + num_bits_per_copy=64, + memory_order=cute.nvgpu.MemoryOrder.RELEASE, + memory_scope=cute.nvgpu.MemoryScope.CLUSTER, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.EVICT_NORMAL, + shared_space=cute.nvgpu.SharedSpace.CTA, + ) + """ + + def __str__(self) -> str: + return "R2G copy operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + num_bits_per_copy: int = 0, + memory_order: MemoryOrder = MemoryOrder.WEAK, + memory_scope: MemoryScope = MemoryScope.CTA, + l1c_evict_priority: CacheEvictionPriority = CacheEvictionPriority.EVICT_NORMAL, + store_cache_mode: StoreCacheMode = StoreCacheMode.WRITE_BACK, + shared_space: SharedSpace = SharedSpace.CTA, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyR2GTrait": + if not isinstance(num_bits_per_copy, int) or num_bits_per_copy < 0: + raise ValueError( + f"'num_bits_per_copy' must be a non-negative int when creating a copy Atom for {self.__class__.__name__!r}" + ) + atom_type = _cute_nvgpu_ir.CopyAtomR2GType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), + l1c_evict_priority._to_ir(), + store_cache_mode._to_ir(), + shared_space._to_ir(), + ) + return CopyR2GTrait(atom.make_atom(atom_type, loc=loc, ip=ip)) + + +class CopyR2GTrait(atom.Trait): + def unpack( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + cache_policy: Optional[Int64] = None, + **kwargs: Any, + ) -> ir.Value: + if cache_policy is None: + return self.value + cache_policy_attr_str = ( + f"#cute_nvgpu.atom_copy_field_r2g<{COPY_CACHE_POLICY_FIELD_NAME}>" + ) + cache_policy_attr = ir.Attribute.parse(cache_policy_attr_str) + val = _cute_nvgpu_ir.atom_set_value( + self.value, + cache_policy_attr, + cache_policy.ir_value(), + loc=loc, + ip=ip, + ) + return val + + +def _reject_unknown_copy_trait_kwargs(op: object, kwargs: Mapping[str, Any]) -> None: + """Shared-memory load/store traits do not accept global-only keyword fields.""" + if kwargs: + name = next(iter(kwargs)) + raise TypeError( + f"{type(op).__name__}._make_trait() got an unexpected keyword argument {name!r}" + ) + + +@dataclass(frozen=True) +class CopyS2ROp(atom.CopyOp): + """ + The S2R copy operation. + + When creating a Copy Atom out of this operation, the expected usage pattern is + + .. code-block:: python + + op = cute.nvgpu.CopyS2ROp() + atom = cute.make_copy_atom( + op, + tensor_dtype, + num_bits_per_copy=64, + memory_order=cute.nvgpu.MemoryOrder.WEAK, + memory_scope=cute.nvgpu.MemoryScope.CTA, + shared_space=cute.nvgpu.SharedSpace.CTA, + ) + """ + + def __str__(self) -> str: + return "S2R copy operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + num_bits_per_copy: int = 0, + memory_order: MemoryOrder = MemoryOrder.WEAK, + memory_scope: MemoryScope = MemoryScope.CTA, + shared_space: SharedSpace = SharedSpace.CTA, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyS2RTrait": + _reject_unknown_copy_trait_kwargs(self, kwargs) + if not isinstance(num_bits_per_copy, int) or num_bits_per_copy < 0: + raise ValueError( + f"'num_bits_per_copy' must be a non-negative int when creating a copy Atom for {self.__class__.__name__!r}" + ) + atom_type = _cute_nvgpu_ir.CopyAtomS2RType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), + shared_space._to_ir(), + ) + return CopyS2RTrait(atom.make_atom(atom_type, loc=loc, ip=ip)) + + +class CopyS2RTrait(atom.Trait): + pass + + +@dataclass(frozen=True) +class CopyR2SOp(atom.CopyOp): + """ + The R2S copy operation. + + When creating a Copy Atom out of this operation, the expected usage pattern is + + .. code-block:: python + + op = cute.nvgpu.CopyR2SOp() + atom = cute.make_copy_atom( + op, + tensor_dtype, + num_bits_per_copy=64, + memory_order=cute.nvgpu.MemoryOrder.WEAK, + memory_scope=cute.nvgpu.MemoryScope.CTA, + shared_space=cute.nvgpu.SharedSpace.CTA, + ) + """ + + def __str__(self) -> str: + return "R2S copy operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + num_bits_per_copy: int = 0, + memory_order: MemoryOrder = MemoryOrder.WEAK, + memory_scope: MemoryScope = MemoryScope.CTA, + shared_space: SharedSpace = SharedSpace.CTA, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyR2STrait": + _reject_unknown_copy_trait_kwargs(self, kwargs) + if not isinstance(num_bits_per_copy, int) or num_bits_per_copy < 0: + raise ValueError( + f"'num_bits_per_copy' must be a non-negative int when creating a copy Atom for {self.__class__.__name__!r}" + ) + atom_type = _cute_nvgpu_ir.CopyAtomR2SType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), + shared_space._to_ir(), + ) + return CopyR2STrait(atom.make_atom(atom_type, loc=loc, ip=ip)) + + +class CopyR2STrait(atom.Trait): + pass + + +######################################################## +# Fragment Base Class +######################################################## + + +class FragmentBase(ABC): + @abstractmethod + def make_fragment( + self, + tensor: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tensor: ... diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py index 8337397a1..8d5ba48f3 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -25,9 +25,16 @@ __all__ = [ "CopyBulkTensorTileS2GOp", "CopyReduceBulkTensorTileS2GOp", "CopyDsmemStoreOp", + "CopyBulkG2SOp", + "CopyBulkG2SMulticastOp", + "CopyBulkS2GOp", + "CopyBulkS2GByteMaskOp", + "CopyBulkS2SOp", + "TmaCopyOp", # # helpers.py # + "TmaInfo", "make_tiled_tma_atom", "tma_partition", "create_tma_multicast_mask", @@ -37,4 +44,5 @@ __all__ = [ "fence_tma_desc_acquire", "cp_fence_tma_desc_release", "fence_tma_desc_release", + "group_bulk_copy_modes", ] diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py index e5250886b..8adfccc5b 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -11,18 +11,20 @@ import enum from dataclasses import dataclass -from typing import Optional, Type +from typing import Any, Optional, Type +from typing_extensions import deprecated from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects.cute import ReductionOp as ReductionOp +from cutlass._mlir.dialects.nvvm import ReductionOp as ReductionOp from cutlass._mlir import ir from ...atom import CopyOp, Trait, make_atom from ...typing import Int16, Int32, Int64, Pointer, Integer, Numeric -from ..common import OpError +from ..common import OpError, LoadCacheMode as LoadCacheMode_ + from ..tcgen05.mma import CtaGroup @@ -33,6 +35,9 @@ from ..tcgen05.mma import CtaGroup #################################################################################################### +@deprecated( + "cute.nvgpu.cpasync.LoadCacheMode is deprecated, use cute.nvgpu.LoadCacheMode instead" +) class LoadCacheMode(enum.Enum): """ An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction. @@ -64,11 +69,23 @@ class CopyG2SOp(CopyOp): See the `PTX documentation `__. """ - cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS + cache_mode: LoadCacheMode_ = LoadCacheMode_.ALWAYS + + def __init__( + self, cache_mode: LoadCacheMode_ | LoadCacheMode = LoadCacheMode_.ALWAYS + ): + super().__init__() + object.__setattr__( + self, + "cache_mode", + LoadCacheMode_(cache_mode.value) + if isinstance(cache_mode, LoadCacheMode) + else cache_mode, + ) def __str__(self) -> str: res = "cp.async GMEM -> SMEM copy Operation" - if self.cache_mode != LoadCacheMode.ALWAYS: + if self.cache_mode != LoadCacheMode_.ALWAYS: res += f"\n with cache mode = {self.cache_mode}" return res @@ -76,9 +93,9 @@ class CopyG2SOp(CopyOp): self, copy_internal_type: Type[Numeric], *, - loc=None, - ip=None, - **kwargs, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyG2STrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", None) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0): @@ -87,7 +104,7 @@ class CopyG2SOp(CopyOp): f"when creating a copy Atom for {self.__class__.__name__}" ) # Verify that the user provided enum values - if not isinstance(self.cache_mode, LoadCacheMode): + if not isinstance(self.cache_mode, LoadCacheMode_): raise OpError( self, "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", @@ -121,16 +138,7 @@ class TmaCopyOp(CopyOp): Base class for all TMA copy operations. """ - def __init__(self, smem_layout: Optional[ir.Value] = None) -> None: - self.smem_layout = smem_layout - - def __extract_mlir_values__(self): - return [self.smem_layout] - - def __new_from_mlir_values__(self, values): - res = self.__class__() - res.smem_layout = values[0] - return res + pass # @@ -141,10 +149,48 @@ class TmaCopyOp(CopyOp): @dataclass class CopyBulkTensorTileG2SOp(TmaCopyOp): """ - Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit. + Bulk tensor asynchronous GMEM to SMEM Copy Operation using the TMA unit. + + TMA copy operations are issued by a single thread within a warp, but the DSL **automatically handles this** by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: TMA copy without elect_one + cute.copy( + tma_atom, + gmem_tensor, # TMA partition ensures single-thread automatically + smem_tensor, + tma_bar_ptr=barrier_ptr + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(tma_atom, gmem_tensor, smem_tensor, tma_bar_ptr=barrier_ptr) + + While the TMA copy itself does not need ``elect_one()``, barrier initialization and transaction byte setup **must** use ``elect_one()``: + + .. code-block:: python + + # Barrier setup requires elect_one + with cute.arch.elect_one(): + cute.arch.mbarrier_init(barrier_ptr, arrival_count) + cute.arch.mbarrier_expect_tx(barrier_ptr, num_tma_bytes) + + # TMA copy does NOT need elect_one + cute.copy(tma_atom, gmem_tensor, smem_tensor, tma_bar_ptr=barrier_ptr) + + **PTX Programming Model**: In PTX, TMA operations (``cp.async.bulk.tensor``) must be issued + by a single thread. The DSL automatically handles this. See the `PTX documentation `__. This Operation uses TMA in the ``.tile`` mode. + + .. seealso:: + - :func:`cute.arch.elect_one` - **NOT** needed for TMA copy, but needed for barrier setup + - :func:`cute.arch.mbarrier_init` - Requires elect_one + - :func:`cute.arch.mbarrier_expect_tx` - Requires elect_one + - Tutorial example: ``examples/blackwell/tutorial_tma/tma_v0.py`` """ cta_group: CtaGroup = CtaGroup.ONE @@ -176,7 +222,12 @@ class CopyBulkTensorTileG2SOp(TmaCopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkTensorTileG2SNonExecTrait": raise NotImplementedError( "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" @@ -195,19 +246,25 @@ class CopyBulkTensorTileG2SNonExecTrait(Trait): # We allow kw args to be dropped so that the user can write common code for non-multicast # and multicast loads. - def with_(self, *, loc=None, ip=None, **kwargs) -> "CopyBulkTensorTileG2STrait": + def with_( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorTileG2STrait": return CopyBulkTensorTileG2STrait(self.unpack(loc=loc, ip=ip, **kwargs)) def unpack( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, tma_bar_ptr: Optional[Pointer] = None, tma_desc_ptr: Optional[Pointer] = None, cache_policy: Optional[Int64] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ir.Value: """ Custom implementation of unpack for non-executable TMAs. @@ -242,7 +299,7 @@ class CopyBulkTensorTileG2SNonExecTrait(Trait): ) attr = ir.Attribute.parse(attr_str) exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, cache_policy.value, loc=loc, ip=ip + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip ) return exec_value @@ -251,23 +308,152 @@ class CopyBulkTensorTileG2STrait(Trait): pass +@dataclass +class CopyBulkTensorIm2ColG2SOp(TmaCopyOp): + """ + Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit in im2col mode. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.im2col`` mode. + """ + + cta_group: CtaGroup = CtaGroup.ONE + + def __post_init__(self) -> None: + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, "expects the 'cta_group' parameter to be a CtaGroup instance" + ) + # Arch verification + arch: Arch = BaseDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: + raise OpError( + self, + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if (self.cta_group == CtaGroup.TWO) and arch.major == Arch.sm_90.major: + raise OpError( + self, + f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk tensor copy Operation" + if self.cta_group == CtaGroup.TWO: + res += "\n CTA group = 2" + return res + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorIm2ColG2SNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_im2col_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.Im2ColTmaLoadEnum: + if self.cta_group == CtaGroup.ONE: + return _cute_nvgpu_ir.Im2ColTmaLoadEnum.sm_90 + elif self.cta_group == CtaGroup.TWO: + return _cute_nvgpu_ir.Im2ColTmaLoadEnum.sm_100_2sm + else: + assert False, "unrecognized self.cta_group" + + +class CopyBulkTensorIm2ColG2SNonExecTrait(Trait): + # We allow kw args to be dropped so that the user can write common code for non-multicast + # and multicast loads. + + def with_( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorIm2ColG2STrait": + return CopyBulkTensorIm2ColG2STrait(self.unpack(loc=loc, ip=ip, **kwargs)) + + def unpack( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + tma_bar_ptr: Optional[Pointer] = None, + tma_desc_ptr: Optional[Pointer] = None, + cache_policy: Optional[Int64] = None, + **kwargs: Any, + ) -> ir.Value: + """ + Custom implementation of unpack for non-executable TMAs. + + The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when + using `cute.copy`. `cache_policy` keyword argument to be provided to set the l2 cache eviction priority. + Any other kw arguments will be ignored instead of triggering an error. + """ + if not isinstance(tma_bar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" + ) + + exec_value = _cute_nvgpu_ir.atom_make_exec_tma( + self.value, + loc=loc, + ip=ip, + ) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) + if isinstance(tma_desc_ptr, Pointer): + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + if cache_policy is not None: + if not isinstance(cache_policy, Int64): + raise ValueError( + "expects `Int64` value to be provided via the cache_policy kw argument" + ) + + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_CACHE_POLICY_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip + ) + return exec_value + + +class CopyBulkTensorIm2ColG2STrait(Trait): + pass + + # # TMA GMEM -> SMEM multicast copies # @dataclass -class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp): +class CopyBulkTensorIm2ColG2SMulticastOp(TmaCopyOp): """ Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit. See the `PTX documentation `__. - This Operation uses TMA in the ``.tile`` mode. + This Operation uses TMA in the ``.im2col`` mode. """ cta_group: CtaGroup = CtaGroup.ONE - def __post_init__(self): + def __post_init__(self) -> None: if not isinstance(self.cta_group, CtaGroup): raise OpError( self, "expects the 'cta_group' parameter to be a CtaGroup instance" @@ -294,7 +480,265 @@ class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorIm2ColG2SMulticastNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_im2col_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.Im2ColTmaLoadEnum: + if self.cta_group == CtaGroup.ONE: + return _cute_nvgpu_ir.Im2ColTmaLoadEnum.sm_90_multicast + elif self.cta_group == CtaGroup.TWO: + return _cute_nvgpu_ir.Im2ColTmaLoadEnum.sm_100_2sm_multicast + else: + assert False, "unrecognized self.cta_group" + + +class CopyBulkTensorIm2ColG2SMulticastNonExecTrait(Trait): + def with_( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorIm2ColG2SMulticastTrait": + return CopyBulkTensorIm2ColG2SMulticastTrait( + self.unpack(loc=loc, ip=ip, **kwargs) + ) + + def unpack( # type: ignore[override] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + tma_bar_ptr: Optional[Pointer] = None, + mcast_mask: Any = None, + tma_desc_ptr: Any = None, + cache_policy: Optional[Int64] = None, + ) -> ir.Value: + """ + Custom implementation of unpack for non-executable TMAs. + + The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be + provided when using `cute.copy`. `cache_policy` keyword argument to be provided to set the + l2 cache eviction priority. + """ + if not isinstance(tma_bar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" + ) + if not isinstance(mcast_mask, Integer): + raise ValueError( + "expects a multicast mask to be provided via the mcast_mask kw argument" + ) + + exec_value = _cute_nvgpu_ir.atom_make_exec_tma( + self.value, + loc=loc, + ip=ip, + ) + attr_str = "#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + if isinstance(tma_desc_ptr, Pointer): + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + if cache_policy is not None: + if not isinstance(cache_policy, Int64): + raise ValueError( + "expects `Int64` value to be provided via the cache_policy kw argument" + ) + + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_CACHE_POLICY_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip + ) + # Set the tma_bar_ptr at last to ensure that the atom creation and setting + # operations above can be moved outside the loop + attr_str = "#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) + return exec_value + + +class CopyBulkTensorIm2ColG2SMulticastTrait(Trait): + pass + + +@dataclass +class CopyBulkTensorIm2ColS2GOp(TmaCopyOp): + """ + Bulk tensor asynchrnous SMEM to GMEM Copy Operation using the TMA unit in im2col mode. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.im2col`` mode. + """ + + def __post_init__(self) -> None: + # Arch verification + arch: Arch = BaseDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: + raise OpError( + self, + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + return "cp.async SMEM -> GMEM bulk tensor copy Operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorIm2ColS2GNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_im2col_tma_atom to obtain a copy Atom for TMA" + ) + + +class CopyBulkTensorIm2ColS2GNonExecTrait(Trait): + def with_( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorIm2ColS2GTrait": + return CopyBulkTensorIm2ColS2GTrait(self.unpack(loc=loc, ip=ip, **kwargs)) + + def unpack( # type: ignore[override] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + tma_desc_ptr: Optional[Pointer] = None, + cache_policy: Optional[Int64] = None, + ) -> ir.Value: + """ + Custom implementation of unpack for non-executable TMAs. + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + if isinstance(tma_desc_ptr, Pointer): + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + if cache_policy is not None: + if not isinstance(cache_policy, Int64): + raise ValueError( + "expects `Int64` value to be provided via the cache_policy kw argument" + ) + + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_CACHE_POLICY_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip + ) + return exec_value + + +class CopyBulkTensorIm2ColS2GTrait(Trait): + pass + + +# +# TMA GMEM -> SMEM multicast copies +# + + +@dataclass +class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp): + """ + Bulk tensor asynchronous multicast GMEM to SMEM Copy Operation using the TMA unit. + + TMA multicast operations are issued by a single thread within a warp, but the DSL **automatically handles this** by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: TMA multicast without elect_one + cute.copy( + tma_atom.with_(mcast_mask=cluster_mask), + gmem_tensor, + smem_tensor, + tma_bar_ptr=barrier_ptr + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(tma_atom.with_(mcast_mask=mask), gmem_tensor, smem_tensor) + + **PTX Programming Model**: In PTX, TMA multicast operations (``cp.async.bulk.tensor.multicast``) + must be issued by a single thread. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + + .. seealso:: + - :func:`cute.arch.elect_one` - **NOT** needed for TMA copy + - :class:`CopyBulkTensorTileG2SOp` - Non-multicast TMA load + """ + + cta_group: CtaGroup = CtaGroup.ONE + + def __post_init__(self) -> None: + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, "expects the 'cta_group' parameter to be a CtaGroup instance" + ) + # Arch verification + arch = BaseDSL._get_dsl().get_arch_enum() + if not arch >= Arch.sm_90: + raise OpError( + self, + f"expects arch to be at least {Arch.sm_90.name}, but got {arch.name}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if (self.cta_group == CtaGroup.TWO) and arch.major == Arch.sm_90.major: + raise OpError( + self, + f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation" + if self.cta_group == CtaGroup.TWO: + res += "\n CTA group = 2" + return res + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkTensorTileG2SMulticastNonExecTrait": raise NotImplementedError( "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" @@ -311,22 +755,26 @@ class CopyBulkTensorTileG2SMulticastOp(TmaCopyOp): class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): def with_( - self, *, loc=None, ip=None, **kwargs + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkTensorTileG2SMulticastTrait": return CopyBulkTensorTileG2SMulticastTrait( self.unpack(loc=loc, ip=ip, **kwargs) ) - def unpack( + def unpack( # type: ignore[override] self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, tma_bar_ptr: Optional[Pointer] = None, - mcast_mask=None, - tma_desc_ptr=None, + mcast_mask: Any = None, + tma_desc_ptr: Any = None, cache_policy: Optional[Int64] = None, - ): + ) -> ir.Value: """ Custom implementation of unpack for non-executable TMAs. @@ -365,7 +813,7 @@ class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): ) attr = ir.Attribute.parse(attr_str) exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, cache_policy.value, loc=loc, ip=ip + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip ) # Set the tma_bar_ptr at last to ensure that the atom creation and setting # operations above can be moved outside the loop @@ -391,11 +839,35 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp): """ Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit. + TMA store operations are issued by a single thread within a warp, but the DSL **automatically handles this** by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: TMA store without elect_one + cute.copy( + tma_atom, + smem_tensor, # Source: shared memory + gmem_tensor, # Destination: global memory + ) + + # WRONG: Do NOT wrap in elect_one (causes deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(tma_atom, smem_tensor, gmem_tensor) + + **PTX Programming Model**: In PTX, TMA store operations must be issued by a single thread. + The DSL automatically handles this. + See the `PTX documentation `__. This Operation uses TMA in the ``.tile`` mode. + + .. seealso:: + - :func:`cute.arch.elect_one` - **NOT** needed for TMA store + - :class:`CopyBulkTensorTileG2SOp` - TMA load operation + - Tutorial example: ``examples/blackwell/tutorial_tma/tma_v0.py`` """ - def __post_init__(self): + def __post_init__(self) -> None: # Arch verification arch = BaseDSL._get_dsl().get_arch_enum() if not arch >= Arch.sm_90: @@ -409,7 +881,12 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp): return "cp.async SMEM -> GMEM bulk tensor copy Operation" def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkTensorTileS2GNonExecTrait": raise NotImplementedError( "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" @@ -417,17 +894,23 @@ class CopyBulkTensorTileS2GOp(TmaCopyOp): class CopyBulkTensorTileS2GNonExecTrait(Trait): - def with_(self, *, loc=None, ip=None, **kwargs) -> "CopyBulkTensorTileS2GTrait": - return CopyBulkTensorTileS2GTrait(self.unpack(loc=loc, ip=ip, **kwargs)) - - def unpack( + def with_( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "CopyBulkTensorTileS2GTrait": + return CopyBulkTensorTileS2GTrait(self.unpack(loc=loc, ip=ip, **kwargs)) + + def unpack( # type: ignore[override] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, tma_desc_ptr: Optional[Pointer] = None, cache_policy: Optional[Int64] = None, - ): + ) -> ir.Value: """ Custom implementation of unpack for non-executable TMAs. """ @@ -451,7 +934,7 @@ class CopyBulkTensorTileS2GNonExecTrait(Trait): ) attr = ir.Attribute.parse(attr_str) exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, cache_policy.value, loc=loc, ip=ip + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip ) return exec_value @@ -471,9 +954,9 @@ class CopyReduceBulkTensorTileS2GOp(TmaCopyOp): reduction_kind: ReductionOp = ReductionOp.ADD - def __post__init__(self): + def __post_init__(self) -> None: # Arch verification - arch = CuTeDSL._get_dsl().get_arch_enum() + arch = BaseDSL._get_dsl().get_arch_enum() if not arch >= Arch.sm_90: raise OpError( self, @@ -485,7 +968,12 @@ class CopyReduceBulkTensorTileS2GOp(TmaCopyOp): return "cp.async SMEM -> GMEM bulk tensor reduction Operation" def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyReduceBulkTensorTileS2GNonExecTrait": raise NotImplementedError( "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" @@ -514,18 +1002,22 @@ class CopyReduceBulkTensorTileS2GOp(TmaCopyOp): class CopyReduceBulkTensorTileS2GNonExecTrait(Trait): def with_( - self, *, loc=None, ip=None, **kwargs + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyReduceBulkTensorTileS2GTrait": return CopyReduceBulkTensorTileS2GTrait(self.unpack(loc=loc, ip=ip, **kwargs)) - def unpack( + def unpack( # type: ignore[override] self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, tma_desc_ptr: Optional[Pointer] = None, cache_policy: Optional[Int64] = None, - ): + ) -> ir.Value: """ Custom implementation of unpack for non-executable TMAs. """ @@ -548,7 +1040,7 @@ class CopyReduceBulkTensorTileS2GNonExecTrait(Trait): ) attr = ir.Attribute.parse(attr_str) exec_value = _cute_nvgpu_ir.atom_set_value( - exec_value, attr, cache_policy.value, loc=loc, ip=ip + exec_value, attr, cache_policy.ir_value(), loc=loc, ip=ip ) return exec_value @@ -585,7 +1077,12 @@ class CopyBulkG2SOp(CopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkG2STrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): @@ -605,12 +1102,12 @@ class CopyBulkG2STrait(Trait): def unpack( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, mbar_ptr: Optional[Pointer] = None, cache_policy: Optional[Int64] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ir.Value: """ Custom implementation of unpack for bulk copy load. @@ -637,7 +1134,7 @@ class CopyBulkG2STrait(Trait): ) attr = ir.Attribute.parse(attr_str) val = _cute_nvgpu_ir.atom_set_value( - val, attr, cache_policy.value, loc=loc, ip=ip + val, attr, cache_policy.ir_value(), loc=loc, ip=ip ) return val @@ -670,7 +1167,12 @@ class CopyBulkG2SMulticastOp(CopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkG2SMulticastTrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): @@ -690,13 +1192,13 @@ class CopyBulkG2SMulticastTrait(Trait): def unpack( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, mbar_ptr: Optional[Pointer] = None, mcast_mask: Optional[Integer] = None, cache_policy: Optional[Int64] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ir.Value: """ Custom implementation of unpack for bulk copy load. @@ -731,7 +1233,7 @@ class CopyBulkG2SMulticastTrait(Trait): ) attr = ir.Attribute.parse(attr_str) val = _cute_nvgpu_ir.atom_set_value( - val, attr, cache_policy.value, loc=loc, ip=ip + val, attr, cache_policy.ir_value(), loc=loc, ip=ip ) return val @@ -764,7 +1266,12 @@ class CopyBulkS2GOp(CopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkS2GTrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): @@ -812,7 +1319,12 @@ class CopyBulkS2GByteMaskOp(CopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkS2GByteMaskTrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): @@ -830,11 +1342,11 @@ class CopyBulkS2GByteMaskTrait(Trait): def unpack( self, *, - loc=None, - ip=None, - byte_mask=None, - **kwargs, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + byte_mask: Any = None, + **kwargs: Any, + ) -> ir.Value: """ Custom implementation of unpack for bulk copy store with mask. @@ -886,7 +1398,12 @@ class CopyBulkS2SOp(CopyOp): return res def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyBulkS2STrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): @@ -904,12 +1421,12 @@ class CopyBulkS2STrait(Trait): def unpack( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, mbar_ptr: Optional[Pointer] = None, cta_rank: Optional[Integer] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ir.Value: """ Custom implementation of unpack for bulk copy cta to cluster. @@ -973,9 +1490,9 @@ class CopyDsmemStoreOp(CopyOp): self, copy_internal_type: Type[Numeric], *, - loc=None, - ip=None, - **kwargs, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "CopyDsmemStoreTrait": num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): @@ -998,11 +1515,11 @@ class CopyDsmemStoreTrait(Trait): def unpack( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, mbar_ptr: Optional[Pointer] = None, - **kwargs, - ): + **kwargs: Any, + ) -> ir.Value: """ Custom implementation of unpack for dsmem async copy. diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py index e02cd1469..4d8cb7536 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -9,11 +9,18 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Optional, Tuple, Type, Union +from typing import Any, Iterator, List, Optional, Tuple, Type, Union, cast +from typing_extensions import deprecated -from cutlass.cutlass_dsl import dsl_user_op +from cutlass.cutlass_dsl import ( + dsl_user_op, + extract_mlir_attributes, + extract_mlir_values, + new_from_mlir_values, +) import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir from cutlass._mlir.dialects import llvm from ...typing import ( @@ -26,27 +33,372 @@ from ...typing import ( Int16, Numeric, NumericMeta, + IntTuple, ) from ... import core, atom from .copy import ( CopyBulkTensorTileG2SOp, + CopyBulkTensorIm2ColG2SOp, CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorIm2ColG2SMulticastOp, CopyBulkTensorTileS2GOp, + CopyBulkTensorIm2ColS2GOp, CopyReduceBulkTensorTileS2GOp, CopyBulkTensorTileG2SNonExecTrait, CopyBulkTensorTileG2SMulticastNonExecTrait, CopyBulkTensorTileS2GNonExecTrait, CopyReduceBulkTensorTileS2GNonExecTrait, + CopyBulkTensorIm2ColG2SNonExecTrait, + CopyBulkTensorIm2ColG2SMulticastNonExecTrait, + CopyBulkTensorIm2ColS2GNonExecTrait, ) + +class TmaInfo: + """ + Container for TMA Copy Atom and related data. + + This class uses software composition to bundle a CopyAtom with the SMEM + layout and TMA tensor. + + Supports tuple unpacking for backward compatibility:: + + atom, tma_tensor = make_tiled_tma_atom(...) + + Access smem_layout via the container:: + + tma_info = make_tiled_tma_atom(...) + layout = tma_info.smem_layout + + :param atom: The TMA Copy Atom + :type atom: CopyAtom + :param tma_tensor: The TMA tensor for coordinate mapping + :param smem_layout: The SMEM layout used to construct the TMA descriptor + """ + + def __init__( + self, copy_atom: atom.CopyAtom, tma_tensor: Any, smem_layout: Any = None + ) -> None: + self._atom = copy_atom + self._tma_tensor = tma_tensor + self._smem_layout = smem_layout + + @property + def atom(self) -> atom.CopyAtom: + """The TMA Copy Atom.""" + return self._atom + + @property + def tma_tensor(self) -> Any: + """The TMA tensor for coordinate mapping.""" + return self._tma_tensor + + @property + def smem_layout(self) -> Any: + """The SMEM layout used to construct the TMA descriptor.""" + return self._smem_layout + + def __iter__(self) -> Iterator[Any]: + """ + Support tuple unpacking: ``atom, tma_tensor = tma_info`` + + This provides backward compatibility with the original return type. + """ + yield self._atom + yield self._tma_tensor + + def __getitem__(self, index: int) -> Any: + """Support indexing for backward compatibility.""" + if index == 0: + return self._atom + if index == 1: + return self._tma_tensor + raise IndexError(f"TmaInfo index out of range: {index}") + + def __len__(self) -> int: + """Return 2 for backward compatibility with tuple unpacking.""" + return 2 + + def __extract_mlir_values__(self) -> List[Any]: + vals = extract_mlir_values(self._atom) + vals += extract_mlir_values(self._tma_tensor) + if self._smem_layout is not None: + vals += extract_mlir_values(self._smem_layout) + return vals + + def __extract_mlir_attributes__(self) -> List[Any]: + attrs = extract_mlir_attributes(self._atom) + attrs += extract_mlir_attributes(self._tma_tensor) + if self._smem_layout is not None: + attrs += extract_mlir_attributes(self._smem_layout) + return attrs + + def __new_from_mlir_values__(self, values: List[Any]) -> "TmaInfo": + atom_len = len(extract_mlir_values(self._atom)) + tensor_len = len(extract_mlir_values(self._tma_tensor)) + smem_len = ( + len(extract_mlir_values(self._smem_layout)) + if self._smem_layout is not None + else 0 + ) + + atom_vals = values[:atom_len] + tensor_vals = values[atom_len : atom_len + tensor_len] + smem_vals = values[atom_len + tensor_len : atom_len + tensor_len + smem_len] + + new_atom = new_from_mlir_values(self._atom, atom_vals) + new_tensor = new_from_mlir_values(self._tma_tensor, tensor_vals) + + new_smem_layout = self._smem_layout + if smem_len: + new_smem_layout = new_from_mlir_values(self._smem_layout, smem_vals) + + return TmaInfo(new_atom, new_tensor, new_smem_layout) + + TMAOp = Union[ CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, + CopyBulkTensorIm2ColG2SOp, + CopyBulkTensorIm2ColS2GOp, CopyReduceBulkTensorTileS2GOp, ] +@dsl_user_op +def make_im2col_tma_atom( + op: TMAOp, + gmem_tensor: Tensor, + smem_layout_: Union[Layout, ComposedLayout], + cta_tiler: Tiler, + lower_corner_whd: Optional[IntTuple] = None, + upper_corner_whd: Optional[IntTuple] = None, + lower_padding_whd: Optional[IntTuple] = None, + upper_padding_whd: Optional[IntTuple] = None, + stride_whd: Optional[IntTuple] = None, + lower_srt: Optional[IntTuple] = None, + stride_srt: Optional[IntTuple] = None, + num_multicast: int = 1, + *, + internal_type: Optional[Type[Numeric]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TmaInfo: + """ + Makes a TMA Copy Atom in the ``.im2col`` mode to copy tiles of a GMEM tensor to/from SMEM + buffer with the given Layout. The im2col descriptor parameters: + - lower_corner_whd + - upper_corner_whd + - lower_padding_whd + - upper_padding_whd + - stride_whd + - lower_srt + - stride_srt + are only needed for the load mode (GMEM -> SMEM). + + Given + + - a GMEM tensor + - a SMEM layout + - a CTA-level Tiler + - a lower corner tuple for w,h,d + - a upper corner tuple for w,h,d + - a lower padding tuple for w,h,d + - a upper padding tuple for w,h,d + - a stride tuple for w,h,d + - a lower corner tuple for s,r,t + - a stride tuple for s,r,t + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided + layout while maintaining consistency with the provided Tiler. + + This function returns two results: + + 1. the Copy Atom + 2. a TMA tensor that maps logical coordinates of the GMEM tensor to coordinates consumed by the \ + TMA unit. TMA tensors contain basis stride elements that enable their associated layout to \ + compute coordinates. Like other CuTe tensors, TMA tensors can be partitioned. + + :param op: The TMA Copy Operation to construct an Atom + :type op: TMAOp + :param gmem_tensor: The GMEM tensor involved in the Copy + :type gmem_tensor: Tensor + :param smem_layout: The SMEM layout to construct the Copy Atom, either w/ or w/o the stage mode + :type smem_layout: Union[Layout, ComposedLayout] + :param cta_tiler: The CTA Tiler to use + :type cta_tiler: Tiler + :param lower_corner_whd: The lower corner of w,h,d involved in the im2col copy + :type lower_corner_whd: IntTuple + :param upper_corner_whd: The uppper corner of w,h,d involved in the im2col copy + :type upper_corner_whd: IntTuple + :param lower_padding_whd: The lower padding of w,h,d involved in the im2col copy + :type lower_padding_whd: IntTuple + :param upper_padding_whd: The upper padding of w,h,d involved in the im2col copy + :type upper_padding_whd: IntTuple + :param stride_whd: The conv stride of w,h,d involved in the im2col copy + :type stride_whd: IntTuple + :param lower_srt: The lower corner of s,r,t involved in the im2col copy for easily reused in fprop and dgrad + :type lower_srt: IntTuple + :param stride_srt: The stride of s,r,t involved in the im2col copy as dilation, and for easily reused in fprop and dgrad + :type stride_srt: IntTuple + :param num_multicast: The multicast factor + :type num_multicast: int + :param internal_type: Optional internal data type to use when the tensor data type is not supported by the TMA unit + :type internal_type: Type[Numeric] + :return: A TmaInfo containing the Copy Atom, TMA tensor, and SMEM layout + :rtype: TmaInfo + """ + smem_rank = core.rank(smem_layout_) + tiler_rank = core.rank(cta_tiler) + assert smem_rank == tiler_rank or smem_rank == tiler_rank + 1, ( + "smem_layout must be non-staged (rank(smem_layout) == rank(cta_tiler)) " + "or staged (rank(smem_layout) == rank(cta_tiler) + 1)" + ) + + # Keep the original SMEM layout object for later retrieval at Python level. + stored_smem_layout = smem_layout_ + + # Set the smem_layout on the operation for later retrieval + cast(Any, op).smem_layout = ( + smem_layout_.value + if isinstance(smem_layout_, core._ComposedLayout) + else smem_layout_ + ) + + # Slice the smem_layout if it is staged + if smem_rank == tiler_rank + 1: + smem_layout = core.select(smem_layout_, mode=list(range(tiler_rank))) + else: + smem_layout = smem_layout_ + + # gmem_tensor is hierarchical form ((w, h, d, n), c) or (k, (c, s, r, t)) + cta_v_map = core.composition( + core.make_identity_layout(core.product_each(gmem_tensor.shape), loc=loc, ip=ip), + cta_tiler, + loc=loc, + ip=ip, + ) + + if isinstance(smem_layout, core._ComposedLayout): + smem_layout = smem_layout.value + + tma_format = None + if internal_type is not None: + itype: Any = internal_type + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + + use_unpack = ( + itype.width == 8 + and isinstance(gmem_tensor.element_type, NumericMeta) + and gmem_tensor.element_type.width < 8 # type: ignore[union-attr] + ) + internal_mlir_type = ( + gmem_tensor.element_type.mlir_type if use_unpack else itype.mlir_type # type: ignore[union-attr] + ) + tma_format = _cute_nvgpu_ir.TmaDataFormat( + _cute_nvgpu_ir.get_default_tma_format(internal_mlir_type, use_unpack) + ) + + if ( + isinstance(op, (CopyBulkTensorIm2ColG2SOp, CopyBulkTensorIm2ColG2SMulticastOp)) + ) and ( + lower_corner_whd is None + or upper_corner_whd is None + or lower_padding_whd is None + or upper_padding_whd is None + or stride_whd is None + or lower_srt is None + or stride_srt is None + ): + raise ValueError( + f"expects lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt to be provided for load mode (GMEM -> SMEM), but got {lower_corner_whd}, {upper_corner_whd}, {lower_padding_whd}, {upper_padding_whd}, {stride_whd}, {lower_srt}, {stride_srt}" + ) + if isinstance(op, CopyBulkTensorIm2ColG2SOp): + if num_multicast != 1: + raise ValueError( + f"expects num_multicast to be 1 for non multicast G2S copies, " + f"but got {num_multicast}" + ) + # Get the non-exec im2col tma load atom + assert lower_corner_whd is not None + assert upper_corner_whd is not None + assert lower_padding_whd is not None + assert upper_padding_whd is not None + assert stride_whd is not None + assert lower_srt is not None + assert stride_srt is not None + res = _cute_nvgpu_ir.atom_make_non_exec_im2col_tma_load( + cast(Any, gmem_tensor).value, + smem_layout, + cta_v_map, + op._to_ir(), + core._pack_int_tuple(lower_corner_whd, loc=loc, ip=ip), + core._pack_int_tuple(upper_corner_whd, loc=loc, ip=ip), + core._pack_int_tuple(lower_padding_whd, loc=loc, ip=ip), + core._pack_int_tuple(upper_padding_whd, loc=loc, ip=ip), + core._pack_int_tuple(stride_whd, loc=loc, ip=ip), + core._pack_int_tuple(lower_srt, loc=loc, ip=ip), + core._pack_int_tuple(stride_srt, loc=loc, ip=ip), + num_multicast=num_multicast, + tma_format=tma_format, + loc=loc, + ip=ip, + ) + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorIm2ColG2SNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) + + elif isinstance(op, CopyBulkTensorIm2ColG2SMulticastOp): + if num_multicast < 1: + raise ValueError( + f"expects num_multicast to be >= 1 for multicast G2S copies, " + f"but got {num_multicast}" + ) + res = _cute_nvgpu_ir.atom_make_non_exec_im2col_tma_load( + cast(Any, gmem_tensor).value, + smem_layout, + cta_v_map, + op._to_ir(), + core._pack_int_tuple(lower_corner_whd, loc=loc, ip=ip), + core._pack_int_tuple(upper_corner_whd, loc=loc, ip=ip), + core._pack_int_tuple(lower_padding_whd, loc=loc, ip=ip), + core._pack_int_tuple(upper_padding_whd, loc=loc, ip=ip), + core._pack_int_tuple(stride_whd, loc=loc, ip=ip), + core._pack_int_tuple(lower_srt, loc=loc, ip=ip), + core._pack_int_tuple(stride_srt, loc=loc, ip=ip), + num_multicast=num_multicast, + tma_format=tma_format, + loc=loc, + ip=ip, + ) + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorIm2ColG2SMulticastNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) + elif isinstance(op, CopyBulkTensorIm2ColS2GOp): + res = _cute_nvgpu_ir.atom_make_non_exec_im2col_tma_store( + cast(Any, gmem_tensor).value, + smem_layout, + cta_v_map, + tma_format=tma_format, + loc=loc, + ip=ip, + ) + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorIm2ColS2GNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) + else: + raise ValueError(f"expects a bulk tensor (TMA) im2col Copy Op, but got {op}") + + @dsl_user_op def make_tiled_tma_atom( op: TMAOp, @@ -56,9 +408,9 @@ def make_tiled_tma_atom( num_multicast: int = 1, *, internal_type: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Tuple[atom.CopyAtom, Tensor]: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TmaInfo: """ Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM buffer with the given Layout. @@ -92,22 +444,18 @@ def make_tiled_tma_atom( :type num_multicast: int :param internal_type: Optional internal data type to use when the tensor data type is not supported by the TMA unit :type internal_type: Type[Numeric] - :return: A TMA Copy Atom associated with the TMA tensor - :rtype: Tuple[atom.CopyAtom, Tensor] + :return: A TmaInfo containing the Copy Atom, TMA tensor, and SMEM layout + :rtype: TmaInfo """ smem_rank = core.rank(smem_layout_) tiler_rank = core.rank(cta_tiler) assert smem_rank == tiler_rank or smem_rank == tiler_rank + 1, ( - f"smem_layout must be non-staged (rank(smem_layout) == rank(cta_tiler)) " - f"or staged (rank(smem_layout) == rank(cta_tiler) + 1)" + "smem_layout must be non-staged (rank(smem_layout) == rank(cta_tiler)) " + "or staged (rank(smem_layout) == rank(cta_tiler) + 1)" ) - # Set the smem_layout on the operation for later retrieval - op.smem_layout = ( - smem_layout_.value - if isinstance(smem_layout_, core._ComposedLayout) - else smem_layout_ - ) + # Keep the original SMEM layout object for later retrieval at Python level. + stored_smem_layout = smem_layout_ # Slice the smem_layout if it is staged if smem_rank == tiler_rank + 1: @@ -127,18 +475,17 @@ def make_tiled_tma_atom( tma_format = None if internal_type is not None: + itype: Any = internal_type if not isinstance(internal_type, NumericMeta): raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") use_unpack = ( - internal_type.width == 8 + itype.width == 8 and isinstance(gmem_tensor.element_type, NumericMeta) - and gmem_tensor.element_type.width < 8 + and gmem_tensor.element_type.width < 8 # type: ignore[union-attr] ) internal_mlir_type = ( - gmem_tensor.element_type.mlir_type - if use_unpack - else internal_type.mlir_type + gmem_tensor.element_type.mlir_type if use_unpack else itype.mlir_type # type: ignore[union-attr] ) tma_format = _cute_nvgpu_ir.TmaDataFormat( _cute_nvgpu_ir.get_default_tma_format(internal_mlir_type, use_unpack) @@ -151,7 +498,7 @@ def make_tiled_tma_atom( f"but got {num_multicast}" ) res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, + cast(Any, gmem_tensor).value, smem_layout, cta_v_map, op._to_ir(), @@ -160,7 +507,11 @@ def make_tiled_tma_atom( loc=loc, ip=ip, ) - return atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) elif isinstance(op, CopyBulkTensorTileG2SMulticastOp): if num_multicast < 1: raise ValueError( @@ -168,7 +519,7 @@ def make_tiled_tma_atom( f"but got {num_multicast}" ) res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, + cast(Any, gmem_tensor).value, smem_layout, cta_v_map, op._to_ir(), @@ -177,23 +528,28 @@ def make_tiled_tma_atom( loc=loc, ip=ip, ) - return ( + return TmaInfo( atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), res[1], + stored_smem_layout, ) elif isinstance(op, CopyBulkTensorTileS2GOp): res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store( - gmem_tensor.value, + cast(Any, gmem_tensor).value, smem_layout, cta_v_map, tma_format=tma_format, loc=loc, ip=ip, ) - return atom.CopyAtom(op, CopyBulkTensorTileS2GNonExecTrait(res[0])), res[1] + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorTileS2GNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) elif isinstance(op, CopyReduceBulkTensorTileS2GOp): res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce( - gmem_tensor.value, + cast(Any, gmem_tensor).value, smem_layout, cta_v_map, op._to_ir(), @@ -201,9 +557,10 @@ def make_tiled_tma_atom( loc=loc, ip=ip, ) - return ( + return TmaInfo( atom.CopyAtom(op, CopyReduceBulkTensorTileS2GNonExecTrait(res[0])), res[1], + stored_smem_layout, ) else: raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}") @@ -217,8 +574,8 @@ def tma_partition( smem_tensor: Tensor, gmem_tensor: Tensor, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Tensor, Tensor]: """ Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom. @@ -228,8 +585,8 @@ def tma_partition( atom._trait.value, cta_coord=cta_coord_val, cta_layout=cta_layout, - smem_tensor=smem_tensor.value, - gmem_tensor=gmem_tensor.value, + smem_tensor=cast(Any, smem_tensor).value, + target_tensors=[cast(Any, gmem_tensor).value], loc=loc, ip=ip, ) @@ -242,8 +599,8 @@ def create_tma_multicast_mask( cta_coord_vmnk: Coord, mcast_mode: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Int16: """ Computes a multicast mask for a TMA load Copy. @@ -271,7 +628,12 @@ def create_tma_multicast_mask( @dsl_user_op -def prefetch_descriptor(tma_atom: atom.CopyAtom, *, loc=None, ip=None) -> None: +def prefetch_descriptor( + tma_atom: atom.CopyAtom, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Prefetches the TMA descriptor associated with the TMA Atom. """ @@ -280,7 +642,11 @@ def prefetch_descriptor(tma_atom: atom.CopyAtom, *, loc=None, ip=None) -> None: @dsl_user_op def copy_tensormap( - tma_atom: atom.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None + tma_atom: atom.CopyAtom, + tensormap_ptr: Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided @@ -292,7 +658,7 @@ def copy_tensormap( :type tensormap_ptr: Pointer """ _cute_nvgpu_ir.copy_tma_desc( - tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip + tma_atom._trait.value, cast(Any, tensormap_ptr).value, loc=loc, ip=ip ) @@ -302,8 +668,8 @@ def update_tma_descriptor( gmem_tensor: Tensor, tma_desc_ptr: Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Updates the TMA descriptor in the memory location pointed to by the provided pointer using @@ -325,7 +691,11 @@ def update_tma_descriptor( :type tensormap_ptr: Pointer """ _cute_nvgpu_ir.update_tma_desc( - tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip + tma_atom._trait.value, + cast(Any, gmem_tensor).value, + cast(Any, tma_desc_ptr).value, + loc=loc, + ip=ip, ) @@ -333,13 +703,15 @@ def update_tma_descriptor( def fence_tma_desc_acquire( tma_desc_ptr: Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ See the `PTX documentation `__. """ - tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + tma_desc_ptr_i64 = ( + cast(Any, tma_desc_ptr).toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) + ) llvm.inline_asm( None, [tma_desc_ptr_i64], @@ -358,17 +730,17 @@ def cp_fence_tma_desc_release( tma_desc_global_ptr: Pointer, tma_desc_shared_ptr: Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ See the `PTX documentation `__. """ - tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value( - loc=loc, ip=ip + tma_desc_global_ptr_i64 = ( + cast(Any, tma_desc_global_ptr).toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) ) - tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value( - loc=loc, ip=ip + tma_desc_shared_ptr_i32 = ( + cast(Any, tma_desc_shared_ptr).toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) ) llvm.inline_asm( None, @@ -384,7 +756,9 @@ def cp_fence_tma_desc_release( @dsl_user_op -def fence_tma_desc_release(*, loc=None, ip=None) -> None: +def fence_tma_desc_release( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ See the `PTX documentation `__. """ @@ -399,3 +773,19 @@ def fence_tma_desc_release(*, loc=None, ip=None) -> None: loc=loc, ip=ip, ) + + +@dsl_user_op +@deprecated("`group_bulk_copy_modes` is deprecated, use `group_modes` instead") +def group_bulk_copy_modes( + src: Tensor, + dst: Tensor, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tuple[Tensor, Tensor]: + """ + Copy async bulk need group mode 0, acquiring whole tensor for bulk copy + """ + mSrc = core.group_modes(src, 0, core.rank(src), loc=loc, ip=ip) + mDst = core.group_modes(dst, 0, core.rank(dst), loc=loc, ip=ip) + return (mSrc, mDst) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py index 106defcfd..af109b26b 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -9,10 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Optional, Tuple, Type, Union +from typing import Any, Optional, Tuple, Type, Union, cast from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir import ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from .. import core, atom @@ -22,12 +23,18 @@ from .cpasync.copy import ( CopyBulkTensorTileG2SNonExecTrait, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SMulticastNonExecTrait, + CopyBulkTensorIm2ColG2SOp, + CopyBulkTensorIm2ColG2SNonExecTrait, + CopyBulkTensorIm2ColG2SMulticastOp, + CopyBulkTensorIm2ColG2SMulticastNonExecTrait, ) +from .cpasync.helpers import TmaInfo __all__ = [ "make_tiled_tma_atom_A", "make_tiled_tma_atom_B", + "make_im2col_tma_atom_A", ] #################################################################################################### @@ -47,9 +54,9 @@ def make_tiled_tma_atom_A( cluster_shape_vmnk: Union[Shape, None] = None, *, internal_type: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Tuple[atom.CopyAtom, Tensor]: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TmaInfo: """ Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation accounting for the MK projections of the TiledMMA for A tensor loads. @@ -90,22 +97,29 @@ def make_tiled_tma_atom_A( :param internal_type: An optional parameter for the internal data type to when element type does not match the copy type :type internal_type: Type[Numeric] - :return: A copy atom for this operation and the associated TMA coord tensor - :rtype: Tuple[atom.CopyAtom, Tensor] + :return: A TmaInfo containing the Copy Atom, TMA tensor, and SMEM layout + :rtype: TmaInfo """ - - # Set the smem_layout on the operation for later retrieval - op.smem_layout = ( - smem_layout.value - if isinstance(smem_layout, core._ComposedLayout) - else smem_layout + smem_rank = core.rank(smem_layout) + assert smem_rank == 3 or smem_rank == 4, ( + "a_smem_layout must be non-staged (atom, rest_m, rest_k) " + "or staged (atom, rest_m, rest_k, stage), " + f"but got rank = {smem_rank}" ) + # Keep the original SMEM layout object for later retrieval at Python level. + stored_smem_layout = smem_layout + + # Slice the smem_layout if it is staged + if smem_rank == 4: + smem_layout = core.select(smem_layout, mode=[0, 1, 2]) + ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) - mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:]) + mma_mnk: Any = mma_tiler_mnk + mma_tiler_mk = (mma_mnk[0], *mma_mnk[2:]) g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip) - cta_v_map = tiled_mma._thrfrg_A(g_tile) + cta_v_map: Any = tiled_mma._thrfrg_A(g_tile) cta_v_map = core.get(cta_v_map, mode=[1]) cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) @@ -120,23 +134,23 @@ def make_tiled_tma_atom_A( ) num_multicast = core.size(cluster_shape_vmnk, mode=[2]) - if isinstance(smem_layout, core._ComposedLayout): - smem_layout = smem_layout.value + smem_for_ir: Any = smem_layout + if isinstance(smem_for_ir, core._ComposedLayout): + smem_for_ir = smem_for_ir.value tma_format = None if internal_type is not None: + itype: Any = internal_type if not isinstance(internal_type, NumericMeta): raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") use_unpack = ( - internal_type.width == 8 + itype.width == 8 and isinstance(gmem_tensor.element_type, NumericMeta) - and gmem_tensor.element_type.width < 8 + and gmem_tensor.element_type.width < 8 # type: ignore[union-attr] ) internal_mlir_type = ( - gmem_tensor.element_type.mlir_type - if use_unpack - else internal_type.mlir_type + gmem_tensor.element_type.mlir_type if use_unpack else itype.mlir_type # type: ignore[union-attr] ) tma_format = _cute_nvgpu_ir.TmaDataFormat( _cute_nvgpu_ir.get_default_tma_format(internal_mlir_type, use_unpack) @@ -145,8 +159,8 @@ def make_tiled_tma_atom_A( # res[0] = the IR Value for the non-executable atom instance # res[1] = the IR Value for the associated TMA tensor res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, - smem_layout, + cast(Any, gmem_tensor).value, + smem_for_ir, cta_v_map, op._to_ir(), num_multicast=num_multicast, @@ -155,14 +169,19 @@ def make_tiled_tma_atom_A( ip=ip, ) if isinstance(op, CopyBulkTensorTileG2SOp): - return atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] - else: - assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) - return ( - atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1], + stored_smem_layout, ) + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) + @dsl_user_op def make_tiled_tma_atom_B( @@ -174,9 +193,9 @@ def make_tiled_tma_atom_B( cluster_shape_vmnk: Union[Shape, None] = None, *, internal_type: Optional[Type[Numeric]] = None, - loc=None, - ip=None, -) -> Tuple[atom.CopyAtom, Tensor]: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TmaInfo: """ Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation accounting for the NK projections of the TiledMMA for B tensor loads. @@ -217,22 +236,29 @@ def make_tiled_tma_atom_B( :param internal_type: An optional parameter for the internal data type to when element type does not match the copy type :type internal_type: Type[Numeric] - :return: A Copy Atom for this Operation and the associated TMA tensor - :rtype: Tuple[atom.CopyAtom, Tensor] + :return: A TmaInfo containing the Copy Atom, TMA tensor, and SMEM layout + :rtype: TmaInfo """ - - # Set the smem_layout on the operation for later retrieval - op.smem_layout = ( - smem_layout.value - if isinstance(smem_layout, core._ComposedLayout) - else smem_layout + smem_rank = core.rank(smem_layout) + assert smem_rank == 3 or smem_rank == 4, ( + "b_smem_layout must be non-staged (atom, rest_n, rest_k) " + "or staged (atom, rest_n, rest_k, stage), " + f"but got rank = {smem_rank}" ) + # Keep the original SMEM layout object for later retrieval at Python level. + stored_smem_layout = smem_layout + + # Slice the smem_layout if it is staged + if smem_rank == 4: + smem_layout = core.select(smem_layout, mode=[0, 1, 2]) + ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) - mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:]) + mma_mnk: Any = mma_tiler_mnk + mma_tiler_nk = (mma_mnk[1], *mma_mnk[2:]) g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip) - cta_v_map = tiled_mma._thrfrg_B(g_tile) + cta_v_map: Any = tiled_mma._thrfrg_B(g_tile) cta_v_map = core.get(cta_v_map, mode=[1]) cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) @@ -247,23 +273,23 @@ def make_tiled_tma_atom_B( ) num_multicast = core.size(cluster_shape_vmnk, mode=[1]) - if isinstance(smem_layout, core._ComposedLayout): - smem_layout = smem_layout.value + smem_for_ir: Any = smem_layout + if isinstance(smem_for_ir, core._ComposedLayout): + smem_for_ir = smem_for_ir.value tma_format = None if internal_type is not None: + itype: Any = internal_type if not isinstance(internal_type, NumericMeta): raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") use_unpack = ( - internal_type.width == 8 + itype.width == 8 and isinstance(gmem_tensor.element_type, NumericMeta) - and gmem_tensor.element_type.width < 8 + and gmem_tensor.element_type.width < 8 # type: ignore[union-attr] ) internal_mlir_type = ( - gmem_tensor.element_type.mlir_type - if use_unpack - else internal_type.mlir_type + gmem_tensor.element_type.mlir_type if use_unpack else itype.mlir_type # type: ignore[union-attr] ) tma_format = _cute_nvgpu_ir.TmaDataFormat( _cute_nvgpu_ir.get_default_tma_format(internal_mlir_type, use_unpack) @@ -272,8 +298,8 @@ def make_tiled_tma_atom_B( # res[0] = the IR Value for the non-executable atom instance # res[1] = the IR Value for the associated TMA tensor res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( - gmem_tensor.value, - smem_layout, + cast(Any, gmem_tensor).value, + smem_for_ir, cta_v_map, op._to_ir(), num_multicast=num_multicast, @@ -282,10 +308,195 @@ def make_tiled_tma_atom_B( ip=ip, ) if isinstance(op, CopyBulkTensorTileG2SOp): - return atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] - else: - assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) - return ( - atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1], + stored_smem_layout, ) + + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) + + +@dsl_user_op +def make_im2col_tma_atom_A( + op: Union[CopyBulkTensorIm2ColG2SOp, CopyBulkTensorIm2ColG2SMulticastOp], + gmem_tensor: Tensor, + smem_layout: Union[Layout, ComposedLayout], + mma_tiler_mnk: Shape, + tiled_mma: atom.TiledMma, + filter_trs: Tuple[int, int, int], + upper_padding_dhw: Tuple[int, int, int], + lower_padding_dhw: Tuple[int, int, int], + stride_dhw: Tuple[int, int, int], + dilation_dhw: Tuple[int, int, int], + cluster_shape_vmnk: Union[Shape, None] = None, + *, + internal_type: Optional[Type[Numeric]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TmaInfo: + """ + Makes a TMA Copy atom mapping to ``.im2col`` mode for ``cp.async.bulk.tensor`` PTX operation accounting for the MK projections of the TiledMMA for A tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a filter shape + - a padding shape + - a stride shape + - a dilation shape + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided + layout while maintaining consistency with the provided Tiler. + + This function returns two results: + + 1. the Copy Atom + 2. the TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorIm2ColG2SOp, CopyBulkTensorIm2ColG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Union[Layout, ComposedLayout] + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: atom.TiledMma + :param filter_trs: The filter shape (T, R, S) in TRS dimensions + :type filter_trs: Tuple[int, int, int] + :param upper_padding_dhw: The upper padding shape (D, H, W) in DHW dimensions + :type upper_padding_dhw: Tuple[int, int, int] + :param lower_padding_dhw: The lower padding shape (D, H, W) in DHW dimensions + :type lower_padding_dhw: Tuple[int, int, int] + :param stride_dhw: The stride shape (D, H, W) in DHW dimensions + :type stride_dhw: Tuple[int, int, int] + :param dilation_dhw: The dilation shape (D, H, W) in DHW dimensions + :type dilation_dhw: Tuple[int, int, int] + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A TmaInfo containing the Copy Atom, TMA tensor, and SMEM layout + :rtype: TmaInfo + """ + smem_rank = core.rank(smem_layout) + assert smem_rank == 3 or smem_rank == 4, ( + "a_smem_layout must be non-staged (atom, rest_m, rest_k) " + "or staged (atom, rest_m, rest_k, stage), " + f"but got rank = {smem_rank}" + ) + + # Keep the original SMEM layout object for later retrieval at Python level. + stored_smem_layout = smem_layout + + # Slice the smem_layout if it is staged + if smem_rank == 4: + smem_layout = core.select(smem_layout, mode=[0, 1, 2]) + + ident = core.make_identity_layout( + core.product_each(gmem_tensor.shape), loc=loc, ip=ip + ) + mma_mnk: Any = mma_tiler_mnk + mma_tiler_mk = (mma_mnk[0], *mma_mnk[2:]) + g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip) + cta_v_map: Any = tiled_mma._thrfrg_A(g_tile) + cta_v_map = core.get(cta_v_map, mode=[1]) + cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) + + # Compute im2col descriptor parameters + pad_upper_d, pad_upper_h, pad_upper_w = upper_padding_dhw + pad_lower_d, pad_lower_h, pad_lower_w = lower_padding_dhw + stride_d, stride_h, stride_w = stride_dhw + dilation_d, dilation_h, dilation_w = dilation_dhw + filter_t, filter_r, filter_s = filter_trs + lower_corner_whd = (-pad_lower_w, -pad_lower_h, -pad_lower_d) + upper_corner_whd = ( + pad_upper_w - ((filter_s - 1) * dilation_w), + pad_upper_h - ((filter_r - 1) * dilation_h), + pad_upper_d - ((filter_t - 1) * dilation_d), + ) + lower_padding_whd = (pad_lower_w, pad_lower_h, pad_lower_d) + upper_padding_whd = (pad_upper_w, pad_upper_h, pad_upper_d) + stride_whd = (stride_w, stride_h, stride_d) + lower_srt = (0, 0, 0) + stride_srt = (dilation_w, dilation_h, dilation_d) + + if isinstance(op, CopyBulkTensorIm2ColG2SOp): + num_multicast = 1 + else: + assert isinstance(op, CopyBulkTensorIm2ColG2SMulticastOp) + # multicast across the N-mode since those would share the same tile of A + if cluster_shape_vmnk is None: + raise ValueError( + "cluster_shape_vmnk must be provided for multicast A tensor loads" + ) + num_multicast = core.size(cluster_shape_vmnk, mode=[2]) + + smem_for_ir: Any = smem_layout + if isinstance(smem_for_ir, core._ComposedLayout): + smem_for_ir = smem_for_ir.value + + tma_format = None + if internal_type is not None: + itype: Any = internal_type + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + + use_unpack = ( + itype.width == 8 + and isinstance(gmem_tensor.element_type, NumericMeta) + and gmem_tensor.element_type.width < 8 # type: ignore[union-attr] + ) + internal_mlir_type = ( + gmem_tensor.element_type.mlir_type if use_unpack else itype.mlir_type # type: ignore[union-attr] + ) + tma_format = _cute_nvgpu_ir.TmaDataFormat( + _cute_nvgpu_ir.get_default_tma_format(internal_mlir_type, use_unpack) + ) + + res = _cute_nvgpu_ir.atom_make_non_exec_im2col_tma_load( + cast(Any, gmem_tensor).value, + smem_for_ir, + cta_v_map, + op._to_ir(), + core._pack_int_tuple(lower_corner_whd, loc=loc, ip=ip), + core._pack_int_tuple(upper_corner_whd, loc=loc, ip=ip), + core._pack_int_tuple(lower_padding_whd, loc=loc, ip=ip), + core._pack_int_tuple(upper_padding_whd, loc=loc, ip=ip), + core._pack_int_tuple(stride_whd, loc=loc, ip=ip), + core._pack_int_tuple(lower_srt, loc=loc, ip=ip), + core._pack_int_tuple(stride_srt, loc=loc, ip=ip), + num_multicast=num_multicast, + tma_format=tma_format, + loc=loc, + ip=ip, + ) + if isinstance(op, CopyBulkTensorIm2ColG2SOp): + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorIm2ColG2SNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) + + assert isinstance(op, CopyBulkTensorIm2ColG2SMulticastOp) + return TmaInfo( + atom.CopyAtom(op, CopyBulkTensorIm2ColG2SMulticastNonExecTrait(res[0])), + res[1], + stored_smem_layout, + ) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py index 9d7d4ec11..9b7fa0018 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py @@ -13,6 +13,26 @@ from .copy import * from .mma import * from .helpers import * +import warnings as _warnings +from typing import Any + +_deprecated_names = { + "OperandMajorMode": ( + OperandMajorMode, + "tcgen05.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead", + ), +} +del OperandMajorMode + + +def __getattr__(name: str) -> Any: + if name in _deprecated_names: + obj, msg = _deprecated_names[name] + _warnings.warn(msg, DeprecationWarning, stacklevel=2) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + # __all__ is required here for documentation generation __all__ = [ # @@ -35,16 +55,17 @@ __all__ = [ # # mma.py # - "OperandMajorMode", + "OperandMajorMode", # deprecated, use cute.nvgpu.OperandMajorMode instead "OperandSource", "CtaGroup", "Field", "MmaTF32Op", "MmaF16BF16Op", - "MmaF16BF16SparseOp", "MmaI8Op", "MmaFP8Op", + "MmaF8F6F4Op", "MmaMXF8Op", + "MmaMXF8F6F4Op", "MmaMXF4Op", "MmaMXF4NVF4Op", "SmemLayoutAtomKind", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py index 562affb1e..0cdb8dd08 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -11,7 +11,7 @@ import enum from dataclasses import dataclass -from typing import Type +from typing import Any, Optional, Type from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL @@ -42,6 +42,7 @@ class TmemLoadRedOp(enum.Enum): def __repr__(self) -> str: return f"<{self.__class__.__name__}.{self.name}>" + class Repetition(enum.Enum): """ An enumeration for the number of repetitions of a given TMEM copy within the instruction. @@ -176,7 +177,12 @@ class Ld16x64bOp(_LdBase): """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Ld16x64bTrait": """ Create a trait object for the 16x64b TMEM load operation. @@ -238,7 +244,12 @@ class Ld16x128bOp(_LdBase): ) def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Ld16x128bTrait": """ Create a trait object for the 16x128b TMEM load operation. @@ -296,7 +307,12 @@ class Ld16x256bOp(_LdBase): ) def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Ld16x256bTrait": """ Create a trait object for the 16x256b TMEM load operation. @@ -336,7 +352,12 @@ class Ld16x32bx2Op(_LdBase): """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Ld16x32bx2Trait": """ Create a trait object for the 16x32bx2 TMEM load operation. @@ -376,7 +397,12 @@ class Ld32x32bOp(_LdBase): """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Ld32x32bTrait": """ Create a trait object for the 32x32b TMEM load operation. @@ -420,7 +446,12 @@ class LdRed16x32bx2Op(_LdBase): half_split_off: int = 0 def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "LdRed16x32bx2Trait": """ Create a trait object for the 16x32bx2 TMEM load Reduce operation. @@ -465,7 +496,12 @@ class LdRed32x32bOp(_LdBase): nan: bool = False def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "LdRed32x32bTrait": """ Create a trait object for the 32x32b TMEM load Reduce operation. @@ -497,6 +533,7 @@ class LdRed32x32bTrait(Trait): pass + @dataclass(frozen=True) class _StBase(CopyOp): """ @@ -561,7 +598,12 @@ class St16x64bOp(_StBase): """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "St16x64bTrait": """ Create a trait object for the 16x64b TMEM store operation. @@ -610,7 +652,12 @@ class St16x128bOp(_StBase): ) def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "St16x128bTrait": ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( copy_internal_type.mlir_type, @@ -645,7 +692,12 @@ class St16x256bOp(_StBase): ) def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "St16x256bTrait": ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( copy_internal_type.mlir_type, @@ -671,7 +723,12 @@ class St16x32bx2Op(_StBase): """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "St16x32bx2Trait": ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( copy_internal_type.mlir_type, @@ -697,7 +754,12 @@ class St32x32bOp(_StBase): """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "St32x32bTrait": ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( copy_internal_type.mlir_type, @@ -733,9 +795,10 @@ class _S2TCopyBase(CopyOp): # Arch verification arch = BaseDSL._get_dsl().get_arch_enum() if not arch.is_family_of(Arch.sm_100f): + supported = Arch.filter(lambda a: a.is_family_of(Arch.sm_100f)) raise OpError( self, - f"expects arch to be one of {Arch.filter(lambda arch: arch.is_family_of(Arch.sm_100f))}, but got {arch}", + f"expects arch to be one of {supported}, but got {arch}", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) # Verify that the user provided enum values @@ -761,10 +824,31 @@ class Cp128x256bOp(_S2TCopyBase): See the `PTX documentation `__. This Operation corresponds to the ``.128x256b`` qualifier. + + SMEM to TMEM copy operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: SMEM to TMEM copy without elect_one + cute.copy( + s2t_atom, + smem_tensor, + tmem_tensor, + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(s2t_atom, smem_tensor, tmem_tensor) """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Cp128x256bTrait": """ Create a trait object for the 128x256b SMEM to TMEM copy operation. @@ -801,10 +885,31 @@ class Cp128x128bOp(_S2TCopyBase): See the `PTX documentation `__. This Operation corresponds to the ``.128x128b`` qualifier. + + SMEM to TMEM copy operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: SMEM to TMEM copy without elect_one + cute.copy( + s2t_atom, + smem_tensor, + tmem_tensor, + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(s2t_atom, smem_tensor, tmem_tensor) """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Cp128x128bTrait": ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( copy_internal_type.mlir_type, @@ -827,10 +932,31 @@ class Cp4x256bOp(_S2TCopyBase): See the `PTX documentation `__. This Operation corresponds to the ``.4x256b`` qualifier. + + SMEM to TMEM copy operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: SMEM to TMEM copy without elect_one + cute.copy( + s2t_atom, + smem_tensor, + tmem_tensor, + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(s2t_atom, smem_tensor, tmem_tensor) """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Cp4x256bTrait": ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( copy_internal_type.mlir_type, @@ -853,10 +979,31 @@ class Cp4x32x128bOp(_S2TCopyBase): See the `PTX documentation `__. This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled. + + SMEM to TMEM copy operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: SMEM to TMEM copy without elect_one + cute.copy( + s2t_atom, + smem_tensor, + tmem_tensor, + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(s2t_atom, smem_tensor, tmem_tensor) """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Cp4x32x128bTrait": ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( copy_internal_type.mlir_type, @@ -879,10 +1026,31 @@ class Cp2x64x128b0213Op(_S2TCopyBase): See the `PTX documentation `__. This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled. + + SMEM to TMEM copy operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: SMEM to TMEM copy without elect_one + cute.copy( + s2t_atom, + smem_tensor, + tmem_tensor, + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(s2t_atom, smem_tensor, tmem_tensor) """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Cp2x64x128b0213Trait": ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( copy_internal_type.mlir_type, @@ -905,10 +1073,32 @@ class Cp2x64x128b0123Op(_S2TCopyBase): See the `PTX documentation `__. This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled. + + SMEM to TMEM copy operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: SMEM to TMEM copy without elect_one + cute.copy( + s2t_atom, + smem_tensor, + tmem_tensor, + ) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.copy(s2t_atom, smem_tensor, tmem_tensor) + """ def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "Cp2x64x128b0123Trait": ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( copy_internal_type.mlir_type, diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py index d9d545b19..b7c6eece3 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py @@ -9,7 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import overload, Type, Tuple, Union, Optional +from typing import Any, overload, Type, Tuple, Union, Optional from cutlass.cutlass_dsl import dsl_user_op @@ -27,6 +27,7 @@ from ...typing import ( Int, Numeric, NumericMeta, + Boolean, Int16, Int32, Int64, @@ -60,7 +61,11 @@ from .copy import ( @dsl_user_op def make_smem_layout_atom( - kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None + kind: SmemLayoutAtomKind, + element_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: """ Makes a SMEM layout Atom. @@ -117,9 +122,15 @@ def make_smem_layout_atom( return core.make_composed_layout(sw, 0, outer, loc=loc, ip=ip) + @overload def tile_to_mma_shape( - atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None + atom: Layout, + mma_tile_shape: Shape, + order: Optional[IntTuple] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Layout: ... @@ -127,17 +138,22 @@ def tile_to_mma_shape( def tile_to_mma_shape( atom: ComposedLayout, mma_tile_shape: Shape, - order: IntTuple = None, + order: Optional[IntTuple] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: ... @dsl_user_op def tile_to_mma_shape( - atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None -): + atom: Union[Layout, ComposedLayout], + mma_tile_shape: Shape, + order: Optional[IntTuple] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, ComposedLayout]: """ Tiles a layout to an MMA shape. """ @@ -170,19 +186,38 @@ def tile_to_mma_shape( @dsl_user_op def commit( mbar_ptr: core.Pointer, - mask=None, + mask: Any = None, cta_group: CtaGroup = CtaGroup.ONE, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Perform an arrive operation on a mbarrier upon completion of previous MMA operations. + **Single-Thread Execution Required - DSL Does NOT Handle Automatically**: This operation + **must** be wrapped in :func:`cute.arch.elect_one`. Without ``elect_one()``, all 32 + threads in the warp will execute the commit, causing 32x redundant ``tcgen05.commit`` PTX instructions. + + .. code-block:: python + + # CORRECT: Wrap tcgen05.commit in elect_one + with cute.arch.elect_one(): + tcgen05.commit(barrier_ptr, None, cta_group) + + # WRONG: Without elect_one, all threads execute (32x redundant) + tcgen05.commit(barrier_ptr, None, cta_group) + :param mbar_ptr: A pointer to the mbarrier in SMEM :type mbar_ptr: Pointer :param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to :type mask: Int + :param cta_group: The CTA group size for the operation (ONE or TWO) + :type cta_group: CtaGroup + + .. seealso:: + - :func:`cute.arch.elect_one` - **REQUIRED** wrapper for single-thread execution + - :func:`cute.arch.mbarrier_arrive` - General barrier arrive operation """ if cta_group == CtaGroup.ONE: group = nvvm.Tcgen05GroupKind.CTA_1 @@ -200,7 +235,9 @@ def commit( @dsl_user_op -def int_to_smem_descriptor(i, *, loc=None, ip=None) -> ir.Value: +def int_to_smem_descriptor( + i: Any, *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> ir.Value: desc_type = _cute_nvgpu_ir.SmemDescType.get() return builtin.unrealized_conversion_cast( [desc_type], [Int64(i).ir_value(loc=loc, ip=ip)], loc=loc, ip=ip @@ -208,7 +245,12 @@ def int_to_smem_descriptor(i, *, loc=None, ip=None) -> ir.Value: @dsl_user_op -def smem_descriptor_to_int(desc: ir.Value, *, loc=None, ip=None) -> Int64: +def smem_descriptor_to_int( + desc: ir.Value, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int64: return Int64( builtin.unrealized_conversion_cast([Int64.mlir_type], [desc], loc=loc, ip=ip) ) @@ -273,14 +315,19 @@ def get_tmem_copy_properties( else: raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}") if is_tmem_load(atom): - return num_dp, num_bits, atom.op.repeat.value, atom.op.pack + return num_dp, num_bits, atom.op.repeat.value, atom.op.pack # type: ignore[union-attr] else: assert is_tmem_store(atom), "atom must be a TMEM store" - return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack + return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack # type: ignore[union-attr] @dsl_user_op -def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int: +def find_tmem_tensor_col_offset( + tmem_tensor: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int: """ Computes the TMEM column offset given a TMEM tensor. @@ -303,7 +350,11 @@ def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> In @dsl_user_op def make_tmem_copy( - atom: CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None + atom: CopyAtom, + tmem_tensor: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> TiledCopy: """ Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. @@ -317,11 +368,16 @@ def make_tmem_copy( @dsl_user_op def make_s2t_copy( - atom: CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None + atom: CopyAtom, + tmem_tensor: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> TiledCopy: """ Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. """ + tmem_tensor = core.filter_zeros(tmem_tensor, loc=loc, ip=ip) tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy( atom._trait.value, tmem_tensor.value, loc=loc, ip=ip ) @@ -331,7 +387,11 @@ def make_s2t_copy( @dsl_user_op def get_s2t_smem_desc_tensor( - atom: CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None + atom: CopyAtom, + smem_tensor: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: """ Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor. @@ -348,9 +408,9 @@ def make_umma_smem_desc( major: str, next_src: Optional[Pointer] = None, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Any: """ Construct shared memory descriptor for UMMA. @@ -392,3 +452,5 @@ def make_umma_smem_desc( loc=loc, ip=ip, ) + + diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index 273253a22..2c9902424 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -10,17 +10,20 @@ # is strictly prohibited. import enum +import warnings from dataclasses import dataclass -from typing import Type, Any +from typing import Type, Any, Union, Optional, cast, Tuple from cutlass.base_dsl.arch import Arch -from cutlass.cutlass_dsl import BaseDSL, T +from cutlass.cutlass_dsl import BaseDSL, T, DSLRuntimeError +from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir import ir +from typing_extensions import deprecated from ..common import OpError, normalize_field_to_ir_name +from ..common import OperandMajorMode as _OperandMajorMode from ... import core, atom from ...core import _pack_shape, rank, depth from ...typing import ( @@ -30,6 +33,8 @@ from ...typing import ( Float8E8M0FNU, Float8E5M2, Float8E4M3FN, + Float6E3M2FN, + Float6E2M3FN, Float16, BFloat16, Float32, @@ -45,6 +50,14 @@ from ...typing import ( from ...atom import Trait, make_atom +_F8F6F4_TYPES = [ + Float8E5M2, + Float8E4M3FN, + Float6E3M2FN, + Float6E2M3FN, + Float4E2M1FN, +] + #################################################################################################### # # MMA Ops and Traits @@ -60,6 +73,9 @@ class Tcgen05MmaOp(atom.MmaOp): pass +@deprecated( + "tcgen05.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead" +) class OperandMajorMode(enum.Enum): """ An enumeration for the majorness of the input operands of the MMA. @@ -74,14 +90,28 @@ class OperandMajorMode(enum.Enum): def __repr__(self) -> str: return f"<{self.__class__.__name__}.{self.name}>" + def __eq__(self, other: object) -> bool: + if hasattr(other, "_to_ir") and type(other._to_ir()) is type(self._to_ir()): + return self._to_ir() == other._to_ir() + raise DSLRuntimeError( + f"{self.__module__}.{self.__class__.__qualname__} cannot be compared with {other.__module__}.{other.__class__.__qualname__}" + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.value) + @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any) -> Optional["OperandMajorMode"]: if isinstance(value, str): value = value.upper() if value == "MN": return OperandMajorMode.MN elif value == "K": return OperandMajorMode.K + return None def _to_ir(self) -> _cute_ir.MajorMode: return self.value @@ -130,7 +160,6 @@ class Field(enum.Enum): ACCUMULATE = "accum_c" SFA = "sf_a" SFB = "sf_b" - def __str__(self) -> str: return f"{self.__class__.__name__}.{self.name}" @@ -151,8 +180,8 @@ class MmaOp(Tcgen05MmaOp): shape_mnk: Shape cta_group: CtaGroup a_src: OperandSource - a_major_mode: OperandMajorMode - b_major_mode: OperandMajorMode + a_major_mode: Union[_OperandMajorMode, OperandMajorMode] + b_major_mode: Union[_OperandMajorMode, OperandMajorMode] admissible_archs = Arch.filter( lambda arch: arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f) @@ -178,28 +207,49 @@ class MmaOp(Tcgen05MmaOp): self, "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", ) - if not isinstance(self.a_major_mode, OperandMajorMode): + if not isinstance(self.a_major_mode, _OperandMajorMode) and not isinstance( + self.a_major_mode, OperandMajorMode + ): raise OpError( self, - "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + "expects the 'a_major_mode' Op parameter to be a cute.nvgpu.OperandMajorMode or tcgen05.OperandMajorMode (deprecated) instance", ) - if not isinstance(self.b_major_mode, OperandMajorMode): + if not isinstance(self.b_major_mode, _OperandMajorMode) and not isinstance( + self.b_major_mode, OperandMajorMode + ): raise OpError( self, - "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + "expects the 'b_major_mode' Op parameter to be a cute.nvgpu.OperandMajorMode or tcgen05.OperandMajorMode (deprecated) instance", + ) + if isinstance(self.a_major_mode, OperandMajorMode) or isinstance( + self.b_major_mode, OperandMajorMode + ): + warnings.warn( + "tcgen05.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead", + DeprecationWarning, + stacklevel=2, + ) + # Normalize the major modes to the new enum type + # Since this is a frozen dataclass, we need to use the object.__setattr__ method to set the attributes + object.__setattr__( + self, "a_major_mode", _OperandMajorMode(self.a_major_mode.value) + ) + object.__setattr__( + self, "b_major_mode", _OperandMajorMode(self.b_major_mode.value) ) # Verify the instruction shape - if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if (rank(shape_mnk_tuple) not in [2, 3]) or (depth(shape_mnk_tuple) != 1): raise OpError( self, f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " f"but got {self.shape_mnk}", ) - m, n = self.shape_mnk[0], self.shape_mnk[1] + m, n = shape_mnk_tuple[0], shape_mnk_tuple[1] if self.cta_group == CtaGroup.ONE: if m not in [64, 128]: raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}") - if self.b_dtype.width == 8 and self.b_major_mode == OperandMajorMode.MN: + if self.b_dtype.width == 8 and (self.b_major_mode == _OperandMajorMode.MN): if (n < 16) or (n > 256) or (n % 16 != 0): raise OpError( self, @@ -214,7 +264,7 @@ class MmaOp(Tcgen05MmaOp): else: if m not in [128, 256]: raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") - if self.b_dtype.width == 8 and self.b_major_mode == OperandMajorMode.MN: + if self.b_dtype.width == 8 and (self.b_major_mode == _OperandMajorMode.MN): if (n < 32) or (n > 256) or (n % 32 != 0): raise OpError( self, @@ -240,7 +290,13 @@ class MmaOp(Tcgen05MmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) - def _verify_fragment_A(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: if input.memspace == AddressSpace.smem and isinstance( input.layout.type, _cute_ir.ComposedLayoutType ): @@ -252,7 +308,13 @@ class MmaOp(Tcgen05MmaOp): ) return True - def _verify_fragment_B(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: if input.memspace == AddressSpace.smem and isinstance( input.layout.type, _cute_ir.ComposedLayoutType ): @@ -268,7 +330,14 @@ class MmaOp(Tcgen05MmaOp): class MmaTraits(Trait): admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] - def set(self, field, value, *, loc=None, ip=None) -> None: + def set( + self, + field: Any, + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: field_ir = normalize_field_to_ir_name(field, self.admissible_fields) bool_val = Boolean(value).ir_value(loc=loc, ip=ip) try: @@ -282,7 +351,13 @@ class MmaTraits(Trait): self.value, attr, bool_val, loc=loc, ip=ip ) - def get(self, field, *, loc=None, ip=None) -> Any: + def get( + self, + field: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: field_ir = normalize_field_to_ir_name(field, self.admissible_fields) try: return _cute_nvgpu_ir.atom_get_value( @@ -300,14 +375,14 @@ class MmaTraits(Trait): class BlockScaledMmaOp(Tcgen05MmaOp): a_dtype: Type[Numeric] b_dtype: Type[Numeric] - acc_dtype: Float32 + acc_dtype: Type[Numeric] sf_dtype: Type[Numeric] sf_vec_size: int shape_mnk: Shape cta_group: CtaGroup a_src: OperandSource - a_major_mode: OperandMajorMode - b_major_mode: OperandMajorMode + a_major_mode: Union[_OperandMajorMode, OperandMajorMode] + b_major_mode: Union[_OperandMajorMode, OperandMajorMode] admissible_archs = [ Arch.sm_100a, @@ -334,24 +409,45 @@ class BlockScaledMmaOp(Tcgen05MmaOp): self, "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", ) - if not isinstance(self.a_major_mode, OperandMajorMode): + if not isinstance(self.a_major_mode, _OperandMajorMode) and not isinstance( + self.a_major_mode, OperandMajorMode + ): raise OpError( self, - "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + "expects the 'a_major_mode' Op parameter to be a cute.nvgpu.OperandMajorMode or tcgen05.OperandMajorMode (deprecated) instance", ) - if not isinstance(self.b_major_mode, OperandMajorMode): + if not isinstance(self.b_major_mode, _OperandMajorMode) and not isinstance( + self.b_major_mode, OperandMajorMode + ): raise OpError( self, - "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + "expects the 'b_major_mode' Op parameter to be a cute.nvgpu.OperandMajorMode or tcgen05.OperandMajorMode (deprecated) instance", + ) + if isinstance(self.a_major_mode, OperandMajorMode) or isinstance( + self.b_major_mode, OperandMajorMode + ): + warnings.warn( + "tcgen05.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead", + DeprecationWarning, + stacklevel=2, + ) + # Normalize the major modes to the new enum type + # Since this is a frozen dataclass, we need to use the object.__setattr__ method to set the attributes + object.__setattr__( + self, "a_major_mode", _OperandMajorMode(self.a_major_mode.value) + ) + object.__setattr__( + self, "b_major_mode", _OperandMajorMode(self.b_major_mode.value) ) # Verify the instruction shape - if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if (rank(shape_mnk_tuple) not in [2, 3]) or (depth(shape_mnk_tuple) != 1): raise OpError( self, f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " f"but got {self.shape_mnk}", ) - m, n = self.shape_mnk[0], self.shape_mnk[1] + m, n = shape_mnk_tuple[0], shape_mnk_tuple[1] if self.cta_group == CtaGroup.ONE: if m != 128: raise OpError(self, f"expects the M-mode to be 128, but got {m}") @@ -390,7 +486,13 @@ class BlockScaledMmaOp(Tcgen05MmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) - def _verify_fragment_A(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: if input.memspace == AddressSpace.smem and isinstance( input.layout.type, _cute_ir.ComposedLayoutType ): @@ -402,7 +504,13 @@ class BlockScaledMmaOp(Tcgen05MmaOp): ) return True - def _verify_fragment_B(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: if input.memspace == AddressSpace.smem and isinstance( input.layout.type, _cute_ir.ComposedLayoutType ): @@ -424,7 +532,14 @@ class BlockScaledMmaTraits(Trait): Field.SFB, ] - def set(self, field, value, *, loc=None, ip=None) -> None: + def set( + self, + field: Any, + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: field_ir = normalize_field_to_ir_name(field, self.admissible_fields) # Derive boolean/pointer IR names from enum values, no hard-coded strings. bool_field_ir = { @@ -460,7 +575,13 @@ class BlockScaledMmaTraits(Trait): self.value, attr, val, loc=loc, ip=ip ) - def get(self, field, *, loc=None, ip=None) -> Any: + def get( + self, + field: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: # Only boolean-returning fields supported for get. Derive from admissible_fields. gettable_fields = [ f for f in self.admissible_fields if f not in (Field.SFA, Field.SFB) @@ -491,6 +612,19 @@ class MmaTF32Op(MmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::tf32`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 TF32 MMA Operation" @@ -500,8 +634,8 @@ class MmaTF32Op(MmaOp): instruction_shape: Shape, cta_group: CtaGroup, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( TFloat32, @@ -518,16 +652,24 @@ class MmaTF32Op(MmaOp): def _verify(self) -> None: # Verify the instruction shape instruction_k = 8 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaTF32Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( shape_mnk.type.attribute, @@ -543,11 +685,11 @@ class MmaTF32Op(MmaOp): return MmaTF32Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), - ), + ], loc=loc, ip=ip, ) @@ -570,6 +712,19 @@ class MmaF16BF16Op(MmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::f16`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 F16/BF16 MMA Operation" @@ -581,8 +736,8 @@ class MmaF16BF16Op(MmaOp): instruction_shape: Shape, cta_group: CtaGroup, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( ab_dtype, @@ -612,16 +767,24 @@ class MmaF16BF16Op(MmaOp): ) # Instruction shape verification instruction_k = 16 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaF16BF16Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( shape_mnk.type.attribute, @@ -637,11 +800,11 @@ class MmaF16BF16Op(MmaOp): return MmaF16BF16Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), - ), + ], loc=loc, ip=ip, ) @@ -664,6 +827,19 @@ class MmaI8Op(MmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::i8`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 I8 MMA Operation" @@ -674,8 +850,8 @@ class MmaI8Op(MmaOp): instruction_shape: Shape, cta_group: CtaGroup, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( ab_dtype, @@ -699,24 +875,32 @@ class MmaI8Op(MmaOp): assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" # Instruction shape verification instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaI8Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( shape_mnk.type.attribute, self.cta_group.value, self.a_major_mode._to_ir(), self.b_major_mode._to_ir(), - (T.si8() if self.a_dtype.signed else T.ui8()), - (T.si8() if self.b_dtype.signed else T.ui8()), + (T.si8() if self.a_dtype.signed else T.ui8()), # type: ignore[attr-defined] + (T.si8() if self.b_dtype.signed else T.ui8()), # type: ignore[attr-defined] T.si32(), self.a_src._to_ir(), 0, @@ -724,11 +908,11 @@ class MmaI8Op(MmaOp): return MmaI8Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), - ), + ], loc=loc, ip=ip, ) @@ -744,12 +928,29 @@ class MmaI8Trait(MmaTraits): # +@deprecated("MmaFP8Op is deprecated, use MmaF8F6F4Op instead") @dataclass(frozen=True) class MmaFP8Op(MmaOp): """ F8 tcgen05 MMA Operation. + .. deprecated:: + Use :class:`MmaF8F6F4Op` instead. + See the `PTX documentation `__. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 F8 MMA Operation" @@ -761,8 +962,8 @@ class MmaFP8Op(MmaOp): instruction_shape: Shape, cta_group: CtaGroup, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( ab_dtype, @@ -778,10 +979,13 @@ class MmaFP8Op(MmaOp): def _verify(self) -> None: # Input data type verification - if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + if self.a_dtype not in [ + Float8E5M2, + Float8E4M3FN, + ]: raise OpError( self, - "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + "expects the 'ab_dtype' Op parameter to be one of Float8E5M2, Float8E4M3FN" ) assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" # Accumulator data type verification @@ -792,16 +996,24 @@ class MmaFP8Op(MmaOp): ) # Instruction shape verification instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaFP8Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( shape_mnk.type.attribute, @@ -817,11 +1029,11 @@ class MmaFP8Op(MmaOp): return MmaFP8Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), - ), + ], loc=loc, ip=ip, ) @@ -832,19 +1044,152 @@ class MmaFP8Trait(MmaTraits): pass +@dataclass(frozen=True) +class MmaF8F6F4Op(MmaOp): + """ + F8F6F4 tcgen05 MMA Operation. + + See the `PTX documentation `__. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + + """ + + descriptive_name = "tcgen05 F8F6F4 MMA Operation" + + def __init__( + self, + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], + ) -> None: + super().__init__( + a_dtype, + b_dtype, + acc_dtype, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in _F8F6F4_TYPES: + raise OpError( + self, + "expects the 'a_dtype' Op parameter to be one of " + "Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, or Float4E2M1FN", + ) + if self.b_dtype not in _F8F6F4_TYPES: + raise OpError( + self, + "expects the 'b_dtype' Op parameter to be one of " + "Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, or Float4E2M1FN", + ) + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Instruction shape verification + instruction_k = 32 + shape_mnk_tuple = cast(Tuple[Any, ...], self.shape_mnk) + if rank(self.shape_mnk) == 2: + shape_mnk_tuple = (*shape_mnk_tuple, instruction_k) + object.__setattr__(self, "shape_mnk", shape_mnk_tuple) + if shape_mnk_tuple[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {shape_mnk_tuple[2]}", + ) + + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaF8F6F4Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaF8F6F4Trait( + make_atom( + ty, + [ + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + ], + loc=loc, + ip=ip, + ) + ) + + +class MmaF8F6F4Trait(MmaTraits): + pass + + # # MXF8F6F4 MMA # +@deprecated("MmaMXF8Op is deprecated, use MmaMXF8F6F4Op instead") @dataclass(frozen=True) class MmaMXF8Op(BlockScaledMmaOp): """ MXF8 tcgen05 BlockScaled MMA Operation. + .. deprecated:: + Use :class:`MmaMXF8F6F4Op` instead. + See the `PTX documentation `__. This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation" @@ -855,8 +1200,8 @@ class MmaMXF8Op(BlockScaledMmaOp): instruction_shape: Shape, cta_group: CtaGroup, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( ab_dtype, @@ -882,16 +1227,24 @@ class MmaMXF8Op(BlockScaledMmaOp): assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" # Instruction shape verification instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF8Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( shape_mnk.type.attribute, @@ -908,7 +1261,7 @@ class MmaMXF8Op(BlockScaledMmaOp): return MmaMXF8Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), @@ -918,7 +1271,7 @@ class MmaMXF8Op(BlockScaledMmaOp): core.make_ptr( self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip ).value, - ), + ], loc=loc, ip=ip, ) @@ -929,6 +1282,126 @@ class MmaMXF8Trait(BlockScaledMmaTraits): pass +@dataclass(frozen=True) +class MmaMXF8F6F4Op(BlockScaledMmaOp): + """ + MXF8F6F4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + + """ + + descriptive_name = "tcgen05 MXF8F6F4 BlockScaled MMA Operation" + + def __init__( + self, + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], + ) -> None: + super().__init__( + a_dtype, + b_dtype, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in _F8F6F4_TYPES: + raise OpError( + self, + "expects the 'a_dtype' Op parameter to be one of " + "Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, or Float4E2M1FN", + ) + if self.b_dtype not in _F8F6F4_TYPES: + raise OpError( + self, + "expects the 'b_dtype' Op parameter to be one of " + "Float8E5M2, Float8E4M3FN, Float6E3M2FN, Float6E2M3FN, or Float4E2M1FN", + ) + + # Instruction shape verification + instruction_k = 32 + shape_mnk_tuple = cast(Tuple[Any, ...], self.shape_mnk) + if rank(self.shape_mnk) == 2: + shape_mnk_tuple = (*shape_mnk_tuple, instruction_k) + object.__setattr__(self, "shape_mnk", shape_mnk_tuple) + if shape_mnk_tuple[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {shape_mnk_tuple[2]}", + ) + + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF8F6F4Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF8F6F4Trait( + make_atom( + ty, + [ + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + core.make_ptr( + self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip + ).value, + ], + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF8F6F4Trait(BlockScaledMmaTraits): + pass + + # # MXF4 MMA # @@ -941,6 +1414,19 @@ class MmaMXF4Op(BlockScaledMmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::mxf4`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation" @@ -960,24 +1446,32 @@ class MmaMXF4Op(BlockScaledMmaOp): instruction_shape, cta_group, a_src, - OperandMajorMode.K, - OperandMajorMode.K, + _OperandMajorMode.K, + _OperandMajorMode.K, ) self._verify() def _verify(self) -> None: # Instruction shape verification instruction_k = 64 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF4Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF4Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( shape_mnk.type.attribute, @@ -994,7 +1488,7 @@ class MmaMXF4Op(BlockScaledMmaOp): return MmaMXF4Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), @@ -1004,7 +1498,7 @@ class MmaMXF4Op(BlockScaledMmaOp): core.make_ptr( self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip ).value, - ), + ], loc=loc, ip=ip, ) @@ -1027,6 +1521,19 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation" @@ -1047,8 +1554,8 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): instruction_shape, cta_group, a_src, - OperandMajorMode.K, - OperandMajorMode.K, + _OperandMajorMode.K, + _OperandMajorMode.K, ) self._verify() @@ -1061,16 +1568,24 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): ) # Instruction shape verification instruction_k = 64 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF4NVF4Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF4NVF4Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( shape_mnk.type.attribute, @@ -1087,7 +1602,7 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): return MmaMXF4NVF4Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), @@ -1097,7 +1612,7 @@ class MmaMXF4NVF4Op(BlockScaledMmaOp): core.make_ptr( self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip ).value, - ), + ], loc=loc, ip=ip, ) @@ -1121,6 +1636,19 @@ class SM103MmaMXF4Op(BlockScaledMmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::mxf4`` qualifier. This Operation is for SM103. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 SM103 MXF4 BlockScaled MMA Operation" @@ -1140,24 +1668,32 @@ class SM103MmaMXF4Op(BlockScaledMmaOp): instruction_shape, cta_group, a_src, - OperandMajorMode.K, - OperandMajorMode.K, + _OperandMajorMode.K, + _OperandMajorMode.K, ) self._verify() def _verify(self) -> None: # Instruction shape verification instruction_k = 96 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF4Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF4Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( shape_mnk.type.attribute, @@ -1175,7 +1711,7 @@ class SM103MmaMXF4Op(BlockScaledMmaOp): return MmaMXF4Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), @@ -1185,7 +1721,7 @@ class SM103MmaMXF4Op(BlockScaledMmaOp): core.make_ptr( self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip ).value, - ), + ], loc=loc, ip=ip, ) @@ -1205,6 +1741,19 @@ class SM103MmaMXF4NVF4Op(BlockScaledMmaOp): See the `PTX documentation `__. This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier. This Operation is for SM103. + + MMA operations should be issued by a single thread. The DSL automatically handles this by + implicitly adding ``elect_one()`` around the copy operation. + + .. code-block:: python + + # CORRECT: MMA without elect_one + cute.gemm(mma_atom, d, a, b, c) + + # WRONG: Do NOT wrap in elect_one (can cause deadlock) + with cute.arch.elect_one(): # INCORRECT + cute.gemm(mma_atom, d, a, b, c) + """ descriptive_name = "tcgen05 SM103 MXF4NVF4 BlockScaled MMA Operation" @@ -1225,8 +1774,8 @@ class SM103MmaMXF4NVF4Op(BlockScaledMmaOp): instruction_shape, cta_group, a_src, - OperandMajorMode.K, - OperandMajorMode.K, + _OperandMajorMode.K, + _OperandMajorMode.K, ) self._verify() @@ -1239,16 +1788,24 @@ class SM103MmaMXF4NVF4Op(BlockScaledMmaOp): ) # Instruction shape verification instruction_k = 96 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF4NVF4Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF4NVF4Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( shape_mnk.type.attribute, @@ -1266,7 +1823,7 @@ class SM103MmaMXF4NVF4Op(BlockScaledMmaOp): return MmaMXF4NVF4Trait( make_atom( ty, - ( + [ Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), Boolean(False).ir_value(loc=loc, ip=ip), @@ -1276,7 +1833,7 @@ class SM103MmaMXF4NVF4Op(BlockScaledMmaOp): core.make_ptr( self.sf_dtype, 0, _cute_ir.AddressSpace.tmem, loc=loc, ip=ip ).value, - ), + ], loc=loc, ip=ip, ) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py index 65eacac44..bacbf63ee 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -18,6 +18,7 @@ __all__ = [ # mma.py "Field", "MmaF16BF16Op", + "MmaFP8Op", "MmaMXF4Op", "MmaMXF4NVF4Op", # copy.py diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py index 0f707b28d..ff40e6b26 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py @@ -10,7 +10,7 @@ # is strictly prohibited. from dataclasses import dataclass -from typing import Type +from typing import Any, Optional, Type import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir @@ -23,6 +23,12 @@ from ...atom import CopyOp, Trait, make_atom @dataclass(frozen=True) class BaseOp(CopyOp): + """ + Base class for warp-level matrix copy operations. + + Provides shared validation and string formatting for warp load/store ops. + """ + transpose: bool = False num_matrices: int = 1 unpack_bits: Optional[int] = None @@ -66,7 +72,12 @@ class LdMatrix8x8x16bOp(BaseOp): raise OpError(self, "Op doesn't support unpacking") def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "LdMatrix8x8x16bTrait": mode = _pack_shape((8, 8), loc=loc, ip=ip) ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( @@ -80,6 +91,8 @@ class LdMatrix8x8x16bOp(BaseOp): class LdMatrix8x8x16bTrait(Trait): + """Trait generated by ``LdMatrix8x8x16bOp``.""" + pass @@ -107,7 +120,12 @@ class LdMatrix8x16x8bOp(BaseOp): raise OpError(self, "Op unpack bits must be 4 or 6 or None") def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "LdMatrix8x16x8bTrait": # LdMatrix8x16x8b without unpacking doesn't exist # but is equivalent to LdMatrix8x8x16b @@ -129,6 +147,8 @@ class LdMatrix8x16x8bOp(BaseOp): class LdMatrix8x16x8bTrait(Trait): + """Trait generated by ``LdMatrix8x16x8bOp``.""" + pass @@ -156,7 +176,12 @@ class LdMatrix16x8x8bOp(BaseOp): raise OpError(self, "Op unpack bits must be 4 or 6 or None") def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "LdMatrix16x8x8bTrait": mode = _pack_shape((16, 8), loc=loc, ip=ip) sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u8 @@ -175,6 +200,8 @@ class LdMatrix16x8x8bOp(BaseOp): class LdMatrix16x8x8bTrait(Trait): + """Trait generated by ``LdMatrix16x8x8bOp``.""" + pass @@ -202,7 +229,12 @@ class LdMatrix16x16x8bOp(BaseOp): raise OpError(self, "Op unpack bits must be 4 or 6 or None") def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "LdMatrix16x16x8bTrait": mode = _pack_shape((16, 16), loc=loc, ip=ip) sz_pattern = _cute_nvgpu_ir.LdsmSzPattern.u8 @@ -221,6 +253,8 @@ class LdMatrix16x16x8bOp(BaseOp): class LdMatrix16x16x8bTrait(Trait): + """Trait generated by ``LdMatrix16x16x8bOp``.""" + pass @@ -244,7 +278,12 @@ class StMatrix8x8x16bOp(BaseOp): raise OpError(self, "Op doesn't support unpacking") def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "StMatrix8x8x16bTrait": mode = _pack_shape((8, 8), loc=loc, ip=ip) ty = _cute_nvgpu_ir.CopyAtomStsmType.get( @@ -257,6 +296,8 @@ class StMatrix8x8x16bOp(BaseOp): class StMatrix8x8x16bTrait(Trait): + """Trait generated by ``StMatrix8x8x16bOp``.""" + pass @@ -282,7 +323,12 @@ class StMatrix16x8x8bOp(BaseOp): raise OpError(self, "Op doesn't support unpacking") def _make_trait( - self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + self, + copy_internal_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> "StMatrix16x8x8bTrait": mode = _pack_shape((16, 8), loc=loc, ip=ip) ty = _cute_nvgpu_ir.CopyAtomStsmType.get( @@ -295,4 +341,6 @@ class StMatrix16x8x8bOp(BaseOp): class StMatrix16x8x8bTrait(Trait): + """Trait generated by ``StMatrix16x8x8bOp``.""" + pass diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index c08438527..4c51d7f59 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -10,10 +10,9 @@ # is strictly prohibited. from dataclasses import dataclass -from typing import Type, Any +from typing import Any, Optional, Type import enum -from cutlass import cute from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL @@ -24,10 +23,10 @@ from ...typing import ( Float4E2M1FN, Float8E8M0FNU, Float8E4M3FN, + Float8E5M2, Float16, BFloat16, Float32, - Boolean, Numeric, Pointer, ) @@ -89,7 +88,13 @@ class MmaF16BF16Op(WarpMmaOp): "expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaF16BF16Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM80Type.get( shape_mnk.type.attribute, @@ -107,17 +112,106 @@ class MmaF16BF16Op(WarpMmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) - def _verify_fragment_A(self, input: Tensor, *, loc=None, ip=None): - pass + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: + return True - def _verify_fragment_B(self, input: Tensor, *, loc=None, ip=None): - pass + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: + return True class MmaF16BF16Trait(Trait): pass +@dataclass(frozen=True) +class MmaFP8Op(WarpMmaOp): + """ + FP8 warp-level MMA Operation (SM89). + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands. + """ + + ab_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + + def __post_init__(self) -> None: + if self.ab_dtype not in [Float8E4M3FN, Float8E5M2]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float8E4M3FN or Float8E5M2", + ) + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32 or Float16", + ) + if self.shape_mnk not in [(16, 8, 32), (16, 8, 16)]: + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be (16,8,32) or (16,8,16)", + ) + + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaFP8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM89Type.get( + shape_mnk.type.attribute, + self.ab_dtype.mlir_type, + self.ab_dtype.mlir_type, + self.acc_dtype.mlir_type, + ) + return MmaFP8Trait(make_atom(ty, loc=loc, ip=ip)) + + def __str__(self) -> str: + return ( + "warp-level FP8 MMA Operation (SM89)" + + f"\n A/B data type = {self.ab_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + pass + + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + pass + + +class MmaFP8Trait(Trait): + pass + + # Base class for SM120 Blockscaled MMA Ops @dataclass(frozen=True) class MmaSM120BlockScaledOp(MmaOp): @@ -129,16 +223,20 @@ class MmaSM120BlockScaledOp(MmaOp): use_sf_layout_TV: bool = False admissible_archs = [ - "sm_120a", + Arch.sm_120a, + Arch.sm_121a, ] def __post_init__(self) -> None: # Verify arch arch = BaseDSL._get_dsl().get_arch_enum() - if not arch == Arch.sm_120a: + if arch not in self.admissible_archs: raise OpError( self, - f"expects arch to be one of {self.admissible_archs}, but got {arch}", + f"expects arch to be one of {self.admissible_archs}, but got {arch}" + " - Note: sm_120f is currently not supported, " + " please compile for your local GPU architecture instead with env " + "CUTE_DSL_ARCH set to sm_120a or sm_121a", suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", ) if self.ab_dtype != Float4E2M1FN: @@ -185,10 +283,22 @@ class MmaSM120BlockScaledOp(MmaOp): + f"\n SF data type = {self.sf_type}" ) - def _verify_fragment_A(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: pass - def _verify_fragment_B(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: pass @@ -213,19 +323,23 @@ class Field(enum.Enum): class MmaBlockScaledTrait(Trait): admissible_fields = [ - Field.ACCUMULATE, Field.SFA, Field.SFB, ] - def set(self, field, value, *, loc=None, ip=None) -> None: + def set( + self, + field: Any, + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: if field not in self.admissible_fields: raise ValueError( f"expects field to be one of {self.admissible_fields}, but got {field}" ) - if field == Field.ACCUMULATE: - value = Boolean(value).ir_value(loc=loc, ip=ip) - elif field in [Field.SFA, Field.SFB]: + if field in [Field.SFA, Field.SFB]: if not isinstance(value, Pointer): raise ValueError( f"expects value to be a pointer for {field}, but got {type(value).__name__}" @@ -238,14 +352,14 @@ class MmaBlockScaledTrait(Trait): self.value, attr, value, loc=loc, ip=ip ) - def get(self, field, *, loc=None, ip=None) -> Any: - if field not in [Field.ACCUMULATE]: - raise ValueError(f"the get method for {field} is not supported") - field_name = f"#cute_nvgpu.atom_mma_field_sm120_block_scaled<{field._to_ir_field_name()}>" - attr = ir.Attribute.parse(field_name) - return _cute_nvgpu_ir.atom_get_value( - Boolean.mlir_type, self.value, attr, loc=loc, ip=ip - ) + def get( + self, + field: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: + raise ValueError(f"the get method for {field} is not supported") # @@ -281,7 +395,13 @@ class MmaMXF4Op(MmaSM120BlockScaledOp): 32, ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF4Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF4Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get( shape_mnk.type.attribute, @@ -332,7 +452,13 @@ class MmaMXF4NVF4Op(MmaSM120BlockScaledOp): 16, ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF4NVF4Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaMXF4NVF4Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM120BlockScaledType.get( shape_mnk.type.attribute, diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py index 80dc24441..dab9c7310 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py @@ -12,10 +12,30 @@ from .mma import * from .helpers import * +import warnings as _warnings +from typing import Any + +_deprecated_names = { + "OperandMajorMode": ( + OperandMajorMode, + "warpgroup.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead", + ), +} +del OperandMajorMode + + +def __getattr__(name: str) -> Any: + if name in _deprecated_names: + obj, msg = _deprecated_names[name] + _warnings.warn(msg, DeprecationWarning, stacklevel=2) + return obj + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + # __all__ is required here for documentation generation __all__ = [ # mma.py - "OperandMajorMode", + "OperandMajorMode", # deprecated, use cute.nvgpu.OperandMajorMode instead "OperandSource", "Field", "MmaF16BF16Op", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py index 05b8a2d8d..616473a47 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py @@ -9,10 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Type +from typing import Any, Optional, Type from cutlass.cutlass_dsl import dsl_user_op +from cutlass._mlir import ir from cutlass._mlir.dialects import nvvm from ...typing import Numeric, NumericMeta, ComposedLayout @@ -22,7 +23,11 @@ from .mma import SmemLayoutAtomKind @dsl_user_op def make_smem_layout_atom( - kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None + kind: SmemLayoutAtomKind, + element_type: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ComposedLayout: """ Makes a SMEM layout Atom. @@ -86,7 +91,9 @@ def make_smem_layout_atom( @dsl_user_op -def fence(*, loc=None, ip=None) -> None: +def fence( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ See the `PTX documentation `__. """ @@ -94,7 +101,9 @@ def fence(*, loc=None, ip=None) -> None: @dsl_user_op -def commit_group(*, loc=None, ip=None) -> None: +def commit_group( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ See the `PTX documentation `__. """ @@ -102,7 +111,12 @@ def commit_group(*, loc=None, ip=None) -> None: @dsl_user_op -def wait_group(group, *, loc=None, ip=None) -> None: +def wait_group( + group: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ See the `PTX documentation `__. """ diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py index bf5d7110d..6c32191bc 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -11,10 +11,11 @@ import enum from dataclasses import dataclass -from typing import Type, Any +from typing import Any, Optional, Type, Union, cast +import warnings from cutlass.base_dsl.arch import Arch -from cutlass.cutlass_dsl import BaseDSL, T +from cutlass.cutlass_dsl import BaseDSL, T, DSLRuntimeError from typing_extensions import deprecated import cutlass._mlir.dialects.cute as _cute_ir @@ -22,6 +23,7 @@ import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass._mlir import ir from ..common import OpError, normalize_field_to_ir_name +from ..common import OperandMajorMode as _OperandMajorMode from ...core import _pack_shape, rank, depth from ...typing import ( Shape, @@ -38,7 +40,7 @@ from ...typing import ( Numeric, AddressSpace, ) -from ...atom import MmaOp, Trait, make_atom +from ...atom import MmaOp as AtomMmaOp, Trait, make_atom #################################################################################################### @@ -48,7 +50,7 @@ from ...atom import MmaOp, Trait, make_atom #################################################################################################### -class WarpGroupMmaOp(MmaOp): +class WarpGroupMmaOp(AtomMmaOp): """ Base class for all warpgroup-level MMA operations. """ @@ -56,6 +58,9 @@ class WarpGroupMmaOp(MmaOp): pass +@deprecated( + "warpgroup.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead" +) class OperandMajorMode(enum.Enum): """ An enumeration for the majorness of the input operands of the MMA. @@ -70,14 +75,29 @@ class OperandMajorMode(enum.Enum): def __repr__(self) -> str: return f"<{self.__class__.__name__}.{self.name}>" + def __eq__(self, other: object) -> bool: + if hasattr(other, "_to_ir") and type(other._to_ir()) is type(self._to_ir()): + return self._to_ir() == other._to_ir() + raise DSLRuntimeError( + f"{self.__module__}.{self.__class__.__qualname__} cannot be compared with " + f"{getattr(other, '__module__', '?')}.{other.__class__.__qualname__}" + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def __hash__(self) -> int: + return hash(self.value) + @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any) -> Optional["OperandMajorMode"]: if isinstance(value, str): value = value.upper() if value == "MN": return OperandMajorMode.MN elif value == "K": return OperandMajorMode.K + return None def _to_ir(self) -> _cute_ir.MajorMode: return self.value @@ -125,8 +145,8 @@ class MmaOp(WarpGroupMmaOp): acc_dtype: Type[Numeric] shape_mnk: Shape a_src: OperandSource - a_major_mode: OperandMajorMode - b_major_mode: OperandMajorMode + a_major_mode: Union[_OperandMajorMode, OperandMajorMode] + b_major_mode: Union[_OperandMajorMode, OperandMajorMode] def __post_init__(self) -> None: # Verify arch @@ -143,24 +163,45 @@ class MmaOp(WarpGroupMmaOp): self, "expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance", ) - if not isinstance(self.a_major_mode, OperandMajorMode): + if not isinstance(self.a_major_mode, _OperandMajorMode) and not isinstance( + self.a_major_mode, OperandMajorMode + ): raise OpError( self, - "expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", + "expects the 'a_major_mode' Op parameter to be a cute.nvgpu.OperandMajorMode or warpgroup.OperandMajorMode (deprecated) instance", ) - if not isinstance(self.b_major_mode, OperandMajorMode): + if not isinstance(self.b_major_mode, _OperandMajorMode) and not isinstance( + self.b_major_mode, OperandMajorMode + ): raise OpError( self, - "expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", + "expects the 'b_major_mode' Op parameter to be a cute.nvgpu.OperandMajorMode or warpgroup.OperandMajorMode (deprecated) instance", + ) + if isinstance(self.a_major_mode, OperandMajorMode) or isinstance( + self.b_major_mode, OperandMajorMode + ): + warnings.warn( + "warpgroup.OperandMajorMode is deprecated, use cute.nvgpu.OperandMajorMode instead", + DeprecationWarning, + stacklevel=2, + ) + # Normalize the major modes to the new enum type + # Since this is a frozen dataclass, we need to use the object.__setattr__ method to set the attributes + object.__setattr__( + self, "a_major_mode", _OperandMajorMode(self.a_major_mode.value) + ) + object.__setattr__( + self, "b_major_mode", _OperandMajorMode(self.b_major_mode.value) ) # Verify instruction shape - if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if (rank(shape_mnk_tuple) not in [2, 3]) or (depth(shape_mnk_tuple) != 1): raise OpError( self, f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " f"but got {self.shape_mnk}", ) - m, n = self.shape_mnk[0], self.shape_mnk[1] + m, n = shape_mnk_tuple[0], shape_mnk_tuple[1] if m != 64: raise OpError(self, f"expects the M-mode to be 64, but got {m}") if (n < 8) or (n > 256) or (n % 8 != 0): @@ -181,7 +222,13 @@ class MmaOp(WarpGroupMmaOp): + f"\n Instruction shape MNK = {self.shape_mnk}" ) - def _verify_fragment_A(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_A( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: if input.memspace == AddressSpace.smem and isinstance( input.layout.type, _cute_ir.ComposedLayoutType ): @@ -193,7 +240,13 @@ class MmaOp(WarpGroupMmaOp): ) return True - def _verify_fragment_B(self, input: Tensor, *, loc=None, ip=None): + def _verify_fragment_B( + self, + input: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> bool: if input.memspace == AddressSpace.smem and isinstance( input.layout.type, _cute_ir.ComposedLayoutType ): @@ -218,7 +271,14 @@ class MmaTraits(Trait): """ return normalize_field_to_ir_name(field, self.admissible_fields) - def set(self, field, value, *, loc=None, ip=None) -> None: + def set( + self, + field: Any, + value: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: field_ir_name = self._normalize_field_name(field) # Prefer the newer builder that accepts a logical field name, but keep # a fallback for legacy attribute-based construction to avoid breaking changes. @@ -235,7 +295,13 @@ class MmaTraits(Trait): self.value, attr, bool_val, loc=loc, ip=ip ) - def get(self, field, *, loc=None, ip=None) -> Any: + def get( + self, + field: Any, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Any: field_ir_name = self._normalize_field_name(field) try: return _cute_nvgpu_ir.atom_get_value( @@ -266,8 +332,8 @@ class MmaF16BF16Op(MmaOp): acc_dtype: Type[Numeric], instruction_shape: Shape, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( ab_dtype, @@ -301,16 +367,24 @@ class MmaF16BF16Op(MmaOp): ) # Verify the instruction shape instruction_k = 16 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaF16BF16Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( shape_mnk.type.attribute, @@ -322,7 +396,7 @@ class MmaF16BF16Op(MmaOp): self.a_src._to_ir(), ) return MmaF16BF16Trait( - make_atom(ty, (Boolean(False).ir_value(loc=loc, ip=ip),), loc=loc, ip=ip) + make_atom(ty, [Boolean(False).ir_value(loc=loc, ip=ip)], loc=loc, ip=ip) ) @@ -348,8 +422,8 @@ class MmaF8Op(MmaOp): acc_dtype: Type[Numeric], instruction_shape: Shape, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( a_dtype, @@ -362,7 +436,7 @@ class MmaF8Op(MmaOp): ) self._verify() - def _verify(self): + def _verify(self) -> None: # Input data type verification if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: raise OpError( @@ -382,16 +456,24 @@ class MmaF8Op(MmaOp): ) # Verify the instruction shape instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaF8Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( shape_mnk.type.attribute, @@ -403,7 +485,7 @@ class MmaF8Op(MmaOp): self.a_src._to_ir(), ) return MmaF8Trait( - make_atom(ty, (Boolean(False).ir_value(loc=loc, ip=ip),), loc=loc, ip=ip) + make_atom(ty, [Boolean(False).ir_value(loc=loc, ip=ip)], loc=loc, ip=ip) ) @@ -429,8 +511,8 @@ class MmaI8Op(MmaOp): acc_dtype: Type[Numeric], instruction_shape: Shape, a_src: OperandSource, - a_major_mode: OperandMajorMode, - b_major_mode: OperandMajorMode, + a_major_mode: Union[_OperandMajorMode, OperandMajorMode], + b_major_mode: Union[_OperandMajorMode, OperandMajorMode], ) -> None: super().__init__( a_dtype, @@ -443,7 +525,7 @@ class MmaI8Op(MmaOp): ) self._verify() - def _verify(self): + def _verify(self) -> None: # Input data type verification if self.a_dtype not in [Int8, Uint8]: raise OpError( @@ -464,16 +546,18 @@ class MmaI8Op(MmaOp): # Verify the instruction shape instruction_k = 32 - if rank(self.shape_mnk) == 2: - object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) - if self.shape_mnk[2] != instruction_k: + shape_mnk_tuple: Any = cast(Any, self.shape_mnk) + if rank(shape_mnk_tuple) == 2: + object.__setattr__(self, "shape_mnk", (*shape_mnk_tuple, instruction_k)) + shape_mnk_tuple = cast(Any, self.shape_mnk) + if shape_mnk_tuple[2] != instruction_k: raise OpError( self, f"expects the instruction extent in the K-mode to be {instruction_k}, " - f"but got {self.shape_mnk[2]}", + f"but got {shape_mnk_tuple[2]}", ) - n = self.shape_mnk[1] + n = shape_mnk_tuple[1] if not (n >= 8 and n <= 256 and (n == 8 or n == 24 or n % 16 == 0)): raise OpError( self, @@ -481,19 +565,25 @@ class MmaI8Op(MmaOp): f"or N=16*i where i={{3,4,...,15,16}}. But got {n}", ) - def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": + def _make_trait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, + ) -> "MmaI8Trait": shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( shape_mnk.type.attribute, self.a_major_mode._to_ir(), self.b_major_mode._to_ir(), - (T.si8() if self.a_dtype.signed else T.ui8()), - (T.si8() if self.b_dtype.signed else T.ui8()), + (T.si8() if self.a_dtype.signed else T.ui8()), # type: ignore[attr-defined] + (T.si8() if self.b_dtype.signed else T.ui8()), # type: ignore[attr-defined] self.acc_dtype.mlir_type, self.a_src._to_ir(), ) return MmaI8Trait( - make_atom(ty, (Boolean(False).ir_value(loc=loc, ip=ip),), loc=loc, ip=ip) + make_atom(ty, [Boolean(False).ir_value(loc=loc, ip=ip)], loc=loc, ip=ip) ) diff --git a/python/CuTeDSL/cutlass/cute/runtime.py b/python/CuTeDSL/cutlass/cute/runtime.py index da1a34d20..436547ec4 100644 --- a/python/CuTeDSL/cutlass/cute/runtime.py +++ b/python/CuTeDSL/cutlass/cute/runtime.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # Use of this software is governed by the terms and conditions of the @@ -9,14 +9,15 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. + import ctypes import sys +import math from pathlib import Path from functools import lru_cache import itertools import operator -from typing import Union, Optional, Type, List - +from typing import Any, Union, Optional, Type, List, NoReturn # MLIR modules imports from cutlass._mlir import ir @@ -31,13 +32,15 @@ from cutlass.base_dsl.export import ExternalBinaryModule # Local modules imports from .typing import ( AddressSpace, - Layout, + TypedTensor, Tensor, Pointer, Numeric, SymInt, Float32, TFloat32, + Shape, + Stride, ) from . import core from .tensor import _Tensor as CoreTensor @@ -66,11 +69,11 @@ class _Pointer(Pointer): def __init__( self, - pointer, - dtype, + pointer: int, + dtype: Type[Numeric], mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic, - assumed_align=None, - ): + assumed_align: Optional[int] = None, + ) -> None: self._pointer = pointer self._dtype = dtype self._addr_space = mem_space @@ -92,16 +95,16 @@ class _Pointer(Pointer): self._desc = ctypes.c_void_p(int(self._pointer)) return ctypes.sizeof(self._desc) - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> List[ir.Type]: return [self.mlir_type] - def __tvm_ffi_opaque_ptr__(self): + def __tvm_ffi_opaque_ptr__(self) -> object: return self._pointer - def __c_pointers__(self): + def __c_pointers__(self) -> List[int]: return self._c_pointers_cache - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[object]) -> object: # type: ignore[override] assert len(values) == 1 return values[0] @@ -117,16 +120,32 @@ class _Pointer(Pointer): return self._dtype @property - def memspace(self): + def memspace(self) -> AddressSpace: return self._addr_space - def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + def align( + self, + min_align: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Pointer: raise NotImplementedError("align is not supported in runtime") + def __add__(self, offset: int) -> Pointer: # type: ignore[override] + offset_bytes = offset * self._dtype.width // 8 + assumed_align = math.gcd(offset_bytes, self._assumed_align) + return _Pointer( + self._pointer + offset_bytes, self._dtype, self._addr_space, assumed_align + ) + + def __sub__(self, offset: int) -> Pointer: + return self.__add__(-offset) + def __str__(self) -> str: return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" - def __repr__(self): + def __repr__(self) -> str: return self.__str__() @property @@ -137,12 +156,12 @@ class _Pointer(Pointer): class _Tensor(Tensor): def __init__( self, - tensor, - assumed_align=None, - use_32bit_stride=False, + tensor: object, + assumed_align: Optional[int] = None, + use_32bit_stride: bool = False, *, - enable_tvm_ffi=False, - ): + enable_tvm_ffi: bool = False, + ) -> None: # If tensor is already a DLPack object, use it directly if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"): self._dlpack_data = tensor.__dlpack_device__() @@ -156,24 +175,24 @@ class _Tensor(Tensor): # we expect no stream sync. Because torch has different default behavior # for stream parameter on different version. # we need to explicitly pass -1 to achieve no sync effects. - self._dlpack_data = tensor.__dlpack__(stream=-1) + self._dlpack_data = tensor.__dlpack__(stream=-1) # type: ignore[attr-defined] except Exception: - self._dlpack_data = tensor.__dlpack__() + self._dlpack_data = tensor.__dlpack__() # type: ignore[attr-defined] - self._dltensor_wrapper = None + self._dltensor_wrapper: Any = None self._assumed_align = assumed_align self._is_dynamic = False - self._memref_desc = None - self._dtype = None + self._memref_desc: Any = None + self._dtype: Any = None self._use_32bit_stride = use_32bit_stride - self._c_pointers_cache = None + self._c_pointers_cache: Optional[List[int]] = None - @property - def __class__(self) -> Type[Tensor]: + @property # type: ignore[misc] + def __class__(self) -> Type[Tensor]: # type: ignore[override] # Cheat to let `type(_Tensor())` to return cute.Tensor return Tensor - def load_dltensor(self): + def load_dltensor(self) -> None: """Lazily load the DLTensorWrapper. This function loads the DLTensorWrapper when needed, @@ -184,23 +203,20 @@ class _Tensor(Tensor): self._dlpack_data, self._use_32bit_stride ) - def mark_layout_dynamic(self, leading_dim: Optional[int] = None): + def mark_layout_dynamic(self, leading_dim: Optional[int] = None) -> "_Tensor": """Marks the tensor layout as dynamic based on the leading dimension. :param leading_dim: The leading dimension of the layout, defaults to None :type leading_dim: int, optional - When ``leading_dim`` is None, the leading dimension is deduced as follows. + When ``leading_dim`` is None, the leading dimension is deduced as follows: - (1) If exactly one dimension has stride 1, that dimension is used. - - (2) If multiple dimensions have stride 1 but exactly one of them has size > 1, - that dimension is used. - - (3) If multiple dimensions have stride 1 but none or more than one has size > 1, - an error is raised. - - (4) If no dimension has stride 1, all strides remain dynamic. + - If exactly one dimension has stride 1, that dimension is used. + - If multiple dimensions have stride 1 but exactly one of them has size > 1, + that dimension is used. + - If multiple dimensions have stride 1 but none or more than one has size > 1, + an error is raised. + - If no dimension has stride 1, all strides remain dynamic. When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent @@ -220,7 +236,7 @@ class _Tensor(Tensor): mode: int, stride_order: Optional[tuple[int, ...]] = None, divisibility: int = 1, - ): + ) -> "_Tensor": """Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides. :param mode: The mode of the compact shape, defaults to 0 @@ -264,7 +280,7 @@ class _Tensor(Tensor): return self._dtype @element_type.setter - def element_type(self, new_type): + def element_type(self, new_type: Type[Numeric]) -> None: """Set the element type of the tensor. :warning: This API is added for narrow precision before we have a clean `recast_tensor` story. @@ -294,7 +310,7 @@ class _Tensor(Tensor): self._dtype = new_type @property - def memspace(self): + def memspace(self) -> AddressSpace: self.load_dltensor() return self._dltensor_wrapper.address_space @@ -314,7 +330,7 @@ class _Tensor(Tensor): self.load_dltensor() return f"Tensor<0x{self._dltensor_wrapper.str}>" - def __repr__(self): + def __repr__(self) -> str: return self.__str__() @property @@ -324,14 +340,14 @@ class _Tensor(Tensor): self._dtype = self._dltensor_wrapper.dtype return (self._dtype, self._assumed_align, self._dltensor_wrapper.cache_key()) - def __setitem__(self, crd, value): + def __setitem__(self, crd: object, value: object) -> None: raise TypeError("runtime._Tensor is not indexable") - def __getitem__(self, crd): + def __getitem__(self, crd: object) -> NoReturn: raise TypeError("runtime._Tensor is not indexable") @property - def iterator(self): + def iterator(self) -> _Pointer: self.load_dltensor() return _Pointer( self._dltensor_wrapper.data_ptr, @@ -341,31 +357,33 @@ class _Tensor(Tensor): ) @property - def layout(self): + def layout(self) -> NoReturn: raise NotImplementedError( "layout property is not supported in runtime, support in future" ) @property - def shape(self): + def shape(self) -> Shape: self.load_dltensor() return self._dltensor_wrapper.shape @property - def stride(self): + def stride(self) -> Stride: self.load_dltensor() strides = self._dltensor_wrapper.stride if strides is None: + # support tensor created by the old numpy version strides = itertools.accumulate( - reversed(self.shape), func=operator.mul, initial=1 + reversed(self.shape), # type: ignore[arg-type] + func=operator.mul, + initial=1, ) strides = tuple(reversed(list(strides)[:-1])) - return strides @property @lru_cache(maxsize=128, typed=True) - def leading_dim(self): + def leading_dim(self) -> Union[int, tuple[int, ...], None]: """Get the leading dimension of this Tensor. :return: The leading dimension index or indices @@ -379,27 +397,27 @@ class _Tensor(Tensor): """ return core.leading_dim(self.shape, self.stride) - def fill(self, value: Numeric): + def fill(self, value: Numeric) -> None: raise TypeError("fill function is not supported in runtime") @property - def data_ptr(self): + def data_ptr(self) -> int: self.load_dltensor() return self._dltensor_wrapper.data_ptr @property - def dynamic_shapes_mask(self): + def dynamic_shapes_mask(self) -> tuple[int, ...]: """Get the mask of dynamic shapes in the tensor.""" self.load_dltensor() return self._dltensor_wrapper.get_dynamic_shapes_mask() @property - def dynamic_strides_mask(self): + def dynamic_strides_mask(self) -> tuple[int, ...]: """Get the mask of dynamic strides in the tensor.""" self.load_dltensor() return self._dltensor_wrapper.get_dynamic_strides_mask() - def __c_pointers__(self): + def __c_pointers__(self) -> List[int]: if self._c_pointers_cache is None: self.load_dltensor() self._memref_desc = self._dltensor_wrapper.build_memref_desc( @@ -408,15 +426,15 @@ class _Tensor(Tensor): self._c_pointers_cache = [_cute_ir.pycapsule_get_pointer(self._memref_desc)] return self._c_pointers_cache - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> List[ir.Type]: return [self.mlir_type] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[object]) -> CoreTensor: assert len(values) == 1 assert isinstance(values[0], CoreTensor) return CoreTensor(values[0].value, self._dtype) - def __tvm_ffi_object__(self): + def __tvm_ffi_object__(self) -> object: try: return self._tvm_ffi_tensor except AttributeError: @@ -429,14 +447,6 @@ class _Tensor(Tensor): ) -def _get_cute_type_str(inp): - def _convert_dyn_elem(e): - return f"?{{i{e.width} div={e.divisibility}}}" - - elems = [_convert_dyn_elem(e) if isinstance(e, SymInt) else str(e) for e in inp] - return "(" + ",".join(elems) + ")" - - class _FakeTensor(Tensor): """Fake Tensor implementation as a placeholder. It mimics the interface of Tensor, but does not hold real data or allow indexing. @@ -462,7 +472,6 @@ class _FakeTensor(Tensor): when the dimension is dynamic. :type use_32bit_stride: bool, optional - """ def __init__( @@ -472,132 +481,106 @@ class _FakeTensor(Tensor): *, stride: tuple[Union[int, SymInt], ...], memspace: AddressSpace = AddressSpace.gmem, - assumed_align: int | None = None, + assumed_align: Optional[int] = None, use_32bit_stride: bool = False, - compact: bool = False, - ): - self._dtype = dtype - self._shape = shape - self._stride = stride - self._use_32bit_stride = use_32bit_stride - self._compact = compact - + ) -> None: if not isinstance(shape, (tuple, list)): raise ValueError(f"Expected tuple or list but got {type(shape)}") - if not all(isinstance(s, (int, SymInt)) for s in self._shape): + if isinstance(shape, list): + shape = tuple(shape) + if not all(isinstance(s, (int, SymInt)) for s in shape): raise ValueError("All shape elements must be int or SymInt") - if stride is not None and not all( - isinstance(s, (int, SymInt)) for s in self._stride - ): + if isinstance(stride, list): + stride = tuple(stride) + + if stride is not None and not all(isinstance(s, (int, SymInt)) for s in stride): raise ValueError("All stride elements must be int or SymInt") - self._memspace = memspace - self._assumed_align = assumed_align - if assumed_align is None: - # use the bytes width of the element dtype. The alignment is at least one byte align. - self._assumed_align = (self._dtype.width + 7) // 8 + self._typed_tensor = TypedTensor(dtype, shape, stride, memspace, assumed_align) # type: ignore[arg-type] + self._assumed_align = self._typed_tensor._assumed_align + self._use_32bit_stride = use_32bit_stride @property def mlir_type(self) -> ir.Type: - shape_str = _get_cute_type_str(self._shape) - stride_str = _get_cute_type_str(self._stride) - layout_ty = ir.Type.parse(f'!cute.layout<"{shape_str}:{stride_str}">') + return self._typed_tensor.mlir_type # pragma: no cover - # Boolean types are stored as i8 in memory - elem_type = T.i8() if self._dtype.width == 1 else self._dtype.mlir_type - ptr_ty = _cute_ir.PtrType.get(elem_type, self._memspace, self._assumed_align) - return _cute_ir.MemRefType.get(ptr_ty, layout_ty) + def __get_mlir_types__(self) -> list[ir.Type]: + return self._typed_tensor.__get_mlir_types__() - def __get_mlir_types__(self): - return [self.mlir_type] - - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[object]) -> CoreTensor: assert len(values) == 1 assert isinstance(values[0], CoreTensor) - return CoreTensor(values[0].value, self._dtype) + return CoreTensor(values[0].value, self.element_type) def __str__(self) -> str: - return f"FakeTensor<{self._dtype}, {self._shape}, {self._stride}>" + return f"FakeTensor<{self.element_type}, {self.shape}, {self.stride}>" @property def __cache_key__(self) -> tuple: - # Check if any shape or stride element is a SymInt without a symbol - import warnings - - has_unnamed_symint = False - for dim in self._shape: - if isinstance(dim, SymInt) and dim.symbol is None: - has_unnamed_symint = True - break - if not self._compact: - if not has_unnamed_symint: - for stride in self._stride: - if isinstance(stride, SymInt) and stride.symbol is None: - has_unnamed_symint = True - break - - if has_unnamed_symint: - warnings.warn( - "FakeTensor cache_key contains unnamed symbolic dimensions. " - "Different variables with the same shape/stride pattern will have " - "identical cache keys, which may cause incorrect cache hits. " - "Consider using 'symbol' parameter to distinguish variables: " - "cute.sym_int32(symbol='M'), cute.sym_int32(symbol='N')", - UserWarning, - stacklevel=2, - ) + # Use id() for SymInt elements to match TVM FFI's identity-based + # deduplication (SymIntId). This ensures that different SymInt objects + # produce different cache keys even if they have the same symbol name, + # preventing incorrect cache hits when kernels have different signatures. + def _cache_key_element(e: object) -> object: + return id(e) if isinstance(e, SymInt) else e return ( - self._dtype, - self._memspace, - self._assumed_align, - self._shape, - self._stride, + self.element_type, + self.memspace, + self._typed_tensor.assumed_align, + tuple(_cache_key_element(s) for s in self.shape), # type: ignore[union-attr] + tuple(_cache_key_element(s) for s in self.stride), # type: ignore[union-attr] ) - def __repr__(self): + def __repr__(self) -> str: return self.__str__() - def __setitem__(self, crd, value): + def __setitem__(self, crd: object, value: object) -> None: raise DSLRuntimeError("runtime._FakeTensor is not indexable") - def __getitem__(self, crd): + def __getitem__(self, crd: object) -> NoReturn: raise DSLRuntimeError("runtime._FakeTensor is not indexable") - @property + @property # type: ignore[misc] def element_type(self) -> Type[Numeric]: - return self._dtype + return self._typed_tensor.element_type @property - def memspace(self): - return self._memspace + def memspace(self) -> AddressSpace: + return self._typed_tensor.memspace @property - def iterator(self): + def iterator(self) -> NoReturn: raise DSLRuntimeError("runtime._FakeTensor has dummy iterator") @property - def shape(self): - return self._shape + def shape(self) -> Shape: + return self._typed_tensor.shape @property - def stride(self): - return self._stride + def stride(self) -> Stride: + return self._typed_tensor.stride @property - def leading_dim(self): - return core.leading_dim(self._shape, self._stride) + def leading_dim(self) -> Union[int, tuple[int, ...], None]: + return core.leading_dim(self._typed_tensor.shape, self._typed_tensor.stride) @property - def dynamic_shapes_mask(self): - return tuple(1 if isinstance(e, SymInt) else 0 for e in self._shape) + def dynamic_shapes_mask(self) -> tuple[int, ...]: + return tuple( + 1 if isinstance(e, SymInt) else 0 + for e in self._typed_tensor.shape # type: ignore[union-attr] + ) @property - def dynamic_strides_mask(self): - return tuple(1 if isinstance(e, SymInt) else 0 for e in self._stride) + def dynamic_strides_mask(self) -> tuple[int, ...]: + return tuple( + 1 if isinstance(e, SymInt) else 0 + for e in self._typed_tensor.stride # type: ignore[union-attr] + ) - def fill(self, value: Numeric): + def fill(self, value: Numeric) -> None: raise DSLRuntimeError("runtime._FakeTensor is not writable") @@ -609,14 +592,14 @@ def make_fake_compact_tensor( memspace: AddressSpace = AddressSpace.gmem, assumed_align: Optional[int] = None, use_32bit_stride: bool = False, -): +) -> _FakeTensor: """ Create a fake tensor with the specified shape, element type, and a compact memory layout. :param dtype: Data type of the tensor elements. :type dtype: Type[Numeric] :param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions. - :type shape: tuple[Union[int, SymInt], ...] + :type shape: tuple[int | SymInt, ...] :param stride_order: Order in which strides (memory layout) are assigned to the tensor dimensions. If None, the default layout is left-to-right order (known as column-major order for flatten layout). Otherwise, it should be a permutation order of the dimension indices. @@ -667,27 +650,24 @@ def make_fake_compact_tensor( stride_product = 1 for order in range(len(stride_order)): idx = stride_order.index(order) - stride[idx] = stride_product - stride_product *= shape[idx] + stride[idx] = stride_product # type: ignore[call-overload] + stride_product *= shape[idx] # type: ignore[assignment] stride_width = 32 if use_32bit_stride else 64 - stride = tuple( - ( - SymInt(width=stride_width, divisibility=s.divisibility) - if isinstance(s, SymInt) - else s - ) + stride = tuple( # type: ignore[assignment] + SymInt(width=stride_width, divisibility=s.divisibility) + if isinstance(s, SymInt) + else s for s in stride ) return _FakeTensor( dtype, shape, - stride=stride, + stride=stride, # type: ignore[arg-type] memspace=memspace, assumed_align=assumed_align, use_32bit_stride=use_32bit_stride, - compact=True, ) @@ -697,17 +677,17 @@ def make_fake_tensor( stride: tuple[Union[int, SymInt], ...], *, memspace: AddressSpace = AddressSpace.gmem, - assumed_align: Optional[int] = None, -): + assumed_align: int | None = None, +) -> _FakeTensor: """ Create a fake tensor with the specified element type, shape, and stride. :param dtype: Data type of the tensor elements. :type dtype: Type[Numeric] :param shape: Shape of the tensor, consisting of static (int) or dynamic (SymInt) dimensions. - :type shape: tuple[Union[int, SymInt], ...] + :type shape: tuple[int | SymInt, ...] :param stride: Stride of the tensor, consisting of static (int) or dynamic (SymInt) values. - :type stride: tuple[Union[int, SymInt], ...] + :type stride: tuple[int | SymInt, ...] :param memspace: Memory space where the fake tensor resides. Defaults to AddressSpace.gmem. :type memspace: AddressSpace, optional :param assumed_align: Assumed byte alignment for the tensor data. If None, the default alignment is the dtype width, & at least 1 byte. @@ -731,27 +711,27 @@ class _FakeStream: use_tvm_ffi_env_stream: bool - def __init__(self, *, use_tvm_ffi_env_stream: bool = False): + def __init__(self, *, use_tvm_ffi_env_stream: bool = False) -> None: self.use_tvm_ffi_env_stream = use_tvm_ffi_env_stream def __str__(self) -> str: - return f"FakeStream" + return "FakeStream" - def __repr__(self): + def __repr__(self) -> str: return self.__str__() - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[object]) -> object: assert len(values) == 1 return values[0] - def __c_pointers__(self): + def __c_pointers__(self) -> List[int]: return [0] - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> List[ir.Type]: return [_cuda_dialect.StreamType.get()] -def make_fake_stream(*, use_tvm_ffi_env_stream: bool = False): +def make_fake_stream(*, use_tvm_ffi_env_stream: bool = False) -> _FakeStream: """Create a fake stream that can be used as a placeholder for a stream in compilation. When use_tvm_ffi_env_stream is True and the function is compiled with TVM-FFI, @@ -767,12 +747,12 @@ def make_fake_stream(*, use_tvm_ffi_env_stream: bool = False): def from_dlpack( - tensor_dlpack, - assumed_align=None, - use_32bit_stride=False, + tensor_dlpack: object, + assumed_align: Optional[int] = None, + use_32bit_stride: bool = False, *, - enable_tvm_ffi=False, - force_tf32=False, + enable_tvm_ffi: bool = False, + force_tf32: bool = False, ) -> Tensor: """Convert from tensor object supporting __dlpack__() to a CuTe Tensor. @@ -793,6 +773,12 @@ def from_dlpack( :return: A CuTe Tensor object :rtype: Tensor + For packed subbyte torch dtypes such as ``torch.float4_e2m1fn_x2``, + ``from_dlpack`` returns the logical element layout expected by CuTe instead + of the packed storage layout. For example, a torch tensor with shape + ``(128, 128)`` and dtype ``torch.float4_e2m1fn_x2`` is exposed as a logical + FP4 tensor with shape ``(128, 256)``. + **Examples:** .. code-block:: python @@ -824,7 +810,7 @@ def make_ptr( dtype: Type[Numeric], value: Union[int, ctypes._Pointer], mem_space: AddressSpace = AddressSpace.generic, - assumed_align=None, + assumed_align: Optional[int] = None, ) -> Pointer: """Create a pointer from a memory address @@ -865,7 +851,7 @@ def make_ptr( address_value = value elif isinstance(value, ctypes._Pointer): # get address value - address_value = ctypes.cast(value, ctypes.c_void_p).value + address_value = ctypes.cast(value, ctypes.c_void_p).value # type: ignore[assignment] assert address_value is not None, "Pointer address is None" else: raise TypeError( @@ -878,7 +864,7 @@ def make_ptr( def nullptr( dtype: Type[Numeric], mem_space: AddressSpace = AddressSpace.generic, - assumed_align=None, + assumed_align: Optional[int] = None, ) -> Pointer: """Create a null pointer which is useful for compilation @@ -897,22 +883,22 @@ class TensorAdapter: Convert a DLPack protocol supported tensor/array to a cute tensor. """ - def __init__(self, arg): + def __init__(self, arg: object) -> None: self._arg = from_dlpack(arg).mark_layout_dynamic() - self._c_pointers_cache = None - self._mlir_types_cache = None + self._c_pointers_cache: Optional[list[int]] = None + self._mlir_types_cache: Optional[list[ir.Type]] = None - def __new_from_mlir_values__(self, values): - return self._arg.__new_from_mlir_values__(values) + def __new_from_mlir_values__(self, values: list[object]) -> object: + return self._arg.__new_from_mlir_values__(values) # type: ignore[attr-defined] - def __c_pointers__(self): + def __c_pointers__(self) -> list[int]: if self._c_pointers_cache is None: - self._c_pointers_cache = self._arg.__c_pointers__() + self._c_pointers_cache = self._arg.__c_pointers__() # type: ignore[attr-defined] return self._c_pointers_cache - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> list[ir.Type]: if self._mlir_types_cache is None: - self._mlir_types_cache = self._arg.__get_mlir_types__() + self._mlir_types_cache = self._arg.__get_mlir_types__() # type: ignore[attr-defined] return self._mlir_types_cache @@ -926,16 +912,16 @@ def find_runtime_libraries(*, enable_tvm_ffi: bool = True) -> List[str]: :rtype: list """ - def _get_cute_dsl_runtime_path(): + def _get_cute_dsl_runtime_path() -> Optional[str]: libs = get_prefix_dsl_libs("CUTE_DSL") if libs is None: return None # check if the separator is ; for windows if sys.platform.startswith("win32") and ";" in libs: - libs = libs.split(";") + libs = libs.split(";") # type: ignore[assignment] else: - libs = libs.split(":") + libs = libs.split(":") # type: ignore[assignment] for path in libs: if path.endswith("libcute_dsl_runtime.so"): @@ -956,10 +942,12 @@ def find_runtime_libraries(*, enable_tvm_ffi: bool = True) -> List[str]: return libs # cache to load runtime libraries so they can be found by the DSO loader -_LOAD_MODULE_LIBS_CACHE = [] +_LOAD_MODULE_LIBS_CACHE: list[Any] = [] -def load_module(file_path: str, *, enable_tvm_ffi: bool = False): +def load_module( + file_path: str, *, enable_tvm_ffi: bool = False +) -> ExternalBinaryModule: """Load a module from a file path. :param file_path: The path to the module file diff --git a/python/CuTeDSL/cutlass/cute/tensor.py b/python/CuTeDSL/cutlass/cute/tensor.py index 587ee920b..0627fc870 100644 --- a/python/CuTeDSL/cutlass/cute/tensor.py +++ b/python/CuTeDSL/cutlass/cute/tensor.py @@ -9,7 +9,8 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Optional, Union, Type, Tuple, overload + +from typing import Any, Callable, Optional, Union, Type, Tuple, overload, List from typing_extensions import deprecated from inspect import isclass import operator @@ -27,8 +28,7 @@ from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir from cutlass._mlir.dialects.cute import ReductionOp as ReductionOp import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir -from cutlass._mlir.dialects import vector, arith - +from cutlass._mlir.dialects import vector, arith, llvm from .typing import ( Numeric, Integer, @@ -37,8 +37,8 @@ from .typing import ( Uint8, Int8, Int32, - Int64, BFloat16, + Float32, IntTuple, Coord, Shape, @@ -60,6 +60,8 @@ from .core import ( _pack_shape, _ComposedLayout, _ComposedLayoutWithInnerFunc, + append_ones, + is_major, is_static, is_weakly_congruent, rank, @@ -68,6 +70,7 @@ from .core import ( flatten, has_underscore, make_layout, + select, slice_, crd2idx, size, @@ -77,7 +80,10 @@ from .core import ( ) from .tuple import transform_leaf, product, product_like, flatten_to_tuple -from .arch import cvt_i8_bf16_intrinsic, cvt_i4_bf16_intrinsic +from .arch import ( + cvt_i8_bf16_intrinsic, + cvt_i4_bf16_intrinsic, +) __all__ = [ @@ -137,8 +143,13 @@ class _Tensor(Tensor): @dsl_user_op def __init__( - self, value, dtype: Optional[Type[Numeric]] = None, *, loc=None, ip=None - ): + self, + value: Union[ir.Value, "_Tensor"], + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Initialize a Tensor from an MLIR value. :param value: The MLIR operation result value or another Tensor to initialize from @@ -157,8 +168,6 @@ class _Tensor(Tensor): self.value = value elif isinstance(value, _Tensor): self.value = value.value - elif isinstance(value, _Tensor): - self.value = value.value else: raise TypeError(f"Expected ir.Value or _Tensor, got {type(value)}") @@ -168,9 +177,9 @@ class _Tensor(Tensor): self._iterator = iter_val elif isinstance(iter_val.type, _cute_ir.ArithTupleIteratorType): itup_val = _cute_ir.deref_arith_tuple_iter(iter_val) - self._iterator = _unpack_x_tuple(itup_val) + self._iterator = _unpack_x_tuple(itup_val) # type: ignore[assignment] elif isinstance(iter_val, ir.Value): - # Example: SMEM descriptor iterator, not well supported today + # SMEM descriptor iterator requires specific vec_mode layout configuration self._iterator = iter_val else: raise TypeError(f"unsupported iterator type, got {type(iter_val)}") @@ -178,27 +187,29 @@ class _Tensor(Tensor): # Set dtype if self._dtype is None: if is_int_tuple(self.iterator): - self._dtype = IntTuple + self._dtype = IntTuple # type: ignore[assignment] elif isinstance(self.iterator, Pointer): self._dtype = self.iterator.value_type elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): - # SmemDescViewType do not need dtype + # SmemDescViewType requires specific vec_mode layout configuration self._dtype = None else: raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") - def __repr__(self): + def __repr__(self) -> str: return self.__str__() - def __str__(self): + def __str__(self) -> str: from .core import pretty_str return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> List[ir.Value]: return [self.value] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__( + self, values: List[Union["_Tensor", ir.Value]] + ) -> "_Tensor": # Only expecting single value of _Tensor or ir.Value # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created # by value caster for MemRef/CoordTensor/SmemDescView typed values @@ -211,20 +222,13 @@ class _Tensor(Tensor): dtype=self.element_type, ) - # Cheat to let `Type(_Tensor())` to return cute.Tensor - @property - def __class__(self) -> Type[Tensor]: - return Tensor - - # Make it behave as if it inherited from ir.Value - @property - @lru_cache_ir() - def type(self) -> ir.Type: - return self.value.type - @dsl_user_op def __getitem__( - self, crd: Coord, *, loc=None, ip=None + self, + crd: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Tensor, Numeric, IntTuple]: """Access or slice tensor elements using coordinates. @@ -293,7 +297,9 @@ class _Tensor(Tensor): return slice_(self, crd, loc=loc, ip=ip) elif isinstance(self.type, _cute_ir.CoordTensorType): res = _cute_ir.get_iter( - slice_(self, crd, loc=loc, ip=ip).value, loc=loc, ip=ip + slice_(self, crd, loc=loc, ip=ip).value, + loc=loc, + ip=ip, ) itup_val = _cute_ir.deref_arith_tuple_iter(res) return _unpack_x_tuple(itup_val) @@ -305,16 +311,22 @@ class _Tensor(Tensor): data_val = _cute_ir.memref_load(self.value, crd_val, loc=loc, ip=ip) return self.element_type(data_val) - def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): + def _cvt_to_dest( + self, + data: Union["TensorSSA", Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: orig_dtype = data.dtype # Implicit upcast to wider type if ( data.dtype.is_same_kind(self.element_type) - and self.element_type.width >= data.dtype.width + and self.element_type.width >= data.dtype.width # type: ignore[union-attr] ): - data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore + data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore[assignment] - if data.dtype.width != self.element_type.width: + if data.dtype.width != self.element_type.width: # type: ignore[union-attr] raise ValueError( f"Type mismatch, store {orig_dtype} (-> {data.dtype}) " f"to Tensor with element type {self.element_type}" @@ -322,7 +334,7 @@ class _Tensor(Tensor): if data.dtype is Boolean and self.element_type is Boolean: # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory - val = data.ir_value_int8(loc=loc, ip=ip) + val = data.ir_value_int8(loc=loc, ip=ip) # type: ignore[union-attr] else: val = data.ir_value(loc=loc, ip=ip) return val @@ -333,8 +345,8 @@ class _Tensor(Tensor): crd: Coord, data: Union[int, float, ir.Value, Numeric, "TensorSSA"], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """Set tensor elements at specified coordinates. @@ -380,7 +392,6 @@ class _Tensor(Tensor): # convert scalar type if not has_underscore(crd): self._check_can_dereference() - # First, convert ir.Value to Numeric if isinstance(data, ir.Value): data = as_numeric(data) elif isinstance(data, (int, float, bool)): @@ -402,10 +413,11 @@ class _Tensor(Tensor): if not isinstance(data, TensorSSA): raise ValueError(f"Expected TensorSSA, but got {data}") - self.__getitem__(crd, loc=loc, ip=ip).store(data, loc=loc, ip=ip) # type: ignore + self.__getitem__(crd, loc=loc, ip=ip).store(data, loc=loc, ip=ip) - @property - def __class__(self) -> Type[Tensor]: + # Cheat to let `Type(_Tensor())` to return cute.Tensor + @property # type: ignore[misc] + def __class__(self) -> Type[Tensor]: # type: ignore[override] return Tensor # Make it behave as if it inherited from ir.Value @@ -422,13 +434,23 @@ class _Tensor(Tensor): @property @dsl_user_op @lru_cache_ir() - def layout(self, *, loc=None, ip=None) -> Layout: + def layout( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: return _cute_ir.get_layout(self.value, loc=loc, ip=ip) @property @dsl_user_op @lru_cache_ir() - def shape(self, *, loc=None, ip=None) -> Shape: + def shape( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: return self.layout.shape_method(loc=loc, ip=ip) @property @@ -451,12 +473,16 @@ class _Tensor(Tensor): :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` """ - return leading_dim(self.shape, self.stride) + return leading_dim(self.shape, self.stride) # type: ignore[return-value] @property + def dtype(self) -> Type[Numeric]: + return self._dtype # type: ignore[return-value] + + @property # type: ignore[misc] @lru_cache_ir() def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: - return self._dtype + return self._dtype # type: ignore[return-value] @property @lru_cache_ir() @@ -472,8 +498,8 @@ class _Tensor(Tensor): *, mask: Optional["TensorSSA"] = None, pass_thru: Optional["TensorSSA"] = None, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> "TensorSSA": """Load tensor elements as a vector. @@ -524,9 +550,9 @@ class _Tensor(Tensor): data: "TensorSSA", *, mask: Optional["TensorSSA"] = None, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Store vector data into tensor. Stores vector data into the tensor, assuming matching shapes and a memory space @@ -560,10 +586,15 @@ class _Tensor(Tensor): ) elem_mlir_type = cutlass_arith.element_type(data.dtype.mlir_type) - if cutlass_arith.is_narrow_precision(elem_mlir_type): + if ( + cutlass_arith.is_narrow_precision(elem_mlir_type) + and elem_mlir_type.width < 8 + ): + n_elems = size(self.shape, loc=loc, ip=ip) if elem_mlir_type.width * n_elems % 32 != 0: raise ValueError( - f"narrow precision type must be 32-bit aligned vector, but got {elem_mlir_type} with {n_elems} elements" + f"narrow precision type must be 32-bit aligned vector, " + f"but got {elem_mlir_type} with {n_elems} elements" ) # Implicit upcast to wider type @@ -576,7 +607,13 @@ class _Tensor(Tensor): ) @dsl_user_op - def fill(self, value: Numeric, *, loc=None, ip=None) -> None: + def fill( + self, + value: Numeric, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Fill tensor with a constant value. Fills all elements of the tensor with the specified value, assuming static size @@ -616,7 +653,7 @@ class _Tensor(Tensor): ) self.store(vect_val, loc=loc, ip=ip) - def _check_can_load_store(self, vectorized: bool = False): + def _check_can_load_store(self, vectorized: bool = False) -> None: if not isinstance(self.type, _cute_ir.MemRefType) or self.memspace not in ( AddressSpace.rmem, AddressSpace.smem, @@ -630,9 +667,14 @@ class _Tensor(Tensor): "vectorized load/store on tensor with composed layout is not supported yet" ) - def _check_can_dereference(self): + def _check_can_dereference(self) -> None: + sub_byte_types = ( + type(Boolean), + ) # Check for sub-byte types and raise error if needed - if self.element_type.width % 8 != 0 and self.element_type is not Boolean: + if self.element_type.width % 8 != 0 and not isinstance( + self.element_type, sub_byte_types + ): raise ValueError( f"Sub-byte scalar dereference not supported for type {self.element_type}" ) @@ -645,7 +687,11 @@ class _Tensor(Tensor): @dsl_user_op def make_tensor( - iterator, layout: Union[Shape, Layout, ComposedLayout], *, loc=None, ip=None + iterator: Union[Pointer, IntTuple, ir.Value], + layout: Union[Shape, Layout, ComposedLayout], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: """Creates a tensor by composing an engine (iterator/pointer) with a layout. @@ -714,19 +760,21 @@ def make_tensor( res_ty = None if is_integer(iterator) or isinstance(iterator, tuple): - itup_val = _pack_int_tuple(iterator, loc=loc, ip=ip) + itup_val = _pack_int_tuple(iterator, loc=loc, ip=ip) # type: ignore[arg-type] iter_ty = _cute_ir.ArithTupleIteratorType.get(itup_val.type) iterator = _cute_ir.make_arith_tuple_iter( iter=iter_ty, value=itup_val, loc=loc, ip=ip ) - res_ty = _cute_ir.CoordTensorType.get(itup_val.type, layout.type) + res_ty = _cute_ir.CoordTensorType.get(itup_val.type, layout.type) # type: ignore[union-attr] elif isinstance(iterator, Pointer): iterator = iterator.value - res_ty = _cute_ir.MemRefType.get(iterator.type, layout.type) + res_ty = _cute_ir.MemRefType.get(iterator.type, layout.type) # type: ignore[union-attr] elif isinstance(iterator, ir.Value) and isinstance( - iterator.type, _cute_nvgpu_ir.SmemDescType + iterator.type, + _cute_nvgpu_ir.SmemDescType, ): - res_ty = _cute_nvgpu_ir.SmemDescViewType.get(layout.type) + # SmemDescType requires specific vec_mode layout configuration + res_ty = _cute_nvgpu_ir.SmemDescViewType.get(layout.type) # type: ignore[union-attr] else: raise TypeError(f"unsupported iterator type, got {type(iterator)}") @@ -738,7 +786,12 @@ def make_tensor( @dsl_user_op -def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: +def make_identity_tensor( + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: """Creates an identity tensor with the given shape. An identity tensor maps each coordinate to itself, effectively creating a counting @@ -780,7 +833,11 @@ def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: @dsl_user_op def make_rmem_tensor( - layout_or_shape: Union[Layout, Shape], dtype: Type[Numeric], *, loc=None, ip=None + layout_or_shape: Union[Layout, Shape], + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: """Creates a tensor in register memory with the specified layout/shape and data type. @@ -827,6 +884,7 @@ def make_rmem_tensor( if not isinstance(layout_or_shape, Layout): layout = make_layout(layout_or_shape, loc=loc, ip=ip) elif isinstance(layout_or_shape, _ComposedLayout): + # Defensive check: make_rmem_tensor doesn't accept ComposedLayout objects layout = layout_or_shape.value else: layout = layout_or_shape @@ -840,7 +898,11 @@ def make_rmem_tensor( @dsl_user_op @deprecated("`make_fragment` is deprecated, use `make_rmem_tensor` instead") def make_fragment( - layout_or_shape: Union[Layout, Shape], dtype: Type[Numeric], *, loc=None, ip=None + layout_or_shape: Union[Layout, Shape], + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: return make_rmem_tensor(layout_or_shape, dtype, loc=loc, ip=ip) @@ -850,8 +912,8 @@ def make_rmem_tensor_like( src: Union[Layout, ComposedLayout, Tensor, "TensorSSA"], dtype: Optional[Type[Numeric]] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: """Creates a tensor in register memory with the same shape as the input layout but compact col-major strides. This is equivalent to calling `make_rmem_tensor(make_layout_like(tensor))`. @@ -906,7 +968,7 @@ def make_rmem_tensor_like( ) if isinstance(src, Tensor): - if isinstance(src.type, _cute_ir.CoordTensorType): + if isinstance(src.type, _cute_ir.CoordTensorType): # type: ignore[union-attr] if dtype is None: raise ValueError( "dtype must be provided when src is a coordinate tensor" @@ -916,7 +978,7 @@ def make_rmem_tensor_like( compact_layout = make_layout(src.shape, loc=loc, ip=ip) src_layout = _cute_ir.make_layout_like(compact_layout, loc=loc, ip=ip) else: - res_dtype = dtype or src.element_type + res_dtype = dtype or src.element_type # type: ignore[assignment] src_layout = src.layout elif isinstance(src, TensorSSA): res_dtype = dtype or src.element_type @@ -937,16 +999,36 @@ def make_rmem_tensor_like( @overload def make_fragment_like( - src: Tensor, dtype: Optional[Type[Numeric]], *, loc=None, ip=None + src: Tensor, + dtype: Optional[Type[Numeric]], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tensor: ... @overload -def make_fragment_like(src: Layout, *, loc=None, ip=None) -> Layout: ... +def make_fragment_like( + src: Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: ... @overload -def make_fragment_like(src: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... +def make_fragment_like( + src: ComposedLayout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ComposedLayout: ... @dsl_user_op -def make_fragment_like(src, dtype=None, *, loc=None, ip=None): +def make_fragment_like( + src: Union[Layout, ComposedLayout, Tensor], + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Layout, Tensor]: # Keep code to avoid potential regression if isinstance(src, (Layout, _ComposedLayout)): if isinstance(src, _ComposedLayout): @@ -964,8 +1046,13 @@ def make_fragment_like(src, dtype=None, *, loc=None, ip=None): @dsl_user_op def recast_tensor( - src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None -): + src: Tensor, + dtype: Type[Numeric], + swizzle_: object = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: """Recast a tensor to a different data type by changing the element interpretation. This function reinterprets the memory of a tensor with a different element type, @@ -997,28 +1084,32 @@ def recast_tensor( # Both tensors share the same memory, but interpret it differently """ - if not isclass(dtype) or not issubclass(dtype, Numeric): - raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") - - if dtype is Boolean: - dst_width = 8 - else: - dst_width = dtype.width + dst_width = None + if dst_width is None: + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + dst_width = 8 if dtype is Boolean else dtype.width if src.element_type is Boolean: src_width = 8 else: - src_width = src.element_type.width + src_width = src.element_type.width # type: ignore[union-attr] src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) - return type(src)( + return type(src)( # type: ignore[call-arg] make_tensor(src_iter, src_layout, loc=loc, ip=ip), dtype=dtype, loc=loc, ip=ip ) @dsl_user_op -def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: +def domain_offset( + coord: Coord, + tensor: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: """Offset the tensor domain by the given coordinate. This function creates a new tensor by offsetting the iterator/pointer of the input tensor @@ -1052,7 +1143,7 @@ def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip) if isinstance(tensor.iterator, Pointer): return make_tensor( - tensor.iterator.__add__(offset, loc=loc, ip=ip), + tensor.iterator.__add__(offset, loc=loc, ip=ip), # type: ignore[call-arg] tensor.layout, loc=loc, ip=ip, @@ -1071,13 +1162,17 @@ def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: ip=ip, ) else: + # Defensive check: all valid tensors have Pointer or int/tuple iterators raise ValueError(f"unsupported tensor for domain_offset, got {tensor}") - @dsl_user_op def print_tensor( - tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None -): + tensor: Union[Tensor, "TensorSSA"], + *, + verbose: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """Print content of the tensor in human readable format. Outputs the tensor data in a structured format showing both metadata @@ -1113,22 +1208,32 @@ def print_tensor( tmp.store(tensor) tensor = tmp - if isinstance(tensor.type, _cute_ir.MemRefType): - if tensor.element_type.is_integer: - signed = tensor.element_type.signed + if isinstance(tensor.type, _cute_ir.MemRefType): # type: ignore[union-attr] + if tensor.element_type.is_integer: # type: ignore[union-attr] + signed = tensor.element_type.signed # type: ignore[union-attr] else: signed = False - elif isinstance(tensor.type, _cute_ir.CoordTensorType): + elif isinstance(tensor.type, _cute_ir.CoordTensorType): # type: ignore[union-attr] signed = True else: - raise ValueError(f"unsupported tensor type for print_tensor, got {tensor.type}") - + # Defensive check: all valid tensors are either MemRefType or CoordTensorType + raise ValueError(f"unsupported tensor type for print_tensor, got {tensor.type}") # type: ignore[union-attr] _cute_ir.print_view(tensor.value, verbose=verbose, is_signed=signed, loc=loc, ip=ip) -def _get_row_and_col_map(col_maj_shape_1d: tuple, is_row_to_col: bool): - """ - Create an index mapping mask for converting between row-major and column-major vector ordering. +def _get_row_and_col_map(col_maj_shape_1d: tuple, is_row_to_col: bool) -> list: + """Create an index mapping mask for converting between row-major and column-major vector ordering. + + This helper function generates a permutation array that maps between row-major and + column-major orderings of vector elements. + + :param col_maj_shape_1d: The shape tuple in column-major order + :type col_maj_shape_1d: tuple + :param is_row_to_col: If True, generates row-to-column mapping; if False, column-to-row + :type is_row_to_col: bool + :return: A list representing the index permutation + :rtype: list + :raises ValueError: If col_maj_shape_1d is None """ # create row-major layout with compact row-major stride @@ -1151,7 +1256,7 @@ def _get_row_and_col_map(col_maj_shape_1d: tuple, is_row_to_col: bool): row_maj_stride = tuple(reversed(strides)) else: # Single dimension - row_maj_stride = 1 + row_maj_stride = 1 # type: ignore[assignment] row_maj_lay_1d = make_layout(row_maj_shape_1d, stride=row_maj_stride) @@ -1169,19 +1274,51 @@ def _get_row_and_col_map(col_maj_shape_1d: tuple, is_row_to_col: bool): return mask -def _row2col(vec: ir.Value, *, shape, loc=None, ip=None) -> ir.Value: +def _row2col( + vec: ir.Value, + *, + shape: Shape, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: + """Convert a vector or tensor from row-major order to column-major order. + + :param vec: The input vector in row-major order + :type vec: ir.Value + :param shape: The shape of the vector + :type shape: Shape + :param loc: Source location for MLIR operations, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint] + :return: The vector reordered to column-major layout + :rtype: ir.Value """ - Convert a vector or tensor from row-major order to column-major order. - """ - row_and_col_map = _get_row_and_col_map(shape, is_row_to_col=True) + row_and_col_map = _get_row_and_col_map(shape, is_row_to_col=True) # type: ignore[arg-type] return vector.shuffle(vec, vec, row_and_col_map, loc=loc, ip=ip) -def _col2row(vec: ir.Value, *, shape, loc=None, ip=None) -> ir.Value: +def _col2row( + vec: ir.Value, + *, + shape: Shape, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> ir.Value: + """Convert a vector or tensor from column-major order to row-major order. + + :param vec: The input vector in column-major order + :type vec: ir.Value + :param shape: The shape of the vector + :type shape: Shape + :param loc: Source location for MLIR operations, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint] + :return: The vector reordered to row-major layout + :rtype: ir.Value """ - Convert a vector or tensor from column-major order to row-major order. - """ - row_and_col_map = _get_row_and_col_map(shape, is_row_to_col=False) + row_and_col_map = _get_row_and_col_map(shape, is_row_to_col=False) # type: ignore[arg-type] return vector.shuffle(vec, vec, row_and_col_map, loc=loc, ip=ip) @@ -1201,7 +1338,7 @@ def _infer_broadcast_shape(*shapes: Shape) -> Shape: elif len(shapes) == 1: return shapes[0] - def _broadcast(*values): + def _broadcast(*values: int) -> int: non_one_values = [v for v in values if v != 1] if len(non_one_values) == 0: return 1 @@ -1210,8 +1347,10 @@ def _infer_broadcast_shape(*shapes: Shape) -> Shape: else: raise ValueError(f"cannot broadcast {values}") - max_rank = max(rank(shape) for shape in shapes) - ext_shapes = tuple(append(shape, 1, up_to_rank=max_rank) for shape in shapes) + # Use list comprehension instead of generator to avoid keeping frames on stack + # which can cause recursion issues with @dsl_user_op decorated functions + max_rank = max([rank(shape) for shape in shapes]) + ext_shapes = tuple([append(shape, 1, up_to_rank=max_rank) for shape in shapes]) res_shape = transform_leaf(_broadcast, *ext_shapes) return res_shape @@ -1232,20 +1371,55 @@ class TensorSSA(cutlass_arith.ArithValue): :raises ValueError: If shape is not static """ - def __init__(self, value, shape: Shape, dtype: Type[Numeric]): - """Initialize a new TensorSSA object. - - :param value: Flatten vector as ir.Value holding logic data of SSA Tensor - :type value: ir.Value - :param shape: The nested shape in CuTe of the vector - :type shape: Shape - :param dtype: Data type of the tensor elements - :type dtype: Type[Numeric] - :raises ValueError: If shape is not static + @dsl_user_op + def __init__( + self, + value: ir.Value, + shape: Shape, + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ + Create a :class:`TensorSSA` object: an immutable, thread-local tensor backed by a flattened MLIR vector. + + :param value: A :class:`ir.Value` holding the flattened MLIR vector value of the tensor. + :type value: :class:`ir.Value` + :param shape: The logical (possibly nested) shape of the tensor. If None, + this is inferred from ``value.type.shape``. + :type shape: Shape, optional + :param dtype: The data type of the tensor elements. If None, + this is inferred from the MLIR element type. + :type dtype: Type[Numeric], optional + + :keyword loc: Optional location for op construction. + :keyword ip: Optional insertion point for op construction. + + :raises ValueError: If ``value`` is not an ``ir.Value``, is not of vector type, + or if ``shape`` is not statically known. + + .. note:: + - Instances are immutable and represent per-thread local SSA values using value semantics. + - If ``shape`` is inferred and is multi-dimensional, the provided ``value`` + will be shape-cast to a 1D vector with the same logical product, aligning the + physical and logical shape representations. + - The tensor's broadcast shape and static element type are registered; dynamic shapes are not supported. + """ + if not isinstance(value, ir.Value): + raise ValueError(f"Expected value to be an ir.Value, got {type(value)}") + + if not isinstance(value.type, ir.VectorType): + raise ValueError( + f"Expected value to be a vector type, got {type(value.type)}" + ) + if not is_static(shape): raise ValueError("dynamic shape is not supported") + if dtype is None: + dtype = Numeric.from_mlir_type(value.type.element_type) + signed = dtype.signed if issubclass(dtype, Integer) else False super().__init__(value, signed) @@ -1253,6 +1427,76 @@ class TensorSSA(cutlass_arith.ArithValue): self._dtype = dtype self._layout = None + @staticmethod + @dsl_user_op + def from_vector( + value: ir.Value, + *, + dtype: Optional[Type[Numeric]] = None, + shape: Optional[Shape] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": + """ + Construct a :class:`TensorSSA` from a given MLIR vector value. + + This helper interprets the given 1D or n-D MLIR vector value and returns a TensorSSA view. + If the input is an n-D vector, it shape-casts it into a 1D vector holding the same number of elements. + + :param value: The ir.Value representing an MLIR vector value (1D or n-D). + :param dtype: Optional explicit type of the elements. Deduced from MLIR type if not provided. + :param loc: Optional MLIR location. + :param ip: Optional MLIR insertion point. + :return: A TensorSSA view over the vector value. + """ + if not isinstance(value, ir.Value): + raise ValueError(f"Expected value to be an ir.Value, got {type(value)}") + + if not isinstance(value.type, ir.VectorType): + raise ValueError( + f"Expected value to be a vector type, got {type(value.type)}" + ) + + if dtype is None: + dtype = Numeric.from_mlir_type(value.type.element_type) + + shape = shape or tuple(value.type.shape) + if not is_static(shape): + raise ValueError("dynamic shape is not supported") + + if rank(shape) > 1: + flat_vect_ty = ir.VectorType.get( + [product(shape, loc=loc, ip=ip)], value.type.element_type + ) + value = vector.shape_cast(flat_vect_ty, value, loc=loc, ip=ip) + + value = _row2col(value, shape=shape, loc=loc, ip=ip) + return TensorSSA(value, shape, dtype, loc=loc, ip=ip) + + @dsl_user_op + def to_vector( + self, + *, + force_flatten: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: + """ + Convert the tensor to a MLIR vector value. + """ + if depth(self.shape) > 1: + if not force_flatten: + raise ValueError( + "Cannot convert non-flattened tensor to vector, use force_flatten=True to flatten nested shape" + ) + shape = flatten_to_tuple(self.shape) + else: + shape = self.shape # type: ignore[assignment] + + res_ty = ir.VectorType.get(list(shape), self.dtype.mlir_type) + val = _col2row(self, shape=shape, loc=loc, ip=ip) + return vector.shape_cast(res_ty, val, loc=loc, ip=ip) + @property def dtype(self) -> Type[Numeric]: return self._dtype @@ -1261,35 +1505,28 @@ class TensorSSA(cutlass_arith.ArithValue): def element_type(self) -> Type[Numeric]: return self._dtype - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list: return [self] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list) -> "TensorSSA": return TensorSSA(values[0], self.shape, self.dtype) - def __str__(self): + def __str__(self) -> str: return f"tensor_value<{self.type} o {self.shape}>" @property - def shape(self): + def shape(self) -> Shape: return self._shape - @overload def _apply_op( - self, op, other: "TensorSSA", flip=False, *, loc, ip - ) -> "TensorSSA": ... - - @overload - def _apply_op( - self, op, other: cutlass_arith.ArithValue, flip=False, *, loc, ip - ) -> "TensorSSA": ... - - @overload - def _apply_op( - self, op, other: Union[int, float, bool], flip=False, *, loc, ip - ) -> "TensorSSA": ... - - def _apply_op(self, op, other, flip=False, *, loc=None, ip=None): + self, + op: Callable, + other: object, + flip: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": # Canonicalize into Numeric if isinstance(other, (int, float, bool)) or ( not isinstance(other, TensorSSA) @@ -1297,8 +1534,13 @@ class TensorSSA(cutlass_arith.ArithValue): ): other = as_numeric(other) + assert isinstance(other, (Numeric, TensorSSA)), ( + f"Expected other to be Numeric or TensorSSA after canonicalization, but got {type(other)}" + ) + # Promote types lhs, rhs, res_type = _binary_op_type_promote(self, other) + assert isinstance(lhs, TensorSSA) # Promote scalar to vector if not isinstance(rhs, TensorSSA): @@ -1350,6 +1592,7 @@ class TensorSSA(cutlass_arith.ArithValue): # Use ArithValue's operator method directly to avoid recursion # through TensorSSA's __add__/__sub__/etc. when op() dispatches # back to the subclass method + arith_op: Optional[Callable[..., Any]] = None if op.__name__ == "_min": arith_op = cutlass_arith._min elif op.__name__ == "_max": @@ -1369,42 +1612,87 @@ class TensorSSA(cutlass_arith.ArithValue): return res @dsl_user_op - def apply_op(self, op, other, flip=False, *, loc=None, ip=None) -> "TensorSSA": + def apply_op( + self, + op: Callable, + other: object, + flip: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Apply a binary operation to this tensor and another operand. - This is a public interface to the internal _apply_op method, providing - a stable API for external users who need to apply custom operations. + This public API method wraps the internal ``_apply_op`` for external usage, allowing custom operations to be performed on tensors. - Args: - op: The operation function (e.g., operator.add, operator.mul, etc.) - other: The other operand (TensorSSA, ArithValue, or scalar) - flip: Whether to flip the operands (for right-hand operations) - loc: MLIR location (optional) - ip: MLIR insertion point (optional) + :param op: The operation function (e.g., :obj:`operator.add`, :obj:`operator.mul`, etc.). + :type op: Callable + :param other: The other operand. Can be a :class:`TensorSSA`, ArithValue, or scalar. + :type other: TensorSSA or ArithValue or scalar + :param flip: If ``True``, flips the operands (applies operation as ``op(other, self)``). + :type flip: bool, optional + :param loc: MLIR location, optional. + :type loc: object, optional + :param ip: MLIR insertion point, optional. + :type ip: object, optional - Returns: - TensorSSA: The result of the operation + :return: The result of applying the binary operation. + :rtype: TensorSSA + + **Example** + + .. code-block:: python + + import operator + + tensor1 = cute.Tensor(...) + tensor2 = cute.Tensor(...) + result = tensor1.apply_op(operator.add, tensor2) + # Equivalent to: tensor1 + tensor2 - Example: - >>> tensor1 = cute.Tensor(...) - >>> tensor2 = cute.Tensor(...) - >>> result = tensor1.apply_op(operator.add, tensor2) - >>> # Equivalent to: tensor1 + tensor2 """ return self._apply_op(op, other, flip=flip, loc=loc, ip=ip) @dsl_user_op - def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": - """ - Broadcast the tensor to the target shape. + def broadcast_to( + self, + target_shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": + """Broadcast the tensor to the target shape. + + This method broadcasts the tensor to match a target shape following NumPy-style + broadcasting rules. Dimensions of size 1 can be broadcast to any size, and + missing dimensions are added with size 1. + + :param target_shape: The desired output shape + :type target_shape: Shape + :param loc: Source location for MLIR operations, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new tensor broadcast to the target shape + :rtype: TensorSSA + :raises ValueError: If shapes are incompatible for broadcasting + + **Examples:** + + .. code-block:: python + + # Broadcast a (1, 4) tensor to (3, 4) + src = cute.full((1, 4), 1.0, Float32) + dst = src.broadcast_to((3, 4)) + # dst now has shape (3, 4) with the first row replicated """ # pad source shape to the same rank shape = append(self.shape, 1, up_to_rank=rank(target_shape)) if shape == target_shape: return self - def _check_broadcast(s, t): + def _check_broadcast(s: int, t: int) -> None: if s != t and s != 1: raise ValueError( f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" @@ -1429,7 +1717,13 @@ class TensorSSA(cutlass_arith.ArithValue): ) @dsl_user_op - def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __pow__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the results of tensor^other. @@ -1441,7 +1735,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.pow, other, loc=loc, ip=ip) @dsl_user_op - def __rpow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rpow__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the results of other^tensor. @@ -1453,7 +1753,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.pow, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __add__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __add__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the sum of the tensor and another tensor. @@ -1465,7 +1771,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.add, other, loc=loc, ip=ip) @dsl_user_op - def __radd__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __radd__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the sum of the tensor and another tensor (reverse add) @@ -1477,7 +1789,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.add, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __sub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __sub__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the difference of the tensor and another tensor. @@ -1489,7 +1807,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.sub, other, loc=loc, ip=ip) @dsl_user_op - def __rsub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rsub__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the difference of the tensor and another tensor (reverse subtract) @@ -1501,7 +1825,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.sub, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __mul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __mul__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the multiplication of the tensor and another tensor. @@ -1513,7 +1843,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.mul, other, loc=loc, ip=ip) @dsl_user_op - def __rmul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rmul__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the multiplication of the tensor and another tensor (reverse multiply) @@ -1525,7 +1861,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.mul, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __mod__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __mod__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the modulo of the tensor and another tensor. @@ -1537,7 +1879,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.mod, other, loc=loc, ip=ip) @dsl_user_op - def __rmod__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rmod__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the modulo of the tensor and another tensor (reverse modulo) @@ -1549,7 +1897,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.mod, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __floordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __floordiv__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the floordiv(//) of the tensor and another tensor. @@ -1561,7 +1915,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.floordiv, other, loc=loc, ip=ip) @dsl_user_op - def __rfloordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rfloordiv__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the floordiv(//) of the tensor and another tensor (reverse floordiv) @@ -1573,7 +1933,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.floordiv, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __truediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __truediv__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the truediv(/) of the tensor and another tensor. @@ -1585,7 +1951,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.truediv, other, loc=loc, ip=ip) @dsl_user_op - def __rtruediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rtruediv__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the truediv(/) of the tensor and another tensor (reverse truediv) @@ -1597,7 +1969,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.truediv, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __eq__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __eq__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the comparison of the tensor and another tensor as mask @@ -1609,7 +1987,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.eq, other, loc=loc, ip=ip) @dsl_user_op - def __ne__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __ne__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise not equal comparison of the tensor and another tensor. @@ -1621,7 +2005,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.ne, other, loc=loc, ip=ip) @dsl_user_op - def __lt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __lt__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise less than comparison of the tensor and another tensor. @@ -1633,7 +2023,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.lt, other, loc=loc, ip=ip) @dsl_user_op - def __le__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __le__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise less than or equal comparison of the tensor and another tensor. @@ -1645,7 +2041,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.le, other, loc=loc, ip=ip) @dsl_user_op - def __gt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __gt__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise greater than comparison of the tensor and another tensor. @@ -1657,7 +2059,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.gt, other, loc=loc, ip=ip) @dsl_user_op - def __ge__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __ge__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise greater than or equal comparison of the tensor and another tensor. @@ -1669,7 +2077,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.ge, other, loc=loc, ip=ip) @dsl_user_op - def __xor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __xor__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise XOR of the tensor and another tensor. @@ -1681,7 +2095,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.xor, other, loc=loc, ip=ip) @dsl_user_op - def __rxor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rxor__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the bitwise XOR of the tensor and another tensor. @@ -1693,7 +2113,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.xor, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __or__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __or__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise OR of the tensor and another tensor. @@ -1705,7 +2131,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.or_, other, loc=loc, ip=ip) @dsl_user_op - def __ror__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __ror__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise OR of the tensor and another tensor. @@ -1717,7 +2149,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.or_, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __and__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __and__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise AND of the tensor and another tensor. @@ -1729,7 +2167,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.and_, other, loc=loc, ip=ip) @dsl_user_op - def __rand__(self, other, *, loc=None, ip=None) -> "TensorSSA": + def __rand__( + self, + other: object, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the element-wise AND of the tensor and another tensor. @@ -1741,7 +2185,12 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) @dsl_user_op - def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": + def __neg__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """ Returns the negation of the tensor. @@ -1751,7 +2200,29 @@ class TensorSSA(cutlass_arith.ArithValue): return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) - def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): + @dsl_user_op + def __abs__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": + """ + Returns the element-wise absolute value of the tensor. + + :return: The element-wise absolute value of the tensor + :rtype: TensorSSA + """ + res_vect = abs(self.maybe_downcast()) + return TensorSSA(res_vect, self._shape, self.dtype) + + def _flatten_shape_and_coord( + self, + crd: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tuple[Shape, Coord]: # Coalesce and flatten source layout at terminal of coordinate # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) crd_shp = product_like(self._shape, target_profile=crd, loc=loc, ip=ip) @@ -1765,10 +2236,18 @@ class TensorSSA(cutlass_arith.ArithValue): assert isinstance(flat_crd, tuple) and is_static(flat_crd) return flat_shp, flat_crd - def _build_result(self, res_vect, res_shp, *, row_major=False, loc=None, ip=None): + def _build_result( + self, + res_vect: ir.Value, + res_shp: Shape, + *, + row_major: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": if isinstance(res_shp, ir.Value): + # Defensive check: internal method, public API never passes dynamic shapes raise ValueError(f"Expected static shape, but got {self._shape}") - # cast back to 1D vector res_1d_ty = ir.VectorType.get([size(res_shp)], self.type.element_type) res_1d_vect = vector.shape_cast(res_1d_ty, res_vect, loc=loc, ip=ip) @@ -1779,7 +2258,13 @@ class TensorSSA(cutlass_arith.ArithValue): return TensorSSA(res_1d_vect, res_shp, self.dtype) @dsl_user_op - def reshape(self, shape: Shape, *, loc=None, ip=None) -> "TensorSSA": + def reshape( + self, + shape: Shape, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """Reshape the tensor to a new shape. :param shape: The new shape to reshape to. @@ -1804,7 +2289,11 @@ class TensorSSA(cutlass_arith.ArithValue): @dsl_user_op def __getitem__( - self, crd: Coord, *, loc=None, ip=None + self, + crd: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union["TensorSSA", Numeric]: """Access or slice tensor elements using coordinates. @@ -1859,7 +2348,7 @@ class TensorSSA(cutlass_arith.ArithValue): # convert TensorSSA col-major vec to row-m to be compatible with mlir vector ops row_major_vec = _col2row(self, shape=self._shape, loc=loc, ip=ip) - multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) + multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) # type: ignore[arg-type] # vector -> vector tmp_vect = vector.shape_cast(multi_dim_ty, row_major_vec, loc=loc, ip=ip) @@ -1870,9 +2359,9 @@ class TensorSSA(cutlass_arith.ArithValue): ) # Offsets is index of coordinates if NOT `_` otherwise 0 - offsets = [c if c is not None else 0 for c in flat_crd] + offsets = [c if c is not None else 0 for c in flat_crd] # type: ignore[union-attr] # Sizes is size of shapes if `_` otherwise 1 - sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] + sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] # type: ignore[arg-type] # Logic stride to index vector. Only support stride-1 by vector strides = [1] * rank(flat_shp) @@ -1893,7 +2382,13 @@ class TensorSSA(cutlass_arith.ArithValue): return self._build_result(res_vect, res_shp, row_major=True, loc=loc, ip=ip) @dsl_user_op - def to(self, dtype: Type[Numeric], *, loc=None, ip=None): + def to( + self, + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": """Convert the tensor to a different numeric type. :param dtype: The target numeric type to cast to. @@ -1916,7 +2411,16 @@ class TensorSSA(cutlass_arith.ArithValue): # maybe downcast can lose signedness src = self.maybe_downcast().with_signedness(self.signed) if src_dtype.is_float and dtype.is_float: - res_vect = cutlass_arith.cvtf(src, dtype.mlir_type, loc=loc, ip=ip) + + def convert_fp_to_fp( + src: cutlass_arith.ArithValue, + dst_dtype: Type[Numeric], + loc: Optional[ir.Location], + ip: Optional[ir.InsertionPoint], + ) -> ir.Value: + return cutlass_arith.cvtf(src, dst_dtype.mlir_type, loc=loc, ip=ip) + + res_vect = convert_fp_to_fp(src, dtype, loc, ip) elif src_dtype.is_float and issubclass(dtype, Integer): res_vect = cutlass_arith.fptoi( src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip @@ -1929,7 +2433,7 @@ class TensorSSA(cutlass_arith.ArithValue): elif src_dtype == Int4 and dtype == BFloat16: fast_cvt_func = cvt_i4_bf16_intrinsic arch = BaseDSL._get_dsl().get_arch_enum() - if fast_cvt_func is not None and arch in fast_cvt_func.supported_archs: + if fast_cvt_func is not None and arch in fast_cvt_func.supported_archs: # type: ignore[attr-defined] res_vect = fast_cvt_func(src, size(self.shape), loc=loc, ip=ip) else: res_vect = cutlass_arith.itofp( @@ -1941,11 +2445,52 @@ class TensorSSA(cutlass_arith.ArithValue): return TensorSSA(res_vect, self._shape, dtype) @dsl_user_op - def ir_value(self, *, loc=None, ip=None): + def bitcast( + self, + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": + """Reinterpret the bits of this tensor as a different element type. + + Total bit width is preserved; the element count adjusts proportionally. + For example, a ``TensorSSA`` of shape ``(4,)`` with ``Float32`` bitcast + to ``Float16`` yields a ``TensorSSA`` of shape ``(8,)`` with ``Float16`` + (4 × 32 = 8 × 16 bits). Multi-dimensional shapes are flattened. + + :param dtype: Target DSL element type (e.g. ``Int32``, ``Float16``). + :type dtype: Type[Numeric] + :return: A new :class:`TensorSSA` with bits reinterpreted as ``dtype``. + :rtype: TensorSSA + :raises TypeError: If ``dtype`` is not a subclass of :class:`Numeric`. + """ + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a Numeric type, but got {dtype}") + if dtype is self._dtype: + return self + old_count = size(self._shape) + new_count = old_count * self._dtype.width // dtype.width + target_vec_ty = ir.VectorType.get([new_count], dtype.mlir_type) + res_vec = vector.bitcast(target_vec_ty, self, loc=loc, ip=ip) + return TensorSSA(res_vec, (new_count,), dtype, loc=loc, ip=ip) + + @dsl_user_op + def ir_value( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": return self @dsl_user_op - def ir_value_int8(self, *, loc=None, ip=None): + def ir_value_int8( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: """ Returns int8 ir value of Boolean tensor. When we need to store Boolean tensor ssa, use ir_value_int8(). @@ -1968,7 +2513,15 @@ class TensorSSA(cutlass_arith.ArithValue): return self._value_int8 @dsl_user_op - def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): + def reduce( + self, + op: ReductionOp, + init_val: object, + reduction_profile: Coord, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Union["TensorSSA", ir.Value]: """ Perform reduce on selected modes with given predefined reduction op. @@ -2015,11 +2568,11 @@ class TensorSSA(cutlass_arith.ArithValue): elif op is ReductionOp.MIN: red_kind = vector.CombiningKind.MINIMUMF else: + # Defensive check: ReductionOp enum only has 4 valid values raise NotImplementedError( f"{op} is not supported, expected one of " f"{ReductionOp.ADD, ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN}" ) - elem_type = self.element_type # Canonicalize to `Numeric` and convert into MLIR value init_val = ( @@ -2042,7 +2595,7 @@ class TensorSSA(cutlass_arith.ArithValue): # convert TensorSSA col-major vec to row-m to be compatible with mlir vector ops row_major_vec = _col2row(self, shape=self._shape, loc=loc, ip=ip) - temp_ty = ir.VectorType.get(list(flat_shp), elem_type.mlir_type) + temp_ty = ir.VectorType.get(list(flat_shp), elem_type.mlir_type) # type: ignore[arg-type] temp_vect = vector.shape_cast(temp_ty, row_major_vec, loc=loc, ip=ip) red_dims = [i for i, x in enumerate(flat_prof) if x is not None] @@ -2061,7 +2614,14 @@ class TensorSSA(cutlass_arith.ArithValue): @dsl_user_op -def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: +def full( + shape: Shape, + fill_value: Union[ir.Value, int, float, bool, Numeric], + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TensorSSA: """ Return a new TensorSSA of given shape and type, filled with fill_value. @@ -2093,11 +2653,11 @@ def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> Tenso @dsl_user_op def full_like( a: Union[TensorSSA, Tensor], - fill_value, + fill_value: object, dtype: Union[None, Type[Numeric]] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> TensorSSA: """ Return a full TensorSSA with the same shape and type as a given array. @@ -2128,12 +2688,18 @@ def full_like( if not hasattr(a, "shape"): raise TypeError(f"Expected `a` be shaped type, but got {type(a)}") - res_dtype = dtype if dtype is not None else a.dtype # type: ignore + res_dtype = dtype if dtype is not None else a.dtype return full(a.shape, fill_value, res_dtype, loc=loc, ip=ip) @dsl_user_op -def empty_like(a, dtype=None, *, loc=None, ip=None): +def empty_like( + a: Union[TensorSSA, Tensor], + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TensorSSA: """ Return a new TensorSSA with the same shape and type as a given array, without initializing entries. @@ -2148,7 +2714,13 @@ def empty_like(a, dtype=None, *, loc=None, ip=None): @dsl_user_op -def ones_like(a, dtype=None, *, loc=None, ip=None): +def ones_like( + a: Union[TensorSSA, Tensor], + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TensorSSA: """ Return a TensorSSA of ones with the same shape and type as a given array. @@ -2163,7 +2735,13 @@ def ones_like(a, dtype=None, *, loc=None, ip=None): @dsl_user_op -def zeros_like(a, dtype=None, *, loc=None, ip=None): +def zeros_like( + a: Union[TensorSSA, Tensor], + dtype: Optional[Type[Numeric]] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TensorSSA: """ Return a TensorSSA of zeros with the same shape and type as a given array. @@ -2183,8 +2761,8 @@ def where( x: Union[TensorSSA, Numeric], y: Union[TensorSSA, Numeric], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> TensorSSA: """ Return elements chosen from x or y depending on condition; will auto broadcast x or y if needed. @@ -2200,7 +2778,9 @@ def where( """ # Helper function to promote scalars to tensors or broadcast tensors to target shape - def promote_and_broadcast(v, shape): + def promote_and_broadcast( + v: Union[TensorSSA, Numeric, bool, int, float, ir.Value], shape: Shape + ) -> TensorSSA: if isinstance(v, TensorSSA): return v.broadcast_to(shape) elif isinstance(v, (bool, int, float, ir.Value, Numeric)): @@ -2216,8 +2796,8 @@ def where( raise ValueError( f"at least one of x and y must be tensor, but got {type(x)} and {type(y)}" ) - x_shape = x.shape if x_is_tensor else y.shape - y_shape = y.shape if y_is_tensor else x.shape + x_shape = x.shape if x_is_tensor else y.shape # type: ignore[union-attr] + y_shape = y.shape if y_is_tensor else x.shape # type: ignore[union-attr] # Promote both operands to tensors with broadcast shape res_shape = _infer_broadcast_shape(cond.shape, x_shape, y_shape) @@ -2239,7 +2819,12 @@ def where( @dsl_user_op -def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: +def any_( + x: TensorSSA, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Boolean: """ Test whether any tensor element evaluates to True. @@ -2255,7 +2840,12 @@ def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: @dsl_user_op -def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: +def all_( + x: TensorSSA, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Boolean: """ Test whether all tensor elements evaluate to True. @@ -2268,3 +2858,295 @@ def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: return Boolean( vector.reduction(T.bool(), vector.CombiningKind.AND, is_true, loc=loc, ip=ip) ) + + +@dsl_user_op +def gather( + input: Tensor, + mode: int, + index: TensorSSA, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> TensorSSA: + """ + Gather elements from input tensor along the index specified by mode. + + For each value in the output, its load index is specified by its index in itself + for m != `mode` and by the corresponding value in `index` for m = `mode`. + + E.g., for a 3D case, the result TensorSSA `output` is specified by: + ``` + output[i][j][k] = input[index[i][j][k]][j][k] # if mode == 0 + output[i][j][k] = input[i][index[i][j][k]][k] # if mode == 1 + output[i][j][k] = input[i][j][index[i][j][k]] # if mode == 2 + ``` + + * `input` and `index` must have the same rank and congruent shapes. + * `output` will have the same shape as `index`. + * Regarding the shape of `index`: + * size(index.shape[m]) <= size(input.shape[m]) for all modes m != mode + * all values in `index` must be in the range [0, size(input.shape[mode])), + otherwise, it will result in an undefined behavior + + :param input: The input tensor + :type input: Tensor + :param mode: The mode along which to gather + :type mode: int + :param index: The index tensor + :type index: TensorSSA + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The gathered tensor ssa + :rtype: TensorSSA + """ + + _check_can_gather_scatter(input, mode, index) + + idx_layout = make_layout(index.shape) + src_layout = make_layout(index.shape, stride=input.stride) + + # Split src and index layouts into two parts respectively: + # * gather part: {mode} + # * rest part: [0, mode) ∪ (mode, rank) + # Append ones (i.e., 1:0) to the layouts in case the rest part is empty + idx_layout = append_ones(idx_layout) + src_layout = append_ones(src_layout) + gather_modes = [mode] + rest_modes = [m for m in range(rank(idx_layout)) if m not in gather_modes] + idx_layout_gather = select(idx_layout, gather_modes) + idx_layout_rest = select(idx_layout, rest_modes) + src_layout_gather = select(src_layout, gather_modes) + src_layout_rest = select(src_layout, rest_modes) + + res_elems = [None] * size(index.shape) + res_vect_ty = T.vector(size(index.shape), input.element_type.mlir_type) # type: ignore[union-attr] + + # Optimized path: lower to vector.gather when the tensor is col-major + # and gathering along the left-most mode + if ( + mode == 0 + and is_major(mode, input.stride) + and not input.iterator.value.type.is_swizzled # type: ignore[union-attr] + ): + vect_sz = size(idx_layout_gather) + vect_ty = T.vector(vect_sz, input.element_type.mlir_type) # type: ignore[union-attr] + idx_vect_ty = T.vector(vect_sz, index.element_type.mlir_type) + mask_all_ones = vector.constant_mask( + T.vector(vect_sz, T.bool()), [vect_sz], loc=loc, ip=ip + ) + pass_thru_poison = llvm.mlir_poison(vect_ty, loc=loc, ip=ip) + for rest_crd in range(size(select(idx_layout.shape, rest_modes))): + curr_ptr = input.iterator + src_layout_rest(rest_crd) + idx_vect = vector.extract_strided_slice( + idx_vect_ty, + index.ir_value(loc=loc, ip=ip), + offsets=[rest_crd * vect_sz], + sizes=[vect_sz], + strides=[1], + loc=loc, + ip=ip, + ) + res_vect = vector.gather( + result=vect_ty, + base=curr_ptr._to_builtin_memref(loc=loc, ip=ip), + offsets=[], + indices=idx_vect, + mask=mask_all_ones, + pass_thru=pass_thru_poison, + alignment=input.iterator.alignment, # type: ignore[union-attr] + loc=loc, + ip=ip, + ) + vect_elems = vector.to_elements(res_vect) + res_start_idx = rest_crd * vect_sz + if vect_sz == 1: + res_elems[res_start_idx] = vect_elems + else: + res_elems[res_start_idx : res_start_idx + vect_sz] = vect_elems + res_vect = vector.from_elements(res_vect_ty, res_elems) + return TensorSSA(res_vect, index.shape, input.element_type) + + # Normal path: gather by computing the new index for each element + for gather_crd in range(size(select(idx_layout.shape, gather_modes))): + for rest_crd in range(size(select(idx_layout.shape, rest_modes))): + index_crd = idx_layout_gather(gather_crd) + idx_layout_rest(rest_crd) + src_crd_gather = index[index_crd] + src_crd = src_layout_gather(src_crd_gather) + src_layout_rest(rest_crd) + src_crd_hier = input.layout.get_hier_coord(src_crd, loc=loc, ip=ip) # type: ignore[call-arg, union-attr] + res_elems[index_crd] = input[src_crd_hier].ir_value(loc=loc, ip=ip) # type: ignore[union-attr] + res_vect = vector.from_elements(res_vect_ty, res_elems) + return TensorSSA(res_vect, index.shape, input.element_type) + + +@dsl_user_op +def scatter( + output: Tensor, + mode: int, + index: TensorSSA, + data: TensorSSA, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Scatter elements to output tensor along the index specified by `mode`. + + For each value in `data`, its store index is specified by its index in itself + for m != `mode` and by the corresponding value in `index` for m = `mode`. + + E.g., for a 3D case, the tensor `output` is updated as: + ``` + output[index[i][j][k]][j][k] = data[i][j][k] # if dim == 0 + output[i][index[i][j][k]][k] = data[i][j][k] # if dim == 1 + output[i][j][index[i][j][k]] = data[i][j][k] # if dim == 2 + ``` + + * `output` and `index` must have the same rank and congruent shapes. + * `data` must have the same shape as `index`. + * Regarding the shape of `index`: + * size(index.shape[m]) <= size(output.shape[m]) for all modes m != mode + * all values in `index` must be in the range [0, size(output.shape[mode])), + otherwise, it will result in an undefined behavior + * If the index vector contains two or more duplicate indices, the behavior + is undefined. Underlying implementation may enforce strict col-major + sequential semantics. + + :param output: The output tensor + :type output: Tensor + :param mode: The mode along which to scatter + :type mode: int + :param index: The index tensor + :type index: TensorSSA + :param data: The data tensor + :type data: TensorSSA + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + """ + + _check_can_gather_scatter(output, mode, index, data) + + idx_layout = make_layout(index.shape) + dst_layout = make_layout(index.shape, stride=output.stride) + + # Split dst and index layouts into two parts respectively: + # * scatter part: {mode} + # * rest part: [0, mode) ∪ (mode, rank) + # Append ones (i.e., 1:0) to the layouts in case the rest part is empty + idx_layout = append_ones(idx_layout) + dst_layout = append_ones(dst_layout) + scatter_modes = [mode] + rest_modes = [m for m in range(rank(idx_layout)) if m not in scatter_modes] + idx_layout_scatter = select(idx_layout, scatter_modes) + idx_layout_rest = select(idx_layout, rest_modes) + dst_layout_scatter = select(dst_layout, scatter_modes) + dst_layout_rest = select(dst_layout, rest_modes) + + # Optimized path: lower to vector.scatter when tensor is col-major and + # scattering along the left-most mode + if ( + mode == 0 + and is_major(mode, output.stride) + and not output.iterator.value.type.is_swizzled # type: ignore[union-attr] + ): + vect_sz = size(idx_layout_scatter) + vect_ty = T.vector(vect_sz, output.element_type.mlir_type) # type: ignore[union-attr] + idx_vect_ty = T.vector(vect_sz, index.element_type.mlir_type) + mask_all_ones = vector.constant_mask( + T.vector(vect_sz, T.bool()), [vect_sz], loc=loc, ip=ip + ) + for rest_crd in range(size(select(idx_layout.shape, rest_modes))): + curr_ptr = output.iterator + dst_layout_rest(rest_crd) + idx_vect = vector.extract_strided_slice( + idx_vect_ty, + index.ir_value(loc=loc, ip=ip), + offsets=[rest_crd * vect_sz], + sizes=[vect_sz], + strides=[1], + loc=loc, + ip=ip, + ) + data_vect = vector.extract_strided_slice( + vect_ty, + data.ir_value(loc=loc, ip=ip), + offsets=[rest_crd * vect_sz], + sizes=[vect_sz], + strides=[1], + loc=loc, + ip=ip, + ) + vector.scatter( + result=None, + base=curr_ptr._to_builtin_memref(loc=loc, ip=ip), + offsets=[], + indices=idx_vect, + mask=mask_all_ones, + value_to_store=data_vect, + alignment=output.iterator.alignment, # type: ignore[union-attr] + loc=loc, + ip=ip, + ) + return + + # Normal path: scatter by computing the new index for each element + for scatter_crd in range(size(select(idx_layout.shape, scatter_modes))): + for rest_crd in range(size(select(idx_layout.shape, rest_modes))): + index_crd = idx_layout_scatter(scatter_crd) + idx_layout_rest(rest_crd) + dst_crd_scatter = index[index_crd] + dst_crd = dst_layout_scatter(dst_crd_scatter) + dst_layout_rest(rest_crd) + dst_crd_hier = output.layout.get_hier_coord(dst_crd, loc=loc, ip=ip) # type: ignore[call-arg, union-attr] + output[dst_crd_hier] = data[index_crd] + + +def _check_can_gather_scatter( + tensor: Tensor, mode: int, index: TensorSSA, data: Optional[TensorSSA] = None +) -> None: + # Check static + if not is_static(tensor.shape): + raise ValueError( + f"gather/scatter on tensor with dynamic shape is not supported, got: {tensor.type}" # type: ignore[attr-defined] + ) + + # Check modes + n_modes = rank(tensor.layout) + if mode < 0 or mode >= n_modes: + raise ValueError(f"mode must be in the range [0, {n_modes}), got: {mode}") + if n_modes != rank(index.shape): + raise ValueError( + f"source and index must have the same rank, got: {n_modes} and {rank(index.shape)}" + ) + + # Check layout + if isinstance(tensor.layout, ComposedLayout): + raise NotImplementedError( + f"gather/scatter on tensor with composed layout is not supported, got: {tensor.layout}" + ) + if depth(tensor.layout) > 1 or depth(index.shape) > 1: + raise NotImplementedError( + f"gather/scatter on tensor with nested layout is not supported, got: {tensor.layout} and {index.shape}" + ) + for m in range(n_modes): + if m != mode and size(index.shape[m]) > size(tensor.shape[m]): # type: ignore[index] + raise ValueError( + f"index dimension {m} must be less than or equal to the corresponding source dimension," + f"got: {size(index.shape[m])} and {size(tensor.shape[m])}" # type: ignore[index] + ) + if data is not None and index.shape != data.shape: + raise ValueError( + f"index and data must have the same shape, got: {index.shape} and {data.shape}" + ) + + # Check data type + if not issubclass(index.dtype, Integer): + raise TypeError(f"index must be integer TensorSSA, got {index.dtype}") + if tensor.element_type.width % 8 != 0: # type: ignore[union-attr] + raise TypeError( + f"gather/scatter for sub-byte element type is not supported, got: {tensor.element_type}" + ) + if data is not None and data.dtype != tensor.element_type: + raise TypeError( + f"element type of data must be {tensor.element_type}, got: {data.dtype}" + ) diff --git a/python/CuTeDSL/cutlass/cute/testing.py b/python/CuTeDSL/cutlass/cute/testing.py index 618522a1b..f7f3dfe64 100644 --- a/python/CuTeDSL/cutlass/cute/testing.py +++ b/python/CuTeDSL/cutlass/cute/testing.py @@ -9,10 +9,12 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +import argparse import functools import inspect import logging import os +from dataclasses import dataclass from itertools import product from time import time from typing import Type, Union, Callable, Optional, Dict, List, Any @@ -22,7 +24,7 @@ import cuda.bindings.runtime as cuda_runtime from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, dsl_user_op, const_expr -from .typing import Numeric, Int8, Boolean, Tensor, Layout, Shape +from .typing import Numeric, Int8, Uint8, Boolean, Tensor, Layout, Shape from . import nvgpu from .core import recast_layout, make_layout, composition, get, rank, size @@ -41,9 +43,129 @@ from .runtime import from_dlpack from cutlass._mlir.dialects import builtin, cf, nvvm, vector +from functools import partial +from cutlass._mlir import ir + + +class CuptiProfiler: + """A class for managing CUPTI profiling measurements with start, stop, and duration methods. + + This class provides a clean interface for measuring CUDA kernel execution times + using CUPTI (CUDA Profiling Tools Interface). It encapsulates the complexity + of buffer management, callback registration, and activity tracking. + + Example usage: + profiler = CuptiProfiler() + profiler.start() + # ... run your CUDA kernels ... + profiler.stop() + duration = profiler.get_duration() # Returns total duration in milliseconds + """ + + def __init__(self, buffer_size: int = 8 * 1024 * 1024) -> None: + """Initialize the CUPTI profiler. + + Args: + buffer_size: Size of the CUPTI buffer in bytes (default: 8MB) + + Raises: + ImportError: If the cupti-python package is not installed + """ + try: + from cupti import cupti + + self._cupti = cupti + except ModuleNotFoundError: + raise ModuleNotFoundError( + "CUPTI is not available. Install the 'cupti-python' package to use CuptiProfiler." + ) + self.buffer_size = buffer_size + self.timings: list[tuple[str, float]] = [] + self._is_active = False + self._buffer_requested_callback: Optional[Callable[..., Any]] = None + self._buffer_completed_callback: Optional[Callable[..., Any]] = None + + def _buffer_requested(self) -> tuple[int, int]: + """Internal callback for CUPTI buffer requests.""" + max_num_records = 0 + return self.buffer_size, max_num_records + + def _buffer_completed(self, activities: list[Any]) -> None: + """Internal callback for processing completed CUPTI activities.""" + for activity in activities: + start = activity.start if hasattr(activity, "start") else "nil" + end = activity.end if hasattr(activity, "end") else "nil" + duration = end - start if start != "nil" and end != "nil" else "nil" # type: ignore[operator] + name = activity.name[:100] if hasattr(activity, "name") else "unknown" + # Convert to milliseconds + if duration != "nil": + self.timings.append((name, duration / 1e6)) # type: ignore[operator] + + def start(self) -> None: + """Start CUPTI profiling. + + Enables CUPTI activity tracking for concurrent kernels and registers + the necessary callbacks for buffer management. + + Raises: + ValueError: If CUPTI activity cannot be enabled + """ + if self._is_active: + raise RuntimeError("CUPTI profiler is already active") + + # Clear previous timings + self.timings = [] + + try: + self._cupti.activity_enable(self._cupti.ActivityKind.CONCURRENT_KERNEL) + except self._cupti.cuptiError as e: + raise ValueError( + f"\033[91mError while enabling Activity Kind {self._cupti.ActivityKind.CONCURRENT_KERNEL.name}: {e}. Please disable CUPTI if you using profilers\033[0m" + ) + + # Register callbacks + self._buffer_requested_callback = self._buffer_requested + self._buffer_completed_callback = partial(self._buffer_completed) + + self._cupti.activity_register_callbacks( + self._buffer_requested_callback, self._buffer_completed_callback + ) + + self._is_active = True + + def stop(self) -> None: + """Stop CUPTI profiling. + + Flushes all activities, disables CUPTI tracking, and finalizes the profiler. + This method should be called after the kernels you want to measure have completed. + """ + if not self._is_active: + raise RuntimeError("CUPTI profiler is not active") + + # Flush all activities and cleanup + self._cupti.activity_flush_all(0) + self._cupti.activity_disable(self._cupti.ActivityKind.CONCURRENT_KERNEL) + self._cupti.finalize() + + self._is_active = False + + def get_duration(self) -> float: + """Get the total duration of all measured activities in milliseconds. + + Returns: + Total duration in milliseconds. Returns 0.0 if no activities were recorded. + """ + return sum(timing[1] for timing in self.timings) + @dsl_user_op -def assert_(cond, msg=None, *, loc=None, ip=None): +def assert_( + cond: object, + msg: Optional[str] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: cf.assert_(Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) @@ -74,14 +196,14 @@ class _CompileTimeAssertion(Assertion): def __init__( self, - tensor: Tensor, + tensor: Optional[Tensor], num_assertions: int = 1, - msgs=None, - device=None, + msgs: Optional[list[str]] = None, + device: Optional[str] = None, disable: bool = False, init_value: bool = False, - used_indices: set = None, - ): + used_indices: Optional[set[int]] = None, + ) -> None: """Initialize _CompileTimeAssertion. :param tensor: Tensor to store assertion results @@ -102,7 +224,9 @@ class _CompileTimeAssertion(Assertion): self._init_value = init_value self._used_indices = used_indices - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "_CompileTimeAssertion": if self._disable: return _CompileTimeAssertion( None, @@ -123,14 +247,22 @@ class _CompileTimeAssertion(Assertion): self._used_indices, ) - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: if self._disable: return [] - return self._tensor.__extract_mlir_values__() + return self._tensor.__extract_mlir_values__() # type: ignore[union-attr] @dsl_user_op @CuTeDSL.jit - def store(self, idx: Constexpr, pred: Boolean, msg: str = "", *, loc=None, ip=None): + def store( + self, + idx: Constexpr, + pred: Boolean, + msg: str = "", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Assert a predicate condition. :param idx: Assertion index @@ -148,26 +280,30 @@ class _CompileTimeAssertion(Assertion): return if const_expr(not isinstance(idx, int)): raise ValueError(f"expects idx to be 'int', but got {type(idx)}") - if const_expr(idx >= self._num_assertions): - raise ValueError(f"please increase the number of assertions!!!") + if const_expr(idx >= self._num_assertions): # type: ignore[operator] + raise ValueError("please increase the number of assertions!!!") if const_expr(self._init_value is True): - self._tensor[idx] = pred and self._tensor[idx] + self._tensor[idx] = pred and self._tensor[idx] # type: ignore[index] else: - self._tensor[idx] = pred - self._msgs[idx] = f"{msg}\nAt {loc}" - self._used_indices.add(idx) + self._tensor[idx] = pred # type: ignore[index] + self._msgs[idx] = f"{msg}\nAt {loc}" # type: ignore[call-overload] + self._used_indices.add(idx) # type: ignore[union-attr, arg-type] - def __enter__(self): + def __enter__(self) -> "_CompileTimeAssertion": """Enter context manager.""" return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[object], + ) -> None: """Exit context manager and verify assertions if no exception occurred.""" # Only verify if there was no exception in the with block if exc_type is None and not self._disable: # _CompileTimeAssertion doesn't have verify method as it's checked at compile time pass - return False # Don't suppress exceptions class RuntimeAssertion(Assertion): @@ -195,10 +331,10 @@ class RuntimeAssertion(Assertion): def __init__( self, num_assertions: int = 1, - device=None, + device: Optional[str] = None, disable: bool = False, init_value: bool = False, - ): + ) -> None: """Initialize _RuntimeAssertion. :param num_assertions: Number of assertions to support @@ -211,7 +347,7 @@ class RuntimeAssertion(Assertion): self._disable = disable self._msgs = [""] * num_assertions self._init_value = init_value - self._used_indices = set() + self._used_indices: set[int] = set() if self._disable: return import torch @@ -224,19 +360,19 @@ class RuntimeAssertion(Assertion): ) self._tensor = from_dlpack(self._torch_tensor) - def __c_pointers__(self): + def __c_pointers__(self) -> list[Any]: """Get C pointers for passing to JIT functions.""" if self._disable: return [] - return self._tensor.__c_pointers__() + return self._tensor.__c_pointers__() # type: ignore[attr-defined] - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> list[Any]: """Get MLIR types for code generation.""" if self._disable: return [] - return self._tensor.__get_mlir_types__() + return self._tensor.__get_mlir_types__() # type: ignore[attr-defined] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[ir.Value]) -> _CompileTimeAssertion: """Create new instance from MLIR values (for JIT compilation).""" if self._disable: return _CompileTimeAssertion( @@ -258,7 +394,7 @@ class RuntimeAssertion(Assertion): self._used_indices, ) - def verify(self): + def verify(self) -> None: """Verify all assertions have passed.""" if self._disable: return @@ -272,33 +408,41 @@ class RuntimeAssertion(Assertion): # emit the first assertion error. raise AssertionError(self._msgs[valid_indices[0]]) - def __enter__(self): + def __enter__(self) -> "RuntimeAssertion": """Enter the context manager, returns self for use in 'with' statement.""" return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[object], + ) -> None: """Exit the context manager, automatically calls verify().""" if exc_type is None: # Only verify if no exception occurred in the with block self.verify() - # Return False to propagate any exception that occurred - return False -def _maybe_recast_tensor_from_f4(src: Tensor, tv_layout: Layout): - if src.element_type.width == 4: +def _maybe_recast_tensor_from_f4_f6( + src: Tensor, tv_layout: Layout +) -> tuple[Tensor, Layout]: + if src.element_type.width == 4: # type: ignore[union-attr] tv_layout = recast_layout(8, 4, tv_layout) src = recast_tensor(src, dtype=Int8) + elif src.element_type.width == 6: # type: ignore[union-attr] + tv_layout = recast_layout(8, 6, tv_layout) + src = recast_tensor(src, dtype=Int8) return src, tv_layout -def _maybe_recast_to_f4(input: TensorSSA, dtype: Type[Numeric]): - """Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit. +def _maybe_recast_to_f4_f6(input: TensorSSA, dtype: Type[Numeric]) -> TensorSSA: + """Conditionally recasts the tensor to 4-bit or 6-bit type if the destination type is 4-bit or 6-bit. :param input: The input tensor to recast. :param dtype: The target numeric type to potentially recast to. :raises TypeError: If dtype is not a subclass of Numeric. - :return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged. + :return: A new tensor recast to 4-bit or 6-bit if dtype is 4-bit or 6-bit, otherwise returns self unchanged. """ if not inspect.isclass(dtype) or not issubclass(dtype, Numeric): raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}") @@ -312,16 +456,26 @@ def _maybe_recast_to_f4(input: TensorSSA, dtype: Type[Numeric]): [T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec] ) return TensorSSA(res_vect, recast_shape, dtype) + elif dtype.width == 6: + recast_shape = recast_layout(6, 8, make_layout(input.shape)).shape + n = input.type.shape[0] + assert (n * 8) % 6 == 0, ( + f"N * 8 must be divisible by 6 for fp6 unpacking, got N={n}" + ) + res_vect = vector.bitcast( + T.vector(n * 8 // 6, dtype.mlir_type), input.maybe_downcast() + ) + return TensorSSA(res_vect, recast_shape, dtype) return input -def _maybe_recast_from_f4(input: TensorSSA, src_dtype: Type[Numeric]): - """Conditionally recasts the tensor from 4-bit type if the source type is 4-bit. +def _maybe_recast_from_f4_f6(input: TensorSSA, src_dtype: Type[Numeric]) -> TensorSSA: + """Conditionally recasts the tensor from 4-bit or 6-bit type if the source type is 4-bit or 6-bit. :param input: The input tensor to recast. :param src_dtype: The source numeric type to potentially recast from. :raises TypeError: If src_dtype is not a subclass of Numeric. - :return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged. + :return: A new tensor recast from 4-bit or 6-bit if src_dtype is 4-bit or 6-bit, otherwise returns self unchanged. """ if not inspect.isclass(src_dtype) or not issubclass(src_dtype, Numeric): raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}") @@ -333,6 +487,14 @@ def _maybe_recast_from_f4(input: TensorSSA, src_dtype: Type[Numeric]): ) res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec) return TensorSSA(res_vect, recast_shape, Int8) + elif src_dtype.width == 6: + recast_shape = recast_layout(8, 6, make_layout(input.shape)).shape + n = input.type.shape[0] + assert (n * 6) % 8 == 0, ( + f"N * 6 must be divisible by 8 for i8 packing, got N={n}" + ) + res_vect = vector.bitcast(T.vector(n * 6 // 8, T.i8()), input.maybe_downcast()) + return TensorSSA(res_vect, recast_shape, Int8) return input @@ -344,9 +506,9 @@ def _convert_kernel( src_tv_layout: Layout, dst_tv_layout: Layout, src_shape: Shape, - src_ty, - dst_ty, -): + src_ty: Type[Numeric], + dst_ty: Type[Numeric], +) -> None: tidx = nvvm.read_ptx_sreg_tid_x(T.i32()) bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32()) @@ -359,9 +521,9 @@ def _convert_kernel( # compose with CTA TV layout # tid, vid -> address - tidfrgSrc = composition(ctaSrc, src_tv_layout) # (T,V) - tidfrgDst = composition(ctaDst, dst_tv_layout) # (T,V) - tidfrgCSrc = composition(ctaCSrc, src_tv_layout) # (T,V) + tidfrgSrc = composition(ctaSrc, src_tv_layout) # type: ignore[arg-type] # (T,V) + tidfrgDst = composition(ctaDst, dst_tv_layout) # type: ignore[arg-type] # (T,V) + tidfrgCSrc = composition(ctaCSrc, src_tv_layout) # type: ignore[arg-type] # (T,V) # print(f"tidfrgSrc = {tidfrgSrc.type}") # slice for threads @@ -387,9 +549,9 @@ def _convert_kernel( copy(copy_atom_load, thrSrc, frgSrc) vec_src = frgSrc.load() - vec_src = _maybe_recast_to_f4(vec_src, src_ty) + vec_src = _maybe_recast_to_f4_f6(vec_src, src_ty) vec_dst = vec_src.to(dst_ty) - vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty) + vec_dst = _maybe_recast_from_f4_f6(vec_dst, dst_ty) frgDst.store(vec_dst) # Copy the results back to c @@ -403,7 +565,7 @@ def _convert( dst: Tensor, leading_mode: Constexpr, elem_per_copy: Constexpr, -): +) -> None: # Step 1. figure proper tv_layout src_ty = src.element_type dst_ty = dst.element_type @@ -411,8 +573,8 @@ def _convert( tv_layout = make_layout((128, elem_per_copy), stride=(elem_per_copy, 1)) # Step 2. maybe recast from f4 tensor - src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout) - dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout) + src, src_tv_layout = _maybe_recast_tensor_from_f4_f6(src, tv_layout) + dst, dst_tv_layout = _maybe_recast_tensor_from_f4_f6(dst, tv_layout) src_shape = src.shape # predicate tensor idA = make_identity_tensor(src.shape) @@ -421,11 +583,11 @@ def _convert( src_cta_tiler = [ 1, ] * rank(src.layout) - src_cta_tiler[leading_mode] = size(src_tv_layout) # (...,TileV,...) + src_cta_tiler[leading_mode] = size(src_tv_layout) # type: ignore[call-overload] # (...,TileV,...) dst_cta_tiler = [ 1, ] * rank(dst.layout) - dst_cta_tiler[leading_mode] = size(dst_tv_layout) # (...,TileV,...) + dst_cta_tiler[leading_mode] = size(dst_tv_layout) # type: ignore[call-overload] # (...,TileV,...) # Step 4. partition input and output tensor by cta tiler. gS = zipped_divide(src, tuple(src_cta_tiler)) # ((...,TileV,...),(...,RestV,...)) @@ -452,29 +614,31 @@ def _convert( # And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of # their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext # needs 32-bits aligned input/output) -def convert(src: Tensor, dst: Tensor): - assert len(src.shape) == len(dst.shape), ( +def convert(src: Tensor, dst: Tensor) -> None: + assert len(src.shape) == len(dst.shape), ( # type: ignore[arg-type] "Shape of src and dst tensors should be the same rank." ) # find leading mode leading_mode = [ idx - for idx, (shape, stride) in enumerate(zip(src.shape, src.stride)) - if shape > 1 and stride == 1 + for idx, (shape, stride) in enumerate(zip(src.shape, src.stride)) # type: ignore[arg-type] + if shape > 1 and stride == 1 # type: ignore[operator] ] if len(leading_mode) != 1: raise ValueError(f"Leading mode should be unique, but got {leading_mode}") - leading_mode = leading_mode[0] + leading_mode = leading_mode[0] # type: ignore[assignment] elem_per_copy = 2 - if src.element_type.width == 4 or dst.element_type.width == 4: + if src.element_type.width == 4 or dst.element_type.width == 4: # type: ignore[union-attr] elem_per_copy = 8 - elif src.element_type.width == 8 or dst.element_type.width == 8: + elif src.element_type.width == 8 or dst.element_type.width == 8: # type: ignore[union-attr] elem_per_copy = 4 + elif src.element_type.width == 6 or dst.element_type.width == 6: # type: ignore[union-attr] + elem_per_copy = 16 # 16*f6 elements per 96 bits(12 bytes) assert ( - src.shape[leading_mode] % elem_per_copy == 0 - and dst.shape[leading_mode] % elem_per_copy == 0 + src.shape[leading_mode] % elem_per_copy == 0 # type: ignore[index, call-overload] + and dst.shape[leading_mode] % elem_per_copy == 0 # type: ignore[index, call-overload] ) _convert(src, dst, leading_mode, elem_per_copy) @@ -485,7 +649,7 @@ def convert(src: Tensor, dst: Tensor): ######################################### -def sample_pytest(rand_cfg=None): +def sample_pytest(rand_cfg: Optional[tuple[int, float]] = None) -> Callable[..., Any]: """ Decorator to randomly sample pytest parametrized tests. rand_cfg: Tuple[int, float] - (random_seed, sample_ratio) @@ -500,12 +664,12 @@ def sample_pytest(rand_cfg=None): import pytest - seed, sample_ratio = rand_cfg + seed, sample_ratio = rand_cfg # type: ignore[misc] random.seed(seed) - def decorator(func): + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ: # Check if test was explicitly selected like ::test_name[param1-param2-...] if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv): @@ -531,10 +695,10 @@ class JitArguments: A type to hold both args and kwargs for passing to a kernel while benchmarking. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.args = args self.kwargs = kwargs - self.references = list() + self.references: list[Any] = list() def add_to_scope(self, references: Any) -> None: """ @@ -545,8 +709,9 @@ class JitArguments: def _cuda_success( - err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str -): + err: Union[tuple[Any, ...], cuda_runtime.cudaError_t, cuda_driver.CUresult], + message: str, +) -> None: """ Helper function to check CUDA API errors. """ @@ -567,8 +732,8 @@ def _cuda_success( def _does_kernel_use_stream( - kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs -): + kernel: Callable[..., Any], stream: cuda_driver.CUstream, *args: Any, **kwargs: Any +) -> bool: """ This function checks if the kernel uses the provided non-default stream. It does this by capturing the stream and then checking if any kernels were launched. @@ -589,7 +754,16 @@ def _does_kernel_use_stream( ) _cuda_success(err, "Error on stream capture") - kernel(*args, **kwargs) + try: + kernel(*args, **kwargs) + except Exception: + # Always end the capture even on failure to avoid zombie capture state + # that would poison all subsequent graph capture operations in the process. + try: + cuda_runtime.cudaStreamEndCapture(stream) + except Exception: + pass + raise err, graph = cuda_runtime.cudaStreamEndCapture(stream) _cuda_success(err, "Error on stream capture") @@ -610,6 +784,7 @@ def benchmark( workspace_generator: Optional[Callable[[], JitArguments]] = None, workspace_count: int = 1, use_cuda_graphs: bool = False, + use_cupti: bool = False, ) -> float: """Benchmarks a callable function with the specified parameters. @@ -677,9 +852,8 @@ def benchmark( :return: The benchmark time in microseconds :rtype: float """ - - import cutlass.base_dsl.jit_executor as jit_executor - import cutlass.cutlass_dsl.cuda_jit_executor as cuda_jit_executor + import cutlass.base_dsl.jit_executor # noqa: F401 + import cutlass.cutlass_dsl.cuda_jit_executor # noqa: F401 if stream is None: stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) @@ -687,7 +861,7 @@ def benchmark( if workspace_count < 1: raise ValueError("workspace_count must be at least 1") - time_us = float("nan") + _time_us = float("nan") if workspace_generator == None: # If no workspace generator is provided, we need a single workspace if workspace_count != 1: @@ -708,7 +882,25 @@ def benchmark( "workspace_generator and/or kernel_arguments should use JitArguments type" ) - def _loop_and_call_kernel(iterations: int, workspace_index: int = 0): + # use memset to flush L2 cache after workspace h2d copies + if workspace_count > 1: + from cutlass.utils import HardwareInfo + + hardware_info = HardwareInfo() + num_l2_cache_bytes = hardware_info.get_l2_cache_size_in_bytes() + l2_flush_bytes = num_l2_cache_bytes * 2 + err, cache_ptr = cuda_driver.cuMemAlloc(int(l2_flush_bytes)) + _cuda_success(err, "Error on allocating memory") + + err = cuda_driver.cuMemsetD32Async( + cache_ptr, 0, int(l2_flush_bytes // 4), stream + ) + _cuda_success(err, "Error on memset") + + err = cuda_driver.cuMemFree(cache_ptr) + _cuda_success(err, "Error on freeing memory") + + def _loop_and_call_kernel(iterations: int, workspace_index: int = 0) -> int: for _ in range(iterations): current_workspace = workspaces[workspace_index] callable(*current_workspace.args, **current_workspace.kwargs) @@ -727,78 +919,47 @@ def benchmark( elapsed_time = float("nan") - if use_cuda_graphs: - # Check if the stream is a non-default stream - if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT): - raise ValueError( - "Measuring with CUDA Graphs requires executing in a non-default stream" + # ========================================================================= + # Helper: Measure kernel execution time using CUPTI profiler + # ========================================================================= + def _measure_with_cupti(kernel_launcher: Callable[[], Any]) -> float: + """ + Measure kernel execution time using NVIDIA CUPTI profiler. + :param kernel_launcher: Callable that launches the kernel(s) to be profiled + :type kernel_launcher: Callable + :return: Elapsed time in milliseconds + :rtype: float + """ + if not hasattr(kernel_launcher, "__call__"): + raise TypeError( + f"kernel_launcher must be callable, got {type(kernel_launcher).__name__}" ) - workspace_index = 0 + cupti_profiler = CuptiProfiler() - # Capture warmup graph - err = cuda_runtime.cudaStreamBeginCapture( - stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal - ) - _cuda_success(err, "Error on stream capture") + cupti_profiler.start() + kernel_launcher() - workspace_index = _loop_and_call_kernel(warmup_iterations) - err, gwarm = cuda_runtime.cudaStreamEndCapture(stream) - _cuda_success(err, "Error on stream capture") + err = cuda_runtime.cudaDeviceSynchronize() + _cuda_success(err, "Error on synchronizing device") - # Get number of nodes in warmup graph to check it matches what is expected - err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm) - _cuda_success(err, "Error on querying graph") - # Assertion is >= since we may launch multiple kernels in one host function - if num_nodes < warmup_iterations: - raise ValueError( - "CUDA stream passed to benchmark does not match the stream the kernel was launched in" + cupti_profiler.stop() + duration_ms = cupti_profiler.get_duration() + return duration_ms + + def _measure_with_cuda_event(kernel_launcher: Callable[[], Any]) -> float: + """ + Measure kernel execution time using CUDA events. + :param kernel_launcher: Callable that launches the kernel(s) to be profiled + :type kernel_launcher: Callable + :return: Elapsed time in milliseconds + :rtype: float + """ + if not hasattr(kernel_launcher, "__call__"): + raise TypeError( + f"kernel_launcher must be callable, got {type(kernel_launcher).__name__}" ) - # Capture profiling graph - err = cuda_runtime.cudaStreamBeginCapture( - stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal - ) - _cuda_success(err, "Error on stream capture") - _loop_and_call_kernel(iterations, workspace_index) - err, gprofile = cuda_runtime.cudaStreamEndCapture(stream) - _cuda_success(err, "Error on stream capture") - - # Instantiate graphs - err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0) - _cuda_success(err, "Error on graph instantiation") - err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0) - _cuda_success(err, "Error on graph instantiation") - - # Launch warmup graph - err = cuda_runtime.cudaGraphLaunch(gwarm, stream) - _cuda_success(err, "Error on graph launch") - - # Record start time - err = cuda_driver.cuEventRecord(start_event, stream) - _cuda_success(err, "Error on recording event") - - # Launch profiling graph - err = cuda_runtime.cudaGraphLaunch(gprofile, stream) - _cuda_success(err, "Error on graph launch") - - # Record end time - err = cuda_driver.cuEventRecord(end_event, stream) - _cuda_success(err, "Error on recording event") - err = cuda_driver.cuEventSynchronize(end_event) - _cuda_success(err, "Error on synchronizing event") - - # Get elapsed time - err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) - _cuda_success(err, "Error on querying event") - - # Destroy graphs - err = cuda_runtime.cudaGraphExecDestroy(gwarm) - _cuda_success(err, "Error on destroying graph") - err = cuda_runtime.cudaGraphExecDestroy(gprofile) - _cuda_success(err, "Error on destroying graph") - - else: if int(stream) != int( cuda_driver.CUstream_flags.CU_STREAM_DEFAULT ) and not _does_kernel_use_stream( @@ -808,21 +969,152 @@ def benchmark( "CUDA stream passed to benchmark does not match the stream the kernel was launched in" ) - # Not using graphs - # Warmup - workspace_index = _loop_and_call_kernel(warmup_iterations) - # Record start event err = cuda_driver.cuEventRecord(start_event, stream) - _cuda_success(err, "Error on recording event") - _loop_and_call_kernel(iterations, workspace_index) - # Record end event + _cuda_success(err, "Error on recording start event") + + kernel_launcher() + err = cuda_driver.cuEventRecord(end_event, stream) - _cuda_success(err, "Error on recording event") - # Synchronize end event + _cuda_success(err, "Error on recording end event") + err = cuda_driver.cuEventSynchronize(end_event) - _cuda_success(err, "Error on synchronizing event") - err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) - _cuda_success(err, "Error on querying event") + _cuda_success(err, "Error on synchronizing end event") + + err, duration_ms = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying elapsed time") + return duration_ms + + # ========================================================================= + # Branch 1: CUDA Graphs mode - Capture and replay kernel execution + # ========================================================================= + if use_cuda_graphs: + if hasattr(callable, "_dsl_cls"): + raise TypeError( + "Uncompiled @cute.jit function cannot be captured into a CUDA Graph. " + "Use cute.compile() first, or wrap compiled calls in a plain function." + ) + + # --------------------------------------------------------------------- + # Step 1: Capture warmup graph + # --------------------------------------------------------------------- + import gc as _gc + + # Disable GC during capture to prevent __del__ methods (e.g., cudaFree) + # from invalidating the capture with a non-capturable CUDA call. + _gc.collect() + _gc.disable() + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on beginning warmup stream capture") + + try: + warmup_workspace_idx = _loop_and_call_kernel(warmup_iterations) + except Exception: + _gc.enable() + try: + cuda_runtime.cudaStreamEndCapture(stream) + except Exception: + pass + raise + + err, warmup_graph = cuda_runtime.cudaStreamEndCapture(stream) + _gc.enable() + _cuda_success(err, "Error on ending warmup stream capture") + + # Validate warmup graph node count + # Each kernel launch should produce at least one graph node + err, _, warmup_node_count = cuda_runtime.cudaGraphGetNodes(warmup_graph) + _cuda_success(err, "Error on querying warmup graph nodes") + # Use >= since one host function may launch multiple kernels + if warmup_node_count < warmup_iterations: + raise ValueError( + "CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # --------------------------------------------------------------------- + # Step 2: Capture profiling graph + # --------------------------------------------------------------------- + _gc.collect() + _gc.disable() + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on beginning profiling stream capture") + + try: + _loop_and_call_kernel(iterations, warmup_workspace_idx) + except Exception: + _gc.enable() + try: + cuda_runtime.cudaStreamEndCapture(stream) + except Exception: + pass + raise + + err, profiling_graph = cuda_runtime.cudaStreamEndCapture(stream) + _gc.enable() + _cuda_success(err, "Error on ending profiling stream capture") + + # --------------------------------------------------------------------- + # Step 3: Instantiate executable graphs + # --------------------------------------------------------------------- + err, warmup_graph_exec = cuda_runtime.cudaGraphInstantiate(warmup_graph, 0) + _cuda_success(err, "Error on instantiating warmup graph") + err, profiling_graph_exec = cuda_runtime.cudaGraphInstantiate( + profiling_graph, 0 + ) + _cuda_success(err, "Error on instantiating profiling graph") + + # --------------------------------------------------------------------- + # Step 4: Execute warmup graph (cache warming) + # --------------------------------------------------------------------- + err = cuda_runtime.cudaGraphLaunch(warmup_graph_exec, stream) + _cuda_success(err, "Error on launching warmup graph") + + # --------------------------------------------------------------------- + # Step 5: Profile execution using selected profiler + # --------------------------------------------------------------------- + def launch_profiling_graph() -> None: + err = cuda_runtime.cudaGraphLaunch(profiling_graph_exec, stream) + _cuda_success(err, "Error on launching profiling graph") + + if use_cupti: + elapsed_time = _measure_with_cupti(launch_profiling_graph) + else: + elapsed_time = _measure_with_cuda_event(launch_profiling_graph) + + # --------------------------------------------------------------------- + # Step 6: Cleanup - Destroy graph executables + # --------------------------------------------------------------------- + err = cuda_runtime.cudaGraphExecDestroy(warmup_graph_exec) + _cuda_success(err, "Error on destroying warmup graph executable") + err = cuda_runtime.cudaGraphExecDestroy(profiling_graph_exec) + _cuda_success(err, "Error on destroying profiling graph executable") + + # ========================================================================= + # Branch 2: CUPTI profiler mode (without CUDA Graphs) + # ========================================================================= + elif use_cupti: + # Warmup iterations to stabilize GPU state + warmup_workspace_idx = _loop_and_call_kernel(warmup_iterations) + + def run_profiling_iterations() -> None: + _loop_and_call_kernel(iterations, warmup_workspace_idx) + + elapsed_time = _measure_with_cupti(run_profiling_iterations) + + # ========================================================================= + # Branch 3: CUDA event profiler mode (default) + # ========================================================================= + else: + # Warmup iterations to stabilize GPU state + warmup_workspace_idx = _loop_and_call_kernel(warmup_iterations) + + def run_profiling_iterations() -> None: + _loop_and_call_kernel(iterations, warmup_workspace_idx) + + elapsed_time = _measure_with_cuda_event(run_profiling_iterations) # Destroy events err = cuda_driver.cuEventDestroy(start_event) @@ -861,14 +1153,14 @@ def get_workspace_count( def _benchmark_for_autotune( - callable: Callable, - *args, + callable: Callable[..., Any], + *args: Any, warmup_iterations: int, iterations: int, use_cold_l2: bool, print_verbose: bool, current_stream: Optional[cuda_driver.CUstream] = None, - **kwargs, + **kwargs: Any, ) -> float: """Benchmarks a callable function with the specified parameters. @@ -932,7 +1224,7 @@ def _benchmark_for_autotune( for _ in range(warmup_iterations): callable(*args, **kwargs) - time = 0 + _time = 0 execution_time_ms = [] for _ in range(iterations): if use_cold_l2: @@ -998,10 +1290,10 @@ class autotune_jit: the autotuner will not recompile the kernel. """ - logger = None + logger: Optional[logging.Logger] = None @classmethod - def _initialize_logger(cls): + def _initialize_logger(cls) -> None: """Ensure the logger is initialized""" if cls.logger is None: cls.logger = logging.getLogger(__name__ + "_Autotune") @@ -1020,8 +1312,12 @@ class autotune_jit: @classmethod def _create_tuning_wrapper( - cls, func, warmup_iterations, iterations, autotune_update_params - ): + cls, + func: Callable[..., Any], + warmup_iterations: int, + iterations: int, + autotune_update_params: list[str], + ) -> Callable[..., Any]: """Create a wrapper function that performs auto-tuning Args: @@ -1034,18 +1330,18 @@ class autotune_jit: # Initialize autotune parameters if not hasattr(func, "_autotune_params"): - func._original_func = func - func._autotune_params = {} - func._autotune_update_params = autotune_update_params - func._best_kernel = dict() - func._best_config = dict() + func._original_func = func # type: ignore[attr-defined] + func._autotune_params = {} # type: ignore[attr-defined] + func._autotune_update_params = autotune_update_params # type: ignore[attr-defined] + func._best_kernel = dict() # type: ignore[attr-defined] + func._best_config = dict() # type: ignore[attr-defined] # Create wrapper function for auto-tuning @functools.wraps(func) - def tuning_wrapper(*args, **kwargs): - parameters = inspect.signature(func._original_func).parameters.keys() - tuning_key = list() - for param_name in func._autotune_update_params: + def tuning_wrapper(*args: Any, **kwargs: Any) -> Any: + parameters = inspect.signature(func._original_func).parameters.keys() # type: ignore[attr-defined] + tuning_key: Any = list() + for param_name in func._autotune_update_params: # type: ignore[attr-defined] if param_name in kwargs.keys(): tuning_key.append(kwargs[param_name]) else: @@ -1053,14 +1349,14 @@ class autotune_jit: if index < len(args): tuning_key.append(args[index]) tuning_key = tuple(tuning_key) - if tuning_key in func._best_kernel.keys(): - cls.logger.info( - f"Using cached best configuration: {func._best_config[tuning_key]}" + if tuning_key in func._best_kernel.keys(): # type: ignore[attr-defined] + cls.logger.info( # type: ignore[union-attr] + f"Using cached best configuration: {func._best_config[tuning_key]}" # type: ignore[attr-defined] ) - return func._best_kernel[tuning_key](*args, **kwargs) + return func._best_kernel[tuning_key](*args, **kwargs) # type: ignore[attr-defined] # Get all parameter configurations - params_dict = func._autotune_params + params_dict = func._autotune_params # type: ignore[attr-defined] keys = list(params_dict.keys()) values = list(params_dict.values()) @@ -1074,7 +1370,7 @@ class autotune_jit: for config_values in product(*values): # Build current configuration current_config = dict(zip(keys, config_values)) - cls.logger.info(f"Tuning configuration: {current_config}") + cls.logger.info(f"Tuning configuration: {current_config}") # type: ignore[union-attr] try: # Call the original function, using current configuration to replace default parameters @@ -1082,21 +1378,23 @@ class autotune_jit: # It will override func's default parameter value merged_kwargs = {**kwargs, **current_config} compiled_func = compile( - func._original_func, *args, **merged_kwargs + func._original_func, # type: ignore[attr-defined] + *args, + **merged_kwargs, ) # Detect which constexpr arguments we need to remove from args and merged_kwargs # This is done because after compiling our function signature will change, removing all constexpr arguments. indexes_to_remove = list() - for arg in compiled_func.args_spec.get_constexpr_args(): + for arg in compiled_func.execution_args.get_constexpr_args(): if arg["argument_name"] in merged_kwargs: del merged_kwargs[arg["argument_name"]] elif arg["argument_index"] is not None: indexes_to_remove.append(arg["argument_index"]) - if arg["argument_name"] not in func._autotune_update_params: + if arg["argument_name"] not in func._autotune_update_params: # type: ignore[attr-defined] # Handle the case where the programmer avoided autotuning over constexpr values, and # recompile in that case - func._autotune_update_params.append( + func._autotune_update_params.append( # type: ignore[attr-defined] arg["argument_name"] ) @@ -1116,7 +1414,7 @@ class autotune_jit: **merged_kwargs, ) - cls.logger.info(f" Execution time: {cur_time} us") + cls.logger.info(f" Execution time: {cur_time} us") # type: ignore[union-attr] # Update best results if cur_time < min_time: @@ -1125,16 +1423,16 @@ class autotune_jit: best_config = current_config except NotImplementedError as e: - cls.logger.info( + cls.logger.info( # type: ignore[union-attr] f" Encountered unimplemented error, abort execution: {e}" ) raise e except (ValueError, TypeError) as e: - cls.logger.info(f" Configuration parameter skipping: {e}") + cls.logger.info(f" Configuration parameter skipping: {e}") # type: ignore[union-attr] raise e continue except Exception as e: - cls.logger.info(f" Execution error skipping: {e}") + cls.logger.info(f" Execution error skipping: {e}") # type: ignore[union-attr] raise e continue @@ -1144,12 +1442,12 @@ class autotune_jit: if best_kernel is None: raise ValueError("No best kernel found") - cls.logger.info( + cls.logger.info( # type: ignore[union-attr] f"Best configuration: {best_config}, execution time: {min_time} us" ) - cls.logger.info(f"Total tuning time: {tuning_time} s") - func._best_kernel[tuning_key] = best_kernel - func._best_config[tuning_key] = best_config + cls.logger.info(f"Total tuning time: {tuning_time} s") # type: ignore[union-attr] + func._best_kernel[tuning_key] = best_kernel # type: ignore[attr-defined] + func._best_config[tuning_key] = best_config # type: ignore[attr-defined] return best_kernel(*args, **kwargs) # Append autotune wrapper to not conflict with the jit kernel names @@ -1162,11 +1460,11 @@ class autotune_jit: def __init__( self, - params_dict: Dict[str, List[Any]] = None, - update_on_change: List[str] = None, - warmup_iterations=10, - iterations=100, - ): + params_dict: Optional[Dict[str, List[Any]]] = None, + update_on_change: Optional[List[str]] = None, + warmup_iterations: int = 10, + iterations: int = 100, + ) -> None: """Initialize the autotune_jit decorator. :param params_dict: Dictionary containing parameter names and their possible values @@ -1189,7 +1487,7 @@ class autotune_jit: self.warmup_iterations = warmup_iterations self.iterations = iterations - def __call__(self, func): + def __call__(self, func: Callable[..., Any]) -> Callable[..., Any]: """Called when class instance is used as a decorator. :param func: Function to be decorated @@ -1215,11 +1513,11 @@ class autotune_jit: def tune( - func: Callable[[Any], Callable[[], Any]], - params_dict: Dict[str, List[Any]] = None, + func: Callable[..., Callable[[], Any]], + params_dict: Optional[Dict[str, List[Any]]] = None, kernel_arguments: JitArguments = JitArguments(), - warmup_iterations=10, - iterations=100, + warmup_iterations: int = 10, + iterations: int = 100, stream: Optional[cuda_driver.CUstream] = None, ) -> Dict[str, Any]: """Tuning tool to suport arbitrary functions. The user must provide a function that returns a callable, which @@ -1266,6 +1564,9 @@ def tune( if stream is None: stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) + if params_dict is None: + raise ValueError("params_dict must be provided") + # Get all parameter configurations keys = list(params_dict.keys()) values = list(params_dict.values()) @@ -1327,12 +1628,102 @@ def tune( class CantImplementError(Exception): """Exception raised when a function is not implemented.""" - def __init__(self, message=None): + def __init__(self, message: Optional[str] = None) -> None: self.message = message or "The current config is invalid/unsupported" super().__init__(self.message) - def __str__(self): + def __str__(self) -> str: return self.message - def __repr__(self): + def __repr__(self) -> str: return self.message + + +######################################### +# Tensor initialization configuration +######################################### + + +@dataclass(frozen=True) +class TensorInitConfig: + """Configuration for tensor initialization policy. + + When init_normal=True, tensors are initialized from a normal distribution + with the specified mean and std. Int8/Uint8 dtypes always use random + integer initialization regardless of this flag. + """ + + init_normal: bool = False + normal_mean: float = 0.0 + normal_std: float = 1.0 + + +def add_tensor_init_args( + parser: argparse.ArgumentParser, + supports_int_dtypes: bool = True, +) -> None: + """Add --init_normal, --normal_mean, --normal_std arguments to a parser. + + :param parser: ArgumentParser to add arguments to. + :param supports_int_dtypes: If True, appends Int8/Uint8 caveat to --init_normal + help text. Set to False for files whose ab_dtype choices do not include + Int8/Uint8 (e.g. grouped_gemm, dense_blockscaled_gemm_persistent). + """ + init_normal_help = ( + "Use normal distribution for tensor initialization instead of random integers." + ) + if supports_int_dtypes: + init_normal_help += ( + " Note: Int8/Uint8 dtypes always use random init regardless of this flag" + ) + parser.add_argument( + "--init_normal", + action="store_true", + help=init_normal_help, + ) + parser.add_argument( + "--normal_mean", + type=float, + default=0.0, + help="Mean for normal distribution initialization", + ) + parser.add_argument( + "--normal_std", + type=float, + default=1.0, + help="Standard deviation for normal distribution initialization (must be >= 0)", + ) + + +def validate_tensor_init_args( + args: argparse.Namespace, + parser: argparse.ArgumentParser, +) -> None: + """Validate tensor init arguments after parse_args(). + + :param args: Parsed arguments namespace. + :param parser: Parser instance (used for error reporting). + """ + if args.normal_std < 0: + parser.error("--normal_std must be non-negative") + + +def tensor_init_config_from_args(args: argparse.Namespace) -> TensorInitConfig: + """Extract a TensorInitConfig from parsed arguments.""" + return TensorInitConfig( + init_normal=args.init_normal, + normal_mean=args.normal_mean, + normal_std=args.normal_std, + ) + + +def should_use_normal_init( + config: TensorInitConfig, + dtype: Type[Numeric], +) -> bool: + """Determine whether normal initialization should be used for the given dtype. + + Returns False if config.init_normal is False or if dtype is Int8/Uint8 + (which do not support normal distribution initialization). + """ + return config.init_normal and dtype not in (Int8, Uint8) diff --git a/python/CuTeDSL/cutlass/cute/tuple.py b/python/CuTeDSL/cutlass/cute/tuple.py index 66a334339..545e9f93b 100644 --- a/python/CuTeDSL/cutlass/cute/tuple.py +++ b/python/CuTeDSL/cutlass/cute/tuple.py @@ -11,16 +11,14 @@ from inspect import signature from itertools import chain -from typing import Any, Callable, Union, Tuple, List, Iterable +from typing import Any, Callable, Optional, Union, Tuple, List, Iterable + +from cutlass._mlir import ir from cutlass.cutlass_dsl import is_dynamic_expression, dsl_user_op -from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir from .typing import ( - ComposedLayout, - Layout, - Stride, XTuple, IntTuple, Shape, @@ -30,7 +28,7 @@ from .typing import ( ) -def wrap(x) -> Tuple[Any, ...]: +def wrap(x: XTuple) -> Tuple[Any, ...]: """ Wraps the input into a tuple if not a tuple. """ @@ -39,6 +37,23 @@ def wrap(x) -> Tuple[Any, ...]: return (x,) +def unwrap(x: XTuple) -> XTuple: + """ + Unwraps the input tuple if it is a single-element tuple, otherwise returns the input. + + Example: + >>> unwrap((1,)) + 1 + >>> unwrap(((1, 2, 3),)) + (1, 2, 3) + >>> unwrap((1, 2, 3)) + (1, 2, 3) + """ + while isinstance(x, tuple) and len(x) == 1: + x = x[0] + return x + + def flatten_to_tuple(a: XTuple) -> Tuple[Any, ...]: """Flattens a potentially nested tuple structure into a flat tuple. @@ -89,7 +104,7 @@ def unflatten( unflatten([1, 2, 3, 4], ((0, 0), (0, 0))) # Returns ((1, 2), (3, 4)) """ - def _make_generator(): + def _make_generator() -> Any: for element in sequence: yield element @@ -98,7 +113,12 @@ def unflatten( @dsl_user_op -def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): +def product( + a: Union[IntTuple, Shape], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> IntTuple: # Local import to avoid circular dependency from .core import _pack_int_tuple, _unpack_x_tuple @@ -129,7 +149,13 @@ def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): @dsl_user_op -def product_like(a: IntTuple, target_profile: XTuple, *, loc=None, ip=None) -> IntTuple: +def product_like( + a: IntTuple, + target_profile: XTuple, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> IntTuple: """Return product of the given IntTuple or Shape at leaves of `target_profile`. This function computes products according to the structure defined by target_profile. @@ -161,7 +187,12 @@ def product_like(a: IntTuple, target_profile: XTuple, *, loc=None, ip=None) -> I @dsl_user_op -def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: +def product_each( + a: IntTuple, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> IntTuple: from .core import _pack_int_tuple, _unpack_x_tuple """Compute products for each component of the input. @@ -193,11 +224,9 @@ def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: def find_if( - t: Union[tuple, ir.Value, int], - pred_fn: Callable[[Union[tuple, ir.Value, int], int], bool], - *, - loc=None, - ip=None, + t: XTuple, + pred_fn: Callable[[XTuple, int], bool], + hierarchical: bool = True, ) -> Union[int, Tuple[int, ...], None]: from .core import rank, get @@ -232,13 +261,13 @@ def find_if( find_if(stride, pred_fn=pred_fn) """ - def _find_if_impl(curr, pos, *, loc=None, ip=None): + def _find_if_impl(curr: Any, pos: Any) -> Any: if isinstance(curr, tuple): # Recursively search nested tuple for i in range(rank(curr)): - sub_curr = get(curr, mode=[i], loc=loc, ip=ip) + sub_curr = get(curr, mode=[i]) sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) - res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) + res_pos = _find_if_impl(sub_curr, sub_pos) if res_pos is not None: return res_pos else: @@ -247,29 +276,33 @@ def find_if( return pos return None - def _check_pred_fn(): - if not callable(pred_fn): - raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") + if not callable(pred_fn): + raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") - sig = signature(pred_fn) - if len(sig.parameters) != 2: - raise ValueError( - f"pred_fn must have two parameters (value, pos), but got {len(sig.parameters)}" - ) + sig = signature(pred_fn) - _check_pred_fn() + if len(sig.parameters) != 2: + raise ValueError( + f"pred_fn must have two parameters (value, pos), but got {len(sig.parameters)}" + ) for i in range(rank(t)): - curr = get(t, mode=[i], loc=loc, ip=ip) - res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) + curr = get(t, mode=[i]) + res_pos = _find_if_impl(curr, i) if res_pos is not None: - return res_pos + if hierarchical: + return res_pos + else: + return ( + res_pos + if not isinstance(res_pos, tuple) + else flatten_to_tuple(res_pos)[0] + ) return None -@dsl_user_op def find( - t: Union[tuple, ir.Value, int], x: int, *, loc=None, ip=None + t: XTuple, x: int, hierarchical: bool = True ) -> Union[int, Tuple[int, ...], None]: """Find the first position of a value ``x`` in a hierarchical structure ``t``. @@ -278,7 +311,7 @@ def find( and returns either a single index or a tuple of indices for nested positions. :param t: The search space - :type t: Union[tuple, ir.Value, int] + :type t: XTuple :param x: The static integer x to search for :type x: int :return: Index if found at top level, tuple of indices showing nested position, or None if not found @@ -287,14 +320,14 @@ def find( if not isinstance(x, int): raise TypeError(f"find() requires a static x to search for, but got {x}") - def pred_fn(val, pos): + def pred_fn(val: Any, pos: Any) -> bool: # Skip dynamic values which can't be compared return not is_dynamic_expression(val) and val == x - return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) + return find_if(t, pred_fn=pred_fn, hierarchical=hierarchical) -def transform_leaf(f, *args): +def transform_leaf(f: Callable[..., XTuple], *args: XTuple) -> XTuple: """ Apply a function to the leaf nodes of nested tuple structures. @@ -331,8 +364,8 @@ def elem_less( lhs: Union[Shape, IntTuple, Coord], rhs: Union[Shape, IntTuple, Coord], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Boolean: from .core import _pack_coord @@ -342,7 +375,7 @@ def elem_less( return Boolean(_cute_ir.elem_less(lhs_val, rhs_val, loc=loc, ip=ip)) -def tuple_cat(*tuples): +def tuple_cat(*tuples: XTuple) -> Tuple[Any, ...]: """Concatenate multiple tuples into a single tuple. This function takes any number of tuples and concatenates them into a single tuple. @@ -364,7 +397,7 @@ def tuple_cat(*tuples): >>> tuple_cat(1, (2, 3)) (1, 2, 3) """ - result = () + result: Tuple[Any, ...] = () for t in tuples: if isinstance(t, tuple): result += t @@ -373,7 +406,9 @@ def tuple_cat(*tuples): return result -def transform_apply(*args, f: Callable, g: Callable): +def transform_apply( + *args: XTuple, f: Callable[..., XTuple], g: Callable[..., XTuple] +) -> XTuple: """Transform elements of tuple(s) with f, then apply g to all results. This function applies f to corresponding elements across input tuple(s), @@ -409,31 +444,40 @@ def transform_apply(*args, f: Callable, g: Callable): >>> transform_apply((1, 2), (3, 4), f=lambda x, y: x + y, g=lambda *args: args) (4, 6) """ - if not isinstance(f, Callable): + if not callable(f): raise TypeError(f"f must be callable, but got {type(f)}") - if not isinstance(g, Callable): + if not callable(g): raise TypeError(f"g must be callable, but got {type(g)}") if not args: raise ValueError("transform_apply requires at least one argument") + def _compatible_xtuples(args: XTuple) -> bool: + if isinstance(args[0], tuple): + if not all(isinstance(arg, tuple) for arg in args): + return False + tuple_length = len(args[0]) + for i, arg in enumerate(args, 1): + if len(arg) != tuple_length: + return False + for i in range(tuple_length): + if not _compatible_xtuples(tuple([arg[i] for arg in args])): + return False + return True + else: + return all(not isinstance(arg, tuple) for arg in args) + + if not _compatible_xtuples(args): + raise ValueError("All arguments must be congruent") + # Check if first argument is a tuple to determine behavior if isinstance(args[0], tuple): - # Verify all args are tuples of the same length - if not all(isinstance(arg, tuple) for arg in args): - raise TypeError("All arguments must be tuples or all must be non-tuples") - - tuple_length = len(args[0]) - for i, arg in enumerate(args[1:], 1): - if len(arg) != tuple_length: - raise ValueError( - f"All tuple arguments must have the same length. " - f"arg[0] has length {tuple_length}, but arg[{i}] has length {len(arg)}" - ) + flat_args = [flatten_to_tuple(arg) for arg in args] + tuple_length = len(flat_args[0]) # Apply f to corresponding elements across all tuples: g(f(args[0][i], args[1][i], ...), ...) transformed_results = tuple( - f(*(arg[i] for arg in args)) for i in range(tuple_length) + f(*(arg[i] for arg in flat_args)) for i in range(tuple_length) ) return g(*transformed_results) else: @@ -442,7 +486,7 @@ def transform_apply(*args, f: Callable, g: Callable): return g(result) -def filter_tuple(*args, f: Callable): +def filter_tuple(*args: XTuple, f: Callable[..., Tuple[Any, ...]]) -> Tuple[Any, ...]: """Filter and flatten tuple elements by applying a function. The function f should return tuples, which are then concatenated together @@ -471,7 +515,7 @@ def filter_tuple(*args, f: Callable): >>> filter_tuple((1, 2, 3), lambda x: (x, x)) (1, 1, 2, 2, 3, 3) """ - if not isinstance(f, Callable): + if not callable(f): raise TypeError(f"f must be callable, but got {type(f)}") return transform_apply(*args, f=f, g=lambda *args: tuple_cat(*args)) @@ -490,4 +534,6 @@ __all__ = [ "tuple_cat", "transform_apply", "filter_tuple", + "unwrap", + "wrap", ] diff --git a/python/CuTeDSL/cutlass/cute/typing.py b/python/CuTeDSL/cutlass/cute/typing.py index 77e171598..59c366146 100644 --- a/python/CuTeDSL/cutlass/cute/typing.py +++ b/python/CuTeDSL/cutlass/cute/typing.py @@ -13,22 +13,66 @@ from __future__ import annotations from abc import ABC, abstractmethod import ctypes -from typing import ForwardRef, Tuple, Union, Any, Type, List, Optional, Literal +from typing import ( + ForwardRef, + Tuple, + Union, + Any, + Type, + List, + Optional, + Literal, + TYPE_CHECKING, +) -from cutlass.base_dsl.typing import * +from cutlass.cutlass_dsl import T +from cutlass.base_dsl.typing import ( + Numeric, + NumericMeta, + Integer, + Boolean, + Int4, + Int8, + Int16, + Int32, + Int64, + Int128, + Uint8, + Uint16, + Uint32, + Uint64, + Float, + Float16, + BFloat16, + TFloat32, + Float32, + Float64, + Float8E5M2, + Float8E4M3FN, + Float8E4M3B11FNUZ, + Float8E4M3, + Float8E8M0FNU, + Float4E2M1FN, + Float6E2M3FN, + Float6E3M2FN, + as_numeric, +) from cutlass._mlir import ir +import cutlass._mlir.dialects.cute as _cute_ir from cutlass._mlir.dialects.cute import AddressSpace, ConstrainedIntType -from cutlass.base_dsl.typing import JitArgument - Int = Union[int, Integer] class SymInt: def __init__( - self, width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None - ): + self, + width: Literal[32, 64] = 32, + *, + divisibility: int = 1, + symbol: str | None = None, + ) -> None: if width not in [32, 64]: raise ValueError(f"Unsupported width: {width}") @@ -36,19 +80,19 @@ class SymInt: self._divisibility = divisibility self._symbol = symbol - def __hash__(self): + def __hash__(self) -> int: return hash((self._width, self._divisibility, self._symbol)) @property - def width(self): + def width(self) -> int: return self._width @property - def divisibility(self): + def divisibility(self) -> int: return self._divisibility @property - def symbol(self): + def symbol(self) -> str | None: return self._symbol def __str__(self) -> str: @@ -61,7 +105,7 @@ class SymInt: def __repr__(self) -> str: return self.__str__() - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: if not isinstance(other, SymInt): return False @@ -111,7 +155,7 @@ class SymInt: def __rmul__(self, other: int | SymInt) -> SymInt: return self.__mul__(other) - def __c_pointers__(self): + def __c_pointers__(self) -> List[int | None]: return [ctypes.c_void_p(0).value] def __get_mlir_types__(self) -> List[ir.Type]: @@ -120,32 +164,36 @@ class SymInt: ) return [res_ty] - def __new_from_mlir_values__(self, values) -> SymInt: + def __new_from_mlir_values__(self, values: List[ir.Value]) -> SymInt: from .core import IntValue if self.width == 32: - return Int32(IntValue(values[0])) + return Int32(IntValue(values[0])) # type: ignore[return-value] elif self.width == 64: - return Int64(IntValue(values[0])) + return Int64(IntValue(values[0])) # type: ignore[return-value] else: assert False, f"Unsupported width: {self.width}" return self def sym_int( - width: Literal[32, 64] = 32, *, divisibility=1, symbol: str | None = None + width: Literal[32, 64] = 32, *, divisibility: int = 1, symbol: str | None = None ) -> SymInt: return SymInt(width, divisibility=divisibility, symbol=symbol) -def sym_int32(divisibility=1, symbol: str | None = None) -> SymInt: +def sym_int32(divisibility: int = 1, symbol: str | None = None) -> SymInt: return sym_int(32, divisibility=divisibility, symbol=symbol) -def sym_int64(divisibility=1, symbol: str | None = None) -> SymInt: +def sym_int64(divisibility: int = 1, symbol: str | None = None) -> SymInt: return sym_int(64, divisibility=divisibility, symbol=symbol) -ScaledBasis = ForwardRef("ScaledBasis") +if TYPE_CHECKING: + from cutlass.cute.core import ScaledBasis, Swizzle + from cutlass.cute.tensor import TensorSSA +else: + ScaledBasis = ForwardRef("ScaledBasis") IntTuple = Union[Int, Tuple["IntTuple", ...]] Shape = Union[Int, Tuple["Shape", ...]] @@ -154,20 +202,35 @@ Coord = Union[Int, None, Tuple["Coord", ...]] class Layout(ir.Value): - def __init__(self, op_result): + def __init__(self, op_result: ir.Value) -> None: super().__init__(op_result) - def __str__(self) -> str: ... + def __str__(self) -> str: + return super().__str__() # pragma: no cover - def get_hier_coord(self, idx) -> Coord: + def get_hier_coord(self, idx: Int) -> Coord: """Return the (hierarchical) ND logical coordinate corresponding to the linear index""" ... @property - def shape(self, *, loc=None, ip=None) -> Shape: ... + def shape( # type: ignore[empty-body] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: + """Implemented by subclasses.""" + ... @property - def stride(self, *, loc=None, ip=None) -> Stride: ... + def stride( # type: ignore[empty-body] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Stride: + """Implemented by subclasses.""" + ... class ComposedLayout(ABC): @@ -235,22 +298,47 @@ class ComposedLayout(ABC): @property @abstractmethod - def inner(self, *, loc=None, ip=None): ... + def inner( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Union[Layout, "Swizzle"]: ... @property @abstractmethod - def offset(self, *, loc=None, ip=None) -> IntTuple: ... + def offset( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> IntTuple: ... @property @abstractmethod - def outer(self, *, loc=None, ip=None) -> Layout: ... + def outer( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Layout: ... @property @abstractmethod - def shape(self, *, loc=None, ip=None): ... + def shape( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Shape: ... @abstractmethod - def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: ... + def __call__( + self, + coord: Coord, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> IntTuple: ... Tile = Union[Int, None, Layout, Tuple["Tile", ...]] @@ -266,22 +354,180 @@ class Pointer(ABC): Abstract base class for CuTe jit function and runtime _Pointer """ + value: ir.Value + + @property + def type(self) -> ir.Type: + """The MLIR type of this pointer. Implemented by subclasses.""" + ... + @property def value_type(self) -> Type[Numeric]: return self.dtype @property - def dtype(self) -> Type[Numeric]: ... + def dtype(self) -> Type[Numeric]: # type: ignore[empty-body] + """Implemented by subclasses.""" + ... - def align(self, min_align: int) -> "Pointer": ... + @property + def memspace(self) -> AddressSpace: + """The memory address space of this pointer. Implemented by subclasses.""" + ... - def __add__(self, other: int, *, loc=None, ip=None) -> "Pointer": ... + @property + def max_alignment(self) -> int: # type: ignore[empty-body] + """Maximum alignment of this pointer in bytes. Implemented by subclasses.""" + ... - def __get_mlir_types__(self) -> List[ir.Type]: ... + @property + def llvm_ptr( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> ir.Value: + """Get the LLVM pointer representation. Implemented by subclasses.""" + ... - def __extract_mlir_values__(self) -> List[ir.Value]: ... + def toint( # type: ignore[empty-body] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Numeric: + """Convert pointer to integer. Implemented by subclasses.""" + ... - def __new_from_mlir_values__(self, values) -> "Pointer": ... + def align(self, min_align: int) -> "Pointer": # type: ignore[empty-body] + """Implemented by subclasses.""" + ... + + def __add__( # type: ignore[empty-body] + self, + other: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "Pointer": + """Implemented by subclasses.""" + ... + + def __get_mlir_types__(self) -> List[ir.Type]: # type: ignore[empty-body] + """Implemented by subclasses.""" + ... + + def __extract_mlir_values__(self) -> List[ir.Value]: # type: ignore[empty-body] + """Implemented by subclasses.""" + ... + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "Pointer": # type: ignore[empty-body] + """Implemented by subclasses.""" + ... + + +class TypedTensor: + r"""A compile-time type descriptor for a statically-typed CuTe tensor. + + ``TypedTensor`` captures the element type, shape, stride, memory space, and + pointer alignment of a tensor at compile time. + + The preferred way to construct a ``TypedTensor`` is via the ``Tensor`` + subscript syntax: + + .. code-block:: python + + # equivalent to TypedTensor(cutlass.Float32, (16,), (1,)) + ty = Tensor[cutlass.Float32, (16,), (1,)] + + @cute.jit + def kernel(a: Tensor[cutlass.Float32, (16,), (1,)]): + ... + + :param dtype: Element type of the tensor + :param shape: Static shape of the tensor + :param stride: Static stride of the tensor + :param memspace: Memory space of the underlying pointer. Defaults to + ``AddressSpace.generic``. + :param assumed_align: Assumed byte alignment of the pointer. + + **Runtime type checking** + + Use :meth:`isinstance` to check whether a live ``Tensor`` value matches + this descriptor: + + .. code-block:: python + + tt = TypedTensor(cutlass.Float32, (16,), (1,), AddressSpace.gmem) + + @cute.jit + def kernel(a: cute.Tensor): + if tt.isinstance(a): + ... + """ + + def __init__( + self, + dtype: Type[Numeric], + shape: Shape, + stride: Stride, + memspace: AddressSpace = AddressSpace.generic, + assumed_align: int | None = None, + ): + self._dtype = dtype + self._shape = shape + self._stride = stride + self._memspace = memspace + self._assumed_align = assumed_align + if assumed_align is None: + # use the bytes width of the element dtype. The alignment is at least one byte align. + self._assumed_align = (self._dtype.width + 7) // 8 + + @property + def element_type(self) -> Type[Numeric]: + return self._dtype + + @property + def shape(self) -> Shape: + return self._shape + + @property + def stride(self) -> Stride: + return self._stride + + @property + def memspace(self) -> AddressSpace: + return self._memspace + + @property + def assumed_align(self) -> int | None: + return self._assumed_align + + def isinstance(self, other: object) -> bool: + if not isinstance(other, Tensor): + return False # pragma: no cover + mlir_type = other.__extract_mlir_values__()[0].type # type: ignore[attr-defined] + return mlir_type == self.mlir_type + + @property + def mlir_type(self) -> ir.Type: + shape_ty = _cute_ir.ShapeType.get_from_x_tuple(ir.Context.current, self._shape) + stride_ty = _cute_ir.StrideType.get_from_x_tuple( + ir.Context.current, self._stride + ) + layout_ty = _cute_ir.LayoutType.get(shape_ty, stride_ty) + + # Boolean types are stored as i8 in memory + elem_type = T.i8() if self._dtype.width == 1 else self._dtype.mlir_type + ptr_ty = _cute_ir.PtrType.get(elem_type, self._memspace, self._assumed_align) + + return _cute_ir.MemRefType.get(ptr_ty, layout_ty) + + def __get_mlir_types__(self) -> List[ir.Type]: + return [self.mlir_type] + + def __str__(self) -> str: + return f"Tensor<{self._dtype}, {self._shape}, {self._stride}>" class Tensor(ABC): @@ -346,21 +592,35 @@ class Tensor(ABC): print(c) # tensor([3, 7, 11], dtype=torch.int32) """ + value: ir.Value + + def __class_getitem__(cls, args: tuple) -> TypedTensor: + return TypedTensor(*args) + @abstractmethod def __str__(self) -> str: ... @abstractmethod - def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ... + def __getitem__( + self, idx: Union[Int, slice, Coord, Tuple] + ) -> Union["Tensor", ir.Value, IntTuple]: ... @abstractmethod - def __setitem__(self, idx, value): ... + def __setitem__( + self, idx: Union[Int, slice, Coord, Tuple], value: Union[Numeric, ir.Value] + ) -> None: ... @property @abstractmethod def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ... @element_type.setter - def element_type(self, new_type): ... + def element_type(self, new_type: Union[Type[Numeric], Type[IntTuple]]) -> None: ... + + @property + def dtype(self) -> Type[Numeric]: # type: ignore[empty-body] + """The element data type. Implemented by subclasses.""" + ... @property @abstractmethod @@ -371,32 +631,71 @@ class Tensor(ABC): def iterator(self) -> Union[Pointer, IntTuple]: ... @property - def layout(self) -> Union[Layout, "ComposedLayout"]: ... + @abstractmethod + def leading_dim(self) -> Union[int, Tuple[int, ...], None]: + """Get the leading dimension of this Tensor + (first mode from left to right with stride==1, shape!=1) + + :return: The leading dimension index or indices + :rtype: int or tuple or None + + The return value depends on the tensor's stride pattern: + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None + """ + ... @property + def layout(self) -> Union[Layout, "ComposedLayout"]: # type: ignore[empty-body] + """Implemented by subclasses.""" + ... + + @property + @abstractmethod def shape(self) -> Shape: ... @property + @abstractmethod def stride(self) -> Stride: ... - def load(self, *, loc=None, ip=None) -> "TensorSSA": ... + def load( # type: ignore[empty-body] + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "TensorSSA": + """Implemented by subclasses.""" + ... - def store(self, data: "TensorSSA", *, loc=None, ip=None): ... + def store( + self, + data: "TensorSSA", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: ... - def mark_layout_dynamic(self, leading_dim: Optional[int] = None) -> "Tensor": ... + def mark_layout_dynamic( # type: ignore[empty-body] + self, leading_dim: Optional[int] = None + ) -> "Tensor": + """Implemented by subclasses.""" + ... - def mark_compact_shape_dynamic( + def mark_compact_shape_dynamic( # type: ignore[empty-body] self, mode: int, stride_order: Optional[tuple[int, ...]] = None, divisibility: int = 1, - ) -> "Tensor": ... + ) -> "Tensor": + """Implemented by subclasses.""" + ... @abstractmethod def fill(self, value: Numeric) -> None: ... -def is_integer(a) -> bool: +def is_integer(a: object) -> bool: """Check if an object is static integer or dynamic integer""" return isinstance(a, (int, Integer)) or ( isinstance(a, ir.Value) @@ -404,7 +703,7 @@ def is_integer(a) -> bool: ) -def is_int_tuple(a) -> bool: +def is_int_tuple(a: object) -> bool: if isinstance(a, tuple): return all([is_int_tuple(x) for x in a]) else: @@ -417,6 +716,7 @@ __all__ = [ "sym_int32", "sym_int64", "Numeric", + "NumericMeta", "Integer", "Boolean", "Int4", @@ -424,6 +724,7 @@ __all__ = [ "Int16", "Int32", "Int64", + "Int128", "Uint8", "Uint16", "Uint32", @@ -451,9 +752,11 @@ __all__ = [ "ComposedLayout", "Pointer", "Tensor", + "TypedTensor", "Tile", "Tiler", "XTuple", + "as_numeric", "is_integer", "is_int_tuple", ] diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py b/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py index 6bc85cd4b..5d23e8971 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/__init__.py @@ -25,8 +25,6 @@ from ..base_dsl.ast_helpers import ( assert_executor, bool_cast, compare_executor, - any_executor, - all_executor, range_value_check, cf_symbol_check, redirect_builtin_function, @@ -38,6 +36,7 @@ from ..base_dsl.ast_helpers import ( ) from ..base_dsl import * +from ..base_dsl.arch import Arch from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values from ..base_dsl.typing import _binary_op_type_promote from ..base_dsl._mlir_helpers.gpu import * @@ -53,9 +52,11 @@ from ..base_dsl.compiler import ( KeepCUBIN, KeepPTX, GPUArch, + LinkLibraries, EnableTVMFFI, ) from ..base_dsl.runtime.jit_arg_adapters import * +from ..base_dsl.native_struct import make_native_struct, native_struct from ..base_dsl.utils.logger import _init_logger_with_client_name diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py index a854ec4e0..c94ef11c7 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_jit_executor.py @@ -15,8 +15,10 @@ This module provides jit executor related classes for CUTLASS. import ctypes import functools +import inspect import weakref import threading +from typing import Any, List, Optional, Tuple, Union import cuda.bindings.runtime as cuda_runtime import cuda.bindings.driver as cuda_driver @@ -33,35 +35,39 @@ from ..base_dsl.common import DSLRuntimeError from ..base_dsl.typing import Int32 from ..base_dsl.runtime.cuda import checkCudaErrors +from .._mlir import ir, execution_engine + class CudaDialectJitModule: """Holds the execution engine and cuda libraries.""" def __init__( self, - engine, - capi_func, - args_spec: ExecutionArgs, + engine: execution_engine.ExecutionEngine, + capi_func: Any, + execution_args: ExecutionArgs, cuda_library: list["cuda_runtime.cudaLibrary_t"], - ): + ) -> None: self.engine = engine self.capi_func = capi_func - self.args_spec = args_spec + self.execution_args = execution_args self.cuda_library = cuda_library self._unloaded = False - def is_unloaded(self): + def is_unloaded(self) -> bool: return self._unloaded - def unload(self): + def unload(self) -> None: try: for library in self.cuda_library: cuda_runtime.cudaLibraryUnload(library) self.cuda_library.clear() + except Exception as e: + pass finally: self._unloaded = True - def __del__(self): + def __del__(self) -> None: self.unload() @@ -70,24 +76,25 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): def __init__( self, - ir_module, - engine, - capi_func, - args_spec, - function_name, - kernel_info, - jit_time_profiling, - jit_function_artifacts, - prefix=None, - load_from_binary=False, - dynamic_args=None, - dynamic_kwargs=None, - ): + ir_module: ir.Module, + engine: Optional[execution_engine.ExecutionEngine], + capi_func: Any, + signature: Optional[inspect.Signature], + function_name: str, + kernel_info: Optional[dict], + jit_time_profiling: bool, + jit_function_artifacts: Optional[JitFunctionArtifacts], + prefix: Optional[str] = None, + load_from_binary: bool = False, + dynamic_args: tuple[Any] = tuple[Any](), + dynamic_kwargs: dict[str, Any] = dict[str, Any](), + has_gpu_module: bool = True, + ) -> None: super().__init__( ir_module, engine, capi_func, - args_spec, + signature, function_name, kernel_info, jit_time_profiling, @@ -96,21 +103,29 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): load_from_binary, dynamic_args, dynamic_kwargs, + has_gpu_module, ) + # Populated from module attributes by CuteExperimentalDSL.compile_and_cache; + # defaults match pre-pass state and non-experimental CUDA JIT functions. + self.kernel_extra_args: dict[str, int] = {} + self.total_added_arguments: int = 0 + # Set cuda result return type. # When execution engine/capi function is None, do not set the return type. if self.capi_func: self.capi_func.restype = ctypes.c_int32 - if self.args_spec: - self.args_spec.args_spec.annotations["return"] = Int32 + if self.execution_args: + self.execution_args.signature = self.execution_args.signature.replace( + return_annotation=Int32 + ) @functools.cached_property - def num_devices(self): + def num_devices(self) -> int: """Returns the number of CUDA devices available.""" return checkCudaErrors(cuda_runtime.cudaGetDeviceCount()) - def _deserializer(self): + def _deserializer(self) -> List["cuda_runtime.cudaLibrary_t"]: """Load the cuda library from the binary execution engine. @return: The list of cuda kernels. """ @@ -137,7 +152,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): cuda_init_args = [pointer_to_pointer_to_library, pointer_to_err] packed_args = (ctypes.c_void_p * len(cuda_init_args))() for i in range(len(cuda_init_args)): - packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p) + packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p) # type: ignore[arg-type] cuda_init(packed_args) checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) @@ -145,14 +160,14 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): cuda_load_args = [pointer_to_library, pointer_to_err] packed_args = (ctypes.c_void_p * len(cuda_load_args))() for i in range(len(cuda_load_args)): - packed_args[i] = ctypes.cast(cuda_load_args[i], ctypes.c_void_p) + packed_args[i] = ctypes.cast(cuda_load_args[i], ctypes.c_void_p) # type: ignore[arg-type] cuda_load(packed_args) checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) return [cuda_runtime.cudaLibrary_t(library.value)] - def _get_cuda_init_and_load(self): + def _get_cuda_init_and_load(self) -> Tuple[Any, Any]: """Returns the cuda init and load functions from the engine.""" # cuda init takes in a pointer to a cudaLibrary_t and returns # a i32 cudaError_t. It initialized (lazy loads) our cudaLibrary_t @@ -194,7 +209,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): return cuda_init, cuda_load_to_device - def _load_cuda_library(self): + def _load_cuda_library(self) -> List["cuda_runtime.cudaLibrary_t"]: """Loads the CUDA library from the engine.""" cuda_init, cuda_load_to_device = self._get_cuda_init_and_load() @@ -208,7 +223,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): cuda_init_args = [pointer_to_pointer_to_library, pointer_to_err] packed_args = (ctypes.c_void_p * len(cuda_init_args))() for i in range(len(cuda_init_args)): - packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p) + packed_args[i] = ctypes.cast(cuda_init_args[i], ctypes.c_void_p) # type: ignore[arg-type] cuda_init(packed_args) checkCudaErrors((cuda_runtime.cudaError_t(err.value),)) @@ -223,7 +238,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): ] packed_args = (ctypes.c_void_p * len(cuda_load_args))() for i, arg in enumerate(cuda_load_args): - packed_args[i] = ctypes.cast(arg, ctypes.c_void_p) + packed_args[i] = ctypes.cast(arg, ctypes.c_void_p) # type: ignore[arg-type] for dev in range(self.num_devices): device_id.value = dev @@ -234,7 +249,7 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): return [cuda_runtime.cudaLibrary_t(library.value)] - def to(self, device=None) -> JitExecutor: + def to(self, device: Optional[int] = None) -> JitExecutor: """Returns an executable function bound to the given device. For multi-device execution this method can be called for each device where @@ -252,12 +267,15 @@ class CudaDialectJitCompiledFunction(JitCompiledFunction): super()._validate_engine() with self._executor_lock: # We need to ensure that the modules are loaded if not already - if self.jit_module is None or self.jit_module.is_unloaded(): - cuda_library = self._load_cuda_library() - self.jit_module = CudaDialectJitModule( + if self.jit_module is None or ( + isinstance(self.jit_module, CudaDialectJitModule) + and self.jit_module.is_unloaded() + ): + cuda_library = self._load_cuda_library() if self.has_gpu_module else [] + self.jit_module = CudaDialectJitModule( # type: ignore[assignment] self.engine, self.capi_func, - self.args_spec, + self.execution_args, cuda_library, ) diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py index c4d41b6c5..d3a37068c 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cuda_stream_adapter.py @@ -13,6 +13,9 @@ This module provides CUDA Python helper functions """ +import ctypes +from typing import List, Tuple + import cuda.bindings.driver as cuda_driver # MLIR imports @@ -29,20 +32,20 @@ class CudaDialectStreamAdapter: Convert a CUDA stream to a stream representation for JIT arg generation. """ - def __init__(self, arg): + def __init__(self, arg: "cuda_driver.CUstream") -> None: self._arg = arg self._c_pointer = self._arg.getPtr() - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: List[ir.Value]) -> ir.Value: assert len(values) == 1 return values[0] - def __c_pointers__(self): + def __c_pointers__(self) -> List[ctypes.c_void_p]: return [self._c_pointer] - def __get_mlir_types__(self): + def __get_mlir_types__(self) -> List[ir.Type]: return [cuda.StreamType.get()] - def __cuda_stream__(self): + def __cuda_stream__(self) -> Tuple[int, int]: # support cuda stream protocol return (0, int(self._arg)) diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py index b62e3a8f1..6bdfc0352 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass.py @@ -19,6 +19,8 @@ from types import GenericAlias, SimpleNamespace, UnionType from typing_extensions import deprecated from typing import ( Callable, + Generator, + Optional, Union, List, Tuple, @@ -30,9 +32,10 @@ from typing import ( get_args, ) import functools +import inspect import pkgutil from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import is_dataclass, fields +from dataclasses import fields from math import ceil from itertools import chain from pathlib import Path @@ -40,14 +43,46 @@ import builtins import ctypes import hashlib import os +import re from ..base_dsl import * from ..base_dsl import compiler -from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values +from ..base_dsl.dsl import ( + is_dynamic_expression, + extract_mlir_values, + BaseDSL, + new_from_mlir_values, + implements_dynamic_expression, +) from ..base_dsl.typing import * -from ..base_dsl.typing import DynamicExpression, get_mlir_types -from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr -from ..base_dsl.jit_executor import ExecutionArgs +from ..base_dsl.typing import ( + DynamicExpression, + get_mlir_types, + Int32, + Int64, + Int8, + Integer, + Boolean, + Numeric, + NumericMeta, + DslType, + as_numeric, + get_c_pointers, + cast, +) +from ..base_dsl.common import DSLRuntimeError, DSLNotImplemented +from ..base_dsl.utils.logger import log +from ..base_dsl.utils.tree_utils import ( + Leaf, + PyTreeDef, + tree_flatten, + tree_unflatten, + DSLTreeFlattenError, + is_constexpr_field, +) +from ..base_dsl.leaf_utils import is_frozen_dataclass +from ..base_dsl.runtime.jit_arg_adapters import is_arg_annotation_constexpr +from ..base_dsl.jit_executor import ExecutionArgs # noqa: F401 from ..base_dsl.runtime import cuda as cuda_helpers from .cuda_stream_adapter import CudaDialectStreamAdapter from .cuda_jit_executor import CudaDialectJitCompiledFunction @@ -67,7 +102,10 @@ from cutlass._mlir.dialects._ods_common import ( get_op_result_or_op_results as _get_op_result_or_op_results, ) -from cutlass._mlir.dialects import lir as cutlass_lir +try: + from cutlass._mlir.dialects import lir as cutlass_lir +except ImportError: + cutlass_lir = None from cutlass._mlir.extras import types as T @@ -89,8 +127,6 @@ from ..base_dsl.ast_helpers import ( dynamic_expr, bool_cast, compare_executor, - any_executor, - all_executor, range_value_check, cf_symbol_check, ) @@ -100,6 +136,7 @@ from .cutlass_ast_decorators import ( _if_execute_dynamic, _while_execute_dynamic, _ifexp_execute_dynamic, + LoopUnroll, ) from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry @@ -128,28 +165,35 @@ SMEM_CAPACITY_MAP = { # ============================================================================= +def _get_max_cpu_threads() -> int: + """Return a safe thread-pool size: half of CPU count, clamped to [1, 16].""" + return max(1, min(16, (os.cpu_count() or 8) // 2)) + + # Return a ctype class that represents the in-memory layout expected # for a CuTe hierarchical tuple type. -def get_sparse_tuple_ctype(dyn): +def get_sparse_tuple_ctype(dyn: Union[int, Sequence[object]]) -> type: # When there is a single dynamic value, the sparse CuTe # representation is a single integer. if isinstance(dyn, int): return ctypes.c_int32 + dyn_seq: Sequence[object] = dyn + # For zero or greater than 1 dynamic values, the tuple # representation will be a struct with a field for each dynamic # value. The representation is flattened, even for hierarchical CuTe # profiles (although we are only dealing with depth 1 inputs here). class TupleDescriptor(ctypes.Structure): - _fields_ = [(f"x{idx}", ctypes.c_int32) for idx in range(len(dyn))] + _fields_ = [(f"x{idx}", ctypes.c_int32) for idx in range(len(dyn_seq))] - def __str__(self): + def __str__(self) -> str: return f"struct<{str(self._fields_)}>" return TupleDescriptor -def is_cute_algebra_type(arg_spec): +def is_cute_algebra_type(arg_spec: object) -> bool: # Walk through the arg_spec to check if it's a cute algebra type _cute_algebra_type_aliases = ( "Shape", @@ -174,9 +218,10 @@ def is_cute_algebra_type(arg_spec): return False -def _build_kernel_attrs(config) -> dict: +def _build_kernel_attrs(config: BaseDSL.LaunchConfig) -> dict: kernel_attrs = {} if config.min_blocks_per_mp > 1: + assert config.smem is not None kernel_attrs = { cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: ceil( config.min_blocks_per_mp @@ -190,53 +235,16 @@ def _build_kernel_attrs(config) -> dict: return kernel_attrs -def _get_c_pointers_cutlass(obj): - """ - This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict. - """ - if hasattr(obj, "__c_pointers__"): - return obj.__c_pointers__() - elif isinstance(obj, (tuple, list)): - return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj)) - elif isinstance(obj, SimpleNamespace): - return list( - chain.from_iterable( - _get_c_pointers_cutlass(x) for x in obj.__dict__.values() - ) - ) - elif isinstance(obj, dict): - return list( - chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values()) - ) - elif is_dataclass(obj): - return list( - chain.from_iterable( - _get_c_pointers_cutlass(getattr(obj, f.name)) - for f in fields(obj) - if not is_constexpr_field(f) - ) - ) - elif isinstance(obj, set): - raise DSLRuntimeError( - "Sets are not supported in get_c_pointers to ensure order preservation", - context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", - suggestion="Consider using a list or tuple instead", - ) - else: - # Try get adapter - adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj)) - if adapter is not None: - return _get_c_pointers_cutlass(adapter(obj)) - return [] - - class CutlassBaseDSL(BaseDSL): """This abstract class provides a DSL for Cutlass.""" + _ALLOWED_EXTRA_KERNEL_VALUE_ATTRS: frozenset[str] = frozenset() + _KERNEL_ATTR_SPEC_FIELD: Optional[str] = None + def __init__( self, name: str, - compiler_provider: Any, + compiler_provider: compiler.Compiler, pass_sm_arch_name: str, device_compilation_only: bool = False, preprocess: bool = False, @@ -249,23 +257,102 @@ class CutlassBaseDSL(BaseDSL): device_compilation_only=device_compilation_only, preprocess=preprocess, ) - self._smem_usage_tracker: tuple = None + self._smem_usage_tracker: Optional[tuple] = None # extra function to convert cute arguments to tvm ffi spec params # this needs to be reverse registered because the arg convention # depends on the runtime type of the DSL arguments self._tvm_ffi_args_spec_converter = None + def _set_smem_tracking( + self, allocator: object, callback: Callable[[object], int] + ) -> None: + self._smem_usage_tracker = (allocator, callback) + + def _reset_smem_tracking(self) -> None: + self._smem_usage_tracker = None + + def _get_smem_usage(self) -> int: + if not self._smem_usage_tracker: + return 0 + allocator, callback = self._smem_usage_tracker + return callback(allocator) + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. - def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: + def _is_tensor_descriptor(self, maybe_tensor_descriptor: object) -> bool: return False # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. def _handle_tensor_descriptor( - self, maybe_tensor, arg_name: str, need_gpu_memory: bool - ) -> Any: + self, maybe_tensor: object, arg_name: str, need_gpu_memory: bool + ) -> bool: return False - def _build_gpu_module(self, attrs, loc=None): + def _collect_raw_kernel_attrs_from_decorator( + self, func_body: Callable[..., None], func_args: tuple + ) -> dict: + field = self._KERNEL_ATTR_SPEC_FIELD + if field is None: + return {} + attr_spec = getattr(func_body, field, None) + if attr_spec is None: + return {} + + if isinstance(attr_spec, dict): + return attr_spec + if callable(attr_spec): + # Resolver signature: (owner, kernel_name) -> dict | None + # *owner* is the first argument passed to the kernel (typically + # ``self``), or ``None`` for free-function kernels. + owner = func_args[0] if func_args else None + resolved = attr_spec(owner, func_body.__name__) + if resolved is None: + return {} + if not isinstance(resolved, dict): + raise DSLRuntimeError( + "Kernel attribute resolver must return a dict or None.", + suggestion="Return a dict[str, str | ir.Attribute], or None for no attributes.", + ) + return resolved + + raise DSLRuntimeError( + f"Unsupported kernel decorator attributes spec type: {type(attr_spec)}", + suggestion="Use a dict or a callable returning dict[str, str | ir.Attribute].", + ) + + def _collect_extra_kernel_value_attrs( + self, func_body: Callable[..., None], func_args: tuple, func_kwargs: dict + ) -> dict[str, ir.Attribute]: + del func_kwargs + raw_attrs = self._collect_raw_kernel_attrs_from_decorator(func_body, func_args) + if not raw_attrs: + return {} + + converted_attrs: dict[str, ir.Attribute] = {} + for key, value in raw_attrs.items(): + if key not in self._ALLOWED_EXTRA_KERNEL_VALUE_ATTRS: + allowed_keys = ", ".join(sorted(self._ALLOWED_EXTRA_KERNEL_VALUE_ATTRS)) + if allowed_keys: + suggestion = f"Use one of the allowed keys: {allowed_keys}." + else: + suggestion = f"No extra kernel function attributes are supported for '{self.name}'." + raise DSLRuntimeError( + f"Unsupported kernel function attribute key '{key}'.", + suggestion=suggestion, + ) + if isinstance(value, ir.Attribute): + converted_attrs[key] = value + elif isinstance(value, str): + converted_attrs[key] = ir.StringAttr.get(value) + else: + raise DSLRuntimeError( + f"Unsupported kernel function attribute value type for '{key}': {type(value)}", + suggestion="Use str or ir.Attribute as the attribute value.", + ) + return converted_attrs + + def _build_gpu_module( + self, attrs: dict[str, str], loc: Optional[ir.Location] = None + ) -> None: log().info(f"self : {self}") log().info(f"Building GPU module for {self.name}") self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels"), loc=loc) @@ -276,7 +363,7 @@ class CutlassBaseDSL(BaseDSL): for attr_name in attrs: self.gpu_module.attributes[attr_name] = ir.Attribute.parse(attrs[attr_name]) - def _get_pipeline(self, pipeline): + def _get_pipeline(self, pipeline: Optional[str]) -> str: pipeline = super()._get_pipeline(pipeline) if pipeline is None: # cubin format is required to be cubin as we launch cuda module at python level. @@ -288,7 +375,7 @@ class CutlassBaseDSL(BaseDSL): return pipeline - def preprocess_pipeline(self, pipeline, arch) -> str: + def preprocess_pipeline(self, pipeline: str, arch: str) -> str: pipeline = super().preprocess_pipeline(pipeline, arch) pipeline = ( pipeline.rstrip("})") @@ -296,33 +383,35 @@ class CutlassBaseDSL(BaseDSL): ) return pipeline - def _enter_gpu_module(self): + def _enter_gpu_module(self) -> ir.InsertionPoint: log().info(f"self: {self}") log().info(f"Entering GPU module for {self.name}") log().info(f"GPU module: {self.gpu_module}") - if not self.gpu_module: - raise DSLRuntimeError( - f"GPU module is not set, probably compilation of a kernel from different DSL decorator", - suggestion=f"Use the same DSL decorator to build the GPU module, DSL: {type(self).__name__}", - ) return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0]) @staticmethod - def generate_func_ret_op(loc=None, ip=None): + def generate_func_ret_op( + loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None + ) -> None: raise NotImplementedError( "generate_func_ret_op() must be implemented by subclasses." ) @staticmethod - def generate_func_op(arg_types, arg_attrs, kernel_name, loc=None): + def generate_func_op( + arg_types: List[ir.Type], + arg_attrs: Optional[List[ir.Attribute]], + kernel_name: str, + loc: Optional[ir.Location] = None, + ) -> ir.Operation: raise NotImplementedError( "generate_func_op() must be implemented by subclasses." ) def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict: - assert isinstance( - config, BaseDSL.LaunchConfig - ), f"Expect LaunchConfig for @kernel, but got {type(config)}" + assert isinstance(config, BaseDSL.LaunchConfig), ( + f"Expect LaunchConfig for @kernel, but got {type(config)}" + ) ret = {} if config.has_max_number_threads(): @@ -356,8 +445,8 @@ class CutlassBaseDSL(BaseDSL): return ret - @lru_cache(maxsize=1) - def get_version(self): + @functools.lru_cache(maxsize=1) + def get_version(self) -> Any: """ Get the version of cutlass dsl, used for computing the hash key of the cache. Including source python files and the shared library. @@ -381,7 +470,7 @@ class CutlassBaseDSL(BaseDSL): ) from e return key, idx, h.digest() - def _iter_jobs(): + def _iter_jobs() -> Generator: """Chunk jobs generator to hash files in parallel""" for key, path, size in files: # empty files still get a deterministic hash from SHA-256 of zero bytes @@ -393,13 +482,29 @@ class CutlassBaseDSL(BaseDSL): dsl_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) files = [] - # Keep large dso file first in the list to reduce tail effect - giant_dso_name = str( - next( - (Path(dsl_path) / "_mlir" / "_mlir_libs").glob("_cutlass_ir.cpython*") - ).name - ) - so_path = os.path.join(dsl_path, "_mlir", "_mlir_libs", giant_dso_name) + mlir_libs_candidates = [ + Path(dsl_path) / "_mlir" / "_mlir_libs", + ] + try: + import cutlass._mlir as _mlir_module + + if hasattr(_mlir_module, "__path__"): + for p in _mlir_module.__path__: + mlir_libs_candidates.append(Path(p) / "_mlir_libs") + except (ImportError, AttributeError): + pass + mlir_libs_path = None + for candidate in mlir_libs_candidates: + if candidate.exists(): + mlir_libs_path = candidate + break + if mlir_libs_path is None: + raise DSLRuntimeError( + "Could not find _mlir/_mlir_libs directory. " + "Please re-install the package." + ) + giant_dso_name = str(next(mlir_libs_path.glob("_cutlass_ir.cpython*")).name) + so_path = str(mlir_libs_path / giant_dso_name) try: # update the version hash of the cutlass shared library so_size = os.path.getsize(so_path) @@ -412,7 +517,7 @@ class CutlassBaseDSL(BaseDSL): files.append((giant_dso_name, so_path, so_size)) for lib in pkgutil.walk_packages([dsl_path], prefix="cutlass."): - spec = lib.module_finder.find_spec(lib.name) + spec = lib.module_finder.find_spec(lib.name) # type: ignore[call-arg] if not spec or not spec.origin: continue path = spec.origin @@ -427,10 +532,8 @@ class CutlassBaseDSL(BaseDSL): # Submit chunks to a job queue chunk_size = 1 << 24 # 16 MB (tuned) - per_file_chunks = {} - # 16 threads max to avoid context switching overhead - # To avoid oversubscription, we use half of cpu_count() - max_workers = min(16, (os.cpu_count() or 8) // 2) + per_file_chunks: dict[str, list] = {} + max_workers = _get_max_cpu_threads() with ThreadPoolExecutor(max_workers=max_workers) as ex: futures = [ex.submit(_hash_chunk, *job) for job in _iter_jobs()] for fut in as_completed(futures): @@ -442,6 +545,7 @@ class CutlassBaseDSL(BaseDSL): # Since files list is in arbitrary order, we sort by key to get deterministic order for key, path, size in sorted(files, key=lambda t: t[0]): chunks = per_file_chunks.get(key) + assert chunks is not None file_hash = hashlib.sha256( b"".join( digest @@ -463,7 +567,9 @@ class CutlassBaseDSL(BaseDSL): i32_ty = ir.IntegerType.get_signless(32) return [i32_ty] - def generate_default_return_values(self, ip=None) -> List[ir.Value]: + def generate_default_return_values( + self, ip: Optional[ir.InsertionPoint] = None + ) -> List[ir.Value]: """ Generate the default return values of the function. With cuda dialect, the default return value is 0 to indicate success. @@ -482,21 +588,22 @@ class CutlassBaseDSL(BaseDSL): def compile_and_cache( self, - module, - module_hash, - function_name, - pipeline, - args_spec, - no_cache, - no_jit_engine, + module: ir.Module, + module_hash: str, + function_name: str, + pipeline: Optional[str], + signature: inspect.Signature, + no_cache: bool, + no_jit_engine: bool, + func_type: type = CudaDialectJitCompiledFunction, *, - full_args=None, - full_kwargs=None, - dynamic_args=None, - dynamic_kwargs=None, - original_function_name=None, - funcBody=None, - ): + full_args: Optional[tuple] = None, + full_kwargs: Optional[dict] = None, + dynamic_args: Optional[list] = None, + dynamic_kwargs: Optional[dict] = None, + original_function_name: Optional[str] = None, + funcBody: Optional[Callable[..., None]] = None, + ) -> CudaDialectJitCompiledFunction: """ Compile the module and cache the result. @@ -504,7 +611,7 @@ class CutlassBaseDSL(BaseDSL): :param module_hash: The hash of the MLIR module. :param function_name: The name of the function to compile. :param pipeline: The pipeline to use for compilation. - :param args_spec: The args spec to use for compilation. + :param signature: The signature of the function to compile. :param no_cache: Whether to cache the result. :param no_jit_engine: Whether to create JIT execution engine. :param full_args: The full arguments to use for compilation. @@ -527,13 +634,15 @@ class CutlassBaseDSL(BaseDSL): assert self._tvm_ffi_args_spec_converter is not None tvm_ffi_spec_params, kwargs_wrapper_spec = ( self._tvm_ffi_args_spec_converter( - function_name, args_spec, full_args, full_kwargs + function_name, signature, full_args, full_kwargs ) ) - tvm_ffi_provider = TVMFFICuteCallProvider(function_name) + tvm_ffi_provider = TVMFFICuteCallProvider( + function_name, has_gpu_module=self.num_kernels > 0 + ) # ensure we run the postprocessor hook after the compiler has run its passes - def post_compile_hook(module: ir.Module): + def post_compile_hook(module: ir.Module) -> None: with module.context, module.operation.location: # attach the tvm ffi function to the mlir module attach_ffi_func( @@ -545,7 +654,7 @@ class CutlassBaseDSL(BaseDSL): ) module.operation.verify() - def _make_compiled_func(*args, **kwargs): + def _make_compiled_func(*args: Any, **kwargs: Any) -> Any: if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults: return TVMFFIJitCompiledFunctionWithKwargs( *args, **kwargs, kwargs_wrapper_spec=kwargs_wrapper_spec @@ -563,7 +672,7 @@ class CutlassBaseDSL(BaseDSL): module_hash, function_name, pipeline, - args_spec, + signature, no_cache, no_jit_engine, _make_compiled_func, @@ -574,12 +683,12 @@ class CutlassBaseDSL(BaseDSL): funcBody=funcBody, ) - return super().compile_and_cache( + return super().compile_and_cache( # type: ignore[return-value] module, module_hash, function_name, pipeline, - args_spec, + signature, no_cache, no_jit_engine, CudaDialectJitCompiledFunction, @@ -592,12 +701,14 @@ class CutlassBaseDSL(BaseDSL): ) @staticmethod - def track_smem_allocator(allocator, callback): + def track_smem_allocator( + allocator: object, callback: Callable[[object], int] + ) -> None: """ Tracks shared memory usage for kernel functions. Find and set allocator to its parent dsl object. """ - frame = inspect.currentframe().f_back + frame = inspect.currentframe().f_back # type: ignore[union-attr] while frame: obj = frame.f_locals.get("self", None) if obj and isinstance(obj, CutlassBaseDSL): @@ -606,15 +717,17 @@ class CutlassBaseDSL(BaseDSL): frame = frame.f_back warnings.warn("Cannot find parent dsl for allocator!", UserWarning) - def _set_smem_tracking(self, allocator, callback): + def _set_smem_tracking( # type: ignore[no-redef] + self, allocator: object, callback: Callable[[object], int] + ) -> None: # Registers an allocator and callback for current dsl self._smem_usage_tracker = (allocator, callback) - def _reset_smem_tracking(self): + def _reset_smem_tracking(self) -> None: # type: ignore[no-redef] # Clear an allocator and callback for current dsl self._smem_usage_tracker = None - def _get_smem_usage(self) -> int: + def _get_smem_usage(self) -> int: # type: ignore[no-redef] # Treat final allocated bytes of allocator as smem usage if not self._smem_usage_tracker: return 0 @@ -624,27 +737,27 @@ class CutlassBaseDSL(BaseDSL): @staticmethod def cuda_launch_func( stream: Union[list, tuple], - kernel, - grid_size_x, - grid_size_y, - grid_size_z, - block_size_x, - block_size_y, - block_size_z, - kernel_operands, + kernel: ir.Value, + grid_size_x: Union[Int32, int], + grid_size_y: Union[Int32, int], + grid_size_z: Union[Int32, int], + block_size_x: Union[Int32, int], + block_size_y: Union[Int32, int], + block_size_z: Union[Int32, int], + kernel_operands: List[ir.Value], *, - cluster_size_x=None, - cluster_size_y=None, - cluster_size_z=None, - preferred_cluster_size_x=None, - preferred_cluster_size_y=None, - preferred_cluster_size_z=None, - dynamic_shared_memory_size=None, - use_pdl=False, - cooperative=False, - loc=None, - ip=None, - ): + cluster_size_x: Optional[Union[Int32, int]] = None, + cluster_size_y: Optional[Union[Int32, int]] = None, + cluster_size_z: Optional[Union[Int32, int]] = None, + preferred_cluster_size_x: Optional[Union[Int32, int]] = None, + preferred_cluster_size_y: Optional[Union[Int32, int]] = None, + preferred_cluster_size_z: Optional[Union[Int32, int]] = None, + dynamic_shared_memory_size: Optional[Union[Int64, int]] = None, + use_pdl: bool = False, + cooperative: bool = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: max_num_attributes = 17 launch_config_type = cuda_dialect.LaunchConfigType.get(max_num_attributes) @@ -735,25 +848,25 @@ class CutlassBaseDSL(BaseDSL): @staticmethod def gpu_launch_func( - async_token, - async_dependencies, - kernel, - grid_size_x, - grid_size_y, - grid_size_z, - block_size_x, - block_size_y, - block_size_z, - kernel_operands, + async_token: ir.Value, + async_dependencies: List[ir.Value], + kernel: ir.Value, + grid_size_x: Union[Int32, int], + grid_size_y: Union[Int32, int], + grid_size_z: Union[Int32, int], + block_size_x: Union[Int32, int], + block_size_y: Union[Int32, int], + block_size_z: Union[Int32, int], + kernel_operands: List[ir.Value], *, - cluster_size_x=None, - cluster_size_y=None, - cluster_size_z=None, - dynamic_shared_memory_size=None, - async_object=None, - use_pdl=False, - loc=None, - ip=None, + cluster_size_x: Optional[Union[Int32, int]] = None, + cluster_size_y: Optional[Union[Int32, int]] = None, + cluster_size_z: Optional[Union[Int32, int]] = None, + dynamic_shared_memory_size: Optional[Union[Int64, int]] = None, + async_object: Optional[ir.Value] = None, + use_pdl: Any = False, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ir.Value: op = gpu.LaunchFuncOp( asyncToken=async_token, @@ -777,14 +890,23 @@ class CutlassBaseDSL(BaseDSL): op.attributes["use_pdl"] = use_pdl.ir_value() return _get_op_result_or_op_results(op) - def _kernel_helper(self, funcBody, *args, **kwargs): + def _kernel_helper( # type: ignore[override] + self, funcBody: Callable[..., None], *args: Any, **kwargs: Any + ) -> "KernelLauncher": class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper): def __init__(self, dsl: CutlassBaseDSL): super().__init__() self.dsl = dsl self.dsl._reset_smem_tracking() - def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): + def generate_func_op( + self, + arg_types: List[ir.Type], + arg_attrs: Optional[List[ir.Attribute]], + kernel_name: str, + loc: Optional[ir.Location] = None, + ) -> ir.Operation: + assert arg_attrs is not None super().generate_func_op(arg_types, arg_attrs, kernel_name) self.func_op = self.dsl.generate_func_op( arg_types, arg_attrs, kernel_name, loc @@ -792,30 +914,34 @@ class CutlassBaseDSL(BaseDSL): self.arg_types = arg_types return self.func_op - def generate_func_ret_op(self, loc=None, ip=None): + def generate_func_ret_op( + self, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self.dsl.generate_func_ret_op(loc, ip) - def get_func_body_start(self): + def get_func_body_start(self) -> ir.Block: assert self.func_op is not None, "Invalid func_op is not expected!" arg_locs = [self.func_op.operation.location for _ in self.arg_types] return self.func_op.add_entry_block(arg_locs=arg_locs) - def generate_launch_op(self, *args, **kwargs): + def generate_launch_op(self, *args: Any, **kwargs: Any) -> None: # Extract args and do validation kernelSym = kwargs.get("kernelSym", None) kernelOperands = kwargs.get("kernelOperands", None) requiredArgs = kwargs.get("requiredArgs", None) loc = kwargs.get("loc", None) assert kernelSym is not None, "kernelSym being None is not expected!" - assert ( - requiredArgs is not None - ), "requiredArgs being None is not expected!" - assert ( - kernelOperands is not None - ), "kernelOperands being None is not expected!" - assert isinstance( - requiredArgs.config, BaseDSL.LaunchConfig - ), f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" + assert requiredArgs is not None, ( + "requiredArgs being None is not expected!" + ) + assert kernelOperands is not None, ( + "kernelOperands being None is not expected!" + ) + assert isinstance(requiredArgs.config, BaseDSL.LaunchConfig), ( + f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" + ) cfg = requiredArgs.config @@ -824,7 +950,7 @@ class CutlassBaseDSL(BaseDSL): pass # cannot compare dynamic value inside kernel to launch op in py elif cfg.auto_smem: cfg.smem = smem_usage - elif smem_usage > cfg.smem: + elif smem_usage > cfg.smem: # type: ignore[operator] warnings.warn( f"Potential error: specified kernel launch smem bytes " f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!", @@ -857,10 +983,11 @@ class CutlassBaseDSL(BaseDSL): if not isinstance(cfg.async_deps, (list, tuple)): async_deps = [cfg.async_deps] - # Prepare launch kwargs launch_kwargs = {} if cfg.has_fallback_cluster: + assert cfg.fallback_cluster is not None + assert cfg.cluster is not None launch_kwargs.update( dict( zip( @@ -882,6 +1009,7 @@ class CutlassBaseDSL(BaseDSL): ) ) elif cfg.has_cluster: + assert cfg.cluster is not None launch_kwargs.update( dict( zip( @@ -891,7 +1019,7 @@ class CutlassBaseDSL(BaseDSL): ) ) - CutlassBaseDSL.cuda_launch_func( + CutlassBaseDSL.cuda_launch_func( # type: ignore[misc] async_deps, kernelSym, *cfg.grid, @@ -909,7 +1037,7 @@ class CutlassBaseDSL(BaseDSL): if custom_name: return KernelLauncher( self, - lambda: _CutlassIrKernelGenHelper(self), + lambda: _CutlassIrKernelGenHelper(self), # type: ignore[arg-type] funcBody, *args, **kwargs, @@ -918,29 +1046,37 @@ class CutlassBaseDSL(BaseDSL): else: return KernelLauncher( self, - lambda: _CutlassIrKernelGenHelper(self), + lambda: _CutlassIrKernelGenHelper(self), # type: ignore[arg-type] funcBody, *args, **kwargs, ) - def _preprocess_launch_config_args(self, args, kwargs): + def _preprocess_launch_config_args(self, args: tuple, kwargs: dict) -> None: """Helper to preprocess args and kwargs for LaunchConfig""" if "stream" in kwargs: kwargs["async_deps"] = kwargs.pop("stream") - def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): + def mangle_name( + self, function_name: str, args: tuple[Any, ...], signature: inspect.Signature + ) -> str: """Mangle the name of the function to avoid conflicts with other functions""" function_name = "cutlass_" + function_name - return super().mangle_name(function_name, args, args_spec) + return super().mangle_name(function_name, args, signature) - def _validate_arg(self, arg, arg_index, arg_name, arg_annotation): + def _validate_arg( + self, + arg: object, + arg_index: int, + arg_name: str, + arg_annotation: object, + ) -> Optional[DSLRuntimeError]: """ Validates if the arg is really of the annotated type. """ if ( - is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None) + is_arg_annotation_constexpr(arg_annotation, arg_name, arg_index, None) or arg_annotation is Any ): pass @@ -989,15 +1125,21 @@ class CutlassBaseDSL(BaseDSL): def _generate_jit_func_args_for_known_types( self, - func, - arg, - arg_name, - arg_spec, - arg_index, + func: Callable[..., None], + arg: Any, + arg_name: str, + arg_spec: object, + arg_index: int, *, - is_host=True, - ): - jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] + is_host: bool = True, + ) -> Tuple[ + Optional[List[object]], + Optional[List[ir.Type]], + Optional[List[ir.Attribute]], + ]: + jit_arg_type: Optional[List[ir.Type]] = [] + jit_arg_attr: Optional[List[ir.Attribute]] = [] + jit_exec_arg: Optional[List[object]] = [] default_attr = ir.DictAttr.get({}) ( @@ -1009,6 +1151,8 @@ class CutlassBaseDSL(BaseDSL): ) if jit_arg_type is not None and len(jit_arg_type) == 0: + assert jit_arg_attr is not None + assert jit_exec_arg is not None # Handle DSL specific types if is_cute_algebra_type(arg_spec): dyn_vals = extract_mlir_values(arg) @@ -1022,7 +1166,7 @@ class CutlassBaseDSL(BaseDSL): jit_exec_arg.extend( [ tvm_ffi.Shape( - [ + [ # type: ignore[arg-type] v.value if isinstance(v, Numeric) else v for v in arg ] @@ -1034,10 +1178,8 @@ class CutlassBaseDSL(BaseDSL): get_c_pointers(arg) if is_host else dyn_vals ) else: - jit_exec_arg = jit_arg_type = jit_arg_attr = None - elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( - arg, "__new_from_mlir_values__" - ): + return None, None, None + elif not is_host and not implements_dynamic_expression(arg, partial=True): # Try tree_flatten try: dyn_vals, attr_vals, _ = tree_flatten(arg) @@ -1048,18 +1190,20 @@ class CutlassBaseDSL(BaseDSL): if dyn_vals: jit_arg_type.extend([v.type for v in dyn_vals]) jit_arg_attr.extend(attr_vals) - jit_exec_arg.extend( - _get_c_pointers_cutlass(arg) if is_host else dyn_vals - ) + jit_exec_arg.extend(dyn_vals) else: - # If tree flatten yields empty list, treat it as a constexpr thing - # Like a dataclass with all fields are constexpr, or an empty tuple or list - jit_exec_arg = jit_arg_type = jit_arg_attr = None + return None, None, None return jit_exec_arg, jit_arg_type, jit_arg_attr def _generate_execution_arguments_for_known_types( - self, arg, arg_spec, arg_name, i, fop_args, iv_block_args - ): + self, + arg: object, + arg_spec: object, + arg_name: str, + i: int, + fop_args: List[ir.Value], + iv_block_args: int, + ) -> Tuple[List[object], int]: ir_arg, iv_block_args = super()._generate_execution_arguments_for_known_types( arg, arg_spec, arg_name, i, fop_args, iv_block_args ) @@ -1070,15 +1214,18 @@ class CutlassBaseDSL(BaseDSL): blk_args = fop_args[iv_block_args : iv_block_args + n_args] ir_arg.append(new_from_mlir_values(arg, blk_args)) iv_block_args += n_args - elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( - arg, "__new_from_mlir_values__" - ): + elif not implements_dynamic_expression(arg, partial=True): # Try tree_unflatten try: - dyn_vals, _, tree_def = tree_flatten(arg) - block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)] - ir_arg.append(tree_unflatten(tree_def, block_args)) - iv_block_args += len(dyn_vals) + # we just need the length of flattened values, + # and we don't expect to emit arith.constant ops + # to get ir values from python literals. + flat_vals, _, tree_def = tree_flatten(arg, return_ir_values=False) + block_args = fop_args[ + iv_block_args : iv_block_args + len(flat_vals) + ] + ir_arg.append(tree_unflatten(tree_def, block_args)) # type: ignore[arg-type] + iv_block_args += len(flat_vals) except DSLTreeFlattenError: return ir_arg, iv_block_args @@ -1095,7 +1242,7 @@ class CuTeDSL(CutlassBaseDSL): This is a concrete DSL subclass for the CuTe dialect. """ - def __init__(self): + def __init__(self) -> None: name = "CUTE_DSL" compiler_provider = compiler.Compiler(passmanager, execution_engine) pass_sm_arch_name = "cubin-chip" @@ -1103,7 +1250,12 @@ class CuTeDSL(CutlassBaseDSL): super().__init__(name, compiler_provider, pass_sm_arch_name, preprocess=True) @staticmethod - def generate_func_op(arg_types, arg_attrs, kernel_name, loc=None): + def generate_func_op( + arg_types: List[ir.Type], + arg_attrs: Optional[List[ir.Attribute]], + kernel_name: str, + loc: Optional[ir.Location] = None, + ) -> ir.Operation: func_op = cuda_dialect.KernelOp( kernel_name, ir.FunctionType.get(arg_types, []), loc=loc ) @@ -1123,30 +1275,115 @@ class CuTeDSL(CutlassBaseDSL): return func_op @staticmethod - def generate_func_ret_op(loc=None, ip=None): + def generate_func_ret_op( + loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None + ) -> Any: return cuda_dialect.ReturnOp([], loc=loc, ip=ip) +# ============================================================================= +# CuteExperimentalJitCompiledFunction Class +# ============================================================================= + + +class _CuteExperimentalJitCompiledFunction(CudaDialectJitCompiledFunction): + """JitCompiledFunction subclass for CuteExperimentalDSL. + + Overrides ``__call__`` to validate that the caller supplies exactly + ``total_added_arguments`` extra workspace pointer arguments beyond the + original kernel signature. + """ + + def __call__(self, *args: Any, **kwargs: Any) -> int | None: + n = self.execution_args._meta.arg_count + n_extra = builtins.max(0, len(args) - n) + if n_extra != self.total_added_arguments: + raise DSLRuntimeError( + "Wrong number of extra workspace arguments", + context={ + "expected": self.total_added_arguments, + "got": n_extra, + }, + ) + return super().__call__(*args, **kwargs) + + # ============================================================================= # CuteExperimental DSL Class # ============================================================================= class CuteExperimentalDSL(CutlassBaseDSL): - def __init__(self): + _ALLOWED_EXTRA_KERNEL_VALUE_ATTRS: frozenset[str] = frozenset( + {"lir.tma_update_mode"} + ) + _KERNEL_ATTR_SPEC_FIELD: Optional[str] = "_cute_experimental_kernel_attributes" + JitCompiledFunction = _CuteExperimentalJitCompiledFunction + + def __init__(self) -> None: name = "CUTE_EXPERIMENTAL_DSL" compiler_provider = compiler.Compiler(passmanager, execution_engine) pass_sm_arch_name = "cubin-chip" super().__init__(name, compiler_provider, pass_sm_arch_name, preprocess=True) - def _get_pipeline(self, pipeline): + @classmethod + def kernel(cls, *dargs: Any, **dkwargs: Any) -> Any: + attr_spec = dkwargs.pop("attributes", None) + # Capture the caller's frame here rather than delegating to + # super().kernel(), which would record *this* frame instead of + # the user's source location (f_back would land in this override + # rather than in the user file). + current_frame = inspect.currentframe() + assert current_frame is not None + frame = current_frame.f_back + kernel_decorator = BaseDSL.jit_runner( + cls, "_kernel_helper", frame, *dargs, **dkwargs + ) + if attr_spec is None: + return kernel_decorator + + def attach_and_decorate(func: Callable[..., None]) -> Callable[..., None]: + assert cls._KERNEL_ATTR_SPEC_FIELD is not None + setattr(func, cls._KERNEL_ATTR_SPEC_FIELD, attr_spec) + return kernel_decorator(func) + + return attach_and_decorate + + def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict: + import re + + ret = super()._generate_kernel_attrs(config) + + # Add compute capability attribute from the target arch. + # get_arch_enum() validates the arch string; strip the portability + # suffix (a/f) since C++ GpuArchitecture only has base names. + arch_enum = self.get_arch_enum() + sm_match = re.match(r"(sm_\d+)", arch_enum.to_string()) + if sm_match: + sm_name = sm_match.group(1) + ret["cc_attr"] = ir.Attribute.parse( + f"#core.compute_capability" + ) + + return ret + + def _get_pipeline(self, pipeline: Optional[str]) -> str: if pipeline == None: - return "builtin.module(gpu.module(lir-to-cute{enable-cuda-dialect enable-lir-func-finalization=false}), lir-func-finalization{enable-cuda-dialect=true}, cute-to-nvvm{check-inline-asm=false cubin-format=bin enable-cuda-dialect})" + return ( + "builtin.module(gpu.module(lir-to-cute{enable-cuda-dialect enable-lir-func-finalization=false}), lir-func-finalization{enable-cuda-dialect=true}, cute-to-nvvm{cubin-format=bin enable-cuda-dialect " + + self.compile_options.to_str() + + "})" + ) return pipeline @staticmethod - def generate_func_op(arg_types, arg_attrs, kernel_name, loc=None): + def generate_func_op( + arg_types: List[ir.Type], + arg_attrs: Optional[List[ir.Attribute]], + kernel_name: str, + loc: Optional[ir.Location] = None, + ) -> ir.Operation: func_op = cutlass_lir.FuncOp( ir.StringAttr.get(kernel_name), ir.TypeAttr.get(ir.FunctionType.get(arg_types, [])), @@ -1165,7 +1402,7 @@ class CuteExperimentalDSL(CutlassBaseDSL): # Monkey patch FuncOp to add an add_entry_block method, if not already defined. if not hasattr(func_op, "add_entry_block"): - def add_entry_block(arg_locs): + def add_entry_block(arg_locs: List[ir.Location]) -> ir.Block: if len(func_op.body.blocks) != 0: raise RuntimeError("The function already has an entry block.") func_op.body.blocks.append(*arg_types) @@ -1175,9 +1412,62 @@ class CuteExperimentalDSL(CutlassBaseDSL): return func_op @staticmethod - def generate_func_ret_op(loc=None, ip=None): + def generate_func_ret_op( + loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None + ) -> Any: return cutlass_lir.ReturnOp([]) + def compile_and_cache( + self, + module: ir.Module, + module_hash: str, + function_name: str, + pipeline: Optional[str], + signature: inspect.Signature, + no_cache: bool, + no_jit_engine: bool, + func_type: type = CudaDialectJitCompiledFunction, + *, + full_args: Optional[tuple] = None, + full_kwargs: Optional[dict] = None, + dynamic_args: Optional[list] = None, + dynamic_kwargs: Optional[dict] = None, + original_function_name: Optional[str] = None, + funcBody: Optional[Callable[..., None]] = None, + ) -> CudaDialectJitCompiledFunction: + fn = super().compile_and_cache( + module, + module_hash, + function_name, + pipeline, + signature, + no_cache, + no_jit_engine, + func_type, + full_args=full_args, + full_kwargs=full_kwargs, + dynamic_args=dynamic_args, + dynamic_kwargs=dynamic_kwargs, + original_function_name=original_function_name, + funcBody=funcBody, + ) + # Extract the kernel_extra_args attribute written by FuncFinalizationPass + # and store it on the compiled function for later retrieval. + # Maps kernel name → number of extra workspace pointer args added to the + # host entry point signature. + fn.kernel_extra_args = {} + fn.total_added_arguments = 0 + if fn.ir_module is not None: + attrs = fn.ir_module.operation.attributes + if "kernel_extra_args" in attrs: + for named in ir.DictAttr(attrs["kernel_extra_args"]): + fn.kernel_extra_args[named.name] = ir.IntegerAttr(named.attr).value + fn.total_added_arguments = fn.kernel_extra_args.pop( + "total_added_arguments", 0 + ) + fn.__class__ = CuteExperimentalDSL.JitCompiledFunction + return fn + # ============================================================================= # KernelLauncher @@ -1205,10 +1495,10 @@ class KernelLauncher: self, dsl: "CutlassBaseDSL", kernelGenHelper: BaseDSL._KernelGenHelper, - funcBody, - *func_args, - **func_kwargs, - ): + funcBody: Callable[..., None], + *func_args: Any, + **func_kwargs: Any, + ) -> None: self.dsl = dsl self.kernelGenHelper = kernelGenHelper self.funcBody = funcBody @@ -1220,7 +1510,9 @@ class KernelLauncher: self._check_func_args(funcBody, *func_args, **func_kwargs) - def _check_func_args(self, funcBody, *func_args, **func_kwargs): + def _check_func_args( + self, funcBody: Any, *func_args: Any, **func_kwargs: Any + ) -> None: # Get function signature sig = inspect.signature(funcBody) @@ -1241,20 +1533,26 @@ class KernelLauncher: """ Check smem usage for this kernel, only available after `launch` """ - return self.dsl._get_smem_usage() + return self.dsl._get_smem_usage() # type: ignore[return-value] - def launch(self, *args, **kwargs): + def launch(self, *args: Any, **kwargs: Any) -> Any: self.dsl._preprocess_launch_config_args(args, kwargs) config = self.dsl.LaunchConfig(*args, **kwargs) kernel_attrs = _build_kernel_attrs(config) + value_attrs = self.dsl._generate_kernel_attrs(config) + collector = getattr(self.dsl, "_collect_extra_kernel_value_attrs", None) + if callable(collector): + value_attrs.update( + collector(self.funcBody, self.func_args, self.func_kwargs) + ) if hasattr(self, "_name_prefix") and self._name_prefix: - self.dsl._name_prefix = self._name_prefix + self.dsl._name_prefix = self._name_prefix # type: ignore[attr-defined] kernel_generator = self.dsl.kernel_launcher( requiredArgs=["config"], unitAttrNames=["gpu.kernel", "cute.kernel"], - valueAttrDict=self.dsl._generate_kernel_attrs(config), + valueAttrDict=value_attrs, kernelGenHelper=self.kernelGenHelper, )(self.funcBody) @@ -1263,7 +1561,7 @@ class KernelLauncher: self._launch_name = name return ret.launch_op_ret - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.launch(*args, **kwargs) @@ -1347,7 +1645,7 @@ def insert_read_only_frozen_dataclass( remaining_original = original_iter_args[full_write_args_count:] remaining_new = iter_args[full_write_args_count:] - def process_remaining_arg(original_arg, new_arg_iter): + def process_remaining_arg(original_arg: object, new_arg_iter: Any) -> object: """Process a single remaining argument, preserving frozen dataclass if present""" return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter) @@ -1362,11 +1660,51 @@ def insert_read_only_frozen_dataclass( def unpack_to_irvalue( mixed_values: List[Any], body_name: str, full_write_args_count: int -) -> Tuple[List[ir.Value], PyTreeDef]: +) -> Tuple[List[ir.Value], Union[PyTreeDef, Leaf]]: log().debug("===--- Values UNPack") for idx, packed in enumerate(mixed_values): log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) + # DEBUG: Print input values before tree_flatten (only if enabled) + from ..base_dsl.dsl import ( + debug_print_mlir_values, + should_print_dynamic_debug, + get_dynamic_debug_level, + ) + + if should_print_dynamic_debug(): + import traceback + import re + + level = get_dynamic_debug_level() + indent = " " * level # Indent based on level + + print("=" * 80) + print(f"{indent}[Level {level}] DEBUG '{body_name}'") + # Find the source location - look for generated DSL function names in stack + # These are created by the AST transformer: loop_body_N, then_block_N, etc. + generated_func_pattern = re.compile( + r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$" + ) + stack = traceback.extract_stack() + for frame_info in reversed(stack): + if generated_func_pattern.match(frame_info.name): + print(f"{indent} source: {frame_info.filename}:{frame_info.lineno}") + if frame_info.line: + print(f"{indent} {frame_info.line}") + break + print(f"{indent} mixed_values count: {len(mixed_values)}") + print(f"{indent} full_write_args_count: {full_write_args_count}") + for idx, packed in enumerate(mixed_values): + print(f"{indent} [{idx}] type: {type(packed).__name__}") + if hasattr(packed, "__extract_mlir_values__"): + # Add extra indentation for the tree print + tree_str = debug_print_mlir_values( + packed, indent=3 + level, types_only=True + ) + print(tree_str) + print("=" * 80) + try: unpacked_values, _, treedef = tree_flatten( remove_read_only_frozen_dataclass(mixed_values, full_write_args_count) @@ -1387,6 +1725,13 @@ def unpack_to_irvalue( ), ) + # DEBUG: Print unpacked values after tree_flatten + if should_print_dynamic_debug(): + level = get_dynamic_debug_level() + indent = " " * level + print(f"{indent} => flattened to {len(unpacked_values)} ir.Values") + print("=" * 80) + log().debug("------------------ ") for idx, unpacked in enumerate(unpacked_values): log().debug("[%d]: unpacked values: %s", idx, unpacked) @@ -1417,14 +1762,15 @@ def pack_from_irvalue( ) -def to_index(value): +def to_index(value: Union[Numeric, ir.Value, int]) -> ir.Value: """Converts a value to an index, either by casting or coercing to int.""" if is_dynamic_expression(value): if isinstance(value, Numeric): value = value.ir_value() - assert ir.IntegerType.isinstance( - value.type - ), f"expects integer type, but got {value.type}" + assert isinstance(value, ir.Value) + assert isinstance(value.type, ir.IntegerType), ( + f"expects integer type, but got {value.type}" + ) res = arith.index_cast(T.index(), value) else: res = const(int(value), ty=T.index()) @@ -1432,7 +1778,7 @@ def to_index(value): return res -def _validate_iter_args_structure(iter_args, ir_values): +def _validate_iter_args_structure(iter_args: object, ir_values: object) -> bool: """ Validates that iter_args structure contains the same number of atomic values as there are IR values. @@ -1454,7 +1800,7 @@ def _validate_iter_args_structure(iter_args, ir_values): return False # Count all non-sequence values recursively - def count_values(args): + def count_values(args: object) -> int: if not isinstance(args, (tuple, list, set)): return 1 else: @@ -1468,7 +1814,12 @@ def _validate_iter_args_structure(iter_args, ir_values): # ============================================================================= -def _minmax(op, *args, loc=None, ip=None): +def _minmax( + op: Any, + *args: Union[Numeric, ir.Value, int, float, bool], + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Numeric, int, float]: """Computes the minimum or maximum value from the provided arguments.""" from ..base_dsl.typing import _binary_op, _binary_op_type_promote @@ -1506,20 +1857,24 @@ def _minmax(op, *args, loc=None, ip=None): lhs, rhs, promote_bool=True ) + lhs_val: Union[bool, int, float, ir.Value, cutlass_arith.ArithValue] if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance( lhs, Integer ): - lhs_val = lhs.value.with_signedness(lhs.signed) + lhs_val = lhs.value.with_signedness(lhs.signed) # type: ignore[attr-defined] else: lhs_val = lhs.value + rhs_val: Union[bool, int, float, ir.Value, cutlass_arith.ArithValue] if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance( rhs, Integer ): - rhs_val = rhs.value.with_signedness(rhs.signed) + rhs_val = rhs.value.with_signedness(rhs.signed) # type: ignore[attr-defined] else: rhs_val = rhs.value - res = res_type(emitter(lhs_val, rhs_val, loc=loc, ip=ip), loc=loc, ip=ip) + res = res_type( + emitter(lhs_val, rhs_val, loc=loc, ip=ip), loc=loc, ip=ip + ) x = res else: raise DSLNotImplemented(f"{type(args)} is not supported") @@ -1527,7 +1882,11 @@ def _minmax(op, *args, loc=None, ip=None): @dsl_user_op -def min(*args, loc=None, ip=None): +def min( + *args: Union[Numeric, ir.Value, int, float, bool], + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Numeric, int, float]: """Computes the minimum value from the provided arguments. This function differs from Python's built-in min() in that the return type @@ -1586,7 +1945,11 @@ def min(*args, loc=None, ip=None): @dsl_user_op -def max(*args, loc=None, ip=None): +def max( + *args: Union[Numeric, ir.Value, int, float, bool], + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Numeric, int, float]: """Computes the maximum value from the provided arguments. This function differs from Python's built-in max() in that the return type @@ -1644,7 +2007,11 @@ def max(*args, loc=None, ip=None): return _minmax(max, *args, loc=loc, ip=ip) -def and_(*args, loc=None, ip=None): +def and_( + *args: Union[Numeric, ir.Value, int, float, bool], + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Numeric, int, float, bool]: """AND operation for value in DSL numeric types. :param *args: One or more numeric values to AND together @@ -1685,7 +2052,10 @@ def and_(*args, loc=None, ip=None): if len(args) == 1: return args[0] - def and_op(lhs, rhs): + def and_op( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], + ) -> Union[Numeric, int, float, bool]: if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): raise DSLNotImplemented(f"{type(lhs)} is not supported") elif isinstance(lhs, (int, float, bool)) and isinstance( @@ -1698,7 +2068,11 @@ def and_(*args, loc=None, ip=None): return functools.reduce(and_op, args[1:], args[0]) -def or_(*args, loc=None, ip=None): +def or_( + *args: Union[Numeric, ir.Value, int, float, bool], + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Numeric, int, float, bool]: """Logical OR operation for DSL numeric types. :param *args: One or more numeric values to OR together @@ -1737,7 +2111,10 @@ def or_(*args, loc=None, ip=None): if len(args) == 1: return args[0] - def or_op(lhs, rhs): + def or_op( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], + ) -> Union[Numeric, int, float, bool]: if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): raise DSLNotImplemented(f"{type(lhs)} is not supported") elif isinstance(lhs, (int, float, bool)) and isinstance( @@ -1750,7 +2127,7 @@ def or_(*args, loc=None, ip=None): return functools.reduce(or_op, args[1:], args[0]) -def all_(iterable): +def all_(iterable: Iterable[Union[Numeric, ir.Value, bool]]) -> Boolean: """Logical AND operation for all elements in an iterable. Returns True if all elements in the iterable are truthy, otherwise False. @@ -1774,14 +2151,13 @@ def all_(iterable): result = all_(conditions) # Returns True if all conditions are met """ bool_iterable = [Boolean(i) for i in iterable] - return functools.reduce( - lambda lhs, rhs: lhs.__dsl_and__(rhs) if hasattr(lhs, "__dsl_and__") else lhs, - bool_iterable, - Boolean(True), - ) + reducer = lambda lhs, rhs: ( + lhs.__dsl_and__(rhs) if hasattr(lhs, "__dsl_and__") else lhs + ) # noqa: E731 + return functools.reduce(reducer, bool_iterable, Boolean(True)) -def any_(iterable): +def any_(iterable: Iterable[Union[Numeric, ir.Value, bool]]) -> Boolean: """Logical OR operation for any element in an iterable. Returns True if any element in the iterable is truthy, otherwise False. @@ -1805,11 +2181,10 @@ def any_(iterable): result = any_(conditions) # Returns True if any condition is met """ bool_iterable = [Boolean(i) for i in iterable] - return functools.reduce( - lambda lhs, rhs: lhs.__dsl_or__(rhs) if hasattr(lhs, "__dsl_or__") else lhs, - bool_iterable, - Boolean(False), - ) + reducer = lambda lhs, rhs: ( + lhs.__dsl_or__(rhs) if hasattr(lhs, "__dsl_or__") else lhs + ) # noqa: E731 + return functools.reduce(reducer, bool_iterable, Boolean(False)) # ============================================================================= @@ -1817,8 +2192,14 @@ def any_(iterable): # ============================================================================= -def select_(cond, if_value, else_value): - def _as_scalar(value): +def select_( + cond: Union[Boolean, ir.Value, bool], + if_value: Union[Numeric, ir.Value, int, float, bool], + else_value: Union[Numeric, ir.Value, int, float, bool], +) -> ir.Value: + def _as_scalar( + value: Union[Numeric, ir.Value, int, float, List[ir.Value]], + ) -> Union[ir.Value, Numeric]: if isinstance(value, list): if len(value) == 1: return value[0] @@ -1852,7 +2233,11 @@ def select_(cond, if_value, else_value): # ============================================================================= -def yield_out(args=[], loc=None, ip=None): +def yield_out( + args: Union[List[ir.Value], List[Numeric]] = [], + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Generate a yield operation. It it used to return values from a loop, if-else, or while region. """ @@ -1864,39 +2249,17 @@ def yield_out(args=[], loc=None, ip=None): # ============================================================================= -class LoopUnroll(ir.Attribute): - def __init__(self, **kwargs): - valid_keys = set(["count", "full"]) - def to_mlir_attr(val): - if isinstance(val, bool): - return "true" if val else "false" - elif isinstance(val, int): - return f"{val} : i32" - else: - raise DSLNotImplemented(f"{type(val)} is not supported") - - cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} - if kwargs.get("count", None) == 1: - cfg["disable"] = "true" - - unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" - - super().__init__( - ir.Attribute.parse(f"#llvm.loop_annotation") - ) - - def for_generate( - start, - stop=None, - step=None, + start: Union[Int32, int], + stop: Optional[Union[Int32, int]] = None, + step: Optional[Union[Int32, int]] = None, iter_args: Optional[Sequence[ir.Value]] = None, *, - unroll: LoopUnroll = None, - prefetch_stages=None, - loc=None, - ip=None, -): + unroll: Optional[LoopUnroll] = None, + prefetch_stages: Optional[int] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Generator: """ scf.for with yield support """ @@ -1919,7 +2282,7 @@ def for_generate( start, stop, step = params - def _createI32Attr(value): + def _createI32Attr(value: Union[Int32, int]) -> ir.IntegerAttr: if not isinstance(value, int): raise DSLRuntimeError(f"value must be int.") return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value) @@ -1953,7 +2316,12 @@ def for_generate( # ============================================================================= -def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None): +def not_( + lhs: Union[ir.Value, bool], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[Boolean, bool, ir.Value]: """ Logical Not """ @@ -1978,15 +2346,15 @@ def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None): def if_generate( - cond: Boolean, + cond: Union[Boolean, ir.Value, bool], then_body: Callable, else_body: Optional[Callable] = None, - input_args: List[DslType] = None, - return_types: List[DslType] = None, + input_args: Optional[List[DslType]] = None, + return_types: Optional[List[DslType]] = None, *, - loc=None, - ip=None, -) -> List: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Union[List[Numeric], Numeric]: """ Generate an IfOp with optional else branch and return values. @@ -2010,7 +2378,7 @@ def if_generate( for t in return_types: if not isinstance(t, DslType): raise DSLRuntimeError(f"{t=} must be a DslType.") - mlir_return_types.append(t.mlir_type) + mlir_return_types.append(t.mlir_type) # type: ignore[attr-defined] # Determine whether there's an else branch. has_else = else_body is not None @@ -2020,7 +2388,7 @@ def if_generate( Boolean(cond).ir_value(), mlir_return_types, hasElse=has_else, loc=loc, ip=ip ) - def _execute_and_yield_out(body, input_args): + def _execute_and_yield_out(body: Callable, input_args: List[DslType]) -> None: yield_vals = body(*input_args) if return_types is not None: if not isinstance(yield_vals, Iterable): @@ -2036,6 +2404,7 @@ def if_generate( # Generate the body for 'else' if provided. if has_else: + assert else_body is not None with ir.InsertionPoint(if_op.else_block): _execute_and_yield_out(else_body, input_args) @@ -2074,9 +2443,9 @@ class WhileLoopContext: inputs: Sequence[Union[ir.Value, Numeric]], condition: Callable[[Sequence[ir.Value]], ir.Value], *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: # Keep original inputs and allow recover original type information self.inputs = inputs @@ -2100,7 +2469,7 @@ class WhileLoopContext: self.after_region.blocks.append(*self.input_ir_types) self.after_block = self.after_region.blocks[0] - def __enter__(self): + def __enter__(self) -> List[Numeric]: with ir.InsertionPoint(self.before_block): args = new_from_mlir_values(self.inputs, self.before_block.arguments) cond = self.condition(*args) @@ -2110,11 +2479,16 @@ class WhileLoopContext: self.ipoint_op.__enter__() return new_from_mlir_values(self.inputs, self.after_block.arguments) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[type], + exc_value: Optional[BaseException], + traceback: object, + ) -> None: self.ipoint_op.__exit__(exc_type, exc_value, traceback) @property - def results(self): + def results(self) -> List[Numeric]: return new_from_mlir_values(self.inputs, self.while_op.results_) @@ -2122,8 +2496,8 @@ def while_generate( inputs: Sequence[Union[ir.Value, Numeric]], condition: Callable[[Sequence[Union[ir.Value, Numeric]]], Union[ir.Value, Numeric]], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> WhileLoopContext: """ Generate a WhileLoopContext for a dynamic loop. @@ -2131,7 +2505,10 @@ def while_generate( return WhileLoopContext(inputs, condition, loc=loc, ip=ip) -def equal(lhs, rhs): +def equal( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], +) -> Union[Boolean, bool]: if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): return lhs == rhs @@ -2144,7 +2521,10 @@ def equal(lhs, rhs): return lhs == rhs -def not_equal(lhs, rhs): +def not_equal( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], +) -> Union[Boolean, bool]: if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): return lhs != rhs @@ -2163,7 +2543,7 @@ def not_equal(lhs, rhs): return not_(equal(lhs, rhs)) -def in_(lhs, rhs): +def in_(lhs: object, rhs: Any) -> Union[bool, Boolean]: if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): return lhs in rhs @@ -2175,8 +2555,16 @@ def in_(lhs, rhs): return any_(equal(lhs, r) for r in rhs) -def _lte_gte(lhs, rhs, op): - def native_lte_gte(lhs, rhs, op): +def _lte_gte( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], + op: str, +) -> Union[Boolean, bool]: + def native_lte_gte( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], + op: str, + ) -> Union[Boolean, bool]: if op == "<": return lhs < rhs elif op == "<=": @@ -2203,7 +2591,7 @@ def _lte_gte(lhs, rhs, op): and isinstance(rhs, Sequence) and type(lhs) == type(rhs) ): - unequal_found = False + unequal_found: Union[Numeric, bool] = False comp_results = [] mask = [] for l, r in zip(lhs, rhs): @@ -2240,23 +2628,39 @@ def _lte_gte(lhs, rhs, op): return native_lte_gte(lhs, rhs, op) -def greater_than(lhs, rhs): +def greater_than( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], +) -> Union[Boolean, bool]: return _lte_gte(lhs, rhs, ">") -def greater_equal(lhs, rhs): +def greater_equal( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], +) -> Union[Boolean, bool]: return _lte_gte(lhs, rhs, ">=") -def less_than(lhs, rhs): +def less_than( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], +) -> Union[Boolean, bool]: return _lte_gte(lhs, rhs, "<") -def less_equal(lhs, rhs): +def less_equal( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], +) -> Union[Boolean, bool]: return _lte_gte(lhs, rhs, "<=") -def _compare_dispatch(lhs, rhs, op): +def _compare_dispatch( + lhs: Union[Numeric, ir.Value, int, float, bool], + rhs: Union[Numeric, ir.Value, int, float, bool], + op: str, +) -> Union[Boolean, bool]: """ Dispatches the comparison operation between lhs and rhs based on the given operator. @@ -2295,13 +2699,17 @@ def _compare_dispatch(lhs, rhs, op): raise DSLRuntimeError(f"Unsupported comparison operator: {op}") -def _compare_executor(left, comparators, ops): +def _compare_executor( + left: Union[Numeric, ir.Value, int, float, bool], + comparators: List[Union[Numeric, ir.Value, int, float, bool]], + ops: List[str], +) -> Union[Numeric, int, float, bool]: # Fast path for single comparison if len(comparators) == 1: return _compare_dispatch(left, comparators[0], ops[0]) # Chain comparison, dispatch in a loop - result = True + result: Union[Numeric, int, float, bool] = True current = left for comparator, op in zip(comparators, ops): cmp_result = _compare_dispatch(current, comparator, op) @@ -2311,17 +2719,28 @@ def _compare_executor(left, comparators, ops): return result -def _builtin_redirector(fcn): - if fcn == builtins.max: - return max - elif fcn == builtins.min: - return min - elif fcn == builtins.any: - return any_ - elif fcn == builtins.all: - return all_ - else: - raise DSLRuntimeError(f"Unsupported built-in function: {fcn}") +def _builtin_redirector(fcn: Callable[..., object]) -> Callable[..., object]: + def builtin_wrapper(fcn: Any, *args: Any, **kwargs: Any) -> Any: + if is_dynamic_expression(args): + if kwargs: + # Redirected built-ins do not support keyword arguments + raise DSLRuntimeError( + f"Unsupported keyword arguments for built-in function: {fcn}" + ) + if fcn is builtins.max: + return max(*args) + elif fcn is builtins.min: + return min(*args) + elif fcn is builtins.any: + return any_(*args) + elif fcn is builtins.all: + return all_(*args) + else: + raise DSLRuntimeError(f"Unsupported built-in function: {fcn}") + else: + return fcn(*args, **kwargs) + + return functools.partial(builtin_wrapper, fcn) # ============================================================================= @@ -2335,8 +2754,6 @@ executor.set_functions( if_dynamic=_if_execute_dynamic, while_dynamic=_while_execute_dynamic, compare_executor=_compare_executor, - any_executor=any_, - all_executor=all_, builtin_redirector=_builtin_redirector, ifexp_dynamic=_ifexp_execute_dynamic, ) diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py index 26cc159f9..cba47634d 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/cutlass_ast_decorators.py @@ -9,16 +9,19 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import List +import builtins +from typing import Any, Callable, Dict, List, Optional, Union from cutlass._mlir import ir from cutlass._mlir.dialects import scf from collections.abc import Sequence +from ..base_dsl.common import DSLRuntimeError, DSLNotImplemented from ..base_dsl.dsl import is_dynamic_expression -from ..base_dsl.ast_helpers import * +from ..base_dsl._mlir_helpers.arith import ArithValue +from ..base_dsl.ast_helpers import * # noqa: F401,F403 from ..base_dsl.utils.logger import log -from ..base_dsl import typing as t, Arch +from ..base_dsl import typing as t from ..base_dsl.typing import Boolean, Numeric, as_numeric from ..base_dsl.utils.tree_utils import PyTreeDef, check_tree_equal from . import cutlass as cutlass_dsl @@ -31,9 +34,9 @@ NoneType = type(None) class LoopUnroll(ir.Attribute): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Union[int, bool]) -> None: valid_keys = set(["count", "full"]) - def to_mlir_attr(val): + def to_mlir_attr(val: Union[int, bool]) -> str: if isinstance(val, bool): return "true" if val else "false" elif isinstance(val, int): @@ -57,7 +60,7 @@ class ScfGenerator: Encapsulates common scf dialect functionality: pack, unpack, and SCF execution. """ - def __init__(self): + def __init__(self) -> None: pass @staticmethod @@ -77,7 +80,9 @@ class ScfGenerator: return region_result_list @staticmethod - def _check_region_result(original_value, region_value, arg_name, op_type_name): + def _check_region_result( + original_value: object, region_value: object, arg_name: str, op_type_name: str + ) -> None: """ Validate that a region result maintains the same type as the original value. @@ -101,7 +106,7 @@ class ScfGenerator: or different numeric types) are not allowed in dynamic SCF operations. """ - def get_type_name(value): + def get_type_name(value: object) -> str: if isinstance(value, NoneType): return "None" elif isinstance(value, Sequence): @@ -167,31 +172,18 @@ class ScfGenerator: def scf_execute_dynamic( self, op_type_name: str, - mix_iter_args: List[Any], + mix_iter_args: List[object], full_write_args_count: int, mix_iter_arg_names: List[str], create_op_func: Callable[[List[ir.Value]], ir.Operation], - region_builders: List[ - Callable[ - [ - "ir.Operation", - List["ir.Value"], # block_args - List["ir.Value"], # dyn_yield_ops - PyTreeDef, - List[Any], - int, - ], - Any, - ] - ], - # block_term_op_builder[region_builder] = scf_op_builder - # e.g. scf.ConditionOp for while loop - block_term_op_builder: Dict[Callable, Callable] = {}, + region_builders: List[Callable[..., Any]], + block_term_op_builder: Dict[Callable[..., Any], Callable[..., Any]] = {}, ) -> Any: # 1) Unpack ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue( mix_iter_args, op_type_name, full_write_args_count ) + # 2) Create the SCF op op = create_op_func(ir_values) log().debug("Generated scf.%s \n[%s]", op_type_name, op) @@ -235,6 +227,8 @@ class ScfGenerator: region_result_list, op_type_name, full_write_args_count ) + assert isinstance(pytree_def, PyTreeDef) + assert isinstance(yield_pytree_def, PyTreeDef) mismatch = check_tree_equal(pytree_def, yield_pytree_def) if mismatch != -1: # Get arg name @@ -258,6 +252,7 @@ class ScfGenerator: log().debug("Completed scf.%s \n[%s]", op_type_name, op) # 4) Pack final results + assert isinstance(pytree_def, PyTreeDef) final_results = cutlass_dsl.pack_from_irvalue( op.results, pytree_def, mix_iter_args, full_write_args_count ) @@ -270,7 +265,7 @@ class ScfGenerator: return final_results -def _attr_const_check(attr, expected_type, attr_name): +def _attr_const_check(attr: object, expected_type: type, attr_name: str) -> None: # Use strict type equality to prevent `bool` being accepted where `int` is required. if is_dynamic_expression(attr) or type(attr) is not expected_type: raise DSLRuntimeError( @@ -279,24 +274,27 @@ def _attr_const_check(attr, expected_type, attr_name): def _loop_execute_range_dynamic( - func: Callable, + func: Callable[..., Any], start: Any, stop: Any, step: Any, - mix_iter_args: List[Any] = [], + *, + write_args: List[Any] = [], full_write_args_count: int = 0, - mix_iter_arg_names: List[str] = [], + write_args_names: List[str] = [], unroll: int = -1, unroll_full: bool = False, - prefetch_stages: int = None, - vectorize: bool = None, -): + prefetch_stages: Optional[int] = None, + vectorize: Optional[bool] = None, + at_least_once: bool = False, + **kwargs: Any, +) -> Any: """ Example: build an scf.for with optional unroll, using our universal helper. """ scf_gen = ScfGenerator() - def create_for_op(dyn_yield_ops: List[ir.Value]): + def create_for_op(dyn_yield_ops: List[ir.Value]) -> ir.Operation: for d in dyn_yield_ops: if not isinstance(d, ir.Value): raise DSLRuntimeError( @@ -344,6 +342,8 @@ def _loop_execute_range_dynamic( vectorize_attr = None if vectorize: + from ..base_dsl.arch import Arch + arch = cutlass_dsl.CuTeDSL._get_dsl().get_arch_enum() if arch < Arch.sm_100: raise DSLRuntimeError( @@ -362,6 +362,7 @@ def _loop_execute_range_dynamic( step_, type(step_), ) + # Create scf.ForOp, passing iteration args if any try: if not dyn_yield_ops: @@ -388,16 +389,19 @@ def _loop_execute_range_dynamic( if vectorize_attr is not None: for_op.attributes["cutlass.vectorize"] = vectorize_attr + if at_least_once: + for_op.attributes["at_least_once"] = ir.UnitAttr.get() + return for_op def for_body_builder( - op, - block_args, - _, - pytree_def, - mix_iter_args, - full_write_args_count, - ): + op: ir.Operation, + block_args: List[ir.Value], + _: List[ir.Value], + pytree_def: Optional[PyTreeDef], + mix_iter_args: List[object], + full_write_args_count: int, + ) -> object: # scf.ForOp block_args are typically [induction_var, iter_args...] # But MLIR also gives you op.induction_variable iv = t.as_numeric(op.induction_variable) @@ -411,7 +415,10 @@ def _loop_execute_range_dynamic( func_args = [] func_args.extend( cutlass_dsl.pack_from_irvalue( - block_args[1:], pytree_def, mix_iter_args, full_write_args_count + block_args[1:], + pytree_def, # type: ignore[arg-type] + mix_iter_args, + full_write_args_count, ) ) if not func_args: @@ -425,9 +432,9 @@ def _loop_execute_range_dynamic( # Now call the universal SCF executor with a single region builder return scf_gen.scf_execute_dynamic( op_type_name="for", - mix_iter_args=mix_iter_args, + mix_iter_args=write_args, full_write_args_count=full_write_args_count, - mix_iter_arg_names=mix_iter_arg_names, + mix_iter_arg_names=write_args_names, create_op_func=create_for_op, region_builders=[for_body_builder], ) @@ -435,19 +442,19 @@ def _loop_execute_range_dynamic( def _if_execute_dynamic( pred: "ir.Value", - then_block: Callable, - else_block: Callable = None, - mix_yield_args: List[Any] = [], + then_block: Callable[..., object], + else_block: Optional[Callable[..., object]] = None, + mix_yield_args: List[object] = [], full_write_args_count: int = 0, mix_yield_arg_names: List[str] = [], - if_constexpr=None, # ignoring for brevity -): + if_constexpr: Optional[bool] = None, +) -> object: """ Build an scf.if with optional else, using our universal helper. """ scf_gen = ScfGenerator() - def create_if_op(dyn_yield_ops: List[ir.Value]): + def create_if_op(dyn_yield_ops: List[ir.Value]) -> ir.Operation: # Assume final result types match the dynamic yields result_types = [arg.type for arg in dyn_yield_ops] @@ -466,17 +473,19 @@ def _if_execute_dynamic( return if_op def then_builder( - if_op, - _, - dyn_yield_ops, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - flat_args = [] - flat_args.extend( + if_op: ir.Operation, + _: List[ir.Value], + dyn_yield_ops: List[ir.Value], + pytree_def: Optional[PyTreeDef], + mix_iter_args: List[object], + full_write_args_count: int, + ) -> object: + flat_args = list( cutlass_dsl.pack_from_irvalue( - dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + dyn_yield_ops, + pytree_def, # type: ignore[arg-type] + mix_iter_args, + full_write_args_count, ) ) return then_block(*flat_args) @@ -486,17 +495,19 @@ def _if_execute_dynamic( if else_block is not None: def else_builder( - if_op, - _, - dyn_yield_ops, - pytree_def, - mix_iter_args, - full_write_args_count, - ): - flat_args = [] - flat_args.extend( + if_op: ir.Operation, + _: List[ir.Value], + dyn_yield_ops: List[ir.Value], + pytree_def: Optional[PyTreeDef], + mix_iter_args: List[object], + full_write_args_count: int, + ) -> object: + flat_args = list( cutlass_dsl.pack_from_irvalue( - dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + dyn_yield_ops, + pytree_def, # type: ignore[arg-type] + mix_iter_args, + full_write_args_count, ) ) return else_block(*flat_args) @@ -514,12 +525,12 @@ def _if_execute_dynamic( def _while_execute_dynamic( - while_before_block: Callable, - while_after_block: Callable = None, - write_args=[], - full_write_args_count=0, - write_args_names=[], -): + while_before_block: Callable[..., Any], + while_after_block: Optional[Callable[..., Any]] = None, + write_args: List[Any] = [], + full_write_args_count: int = 0, + write_args_names: List[str] = [], +) -> Any: """ Create and return an SCF WhileOp for dynamic loops. Generate the dynamic loop body using SCF WhileOp. @@ -535,9 +546,10 @@ def _while_execute_dynamic( while_op_type_name = "while" scf_gen = ScfGenerator() - def create_while_op(dyn_yield_ops: List[ir.Value]): + def create_while_op(dyn_yield_ops: List[ir.Value]) -> ir.Operation: # Create the while operation with the types from yield_args result_types = [arg.type for arg in dyn_yield_ops] + try: while_op = scf.WhileOp(result_types, dyn_yield_ops) while_op.before.blocks.append(*result_types) @@ -554,18 +566,21 @@ def _while_execute_dynamic( ) from e def before_block_builder( - op, - block_args, - _, - pytree_def, - mix_iter_args, - full_write_args_count, - ): + op: ir.Operation, + block_args: List[ir.Value], + _: List[ir.Value], + pytree_def: Optional[PyTreeDef], + mix_iter_args: List[Any], + full_write_args_count: int, + ) -> Any: # Build the before (condition) block flat_args = [] flat_args.extend( cutlass_dsl.pack_from_irvalue( - block_args, pytree_def, mix_iter_args, full_write_args_count + block_args, + pytree_def, # type: ignore[arg-type] + mix_iter_args, + full_write_args_count, ) ) @@ -584,8 +599,10 @@ def _while_execute_dynamic( return cond, before_results - def before_block_terminator(cond_and_results, full_write_args_count): - # Generate a condition op instead of yield op + def before_block_terminator( + cond_and_results: Any, full_write_args_count: int + ) -> None: + # Generate a condition op instead of yield op. cond = cond_and_results[0] before_result_list = ScfGenerator._normalize_region_result_to_list( cond_and_results[1] @@ -602,23 +619,27 @@ def _while_execute_dynamic( scf.ConditionOp(ir_cond, ir_results_list) def after_block_builder( - op, - block_args, - _, - pytree_def, - mix_iter_args, - full_write_args_count, - ): + op: ir.Operation, + block_args: List[ir.Value], + _: List[ir.Value], + pytree_def: Optional[PyTreeDef], + mix_iter_args: List[object], + full_write_args_count: int, + ) -> object: # Build the after (body) block flat_args = [] flat_args.extend( cutlass_dsl.pack_from_irvalue( - block_args, pytree_def, mix_iter_args, full_write_args_count + block_args, + pytree_def, # type: ignore[arg-type] + mix_iter_args, + full_write_args_count, ) ) log().debug("after block args: %s", flat_args) + assert while_after_block is not None after_results = while_after_block(*flat_args) if not isinstance(after_results, (list, ir.OpResultList)): @@ -648,9 +669,9 @@ def _while_execute_dynamic( def _ifexp_execute_dynamic( pred: "ir.Value", block_args: tuple, - then_block: Callable, - else_block: Callable, -): + then_block: Callable[..., object], + else_block: Callable[..., object], +) -> object: """ Dynamically execute a Python inline if-expression (ternary) as a runtime-dispatched control flow op. @@ -708,7 +729,8 @@ def _ifexp_execute_dynamic( ) _, else_tree = cutlass_dsl.unpack_to_irvalue(else_results, "ifexp", 0) - # Check that both branches are structurally and type compatible + assert isinstance(then_tree, PyTreeDef) + assert isinstance(else_tree, PyTreeDef) if check_tree_equal(then_tree, else_tree) != -1: raise DSLRuntimeError( "Then and else blocks of ifexp return different types" @@ -722,7 +744,7 @@ def _ifexp_execute_dynamic( scf_gen = ScfGenerator() # Function to create the IfOp with correct predicate and result types - def create_if_op(_): + def create_if_op(_: List[ir.Value]) -> ir.Operation: pred_ = Boolean(pred) try: if_op = scf.IfOp( @@ -736,13 +758,10 @@ def _ifexp_execute_dynamic( ) from e return if_op - # SCF region builder for then block - def then_builder(*args): - # Just call the then_block as no arguments are passed to it + def then_builder(*args: object) -> object: return then_block(*block_args) - # SCF region builder for else block - def else_builder(*args): + def else_builder(*args: object) -> object: return else_block(*block_args) # Prepare the list of region builders for the SCF IfOp: first for "then", then for "else" diff --git a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py b/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py index b63c863be..3d2a1ed29 100644 --- a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py +++ b/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py @@ -9,6 +9,10 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from dataclasses import is_dataclass, fields as dataclass_fields +from typing import Any, Callable, List, Optional, cast + +from cutlass.base_dsl.utils.tree_utils import is_constexpr_field from cutlass.base_dsl.tvm_ffi_builder import ( DynamicParamPackCallProvider, CallContext, @@ -20,8 +24,8 @@ from cutlass._mlir import ir from cutlass._mlir.dialects import llvm from cutlass._mlir._mlir_libs._cutlass_ir import _aot_support from cutlass.cutlass_dsl.cuda_jit_executor import CudaDialectJitCompiledFunction +from cutlass.base_dsl.jit_executor import JitExecutor from cutlass.base_dsl.common import DSLRuntimeError -from typing import Optional, Callable import tvm_ffi @@ -31,11 +35,12 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): cuda_device_index: Optional[ir.Value] cuda_error_handle_block: Optional[ir.Block] - def __init__(self, target_func: str): + def __init__(self, target_func: str, has_gpu_module: bool = True): super().__init__(target_func, struct_call=True) self.cuda_global_state_symbol = f"__{target_func}_cuda_state" self.cuda_device_index = None self.cuda_error_handle_block = None + self.has_gpu_module = has_gpu_module def get_callee_struct_for_param_tensor( self, @@ -86,8 +91,11 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): arg_types.append(context.matched_var_binding[dim].type) return tuple(arg_types), tuple(allocas) - def declare_extern_funcs(self, current_block: ir.Block, context: CallContext): + def declare_extern_funcs( + self, current_block: ir.Block, context: CallContext + ) -> ir.Block: """Append the error handling function to the current block.""" + assert context.builder is not None with ir.InsertionPoint(context.module.body): context.builder.find_or_declare_extern_func( "cuda_dialect_get_error_name", @@ -116,8 +124,11 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): ) return current_block - def insert_lazy_init_cuda(self, current_block: ir.Block, context: CallContext): + def insert_lazy_init_cuda( + self, current_block: ir.Block, context: CallContext + ) -> ir.Block: """Insert the lazy init cuda function.""" + assert context.builder is not None # create global private static that is initialized to nullptr with ir.InsertionPoint(context.module.body): parsed_op = ir.Operation.parse( @@ -130,15 +141,14 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): self.cuda_global_state_symbol, self.ptr_type ) - cuda_init_ptr = context.builder.get_or_load_global_func_ptr_from_text( - current_block, "cuda_init" - ) - cuda_load_to_device_ptr = context.builder.get_or_load_global_func_ptr_from_text( - current_block, "cuda_load_to_device" - ) - set_error_ptr = context.builder.get_or_load_global_func_ptr_from_text( - current_block, "TVMFFIErrorSetRaisedFromCStr" - ) + with ir.InsertionPoint(current_block): + cuda_init_ptr = self.address_of("cuda_init", self.ptr_type) + cuda_load_to_device_ptr = self.address_of( + "cuda_load_to_device", self.ptr_type + ) + set_error_ptr = self.address_of( + "TVMFFIErrorSetRaisedFromCStr", self.ptr_type + ) with ir.InsertionPoint(current_block): # Call the callback function with the loaded ptr value @@ -231,7 +241,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): def check_cuda_error( self, code: ir.Value, current_block: ir.Block, context: CallContext - ): + ) -> ir.Block: """Check if the CUDA error is raised and return the error string if so. Uses a shared error handling block to avoid code duplication. The error code @@ -258,8 +268,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): current_device: Optional[ir.Value], target_device: Optional[ir.Value], ) -> ir.Block: - """Set the CUDA device index if it differs from the target device. - """ + """Set the CUDA device index if it differs from the target device.""" # If either device is None, no switching needed if current_device is None: assert target_device is None @@ -277,7 +286,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): self.cond_br( cond=devices_differ, true_block=switch_device_block, - false_block=continuation_block + false_block=continuation_block, ) # Switch device block: call cudaSetDevice @@ -291,7 +300,9 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): ) # Check for errors and branch to continuation - switch_device_block = self.check_cuda_error(result, switch_device_block, context) + switch_device_block = self.check_cuda_error( + result, switch_device_block, context + ) with ir.InsertionPoint(switch_device_block): self.br(continuation_block) @@ -304,6 +315,7 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): context: CallContext, ) -> ir.Block: """Generate the LLVM call operation and check if the call is successful.""" + assert context.builder is not None old_cuda_device_index: Optional[ir.Value] = None # If we need to manage CUDA device context @@ -322,7 +334,9 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): op_bundle_sizes=[], op_bundle_operands=[], ) - current_block = self.check_cuda_error(get_device_result, current_block, context) + current_block = self.check_cuda_error( + get_device_result, current_block, context + ) # Load the current device index from the alloca with ir.InsertionPoint(current_block): @@ -354,8 +368,9 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): return current_block - - def find_cuda_device_index_from_params(self, context: CallContext): + def find_cuda_device_index_from_params( + self, context: CallContext + ) -> Optional[ir.Value]: """Find the CUDA device index from tensor parameters.""" for param in context.params: if ( @@ -366,12 +381,10 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): return None def create_shared_cuda_error_block( - self, - current_block: ir.Block, - context: CallContext + self, current_block: ir.Block, context: CallContext ) -> ir.Block: - """Create a shared error handling block for all CUDA errors. - """ + """Create a shared error handling block for all CUDA errors.""" + assert context.builder is not None # Create the shared error block after the current block (setup phase) # This block will be branched to from multiple error checking sites # It accepts the error code as a block argument @@ -397,11 +410,14 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): def __call__(self, current_block: ir.Block, context: CallContext) -> ir.Block: current_block = self.declare_extern_funcs(current_block, context) - current_block = self.insert_lazy_init_cuda(current_block, context) - current_block = self.append_unload_to_global_dtors(current_block, context) + if self.has_gpu_module: + current_block = self.insert_lazy_init_cuda(current_block, context) + current_block = self.append_unload_to_global_dtors(current_block, context) # Create shared CUDA error handling block after the setup blocks # This reduces code duplication - all CUDA errors branch to this single block - self.cuda_error_handle_block = self.create_shared_cuda_error_block(current_block, context) + self.cuda_error_handle_block = self.create_shared_cuda_error_block( + current_block, context + ) # setup device index, will be set around the call to the target function self.cuda_device_index = self.find_cuda_device_index_from_params(context) current_block = super().__call__(current_block, context) @@ -411,16 +427,18 @@ class TVMFFICuteCallProvider(DynamicParamPackCallProvider): return current_block -def _inplace_hide_symbols(ir_module: ir.Module, hide_check: Callable[[str], bool]): +def _inplace_hide_symbols( + ir_module: ir.Module, hide_check: Callable[[str], bool] +) -> None: """Walk through the IRModule, hide functions that do not yet have linkage set. @param ir_module: The ir module to hide the symbols. @param hide_check: The callback to check if the symbol should be hidden. @return: The ir module with the symbols hidden. """ - defined_symbols = set() + defined_symbols: set[str] = set() - def walk_llvm_func_op(op): + def walk_llvm_func_op(op: ir.Operation) -> ir.WalkResult: # not a declaration if ( op.name == "llvm.func" @@ -432,7 +450,7 @@ def _inplace_hide_symbols(ir_module: ir.Module, hide_check: Callable[[str], bool return ir.WalkResult.ADVANCE - def walk_and_hide_symbols(op): + def walk_and_hide_symbols(op: ir.Operation) -> ir.WalkResult: # Handle llvm.func operations if op.name == "llvm.func": func_name = op.attributes["sym_name"].value @@ -454,32 +472,48 @@ def _get_format_from_object_file_path(object_file_path: str) -> str: return format +def _flatten_dataclass_arg(arg: Any) -> Any: + """Recursively flatten a dataclass argument into a tuple for TVM FFI runtime. + + TVM FFI expects tuple/array for TupleParam specs. NamedTuples work because + they are tuples, but dataclass instances need explicit flattening. + """ + if is_dataclass(arg) and not isinstance(arg, type): + values = [] + for f in dataclass_fields(arg): + if is_constexpr_field(f): + continue + values.append(_flatten_dataclass_arg(getattr(arg, f.name))) + return tuple(values) + return arg + + class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction): """Base class for TVM FFI compiled function.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) # use direct call to the tvm_ffi.Function.__call__ # to avoid most of python overhead __call__ = tvm_ffi.Function.__call__ - def to(self, device=None): + def to(self, device: Optional[int] = None) -> JitExecutor: """TVM FFI function itself is already support all devices.""" - return self + return cast(JitExecutor, self) - def run_compiled_program(self, exe_args: list[ir.Value]): + def run_compiled_program(self, exe_args: list[Any]) -> int | None: """Run the compiled program. This override is needed for implicit compile and execution.""" - return self.__call__(*exe_args) + return cast(int | None, self.__call__(*exe_args)) # type: ignore[misc] - def export_to_c( + def export_to_c( # type: ignore[override] self, object_file_path: str, - function_name: str = None, + function_name: Optional[str] = None, *, enable_pic: bool = True, export_only_tvm_ffi_symbols: bool = False, - ): + ) -> None: """Export the TVM FFI function to an object file. :param object_file_path: The path to the object file. @@ -488,16 +522,15 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction): :param export_only_tvm_ffi_symbols: Only export TVM FFI symbols (hide all others). :param host_target_triple: If not provided, the current host target is used. """ - # prefix internal function by function name - internal_symbol_prefix = "__cute_internal_" + function_name + internal_symbol_prefix = "__cute_internal_" + function_name # type: ignore[operator] mod = self.ir_module mod = get_export_module( self.ir_module, internal_symbol_prefix, - preserve_symbols=[f"__tvm_ffi_{self.function_name}"], + preserve_symbols={f"__tvm_ffi_{self.function_name}"}, ) - rename_tvm_ffi_function(mod, self.function_name, function_name) + rename_tvm_ffi_function(mod, self.function_name, function_name) # type: ignore[arg-type] if export_only_tvm_ffi_symbols: _inplace_hide_symbols(mod, lambda x: not x.startswith("__tvm_ffi")) @@ -509,20 +542,33 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction): with open(object_file_path, "wb") as f: f.write(out_bytes) - def _create_tvm_ffi_function(self): - """Create the tvm_ffi.Function from the current execution engine.""" + def _create_tvm_ffi_function(self) -> Optional["tvm_ffi.Function"]: + """Create the tvm_ffi.Function from the current execution engine. + + When the base class hands us an MlirExecutionEngine (MCJIT), we + replace it with a BinaryExecutionEngine (JITLink) to avoid + non-deterministic SIGSEGV with duplicate .text ELF sections in + multi-process torchrun workloads. + """ if self.engine is not None: - # trigger eager compile of init callbacks - cuda_init = self.engine.raw_lookup("cuda_init") - cuda_load_to_device = self.engine.raw_lookup("cuda_load_to_device") - if cuda_init is None: - raise DSLRuntimeError("cuda_init not found") - if cuda_load_to_device is None: - raise DSLRuntimeError("cuda_load_to_device not found") - tvm_ffi_function_ptr = self.engine.raw_lookup( - "__tvm_ffi_" + self.function_name + from cutlass._mlir._mlir_libs._cutlass_ir._execution_engine import ( + BinaryExecutionEngine, ) - tvm_ffi_function = tvm_ffi.Function.__from_mlir_packed_safe_call__( + from cutlass.base_dsl.env_manager import get_prefix_dsl_libs + + obj = _aot_support.export_module_to_bytes( + self.ir_module, format="o", opt_level=3, enable_pic=True + ) + libs_str = get_prefix_dsl_libs("CUTE_DSL") + shared_libs = libs_str.split(":") if libs_str else [] + self.engine = BinaryExecutionEngine( + obj, + shared_libs, + True, # useJitLink + ) + + tvm_ffi_function_ptr = self.engine.lookup("__tvm_ffi_" + self.function_name) + tvm_ffi_function = tvm_ffi.Function.__from_extern_c__( tvm_ffi_function_ptr, keep_alive_object=self.engine ) return tvm_ffi_function @@ -532,7 +578,7 @@ class TVMFFIJitCompiledFunctionBase(CudaDialectJitCompiledFunction): class TVMFFIJitCompiledFunction(tvm_ffi.Function, TVMFFIJitCompiledFunctionBase): """TVM FFI Function that directly subclasses the tvm_ffi.Function for pos only arguments.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: TVMFFIJitCompiledFunctionBase.__init__(self, *args, **kwargs) # initialize the tvm_ffi.Function from the current execution engine if self.__chandle__() != 0: @@ -542,23 +588,24 @@ class TVMFFIJitCompiledFunction(tvm_ffi.Function, TVMFFIJitCompiledFunctionBase) # move the handle from the tvm_ffi.Function to the current instance self.__move_handle_from__(tvm_ffi_function) - # use direct call to the tvm_ffi.Function.__call__ - # to avoid most of python overhead - __call__ = tvm_ffi.Function.__call__ + def __call__(self, *args: Any) -> Any: + args = tuple(_flatten_dataclass_arg(a) for a in args) + return tvm_ffi.Function.__call__(self, *args) class TVMFFIJitCompiledFunctionWithKwargs(TVMFFIJitCompiledFunctionBase): """TVM FFI Function with kwargs wrapper support""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: assert "kwargs_wrapper_spec" in kwargs, "kwargs_wrapper_spec is required" kwargs_wrapper_spec = kwargs.pop("kwargs_wrapper_spec") super().__init__(*args, **kwargs) # initialize the tvm_ffi.Function from the current execution engine self._tvm_ffi_function = self._create_tvm_ffi_function() + assert self._tvm_ffi_function is not None if kwargs_wrapper_spec.kwonly_names or kwargs_wrapper_spec.arg_defaults: try: - from tvm_ffi.utils import kwargs_wrapper # type: ignore + from tvm_ffi.utils import kwargs_wrapper self._kwargs_wrapper = kwargs_wrapper.make_kwargs_wrapper( self._tvm_ffi_function, @@ -575,18 +622,20 @@ class TVMFFIJitCompiledFunctionWithKwargs(TVMFFIJitCompiledFunctionBase): # positional only is probably fine self._kwargs_wrapper = self._tvm_ffi_function - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: """Call the TVM FFI function with kwargs wrapper.""" + args = tuple(_flatten_dataclass_arg(a) for a in args) + kwargs = {k: _flatten_dataclass_arg(v) for k, v in kwargs.items()} return self._kwargs_wrapper(*args, **kwargs) - def __tvm_ffi_object__(self): + def __tvm_ffi_object__(self) -> Optional["tvm_ffi.Function"]: return self._tvm_ffi_function def supports_kwargs_wrapper() -> bool: """Check if the kwargs wrapper is supported.""" try: - from tvm_ffi.utils import kwargs_wrapper # type: ignore + from tvm_ffi.utils import kwargs_wrapper return True except ImportError: diff --git a/python/CuTeDSL/cutlass/impl_utils.py b/python/CuTeDSL/cutlass/impl_utils.py index 29bd22786..59268d295 100644 --- a/python/CuTeDSL/cutlass/impl_utils.py +++ b/python/CuTeDSL/cutlass/impl_utils.py @@ -9,9 +9,11 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +from typing import Any + def check_value_in( - value, possible_values: list, value_description: str, prefix="" + value: Any, possible_values: list, value_description: str, prefix: str = "" ) -> None: if value not in possible_values: err_msg = prefix @@ -21,7 +23,9 @@ def check_value_in( raise ValueError(err_msg) -def check_type_in(ty, possible_types: list, type_description: str, prefix="") -> None: +def check_type_in( + ty: Any, possible_types: list, type_description: str, prefix: str = "" +) -> None: if not isinstance(ty, type): ty = type(ty) if ty not in possible_types: diff --git a/python/CuTeDSL/cutlass/jax/compile.py b/python/CuTeDSL/cutlass/jax/compile.py index 609180cbf..e58fa20d2 100644 --- a/python/CuTeDSL/cutlass/jax/compile.py +++ b/python/CuTeDSL/cutlass/jax/compile.py @@ -111,12 +111,12 @@ def jit_wrapper( spec: cutlass.Constexpr, ): # split buffer argument into inputs and outputs and return to tree - ins, outs = args[: len(spec.in_args)], args[(len(spec.in_args)) :] - ins = [x.get_tensor() for x in ins] - outs = [x.get_tensor() for x in outs] - ins = jax.tree.unflatten(spec.input_tree, ins) - outs = jax.tree.unflatten(spec.output_tree, outs) - wrapped_fn(stream, *ins, *outs, **dict(spec.kwargs)) + ins, outs = args[: len(spec.in_args)], args[(len(spec.in_args)) :] # type: ignore[attr-defined] + ins = [x.get_tensor() for x in ins] # type: ignore[assignment, attr-defined] + outs = [x.get_tensor() for x in outs] # type: ignore[assignment, attr-defined] + ins = jax.tree.unflatten(spec.input_tree, ins) # type: ignore[attr-defined] + outs = jax.tree.unflatten(spec.output_tree, outs) # type: ignore[attr-defined] + wrapped_fn(stream, *ins, *outs, **dict(spec.kwargs)) # type: ignore[operator, attr-defined] @dataclass @@ -228,7 +228,7 @@ def get_or_compile_kernel(fn, spec): try: cute_compile = cutlass.cute.compile if spec.compile_options: - cute_compile = partial(cute_compile, options=spec.compile_options) + cute_compile = partial(cute_compile, options=spec.compile_options) # type: ignore[assignment] compiled_fn = cute_compile( jit_wrapper, diff --git a/python/CuTeDSL/cutlass/jax/ffi.py b/python/CuTeDSL/cutlass/jax/ffi.py index 5aee47fcc..8bef8aa38 100644 --- a/python/CuTeDSL/cutlass/jax/ffi.py +++ b/python/CuTeDSL/cutlass/jax/ffi.py @@ -165,7 +165,7 @@ def register_ffi(ffi_version: int = get_cutlass_call_ffi_version()): fn.restype = ctypes.c_void_p type_dict[field] = jax.ffi.pycapsule(fn()) logger.debug(f"Registering ffi type: {type_name}, {type_dict}") - jax.ffi.register_ffi_type(type_name, type_dict, platform="CUDA") + jax.ffi.register_ffi_type(type_name, type_dict, platform="CUDA") # type: ignore[arg-type] # Register the custom FFI targets. match ffi_version: diff --git a/python/CuTeDSL/cutlass/jax/testing.py b/python/CuTeDSL/cutlass/jax/testing.py index 9c5d9a0ef..8ba31a3e5 100644 --- a/python/CuTeDSL/cutlass/jax/testing.py +++ b/python/CuTeDSL/cutlass/jax/testing.py @@ -15,6 +15,8 @@ import jax.numpy as jnp import cutlass.cute as cute from cutlass.cutlass_dsl import dsl_user_op +from typing import Optional +from cutlass._mlir import ir def reorder_modes(src: str, target: str) -> tuple[int, ...]: @@ -80,13 +82,17 @@ def gemm_c_shape(l, m, n, major) -> tuple[int, ...]: @dsl_user_op def get_gemm_shape_from_tensors( - a: cute.Tensor, b: cute.Tensor, *, loc=None, ip=None + a: cute.Tensor, + b: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> tuple[int, int, int, int]: """Returns a tuple of (M, N, K, L) from A/B gemm tensors.""" # mkl, nkl - m, k, l = a.shape[:] - n = b.shape[0] - return (m, n, k, l) + m, k, l = a.shape[:] # type: ignore[index] + n = b.shape[0] # type: ignore[index] + return (m, n, k, l) # type: ignore[return-value] def create_tensor( diff --git a/python/CuTeDSL/cutlass/jax/types.py b/python/CuTeDSL/cutlass/jax/types.py index 56e5d2aca..3f22dfb1d 100644 --- a/python/CuTeDSL/cutlass/jax/types.py +++ b/python/CuTeDSL/cutlass/jax/types.py @@ -9,7 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Sequence +from typing import Optional, Sequence from dataclasses import dataclass, field @@ -153,7 +153,7 @@ def default_tensor_spec(shaped) -> TensorSpec: This is appropriate for standard row-major (C-contiguous) JAX arrays that do not require dimension reordering inside the kernel. - Divisibility hints are inferred only for concrete integer dimensions. + Divisibility hints are inferred only for concrete integer dimensions. Symbolic dimensions always produce ``None`` for their slot; pass an explicit ``TensorSpec`` with ``divisibility`` set if you need alignment hints for symbolic shapes. @@ -373,7 +373,12 @@ class JaxArrayValue(JaxArray): return str(self) def _make_ordered_layout_dynamic_strides( - self, shape, order: tuple[int, ...], *, loc=None, ip=None + self, + shape, + order: tuple[int, ...], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ): i32 = ir.IntegerType.get_signless(32) pairs = sorted(zip(shape, order), key=lambda x: x[1]) @@ -415,7 +420,13 @@ class JaxArrayValue(JaxArray): return cute.make_layout(shape_i32, stride=tuple(strides_ordered)) - def _load_dynamic_shapes(self, ffi_buffer, *, loc=None, ip=None): + def _load_dynamic_shapes( + self, + ffi_buffer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ): i64 = ir.IntegerType.get_signless(64) shape_array = llvm.extractvalue( llvm.PointerType.get(), @@ -441,7 +452,13 @@ class JaxArrayValue(JaxArray): return tuple(shape_i64) - def _load_pointer(self, ffi_buffer, *, loc=None, ip=None): + def _load_pointer( + self, + ffi_buffer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ): raw_ptr = llvm.extractvalue( llvm.PointerType.get(), ffi_buffer, @@ -458,7 +475,12 @@ class JaxArrayValue(JaxArray): ip=ip, ) - def get_tensor(self, *, loc=None, ip=None): + def get_tensor( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ): ffi_buffer_type = llvm.StructType.get_literal( [llvm.PointerType.get(), llvm.PointerType.get()] ) diff --git a/python/CuTeDSL/cutlass/pipeline/helpers.py b/python/CuTeDSL/cutlass/pipeline/helpers.py index 495d819b4..7250f4686 100644 --- a/python/CuTeDSL/cutlass/pipeline/helpers.py +++ b/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -13,11 +13,13 @@ import enum import inspect from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional, Union +from typing import Any, Optional, Union import warnings import cutlass.cute as cute -from cutlass.cutlass_dsl import Boolean, Int32, if_generate, dsl_user_op +from cutlass._mlir import ir +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import CuTeDSL, Boolean, Int32, if_generate, dsl_user_op ############################################################################## @@ -32,6 +34,8 @@ class Agent(enum.Enum): # Arbitrary grouping of N threads Thread = enum.auto() + # A collection of 32 threads executing in lockstep + Warp = enum.auto() # Same as AsyncThread, but includes all threads in the block ThreadBlock = enum.auto() # Same as AsyncThread, but includes all threads in the cluster @@ -48,7 +52,7 @@ class CooperativeGroup: CooperativeGroup contains size and alignment restrictions for an Agent. """ - def __init__(self, agent: Agent, size: int = 1, alignment=None): + def __init__(self, agent: Agent, size: int = 1, alignment: Optional[int] = None): if alignment is not None: warnings.warn( "The 'alignment' parameter of CooperativeGroup's constructor is deprecated and " @@ -103,7 +107,7 @@ class PipelineOp(enum.Enum): # Async load without TMA AsyncLoad = enum.auto() -def _get_pipeline_op(type_str): +def _get_pipeline_op(type_str: int | PipelineOp) -> PipelineOp: return PipelineOp(type_str) @@ -162,8 +166,8 @@ class MbarrierArray(SyncObject): agent: tuple[PipelineOp, CooperativeGroup], tx_count: int = 0, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: self.barrier_storage = barrier_storage self.tx_count = tx_count @@ -205,20 +209,26 @@ class MbarrierArray(SyncObject): # Mbarrier initialization @dsl_user_op - def mbarrier_init(self, *, loc=None, ip=None) -> None: + def mbarrier_init( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Initializes an array of mbarriers using warp 0. """ - def then_body(): - for index in range(self.num_stages): - cute.arch.mbarrier_init( - self.get_barrier(index, loc=loc, ip=ip), - self.arrive_count, - loc=loc, - ip=ip, - ) - + def then_body() -> None: + use_uniform_mbarrier_init = True + if use_uniform_mbarrier_init: + for index in range(self.num_stages): + cute.arch.mbarrier_init( + self.get_barrier(index, loc=loc, ip=ip), + self.arrive_count, + loc=loc, + ip=ip, + ) warp_idx = cute.arch.warp_idx(loc=loc, ip=ip) warp_idx = cute.arch.make_warp_uniform(warp_idx, loc=loc, ip=ip) @@ -231,8 +241,8 @@ class MbarrierArray(SyncObject): dst: int, cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """Select the arrive corresponding to this MbarrierArray's PipelineOp. @@ -259,7 +269,6 @@ class MbarrierArray(SyncObject): # TMA operation signals local mbarrier only self.arrive_and_expect_tx(index, self.tx_count, loc=loc, ip=ip) elif self.op_type in [PipelineOp.ClcLoad]: - # Multiple threads in CTA 0 each signal a different remote CTA in cluster's mbarrier self.arrive_and_expect_tx_with_dst( index, self.tx_count, dst, loc=loc, ip=ip ) @@ -272,7 +281,12 @@ class MbarrierArray(SyncObject): @dsl_user_op def arrive_mbarrier( - self, index: int, dst_rank: Optional[int] = None, *, loc=None, ip=None + self, + index: int, + dst_rank: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: if dst_rank is None: cute.arch.mbarrier_arrive( @@ -284,7 +298,13 @@ class MbarrierArray(SyncObject): ) @dsl_user_op - def arrive_cp_async_mbarrier(self, index: int, *, loc=None, ip=None): + def arrive_cp_async_mbarrier( + self, + index: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: cute.arch.cp_async_mbarrier_arrive_noinc( self.get_barrier(index, loc=loc, ip=ip), loc=loc, ip=ip ) @@ -296,8 +316,8 @@ class MbarrierArray(SyncObject): mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: if mask is None: with cute.arch.elect_one(loc=loc, ip=ip): @@ -316,7 +336,12 @@ class MbarrierArray(SyncObject): @dsl_user_op def arrive_and_expect_tx( - self, index: int, tx_count: int, *, loc=None, ip=None + self, + index: int, + tx_count: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: with cute.arch.elect_one(loc=loc, ip=ip): cute.arch.mbarrier_arrive_and_expect_tx( @@ -325,23 +350,44 @@ class MbarrierArray(SyncObject): @dsl_user_op def arrive_and_expect_tx_with_dst( - self, index: int, tx_count: int, dst: Optional[int] = None, *, loc=None, ip=None + self, + index: int, + tx_count: int, + dst: Optional[int] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: cute.arch.mbarrier_arrive_and_expect_tx( self.get_barrier(index, loc=loc, ip=ip), tx_count, dst, loc=loc, ip=ip ) @dsl_user_op - def try_wait(self, index: int, phase: int, *, loc=None, ip=None) -> Boolean: + def try_wait( + self, + index: int, + phase: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: return cute.arch.mbarrier_try_wait( self.get_barrier(index, loc=loc, ip=ip), phase, loc=loc, ip=ip ) @dsl_user_op - def wait(self, index: int, phase: int, *, loc=None, ip=None) -> None: + def wait( + self, + index: int, + phase: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Optional[tuple]: cute.arch.mbarrier_wait( self.get_barrier(index, loc=loc, ip=ip), phase, loc=loc, ip=ip ) + return None @dsl_user_op def arrive_and_wait( @@ -351,18 +397,29 @@ class MbarrierArray(SyncObject): dst: int, cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: arrive(index, dst, cta_group, loc=loc, ip=ip) wait(index, phase, loc=loc, ip=ip) @dsl_user_op - def arrive_and_drop(self, *, loc=None, ip=None) -> None: + def arrive_and_drop( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: raise NotImplementedError("Error: Not yet supported.") @dsl_user_op - def get_barrier(self, index: int, *, loc=None, ip=None) -> cute.Pointer: + def get_barrier( + self, + index: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Pointer: return self.mbarrier_base + index def max(self) -> int: @@ -370,23 +427,24 @@ class MbarrierArray(SyncObject): # Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1). return 511 - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[object]: return [self.barrier_storage] - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[object]) -> "MbarrierArray": return MbarrierArray( values[0], self.num_stages, (self.op_type, self.cg), self.tx_count ) # Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator -MbarrierArray.__init__.__signature__ = inspect.Signature( +MbarrierArray.__init__.__signature__ = inspect.Signature( # type: ignore[attr-defined] [ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), ] ) + ############################################################################## # NamedBarrier class ############################################################################## @@ -401,19 +459,23 @@ class NamedBarrier(SyncObject): See the `PTX documentation `__. """ - barrier_id: int - num_threads: int + barrier_id: Union[int, Int32] + num_threads: Union[int, Int32] def __post_init__(self) -> None: - if self.barrier_id < 0 or self.barrier_id >= 16: - raise ValueError("Error: NamedBarrier ID must be between 0 and 16.") - if self.barrier_id == 0: - warnings.warn( - "NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used." - ) + if isinstance(self.barrier_id, int): + if self.barrier_id < 0 or self.barrier_id >= 16: + raise ValueError("Error: NamedBarrier ID must be in [0,15].") + if self.barrier_id == 0: + warnings.warn("NamedBarrier ID 0 is used by sync_threads, avoid using.") @dsl_user_op - def arrive(self, *, loc=None, ip=None) -> None: + def arrive( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ The aligned flavor of arrive is used when all threads in the CTA will execute the same instruction. See PTX documentation. @@ -426,7 +488,12 @@ class NamedBarrier(SyncObject): ) @dsl_user_op - def arrive_unaligned(self, *, loc=None, ip=None) -> None: + def arrive_unaligned( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. """ @@ -438,7 +505,12 @@ class NamedBarrier(SyncObject): ) @dsl_user_op - def wait(self, *, loc=None, ip=None) -> None: + def wait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. If synchronizing two warps in a producer/consumer pairing, the arrive count would be @@ -452,7 +524,12 @@ class NamedBarrier(SyncObject): self.arrive_and_wait(loc=loc, ip=ip) @dsl_user_op - def wait_unaligned(self, *, loc=None, ip=None) -> None: + def wait_unaligned( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: cute.arch.barrier( barrier_id=self.barrier_id, number_of_threads=self.num_threads, @@ -461,7 +538,12 @@ class NamedBarrier(SyncObject): ) @dsl_user_op - def arrive_and_wait(self, *, loc=None, ip=None) -> None: + def arrive_and_wait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: cute.arch.barrier( barrier_id=self.barrier_id, number_of_threads=self.num_threads, @@ -470,15 +552,30 @@ class NamedBarrier(SyncObject): ) @dsl_user_op - def arrive_and_drop(self, *, loc=None, ip=None) -> None: + def arrive_and_drop( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: raise NotImplementedError("Error: Not supported.") @dsl_user_op - def sync(self, *, loc=None, ip=None) -> None: + def sync( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self.arrive_and_wait() @dsl_user_op - def get_barrier(self, *, loc=None, ip=None) -> int: + def get_barrier( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Union[int, Int32]: return self.barrier_id def max(self) -> int: @@ -503,27 +600,52 @@ class TmaStoreFence(SyncObject): self.num_stages = num_stages @dsl_user_op - def arrive(self, *, loc=None, ip=None) -> None: + def arrive( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: cute.arch.cp_async_bulk_commit_group(loc=loc, ip=ip) @dsl_user_op - def wait(self, *, loc=None, ip=None) -> None: + def wait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: cute.arch.cp_async_bulk_wait_group( self.num_stages - 1, read=True, loc=loc, ip=ip ) @dsl_user_op - def arrive_and_wait(self, *, loc=None, ip=None) -> None: + def arrive_and_wait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self.arrive(loc=loc, ip=ip) self.wait(loc=loc, ip=ip) @dsl_user_op - def arrive_and_drop(self, *, loc=None, ip=None) -> None: + def arrive_and_drop( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: raise NotImplementedError("Error: Not supported.") # TmaStoreFence doesn't have mbarriers @dsl_user_op - def get_barrier(self, *, loc=None, ip=None) -> None: + def get_barrier( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: assert False, ( "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." ) @@ -532,7 +654,12 @@ class TmaStoreFence(SyncObject): raise NotImplementedError("Error: Not supported.") @dsl_user_op - def tail(self, *, loc=None, ip=None) -> None: + def tail( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: cute.arch.cp_async_bulk_wait_group(0, read=True, loc=loc, ip=ip) @@ -557,7 +684,7 @@ class PipelineState: Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. """ - def __init__(self, stages: int, count, index, phase): + def __init__(self, stages: int, count: Int32, index: Int32, phase: Int32): self._stages = stages self._count = count self._index = index @@ -583,73 +710,94 @@ class PipelineState: return self._phase @dsl_user_op - def reset_count(self, *, loc=None, ip=None): + def reset_count( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self._count = Int32(0, loc=loc, ip=ip) @dsl_user_op - def advance(self, *, loc=None, ip=None) -> None: + def advance( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self._index += 1 self._count += 1 - def then_body(index, phase): + def then_body(index: Int32, phase: Int32) -> tuple[Int32, Int32]: new_index = Int32(0, loc=loc, ip=ip) new_phase = phase ^ 1 - return new_index, new_phase + return new_index, new_phase # type: ignore[return-value] - def else_body(index, phase): + def else_body(index: Int32, phase: Int32) -> tuple[Int32, Int32]: return index, phase - self._index, self._phase = if_generate( + self._index, self._phase = if_generate( # type: ignore[assignment, misc] self._index == self.stages, then_body, else_body, - [self.index, self.phase], + [self.index, self.phase], # type: ignore[list-item] [Int32, Int32], loc=loc, ip=ip, ) @dsl_user_op - def reverse(self, *, loc=None, ip=None): + def reverse( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self._index -= 1 self._count -= 1 - def then_body(index, phase): + def then_body(index: Int32, phase: Int32) -> tuple[Int32, Int32]: new_index = Int32(self.stages - 1, loc=loc, ip=ip) new_phase = phase ^ 1 - return new_index, new_phase + return new_index, new_phase # type: ignore[return-value] - def else_body(index, phase): + def else_body(index: Int32, phase: Int32) -> tuple[Int32, Int32]: return index, phase - self._index, self._phase = if_generate( + self._index, self._phase = if_generate( # type: ignore[assignment, misc] self._index == -1, then_body, else_body, - [self.index, self.phase], + [self.index, self.phase], # type: ignore[list-item] [Int32, Int32], loc=loc, ip=ip, ) - def __get_mlir_types__(self): - return [self._count.type, self._index.type, self._phase.type] + def __get_mlir_types__(self) -> list[ir.Type]: + return [self._count.type, self._index.type, self._phase.type] # type: ignore[attr-defined] - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: count = self._count index = self._index phase = self._phase return [count.ir_value(), index.ir_value(), phase.ir_value()] # This can be overridden by derived classes - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "PipelineState": return PipelineState( self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) ) @dsl_user_op -def make_pipeline_state(type: PipelineUserType, stages: int, *, loc=None, ip=None): +def make_pipeline_state( + type: PipelineUserType, + stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> PipelineState: """ Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. """ @@ -673,6 +821,7 @@ def make_pipeline_state(type: PipelineUserType, stages: int, *, loc=None, ip=Non ) + ############################################################################## # Helper functions ############################################################################## @@ -683,9 +832,9 @@ def pipeline_init_arrive( cluster_shape_mn: Optional[cute.Layout] = None, is_relaxed: bool = False, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Fences the mbarrier_init and sends an arrive if using clusters. """ @@ -706,8 +855,11 @@ def pipeline_init_arrive( @dsl_user_op def pipeline_init_wait( - cluster_shape_mn: Optional[cute.Layout] = None, *, loc=None, ip=None -): + cluster_shape_mn: Optional[cute.Layout] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Syncs the threadblock or cluster """ @@ -721,13 +873,25 @@ def pipeline_init_wait( @dsl_user_op -def _sync(group: Agent, is_relaxed: bool = False, *, loc=None, ip=None): +def _sync( + group: Agent, + is_relaxed: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: warnings.warn("_sync is deprecated. Please use agent_sync instead.") agent_sync(group, is_relaxed, loc=loc, ip=ip) @dsl_user_op -def agent_sync(group: Agent, is_relaxed: bool = False, *, loc=None, ip=None): +def agent_sync( + group: Agent, + is_relaxed: bool = False, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ Syncs all threads within an agent. """ @@ -749,7 +913,13 @@ def agent_sync(group: Agent, is_relaxed: bool = False, *, loc=None, ip=None): # NamedBarrier free functions @dsl_user_op -def arrive(barrier_id: int, num_threads: int, *, loc=None, ip=None): +def arrive( + barrier_id: int, + num_threads: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ The aligned flavor of arrive is used when all threads in the CTA will execute the same instruction. See PTX documentation. @@ -760,7 +930,13 @@ def arrive(barrier_id: int, num_threads: int, *, loc=None, ip=None): @dsl_user_op -def arrive_unaligned(barrier_id: int, num_threads: int, *, loc=None, ip=None): +def arrive_unaligned( + barrier_id: int, + num_threads: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: """ The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. """ @@ -770,7 +946,9 @@ def arrive_unaligned(barrier_id: int, num_threads: int, *, loc=None, ip=None): @dsl_user_op -def wait(*, loc=None, ip=None): +def wait( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None +) -> None: """ NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. If synchronizing two warps in a producer/consumer pairing, the arrive count would be @@ -785,7 +963,13 @@ def wait(*, loc=None, ip=None): @dsl_user_op -def wait_unaligned(barrier_id: int, num_threads: int, *, loc=None, ip=None): +def wait_unaligned( + barrier_id: int, + num_threads: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: warnings.warn( "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." ) @@ -795,12 +979,23 @@ def wait_unaligned(barrier_id: int, num_threads: int, *, loc=None, ip=None): @dsl_user_op -def arrive_and_wait(barrier_id: int, num_threads: int, *, loc=None, ip=None): +def arrive_and_wait( + barrier_id: int, + num_threads: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: cute.arch.barrier( barrier_id=barrier_id, number_of_threads=num_threads, loc=loc, ip=ip ) @dsl_user_op -def sync(barrier_id: int = 0, *, loc=None, ip=None): +def sync( + barrier_id: int = 0, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: cute.arch.barrier(barrier_id=barrier_id, loc=loc, ip=ip) diff --git a/python/CuTeDSL/cutlass/pipeline/sm100.py b/python/CuTeDSL/cutlass/pipeline/sm100.py index f9c8eb5b2..26b4c5687 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm100.py +++ b/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -12,6 +12,8 @@ from dataclasses import dataclass from typing import Optional +from cutlass._mlir import ir + import cutlass import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, Int32, if_generate, dsl_user_op @@ -50,8 +52,8 @@ class PipelineTmaUmma(PipelineAsync): agent: tuple[PipelineOp, CooperativeGroup], tx_count: int = 0, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> SyncObject: """ Returns a SyncObject corresponding to an agent's PipelineOp. @@ -84,9 +86,9 @@ class PipelineTmaUmma(PipelineAsync): cta_layout_vmnk: cute.Layout, mcast_mode_mn: tuple[int, int], *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Int32: """ Computes a mask for signaling arrivals to multicasting threadblocks. """ @@ -138,7 +140,12 @@ class PipelineTmaUmma(PipelineAsync): @dsl_user_op @staticmethod - def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout, *, loc=None, ip=None): + def _compute_is_leader_cta( + cta_layout_vmnk: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: """ Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. """ @@ -160,13 +167,13 @@ class PipelineTmaUmma(PipelineAsync): producer_group: CooperativeGroup, consumer_group: CooperativeGroup, tx_count: int, - barrier_storage: cute.Pointer = None, + barrier_storage: Optional[cute.Pointer] = None, cta_layout_vmnk: Optional[cute.Layout] = None, mcast_mode_mn: tuple[int, int] = (1, 1), defer_sync: bool = False, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "PipelineTmaUmma": """Creates and initializes a new PipelineTmaUmma instance. :param num_stages: Number of buffer stages for this pipeline @@ -257,11 +264,17 @@ class PipelineTmaUmma(PipelineAsync): ) @dsl_user_op - def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + def consumer_release( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ UMMA consumer release buffer empty, cta_group needs to be provided. """ - self.sync_object_empty.arrive( + self.sync_object_empty.arrive( # type: ignore[call-arg] state.index, self.consumer_mask, self.cta_group, loc=loc, ip=ip ) @@ -270,15 +283,15 @@ class PipelineTmaUmma(PipelineAsync): state: PipelineState, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( + lambda: self.sync_object_empty.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, @@ -286,20 +299,21 @@ class PipelineTmaUmma(PipelineAsync): ) if_generate( self.is_leader_cta, - lambda: self.sync_object_full.arrive( + lambda: self.sync_object_full.arrive( # type: ignore[call-arg] state.index, self.producer_mask, loc=loc, ip=ip ), loc=loc, ip=ip, ) - def producer_commit(self, state: PipelineState): + def producer_commit(self, state: PipelineState) -> None: """ TMA producer commit is a noop since TMA instruction itself updates the transaction count. """ pass + @dataclass(frozen=True) class PipelineAsyncUmma(PipelineAsync): """ @@ -310,7 +324,12 @@ class PipelineAsyncUmma(PipelineAsync): @dsl_user_op @staticmethod - def _compute_leading_cta_rank(cta_v_size, *, loc=None, ip=None): + def _compute_leading_cta_rank( + cta_v_size: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Int32: """ Computes the leading CTA rank. """ @@ -323,7 +342,12 @@ class PipelineAsyncUmma(PipelineAsync): @dsl_user_op @staticmethod - def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout, *, loc=None, ip=None): + def _compute_is_leader_cta( + cta_layout_vmnk: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: """ Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. """ @@ -338,7 +362,12 @@ class PipelineAsyncUmma(PipelineAsync): @dsl_user_op @staticmethod - def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout, *, loc=None, ip=None): + def _compute_peer_cta_mask( + cta_layout_vmnk: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Int32: """ Computes a mask for signaling arrivals to multicasting threadblocks. """ @@ -373,12 +402,12 @@ class PipelineAsyncUmma(PipelineAsync): num_stages: int, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, - barrier_storage: cute.Pointer = None, + barrier_storage: Optional[cute.Pointer] = None, cta_layout_vmnk: Optional[cute.Layout] = None, defer_sync: bool = False, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "PipelineAsyncUmma": """Creates and initializes a new PipelineAsyncUmma instance. :param num_stages: Number of buffer stages for this pipeline @@ -470,11 +499,17 @@ class PipelineAsyncUmma(PipelineAsync): ) @dsl_user_op - def consumer_release(self, state: PipelineState, *, loc=None, ip=None): + def consumer_release( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ UMMA consumer release buffer empty, cta_group needs to be provided. """ - self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) # type: ignore[call-arg] @dataclass(frozen=True) @@ -487,7 +522,12 @@ class PipelineUmmaAsync(PipelineAsync): @dsl_user_op @staticmethod - def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout, *, loc=None, ip=None): + def _compute_tmem_sync_mask( + cta_layout_vmnk: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Int32: """ Computes a mask to signal completion of tmem buffers for 2CTA kernels. """ @@ -505,7 +545,9 @@ class PipelineUmmaAsync(PipelineAsync): @dsl_user_op @staticmethod - def _compute_peer_cta_rank(*, loc=None, ip=None): + def _compute_peer_cta_rank( + *, loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None + ) -> Int32: """ Computes a mask to signal release of tmem buffers for 2CTA kernels. """ @@ -523,12 +565,12 @@ class PipelineUmmaAsync(PipelineAsync): num_stages: int, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, - barrier_storage: cute.Pointer = None, + barrier_storage: Optional[cute.Pointer] = None, cta_layout_vmnk: Optional[cute.Layout] = None, defer_sync: bool = False, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "PipelineUmmaAsync": """Creates an instance of PipelineUmmaAsync with computed attributes. :param num_stages: Number of buffer stages for this pipeline @@ -611,17 +653,29 @@ class PipelineUmmaAsync(PipelineAsync): ) @dsl_user_op - def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + def producer_commit( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ UMMA producer commit buffer full, cta_group needs to be provided. """ - self.sync_object_full.arrive( + self.sync_object_full.arrive( # type: ignore[call-arg] state.index, self.producer_mask, self.cta_group, loc=loc, ip=ip ) @dsl_user_op @cute.jit - def producer_tail(self, state: PipelineState, *, loc=None, ip=None): + def producer_tail( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Make sure the last used buffer empty signal is visible to producer. Producer tail is usually executed by producer before exit, to avoid dangling @@ -639,7 +693,7 @@ class PipelineUmmaAsync(PipelineAsync): if is_leader_cta: # Assume state contains that next useful buffer # So we only need to advance to num_stages - 1 times to last used buffer - for i in cutlass.range_constexpr(self.num_stages - 1): + for i in cutlass.range_constexpr(self.num_stages - 1): # type: ignore[func-returns-value] state.advance(loc=loc, ip=ip) self.producer_acquire(state, loc=loc, ip=ip) @@ -667,7 +721,9 @@ class PipelineClcFetchAsync: @staticmethod @cute.jit - def _init_full_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32): + def _init_full_barrier_arrive_signal( + cta_layout_vmnk: cute.Layout, tidx: Int32 + ) -> tuple: """ Computes producer barrier signaling parameters, returns destination CTA rank (0 to cluster_size-1) based on thread ID, and a boolean flag indicating if @@ -687,12 +743,12 @@ class PipelineClcFetchAsync: producer_group: CooperativeGroup, consumer_group: CooperativeGroup, tx_count: int, - barrier_storage: cute.Pointer = None, - producer_mask: Int32 = None, - consumer_mask: Int32 = None, + barrier_storage: Optional[cute.Pointer] = None, + producer_mask: Optional[Int32] = None, + consumer_mask: Optional[Int32] = None, cta_layout_vmnk: Optional[cute.Layout] = None, defer_sync: bool = False, - ): + ) -> "PipelineClcFetchAsync": """ This helper function computes any necessary attributes and returns an instance of PipelineClcFetchAsync. :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers @@ -741,7 +797,7 @@ class PipelineClcFetchAsync: # The producer (sched warp) runs ONLY in CTA 0, all consumers # across the cluster must arrive at CTA 0's empty barrier - consumer_mask = 0 + consumer_mask = 0 # type: ignore[assignment] if not defer_sync: cute.arch.mbarrier_init_fence() @@ -765,9 +821,9 @@ class PipelineClcFetchAsync: state: PipelineState, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Producer acquire waits for empty buffer and sets transaction expectation on full barrier. @@ -776,7 +832,7 @@ class PipelineClcFetchAsync: """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( + lambda: self.sync_object_empty.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, @@ -784,7 +840,7 @@ class PipelineClcFetchAsync: ) if_generate( self.is_signalling_thread, - lambda: self.sync_object_full.arrive( + lambda: self.sync_object_full.arrive( # type: ignore[call-arg] state.index, self.producer_mask, loc=loc, ip=ip ), loc=loc, @@ -797,9 +853,9 @@ class PipelineClcFetchAsync: state: PipelineState, try_wait_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Consumer waits for full barrier to be signaled by hardware multicast. @@ -808,7 +864,7 @@ class PipelineClcFetchAsync: """ if_generate( try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait( + lambda: self.sync_object_full.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, @@ -816,14 +872,34 @@ class PipelineClcFetchAsync: ) @dsl_user_op - def consumer_release(self, state: PipelineState, *, loc=None, ip=None): - self.sync_object_empty.arrive(state.index, self.consumer_mask, loc=loc, ip=ip) + def consumer_release( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + self.sync_object_empty.arrive(state.index, self.consumer_mask, loc=loc, ip=ip) # type: ignore[call-arg] @dsl_user_op def producer_get_barrier( - self, state: PipelineState, *, loc=None, ip=None + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: - return self.sync_object_full.get_barrier(state.index, loc=loc, ip=ip) + return self.sync_object_full.get_barrier(state.index, loc=loc, ip=ip) # type: ignore[call-arg, return-value] + + @dsl_user_op + def consumer_get_barrier( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Pointer: + return self.sync_object_empty.get_barrier(state.index, loc=loc, ip=ip) # type: ignore[call-arg, return-value] @dsl_user_op def producer_tail( @@ -831,26 +907,10 @@ class PipelineClcFetchAsync: state: PipelineState, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): - """ - Ensures all in-flight buffers are released before producer exits. - - :param state: Pipeline state with current position in the buffer - :param try_acquire_token: Optional token to skip the empty barrier waits - - """ - for i in range(self.num_stages): - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( - state.index, state.phase, loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - state.advance(loc=loc, ip=ip) + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Pointer: + return self.sync_object_empty.get_barrier(state.index, loc=loc, ip=ip) # type: ignore[call-arg, return-value] @dataclass(frozen=True) @@ -863,19 +923,21 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): sync_object_empty_umma: SyncObject sync_object_empty_async: SyncObject cta_group: cute.nvgpu.tcgen05.CtaGroup + consumer_dst_rank_async: Optional[Int32] = None + is_signalling_thread: Boolean = True # type: ignore[assignment] @staticmethod - def create( + def create( # type: ignore[override] *, num_stages: int, producer_group: CooperativeGroup, consumer_group_umma: CooperativeGroup, consumer_group_async: CooperativeGroup, tx_count: int, - barrier_storage: cute.Pointer = None, + barrier_storage: Optional[cute.Pointer] = None, cta_layout_vmnk: Optional[cute.Layout] = None, defer_sync: bool = False, - ): + ) -> "PipelineTmaMultiConsumersAsync": """ This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync. :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers @@ -922,7 +984,10 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): consumer = (consumer_type, consumer_group) sync_object_full = PipelineTmaUmma._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count + barrier_storage.align(min_align=8), + num_stages, + producer, + tx_count, ) sync_object_empty = PipelineTmaUmma._make_sync_object( barrier_storage.align(min_align=8) + num_stages, num_stages, consumer @@ -970,15 +1035,15 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): state: PipelineState, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks. """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( + lambda: self.sync_object_empty.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, @@ -986,13 +1051,19 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): ) if_generate( self.is_leader_cta, - lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), # type: ignore[call-arg] loc=loc, ip=ip, ) @dsl_user_op - def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + def producer_commit( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ TMA producer commit is a noop since TMA instruction itself updates the transaction count. """ @@ -1000,14 +1071,19 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): @dsl_user_op def consumer_release( - self, state: PipelineState, op_type: PipelineOp, *, loc=None, ip=None - ): + self, + state: PipelineState, + op_type: PipelineOp, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: if op_type == PipelineOp.TCGen05Mma: - self.sync_object_empty_umma.arrive( + self.sync_object_empty_umma.arrive( # type: ignore[call-arg] state.index, self.consumer_mask, self.cta_group, loc=loc, ip=ip ) elif op_type == PipelineOp.AsyncThread: - self.sync_object_empty_async.arrive( + self.sync_object_empty_async.arrive( # type: ignore[call-arg] state.index, self.consumer_mask, loc=loc, ip=ip ) else: diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py index 9a5b3ebb3..181866088 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm90.py +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -10,7 +10,7 @@ # is strictly prohibited. from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, Int32, if_generate, dsl_user_op @@ -27,6 +27,7 @@ from cutlass.pipeline import ( make_pipeline_state, agent_sync, ) +from cutlass._mlir import ir ############################################################################## # Pipeline classes @@ -155,11 +156,11 @@ class PipelineAsync: num_stages: int, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, - barrier_storage: cute.Pointer = None, - producer_mask: Int32 = None, - consumer_mask: Int32 = None, + barrier_storage: Optional[cute.Pointer] = None, + producer_mask: Optional[Int32] = None, + consumer_mask: Optional[Int32] = None, defer_sync: bool = False, - ): + ) -> "PipelineAsync": """Creates and initializes a new PipelineAsync instance. This helper function computes necessary attributes and returns an instance of PipelineAsync @@ -217,12 +218,12 @@ class PipelineAsync: state: PipelineState, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( + lambda: self.sync_object_empty.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, @@ -230,12 +231,24 @@ class PipelineAsync: ) @dsl_user_op - def producer_try_acquire(self, state: PipelineState, *, loc=None, ip=None): - return self.sync_object_empty.try_wait(state.index, state.phase, loc=loc, ip=ip) + def producer_try_acquire( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: + return self.sync_object_empty.try_wait(state.index, state.phase, loc=loc, ip=ip) # type: ignore[attr-defined] @dsl_user_op - def producer_commit(self, state: PipelineState, *, loc=None, ip=None): - self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) + def producer_commit( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) # type: ignore[call-arg] @dsl_user_op def consumer_wait( @@ -243,12 +256,12 @@ class PipelineAsync: state: PipelineState, try_wait_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: if_generate( try_wait_token is None or try_wait_token == 0, - lambda: self.sync_object_full.wait( + lambda: self.sync_object_full.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, @@ -256,21 +269,53 @@ class PipelineAsync: ) @dsl_user_op - def consumer_try_wait(self, state: PipelineState, *, loc=None, ip=None): - return self.sync_object_full.try_wait(state.index, state.phase, loc=loc, ip=ip) + def consumer_try_wait( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: + return self.sync_object_full.try_wait(state.index, state.phase, loc=loc, ip=ip) # type: ignore[attr-defined] @dsl_user_op - def consumer_release(self, state: PipelineState, *, loc=None, ip=None): - self.sync_object_empty.arrive(state.index, self.consumer_mask, loc=loc, ip=ip) + def consumer_release( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + self.sync_object_empty.arrive(state.index, self.consumer_mask, loc=loc, ip=ip) # type: ignore[call-arg] + + @dsl_user_op + def consumer_get_barrier( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Pointer: + return self.sync_object_empty.get_barrier(state.index, loc=loc, ip=ip) # type: ignore[call-arg, return-value] @dsl_user_op def producer_get_barrier( - self, state: PipelineState, *, loc=None, ip=None + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: - return self.sync_object_full.get_barrier(state.index, loc=loc, ip=ip) + return self.sync_object_full.get_barrier(state.index, loc=loc, ip=ip) # type: ignore[call-arg, return-value] @dsl_user_op - def producer_tail(self, state: PipelineState, *, loc=None, ip=None): + def producer_tail( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Make sure the last used buffer empty signal is visible to producer. Producer tail is usually executed by producer before exit, to avoid dangling @@ -287,21 +332,36 @@ class PipelineAsync: # Util methods to manage producer and consumer @dsl_user_op - def make_producer(self, *, loc=None, ip=None): + def make_producer( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "PipelineProducer": state = make_pipeline_state( PipelineUserType.Producer, self.num_stages, loc=loc, ip=ip ) - return PipelineProducer(self, state, self.sync_object_full.cg) + return PipelineProducer(self, state, self.sync_object_full.cg) # type: ignore[attr-defined] @dsl_user_op - def make_consumer(self, *, loc=None, ip=None): + def make_consumer( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "PipelineConsumer": state = make_pipeline_state( PipelineUserType.Consumer, self.num_stages, loc=loc, ip=ip ) - return PipelineConsumer(self, state, self.sync_object_empty.cg) + return PipelineConsumer(self, state, self.sync_object_empty.cg) # type: ignore[attr-defined] @dsl_user_op - def make_participants(self, *, loc=None, ip=None): + def make_participants( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "tuple[PipelineProducer, PipelineConsumer]": return self.make_producer(loc=loc, ip=ip), self.make_consumer(loc=loc, ip=ip) @@ -312,15 +372,15 @@ class PipelineCpAsync(PipelineAsync): """ @staticmethod - def create( + def create( # type: ignore[override] barrier_storage: cute.Pointer, num_stages: Int32, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, - producer_mask: Int32 = None, - consumer_mask: Int32 = None, + producer_mask: Optional[Int32] = None, + consumer_mask: Optional[Int32] = None, defer_sync: bool = False, - ): + ) -> "PipelineCpAsync": """Helper function that computes necessary attributes and returns a ``PipelineCpAsync`` instance. :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers @@ -345,10 +405,14 @@ class PipelineCpAsync(PipelineAsync): consumer = (consumer_type, consumer_group) sync_object_array_full = PipelineCpAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer + barrier_storage.align(min_align=8), + num_stages, # type: ignore[arg-type] + producer, ) sync_object_array_empty = PipelineCpAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + barrier_storage.align(min_align=8) + num_stages, + num_stages, # type: ignore[arg-type] + consumer, ) if not defer_sync: @@ -358,7 +422,7 @@ class PipelineCpAsync(PipelineAsync): return PipelineCpAsync( sync_object_array_full, sync_object_array_empty, - num_stages, + num_stages, # type: ignore[arg-type] producer_mask, consumer_mask, ) @@ -378,7 +442,7 @@ class PipelineTmaAsync(PipelineAsync): cta_layout_vmnk: cute.Layout, tidx: Int32, mcast_mode_mn: tuple[int, int] = (1, 1), - ): + ) -> tuple[Int32, Boolean]: """Initialize the empty barrier arrive signal. This function determines which threads should signal empty barrier arrives based on the cluster layout @@ -409,14 +473,14 @@ class PipelineTmaAsync(PipelineAsync): cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster) is_mcast_mode_m = ( - dst_cta_coord[0] == cur_cta_coord[0] - and dst_cta_coord[1] == cur_cta_coord[1] - and dst_cta_coord[3] == cur_cta_coord[3] + dst_cta_coord[0] == cur_cta_coord[0] # type: ignore[index] + and dst_cta_coord[1] == cur_cta_coord[1] # type: ignore[index] + and dst_cta_coord[3] == cur_cta_coord[3] # type: ignore[index] ) is_mcast_mode_n = ( - dst_cta_coord[0] == cur_cta_coord[0] - and dst_cta_coord[2] == cur_cta_coord[2] - and dst_cta_coord[3] == cur_cta_coord[3] + dst_cta_coord[0] == cur_cta_coord[0] # type: ignore[index] + and dst_cta_coord[2] == cur_cta_coord[2] # type: ignore[index] + and dst_cta_coord[3] == cur_cta_coord[3] # type: ignore[index] ) assert not (mcast_mode_mn[0] == 0 and mcast_mode_mn[1] == 0) @@ -431,18 +495,18 @@ class PipelineTmaAsync(PipelineAsync): return dst_rank, is_signalling_thread @staticmethod - def create( + def create( # type: ignore[override] *, num_stages: int, producer_group: CooperativeGroup, consumer_group: CooperativeGroup, tx_count: int, - barrier_storage: cute.Pointer = None, + barrier_storage: Optional[cute.Pointer] = None, cta_layout_vmnk: Optional[cute.Layout] = None, tidx: Optional[Int32] = None, mcast_mode_mn: tuple[int, int] = (1, 1), defer_sync: bool = False, - ): + ) -> "PipelineTmaAsync": """Create a new ``PipelineTmaAsync`` instance. :param num_stages: Number of buffer stages for this pipeline @@ -521,190 +585,30 @@ class PipelineTmaAsync(PipelineAsync): state: PipelineState, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. """ if_generate( try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( + lambda: self.sync_object_empty.wait( # type: ignore[call-arg] state.index, state.phase, loc=loc, ip=ip ), loc=loc, ip=ip, ) - self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) + self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) # type: ignore[call-arg] @dsl_user_op - def producer_commit(self, state: PipelineState, *, loc=None, ip=None): - """ - TMA producer commit is a noop since TMA instruction itself updates the transaction count. - """ - pass - - @dsl_user_op - def consumer_release(self, state: PipelineState, *, loc=None, ip=None): - """ - TMA consumer release conditionally signals the empty buffer to the producer. - """ - if_generate( - self.is_signalling_thread, - lambda: self.sync_object_empty.arrive( - state.index, self.consumer_mask, loc=loc, ip=ip - ), - ) - - -@dataclass(frozen=True) -class PipelineTmaMultiConsumersAsync(PipelineAsync): - """ - PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers. - """ - - is_leader_cta: bool - sync_object_empty_umma: SyncObject - sync_object_empty_async: SyncObject - cta_group: cute.nvgpu.tcgen05.CtaGroup - - @staticmethod - def create( - *, - num_stages: int, - producer_group: CooperativeGroup, - consumer_group_umma: CooperativeGroup, - consumer_group_async: CooperativeGroup, - tx_count: int, - barrier_storage: cute.Pointer = None, - cta_layout_vmnk: Optional[cute.Layout] = None, - defer_sync: bool = False, - ): - """Creates an instance of PipelineTmaMultiConsumersAsync with computed attributes. - - :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers - :type barrier_storage: cute.Pointer - :param num_stages: Number of buffer stages for this pipeline - :type num_stages: int - :param producer_group: ``CooperativeGroup`` for the producer agent - :type producer_group: CooperativeGroup - :param consumer_group_umma: ``CooperativeGroup`` for the UMMA consumer agent - :type consumer_group_umma: CooperativeGroup - :param consumer_group_async: ``CooperativeGroup`` for the AsyncThread consumer agent - :type consumer_group_async: CooperativeGroup - :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage - :type tx_count: int - :param cta_layout_vmnk: Layout of the cluster shape, defaults to None - :type cta_layout_vmnk: Optional[cute.Layout] - :raises ValueError: If ``barrier_storage`` is not a ``cute.Pointer`` instance - :raises ValueError: If ``UMMA`` and ``AsyncThread`` consumer groups are not the same agent - :raises ValueError: If ``cta_layout_vmnk`` size is not 1 - :return: New instance of ``PipelineTmaMultiConsumersAsync`` - :rtype: PipelineTmaMultiConsumersAsync - """ - if not isinstance(barrier_storage, cute.Pointer): - raise TypeError( - f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" - ) - - producer_type = PipelineOp.TmaLoad - consumer_type = PipelineOp.Composite - consumer_type_umma = PipelineOp.TCGen05Mma - consumer_type_async = PipelineOp.AsyncThread - - if consumer_group_umma.agent != consumer_group_async.agent: - raise ValueError( - "UMMA and AsyncThread consumer groups must be the same agent" - ) - - if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1: - raise ValueError( - "PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, " - f"cta_layout_vmnk:{cta_layout_vmnk}" - ) - - consumer_group = CooperativeGroup( - consumer_group_umma.agent, - consumer_group_umma.size + consumer_group_async.size, - ) - - producer = (producer_type, producer_group) - consumer = (consumer_type, consumer_group) - - sync_object_full = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8), num_stages, producer, tx_count - ) - sync_object_empty = PipelineAsync._make_sync_object( - barrier_storage.align(min_align=8) + num_stages, num_stages, consumer - ) - sync_object_empty_umma = sync_object_empty.recast_to_new_op_type( - consumer_type_umma - ) - sync_object_empty_async = sync_object_empty.recast_to_new_op_type( - consumer_type_async - ) - - # No mcast mask if not using clusters - producer_mask = None - consumer_mask = None - # All thread-blocks are leaders if not using clusters - is_leader_cta = True - cta_group = ( - cute.nvgpu.tcgen05.CtaGroup.ONE - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 - else cute.nvgpu.tcgen05.CtaGroup.TWO - ) - - if not defer_sync: - cute.arch.mbarrier_init_fence() - if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: - agent_sync(Agent.ThreadBlock) - else: - agent_sync(Agent.ThreadBlockCluster, is_relaxed=True) - - return PipelineTmaMultiConsumersAsync( - sync_object_full, - sync_object_empty, - num_stages, - producer_mask, - consumer_mask, - is_leader_cta, - sync_object_empty_umma, - sync_object_empty_async, - cta_group, - ) - - @dsl_user_op - def producer_acquire( + def producer_commit( self, state: PipelineState, - try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, - ): - """ - TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks. - """ - if_generate( - try_acquire_token is None or try_acquire_token == 0, - lambda: self.sync_object_empty.wait( - state.index, state.phase, loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - if_generate( - self.is_leader_cta, - lambda: self.sync_object_full.arrive( - state.index, self.producer_mask, loc=loc, ip=ip - ), - loc=loc, - ip=ip, - ) - - @dsl_user_op - def producer_commit(self, state: PipelineState, *, loc=None, ip=None): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ TMA producer commit is a noop since TMA instruction itself updates the transaction count. """ @@ -712,18 +616,21 @@ class PipelineTmaMultiConsumersAsync(PipelineAsync): @dsl_user_op def consumer_release( - self, state: PipelineState, op_type: PipelineOp, *, loc=None, ip=None - ): - if op_type == PipelineOp.TCGen05Mma: - self.sync_object_empty_umma.arrive( - state.index, self.consumer_mask, self.cta_group, loc=loc, ip=ip - ) - elif op_type == PipelineOp.AsyncThread: - self.sync_object_empty_async.arrive( + self, + state: PipelineState, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + """ + TMA consumer release conditionally signals the empty buffer to the producer. + """ + if_generate( + self.is_signalling_thread, + lambda: self.sync_object_empty.arrive( # type: ignore[call-arg] state.index, self.consumer_mask, loc=loc, ip=ip - ) - else: - raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}") + ), + ) @dataclass(frozen=True) @@ -733,11 +640,11 @@ class PipelineTmaStore(PipelineAsync): """ @staticmethod - def create( + def create( # type: ignore[override] *, num_stages: int, producer_group: CooperativeGroup, - ): + ) -> "PipelineTmaStore": """This helper function computes any necessary attributes and returns an instance of ``PipelineTmaStore``. :param num_stages: Number of buffer stages for this pipeline @@ -751,29 +658,54 @@ class PipelineTmaStore(PipelineAsync): producer = (producer_type, producer_group) - sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer) + sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer) # type: ignore[arg-type] - return PipelineTmaStore(sync_object_full, None, num_stages, None, None) + return PipelineTmaStore(sync_object_full, None, num_stages, None, None) # type: ignore[arg-type] @dsl_user_op - def producer_acquire(self, *, loc=None, ip=None): - self.sync_object_full.wait(loc=loc, ip=ip) + def producer_acquire( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + self.sync_object_full.wait(loc=loc, ip=ip) # type: ignore[call-arg] @dsl_user_op - def producer_commit(self, *, loc=None, ip=None): - self.sync_object_full.arrive(loc=loc, ip=ip) + def producer_commit( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + self.sync_object_full.arrive(loc=loc, ip=ip) # type: ignore[call-arg] @dsl_user_op - def consumer_wait(self, *, loc=None, ip=None): + def consumer_wait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: assert False, "Error: PipelineTmaStore does not have a consumer agent." @dsl_user_op - def consumer_release(self, *, loc=None, ip=None): + def consumer_release( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: assert False, "Error: PipelineTmaStore does not have a consumer agent." @dsl_user_op - def producer_tail(self, *, loc=None, ip=None): - self.sync_object_full.tail(loc=loc, ip=ip) + def producer_tail( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + self.sync_object_full.tail(loc=loc, ip=ip) # type: ignore[attr-defined] @dataclass(frozen=True) @@ -825,7 +757,7 @@ class PipelineOrder: group_id: int, producer_group: CooperativeGroup, defer_sync: bool = False, - ): + ) -> "PipelineOrder": if not isinstance(barrier_storage, cute.Pointer): raise TypeError( f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" @@ -858,24 +790,43 @@ class PipelineOrder: ), ) - def get_barrier_for_current_stage_idx(self, group_id): - return self.state.index * self.length + group_id + def get_barrier_for_current_stage_idx( + self, group_id: int, state: Optional[PipelineState] = None + ) -> Int32: + state = self.state if state is None else state + return state.index * self.length + group_id @dsl_user_op - def arrive(self, *, loc=None, ip=None): + def arrive( # type: ignore[return] + self, + state: Optional[PipelineState] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Optional[PipelineState]: + state = self.state if state is None else state signalling_id = (self.group_id + 1) % self.length - idx = self.get_barrier_for_current_stage_idx(signalling_id) + idx = self.get_barrier_for_current_stage_idx(signalling_id, state) cute.arch.mbarrier_arrive( - self.sync_object_full.get_barrier(idx, loc=loc, ip=ip), loc=loc, ip=ip + self.sync_object_full.get_barrier(idx, loc=loc, ip=ip), # type: ignore[call-arg] + loc=loc, + ip=ip, ) - self.state.advance(loc=loc, ip=ip) + state.advance(loc=loc, ip=ip) + if state is not self.state: + return state @dsl_user_op - def wait(self, *, loc=None, ip=None): - idx = self.get_barrier_for_current_stage_idx(self.group_id) + def wait( + self, + state: Optional[PipelineState] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + state = self.state if state is None else state + idx = self.get_barrier_for_current_stage_idx(self.group_id, state) cute.arch.mbarrier_wait( - self.sync_object_full.get_barrier(idx, loc=loc, ip=ip), - self.state.phase, + self.sync_object_full.get_barrier(idx, loc=loc, ip=ip), # type: ignore[call-arg] + state.phase, loc=loc, ip=ip, ) @@ -892,26 +843,26 @@ class ImmutableResourceHandle: __immutable_state: PipelineState def __init__(self, origin: PipelineAsync, immutable_state: PipelineState): - self.__origin = origin - self.__immutable_state = immutable_state + self.__origin = origin # type: ignore[misc] + self.__immutable_state = immutable_state # type: ignore[misc] @property - def index(self): + def index(self) -> Int32: """Get the index of the current pipeline stage.""" return self.__immutable_state.index @property - def count(self): + def count(self) -> Int32: """Get the count of how many handles this producer has committed. This is useful for tracking the number of blocks that have been loaded from gmem. """ return self.__immutable_state.count - def get_origin(self): + def get_origin(self) -> PipelineAsync: """Get the original pipeline this resource handle belongs to.""" return self.__origin - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list: """Extract MLIR values from the current state. :return: List of MLIR values representing the current state @@ -920,7 +871,7 @@ class ImmutableResourceHandle: # TODO: need to handle pipeline as well return self.__immutable_state.__extract_mlir_values__() - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: "list") -> "ImmutableResourceHandle": """Create a new Producer instance from MLIR values. :param values: MLIR values to initialize the state @@ -976,27 +927,36 @@ class PipelineProducer: @dataclass(frozen=True) class ImmutableResourceHandle(ImmutableResourceHandle): @property - def barrier(self): + def barrier(self) -> cute.Pointer: """Get the barrier pointer for the current pipeline stage. :return: Pointer to the barrier for the current stage :rtype: cute.Pointer """ return self.get_origin().producer_get_barrier( - self._ImmutableResourceHandle__immutable_state + self._ImmutableResourceHandle__immutable_state # type: ignore[attr-defined] ) @dsl_user_op - def commit(self, *, loc=None, ip=None): + def commit( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Signal that data production is complete for the current stage. This allows consumers to start processing the data. """ self.get_origin().producer_commit( - self._ImmutableResourceHandle__immutable_state, loc=loc, ip=ip + self._ImmutableResourceHandle__immutable_state, # type: ignore[attr-defined] + loc=loc, + ip=ip, ) - def __init__(self, pipeline, state, group: CooperativeGroup): + def __init__( + self, pipeline: PipelineAsync, state: PipelineState, group: CooperativeGroup + ): """Initialize a new Producer instance. :param pipeline: The pipeline this producer belongs to @@ -1010,22 +970,32 @@ class PipelineProducer: self.__state = state self.__group = group - def clone(self): + def clone(self) -> "PipelineProducer": """Create a new Producer instance with the same state.""" return PipelineProducer(self.__pipeline, self.__state.clone(), self.__group) @dsl_user_op - def reset(self, *, loc=None, ip=None): + def reset( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Reset the count of how many handles this producer has committed.""" self.__state.reset_count(loc=loc, ip=ip) + def current_handle(self) -> ImmutableResourceHandle: + """Get the current handle for the producer.""" + return PipelineProducer.ImmutableResourceHandle(self.__pipeline, self.__state) + @dsl_user_op def acquire( self, try_acquire_token: Optional[Boolean] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> ImmutableResourceHandle: """Wait for the current buffer to be empty before producing data. This is a blocking operation. @@ -1036,7 +1006,7 @@ class PipelineProducer: :rtype: ImmutableResourceHandle """ self.__pipeline.producer_acquire( - self.__state, try_acquire_token, loc=loc, ip=ip + self.__state, try_acquire_token, loc=loc, ip=ip, **kwargs ) handle = PipelineProducer.ImmutableResourceHandle( self.__pipeline, self.__state.clone() @@ -1044,13 +1014,23 @@ class PipelineProducer: return handle @dsl_user_op - def advance(self, *, loc=None, ip=None): + def advance( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Move to the next pipeline stage.""" self.__state.advance(loc=loc, ip=ip) @dsl_user_op def acquire_and_advance( - self, try_acquire_token: Optional[Boolean] = None, *, loc=None, ip=None + self, + try_acquire_token: Optional[Boolean] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> ImmutableResourceHandle: """Acquire the current buffer and advance to the next pipeline stage. @@ -1066,12 +1046,17 @@ class PipelineProducer: acquired buffer stage :rtype: ImmutableResourceHandle """ - handle = self.acquire(try_acquire_token, loc=loc, ip=ip) + handle = self.acquire(try_acquire_token, loc=loc, ip=ip, **kwargs) self.advance(loc=loc, ip=ip) return handle @dsl_user_op - def try_acquire(self, *, loc=None, ip=None) -> Boolean: + def try_acquire( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: """Attempt to acquire the current buffer without blocking. This method tries to acquire the current buffer stage for producing data @@ -1085,8 +1070,12 @@ class PipelineProducer: @dsl_user_op def commit( - self, handle: Optional[ImmutableResourceHandle] = None, *, loc=None, ip=None - ): + self, + handle: Optional[ImmutableResourceHandle] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Signal that data production is complete for the current stage. This allows consumers to start processing the data. @@ -1104,14 +1093,19 @@ class PipelineProducer: self.__pipeline.producer_commit(self.__state, loc=loc, ip=ip) @dsl_user_op - def tail(self, *, loc=None, ip=None): + def tail( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Ensure all used buffers are properly synchronized before producer exit. This should be called before the producer finishes to avoid dangling signals. """ self.__pipeline.producer_tail(self.__state, loc=loc, ip=ip) - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: """Extract MLIR values from the current state. :return: List of MLIR values representing the current state @@ -1120,7 +1114,7 @@ class PipelineProducer: # TODO: need to handle pipeline as well return self.__state.__extract_mlir_values__() - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "PipelineProducer": """Create a new Producer instance from MLIR values. :param values: MLIR values to initialize the state @@ -1178,16 +1172,36 @@ class PipelineConsumer: @dataclass(frozen=True) class ImmutableResourceHandle(ImmutableResourceHandle): + @property + def barrier(self) -> cute.Pointer: + """Get the barrier pointer for the current pipeline stage. + + :return: Pointer to the barrier for the current stage + :rtype: cute.Pointer + """ + return self.get_origin().consumer_get_barrier( + self._ImmutableResourceHandle__immutable_state # type: ignore[attr-defined] + ) + @dsl_user_op - def release(self, *, loc=None, ip=None): + def release( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Signal that data production is complete for the current stage. This allows consumers to start processing the data. """ self.get_origin().consumer_release( - self._ImmutableResourceHandle__immutable_state, loc=loc, ip=ip + self._ImmutableResourceHandle__immutable_state, # type: ignore[attr-defined] + loc=loc, + ip=ip, ) - def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): + def __init__( + self, pipeline: PipelineAsync, state: PipelineState, group: CooperativeGroup + ): """Initialize a new Consumer instance. :param pipeline: The pipeline this consumer belongs to @@ -1201,18 +1215,31 @@ class PipelineConsumer: self.__group = group self.__state = state - def clone(self): + def clone(self) -> "PipelineConsumer": """Create a new Consumer instance with the same state.""" return PipelineConsumer(self.__pipeline, self.__state.clone(), self.__group) @dsl_user_op - def reset(self, *, loc=None, ip=None): + def reset( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Reset the count of how many handles this consumer has consumed.""" self.__state.reset_count(loc=loc, ip=ip) + def current_handle(self) -> ImmutableResourceHandle: + """Get the current handle for the consumer.""" + return PipelineConsumer.ImmutableResourceHandle(self.__pipeline, self.__state) + @dsl_user_op def wait( - self, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None + self, + try_wait_token: Optional[Boolean] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ImmutableResourceHandle: """Wait for data to be ready in the current buffer. This is a blocking operation that will not return until data is available. @@ -1231,7 +1258,12 @@ class PipelineConsumer: return handle @dsl_user_op - def advance(self, *, loc=None, ip=None): + def advance( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Advance the consumer to the next pipeline stage. This updates the internal state to point to the next buffer in the pipeline. @@ -1241,7 +1273,11 @@ class PipelineConsumer: @dsl_user_op def wait_and_advance( - self, try_wait_token: Optional[Boolean] = None, *, loc=None, ip=None + self, + try_wait_token: Optional[Boolean] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> ImmutableResourceHandle: """Atomically wait for data and advance to next pipeline stage. @@ -1261,7 +1297,12 @@ class PipelineConsumer: return handle @dsl_user_op - def try_wait(self, *, loc=None, ip=None) -> Boolean: + def try_wait( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Boolean: """Non-blocking check if data is ready in the current buffer. This method provides a way to test if data is available without blocking. @@ -1274,8 +1315,12 @@ class PipelineConsumer: @dsl_user_op def release( - self, handle: Optional[ImmutableResourceHandle] = None, *, loc=None, ip=None - ): + self, + handle: Optional[ImmutableResourceHandle] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Signal that data consumption is complete for the current stage. This allows producers to start producing new data. """ @@ -1287,7 +1332,7 @@ class PipelineConsumer: else: self.__pipeline.consumer_release(self.__state, loc=loc, ip=ip) - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: """Extract MLIR values from the current state. :return: List of MLIR values representing the current state @@ -1295,7 +1340,7 @@ class PipelineConsumer: """ return self.__state.__extract_mlir_values__() - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "PipelineConsumer": """Create a new Consumer instance from MLIR values. :param values: MLIR values to initialize the state diff --git a/python/CuTeDSL/cutlass/torch.py b/python/CuTeDSL/cutlass/torch.py index 63757f9e6..f2334e482 100644 --- a/python/CuTeDSL/cutlass/torch.py +++ b/python/CuTeDSL/cutlass/torch.py @@ -13,7 +13,7 @@ import ctypes from math import prod from dataclasses import dataclass from enum import Enum -from typing import Optional, Type, Union, Tuple, Literal +from typing import Any, Optional, Type, Union, Tuple from cutlass.cute.typing import ( Numeric, @@ -23,6 +23,8 @@ from cutlass.cute.typing import ( Float8E4M3FN, Float8E5M2, Float8E8M0FNU, + Float6E3M2FN, + Float6E2M3FN, Float4E2M1FN, Int4, Tensor, @@ -33,7 +35,7 @@ import torch import cuda.bindings.driver as cuda -def dtype(ty: Type[Numeric]): +def dtype(ty: Type[Numeric]) -> "torch.dtype": """ Return the corresponding torch.dtype per the given DSL type """ @@ -60,10 +62,10 @@ def dtype(ty: Type[Numeric]): return torch_dtype -def as_tensor(pointer, shape, torch_type): +def as_tensor(pointer: Any, shape: Any, torch_type: "torch.dtype") -> "torch.Tensor": """Convert a pointer to a torch tensor""" if torch_type.itemsize == 1: - cytype = ctypes.c_uint8 + cytype: type = ctypes.c_uint8 elif torch_type.itemsize == 2: cytype = ctypes.c_uint16 elif torch_type.itemsize == 4: @@ -72,7 +74,7 @@ def as_tensor(pointer, shape, torch_type): cytype = ctypes.c_uint64 else: raise ValueError(f"Unsupported torch dtype: {torch_type}") - cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype)) + cpointer: Any = ctypes.cast(pointer, ctypes.POINTER(cytype)) arr = (cpointer._type_ * prod(shape)).from_address( ctypes.addressof(cpointer.contents) ) @@ -113,9 +115,9 @@ class TensorInitType(Enum): def create_and_permute_torch_tensor( - shape, + shape: Tuple[int, ...], dtype: "torch.dtype", - permute_order=None, + permute_order: Optional[Tuple[int, ...]] = None, init_type: TensorInitType = TensorInitType.RANDOM, init_config: Optional[ Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] @@ -155,7 +157,7 @@ def create_and_permute_torch_tensor( f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std) f32_torch_tensor = f32_torch_tensor * init_config.scale else: - raise ValueError(f"Invalid init type: {init_type}") + raise ValueError(f"Invalid init type: {init_type} ({type(init_type)})") if permute_order is not None: f32_torch_tensor = f32_torch_tensor.permute(permute_order) @@ -172,7 +174,7 @@ def get_leading_dim(torch_tensor: torch.Tensor) -> int: for i, stride in enumerate(torch_tensor.stride()): if stride == 1: return i - return None + return None # type: ignore[return-value] def convert_cute_tensor( @@ -195,6 +197,8 @@ def convert_cute_tensor( Float8E5M2, Float8E4M3FN, Float8E8M0FNU, + Float6E3M2FN, + Float6E2M3FN, Float4E2M1FN, }: fp32_cute_tensor = from_dlpack(f32_torch_tensor) diff --git a/python/CuTeDSL/cutlass/utils/README.md b/python/CuTeDSL/cutlass/utils/README.md deleted file mode 100644 index b4a84681d..000000000 --- a/python/CuTeDSL/cutlass/utils/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Utilities - -This folder contains various utilties for kernel authoring. Specifically, the implementation of the -followings can be considered experimental and subject to breaking changes: - -- static persistent tile scheduler defined in [`static_persistent_tile_scheduler.py`](./static_persistent_tile_scheduler.py) -- dynamic persistent tile scheduler defined in [`dynamic_persistent_tile_scheduler.py`](./dynamic_persistent_tile_scheduler.py) -- pipeline abstractions defined in [`pipeline.py`](./pipeline.py) -- grouped GEMM utilties defined [`grouped_gemm_tile_scheduler_helper.py`](./grouped_gemm_tile_scheduler_helper.py) - and [`tensormap_manager.py`](./tensormap_manager.py) diff --git a/python/CuTeDSL/cutlass/utils/__init__.py b/python/CuTeDSL/cutlass/utils/__init__.py index 74783822f..5929b4520 100644 --- a/python/CuTeDSL/cutlass/utils/__init__.py +++ b/python/CuTeDSL/cutlass/utils/__init__.py @@ -24,9 +24,16 @@ from .dynamic_persistent_tile_scheduler import ( from .hardware_info import HardwareInfo from .blackwell_helpers import ( + cluster_shape_to_tma_atom_A, + cluster_shape_to_tma_atom_B, + cluster_shape_to_tma_atom_SFB, compute_epilogue_tile_shape, + get_permutation_mnk, + get_smem_layout_atom_ab, + get_smem_layout_atom_epi, get_smem_store_op, get_tmem_load_op, + make_smem_layout, make_smem_layout_a, make_smem_layout_b, make_smem_layout_epi, @@ -36,9 +43,11 @@ from .blackwell_helpers import ( from .hopper_helpers import ( sm90_get_smem_store_op, + get_smem_layout_atom as sm90_get_smem_layout_atom, make_smem_layout_a as sm90_make_smem_layout_a, make_smem_layout_b as sm90_make_smem_layout_b, make_smem_layout_epi as sm90_make_smem_layout_epi, + make_trivial_tiled_mma as sm90_make_trivial_tiled_mma, compute_tile_shape_or_override, ) @@ -66,12 +75,15 @@ from .tensormap_manager import ( ) from .smem_allocator import SmemAllocator, get_smem_capacity_in_bytes -from .tmem_allocator import TmemAllocator, get_num_tmem_alloc_cols +from .tmem_allocator import ( + TmemAllocator, + TmemBufferPool, + get_num_tmem_alloc_cols, + compute_tmem_cols_from_layout, +) from .layout import LayoutEnum -from . import distributed - from .mixed_input_helpers import ( TransformMode, scale_tma_partition, @@ -99,6 +111,7 @@ from .mixed_input_helpers import ( ) from . import gemm +from . import distributed from . import hopper_helpers as sm90 from . import blackwell_helpers as sm100 @@ -113,7 +126,9 @@ __all__ = [ "get_smem_capacity_in_bytes", "SmemAllocator", "TmemAllocator", + "TmemBufferPool", "get_num_tmem_alloc_cols", + "compute_tmem_cols_from_layout", "LayoutEnum", "WorkTileInfo", "PersistentTileSchedulerParams", @@ -141,14 +156,23 @@ __all__ = [ "epilogue_tma_store", "epilogue", "create_tensor_a", + "cluster_shape_to_tma_atom_A", + "cluster_shape_to_tma_atom_B", + "cluster_shape_to_tma_atom_SFB", "compute_epilogue_tile_shape", + "get_permutation_mnk", + "get_smem_layout_atom_ab", + "get_smem_layout_atom_epi", "get_smem_store_op", "get_tmem_load_op", + "make_smem_layout", "make_smem_layout_a", "make_smem_layout_b", "make_smem_layout_epi", "make_trivial_tiled_mma", "make_blockscaled_trivial_tiled_mma", + "sm90_get_smem_layout_atom", + "sm90_make_trivial_tiled_mma", "sm90", "sm100", "gemm", diff --git a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py index f1c009ea5..bc75656fa 100644 --- a/python/CuTeDSL/cutlass/utils/blackwell_helpers.py +++ b/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -9,7 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import List, Type, Union, Tuple +from typing import Any, List, Optional, Type, Union, Tuple, overload from typing_extensions import deprecated from cutlass.cutlass_dsl import ( @@ -21,24 +21,29 @@ from cutlass.cutlass_dsl import ( Int8, Float8E4M3FN, Float8E5M2, + Float6E3M2FN, + Float6E2M3FN, Float4E2M1FN, Numeric, NumericMeta, dsl_user_op, ) + +from cutlass._mlir import ir import cutlass.cute as cute -from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.nvgpu.common import CopyUniversalOp, OperandMajorMode from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp from cutlass.cute.nvgpu.tcgen05 import ( MmaF16BF16Op, MmaTF32Op, MmaI8Op, - MmaFP8Op, - MmaMXF8Op, + MmaF8F6F4Op, + MmaMXF8F6F4Op, MmaMXF4Op, MmaMXF4NVF4Op, + SM103MmaMXF4Op, + SM103MmaMXF4NVF4Op, OperandSource as Tcgen05OperandSource, - OperandMajorMode, CtaGroup, Ld16x64bOp, Ld16x128bOp, @@ -58,21 +63,103 @@ from cutlass.cute.nvgpu.cpasync import ( CopyBulkTensorTileG2SOp, ) from cutlass.utils.layout import LayoutEnum +import cutlass.cute.testing as testing # Type alias for documentation clarity OperandSource = Tcgen05OperandSource +TMA_ALIGNMENT_BYTES = 16 + + +def get_tma_aligned_contiguous_elements(elem_ty: Type[Numeric]) -> int: + """Return the minimum contiguous element count for 16B TMA alignment.""" + if not isinstance(elem_ty, NumericMeta): + raise TypeError(f"elem_ty must be Numeric, but got {elem_ty}") + if elem_ty.width == 6: + raise testing.CantImplementError( + "[alignment] FP6 TMA alignment is not validated by the experimental " + "Blackwell GEMM preflight checks yet" + ) + return TMA_ALIGNMENT_BYTES * 8 // elem_ty.width + + +def check_tma_tensor_alignment( + tensor_name: str, + elem_ty: Type[Numeric], + contiguous_dim_size: int, + contiguous_dim_name: str, +) -> None: + """Fail fast when a TMA operand's contiguous dimension is not 16B aligned.""" + required_elements = get_tma_aligned_contiguous_elements(elem_ty) + if contiguous_dim_size % required_elements != 0: + raise testing.CantImplementError( + f"[alignment] {tensor_name} requires the contiguous " + f"{contiguous_dim_name} dimension to be {TMA_ALIGNMENT_BYTES}B aligned, " + f"but got {contiguous_dim_name}={contiguous_dim_size} with dtype {elem_ty}; " + f"expected a multiple of {required_elements} elements" + ) + + +def check_gemm_tma_alignment( + m: int, + n: int, + k: int, + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + d_dtype: Type[Numeric] | None, + a_major: str, + b_major: str, + d_major: str | None, + *, + output_tensor_name: str = "D", +) -> None: + """Validate GEMM operand alignment for TMA loads and an optional TMA store.""" + if a_major not in ["m", "k"]: + raise testing.CantImplementError( + f"[alignment] Invalid a_major: {a_major}, expected 'm' or 'k'" + ) + if b_major not in ["n", "k"]: + raise testing.CantImplementError( + f"[alignment] Invalid b_major: {b_major}, expected 'n' or 'k'" + ) + if d_dtype is not None and d_major not in ["m", "n"]: + raise testing.CantImplementError( + f"[alignment] Invalid d_major: {d_major}, expected 'm' or 'n'" + ) + + check_tma_tensor_alignment( + "A TMA load", + a_dtype, + m if a_major == "m" else k, + "M" if a_major == "m" else "K", + ) + check_tma_tensor_alignment( + "B TMA load", + b_dtype, + n if b_major == "n" else k, + "N" if b_major == "n" else "K", + ) + + if d_dtype is not None: + check_tma_tensor_alignment( + f"{output_tensor_name} TMA store", + d_dtype, + m if d_major == "m" else n, + "M" if d_major == "m" else "N", + ) + @dsl_user_op @deprecated("API is deprecated, use cutlass.utils.get_num_tmem_alloc_cols instead") def get_num_tmem_alloc_cols( tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], - rounding=True, + rounding: bool = True, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> int: import cutlass.utils as utils + return utils.get_num_tmem_alloc_cols( tmem_tensors, rounding, arch="sm_100", loc=loc, ip=ip ) @@ -85,10 +172,10 @@ def compute_epilogue_tile_shape( layout_d: LayoutEnum, elem_ty_d: Type[Numeric], *, - layout_c: LayoutEnum = None, + layout_c: Optional[LayoutEnum] = None, elem_ty_c: Union[Type[Numeric], None] = None, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Tile: """Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. @@ -113,7 +200,7 @@ def compute_epilogue_tile_shape( :raises ValueError: If the computed tile cute.size does not meet minimum requirements based on CTA dimensions. """ - def validate_type(ty, ty_name): + def validate_type(ty: Type[Numeric], ty_name: str) -> None: if not isinstance(ty, NumericMeta): raise TypeError(f"{ty_name} must be Numeric, but got {ty}") @@ -121,33 +208,26 @@ def compute_epilogue_tile_shape( if elem_ty_c is not None: validate_type(elem_ty_c, "elem_ty_c") - cta_m, cta_n = cta_tile_shape[:2] - (warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1) - disable_source = elem_ty_c == None - max_bits = ( - elem_ty_d.width if disable_source else max(elem_ty_c.width, elem_ty_d.width) - ) - - dp_full = 32 - tile_m = min(cta_m, dp_full * warp_m) - n_perf = 0 - if disable_source: - if max_bits == 4: - compute_elts = 8192 - else: - compute_elts = 4096 - n_perf = compute_elts // tile_m - else: - if max_bits == 32: - n_perf = 16 if (cta_m > 64 and cta_n <= 128) else 32 - elif max_bits == 16: - n_perf = 32 if cta_n <= 128 else 64 - else: - n_perf = 64 - + cta_m, cta_n = cta_tile_shape[:2] # type: ignore[index] d_is_m_major = layout_d.is_m_major_c() c_is_m_major = True if layout_c is None else layout_c.is_m_major_c() + # Pure-Python helper (no MLIR context needed) + tile_m, tile_n = compute_epilogue_tile_size( + cta_m, # type: ignore[arg-type] + cta_n, # type: ignore[arg-type] + use_2cta_instrs, + elem_ty_d.width, + elem_ty_c.width if elem_ty_c is not None else None, + d_is_m_major, + c_is_m_major, + ) + + # Compute warp layout parameters (needed for CuTe layout creation) + (warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1) + + # Validate minimum tile requirements + disable_source = elem_ty_c is None n_min_d = ( 8 * warp_n if d_is_m_major @@ -156,17 +236,18 @@ def compute_epilogue_tile_shape( n_min_c = ( 8 * warp_n if (c_is_m_major or disable_source) - else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n) + else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n) # type: ignore[union-attr] ) - tile_n = min(cta_n, max(n_perf, n_min_c, n_min_d)) - - if cta_n < n_min_c or cta_n < n_min_d: + if cta_n < n_min_c or cta_n < n_min_d: # type: ignore[operator] raise ValueError(f"CTA tile too small: {cta_tile_shape=}") # stride by tmem warp layout and return a by-mode tiler tile_m_layout = cute.make_layout(tile_m, loc=loc, ip=ip) tile_n_layout = cute.make_layout( - (tile_n // warp_n, warp_n), stride=(1, cta_n // warp_n), loc=loc, ip=ip + (tile_n // warp_n, warp_n), + stride=(1, cta_n // warp_n), # type: ignore[operator] + loc=loc, + ip=ip, ) return (tile_m_layout, cute.coalesce(tile_n_layout, loc=loc, ip=ip)) @@ -178,8 +259,8 @@ def get_smem_store_op( elem_ty_acc: Type[Numeric], tiled_tmem_load: cute.TiledCopy, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.CopyAtom: """Selects the largest vectorized smem store atom available subject to constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership. @@ -197,7 +278,7 @@ def get_smem_store_op( :rtype: cute.CopyAtom """ - def validate_type(ty, ty_name): + def validate_type(ty: Type[Numeric], ty_name: str) -> None: if not isinstance(ty, NumericMeta): raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") @@ -312,6 +393,7 @@ def get_smem_store_op( ] ) + op: Any if use_stmatrix_m8n8_4x: op = StMatrix8x8x16bOp(is_m_major, 4) return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) @@ -341,8 +423,8 @@ def get_tmem_load_op( epi_tile: cute.Tile, use_2cta_instrs: bool, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.CopyAtom: """Finds a performant TMEM_LOAD copy op for the selected epilogue tile (epi_tile), element types, and tcgen05.mma instruction used. @@ -373,7 +455,7 @@ def get_tmem_load_op( d_bits = elem_ty_d.width tmem_warp_shape_mn = ( - (2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1) + (2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1) # type: ignore[index] ) epilog_tile_shape_mn = cute.product_each( cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip @@ -496,6 +578,7 @@ def get_tmem_load_op( else: raise ValueError("Can not pick tmem_rep based on cta tile shape and tmem atom.") + op: Any if tmem_dp == 16 and tmem_bit == 64: op = Ld16x64bOp( Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE @@ -531,21 +614,20 @@ def get_smem_layout_atom_ab( element_type: Type[Numeric], smem_shape_mn_k: Tuple[int, int], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> SmemLayoutAtomKind: """Simple heuristics to select the optimal SMEM layout atom based on the majorness, the data type, and the major mode size. :param major_mode: The major mode for the SMEM tensor is K major. - :type major_mode: OperandMajorMode + :type major_mode: cutlass.cute.nvgpu.OperandMajorMode :param element_type: The element type for the SMEM tensor. :type element_type: Type[Numeric] :param smem_shape_mn_k: The shape of the SMEM tensor. :type smem_shape_mn_k: Tuple[int, int] - :return: The SMEM layout atom kind - :rtype: SmemLayoutAtomKind + :rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind """ is_k_major = major_mode == OperandMajorMode.K major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0] @@ -556,6 +638,7 @@ def get_smem_layout_atom_ab( sw32_num_contiguous_bits = 256 inter_num_contiguous_bits = 128 major_mode_size_bits = major_mode_size * element_type.width + assert major_mode_size_bits % inter_num_contiguous_bits == 0 if not is_k_major: @@ -586,8 +669,8 @@ def make_smem_layout( a_dtype: Type[Numeric], num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """Construct a staged SMEM layout for an operand given its major mode and tile shape. @@ -599,7 +682,7 @@ def make_smem_layout( 3. Orders the ``(M, N, stage)`` axes so the major dimension is contiguous, then coalesces. :param leading_mode: Operand major mode (``MN`` or ``K``) of the staged operand. - :type leading_mode: cute.nvgpu.tcgen05.OperandMajorMode + :type leading_mode: cutlass.cute.nvgpu.OperandMajorMode :param smem_tile_shape: 2D SMEM tile shape to stage (before the staging dimension is appended). :type smem_tile_shape: cute.Tile :param a_dtype: Element type of the staged operand. @@ -612,7 +695,11 @@ def make_smem_layout( """ smem_layout_atom_kind = get_smem_layout_atom_ab( - leading_mode, a_dtype, smem_tile_shape, loc=loc, ip=ip + leading_mode, + a_dtype, + smem_tile_shape, # type: ignore[arg-type] + loc=loc, + ip=ip, ) smem_layout_atom = make_smem_layout_atom( smem_layout_atom_kind, a_dtype, loc=loc, ip=ip @@ -634,9 +721,9 @@ def make_smem_layout_a( a_dtype: Type[Numeric], num_stages: int, *, - is_k_major=None, - loc=None, - ip=None, + is_k_major: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps with: @@ -659,7 +746,7 @@ def make_smem_layout_a( """ is_k_major = ( - (tiled_mma.op.a_major_mode == OperandMajorMode.K) + (tiled_mma.op.a_major_mode == OperandMajorMode.K) # type: ignore[attr-defined] if is_k_major is None else is_k_major ) @@ -692,9 +779,9 @@ def make_smem_layout_b( b_dtype: Type[Numeric], num_stages: int, *, - is_k_major=None, - loc=None, - ip=None, + is_k_major: Optional[bool] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps: @@ -717,7 +804,7 @@ def make_smem_layout_b( """ is_k_major = ( - (tiled_mma.op.b_major_mode == OperandMajorMode.K) + (tiled_mma.op.b_major_mode == OperandMajorMode.K) # type: ignore[attr-defined] if is_k_major is None else is_k_major ) @@ -750,8 +837,8 @@ def get_smem_layout_atom_epi( element_type: Type[Numeric], epi_tile: cute.Tile, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> SmemLayoutAtomKind: """Simple heuristics to select the optimal SMEM layout atom for epilog tensors. @@ -763,7 +850,7 @@ def get_smem_layout_atom_epi( :type epi_tile: cute.Tile :return: The SMEM layout atom kind - :rtype: SmemLayoutAtomKind + :rtype: cutlass.cute.nvgpu.tcgen05.SmemLayoutAtomKind """ # Get the max contiguous tile usable by TMA tma_shape = tuple( @@ -773,18 +860,26 @@ def get_smem_layout_atom_epi( if isinstance(x, cute.Layout) else x ) - for x in epi_tile + for x in epi_tile # type: ignore[union-attr] ) if layout.is_m_major_c(): # ColMajor C/D (M-major) return get_smem_layout_atom_ab( - OperandMajorMode.MN, element_type, tma_shape, loc=loc, ip=ip + OperandMajorMode.MN, + element_type, + tma_shape, # type: ignore[arg-type] + loc=loc, + ip=ip, ) else: # RowMajor C/D (N-major) return get_smem_layout_atom_ab( - OperandMajorMode.K, element_type, tma_shape, loc=loc, ip=ip + OperandMajorMode.K, + element_type, + tma_shape, # type: ignore[arg-type] + loc=loc, + ip=ip, ) @@ -795,8 +890,8 @@ def make_smem_layout_epi( epi_tile: cute.Tile, epi_stage: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps: @@ -841,7 +936,91 @@ def make_smem_layout_epi( return epi_smem_layout_staged -@dsl_user_op +_F8F6F4_TYPES = { + Float8E4M3FN, + Float8E5M2, + Float6E3M2FN, + Float6E2M3FN, + Float4E2M1FN, +} + + +def _bind_mma_args( + func_name: str, + args: Tuple[Any, ...], + kwargs: dict, + new_params: Tuple[str, ...], + legacy_params: Tuple[str, ...], +) -> Tuple[bool, dict]: + """Bind positional args and kwargs to either the new or legacy parameter list. + + The last entry of each parameter list is treated as optional; all others are + required. Returns ``(is_new_api, bound)`` where ``bound`` maps parameter + names to values. + """ + if "ab_dtype" in kwargs and ("a_dtype" in kwargs or "b_dtype" in kwargs): + raise TypeError( + f"{func_name}() cannot mix legacy 'ab_dtype' with new " + f"'a_dtype'/'b_dtype' keyword arguments" + ) + + if "ab_dtype" in kwargs: + is_new_api = False + elif "a_dtype" in kwargs or "b_dtype" in kwargs: + is_new_api = True + elif len(args) >= 2: + is_new_api = isinstance(args[1], NumericMeta) + else: + # Single (or zero) positional with no dtype kwargs — legacy is deprecated, + # so default to new API; missing-arg validation below produces a clear error. + is_new_api = True + + params = new_params if is_new_api else legacy_params + + if len(args) > len(params): + raise TypeError( + f"{func_name}() takes at most {len(params)} positional arguments but " + f"{len(args)} were given" + ) + + bound: dict = {} + for i, val in enumerate(args): + bound[params[i]] = val + for key, val in kwargs.items(): + if key not in params: + raise TypeError(f"{func_name}() got an unexpected keyword argument '{key}'") + if key in bound: + raise TypeError(f"{func_name}() got multiple values for argument '{key}'") + bound[key] = val + + required = params[:-1] + missing = [p for p in required if p not in bound] + if missing: + raise TypeError( + f"{func_name}() missing required argument(s): {', '.join(missing)}" + ) + + return is_new_api, bound + + +@overload +def make_trivial_tiled_mma( + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + acc_dtype: Type[Numeric], + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: ... + + +@overload +@deprecated("use make_trivial_tiled_mma with separate a_dtype and b_dtype instead") def make_trivial_tiled_mma( ab_dtype: Type[Numeric], a_leading_mode: OperandMajorMode, @@ -851,36 +1030,106 @@ def make_trivial_tiled_mma( mma_tiler_mn: Tuple[int, int], a_source: OperandSource = OperandSource.SMEM, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: ... + + +@dsl_user_op +def make_trivial_tiled_mma( + *args: Any, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> cute.TiledMma: """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. By default, the MMA atom is created with SMEM operand source for A. - :param ab_dtype: Data type of operands A and B. - :type ab_dtype: type[Numeric] - :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). - :type a_leading_mode: tcgen05.OperandMajorMode - :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). - :type b_leading_mode: tcgen05.OperandMajorMode - :param acc_dtype: Data type of the accumulator. - :type acc_dtype: type[Numeric] - :param cta_group: The CTA group to use. - :type cta_group: tcgen05.CtaGroup - :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. - :type mma_tiler_mn: Tuple[int, int] - :param a_source: The source of operand A (SMEM by default or TMEM). - :type a_source: cutlass.cute.nvgpu.tcgen05.OperandSource + Supports two calling conventions: - :return: A tiled MMA atom. - :rtype: cute.TiledMma + **New (recommended):** separate ``a_dtype`` and ``b_dtype``:: - :raises TypeError: If the data type is not supported. + make_trivial_tiled_mma( + a_dtype, b_dtype, a_leading_mode, b_leading_mode, + acc_dtype, cta_group, mma_tiler_mn, [a_source]) + + **Legacy (deprecated):** single ``ab_dtype``:: + + make_trivial_tiled_mma( + ab_dtype, a_leading_mode, b_leading_mode, + acc_dtype, cta_group, mma_tiler_mn, [a_source]) """ + import warnings - if ab_dtype in {Float16, BFloat16}: + new_params = ( + "a_dtype", + "b_dtype", + "a_leading_mode", + "b_leading_mode", + "acc_dtype", + "cta_group", + "mma_tiler_mn", + "a_source", + ) + legacy_params = ( + "ab_dtype", + "a_leading_mode", + "b_leading_mode", + "acc_dtype", + "cta_group", + "mma_tiler_mn", + "a_source", + ) + + is_new_api, bound = _bind_mma_args( + "make_trivial_tiled_mma", args, kwargs, new_params, legacy_params + ) + bound.setdefault("a_source", OperandSource.SMEM) + + if not is_new_api: + warnings.warn( + "make_trivial_tiled_mma with ab_dtype is deprecated, " + "use the overload with separate a_dtype and b_dtype instead", + DeprecationWarning, + stacklevel=2, + ) + a_dtype = bound["ab_dtype"] + b_dtype = bound["ab_dtype"] + else: + a_dtype = bound["a_dtype"] + b_dtype = bound["b_dtype"] + + return _make_trivial_tiled_mma_impl( + a_dtype, + b_dtype, + bound["a_leading_mode"], + bound["b_leading_mode"], + bound["acc_dtype"], + bound["cta_group"], + bound["mma_tiler_mn"], + bound["a_source"], + loc=loc, + ip=ip, + ) + + +def _make_trivial_tiled_mma_impl( + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + acc_dtype: Type[Numeric], + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: + mma_op: Any + if a_dtype in {Float16, BFloat16} and b_dtype == a_dtype: mma_op = MmaF16BF16Op( - ab_dtype, + a_dtype, acc_dtype, (*mma_tiler_mn, 16), cta_group, @@ -888,7 +1137,7 @@ def make_trivial_tiled_mma( a_leading_mode, b_leading_mode, ) - elif ab_dtype in {TFloat32, Float32}: + elif a_dtype in {TFloat32, Float32} and b_dtype == a_dtype: mma_op = MmaTF32Op( (*mma_tiler_mn, 8), cta_group, @@ -896,21 +1145,19 @@ def make_trivial_tiled_mma( a_leading_mode, b_leading_mode, ) - elif ab_dtype in { - Uint8, - Int8, - }: + elif a_dtype in {Uint8, Int8} and b_dtype == a_dtype: mma_op = MmaI8Op( - ab_dtype, + a_dtype, (*mma_tiler_mn, 32), cta_group, a_source, a_leading_mode, b_leading_mode, ) - elif ab_dtype in {Float8E4M3FN, Float8E5M2}: - mma_op = MmaFP8Op( - ab_dtype, + elif a_dtype in _F8F6F4_TYPES and b_dtype in _F8F6F4_TYPES: + mma_op = MmaF8F6F4Op( + a_dtype, + b_dtype, acc_dtype, (*mma_tiler_mn, 32), cta_group, @@ -919,14 +1166,36 @@ def make_trivial_tiled_mma( b_leading_mode, ) else: - raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + raise TypeError( + f"unsupported a_dtype/b_dtype, got a_dtype: {a_dtype}, b_dtype: {b_dtype}" + ) return cute.make_tiled_mma( cute.make_mma_atom(mma_op, loc=loc, ip=ip), loc=loc, ip=ip ) -@dsl_user_op +@overload +def make_blockscaled_trivial_tiled_mma( + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + sf_dtype: Type[Numeric], + sf_vec_size: int, + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: ... + + +@deprecated( + "use make_blockscaled_trivial_tiled_mma with separate a_dtype and b_dtype instead" +) +@overload def make_blockscaled_trivial_tiled_mma( ab_dtype: Type[Numeric], a_leading_mode: OperandMajorMode, @@ -937,44 +1206,112 @@ def make_blockscaled_trivial_tiled_mma( mma_tiler_mn: Tuple[int, int], a_source: OperandSource = OperandSource.SMEM, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: ... + + +@dsl_user_op +def make_blockscaled_trivial_tiled_mma( + *args: Any, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + **kwargs: Any, ) -> cute.TiledMma: """Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. By default, the MMA atom is created with SMEM operand source for A. - :param ab_dtype: Data type of operands A and B. - :type ab_dtype: type[Numeric] - :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). - :type a_leading_mode: tcgen05.OperandMajorMode - :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). - :type b_leading_mode: tcgen05.OperandMajorMode - :param sf_dtype: Data type of the Scale Factor. - :type sf_dtype: type[Numeric] - :param sf_vec_size: The vector size of the Scale Factor. - :type sf_vec_size: int - :param cta_group: The CTA group to use. - :type cta_group: tcgen05.CtaGroup - :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. - :type mma_tiler_mn: Tuple[int, int] - :param a_source: The source of operand A (SMEM by default or TMEM). - :type a_source: cutlass.cute.nvgpu.tcgen05.OperandSource + Supports two calling conventions: - :return: A tiled MMA atom. - :rtype: cute.TiledMma + **New (recommended):** separate ``a_dtype`` and ``b_dtype``:: - :raises TypeError: If the data type is not supported. + make_blockscaled_trivial_tiled_mma( + a_dtype, b_dtype, a_leading_mode, b_leading_mode, + sf_dtype, sf_vec_size, cta_group, mma_tiler_mn, [a_source]) + + **Legacy (deprecated):** single ``ab_dtype``:: + + make_blockscaled_trivial_tiled_mma( + ab_dtype, a_leading_mode, b_leading_mode, + sf_dtype, sf_vec_size, cta_group, mma_tiler_mn, [a_source]) """ - if ab_dtype in {Float8E4M3FN, Float8E5M2}: - mma_op = MmaMXF8Op( - ab_dtype, - (*mma_tiler_mn, 32), - cta_group, - a_source, - a_leading_mode, - b_leading_mode, + import warnings + + new_params = ( + "a_dtype", + "b_dtype", + "a_leading_mode", + "b_leading_mode", + "sf_dtype", + "sf_vec_size", + "cta_group", + "mma_tiler_mn", + "a_source", + ) + legacy_params = ( + "ab_dtype", + "a_leading_mode", + "b_leading_mode", + "sf_dtype", + "sf_vec_size", + "cta_group", + "mma_tiler_mn", + "a_source", + ) + + is_new_api, bound = _bind_mma_args( + "make_blockscaled_trivial_tiled_mma", + args, + kwargs, + new_params, + legacy_params, + ) + bound.setdefault("a_source", OperandSource.SMEM) + + if not is_new_api: + warnings.warn( + "make_blockscaled_trivial_tiled_mma with ab_dtype is deprecated, " + "use the overload with separate a_dtype and b_dtype instead", + DeprecationWarning, + stacklevel=2, ) - elif ab_dtype == Float4E2M1FN: + a_dtype = bound["ab_dtype"] + b_dtype = bound["ab_dtype"] + else: + a_dtype = bound["a_dtype"] + b_dtype = bound["b_dtype"] + + return _make_blockscaled_trivial_tiled_mma_impl( + a_dtype, + b_dtype, + bound["a_leading_mode"], + bound["b_leading_mode"], + bound["sf_dtype"], + bound["sf_vec_size"], + bound["cta_group"], + bound["mma_tiler_mn"], + bound["a_source"], + loc=loc, + ip=ip, + ) + + +def _make_blockscaled_trivial_tiled_mma_impl( + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + sf_dtype: Type[Numeric], + sf_vec_size: int, + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: + mma_op: Any + if a_dtype == Float4E2M1FN and b_dtype == Float4E2M1FN: if sf_vec_size == 32: mma_op = MmaMXF4Op( (*mma_tiler_mn, 64), @@ -990,8 +1327,20 @@ def make_blockscaled_trivial_tiled_mma( ) else: raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}") + elif a_dtype in _F8F6F4_TYPES and b_dtype in _F8F6F4_TYPES: + mma_op = MmaMXF8F6F4Op( + a_dtype, + b_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) else: - raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + raise TypeError( + f"unsupported a_dtype/b_dtype, got a_dtype: {a_dtype}, b_dtype: {b_dtype}" + ) return cute.make_tiled_mma( cute.make_mma_atom(mma_op, loc=loc, ip=ip), loc=loc, ip=ip @@ -1000,7 +1349,11 @@ def make_blockscaled_trivial_tiled_mma( @dsl_user_op def cluster_shape_to_tma_atom_A( - cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None + cluster_shape_mnk: cute.Shape, + atom_thr_id: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: """ Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag. @@ -1046,7 +1399,11 @@ def cluster_shape_to_tma_atom_A( @dsl_user_op def cluster_shape_to_tma_atom_B( - cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None + cluster_shape_mnk: cute.Shape, + atom_thr_id: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: """ Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag. @@ -1092,7 +1449,11 @@ def cluster_shape_to_tma_atom_B( @dsl_user_op def cluster_shape_to_tma_atom_SFB( - cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None + cluster_shape_mnk: cute.Shape, + atom_thr_id: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: """ Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag. @@ -1134,14 +1495,59 @@ def cluster_shape_to_tma_atom_SFB( ) +@dsl_user_op +def sm120_get_smem_store_op( + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem store atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_d : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_d : Type[Numeric] + The element type for output tensor D. + + elem_ty_acc : Type[Numeric] + The element type for accumulator. + + Returns: + -------- + Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. + """ + + def validate_type(ty: Type[Numeric], ty_name: str) -> None: + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + validate_type(elem_ty_acc, "elem_ty_acc") + + is_m_major = layout_d.is_m_major_c() + + if elem_ty_d.width == 16: + return cute.make_copy_atom( + StMatrix8x8x16bOp(is_m_major, 2), elem_ty_d, loc=loc, ip=ip + ) + else: + return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) + + @dsl_user_op def get_permutation_mnk( tile_shape_mnk: cute.Shape, sf_vec_size: int, use_mxf8f6f4: bool, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[int, int, int]: """ Get the permutation of M, N, K for the tiled MMA. @@ -1158,7 +1564,7 @@ def get_permutation_mnk( :raise ValueError: If the tile shape is not divisible by the sf_vec_size """ - perm_m = min(tile_shape_mnk[0], 128) + perm_m = min(tile_shape_mnk[0], 128) # type: ignore[index] # refer to C++ code: # /include/cutlass/gemm/collective/builders/sm120_common.inl?ref_type=heads#L158 if sf_vec_size == 32 or sf_vec_size == 16: @@ -1175,10 +1581,573 @@ def get_permutation_mnk( perm_k, ) - return permutation_mnk + return permutation_mnk # type: ignore[return-value] + + +def sm103_make_blockscaled_trivial_tiled_mma( + sf_dtype: Type[Numeric], + sf_vec_size: int, + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, +) -> cute.TiledMma: + """Create a blockscaled trivial tiled MMA for SM103 (Ultra FP4), K fixed to 96. + + Returns a tcgen05 MMA configured for the given (M, N) tiler and CTA group. + + :param sf_dtype: Data type of the scale factor (typically 8-bit) + :type sf_dtype: Type[Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param cta_group: The CTA group configuration + :type cta_group: CtaGroup + :param mma_tiler_mn: The MMA tiler dimensions (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param a_source: Source location for operand A (SMEM by default) + :type a_source: OperandSource + + :return: A tiled MMA atom configured for SM103 blockscaled operations + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + :raises ValueError: If the sf_vec_size is not supported. + """ + mma_op: Any + if sf_vec_size == 32: + mma_op = SM103MmaMXF4Op( + (*mma_tiler_mn, 96), + cta_group, + a_source, + ) + elif sf_vec_size == 16: + mma_op = SM103MmaMXF4NVF4Op( + sf_dtype, + (*mma_tiler_mn, 96), + cta_group, + a_source, + ) + else: + raise ValueError(f"Unsupported sf_vec_size: {sf_vec_size}. Expected 16 or 32.") + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + +@dsl_user_op # type: ignore[no-redef] +def sm120_get_smem_store_op( + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem store atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_d : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_d : Type[Numeric] + The element type for output tensor D. + + elem_ty_acc : Type[Numeric] + The element type for accumulator. + + Returns: + -------- + Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. + """ + + def validate_type(ty: Type[Numeric], ty_name: str) -> None: + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + validate_type(elem_ty_acc, "elem_ty_acc") + + is_m_major = layout_d.is_m_major_c() + + if elem_ty_d.width == 16: + return cute.make_copy_atom( + StMatrix8x8x16bOp(is_m_major, 2), elem_ty_d, loc=loc, ip=ip + ) + else: + return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) + + + +def compute_epilogue_tile_size( + cta_tile_m: int, + cta_tile_n: int, + use_2cta: bool, + elem_width_d: int, + elem_width_c: int | None = None, + d_is_m_major: bool = True, + c_is_m_major: bool = True, +) -> tuple[int, int]: + """Compute epilogue subtile dimensions ``(tile_m, tile_n)`` (pure Python, no MLIR). + + Used by :func:`compute_epilogue_tile_shape` and at kernel-discovery time + for SMEM capacity estimation. Must match the C++ logic in + ``cutlass/include/cutlass/epilogue/collective/builders/sm100_builder.inl`` + (the ``EpilogueTileAuto`` branch) and + ``cutlass_ir/compiler/lib/Collective/SM100.cpp`` + (``sm100_compute_tile_shape_or_override``). + + Background – SM100 epilogue flow + --------------------------------- + After the MMA, the accumulator lives in **TMEM** (Tensor Memory, 128 + datapaths x N columns). The epilogue: + + 1. Loads a subtile from TMEM into registers (``tcgen05.tmem_load``). + 2. Optionally loads a source tile C from GMEM -> SMEM -> registers. + 3. Applies the fusion (bias, activation, ...) in registers. + 4. Stores the result D to SMEM, then to GMEM via TMA. + + The full CTA output (``cta_tile_m x cta_tile_n``) is processed in + multiple **epilogue iterations**, each covering one subtile of shape + ``(tile_m, tile_n)``. Each subtile is worked on by 4 warps arranged + in a ``(warp_m, warp_n)`` grid. + + This function picks ``(tile_m, tile_n)`` to balance three constraints: + + * **TMEM load** only supports 16 or 32 datapaths per warp, which + caps ``tile_m``. + * **TMA store alignment** requires a minimum contiguous transaction + size, which sets ``n_min`` floors. + * **SMEM budget** -- the epilogue subtile lives in SMEM alongside the + mainloop's A/B pipeline buffers. A larger subtile means fewer + epilogue iterations (good) but steals SMEM from the mainloop, + reducing pipeline depth (bad). ``n_perf`` targets the sweet spot + found by benchmarking. + + Algorithm + --------- + 1. **Warp grid** ``(warp_m, warp_n)``: + + * ``(2, 2)`` when ``use_2cta and cta_tile_m == 64`` -- each of + the 2 M-warps gets 32 datapaths. + * ``(4, 1)`` otherwise -- 4 warps split M, each gets + ``tile_m / 4`` datapaths (16 or 32). + + 2. **tile_m** ``= min(cta_tile_m, 32 * warp_m)`` + + ``32`` is ``dp_full``, the number of TMEM datapaths in one + subpartition (hardware constant). Capping here ensures each + warp owns at most 32 datapaths, which is the widest mode the + ``tcgen05.tmem_load`` instruction supports. + + 3. **n_perf** -- performance target for N: + + * **Without source C** (elementwise-only epilogue): SMEM pressure + is low because there is no C tile to stage. Target a fixed + element count per iteration so that each iteration does enough + work to amortise the epilogue overhead. The constants are:: + + 4096 elements (general) -> e.g. tile_m=128, n=32 + 8192 elements (4-bit) -> e.g. tile_m=128, n=64 + + 4-bit elements are half a byte each, so doubling the count + keeps roughly the same SMEM footprint while cutting epilogue + iterations in half (experimentally best for 4-bit types). + + * **With source C** (residual-load epilogue): the source tile C + also occupies SMEM, so the epilogue tile must be smaller to + leave room for mainloop A/B pipeline stages. Targets are + chosen by element width and CTA shape to balance SMEM + partitioning:: + + 32-bit elements: n=16 when M>64 and N<=128, else n=32 + 16-bit elements: n=32 when N<=128, else n=64 + <=8-bit elements: n=64 + + Wider elements consume more bytes per element, so N is + reduced to stay within the SMEM budget. When CTA N is + large (>128), N is increased because the mainloop tile is + also large and SMEM is more abundant. + + After choosing, ``n_perf`` is halved until it evenly divides + ``cta_tile_n`` (ensures the CTA output tiles evenly into + subtiles with no ragged remainder). + + 4. **n_min_d, n_min_c** -- hard minimums from TMA store alignment: + + * **M-major** (contiguous dim is M): ``8 * warp_n``. N is the + strided dimension so the minimum is small (8 elements per + warp is enough for the store to issue). + * **N-major** (contiguous dim is N): each TMA store transaction + is 128 bits wide, so the minimum contiguous N per warp is + ``128 / elem_width`` elements, times ``warp_n`` warps. + * **FP6 special case**: TMA store only supports the SW128B + swizzle mode for 6-bit types, requiring 128 contiguous + elements per warp, i.e. ``128 * warp_n``. + + 5. **tile_n** ``= min(cta_tile_n, max(n_perf, n_min_c, n_min_d))``. + + If the chosen N doesn't evenly divide ``cta_tile_n``, fall back + to ``cta_tile_n`` (process the full N in one iteration). + + :param cta_tile_m: Per-CTA tile size in M. + :param cta_tile_n: Per-CTA tile size in N. + :param use_2cta: True when 2-CTA (2-SM) MMA instructions are used. + :param elem_width_d: Bit-width of output element type D (e.g. 16). + :param elem_width_c: Bit-width of source element type C, or ``None`` + if the epilogue has no source (elementwise-only). + :param d_is_m_major: ``True`` if D is column-major (M-contiguous). + :param c_is_m_major: ``True`` if C is column-major (M-contiguous). + :return: ``(tile_m, tile_n)`` -- epilogue subtile dimensions. + """ + # -- Step 1: warp grid ------------------------------------------------ + # (2,2) for 2CTA+M64 so each warp gets 32 rows; else (4,1). + (warp_m, warp_n) = (2, 2) if (cta_tile_m == 64 and use_2cta) else (4, 1) + disable_source = elem_width_c is None + max_bits = elem_width_d if disable_source else max(elem_width_c, elem_width_d) # type: ignore[type-var] + + # -- Step 2: tile_m --------------------------------------------------- + # 32 datapaths per subpartition (hardware constant); cap so each warp + # owns <=32 datapaths, the widest tcgen05.tmem_load mode. + dp_full = 32 + tile_m = min(cta_tile_m, dp_full * warp_m) + + # -- Step 3: n_perf (performance target for N) ------------------------ + if disable_source: + # No source C: SMEM pressure is low. Target a fixed element count + # per iteration to amortise epilogue overhead. + # 4-bit: doubled to 8192 to cut iterations while keeping same SMEM bytes. + compute_elts = 8192 if max_bits == 4 else 4096 + n_perf = compute_elts // tile_m + else: + # With source C: SMEM must also hold the C tile. Smaller N leaves + # more SMEM for mainloop A/B pipeline stages. + if max_bits == 32: + n_perf = 16 if (cta_tile_m > 64 and cta_tile_n <= 128) else 32 + elif max_bits == 16: + n_perf = 32 if cta_tile_n <= 128 else 64 + else: + n_perf = 64 + + # Ensure n_perf evenly divides cta_tile_n (no ragged last subtile). + while cta_tile_n % n_perf != 0: + n_perf //= 2 + + # -- Step 4: n_min (TMA store alignment floors) ----------------------- + # M-major: N is strided -> small minimum (8 elems * warp_n). + # N-major: N is contiguous -> 128-bit TMA granularity -> 128/width * warp_n. + # FP6: TMA store only supports SW128B -> 128 * warp_n. + n_min_d = ( + 8 * warp_n + if d_is_m_major + else (128 * warp_n if elem_width_d == 6 else 128 // elem_width_d * warp_n) + ) + n_min_c = ( + 8 * warp_n + if (c_is_m_major or disable_source) + else (128 * warp_n if elem_width_c == 6 else 128 // elem_width_c * warp_n) # type: ignore[operator] + ) + + # -- Step 5: tile_n --------------------------------------------------- + tile_n = min(cta_tile_n, max(n_perf, n_min_c, n_min_d)) + # If the final tile_n doesn't divide cta_tile_n, fall back to full N. + if cta_tile_n % tile_n != 0: + tile_n = cta_tile_n + + return (tile_m, tile_n) + + +def compute_acc_tmem_cols_per_stage( + cta_tile_m: int, + cta_tile_n: int, + use_2cta: bool, + mma_n: int, + transform_a_source_is_tmem: bool, +) -> int: + """ + Compute the accumulator TMEM column footprint for one pipeline stage. + + Returns the **raw** layout footprint — the caller must enforce hardware + allocation constraints (min 32 columns, power-of-2 total) at the final + ``alloc_tmem`` call site. See ``TmemAllocator.check_valid_num_columns`` + in ``cutlass/utils/tmem_allocator.py``. + + Replicates the C++ logic from without requiring an + MLIR context, so it can be used at kernel discovery time. When an + MLIR context is available, prefer + ``cutlass.cute.nvgpu.tcgen05.find_tmem_tensor_col_offset`` which + computes the column count directly from the compiler-generated + TMEM layout. + + **How TMEM packing works** + + TMEM has 128 datapaths (rows). M maps to datapaths, N maps to + columns. When fewer than 128 DPs are needed, multiple N-values can + share a column by occupying different DP rows:: + + NonInterleaved (each N-tile owns its columns): + + columns 0..N-1 columns N..2N-1 + ┌────────────────┐ ┌────────────────┐ + DP │ ████ tile 0 │ │ ████ tile 1 │ ← 64 DPs used + 0- │ │ │ │ + 127 │ ···· unused │ │ ···· unused │ ← 64 DPs wasted + └────────────────┘ └────────────────┘ + Total: 2N columns + + Interleaved (pairs of N-tiles share columns): + + columns 0..N-1 + ┌────────────────┐ + DP │ ████ tile 0 │ ← DPs 0-15, 32-47, 64-79, 96-111 + 0- │ ▓▓▓▓ tile 1 │ ← DPs 16-31, 48-63, 80-95, 112-127 + 127 │ │ + └────────────────┘ + Total: N columns (halved) + + For **1CTA**, the accumulator uses Interleaved when A is from SMEM + *and* ``cta_tile_m`` == 64. NonInterleaved is forced when + ``cta_tile_m`` == 128 (all datapaths already occupied) or when A is + from TMEM (each datapath can only access its own row, so A and C + must share the same M-to-datapath mapping). + + For **2CTA**, each SM has its own 128-DP TMEM and the fragment layout + is computed for the per-CTA shape (M = ``cta_tile_m``, which the + caller must set to the per-CTA value). Because ``cta_tile_m`` only + occupies part of the 128 DPs, the remaining rows can hold additional + N-values in the same column. The number of N-values that share a + column is 128 / ``cta_tile_m``, so the columns needed are + ``cta_tile_n`` / (128 / ``cta_tile_m``): + + - ``cta_tile_m`` = 32 → 128/32 = 4 per column → ``cta_tile_n`` / 4 + - ``cta_tile_m`` = 64 → 128/64 = 2 per column → ``cta_tile_n`` / 2 + - ``cta_tile_m`` = 128 → 128/128 = 1 (no sharing) → ``cta_tile_n`` + + :param cta_tile_m: Per-CTA tile size in M dimension (for 2CTA the + caller divides the full tile M by 2). + :param cta_tile_n: CTA tile size in N dimension. + :param use_2cta: Whether 2CTA MMA instructions are used. + :param mma_n: MMA atom size in N dimension. + :param transform_a_source_is_tmem: Whether operand A is sourced from + TMEM (forces NonInterleaved allocation). + :return: TMEM columns per accumulator stage (before HW constraints). + """ + if use_2cta: + if cta_tile_m <= 32: + return cta_tile_n // 4 + elif cta_tile_m <= 64: + return cta_tile_n // 2 + else: + return cta_tile_n + else: + # Interleaved when (A from SMEM) and (M == 64) + # NonInterleaved otherwise + is_interleaved = (not transform_a_source_is_tmem) and (cta_tile_m == 64) + if is_interleaved: + num_n_tiles = cta_tile_n // mma_n + packed_groups = (num_n_tiles + 1) // 2 + return packed_groups * mma_n + else: + return cta_tile_n + + +def thrfrg_SFA( + sfa_tensor: cute.Tensor, + tiled_mma: cute.TiledMma, +) -> cute.Tensor: + """Thread-fragment scale factor A tensor for SM120 block-scaled MMA. + + Implements the ThrFrg partitioning for scale factor A according to the + corresponding C++ code. + """ + assert cute.rank(sfa_tensor) >= 2 + + atom_shape_mnk = tiled_mma.shape_mnk + atom_sfa_layout = cute.make_layout(shape=((2, 2, 8), 64), stride=((8, 0, 1), 16)) + permutation_mnk = tiled_mma.permutation_mnk + thr_layout_vmnk = tiled_mma.thr_layout_vmnk + + # Reorder the tensor for TiledAtom + t_tile = (permutation_mnk[0], permutation_mnk[2]) + t_tensor = cute.logical_divide(sfa_tensor, t_tile) + + # Tile the tensor for the Atom + a_tile = ( + cute.make_layout((atom_shape_mnk[0])), + cute.make_layout((atom_shape_mnk[2])), + ) + a_tensor = cute.zipped_divide(t_tensor, a_tile) + + # Transform the Atom mode from (M,K) to (Thr,Val) + tv_tensor = cute.composition(a_tensor, (atom_sfa_layout, None)) + + # Tile the tensor for the Thread + thr_tile = ( + None, + ( + cute.make_layout(cute.size(thr_layout_vmnk[1])), + cute.make_layout(cute.size(thr_layout_vmnk[3])), + ), + ) + + thr_tensor = cute.zipped_divide(tv_tensor, thr_tile) + + return thr_tensor + + +def thrfrg_SFB( + sfb_tensor: cute.Tensor, + tiled_mma: cute.TiledMma, +) -> cute.Tensor: + """Thread-fragment scale factor B tensor for SM120 block-scaled MMA. + + Implements the ThrFrg partitioning for scale factor B according to the + corresponding C++ code. + """ + assert cute.rank(sfb_tensor) >= 2 + + atom_shape_mnk = tiled_mma.shape_mnk + atom_sfb_layout = cute.make_layout(shape=((4, 8), 64), stride=((0, 1), 8)) + permutation_mnk = tiled_mma.permutation_mnk + thr_layout_vmnk = tiled_mma.thr_layout_vmnk + + # Reorder the tensor for TiledAtom + t_tile = (permutation_mnk[1], permutation_mnk[2]) + t_tensor = cute.logical_divide(sfb_tensor, t_tile) + + # Tile the tensor for the Atom + a_tile = ( + cute.make_layout((atom_shape_mnk[1])), + cute.make_layout((atom_shape_mnk[2])), + ) + a_tensor = cute.zipped_divide(t_tensor, a_tile) + + # Transform the Atom mode from (M,K) to (Thr,Val) + tv_tensor = cute.composition(a_tensor, (atom_sfb_layout, None)) + + # Tile the tensor for the Thread + thr_tile = ( + None, + ( + cute.make_layout(cute.size(thr_layout_vmnk[2])), + cute.make_layout(cute.size(thr_layout_vmnk[3])), + ), + ) + + thr_tensor = cute.zipped_divide(tv_tensor, thr_tile) + + return thr_tensor + + +def partition_fragment_SFA( + sfa_tensor: cute.Tensor, + thr_mma: cute.ThrMma, + tidx: int, +) -> cute.Tensor: + """Partition and create a register fragment for scale factor A.""" + thrfrg_sfa_layout = thrfrg_SFA(sfa_tensor.layout, thr_mma) # type: ignore[arg-type] + thr_tensor = cute.make_tensor(sfa_tensor.iterator, thrfrg_sfa_layout) + thr_vmnk = thr_mma.thr_layout_vmnk.get_flat_coord(tidx) + thr_vmk = (thr_vmnk[0], (thr_vmnk[1], thr_vmnk[3])) + partitioned_sfa = thr_tensor[thr_vmk, (None, None)] + partitioned_sfa = cute.group_modes(cute.flatten(partitioned_sfa), 0, 2) + return cute.make_fragment_like(partitioned_sfa) + + +def partition_fragment_SFB( + sfb_tensor: cute.Tensor, + thr_mma: cute.ThrMma, + tidx: int, +) -> cute.Tensor: + """Partition and create a register fragment for scale factor B.""" + thrfrg_sfb_layout = thrfrg_SFB(sfb_tensor.layout, thr_mma) # type: ignore[arg-type] + thr_tensor = cute.make_tensor(sfb_tensor.iterator, thrfrg_sfb_layout) + thr_vmnk = thr_mma.thr_layout_vmnk.get_flat_coord(tidx) + thr_vnk = (thr_vmnk[0], (thr_vmnk[2], thr_vmnk[3])) + partitioned_sfb = thr_tensor[thr_vnk, (None, None)] + partitioned_sfb = cute.group_modes(cute.flatten(partitioned_sfb), 0, 2) + partitioned_sfb = cute.group_modes(partitioned_sfb, 1, 3) + return cute.make_fragment_like(partitioned_sfb) + + +def get_layoutSFA_TV(tiled_mma: cute.TiledMma) -> cute.Layout: + """Get the Thread-Value layout for scale factor A.""" + if tiled_mma.permutation_mnk is not None: + perm_m = tiled_mma.permutation_mnk[0] + perm_k = tiled_mma.permutation_mnk[2] + tile_m = cute.size(perm_m) + tile_k = cute.size(perm_k) + else: + tile_shape_mnk = tiled_mma.shape_mnk * tiled_mma.thr_layout_vmnk + tile_m = cute.size(tile_shape_mnk[0]) + tile_k = cute.size(tile_shape_mnk[2]) + + ref_A = cute.make_layout((tile_m, tile_k)) + thr_layout_vmnk = tiled_mma.thr_layout_vmnk + + # (ThrV, (ThrM, ThrK)) -> (ThrV, (ThrM, ThrN, ThrK)) + atile = ( + None, + ( + cute.make_layout( + shape=( + cute.size(thr_layout_vmnk[1]), + cute.size(thr_layout_vmnk[2]), + ), + stride=(1, 0), + ), + None, + ), + ) + + # thr_idx -> (ThrV,ThrM,ThrN,ThrK) + thridx_2_thrid = cute.right_inverse(thr_layout_vmnk) + thrfrg_sfa = thrfrg_SFA(ref_A, tiled_mma) + layout_tv_1 = cute.composition(thrfrg_sfa, (atile, None)) + layout_tv = cute.composition(layout_tv_1, (thridx_2_thrid, None)) + + return layout_tv # type: ignore[return-value] + + +def get_layoutSFB_TV(tiled_mma: cute.TiledMma) -> cute.Layout: + """Get the Thread-Value layout for scale factor B.""" + if tiled_mma.permutation_mnk is not None: + perm_n_layout = tiled_mma.permutation_mnk[1] + perm_k = tiled_mma.permutation_mnk[2] + tile_n = cute.size(perm_n_layout) + tile_k = cute.size(perm_k) + else: + tile_shape_mnk = tiled_mma.shape_mnk * tiled_mma.thr_layout_vmnk + tile_n = cute.size(tile_shape_mnk[1]) + tile_k = cute.size(tile_shape_mnk[2]) + + ref_B = cute.make_layout((tile_n, tile_k)) + thr_layout_vmnk = tiled_mma.thr_layout_vmnk + + # (ThrV, (ThrM, ThrK)) -> (ThrV, (ThrM, ThrN, ThrK)) + atile = ( + None, + ( + cute.make_layout( + shape=( + cute.size(thr_layout_vmnk[1]), + cute.size(thr_layout_vmnk[2]), + ), + stride=(0, 1), + ), + None, + ), + ) + + # thr_idx -> (ThrV,ThrM,ThrN,ThrK) + thridx_2_thrid = cute.right_inverse(thr_layout_vmnk) + thrfrg_sfb = thrfrg_SFB(ref_B, tiled_mma) + layout_tv = cute.composition(thrfrg_sfb, (atile, None)) + layout_tv = cute.composition(layout_tv, (thridx_2_thrid, None)) + return layout_tv # type: ignore[return-value] __all__ = [ + "compute_epilogue_tile_size", + "compute_acc_tmem_cols_per_stage", "compute_epilogue_tile_shape", "get_smem_store_op", "get_tmem_load_op", @@ -1192,4 +2161,10 @@ __all__ = [ "cluster_shape_to_tma_atom_SFB", "get_permutation_mnk", "get_num_tmem_alloc_cols", # deprecated; use cutlass.utils.get_num_tmem_alloc_cols instead + "thrfrg_SFA", + "thrfrg_SFB", + "partition_fragment_SFA", + "partition_fragment_SFB", + "get_layoutSFA_TV", + "get_layoutSFB_TV", ] diff --git a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py index b87c701db..7bc4ee6f4 100644 --- a/python/CuTeDSL/cutlass/utils/blockscaled_layout.py +++ b/python/CuTeDSL/cutlass/utils/blockscaled_layout.py @@ -10,12 +10,14 @@ # is strictly prohibited. from dataclasses import dataclass, field +from typing import Optional from cutlass.cutlass_dsl import dsl_user_op import cutlass.cute as cute -from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +from cutlass.cute.nvgpu import OperandMajorMode +from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir @@ -65,8 +67,8 @@ def tile_atom_to_shape_SF( Shape: cute.Shape, sf_vec_size: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """ A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. @@ -90,8 +92,8 @@ def make_smem_layout_sf( sf_vec_size: int, num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """ A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. @@ -105,13 +107,22 @@ def make_smem_layout_sf( """ smem_layout = cute.tile_to_shape( - BlockScaledBasicChunk(sf_vec_size).layout, tile_shape, (2, 1) + BlockScaledBasicChunk(sf_vec_size).layout, + tile_shape, # type: ignore[arg-type] + (2, 1), + loc=loc, + ip=ip, ) smem_layout_staged = cute.append( smem_layout, cute.make_layout( - num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + num_stages, + stride=cute.cosize(cute.filter_zeros(smem_layout)), + loc=loc, + ip=ip, ), + loc=loc, + ip=ip, ) return smem_layout_staged @@ -123,8 +134,8 @@ def make_smem_layout_sfa( sf_vec_size: int, num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """ Make smem layout for SFA based on: @@ -148,20 +159,23 @@ def make_smem_layout_sfa( """ # (CTA_Tile_Shape_M, MMA_Tile_Shape_K) sfa_tile_shape = ( - mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), - mma_tiler_mnk[2], + mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), # type: ignore[index] + mma_tiler_mnk[2], # type: ignore[index] ) # ((Atom_M, Rest_M),(Atom_K, Rest_K)) smem_layout = cute.tile_to_shape( BlockScaledBasicChunk(sf_vec_size).layout, - sfa_tile_shape, + sfa_tile_shape, # type: ignore[arg-type] (2, 1), ) - mma_tile_inst_k = 4 + # Number of MMA instructions to cover all k-tiles + mma_tile_inst_m = mma_tiler_mnk[0] // cute.size(tiled_mma.shape_mnk, mode=[0]) # type: ignore[index] + mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index] + # (CTA_Tile_Shape_M, MMA_Inst_Shape_K) - sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k)) + sfa_tile_shape = cute.shape_div(sfa_tile_shape, (mma_tile_inst_m, mma_tile_inst_k)) # ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K)) smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape) @@ -188,8 +202,8 @@ def make_smem_layout_sfb( sf_vec_size: int, num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """ Make smem layout for SFB based on: @@ -213,20 +227,23 @@ def make_smem_layout_sfb( """ # (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K) sfb_tile_shape = ( - cute.round_up(mma_tiler_mnk[1], 128), - mma_tiler_mnk[2], + cute.round_up(mma_tiler_mnk[1], 128), # type: ignore[index, arg-type] + mma_tiler_mnk[2], # type: ignore[index] ) # ((Atom_N, Rest_N),(Atom_K, Rest_K)) smem_layout = cute.tile_to_shape( BlockScaledBasicChunk(sf_vec_size).layout, - sfb_tile_shape, + sfb_tile_shape, # type: ignore[arg-type] (2, 1), ) - mma_tile_inst_k = 4 + # Number of MMA instructions to cover all k-tiles + mma_tile_inst_n = mma_tiler_mnk[1] // cute.size(tiled_mma.shape_mnk, mode=[1]) # type: ignore[index] + mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index] + # (CTA_Tile_Shape_N, MMA_Inst_Shape_K) - sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k)) + sfb_tile_shape = cute.shape_div(sfb_tile_shape, (mma_tile_inst_n, mma_tile_inst_k)) # ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K) smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape) @@ -253,8 +270,8 @@ def sm120_make_smem_layout_sfa( sf_vec_size: int, num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """ Make smem layout for SFA based on: @@ -288,26 +305,26 @@ def sm120_make_smem_layout_sfa( k_basic_block_shape = (sf_vec_size, mma_nsf) k_basic_block_stride = (0, 1) - assert tile_shape_mnk[0] % blk_mn == 0, ( + assert tile_shape_mnk[0] % blk_mn == 0, ( # type: ignore[index, operator] "tile_shape_mnk[0] must be divisible by blk_mn" ) - sSFA_shapeM = (mn_basic_block_shape, tile_shape_mnk[0] // blk_mn) + sSFA_shapeM = (mn_basic_block_shape, tile_shape_mnk[0] // blk_mn) # type: ignore[index, operator] sSF_strideM = (mn_basic_block_stride, blk_elems) - assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( + assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index] "tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf" ) sSFA_shapeK = ( k_basic_block_shape, blk_sf // mma_nsf, - tile_shape_mnk[2] // sf_vec_size // blk_sf, + tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator] ) sSF_strideK = ( k_basic_block_stride, mma_nsf, - tile_shape_mnk[0] // blk_mn * blk_elems, + tile_shape_mnk[0] // blk_mn * blk_elems, # type: ignore[index, operator] ) sSFA_shape = (sSFA_shapeM, sSFA_shapeK) @@ -333,8 +350,8 @@ def sm120_make_smem_layout_sfb( sf_vec_size: int, num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """ Make smem layout for SFB based on: @@ -364,11 +381,11 @@ def sm120_make_smem_layout_sfb( assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32" - assert tile_shape_mnk[1] % blk_mn == 0, ( + assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator] "tile_shape_mnk[1] must be divisible by blk_mn" ) - assert tile_shape_mnk[2] % sf_vec_size == 0, ( + assert tile_shape_mnk[2] % sf_vec_size == 0, ( # type: ignore[index, operator] "tile_shape_mnk[2] must be divisible by sf_vec_size" ) @@ -379,26 +396,26 @@ def sm120_make_smem_layout_sfb( k_basic_block_shape = (sf_vec_size, mma_nsf) k_basic_block_stride = (0, 1) - assert tile_shape_mnk[1] % blk_mn == 0, ( + assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator] "tile_shape_mnk[1] must be divisible by blk_mn" ) - sSFA_shapeN = (mn_basic_block_shape, tile_shape_mnk[1] // blk_mn) + sSFA_shapeN = (mn_basic_block_shape, tile_shape_mnk[1] // blk_mn) # type: ignore[index, operator] sSF_strideN = (mn_basic_block_stride, blk_elems) - assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( + assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index] "tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf" ) sSFA_shapeK = ( k_basic_block_shape, blk_sf // mma_nsf, - tile_shape_mnk[2] // sf_vec_size // blk_sf, + tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator] ) sSF_strideK = ( k_basic_block_stride, mma_nsf, - tile_shape_mnk[1] // blk_mn * blk_elems, + tile_shape_mnk[1] // blk_mn * blk_elems, # type: ignore[index, operator] ) sSFA_shape = (sSFA_shapeN, sSFA_shapeK) @@ -424,8 +441,8 @@ def make_tmem_layout_sfa( sf_vec_size: int, smem_layout: cute.Layout, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """Make tmem layout for SFA based on: @@ -447,7 +464,7 @@ def make_tmem_layout_sfa( :rtype: cute.Layout """ atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip) - cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index] sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size @@ -462,8 +479,8 @@ def make_tmem_layout_sfb( sf_vec_size: int, smem_layout: cute.Layout, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Layout: """Make tmem layout for SFB based on: @@ -485,9 +502,156 @@ def make_tmem_layout_sfb( :rtype: cute.Layout """ atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip) - cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index] sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size ) return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class Sm103BlockScaledBasicChunk: + """ + Basic scale-factor atom layout decided by tcgen05 BlockScaled MMA Ops on SM103. + + Represents the fixed layout pattern for scale factors used by tcgen05 + BlockScaled MMA Ops on SM103. The layout is determined by the instruction + specification and is not configurable. + """ + + sf_vec_size: int + major_mode: OperandMajorMode = OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + atom_shape: cute.Shape + atom_stride: cute.Stride + if self.major_mode == OperandMajorMode.K: + atom_shape = ((8, 4, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 128, 4), (0, 1)) + else: + atom_shape = ((self.sf_vec_size, 4), (8, 4, 4)) + atom_stride = ((0, 1), (16, 128, 4)) + + object.__setattr__( + self, "_layout", cute.make_layout(shape=atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + return self._layout + + +@dsl_user_op +def sm103_make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make SMEM layout for SFA based on: + 1) Sm103BlockScaledBasicChunk, 2) MMA tiler, 3) sf_vec_size, 4) stages. + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler: The mma tiler shape + :type mma_tiler: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + mma_shape_mk = tiled_mma.partition_shape_A((mma_tiler[0], mma_tiler[2])) # type: ignore[index] + sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.a_major_mode).layout # type: ignore[attr-defined] + k_divisor = 4 if sf_vec_size == 16 else 2 + mma_sfa_tiler = ( + mma_shape_mk[0][0] * mma_shape_mk[1], + mma_shape_mk[0][1] * mma_shape_mk[2] // k_divisor, + ) + sfa_smem_atom_layout = cute.tiled_product( + sf_atom, + cute.make_layout( + cute.shape_div(mma_sfa_tiler, cute.product_each(sf_atom.shape)) + ), + ) + sfa_smem_layout_staged = cute.make_layout( + shape=cute.append(sfa_smem_atom_layout.shape, num_stages), + stride=cute.append( + sfa_smem_atom_layout.stride, + cute.size(cute.filter_zeros(sfa_smem_atom_layout)), + ), + ) + return sfa_smem_layout_staged + + +@dsl_user_op +def sm103_make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make SMEM layout for SFB based on the basic chunk, MMA tiler, sf_vec_size, stages. + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler: The mma tiler shape + :type mma_tiler: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFB + :rtype: cute.Layout + """ + sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.b_major_mode).layout # type: ignore[attr-defined] + k_divisor = 4 if sf_vec_size == 16 else 2 + mma_sfb_tiler = (mma_tiler[1], mma_tiler[2] // k_divisor) # type: ignore[index, operator] + if mma_sfb_tiler[0] == 128: + sfb_smem_atom_layout = cute.tiled_product( + sf_atom, + cute.make_layout( + cute.shape_div(mma_sfb_tiler, cute.product_each(sf_atom.shape)) + ), + ) + else: + sf_k_major_atom256 = cute.make_layout( + shape=( + (32, 4, 2), + (sf_vec_size, 4), + ), + stride=( + (16, 4, mma_sfb_tiler[1] // sf_vec_size // 4 * 512), + (0, 1), + ), + ) + sfb_smem_atom_layout = cute.tiled_product( + sf_k_major_atom256, + cute.make_layout( + cute.shape_div( + mma_sfb_tiler, cute.product_each(sf_k_major_atom256.shape) + ) + ), + ) + + sfb_smem_layout_staged = cute.make_layout( + shape=cute.append(sfb_smem_atom_layout.shape, num_stages), + stride=cute.append( + sfb_smem_atom_layout.stride, + cute.size(cute.filter_zeros(sfb_smem_atom_layout)), + ), + ) + return sfb_smem_layout_staged diff --git a/python/CuTeDSL/cutlass/utils/distributed.py b/python/CuTeDSL/cutlass/utils/distributed.py index 725bf7cb2..2adcc47a8 100644 --- a/python/CuTeDSL/cutlass/utils/distributed.py +++ b/python/CuTeDSL/cutlass/utils/distributed.py @@ -10,20 +10,17 @@ # is strictly prohibited. from functools import partial -from typing import Tuple, Union +from typing import Literal, Optional, Tuple, Type, Union import cutlass import cutlass.cute as cute -from cutlass.cute.typing import Pointer, Int32, Literal -from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass.cute.typing import Pointer, Int32 +from cutlass.cutlass_dsl import Numeric, T, dsl_user_op from cutlass._mlir import ir from cutlass._mlir.dialects import llvm from typing_extensions import deprecated __all__ = [ - # Deprecated - "atomicAdd", - "ld_bypass", # Message Passing Lock & Unlock "multimem_red_add1", "red_add1", @@ -64,8 +61,14 @@ __all__ = [ @deprecated("atomicAdd is deprecated, use cute.arch.atomic_add instead") @dsl_user_op -def atomicAdd(dst_ptr: Pointer, val: Int32, *, loc=None, ip=None) -> Int32: - return cute.arch.atomic_add( +def atomicAdd( + dst_ptr: Pointer, + val: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Int32: + return cute.arch.atomic_add( # type: ignore[return-value] ptr=dst_ptr.llvm_ptr, val=val, sem="relaxed", @@ -79,7 +82,7 @@ def atomicAdd(dst_ptr: Pointer, val: Int32, *, loc=None, ip=None) -> Int32: "ld_bypass is deprecated, use cute.arch.load with cop='cv' directly instead" ) @cute.jit -def ld_bypass(input_tensor: cute.Tensor): +def ld_bypass(input_tensor: cute.Tensor) -> cute.Tensor: fragment = cute.make_rmem_tensor(input_tensor.layout, input_tensor.element_type) copy_atom = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), @@ -101,8 +104,8 @@ def ld_bypass(input_tensor: cute.Tensor): def multimem_red_release_gpu_add1( lock_ptr: Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: llvm.inline_asm( None, @@ -120,8 +123,8 @@ def multimem_red_release_gpu_add1( def multimem_red_release_sys_add1( lock_ptr: Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: llvm.inline_asm( None, @@ -138,8 +141,8 @@ def multimem_red_release_sys_add1( @dsl_user_op def multimem_red_relaxed_gpu_add1( lock_ptr: Pointer, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: llvm.inline_asm( None, @@ -156,8 +159,8 @@ def multimem_red_relaxed_gpu_add1( @dsl_user_op def multimem_red_relaxed_sys_add1( lock_ptr: Pointer, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: llvm.inline_asm( None, @@ -177,8 +180,8 @@ def multimem_red_add1( *, order: str, scope: str, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ add 1 to multicast ptr @@ -201,8 +204,8 @@ def red_add1( *, order: str, scope: str, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ add 1 to unicast ptr @@ -226,8 +229,8 @@ def spin_lock_atom_cas_relaxed_wait( expected_val: Int32, reset_val: Int32, scope: Literal["gpu", "sys"], - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ wait on a spin lock until the expected count is reached. Reset flag to reset_val if the expected count is reached. @@ -252,8 +255,8 @@ def spin_lock_atom_cas_acquire_wait( expected_val: Int32, reset_val: Int32, scope: Literal["gpu", "sys"], - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ wait on a spin lock until the expected count is reached. Reset flag to reset_val if the expected count is reached. @@ -277,8 +280,8 @@ def spin_lock_ld_lt_relaxed_wait( *, expected_val: Int32, scope: Literal["gpu", "sys"], - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ wait on a spin lock until the expected count is reached. @@ -306,8 +309,8 @@ def multimem_ld_reduce_128bit_base( mc_ptr: Pointer, *, ptx_string: str = "", - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Int32, Int32, Int32, Int32]: mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) return_struct = llvm.inline_asm( @@ -330,8 +333,8 @@ def multimem_ld_reduce_64bit_base( mc_ptr: Pointer, *, ptx_string: str = "", - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Int32, Int32]: mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) return_struct = llvm.inline_asm( @@ -354,8 +357,8 @@ def multimem_ld_reduce_32bit_base( mc_ptr: Pointer, *, ptx_string: str = "", - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Int32]: mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) return_struct = llvm.inline_asm( @@ -448,8 +451,8 @@ def multimem_st_4xb32( z: Int32, w: Int32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) llvm.inline_asm( @@ -471,8 +474,8 @@ def multimem_st_2xb32( x: Int32, y: Int32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) llvm.inline_asm( @@ -493,8 +496,8 @@ def multimem_st_1xb32( mc_ptr: Pointer, x: Int32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value(loc=loc, ip=ip) llvm.inline_asm( @@ -518,10 +521,10 @@ def multimem_st_1xb32( def multimem_ld_reduce( mc_ptr: Pointer, *, - dtype, + dtype: Type[Numeric], num_elements: int, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[Tuple[Int32, Int32, Int32, Int32], Tuple[Int32, Int32], Tuple[Int32]]: """ Dispatch to appropriate multimem_ld_reduce variant based on dtype and num_elements. @@ -576,8 +579,8 @@ def multimem_ld_reduce( def multimem_st( mc_ptr: Pointer, *regs: Int32, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: """ Dispatch to appropriate multimem_st variant based on number of registers. diff --git a/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py index 2e043951d..05122e3f4 100644 --- a/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/dynamic_persistent_tile_scheduler.py @@ -9,18 +9,19 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Tuple +import inspect +from typing import Optional, Tuple +import cutlass from cutlass.cutlass_dsl import ( Boolean, Integer, Int32, - min, extract_mlir_values, new_from_mlir_values, dsl_user_op, - T, ) + from cutlass._mlir import ir from cutlass.utils.static_persistent_tile_scheduler import ( WorkTileInfo, @@ -38,14 +39,17 @@ class ClcDynamicPersistentTileSchedulerParams: :type cluster_shape_mn: tuple """ + @dsl_user_op def __init__( self, problem_shape_ntile_mnl: cute.Shape, cluster_shape_mnk: cute.Shape, + swizzle_size: int = 1, + raster_along_m: bool = True, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Initializes the ClcDynamicPersistentTileSchedulerParams with the given parameters. @@ -54,40 +58,177 @@ class ClcDynamicPersistentTileSchedulerParams: :type problem_shape_ntile_mnl: cute.Shape :param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions. :type cluster_shape_mnk: cute.Shape + :param swizzle_size: Swizzling size in the unit of cluster. 1 means no swizzle + :type swizzle_size: int + :param raster_along_m: Rasterization order of clusters. Only used when swizzle_size > 1. + True means along M, false means along N. + :type raster_along_m: bool :raises ValueError: If cluster_shape_k is not 1. """ - if cluster_shape_mnk[2] != 1: - raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") + if cluster_shape_mnk[2] != 1: # type: ignore[index] + raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") # type: ignore[index] + if swizzle_size < 1: + raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") self.problem_shape_ntile_mnl = problem_shape_ntile_mnl # cluster_shape_mnk is kept for reconstruction self._cluster_shape_mnk = cluster_shape_mnk - self.cluster_shape_mn = cluster_shape_mnk[:2] + self.cluster_shape_mn = cluster_shape_mnk[:2] # type: ignore[index] + self.swizzle_size = swizzle_size + self._raster_along_m = raster_along_m + self.cluster_shape_major_fdd = None + self.cluster_shape_minor_fdd = None self._loc = loc - def __extract_mlir_values__(self): + # By default, we follow m major (col-major) raster order, so make a col-major layout + self.problem_layout_ncluster_mnl = cute.make_layout( + cute.ceil_div( + self.problem_shape_ntile_mnl, + cluster_shape_mnk[:2], # type: ignore[index] + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + + # Apply swizzle if swizzle_size > 1 + if swizzle_size > 1: + problem_shape_ncluster_mnl = cute.round_up( + self.problem_layout_ncluster_mnl.shape, + (1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1), + ) + + if raster_along_m: + self.problem_layout_ncluster_mnl = cute.make_layout( + ( + problem_shape_ncluster_mnl[0], # type: ignore[index] + (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), # type: ignore[index, operator] + problem_shape_ncluster_mnl[2], # type: ignore[index] + ), + stride=( + swizzle_size, + (1, swizzle_size * problem_shape_ncluster_mnl[0]), # type: ignore[index] + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], # type: ignore[index, operator] + ), + loc=loc, + ip=ip, + ) + else: + self.problem_layout_ncluster_mnl = cute.make_layout( + ( + (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), # type: ignore[index, operator] + problem_shape_ncluster_mnl[1], # type: ignore[index] + problem_shape_ncluster_mnl[2], # type: ignore[index] + ), + stride=( + (1, swizzle_size * problem_shape_ncluster_mnl[1]), # type: ignore[index] + swizzle_size, + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], # type: ignore[index, operator] + ), + loc=loc, + ip=ip, + ) + elif not raster_along_m: + cluster_count_major = self.problem_layout_ncluster_mnl.shape[1] + cluster_count_minor = self.problem_layout_ncluster_mnl.shape[0] + self.cluster_shape_major_fdd = cute.fast_divmod_create_divisor( + cluster_count_major, loc=loc, ip=ip + ) + self.cluster_shape_minor_fdd = cute.fast_divmod_create_divisor( + cluster_count_minor, loc=loc, ip=ip + ) + + def __extract_mlir_values__(self) -> list[ir.Value]: values, self._values_pos = [], [] - for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]: + for obj in [ + self.problem_shape_ntile_mnl, + self._cluster_shape_mnk, + self.swizzle_size, + self._raster_along_m, + ]: obj_values = extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) + + # Add FastDivmod divisors to MLIR values for Host->Device transfer + # Only add non-None values to avoid MLIR type errors + fastdivmod_values = [] + fastdivmod_indices = [] # Track which FastDivmod objects are present + + for i, (fdd_name, fdd_obj) in enumerate( + [ + ("cluster_shape_major_fdd", self.cluster_shape_major_fdd), + ("cluster_shape_minor_fdd", self.cluster_shape_minor_fdd), + ] + ): + if fdd_obj is not None: + # Extract MLIR values from FastDivmodDivisor objects + fdd_values = extract_mlir_values(fdd_obj) + fastdivmod_values.extend(fdd_values) + fastdivmod_indices.append(i) + + values += fastdivmod_values + self._values_pos.append( + len(fastdivmod_indices) + ) # Store count of FastDivmod objects, not values + self._fastdivmod_indices = fastdivmod_indices # Store for reconstruction + return values - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "ClcDynamicPersistentTileSchedulerParams": obj_list = [] + values_copy = list(values) # Make a copy to avoid modifying original + + # Reconstruct original objects from MLIR values for obj, n_items in zip( - [self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos + [ + self.problem_shape_ntile_mnl, + self._cluster_shape_mnk, + self.swizzle_size, + self._raster_along_m, + ], + self._values_pos[:-1], # Exclude FastDivmod count ): - obj_list.append(new_from_mlir_values(obj, values[:n_items])) - values = values[n_items:] - return ClcDynamicPersistentTileSchedulerParams( + obj_list.append(new_from_mlir_values(obj, values_copy[:n_items])) + values_copy = values_copy[n_items:] + + new_params = ClcDynamicPersistentTileSchedulerParams( *(tuple(obj_list)), loc=self._loc ) + # Restore FastDivmod divisors from remaining values + fdd_names = [ + "cluster_shape_major_fdd", + "cluster_shape_minor_fdd", + ] + + if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0: + # Override the FastDivmod divisors created by __init__ with reconstructed ones + for j, original_index in enumerate(self._fastdivmod_indices): + fdd_name = fdd_names[original_index] + # Get the original FastDivmodDivisor object + original_fdd = getattr(self, fdd_name) + if original_fdd is not None and j < len(values_copy): + # Each FastDivmodDivisor has 1 MLIR value + reconstructed_fdd = new_from_mlir_values( + original_fdd, [values_copy[j]] + ) + setattr(new_params, fdd_name, reconstructed_fdd) + + return new_params + @dsl_user_op - def get_grid_shape(self, *, loc=None, ip=None) -> Tuple[Integer, Integer, Integer]: + def get_grid_shape( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tuple[Integer, Integer, Integer]: """ Computes the grid shape based on the problem shape and cluster shape. @@ -97,7 +238,26 @@ class ClcDynamicPersistentTileSchedulerParams: problem_ceiling_cta_mnl = cute.round_up( self.problem_shape_ntile_mnl, self._cluster_shape_mnk ) - return problem_ceiling_cta_mnl + + if self.swizzle_size == 1 and self._raster_along_m: + return problem_ceiling_cta_mnl # type: ignore[return-value] + else: + # If swizzling is enabled or raster_along_n, + # we are going to map from a linear idx to tile id manually. + num_problem_cta_count = cute.size(problem_ceiling_cta_mnl, loc=loc, ip=ip) + num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip) + return ( + *self.cluster_shape_mn, + num_problem_cta_count // num_ctas_per_cluster, + ) + + +# Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator +ClcDynamicPersistentTileSchedulerParams.__init__.__signature__ = inspect.Signature( # type: ignore[attr-defined] + [ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + ] +) class ClcDynamicPersistentTileScheduler: @@ -174,9 +334,9 @@ class ClcDynamicPersistentTileScheduler: grid_dim: Tuple[Integer, Integer, Integer], clc_response_ptr: cute.Pointer, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "ClcDynamicPersistentTileScheduler": """Initialize the dynamic persistent tile scheduler. :param params: Parameters for the persistent @@ -221,8 +381,8 @@ class ClcDynamicPersistentTileScheduler: def get_grid_shape( params: ClcDynamicPersistentTileSchedulerParams, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Integer, Integer, Integer]: """Calculates the grid shape to be launched on GPU using problem shape, threadblock shape, and active cluster size. @@ -236,9 +396,58 @@ class ClcDynamicPersistentTileScheduler: return params.get_grid_shape(loc=loc, ip=ip) + @cute.jit + def _swizzle_and_rasterize( + self, + x_idx: Int32, + y_idx: Int32, + z_idx: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tuple[Int32, Int32, Int32]: + """Swizzle and rasterize the given coordinates for leader CTA of the cluster. + x_idx, y_idx, and z_idx must be divisible by cluster shape x, y, and z respectively. They should not be offset + by the ID of the CTA in the cluster. + """ + if cutlass.const_expr(self.params.swizzle_size == 1): + if cutlass.const_expr(self.params._raster_along_m): + return x_idx, y_idx, z_idx + else: + # Decode linear index using FastDivmod objects. + # First, get cluster_major using cluster_shape_major_fdd + cluster_minor_batch, cluster_major = divmod( # type: ignore[operator] + z_idx, self.params.cluster_shape_major_fdd + ) + # Then decode cluster_minor_batch to get cluster_minor and batch_l using FastDivmod + batch_l, cluster_minor = divmod( + cluster_minor_batch, self.params.cluster_shape_minor_fdd + ) + cluster_m = cluster_minor + cluster_n = cluster_major + + return ( + cluster_m * self.params.cluster_shape_mn[0], + cluster_n * self.params.cluster_shape_mn[1], + batch_l, + ) + else: + cluster_coord = self.params.problem_layout_ncluster_mnl.get_flat_coord( + z_idx, loc=loc, ip=ip + ) + return ( + cluster_coord[0] * self.params.cluster_shape_mn[0], + cluster_coord[1] * self.params.cluster_shape_mn[1], + cluster_coord[2], + ) + @dsl_user_op def work_tile_info_from_clc_response( - self, result_addr: cute.Pointer, *, loc=None, ip=None + self, + result_addr: cute.Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> WorkTileInfo: """ Simulates parsing CLC response data in Python. @@ -249,23 +458,53 @@ class ClcDynamicPersistentTileScheduler: "async.shared", space="cta", ) - cta_idx_in_cluster, cta_idy_in_cluster, _ = self.cta_id_in_cluster + m_idx, n_idx, l_idx = self._swizzle_and_rasterize( + m_idx, n_idx, l_idx, loc=loc, ip=ip + ) + cta_idx_in_cluster, cta_idy_in_cluster, _ = self.cta_id_in_cluster # type: ignore[misc] cur_tile_coord = (m_idx + cta_idx_in_cluster, n_idx + cta_idy_in_cluster, l_idx) + return WorkTileInfo(cur_tile_coord, vld) @dsl_user_op - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def get_current_work( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> WorkTileInfo: smem_addr = self._clc_response_ptr work_tile = self.work_tile_info_from_clc_response(smem_addr) return work_tile @dsl_user_op - def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + def initial_work_tile_info( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> WorkTileInfo: bidx, bidy, bidz = self._block_idx - return WorkTileInfo((bidx, bidy, bidz), True) + # Subtract cta_id_in_cluster from block_idx because swizzle_and_rasterize expects coordinates to be + # those of the leader CTA in the cluster. + cta_idx_in_cluster, cta_idy_in_cluster, _ = self.cta_id_in_cluster # type: ignore[misc] + m_idx = bidx - cta_idx_in_cluster + n_idx = bidy - cta_idy_in_cluster + l_idx = bidz + m_idx, n_idx, l_idx = self._swizzle_and_rasterize( + m_idx, n_idx, l_idx, loc=loc, ip=ip + ) + cur_tile_coord = (m_idx + cta_idx_in_cluster, n_idx + cta_idy_in_cluster, l_idx) + return WorkTileInfo(cur_tile_coord, Boolean(True)) @dsl_user_op - def advance_to_next_work(self, mbarrier_addr, loc=None, ip=None): + def advance_to_next_work( + self, + mbarrier_addr: cute.Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: # Query new work tile with cute.arch.elect_one(): cute.arch.issue_clc_query( diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm100.py b/python/CuTeDSL/cutlass/utils/gemm/sm100.py index 9f180b10d..ead18ca60 100644 --- a/python/CuTeDSL/cutlass/utils/gemm/sm100.py +++ b/python/CuTeDSL/cutlass/utils/gemm/sm100.py @@ -9,11 +9,10 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited -from typing import Tuple, Union +from typing import Any, Optional, Tuple, Union import cutlass.cute as cute from cutlass.cutlass_dsl import Int32, Boolean, Constexpr, const_expr import cutlass.pipeline as pipeline -from cutlass.utils.static_persistent_tile_scheduler import StaticPersistentTileScheduler from cutlass.utils.blackwell_helpers import get_tmem_load_op, get_smem_store_op from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass.cute.nvgpu.common import CacheEvictionPriority @@ -21,8 +20,6 @@ from cutlass.cute.nvgpu.common import CacheEvictionPriority __all__ = [ "epilogue_tma_store", "epilogue", - "epilogue_tma_store_release_flag", - "epilogue_release_flag", ] @@ -39,21 +36,35 @@ def transform_partitioned_tensor_layout(tensor: cute.Tensor) -> cute.Tensor: :rtype: cute.Tensor """ layout = tensor.layout + # Save original layout in case it is a composed layout + stored_layout = layout + + if isinstance(stored_layout, cute.ComposedLayout): + # For composed layouts, we only modify the outer layout + layout = layout.outer + shape = layout.shape - stride = layout.stride + stride = layout.stride # type: ignore[union-attr] # Build new shape: ((shape[0][0], shape[1]), (shape[0][1], shape[2]), ...rest) - new_shape = ((shape[0][0], shape[1]), (shape[0][1], shape[2]), *shape[3:]) + new_shape = ((shape[0][0], shape[1]), (shape[0][1], shape[2]), *shape[3:]) # type: ignore[index] # Build new stride: ((stride[0][0], stride[1]), (stride[0][1], stride[2]), ...rest) - new_stride = ((stride[0][0], stride[1]), (stride[0][1], stride[2]), *stride[3:]) + new_stride = ((stride[0][0], stride[1]), (stride[0][1], stride[2]), *stride[3:]) # type: ignore[index] new_layout = cute.make_layout(shape=new_shape, stride=new_stride) + + if isinstance(stored_layout, cute.ComposedLayout): + # Recreate the composed layout + new_layout = cute.make_composed_layout( + stored_layout.inner, stored_layout.offset, new_layout + ) + return cute.make_tensor(tensor.iterator, new_layout) def epilogue_tmem_copy_and_partition( - gemm_kernel, + gemm_kernel: Any, tidx: Int32, tAcc: cute.Tensor, tCgC: cute.Tensor, @@ -117,7 +128,7 @@ def epilogue_tmem_copy_and_partition( def epilogue_smem_copy_and_partition( - gemm_kernel, + gemm_kernel: Any, tiled_copy_t2r: cute.TiledCopy, tTR_rC: cute.Tensor, tidx: Int32, @@ -155,7 +166,7 @@ def epilogue_smem_copy_and_partition( @cute.jit def epilogue_tma_store( - gemm_kernel, + gemm_kernel: Any, epi_tidx: Int32, warp_idx: Int32, tma_atom_c: cute.CopyAtom, @@ -216,7 +227,7 @@ def epilogue_tma_store( bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] # Set tensor memory buffer for current tile - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] # @@ -230,20 +241,20 @@ def epilogue_tma_store( # # Store accumulator to global memory in subtiles # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) # type: ignore[union-attr] num_prev_subtiles = num_tiles_executed * subtile_cnt for subtile_idx in range(subtile_cnt): # # Load accumulator from tensor memory buffer to register # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] # type: ignore[call-overload] cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) # # Convert to C type # acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - acc_vec = epilogue_op(acc_vec.to(gemm_kernel.c_dtype)) + acc_vec = epilogue_op(acc_vec.to(gemm_kernel.c_dtype)) # type: ignore[operator] tRS_rC.store(acc_vec) # @@ -252,7 +263,10 @@ def epilogue_tma_store( c_buffer = (num_prev_subtiles + subtile_idx) % gemm_kernel.num_c_stage cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy("async.shared", space="cta") + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) epilog_sync_barrier.arrive_and_wait() # @@ -282,7 +296,7 @@ def epilogue_tma_store( @cute.jit def epilogue( - gemm_kernel, + gemm_kernel: Any, epi_tidx: Int32, tCtAcc_base: cute.Tensor, tCgC_base: cute.Tensor, @@ -291,9 +305,9 @@ def epilogue( mma_tile_coord_mnl: Tuple[Int32, Int32, Int32], acc_consumer_state: pipeline.PipelineState, acc_pipeline: pipeline.PipelineAsync, - tCcC_base: cute.Tensor = None, - mC_mnl: cute.Tensor = None, - overlapping_accum: Constexpr = False, + tCcC_base: Optional[cute.Tensor] = None, + mC_mnl: Optional[cute.Tensor] = None, + overlapping_accum: Constexpr = False, # type: ignore[assignment] ) -> pipeline.PipelineState: """ Epilogue function that stores accumulator results directly to global memory. @@ -365,7 +379,7 @@ def epilogue( ) simt_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), + cute.nvgpu.CopyR2GOp(), gemm_kernel.c_dtype, num_bits_per_copy=num_bits_per_copy, l1c_evict_priority=CacheEvictionPriority.NO_ALLOCATE, @@ -376,7 +390,7 @@ def epilogue( # Layout transformation for tCcC_base # ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, TILE_M, TILE_N, TILE_K) # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), TILE_M, TILE_N, TILE_K) - tCcC = transform_partitioned_tensor_layout(tCcC_base) + tCcC = transform_partitioned_tensor_layout(tCcC_base) # type: ignore[arg-type] cC_epi = cute.flat_divide(tCcC, epi_tile) tTR_cC_partitioned = thr_copy_t2r.partition_D(cC_epi) @@ -427,7 +441,7 @@ def epilogue( # # Store accumulator to global memory in subtiles # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) # type: ignore[union-attr] for subtile_idx in range(subtile_cnt): # Compute the actual subtile index real_subtile_idx = subtile_idx @@ -441,7 +455,7 @@ def epilogue( # # Load accumulator from tensor memory buffer to register # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] # type: ignore[call-overload] cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) # @@ -466,7 +480,7 @@ def epilogue( # Convert to C type # acc_vec = tTR_rAcc.load() - acc_vec = epilogue_op(acc_vec.to(gemm_kernel.c_dtype)) + acc_vec = epilogue_op(acc_vec.to(gemm_kernel.c_dtype)) # type: ignore[operator] tTR_rC.store(acc_vec) if const_expr(use_predication): @@ -478,7 +492,8 @@ def epilogue( for n_idx in range(tTR_cC_subtile.shape[2]): vector_first_coord = tTR_cC_subtile[(0, m_idx, n_idx)] pred_C[(0, m_idx, n_idx)] = cute.elem_less( - vector_first_coord, mC_mnl.shape + vector_first_coord, + mC_mnl.shape, # type: ignore[union-attr] ) # Store C to global memory with predication cute.copy(simt_atom, tTR_rC, tTR_gC_subtile, pred=pred_C) @@ -487,363 +502,3 @@ def epilogue( cute.copy(simt_atom, tTR_rC, tTR_gC_subtile) return acc_consumer_state - - -@cute.jit -def epilogue_tma_store_release_flag( - gemm_kernel, - epi_tidx: Int32, - warp_idx: Int32, - acc_pipeline: pipeline.PipelineAsync, - tiled_mma: cute.TiledMma, - tma_atom_c: cute.CopyAtom, - # Input of epilogue - tCtAcc_base: cute.Tensor, - # Staging of epilogue - sC: cute.Tensor, - # Output of epilogue - tCgC_base: cute.Tensor, - epi_tile: cute.Tile, - tile_sched: StaticPersistentTileScheduler, - epilogue_op: Constexpr, - flag_base: cute.Tensor, - flag_mem_scope: str, -) -> None: - # Layout transformation for tCgC_base - # ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, TILE_M, TILE_N, TILE_K) - # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), TILE_M, TILE_N, TILE_K) - tCgC = transform_partitioned_tensor_layout(tCgC_base) - # Layout transformation for tCtAcc_base - # ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE) - # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE) - tCtAcc = transform_partitioned_tensor_layout(tCtAcc_base) - tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = epilogue_tmem_copy_and_partition( - gemm_kernel, epi_tidx, tCtAcc, tCgC, epi_tile, gemm_kernel.use_2cta_instrs - ) - - tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, gemm_kernel.c_dtype) - tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( - gemm_kernel, tiled_copy_t2r, tTR_rC, epi_tidx, sC - ) - - # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) - tCgC_epi = cute.flat_divide(tCgC, epi_tile) - # ((ATOM_V, REST_V), EPI_M, EPI_N) - # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) - bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( - tma_atom_c, - 0, - cute.make_layout(1), - cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), - ) - - acc_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, gemm_kernel.num_acc_stage - ) - - # Threads/warps participating in tma store pipeline - c_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, - 32 * len(gemm_kernel.epilogue_warp_id), - ) - c_pipeline = pipeline.PipelineTmaStore.create( - num_stages=gemm_kernel.num_c_stage, producer_group=c_producer_group - ) - - epilog_sync_barrier = pipeline.NamedBarrier( - barrier_id=gemm_kernel.epilog_sync_bar_id, - num_threads=32 * len(gemm_kernel.epilogue_warp_id), - ) - - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler - cur_tile_coord = work_tile.tile_idx - mma_tile_coord_mnl = ( - cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), - cur_tile_coord[1], - cur_tile_coord[2], - ) - - # - # Slice to per mma tile index - # - # ((ATOM_V, REST_V), EPI_M, EPI_N) - bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] - - # Set tensor memory buffer for current tile - # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) - ] - - # - # Wait for accumulator buffer full - # - acc_pipeline.consumer_wait(acc_consumer_state) - - tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) - bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) - - # - # Store accumulator to global memory in subtiles - # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt - for subtile_idx in range(subtile_cnt): - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - - # - # Convert to C type - # - acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() - acc_vec = epilogue_op(acc_vec.to(gemm_kernel.c_dtype)) - tRS_rC.store(acc_vec) - - # - # Store C to shared memory - # - c_buffer = (num_prev_subtiles + subtile_idx) % gemm_kernel.num_c_stage - cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) - # Fence and barrier to make sure shared memory store is visible to TMA store - cute.arch.fence_proxy("async.shared", space="cta") - epilog_sync_barrier.arrive_and_wait() - - # - # TMA store C to global memory - # - if warp_idx == gemm_kernel.epilogue_warp_id[0]: - cute.copy( - tma_atom_c, - bSG_sC[(None, c_buffer)], - bSG_gC[(None, subtile_idx)], - ) - # Fence and barrier to make sure shared memory store is visible to TMA store - c_pipeline.producer_commit() - c_pipeline.producer_acquire() - epilog_sync_barrier.arrive_and_wait() - - epilog_sync_barrier.arrive_and_wait() - - # - # Async arrive accumulator buffer empty - # - with cute.arch.elect_one(): - acc_pipeline.consumer_release(acc_consumer_state) - acc_consumer_state.advance() - - # - # Set Per Output Tile Flag with Release - # - import cutlass.utils as utils - from cutlass._mlir.dialects.nvvm import ( - MemOrderKind, - MemScopeKind, - ) - - # 1D linear index of current output tile - tile_id_linear = Int32( - tile_sched._current_work_linear_idx - * cute.size(gemm_kernel.cluster_shape_mn) - + cute.arch.block_idx_in_cluster() - ) - # Wait for C store complete - # Unlike regular epilogue where we only wait C store complete once at end of each kernel. - # Here we need to wait for C store complete for each output tile before we set the release flag. - c_pipeline.producer_tail() - # Update flag with release semantic with GPU scope - if warp_idx == gemm_kernel.epilogue_warp_id[0]: - with cute.arch.elect_one(): - flag_curr_tile = flag_base.iterator + tile_id_linear - utils.distributed.multimem_red_add1( - lock_ptr=flag_curr_tile, - order="release", - scope=flag_mem_scope, - ) - - # - # Advance to next tile - # - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - -@cute.jit -def epilogue_release_flag( - gemm_kernel, - epi_tidx: Int32, - acc_pipeline: pipeline.PipelineAsync, - tiled_mma: cute.TiledMma, - tCtAcc_base: cute.Tensor, - tCgC_base: cute.Tensor, - epi_tile: cute.Tile, - tile_sched: StaticPersistentTileScheduler, - epilogue_op: Constexpr, - tmem_dealloc_barrier: pipeline.NamedBarrier, - flag_base: cute.Tensor, - flag_mem_scope: str, -) -> None: - """ - Epilogue function that stores accumulator results directly to global memory. - Used when TMA store is not enabled. - - :param gemm_kernel: The kernel instance - :type gemm_kernel: Any - :param epi_tidx: Thread index in epilogue warp groups - :type epi_tidx: Int32 - :param acc_pipeline: Accumulator pipeline for async operations - :type acc_pipeline: pipeline.PipelineAsync - :param tiled_mma: The tiled MMA configuration - :type tiled_mma: cute.TiledMma - :param tCtAcc_base: Base accumulator tensor in tensor memory - :type tCtAcc_base: cute.Tensor - :param tCgC_base: The global memory tensor C to be copied and partitioned - :type tCgC_base: cute.Tensor - :param epi_tile: Epilogue tile configuration - :type epi_tile: cute.Tile - :param tile_sched: Tile scheduler for persistent scheduling - :type tile_sched: StaticPersistentTileScheduler - :param epilogue_op: Optional elementwise operation to apply - :type epilogue_op: Constexpr - :param tmem_dealloc_barrier: Barrier for tensor memory deallocation - :type tmem_dealloc_barrier: pipeline.NamedBarrier - :param flag_base: Base flag tensor - :type flag_base: cute.Tensor - :param flag_mem_scope: Memory scope for flag - :type flag_mem_scope: str - """ - # Layout transformation for tCgC_base - # ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, TILE_M, TILE_N, TILE_K) - # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), TILE_M, TILE_N, TILE_K) - tCgC = transform_partitioned_tensor_layout(tCgC_base) - # Layout transformation for tCtAcc_base - # ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE) - # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE) - tCtAcc = transform_partitioned_tensor_layout(tCtAcc_base) - # - # Partition for epilogue - # - ( - tiled_copy_t2r, - tTR_tAcc_base, - tTR_rAcc, - ) = epilogue_tmem_copy_and_partition( - gemm_kernel, epi_tidx, tCtAcc, tCgC, epi_tile, gemm_kernel.use_2cta_instrs - ) - - gC_epi = cute.flat_divide(tCgC, epi_tile) - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) - thr_copy_t2r = tiled_copy_t2r.get_slice(epi_tidx) - tTR_gC_partitioned = thr_copy_t2r.partition_D(gC_epi) - # (T2R, T2R_M, T2R_N) - tTR_rC = cute.make_rmem_tensor( - tTR_gC_partitioned[(None, None, None, 0, 0, 0, 0, 0)].shape, gemm_kernel.c_dtype - ) - simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gemm_kernel.c_dtype) - - acc_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, gemm_kernel.num_acc_stage - ) - - work_tile = tile_sched.initial_work_tile_info() - while work_tile.is_valid_tile: - # Get tile coord from tile scheduler - cur_tile_coord = work_tile.tile_idx - mma_tile_coord_mnl = ( - cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), - cur_tile_coord[1], - cur_tile_coord[2], - ) - - # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) - tTR_gC = tTR_gC_partitioned[ - ( - None, - None, - None, - None, - None, - *mma_tile_coord_mnl, - ) - ] - - # Set tensor memory buffer for current tile - # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) - tTR_tAcc = tTR_tAcc_base[ - (None, None, None, None, None, acc_consumer_state.index) - ] - - # - # Wait for accumulator buffer full - # - acc_pipeline.consumer_wait(acc_consumer_state) - - tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) - tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) - # - # Store accumulator to global memory in subtiles - # - subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) - for subtile_idx in range(subtile_cnt): - # - # Load accumulator from tensor memory buffer to register - # - tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] - cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) - # Async arrive accumulator buffer empty - # Release early for perf - if subtile_idx == subtile_cnt - 1: - with cute.arch.elect_one(): - acc_pipeline.consumer_release(acc_consumer_state) - acc_consumer_state.advance() - - # - # Convert to C type - # - acc_vec = tTR_rAcc.load() - acc_vec = epilogue_op(acc_vec.to(gemm_kernel.c_dtype)) - tTR_rC.store(acc_vec) - - # - # Store C directly to global memory - # - cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) - - # - # Set Per Output Tile Flag with Release - # - import cutlass.utils as utils - - # 1D linear index of current output tile - tile_id_linear = Int32( - tile_sched._current_work_linear_idx - * cute.size(gemm_kernel.cluster_shape_mn) - + cute.arch.block_idx_in_cluster() - ) - # Wait for C store complete - # Unlike regular epilogue where we only wait C store complete once at end of each kernel. - # Here we need to wait for C store complete for each output tile before we set the release flag. - c_pipeline.producer_tail() - # Update flag with release semantic with GPU scope - if warp_idx == gemm_kernel.epilogue_warp_id[0]: - with cute.arch.elect_one(): - flag_curr_tile = flag_base.iterator + tile_id_linear - utils.distributed.multimem_red_add1( - lock_ptr=flag_curr_tile, - order="release", - scope=flag_mem_scope, - ) - - # - # Advance to next tile - # - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - # Synchronize before TMEM dealloc (done by the caller) - tmem_dealloc_barrier.arrive_and_wait() diff --git a/python/CuTeDSL/cutlass/utils/gemm/tensor_utils.py b/python/CuTeDSL/cutlass/utils/gemm/tensor_utils.py new file mode 100644 index 000000000..a01ceb71d --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/gemm/tensor_utils.py @@ -0,0 +1,375 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited + +""" +GEMM Tensor Utilities for CuTe DSL + +This module provides end-to-end helpers for creating and wrapping GEMM operands +when using CuTe DSL with PyTorch integration. It consolidates tensor creation, +layout management, and DLPack-based wrapping for GPU kernels. + +Key functions: +- create_gemm_tensor_torch: Allocate PyTorch CUDA tensors with MN-major or K-major layouts +- get_gemm_tensor: Wrap PyTorch tensors as CuTe tensors via DLPack +- get_gemm_tensors: Convenience wrapper for complete GEMM problem (A, B, D) +- create_scale_factor_tensor: Generate block-scaled GEMM scale factors +- decode_float4e2m1fn: Decode packed FP4 tensors to float32 + +Supports FP8/FP4 types with workarounds for DLPack limitations. +""" + +import math +from typing import Tuple, Type + +import torch + +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack +from cutlass.cutlass_dsl import Numeric +import cutlass.torch as cutlass_torch + +__all__ = [ + "create_gemm_tensor_torch", + "get_gemm_tensor", + "get_gemm_tensors", + "create_scale_factor_tensor", + "decode_float4e2m1fn", +] + + +def create_gemm_tensor_torch( + M_or_N: int, + K_or_N: int, + L: int, + major_mode: cute.nvgpu.OperandMajorMode, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Allocate a random GEMM operand as a PyTorch CUDA tensor. + + Returns a tensor of shape ``(M_or_N, K_or_N, L)`` with a physical layout + determined by ``major_mode``: + + * ``MN``-major: stride ``(1, M_or_N, M_or_N * K_or_N)`` + * ``K``-major: stride ``(K_or_N, 1, M_or_N * K_or_N)`` + + Elements are initialised randomly from ``{-1, 0, 1}``. + + :param M_or_N: Size of the M (for A/D) or N (for B) dimension. + :type M_or_N: int + :param K_or_N: Size of the K (for A/B) or N (for D) dimension. + :type K_or_N: int + :param L: Batch dimension (number of independent GEMM problems). + :type L: int + :param major_mode: ``MN`` for MN-major layout, ``K`` for K-major layout. + :type major_mode: cute.nvgpu.OperandMajorMode + :param dtype: Element type of the returned tensor. + :type dtype: torch.dtype + :return: A CUDA tensor of shape ``(M_or_N, K_or_N, L)`` with the + requested layout and dtype. + :rtype: torch.Tensor + """ + + if major_mode == cute.nvgpu.OperandMajorMode.MN: + result = torch.empty(L, K_or_N, M_or_N).permute(2, 1, 0) + elif major_mode == cute.nvgpu.OperandMajorMode.K: + result = torch.empty(L, M_or_N, K_or_N).permute(1, 2, 0) + + if dtype == torch.float4_e2m1fn_x2: + values = torch.tensor( + [ + 0x00, # { 0, 0} + 0x02, # { 0, 1} + 0x0A, # { 0, -1} + 0x20, # { 1, 0} + 0x22, # { 1, 1} + 0x2A, # { 1, -1} + 0xA0, # {-1, 0} + 0xA2, # {-1, 1} + 0xAA, # {-1, -1} + ], + dtype=torch.uint8, + ) + + # Note [dkb 16 Jan '26] we are consciously over-allocating by 2x + # because it makes the code much simpler and terser. + result = result.to(torch.uint8) + result[:] = values[torch.randint(0, len(values), result.size())] + + else: + result = result.random_(-1, 2).to(dtype) + + return result.cuda() + + +def get_gemm_tensor( + torch_tensor: torch.Tensor, + major_mode: cute.nvgpu.OperandMajorMode, + dtype: torch.dtype, +) -> cute.Tensor: + """ + Wrap a PyTorch tensor as a CuTe tensor for passing to a GPU kernel. + + Converts ``torch_tensor`` to a CuTe ``Tensor`` via DLPack and marks the + leading dimension as dynamic so the compiler treats the corresponding + stride as a runtime value: + + * ``K``-major -> leading dim is 1 (K dimension has unit stride) + * ``MN``-major -> leading dim is 0 (MN dimension has unit stride) + + :param torch_tensor: Source PyTorch CUDA tensor, typically created by + :func:`create_gemm_tensor_torch`. + :type torch_tensor: torch.Tensor + :param major_mode: ``MN`` or ``K``, indicating which dimension has unit + stride. + :type major_mode: cute.nvgpu.OperandMajorMode + :param dtype: Logical element type of the tensor. + :type dtype: torch.dtype + :return: A CuTe tensor backed by the same GPU memory, with its leading + dimension marked dynamic and element type set to the CuTe equivalent + of ``dtype``. + :rtype: cutlass.cute.Tensor + """ + + _TYPES_NOT_SUPPORTED_BY_DLPACK = { + torch.float8_e4m3fn: cute.Float8E4M3FN, + torch.float8_e5m2: cute.Float8E5M2, + torch.float8_e4m3fnuz: cute.Float8E4M3B11FNUZ, + torch.float4_e2m1fn_x2: cute.Float4E2M1FN, + } + + dlpack_tensor = torch_tensor + if dtype in _TYPES_NOT_SUPPORTED_BY_DLPACK: + dlpack_tensor = torch_tensor.view(dtype=torch.uint8) + + result = from_dlpack(dlpack_tensor, assumed_align=16).mark_layout_dynamic( + leading_dim=1 if major_mode == cute.nvgpu.OperandMajorMode.K else 0 + ) + + if dtype in _TYPES_NOT_SUPPORTED_BY_DLPACK: + result.element_type = _TYPES_NOT_SUPPORTED_BY_DLPACK[dtype] + + return result + + +def get_gemm_tensors( + M: int, + N: int, + K: int, + L: int, + majors: tuple[ + cute.nvgpu.OperandMajorMode, + cute.nvgpu.OperandMajorMode, + cute.nvgpu.OperandMajorMode, + ], + dtypes: tuple[torch.dtype, torch.dtype, torch.dtype], +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, cute.Tensor, cute.Tensor, cute.Tensor +]: + """ + Allocate all three GEMM operands (A, B, D) as paired PyTorch / CuTe tensors. + + Convenience wrapper around :func:`create_gemm_tensor_torch` and + :func:`get_gemm_tensor` for a complete GEMM problem ``D = A @ B``. + + :param M: Number of rows of A and D. + :type M: int + :param N: Number of columns of B and D. + :type N: int + :param K: Shared (contraction) dimension of A and B. + :type K: int + :param L: Batch dimension (number of independent GEMM problems). + :type L: int + :param majors: ``(major_A, major_B, major_D)`` specifying the physical + layout of each operand. + :type majors: tuple[OperandMajorMode, OperandMajorMode, OperandMajorMode] + :param dtypes: ``(dtype_A, dtype_B, dtype_D)`` specifying the element type + of each operand. + :type dtypes: tuple[torch.dtype, torch.dtype, torch.dtype] + :return: ``(A, B, D, A_cute, B_cute, D_cute)`` where ``A``, ``B``, ``D`` + are PyTorch CUDA tensors and ``A_cute``, ``B_cute``, ``D_cute`` are + CuTe tensors wrapping the same memory. + :rtype: tuple[torch.Tensor, torch.Tensor, torch.Tensor, + cutlass.cute.Tensor, cutlass.cute.Tensor, cutlass.cute.Tensor] + """ + + A = create_gemm_tensor_torch(M, K, L, majors[0], dtypes[0]) + B = create_gemm_tensor_torch(N, K, L, majors[1], dtypes[1]) + D = create_gemm_tensor_torch(M, N, L, majors[2], dtypes[2]) + + A_cute = get_gemm_tensor(A, majors[0], dtypes[0]) + B_cute = get_gemm_tensor(B, majors[1], dtypes[1]) + D_cute = get_gemm_tensor(D, majors[2], dtypes[2]) + + return A, B, D, A_cute, B_cute, D_cute + + +def create_scale_factor_tensor( + MN: int, K: int, L: int, sf_vec_size: int, sf_dtype: Type[Numeric] +) -> Tuple[torch.Tensor, cute.Tensor]: + """ + Create a random scale-factor tensor in BlockScaledBasicChunk layout. + + Allocates a scale-factor tensor for block-scaled GEMM, where each scale + factor covers ``sf_vec_size`` contiguous elements along K. The MN and K + dimensions are padded up to the required atom boundaries. + + Scale factor values are drawn uniformly from ``{1.0, 2.0, 4.0}``. + + Two tensors are returned: a logical FP32 tensor for host-side reference + computation, and a CuTe tensor in the on-device packed layout for passing + to the GPU kernel. + + :param MN: Size of the MN dimension of the operand to be scaled. + :type MN: int + :param K: Size of the K dimension of the operand to be scaled. + :type K: int + :param L: Batch dimension (number of independent GEMM problems). + :type L: int + :param sf_vec_size: Number of contiguous K-elements sharing a single + scale factor (block-scaling granularity along K). + :type sf_vec_size: int + :param sf_dtype: CuTe element type for the on-device scale factors + (e.g. ``cute.Float8E4M3FN``, ``cute.Float8E8M0FNU``). + :return: ``(sf_torch, sf_cute)`` where ``sf_torch`` is an FP32 CPU tensor + of shape ``(MN, K, L)`` with scale factors unpacked into a dense + layout suitable for element-wise multiplication with A or B, and + ``sf_cute`` is a CuTe CUDA tensor in BlockScaledBasicChunk layout + with ``element_type`` set to ``sf_dtype``. + :rtype: tuple[torch.Tensor, cutlass.cute.Tensor] + """ + + def unpack_scale_factors( + sf: torch.Tensor, sf_vec_size: int, MN: int, K: int, L: int + ) -> torch.Tensor: + """ + Unpack a scale-factor tensor from BlockScaledBasicChunk layout to a + dense ``(MN, K, L)`` tensor. + + The on-device SF layout packs scale factors into 512-byte atoms with + a specific index mapping: 128-row MN tiles, 4-element K groups, and + an interleaved ``(mn0, mn1, k1)`` addressing within each atom. This + function inverts that mapping so that the output element at position + ``(m, k, l)`` holds the scale factor that applies to element + ``(m, k, l)`` in A/B, ready for direct element-wise multiplication. + + :param sf: Scale-factor tensor in packed layout, shape + ``(L, m_padded, k_padded)``. + :type sf: torch.Tensor + :param sf_vec_size: Number of K-elements per scale-factor block. + :type sf_vec_size: int + :param MN: Logical (unpadded) MN dimension. + :type MN: int + :param K: Logical (unpadded) K dimension. + :type K: int + :param L: Batch dimension. + :type L: int + :return: Dense FP32 tensor of shape ``(MN, K, L)`` with unpacked + scale factors. + :rtype: torch.Tensor + """ + + def index_map() -> torch.Tensor: + ATOM = 512 + ATOM_MN = 128 + ATOM_K = 4 + DATA_PATHS = 32 + DATA_PATH_STRIDE = ATOM // DATA_PATHS # 16 + K_TILE = ATOM_K * sf_vec_size + + k_tiles = (K + K_TILE - 1) // K_TILE + mn_tiles = (MN + ATOM_MN - 1) // ATOM_MN + sf_per_l = ATOM * mn_tiles * k_tiles + + m, k = torch.meshgrid( + torch.arange(MN, device=sf.device), + torch.arange(K, device=sf.device), + indexing="ij", + ) + + base = ( + (m // ATOM_MN) * (ATOM * k_tiles) + + (k // K_TILE) * ATOM + + DATA_PATH_STRIDE * (m % DATA_PATHS) + + ATOM_K * ((m % ATOM_MN) // DATA_PATHS) + + ((k // sf_vec_size) % ATOM_K) + ) + + l_offsets = torch.arange(L, device=sf.device)[:, None, None] * sf_per_l + return base.unsqueeze(0) + l_offsets + + return sf.flatten()[index_map()].permute(1, 2, 0) + + ATOM_MN = 128 + ATOM_K = 4 + + m_padded = math.ceil(MN / ATOM_MN) * ATOM_MN + k_padded = math.ceil(math.ceil(K / sf_vec_size) / ATOM_K) * ATOM_K + + lut = torch.tensor([1.0, 2.0, 4.0]) # subset of numbers supported by all SF dtypes + sf_torch = lut[torch.randint(0, lut.numel(), (L, m_padded, k_padded))].to( + cutlass_torch.dtype(sf_dtype) + ) + + sf_cute = from_dlpack(sf_torch.cuda().view(dtype=torch.uint8), assumed_align=16) + sf_cute.element_type = sf_dtype + sf_torch = unpack_scale_factors(sf_torch.to(torch.float32), sf_vec_size, MN, K, L) + + return sf_torch, sf_cute + + +def decode_float4e2m1fn(u8: torch.Tensor) -> torch.Tensor: + """ + Decode a packed FP4 (E2M1) tensor into float32. + + Each byte in the input encodes two FP4 values: the low nibble holds the + even-indexed element and the high nibble holds the odd-indexed element. + Because ``create_gemm_tensor_torch`` intentionally over-allocates FP4 + tensors by 2x (one byte per logical element instead of one nibble), only + the first half of the input bytes contain data. This function unpacks + those bytes via a 16-entry LUT covering all representable E2M1 values. + + :param u8: Packed FP4 tensor of shape ``(MN, K, L)`` with dtype + ``torch.uint8``, as produced by :func:`create_gemm_tensor_torch` + with ``dtype=torch.float4_e2m1fn_x2``. + :type u8: torch.Tensor + :return: Decoded float32 tensor of shape ``(MN, K, L)``. + :rtype: torch.Tensor + """ + + lut = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + ) + + MN, K, L = u8.shape + flat = u8.permute(2, 0, 1).flatten() + idx = torch.arange(u8.numel()) + byte_idx = idx // 2 + shift = (idx % 2) * 4 + return lut[(flat[byte_idx] >> shift) & 0xF].view(L, MN, K).permute(1, 2, 0) diff --git a/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py index 9746d304f..efd683cf7 100644 --- a/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/grouped_gemm_persistent_tile_scheduler.py @@ -9,7 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import List, Tuple +from typing import List, Optional, Tuple, Union import cutlass.cute as cute from cutlass.cutlass_dsl import ( @@ -20,7 +20,6 @@ from cutlass.cutlass_dsl import ( new_from_mlir_values, const_expr, dsl_user_op, - min, ) from cutlass._mlir import ir from typing_extensions import deprecated @@ -226,6 +225,8 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): search_state: GroupedGemmGroupSearchState, group_count: int, problem_shape_mnkl: cute.Tensor, + cached_problem_shape_0: cute.Tensor, + cached_problem_shape_1: cute.Tensor, ): StaticPersistentTileScheduler.__init__( self, @@ -241,6 +242,9 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): self.search_state = search_state self.problem_shape_mnkl = problem_shape_mnkl + self.cached_problem_shape_0 = cached_problem_shape_0 + self.cached_problem_shape_1 = cached_problem_shape_1 + def __extract_mlir_values__(self) -> list[ir.Value]: values = extract_mlir_values(self.num_persistent_clusters) values.extend(extract_mlir_values(self._current_work_linear_idx)) @@ -248,13 +252,15 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): values.extend(extract_mlir_values(self._num_tiles_executed)) values.extend(extract_mlir_values(self.search_state)) values.extend(extract_mlir_values(self.problem_shape_mnkl)) + values.extend(extract_mlir_values(self.cached_problem_shape_0)) + values.extend(extract_mlir_values(self.cached_problem_shape_1)) values.extend(extract_mlir_values(self.params)) return values def __new_from_mlir_values__( self, values: list[ir.Value] ) -> "StaticPersistentGroupTileScheduler": - if len(values) < 11: + if len(values) < 13: raise ValueError("Length of mlir values extracted is incorrect.") new_num_persistent_clusters = new_from_mlir_values( self.num_persistent_clusters, [values[0]] @@ -270,7 +276,13 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): ) search_state = new_from_mlir_values(self.search_state, values[6:10]) problem_shape_mnkl = new_from_mlir_values(self.problem_shape_mnkl, [values[10]]) - params = new_from_mlir_values(self.params, values[11:]) + cached_problem_shape_0 = new_from_mlir_values( + self.cached_problem_shape_0, [values[11]] + ) + cached_problem_shape_1 = new_from_mlir_values( + self.cached_problem_shape_1, [values[12]] + ) + params = new_from_mlir_values(self.params, values[13:]) return StaticPersistentGroupTileScheduler( params, @@ -282,6 +294,8 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): search_state, self.group_count, problem_shape_mnkl, + cached_problem_shape_0, + cached_problem_shape_1, ) @staticmethod @@ -295,9 +309,9 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): group_count: int, problem_shape_mnkl: cute.Tensor, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "StaticPersistentGroupTileScheduler": """Initialize the static persistent group-based tile scheduler. :param params: Parameters for the persistent @@ -338,6 +352,13 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): Int32(0), ) + cached_problem_shape_0 = cute.make_rmem_tensor( + cute.make_layout(4), problem_shape_mnkl.element_type + ) + cached_problem_shape_1 = cute.make_rmem_tensor( + cute.make_layout(4), problem_shape_mnkl.element_type + ) + # Initialize number of tiles executed to zero num_tiles_executed = Int32(0) return StaticPersistentGroupTileScheduler( @@ -350,11 +371,28 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): initial_search_state, group_count, problem_shape_mnkl, + cached_problem_shape_0, + cached_problem_shape_1, ) + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed + + # This setter is the main way to prevent the Attribute error right now + @num_tiles_executed.setter + def num_tiles_executed(self, value: Int32) -> None: + self._num_tiles_executed = value + @dsl_user_op @cute.jit - def _prefix_sum(self, value_per_thread: Int32, *, loc=None, ip=None) -> Int32: + def _prefix_sum( + self, + value_per_thread: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Int32: """ Perform prefix sum within a full warp. @@ -377,7 +415,12 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): @dsl_user_op def _get_problem_for_group( - self, problem_shape_mnkl: cute.Tensor, group_idx: Int32, *, loc=None, ip=None + self, + problem_shape_mnkl: cute.Tensor, + group_idx: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Tensor: """ Load gemm problem (m,n,k,l) for the specified group from global memory to register. @@ -397,9 +440,27 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): ) return cur_problem_mnkl + @dsl_user_op + @cute.jit + def prefetch_problem_shapes( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + if self.lane_idx < self.group_count: + cur_problem_mnkl = self._get_problem_for_group( + self.problem_shape_mnkl, self.lane_idx + ) + self.cached_problem_shape_1 = cur_problem_mnkl + @dsl_user_op def _get_cluster_tile_count_mn( - self, problem_shape: cute.Tensor, *, loc=None, ip=None + self, + problem_shape: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Int32: """ Compute total cluster count. @@ -410,13 +471,13 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): :rtype: Int32 """ cur_ntile_m = ( - problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 + problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 # type: ignore[operator] ) // self.cluster_tile_shape_mnk[0] cur_ntile_n = ( - problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 + problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 # type: ignore[operator] ) // self.cluster_tile_shape_mnk[1] cur_ntile_mn = cur_ntile_m * cur_ntile_n - return cur_ntile_mn + return cur_ntile_mn # type: ignore[return-value] @dsl_user_op def _compute_cta_tile_coord( @@ -426,8 +487,8 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): cluster_tile_count_m: Int32, cluster_tile_count_n: Int32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> tuple: """ Compute CTA tile indices along M and N dimensions based on the linear index within a group. @@ -466,8 +527,8 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): init_group_idx: Int32, init_tile_count_searched: Int32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> GroupedGemmGroupSearchState: """ Search which group the linear index belongs to. @@ -491,6 +552,7 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): not_found = linear_idx >= tile_count_searched start_not_found = not_found tile_count_prev_group = self.search_state.tile_count_prev_group + tidx, _, _ = cute.arch.thread_idx() while not_found and start_group_idx < self.group_count: # get group to search for current lane @@ -498,6 +560,18 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): # check if the group to be checked is out of range inside_group_bound = cur_group_idx < self.group_count + # Rotate cache + self.cached_problem_shape_0 = self.cached_problem_shape_1 + + # Prefetch problem shape for next while iteration + next_prefetch_group_idx = ( + start_group_idx + cute.arch.WARP_SIZE + self.lane_idx + ) + if next_prefetch_group_idx < self.group_count: + self.cached_problem_shape_1 = self._get_problem_for_group( + problem_shape_mnkl, next_prefetch_group_idx + ) + cur_ntile_mn = c_0 if inside_group_bound: # get problem size of current group @@ -538,6 +612,7 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): # If no matched group, then get new_cluster_tile_count_end from last lane # Otherwise, get new_cluster_tile_count_end from the hitted group lane_idx_for_cluster_tile_count_end = hitted_group_idx_in_search_window + if not_found: lane_idx_for_cluster_tile_count_end = last_lane_idx tile_count_searched = cute.arch.shuffle_sync( @@ -545,6 +620,13 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): lane_idx_for_cluster_tile_count_end, ) + # Prefetch problem shape for next wave + if not not_found: + if start_group_idx + self.lane_idx < self.group_count: + self.cached_problem_shape_1 = self._get_problem_for_group( + problem_shape_mnkl, start_group_idx + self.lane_idx + ) + # The tile is invalid if not_found doesn't change before and after the while loop. end_not_found = not_found is_valid = start_not_found != end_not_found @@ -565,9 +647,9 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): start_group_idx: Int32, tile_count_searched: Int32, *, - loc=None, - ip=None, - ) -> Tuple[Int32, cute.Tensor]: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> Tuple[Boolean, Union[Int32, int], cute.Tensor]: """ Perform group search and load problem shape for the matched group. @@ -590,16 +672,16 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): loc=loc, ip=ip, ) + tidx, _, _ = cute.arch.thread_idx() # get final group search state found = self.search_state.found - final_group_idx = -1 + final_group_idx: Union[Int32, int] = -1 problem_mnkl = cute.make_rmem_tensor( cute.make_layout(4), problem_shape_mnkl.element_type, loc=loc, ip=ip ) if found: final_group_idx = self.search_state.start_group_idx - # let's revisit if it's better to broadcast problem_shape_mnk in group_search problem_mnkl = self._get_problem_for_group( problem_shape_mnkl, final_group_idx, loc=loc, ip=ip ) @@ -610,9 +692,9 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): self, cta_tile_coord: tuple, *, - loc=None, - ip=None, - ) -> GroupSearchResult: + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "GroupedWorkTileInfo": """ Delinearize the linear z index and return GroupSearchResult. @@ -680,14 +762,22 @@ class StaticPersistentGroupTileScheduler(StaticPersistentTileScheduler): return GroupedWorkTileInfo(cta_tile_coord, is_valid, group_search_result) @dsl_user_op - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def get_current_work( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> WorkTileInfo: work_tile = self._get_current_work_for_linear_idx( self._current_work_linear_idx, loc=loc, ip=ip ) grouped_work_tile = self.delinearize_z(work_tile.tile_idx, loc=loc, ip=ip) return grouped_work_tile -@deprecated("API is deprecated, use cutlass.utils.StaticPersistentGroupTileScheduler instead") + +@deprecated( + "API is deprecated, use cutlass.utils.StaticPersistentGroupTileScheduler instead" +) class GroupedGemmTileSchedulerHelper: """ A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm. @@ -790,9 +880,9 @@ class GroupedGemmTileSchedulerHelper: group_idx, cta_tile_idx_m, cta_tile_idx_n, - problem_mnkl[0], - problem_mnkl[1], - problem_mnkl[2], + problem_mnkl[0], # type: ignore[arg-type] + problem_mnkl[1], # type: ignore[arg-type] + problem_mnkl[2], # type: ignore[arg-type] cluster_count_k, ) @@ -820,9 +910,9 @@ class GroupedGemmTileSchedulerHelper: self.search_state.tile_count_prev_group, ) cluster_count_k = ( - problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1 + problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1 # type: ignore[operator] ) // self.cluster_tile_shape_mnk[2] - return cluster_count_k, group_idx + return cluster_count_k, group_idx # type: ignore[return-value] @cute.jit def _prefix_sum(self, value_per_thread: Int32) -> Int32: @@ -875,13 +965,13 @@ class GroupedGemmTileSchedulerHelper: :rtype: Int32 """ cur_ntile_m = ( - problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 + problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 # type: ignore[operator] ) // self.cluster_tile_shape_mnk[0] cur_ntile_n = ( - problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 + problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 # type: ignore[operator] ) // self.cluster_tile_shape_mnk[1] cur_ntile_mn = cur_ntile_m * cur_ntile_n - return cur_ntile_mn + return cur_ntile_mn # type: ignore[return-value] def _compute_cta_tile_coord( self, @@ -997,7 +1087,7 @@ class GroupedGemmTileSchedulerHelper: start_group_idx, tile_count_prev_group, tile_count_searched, - 1 # found will always be 1 for old api + Boolean(1), # found will always be 1 for old api ) def _group_search_and_load_problem_shape( @@ -1032,4 +1122,3 @@ class GroupedGemmTileSchedulerHelper: # let's revisit if it's better to broadcast problem_shape_mnk in group_search problem_mnkl = self._get_problem_for_group(problem_shape_mnkl, final_group_idx) return final_group_idx, problem_mnkl - diff --git a/python/CuTeDSL/cutlass/utils/hardware_info.py b/python/CuTeDSL/cutlass/utils/hardware_info.py index 5549d0e3e..d95c32979 100644 --- a/python/CuTeDSL/cutlass/utils/hardware_info.py +++ b/python/CuTeDSL/cutlass/utils/hardware_info.py @@ -8,8 +8,9 @@ # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from cuda.bindings import driver, runtime -from cutlass.base_dsl.common import DSLRuntimeError +from typing import Any + +from cuda.bindings import driver from cutlass import cute import tempfile @@ -154,7 +155,7 @@ class HardwareInfo: ) ) - def _checkCudaErrors(self, result) -> None: + def _checkCudaErrors(self, result: Any) -> Any: if result[0].value: raise RuntimeError( "CUDA error code={}({})".format( @@ -169,7 +170,7 @@ class HardwareInfo: else: return result[1:] - def _cudaGetErrorEnum(self, error) -> str: + def _cudaGetErrorEnum(self, error: Any) -> str: if isinstance(error, driver.CUresult): err, name = driver.cuGetErrorName(error) return name if err == driver.CUresult.CUDA_SUCCESS else "" @@ -183,11 +184,11 @@ class HardwareInfo: return not self._cuda_driver_version_ge(major, minor) @cute.kernel - def _empty_kernel(self): + def _empty_kernel(self) -> None: return @cute.jit - def _host_function(self): + def _host_function(self) -> None: self._empty_kernel().launch( grid=[1, 1, 1], block=[1, 1, 1], diff --git a/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/python/CuTeDSL/cutlass/utils/hopper_helpers.py index 5268de7db..d6b19f40c 100644 --- a/python/CuTeDSL/cutlass/utils/hopper_helpers.py +++ b/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -11,6 +11,7 @@ from typing import Type, Union, Tuple, Optional +from cutlass._mlir import ir from cutlass.utils.layout import LayoutEnum from cutlass.cutlass_dsl import ( Float16, @@ -25,13 +26,12 @@ from cutlass.cutlass_dsl import ( ) import cutlass.cute as cute -from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.nvgpu.common import CopyUniversalOp, OperandMajorMode from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp from cutlass.cute.nvgpu.warpgroup import ( MmaF16BF16Op, MmaF8Op, MmaI8Op, - OperandMajorMode, OperandSource as WarpgroupOperandSource, make_smem_layout_atom, ) @@ -46,8 +46,8 @@ def get_smem_store_op( elem_ty_d: Type[Numeric], elem_ty_acc: Type[Numeric], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.CopyAtom: """ Selects the largest vectorized smem store atom available subject to constraint of gmem layout. @@ -68,7 +68,7 @@ def get_smem_store_op( Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. """ - def validate_type(ty, ty_name): + def validate_type(ty: Type[Numeric], ty_name: str) -> None: if not isinstance(ty, NumericMeta): raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") @@ -99,8 +99,8 @@ def make_trivial_tiled_mma( tiler_mn: Tuple[int, int], a_source: OperandSource = OperandSource.SMEM, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.TiledMma: """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. By default, the MMA atom is created with SMEM operand source for A. @@ -110,9 +110,9 @@ def make_trivial_tiled_mma( :param b_dtype: Data type of operand B. :type b_dtype: type[Numeric] :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). - :type a_leading_mode: warpgroup.OperandMajorMode + :type a_leading_mode: cutlass.cute.nvgpu.OperandMajorMode :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). - :type b_leading_mode: warpgroup.OperandMajorMode + :type b_leading_mode: cutlass.cute.nvgpu.OperandMajorMode :param acc_dtype: Data type of the accumulator. :type acc_dtype: type[Numeric] :param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads. @@ -144,7 +144,7 @@ def make_trivial_tiled_mma( Float8E4M3FN, Float8E5M2, }: - mma_op = MmaF8Op( + mma_op = MmaF8Op( # type: ignore[assignment] a_dtype, b_dtype, acc_dtype, @@ -154,7 +154,7 @@ def make_trivial_tiled_mma( b_leading_mode, ) elif a_dtype in {Int8, Uint8} and b_dtype in {Int8, Uint8}: - mma_op = MmaI8Op( + mma_op = MmaI8Op( # type: ignore[assignment] a_dtype, b_dtype, acc_dtype, @@ -175,9 +175,9 @@ def get_smem_layout_atom( element_type: Type[Numeric], major_mode_size: int, *, - loc=None, - ip=None, -): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> "cute.nvgpu.warpgroup.SmemLayoutAtomKind": """Select the optimal shared memory layout atom based on parameters. :param layout: Layout enum of the tensor @@ -188,7 +188,7 @@ def get_smem_layout_atom( :type major_mode_size: int :return: Selected shared memory layout atom kind - :rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind + :rtype: cutlass.cute.nvgpu.warpgroup.SmemLayoutAtomKind """ assert major_mode_size % 8 == 0 sw128_num_contiguous_bits = 1024 @@ -219,8 +219,8 @@ def make_smem_layout_a( a_dtype: Type[Numeric], num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps with: @@ -247,7 +247,7 @@ def make_smem_layout_a( # Determine if K is the major mode and get the major mode size is_k_major = a_layout.is_k_major_a() - a_major_mode_size = a_tile_shape_mnk[2] if is_k_major else a_tile_shape_mnk[0] + a_major_mode_size = a_tile_shape_mnk[2] if is_k_major else a_tile_shape_mnk[0] # type: ignore[index] # Create SMEM layout atom for A tensor based on major mode and data type a_smem_layout_atom = make_smem_layout_atom( @@ -276,8 +276,8 @@ def make_smem_layout_b( b_dtype: Type[Numeric], num_stages: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps with: @@ -303,7 +303,7 @@ def make_smem_layout_b( # Determine if K is the major mode and get the major mode size is_k_major = b_layout.is_k_major_b() - b_major_mode_size = mma_tiler_mnk[2] if is_k_major else mma_tiler_mnk[1] + b_major_mode_size = mma_tiler_mnk[2] if is_k_major else mma_tiler_mnk[1] # type: ignore[index] # Create SMEM layout atom for B tensor based on major mode and data type b_smem_layout_atom = make_smem_layout_atom( @@ -334,8 +334,8 @@ def make_smem_layout_epi( smem_trg_shape: Optional[cute.Layout] = None, smem_order: Optional[tuple] = None, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Union[cute.Layout, cute.ComposedLayout]: """This function helps: @@ -364,7 +364,7 @@ def make_smem_layout_epi( o_smem_shape = epi_tile # Determine major mode size based on layout (M or N major) - o_major_mode_size = epi_tile[1] if epi_layout.is_n_major_c() else epi_tile[0] + o_major_mode_size = epi_tile[1] if epi_layout.is_n_major_c() else epi_tile[0] # type: ignore[index] # Create SMEM layout atom for output tensor based on layout and data type o_smem_layout_atom = make_smem_layout_atom( diff --git a/python/CuTeDSL/cutlass/utils/layout.py b/python/CuTeDSL/cutlass/utils/layout.py index 53985b7be..984f9a2a6 100644 --- a/python/CuTeDSL/cutlass/utils/layout.py +++ b/python/CuTeDSL/cutlass/utils/layout.py @@ -12,44 +12,39 @@ from enum import Enum import cutlass.cute as cute -from cutlass.cute.nvgpu import warpgroup -from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu import OperandMajorMode class LayoutEnum(Enum): ROW_MAJOR = "row_major" COL_MAJOR = "col_major" - def mma_major_mode(self): + def mma_major_mode(self) -> OperandMajorMode: return ( - tcgen05.OperandMajorMode.K - if self == LayoutEnum.ROW_MAJOR - else tcgen05.OperandMajorMode.MN + OperandMajorMode.K if self == LayoutEnum.ROW_MAJOR else OperandMajorMode.MN ) - def sm90_mma_major_mode(self): + def sm90_mma_major_mode(self) -> OperandMajorMode: return ( - warpgroup.OperandMajorMode.K - if self == LayoutEnum.ROW_MAJOR - else warpgroup.OperandMajorMode.MN + OperandMajorMode.K if self == LayoutEnum.ROW_MAJOR else OperandMajorMode.MN ) - def is_k_major_a(self): + def is_k_major_a(self) -> bool: return self == LayoutEnum.ROW_MAJOR - def is_m_major_a(self): + def is_m_major_a(self) -> bool: return self == LayoutEnum.COL_MAJOR - def is_n_major_b(self): + def is_n_major_b(self) -> bool: return self == LayoutEnum.COL_MAJOR - def is_k_major_b(self): + def is_k_major_b(self) -> bool: return self == LayoutEnum.ROW_MAJOR - def is_n_major_c(self): + def is_n_major_c(self) -> bool: return self == LayoutEnum.ROW_MAJOR - def is_m_major_c(self): + def is_m_major_c(self) -> bool: return self == LayoutEnum.COL_MAJOR @staticmethod diff --git a/python/CuTeDSL/cutlass/utils/mixed_input_helpers.py b/python/CuTeDSL/cutlass/utils/mixed_input_helpers.py index 477fcb945..869995a60 100644 --- a/python/CuTeDSL/cutlass/utils/mixed_input_helpers.py +++ b/python/CuTeDSL/cutlass/utils/mixed_input_helpers.py @@ -3,7 +3,7 @@ # # Use of this software is governed by the terms and conditions of the # NVIDIA End User License Agreement (EULA), available at: -# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html # # Any use, reproduction, disclosure, or distribution of this software # and related documentation outside the scope permitted by the EULA @@ -17,7 +17,9 @@ from typing import Optional, Union import cutlass import cutlass.cute as cute +from cutlass._mlir import ir from cutlass.cutlass_dsl import ( + Boolean, extract_mlir_values, new_from_mlir_values, ) @@ -68,7 +70,7 @@ def scale_tma_partition( """ tSsS, tSgS = cpasync.tma_partition( tma_atom_s, - block_in_cluster_coord_vmnk[2], + block_in_cluster_coord_vmnk[2], # type: ignore[index] scale_cta_layout, cute.group_modes(tCsS, 0, 3), cute.group_modes(tCgS, 0, 3), @@ -170,7 +172,7 @@ def transform_partition( reg2smem_tiled_copy = cute.make_cotiled_copy( copy_atom_a_transform, cute.make_layout((128, 8), stride=(8, 1)), - A_transform[(None, None, None, 0)].layout, + A_transform[(None, None, None, 0)].layout, # type: ignore[union-attr] ) thr_reg2smem_tiled_copy = reg2smem_tiled_copy.get_slice(transform_local_tidx) partitioned_tensor_input = thr_reg2smem_tiled_copy.partition_S(sA_input) @@ -281,7 +283,8 @@ def epilog_gmem_copy_and_partition( tTR_gC = None if tma_atom_c is not None: gC_epi_tma = cute.flat_divide( - gC_mnl_tma[((None, None), 0, 0, None, None, None)], epi_tile + gC_mnl_tma[((None, None), 0, 0, None, None, None)], # type: ignore[index, arg-type] + epi_tile, ) # TMA store sC_for_tma_partition = cute.group_modes(sC, 0, 2) @@ -298,7 +301,8 @@ def epilog_gmem_copy_and_partition( if tiled_copy_t2r is not None: # SIMT Store gC_epi_simt = cute.flat_divide( - gC_mnl_simt[((None, None), 0, 0, None, None, None)], epi_tile + gC_mnl_simt[((None, None), 0, 0, None, None, None)], # type: ignore[index, arg-type] + epi_tile, ) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) @@ -389,7 +393,7 @@ def epilog_tmem_copy_and_partition( ) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) tAcc_epi = cute.flat_divide( - tAcc[((None, None), 0, 0, None)], + tAcc[((None, None), 0, 0, None)], # type: ignore[arg-type] epi_tile, ) # (EPI_TILE_M, EPI_TILE_N) @@ -401,7 +405,8 @@ def epilog_tmem_copy_and_partition( tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) gC_mnl_epi = cute.flat_divide( - gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + gC_mnl[((None, None), 0, 0, None, None, None)], # type: ignore[arg-type] + epi_tile, ) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) @@ -416,7 +421,7 @@ def get_gmem_layout_scale( scale_shape_mkl: tuple[int, int, int], scale_granularity_m: int, scale_granularity_k: int, - scale_major_mode: tcgen05.OperandMajorMode, + scale_major_mode: cutlass.cute.nvgpu.OperandMajorMode, ) -> cute.Layout: """ Get the layout of the scale tensor in global memory. @@ -430,7 +435,7 @@ def get_gmem_layout_scale( (scale_granularity_m, cute.ceil_div(m, scale_granularity_m)), (scale_granularity_k, cute.ceil_div(k, scale_granularity_k)), ) - if cutlass.const_expr(scale_major_mode == tcgen05.OperandMajorMode.MN): + if cutlass.const_expr(scale_major_mode == cutlass.cute.nvgpu.OperandMajorMode.MN): layout_mk = cute.make_layout( shape_scale, stride=( @@ -457,7 +462,7 @@ def get_smem_layout_scale( use_2cta_instrs: bool, scale_granularity_m: int, scale_granularity_k: int, - scale_major_mode: tcgen05.OperandMajorMode, + scale_major_mode: cutlass.cute.nvgpu.OperandMajorMode, a_scale_dtype: type[cutlass.Numeric], num_scale_load2trans_stage: int, ) -> tuple[tuple[int, int], cute.ComposedLayout, cute.ComposedLayout]: @@ -485,7 +490,7 @@ def get_smem_layout_scale( (smem_size_mn, div_mn), (smem_size_k, div_k), ) - if cutlass.const_expr(scale_major_mode == tcgen05.OperandMajorMode.MN): + if cutlass.const_expr(scale_major_mode == cutlass.cute.nvgpu.OperandMajorMode.MN): outer_layout = cute.make_layout( smem_atom_shape, stride=( @@ -570,7 +575,7 @@ def compute_smem_layout( smem_layout_a_transform = sm100_utils.make_smem_layout_a( tiled_mma, mma_tiler_mnk, - tiled_mma.op.a_dtype, + tiled_mma.op.a_dtype, # type: ignore[attr-defined] trans2mma_stage_count, ) smem_layout_b = sm100_utils.make_smem_layout_b( @@ -583,12 +588,12 @@ def compute_smem_layout( def get_transform_a_source( - a_major_mode: tcgen05.OperandMajorMode, + a_major_mode: cutlass.cute.nvgpu.OperandMajorMode, ) -> tcgen05.OperandSource: """ Determine the operand source for transformed A tensor based on the operand major mode. """ - if cutlass.const_expr(a_major_mode == tcgen05.OperandMajorMode.K): + if cutlass.const_expr(a_major_mode == cutlass.cute.nvgpu.OperandMajorMode.K): return tcgen05.OperandSource.TMEM else: return tcgen05.OperandSource.SMEM @@ -625,13 +630,13 @@ def get_copy_atom_a_transform( """ if cutlass.const_expr(transform_a_source == tcgen05.OperandSource.TMEM): if cutlass.const_expr( - cute.size(a_smem_shape[0][0]) == 64 and (not use_2cta_instrs) + cute.size(a_smem_shape[0][0]) == 64 and (not use_2cta_instrs) # type: ignore[index] ): copy_op_r2t = tcgen05.St16x256bOp( tcgen05.Repetition(1), tcgen05.Unpack.NONE ) else: - copy_op_r2t = tcgen05.St32x32bOp(tcgen05.Repetition(8), tcgen05.Unpack.NONE) + copy_op_r2t = tcgen05.St32x32bOp(tcgen05.Repetition(8), tcgen05.Unpack.NONE) # type: ignore[assignment] return cute.make_copy_atom(copy_op_r2t, mma_dtype) else: return cute.make_copy_atom( @@ -706,7 +711,11 @@ def is_valid_tensor_alignment( Check if the tensor alignments are valid for the given problem size and data types. """ - def check_contiguous_16B_alignment(dtype, is_mode0_major, tensor_shape): + def check_contiguous_16B_alignment( + dtype: type[cutlass.Numeric], + is_mode0_major: bool, + tensor_shape: tuple[int, int], + ) -> bool: major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] num_contiguous_elements = 16 * 8 // dtype.width @@ -797,7 +806,7 @@ class ContiguousGGSearchState: cur_group_idx: cutlass.Int32, cur_offset: cutlass.Int32, cur_start: cutlass.Int32, - ): + ) -> None: self.last_tile_count = last_tile_count self.cur_boundary = cur_boundary self.cur_tile_count = cur_tile_count @@ -805,7 +814,7 @@ class ContiguousGGSearchState: self.cur_offset = cur_offset self.cur_start = cur_start - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: values = extract_mlir_values(self.last_tile_count) values.extend(extract_mlir_values(self.cur_boundary)) values.extend(extract_mlir_values(self.cur_tile_count)) @@ -814,7 +823,9 @@ class ContiguousGGSearchState: values.extend(extract_mlir_values(self.cur_start)) return values - def __new_from_mlir_values__(self, values) -> "ContiguousGGSearchState": + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "ContiguousGGSearchState": last_tile_count = new_from_mlir_values(self.last_tile_count, [values[0]]) cur_boundary = new_from_mlir_values(self.cur_boundary, [values[1]]) cur_tile_count = new_from_mlir_values(self.cur_tile_count, [values[2]]) @@ -869,21 +880,23 @@ class ContiguousGroupWorkTileInfo: coord_n: cutlass.Int32, group_idx: cutlass.Int32, distance_to_boundary: cutlass.Int32, - ): + ) -> None: self.cta_coord_m = cta_coord_m self.coord_n = coord_n self.group_idx = group_idx self.distance_to_boundary = distance_to_boundary self.group_count = group_count - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: values = extract_mlir_values(self.cta_coord_m) values.extend(extract_mlir_values(self.coord_n)) values.extend(extract_mlir_values(self.group_idx)) values.extend(extract_mlir_values(self.distance_to_boundary)) return values - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "ContiguousGroupWorkTileInfo": assert len(values) == 4 new_cta_coord_m = new_from_mlir_values(self.cta_coord_m, [values[0]]) new_coord_n = new_from_mlir_values(self.coord_n, [values[1]]) @@ -900,7 +913,7 @@ class ContiguousGroupWorkTileInfo: ) @property - def is_valid_tile(self): + def is_valid_tile(self) -> Boolean: return self.group_idx < self.group_count @@ -926,7 +939,7 @@ def contiguous_group_search( if not_found: cur_group_idx = cur_group_idx + 1 while not_found and cur_group_idx <= group_count: - next_boundary = cumsum[cur_group_idx] + next_boundary = cumsum[cur_group_idx] # type: ignore[assignment] num_m_blocks = cute.ceil_div( (next_boundary - cur_boundary), cluster_tile_shape_mnk[search_mode], @@ -953,7 +966,9 @@ def contiguous_group_search( ) -def make_contiguous_group_work_tile_info(group_count: int, sTile_info: cute.Tensor): +def make_contiguous_group_work_tile_info( + group_count: int, sTile_info: cute.Tensor +) -> ContiguousGroupWorkTileInfo: """ Generate ContiguousGroupWorkTileInfo from tile_info tensor generated by contiguous_group_search """ @@ -971,11 +986,9 @@ def cvt_tensor_a( Convert tensor src to the given data type. If shuffle is True, use shuffle intrinsic for int4-to-bf16 conversion. """ - from cutlass import CUDA_VERSION + # shuffle is supported since CUDA 13.1 - shuffle_supported = True - if CUDA_VERSION.major < 13 or (CUDA_VERSION.major == 13 and CUDA_VERSION.minor < 1): - shuffle_supported = False + shuffle_supported = cutlass.target_version(min_version="13.1") shuffle = shuffle and shuffle_supported rst = src.load() if cutlass.const_expr(shuffle): diff --git a/python/CuTeDSL/cutlass/utils/print_latex.py b/python/CuTeDSL/cutlass/utils/print_latex.py index 6f6b320b7..c32fd8caf 100644 --- a/python/CuTeDSL/cutlass/utils/print_latex.py +++ b/python/CuTeDSL/cutlass/utils/print_latex.py @@ -26,7 +26,7 @@ from ..cute.typing import IntTuple __all__ = ["print_latex", "print_latex_tv"] -def tikz_color_bwx8(idx: int): +def tikz_color_bwx8(idx: int) -> str: color_map = [ "black!00", "black!40", @@ -40,11 +40,11 @@ def tikz_color_bwx8(idx: int): return color_map[idx % 8] -def tikz_color_white(idx: int): +def tikz_color_white(idx: int) -> str: return "white" -def tikz_color_tv(tid: int, vid: int): +def tikz_color_tv(tid: int, vid: int) -> str: color_map = [ "{rgb,255:red,175;green,175;blue,255}", "{rgb,255:red,175;green,255;blue,175}", @@ -58,7 +58,9 @@ def tikz_color_tv(tid: int, vid: int): return color_map[tid % 8] -def print_latex(x: Union[Layout, ComposedLayout], *, color: Callable = tikz_color_bwx8): +def print_latex( + x: Union[Layout, ComposedLayout], *, color: Callable = tikz_color_bwx8 +) -> None: """ Prints a layout. :param x: A layout @@ -111,7 +113,7 @@ def print_latex_tv( tile_mn: Union[IntTuple, Layout], *, color: Callable = tikz_color_tv, -): +) -> None: """ Prints a tv layout for a tile M N. Everything must be static. :param layout_tv: A static thread value layout @@ -137,14 +139,14 @@ def print_latex_tv( if not isinstance(tile_mn, Layout): tile_mn = make_layout(tile_mn) - M, N = product_each(tile_mn.shape) + M, N = product_each(tile_mn.shape) # type: ignore[union-attr] filled = [[False for n in range(N)] for m in range(M)] for tid in range(size(layout_tv, mode=[0])): for vid in range(size(layout_tv, mode=[1])): idx = layout_tv((tid, vid)) - m = (idx // tile_mn.stride[0]) % tile_mn.shape[0] - n = (idx // tile_mn.stride[1]) % tile_mn.shape[1] + m = (idx // tile_mn.stride[0]) % tile_mn.shape[0] # type: ignore[operator, union-attr, index] + n = (idx // tile_mn.stride[1]) % tile_mn.shape[1] # type: ignore[operator, union-attr, index] if not filled[m][n]: filled[m][n] = True print( diff --git a/python/CuTeDSL/cutlass/utils/smem_allocator.py b/python/CuTeDSL/cutlass/utils/smem_allocator.py index be86709eb..ce9b4c67c 100644 --- a/python/CuTeDSL/cutlass/utils/smem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -9,8 +9,7 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from typing import Optional, Type, Union, overload -from typing_extensions import deprecated +from typing import Any, Optional, Type, Union, overload import inspect import cutlass.cute as cute @@ -25,6 +24,7 @@ from cutlass.cutlass_dsl import ( NumericMeta, dsl_user_op, ) +from cutlass._mlir import ir from cutlass._mlir.dialects import cute as _cute_ir @@ -72,6 +72,7 @@ class SmemAllocator: # Allocate tensor layout = cute.make_layout((16, 16)) tensor = smem.allocate_tensor(Int8, layout) # 256 bytes + """ @staticmethod @@ -95,7 +96,12 @@ class SmemAllocator: return SMEM_CAPACITY_MAP[compute_capability] @dsl_user_op - def __init__(self, *, loc=None, ip=None): + def __init__( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Initialize a new SmemAllocator instance. Creates a new shared memory allocator with a base pointer aligned to 1024 bytes. @@ -108,26 +114,46 @@ class SmemAllocator: """ self._base = get_dyn_smem(Int8, alignment=1024, loc=loc, ip=ip) self._allocated_bytes = 0 - CuTeDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) + CuTeDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) # type: ignore[attr-defined] @overload def allocate( - self, size_or_type: int, byte_alignment: int, *, loc=None, ip=None + self, + size_or_type: int, + byte_alignment: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: ... @overload def allocate( - self, size_or_type: Type[Numeric], byte_alignment: int, *, loc=None, ip=None + self, + size_or_type: Type[Numeric], + byte_alignment: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: ... @overload def allocate( - self, size_or_type: cute.struct, byte_alignment: int, *, loc=None, ip=None + self, + size_or_type: cute.struct, + byte_alignment: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: ... @dsl_user_op def allocate( - self, size_or_type, byte_alignment: int = 1, *, loc=None, ip=None + self, + size_or_type: Any, + byte_alignment: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: """Allocate a block of memory with specified size and alignment. @@ -210,13 +236,13 @@ class SmemAllocator: num_elems: int = 1, *, byte_alignment: int = 1, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Pointer: """Allocate an array of elements in shared memory. :param element_type: The type of elements to allocate - :type element_type: Type[Numeric] + :type element_type: Union[Type[Numeric]] :param num_elems: Number of elements to allocate, defaults to 1 :type num_elems: int, optional :return: Pointer to the start of the allocated array @@ -226,7 +252,12 @@ class SmemAllocator: """ if cute.is_static(num_elems) and num_elems < 1: raise ValueError("num_elems must be at least 1") - if not isinstance(element_type, NumericMeta): + if not isinstance( + element_type, + ( + NumericMeta, + ), + ): raise TypeError( f"value_ty must be a type of Numeric, but got {element_type}" ) @@ -242,20 +273,20 @@ class SmemAllocator: @dsl_user_op def allocate_tensor( self, - element_type: Type[Numeric], + element_type: Union[Type[Numeric],], layout: Union[int, cute.Layout, cute.ComposedLayout], byte_alignment: int = 1, swizzle: Optional[cute.Swizzle] = None, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Tensor: """Allocate a tensor in shared memory. Note: Currently only supports static layouts. Dynamic layouts are not supported. :param element_type: The type of elements in the tensor - :type element_type: Type[Numeric] + :type element_type: Union[Type[Numeric]] :param layout: The layout specification for the tensor. Must be a static layout. :type layout: Union[int, cute.Layout, cute.ComposedLayout] :param byte_alignment: The byte alignment requirement, defaults to 1 @@ -264,11 +295,16 @@ class SmemAllocator: :type swizzle: cute.Swizzle, optional :return: The allocated tensor with specified properties :rtype: cute.Tensor - :raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout + :raises TypeError: If element_type is not a Numeric type, or if swizzle conflicts with layout :raises ValueError: If allocation is not byte-aligned :raises NotImplementedError: If dynamic layout is specified """ - if not isinstance(element_type, NumericMeta): + if not isinstance( + element_type, + ( + NumericMeta, + ), + ): raise TypeError( f"value_ty must be a type of Numeric, but got {element_type}" ) @@ -284,7 +320,7 @@ class SmemAllocator: if isinstance(layout, int): layout = cute.make_layout(layout) - profile = layout(0, loc=loc, ip=ip) + profile = layout(0, loc=loc, ip=ip) # type: ignore[operator] if isinstance(profile, tuple): raise TypeError( "cannot allocate a shared memory tensor with a non-integer iterator" @@ -312,7 +348,7 @@ class SmemAllocator: # Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator -SmemAllocator.__init__.__signature__ = inspect.Signature( +SmemAllocator.__init__.__signature__ = inspect.Signature( # type: ignore[attr-defined] [ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), ] diff --git a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py index 73aa23dfb..b9762a18d 100644 --- a/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py +++ b/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py @@ -10,7 +10,7 @@ # is strictly prohibited. import inspect -from typing import Tuple +from typing import Optional, Tuple from cutlass.cutlass_dsl import ( Boolean, @@ -96,9 +96,9 @@ class PersistentTileSchedulerParams: swizzle_size: int = 1, raster_along_m: bool = True, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Initializes the PersistentTileSchedulerParams with the given parameters. @@ -116,15 +116,15 @@ class PersistentTileSchedulerParams: :raises ValueError: If cluster_shape_k is not 1. """ - if cluster_shape_mnk[2] != 1: - raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") + if cluster_shape_mnk[2] != 1: # type: ignore[index] + raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") # type: ignore[index] if swizzle_size < 1: raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}") self.problem_shape_ntile_mnl = problem_shape_ntile_mnl # cluster_shape_mnk is kept for reconstruction self._cluster_shape_mnk = cluster_shape_mnk - self.cluster_shape_mn = cluster_shape_mnk[:2] + self.cluster_shape_mn = cluster_shape_mnk[:2] # type: ignore[index] self.swizzle_size = swizzle_size self.raster_along_m = raster_along_m self._loc = loc @@ -132,7 +132,10 @@ class PersistentTileSchedulerParams: # By default, we follow m major (col-major) raster order, so make a col-major layout self.problem_layout_ncluster_mnl = cute.make_layout( cute.ceil_div( - self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip + self.problem_shape_ntile_mnl, + cluster_shape_mnk[:2], # type: ignore[index] + loc=loc, + ip=ip, ), loc=loc, ip=ip, @@ -148,14 +151,14 @@ class PersistentTileSchedulerParams: if raster_along_m: self.problem_layout_ncluster_mnl = cute.make_layout( ( - problem_shape_ncluster_mnl[0], - (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), - problem_shape_ncluster_mnl[2], + problem_shape_ncluster_mnl[0], # type: ignore[index] + (swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size), # type: ignore[index, operator] + problem_shape_ncluster_mnl[2], # type: ignore[index] ), stride=( swizzle_size, - (1, swizzle_size * problem_shape_ncluster_mnl[0]), - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], + (1, swizzle_size * problem_shape_ncluster_mnl[0]), # type: ignore[index] + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], # type: ignore[index, operator] ), loc=loc, ip=ip, @@ -163,14 +166,14 @@ class PersistentTileSchedulerParams: else: self.problem_layout_ncluster_mnl = cute.make_layout( ( - (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), - problem_shape_ncluster_mnl[1], - problem_shape_ncluster_mnl[2], + (swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size), # type: ignore[index, operator] + problem_shape_ncluster_mnl[1], # type: ignore[index] + problem_shape_ncluster_mnl[2], # type: ignore[index] ), stride=( - (1, swizzle_size * problem_shape_ncluster_mnl[1]), + (1, swizzle_size * problem_shape_ncluster_mnl[1]), # type: ignore[index] swizzle_size, - problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], + problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1], # type: ignore[index, operator] ), loc=loc, ip=ip, @@ -179,17 +182,12 @@ class PersistentTileSchedulerParams: # Create FastDivmod divisors (only when swizzle_size == 1 for correctness) # FastDivmod assumes simple col-major layout, incompatible with swizzled layouts if swizzle_size == 1: - problem_layout_size = cute.size( + _problem_layout_size = cute.size( self.problem_layout_ncluster_mnl, loc=loc, ip=ip ) cluster_count_m = self.problem_layout_ncluster_mnl.shape[0] cluster_count_n = self.problem_layout_ncluster_mnl.shape[1] - # batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling) - self.batch_fdd = cute.fast_divmod_create_divisor( - problem_layout_size, loc=loc, ip=ip - ) - if raster_along_m: cluster_count_major = cluster_count_m cluster_count_minor = cluster_count_n @@ -208,11 +206,10 @@ class PersistentTileSchedulerParams: ) else: # FastDivmod not applicable with swizzling, set to None - self.batch_fdd = None self.cluster_shape_major_fdd = None self.cluster_shape_minor_fdd = None - def __extract_mlir_values__(self): + def __extract_mlir_values__(self) -> list[ir.Value]: values, self._values_pos = [], [] for obj in [ self.problem_shape_ntile_mnl, @@ -231,7 +228,6 @@ class PersistentTileSchedulerParams: for i, (fdd_name, fdd_obj) in enumerate( [ - ("batch_fdd", self.batch_fdd), ("cluster_shape_major_fdd", self.cluster_shape_major_fdd), ("cluster_shape_minor_fdd", self.cluster_shape_minor_fdd), ] @@ -250,7 +246,9 @@ class PersistentTileSchedulerParams: return values - def __new_from_mlir_values__(self, values): + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "PersistentTileSchedulerParams": obj_list = [] values_copy = list(values) # Make a copy to avoid modifying original @@ -272,7 +270,7 @@ class PersistentTileSchedulerParams: new_params = PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) # Restore FastDivmod divisors from remaining values - fdd_names = ["batch_fdd", "cluster_shape_major_fdd", "cluster_shape_minor_fdd"] + fdd_names = ["cluster_shape_major_fdd", "cluster_shape_minor_fdd"] if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0: # Override the FastDivmod divisors created by __init__ with reconstructed ones @@ -291,7 +289,11 @@ class PersistentTileSchedulerParams: @dsl_user_op def get_grid_shape( - self, max_active_clusters: Int32, *, loc=None, ip=None + self, + max_active_clusters: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Integer, Integer, Integer]: """ Computes the grid shape based on the maximum active clusters allowed. @@ -327,7 +329,7 @@ class PersistentTileSchedulerParams: # Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator -PersistentTileSchedulerParams.__init__.__signature__ = inspect.Signature( +PersistentTileSchedulerParams.__init__.__signature__ = inspect.Signature( # type: ignore[attr-defined] [ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), ] @@ -424,9 +426,9 @@ class StaticPersistentTileScheduler: block_idx: Tuple[Integer, Integer, Integer], grid_dim: Tuple[Integer, Integer, Integer], *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "StaticPersistentTileScheduler": """Initialize the static persistent tile scheduler. :param params: Parameters for the persistent @@ -474,8 +476,8 @@ class StaticPersistentTileScheduler: params: PersistentTileSchedulerParams, max_active_clusters: Int32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Integer, Integer, Integer]: """Calculates the grid shape to be launched on GPU using problem shape, threadblock shape, and active cluster size. @@ -493,7 +495,11 @@ class StaticPersistentTileScheduler: # private method def _get_current_work_for_linear_idx( - self, current_work_linear_idx: Int32, *, loc=None, ip=None + self, + current_work_linear_idx: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> WorkTileInfo: """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. @@ -521,12 +527,11 @@ class StaticPersistentTileScheduler: current_work_linear_idx, loc=loc, ip=ip ) - # cur_tile_coord is a tuple of i32 values cur_tile_coord = tuple( Int32(x) * Int32(z) + Int32(y) for x, y, z in zip( cur_cluster_coord, - self.cta_id_in_cluster, + self.cta_id_in_cluster, # type: ignore[arg-type] (*self.params.cluster_shape_mn, Int32(1)), ) ) @@ -534,7 +539,11 @@ class StaticPersistentTileScheduler: return WorkTileInfo(cur_tile_coord, is_valid) def _get_cluster_work_idx_with_fastdivmod( - self, current_work_linear_idx: Int32, *, loc=None, ip=None + self, + current_work_linear_idx: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> Tuple[Int32, Int32, Int32]: """ FastDivmod optimized CLUSTER coordinate calculation. @@ -548,18 +557,13 @@ class StaticPersistentTileScheduler: :rtype: Tuple[Int32, Int32, Int32] or None """ - # Step 1: Handle persistent scheduling - map linear_idx to work_unit_id - work_iteration, work_unit_id = divmod( - current_work_linear_idx, self.params.batch_fdd - ) - - # Step 2: Decode work_unit_id using FastDivmod objects + # Step 1: Decode current_work_linear_idx using FastDivmod objects # The layout structure is: problem_layout_ncluster_mnl has shape (cluster_count_m, cluster_count_n, batch_count) - # work_unit_id needs to be decomposed into (batch_l, cluster_minor, cluster_major) in little-endian order + # current_work_linear_idx needs to be decomposed into (batch_l, cluster_minor, cluster_major) in little-endian order # First, get cluster_major using cluster_shape_major_fdd cluster_minor_batch, cluster_major = divmod( - work_unit_id, self.params.cluster_shape_major_fdd + current_work_linear_idx, self.params.cluster_shape_major_fdd ) # Then decode cluster_minor_batch to get cluster_minor and batch_l using FastDivmod @@ -577,17 +581,33 @@ class StaticPersistentTileScheduler: return (cluster_m, cluster_n, batch_l) @dsl_user_op - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def get_current_work( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> WorkTileInfo: return self._get_current_work_for_linear_idx( self._current_work_linear_idx, loc=loc, ip=ip ) @dsl_user_op - def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + def initial_work_tile_info( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> WorkTileInfo: return self.get_current_work(loc=loc, ip=ip) @dsl_user_op - def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + def advance_to_next_work( + self, + *, + advance_count: int = 1, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: self._current_work_linear_idx += Int32(advance_count) * Int32( self.num_persistent_clusters ) @@ -690,9 +710,9 @@ class StaticPersistentRuntimeTileScheduler(StaticPersistentTileScheduler): grid_dim: Tuple[Integer, Integer, Integer], inner_mode: int = 1, *, - loc=None, - ip=None, - ): + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> "StaticPersistentRuntimeTileScheduler": """Initialize the static persistent tile scheduler. :param params: Parameters for the persistent @@ -739,7 +759,11 @@ class StaticPersistentRuntimeTileScheduler(StaticPersistentTileScheduler): # private method def _get_current_work_for_linear_idx( - self, current_work_linear_idx: Int32, *, loc=None, ip=None + self, + current_work_linear_idx: Int32, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> WorkTileInfo: """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. @@ -768,6 +792,6 @@ class StaticPersistentRuntimeTileScheduler(StaticPersistentTileScheduler): ) # it is determined by kernel implementation - is_valid = True + is_valid = Boolean(True) return WorkTileInfo(cur_tile_coord, is_valid) diff --git a/python/CuTeDSL/cutlass/utils/tensor_helpers.py b/python/CuTeDSL/cutlass/utils/tensor_helpers.py index e28afb658..df5f5710c 100644 --- a/python/CuTeDSL/cutlass/utils/tensor_helpers.py +++ b/python/CuTeDSL/cutlass/utils/tensor_helpers.py @@ -11,7 +11,9 @@ """Utility functions for tensor creation and type handling.""" -from typing import Type, Optional +from typing import Any, Optional, Type + +import cutlass.cute as cute # Import only the specific types needed to avoid circular import with cutlass module from cutlass.cute.typing import Float8E5M2, Float8E4M3FN, TFloat32, Numeric @@ -28,11 +30,11 @@ def is_fp8_dtype(dtype: Type[Numeric]) -> bool: def create_cute_tensor_for_fp8( - storage_tensor, + storage_tensor: Any, dtype: Type[Numeric], leading_dim: int, - source_f32_tensor=None, -): + source_f32_tensor: Optional[Any] = None, +) -> cute.Tensor: """Create cute tensor, handling float8 types that don't support dlpack. For float8 types, the storage_tensor should be uint8 (for DLPack compatibility). diff --git a/python/CuTeDSL/cutlass/utils/tensormap_manager.py b/python/CuTeDSL/cutlass/utils/tensormap_manager.py index 2e3e9f82f..cd19dd3ad 100644 --- a/python/CuTeDSL/cutlass/utils/tensormap_manager.py +++ b/python/CuTeDSL/cutlass/utils/tensormap_manager.py @@ -11,14 +11,17 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import Tuple +from typing import Optional, Tuple +from cutlass._mlir import ir import cutlass._mlir.dialects.cute as _cute_ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir from cutlass.cutlass_dsl import dsl_user_op import cutlass.cute as cute from cutlass import const_expr +from cutlass.cute.core import AddressSpace as _CuteAddressSpace +from cutlass.cute.core import make_ptr as _cute_make_ptr class TensorMapUpdateMode(Enum): @@ -51,10 +54,10 @@ class TensorMapManager: def get_tensormap_ptr( self, ptr: cute.Pointer, - address_space=_cute_ir.AddressSpace.gmem, + address_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.gmem, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: if address_space not in [ _cute_ir.AddressSpace.gmem, @@ -87,8 +90,8 @@ class TensorMapManager: dst_ptr: cute.Pointer, warp_id: int, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: warp_idx = cute.arch.warp_idx(loc=loc, ip=ip) warp_idx = cute.arch.make_warp_uniform(warp_idx, loc=loc, ip=ip) @@ -103,8 +106,8 @@ class TensorMapManager: def fence_tensormap_initialization( self, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: if self.tensormap_update_mode == TensorMapUpdateMode.GMEM: cute.arch.fence_acq_rel_cta(loc=loc, ip=ip) @@ -116,8 +119,8 @@ class TensorMapManager: self, tensormap_ptr: cute.Pointer, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: cute.nvgpu.cpasync.fence_tma_desc_acquire(tensormap_ptr, loc=loc, ip=ip) return @@ -132,17 +135,31 @@ class TensorMapManager: warp_id: int, tensormap_smem_ptr: Tuple[cute.Pointer, ...], *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> None: warp_idx = cute.arch.make_warp_uniform( cute.arch.warp_idx(loc=loc, ip=ip), loc=loc, ip=ip ) + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + # Hoist SMEM pointer integer values into warp-uniform registers before + # entering predicated blocks. This avoids predicated R2UR lowering on sm_90a. + uniform_smem_ptrs = tuple( + _cute_make_ptr( + p.dtype, + cute.arch.make_warp_uniform(p.toint(), loc=loc, ip=ip), + mem_space=_CuteAddressSpace.smem, + assumed_align=p.alignment, # type: ignore[attr-defined] + ) + for p in tensormap_smem_ptr + ) + else: + uniform_smem_ptrs = tensormap_smem_ptr # updates before touching tensormap in global memory if warp_idx == warp_id: if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): for copy_atom, tensor, smem_ptr in zip( - tma_copy_atom, tensor_gmem, tensormap_smem_ptr + tma_copy_atom, tensor_gmem, uniform_smem_ptrs ): cute.nvgpu.cpasync.update_tma_descriptor( copy_atom, tensor, smem_ptr, loc=loc, ip=ip @@ -154,7 +171,7 @@ class TensorMapManager: cute.arch.sync_warp(loc=loc, ip=ip) # updates to tensormap in global memory if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): - for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): + for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, uniform_smem_ptrs): cute.nvgpu.cpasync.cp_fence_tma_desc_release( gmem_ptr, smem_ptr, loc=loc, ip=ip ) diff --git a/python/CuTeDSL/cutlass/utils/tmem_allocator.py b/python/CuTeDSL/cutlass/utils/tmem_allocator.py index f1b9b1654..dd747aa50 100644 --- a/python/CuTeDSL/cutlass/utils/tmem_allocator.py +++ b/python/CuTeDSL/cutlass/utils/tmem_allocator.py @@ -14,9 +14,11 @@ from typing import Optional, Type, Union, List import inspect from cutlass import const_expr +from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import ( Numeric, Float32, + Boolean, extract_mlir_values, new_from_mlir_values, dsl_user_op, @@ -28,6 +30,224 @@ from cutlass.cute.nvgpu.tcgen05 import find_tmem_tensor_col_offset from cutlass.cute.arch import get_max_tmem_alloc_cols, get_min_tmem_alloc_cols +_TMEM_COL_MASK = 0x0000FFFF + + +@dsl_user_op +def compute_tmem_cols_from_layout( + layout: cute.Layout, + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> int: + """Compute the number of TMEM columns required for a layout with a given dtype. + + This function calculates the column offset by recasting the layout to Int32 + and computing its cosize, similar to how find_tmem_tensor_col_offset works + but without requiring a tensor. + + :param layout: The TMEM layout to compute columns for. + :type layout: cute.Layout + :param dtype: The data type of the elements in the layout. + :type dtype: Type[Numeric] + :return: The number of TMEM columns (always a Python int). + :rtype: int + + :raises ValueError: If the layout size cannot be determined at compile time. + """ + # Get source width from dtype + if dtype is Boolean: + src_width = 8 + else: + src_width = dtype.width + + # Recast layout to Int32 (32-bit width) as done in find_tmem_tensor_col_offset + dst_width = 32 # Int32.width + recasted_layout = cute.recast_layout(dst_width, src_width, layout, loc=loc, ip=ip) + + # Compute cosize and mask + offset = cute.cosize(recasted_layout, loc=loc, ip=ip) & _TMEM_COL_MASK + + # Ensure we return a Python int + if isinstance(offset, int): + return offset + + # Try to fold the DSL value to a Python int + try: + return const_expr(offset) + except Exception: + raise ValueError( + "Dynamic TMEM layout size not supported; " + "the layout size must be determinable at compile time." + ) + + +class TmemBufferPool: + """A pool for sub-allocating from a reserved chunk of tensor memory. + + This class enables sub-allocation from a pre-reserved TMEM region, + eliminating the need for manual offset calculations when allocating + multiple tensors in TMEM. + + Example usage:: + + tmem_pool = tmem_allocator.reserve(tmem_total_size) + + # Allocate and create tensors in one call + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_layout, cutlass.Float32) + tCtSFA = tmem_pool.allocate_tensor(tCtSFA_layout, sf_dtype) + + # Or allocate pointer only, then create tensor manually + sfb_ptr = tmem_pool.allocate(tCtSFB_layout, sf_dtype) + tCtSFB = cute.make_tensor(sfb_ptr, tCtSFB_layout) + + :ivar _base_ptr: The base pointer to the reserved TMEM region. + :type _base_ptr: cute.Pointer + :ivar _total_cols: The total number of columns in the pool. + :type _total_cols: int + :ivar _current_offset: The current offset within the pool (in columns). + :type _current_offset: int + """ + + def __init__( + self, + base_ptr: cute.Pointer, + total_cols: int, + ): + """ + Initialize a TmemBufferPool instance. + + :param base_ptr: The base pointer to the reserved TMEM region. + :type base_ptr: cute.Pointer + :param total_cols: The total number of columns in the pool. + :type total_cols: int + """ + self._base_ptr = base_ptr + self._total_cols = total_cols + self._current_offset = 0 + + def __extract_mlir_values__(self) -> list[ir.Value]: + return extract_mlir_values(self._base_ptr) + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "TmemBufferPool": + assert len(values) == 1 + new_base_ptr = new_from_mlir_values(self._base_ptr, [values[0]]) + pool = TmemBufferPool(new_base_ptr, self._total_cols) + pool._current_offset = self._current_offset + return pool + + @property + def base_ptr(self) -> cute.Pointer: + """Return the base pointer of the pool.""" + return self._base_ptr + + @property + def total_cols(self) -> int: + """Return the total number of columns in the pool.""" + return self._total_cols + + @property + def current_offset(self) -> int: + """Return the current offset within the pool.""" + return self._current_offset + + @property + def remaining_cols(self) -> int: + """Return the number of remaining columns available for allocation.""" + return self._total_cols - self._current_offset + + @dsl_user_op + def allocate( + self, + size: Union[int, cute.Layout], + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Pointer: + """Allocate a sub-region from the pool and return a pointer. + + This method allocates a contiguous region of TMEM columns from the pool + and returns a pointer to the start of that region. + + :param size: The allocation size, which can be: + - int: explicit number of columns to allocate + - cute.Layout: a TMEM layout that, combined with dtype, determines the size + :type size: Union[int, cute.Layout] + :param dtype: The data type for the returned pointer and for computing + layout size (when size is a Layout). + :type dtype: Type[Numeric] + :return: A pointer to the allocated region with the specified dtype. + :rtype: cute.Pointer + + :raises AssertionError: If there are not enough columns remaining in the pool. + + Example usage:: + + # Allocate with explicit column count + acc_ptr = pool.allocate(64, cutlass.Float32) + + # Allocate based on layout and dtype + sfa_ptr = pool.allocate(tCtSFA_layout, sf_dtype) + """ + # Determine number of columns from size argument + if isinstance(size, cute.Layout): + num_cols = compute_tmem_cols_from_layout(size, dtype, loc=loc, ip=ip) + else: + num_cols = size + + assert self._current_offset + num_cols <= self._total_cols, ( + f"Cannot allocate {num_cols} columns, only {self.remaining_cols} remaining" + ) + + if self._current_offset == 0: + # First allocation - return base pointer with recast + ptr = cute.recast_ptr(self._base_ptr, dtype=dtype, loc=loc, ip=ip) + else: + # Subsequent allocations - offset from base + ptr = cute.recast_ptr( + self._base_ptr + self._current_offset, + dtype=dtype, + loc=loc, + ip=ip, + ) + + self._current_offset += num_cols + return ptr + + @dsl_user_op + def allocate_tensor( + self, + layout: cute.Layout, + dtype: Type[Numeric], + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> cute.Tensor: + """Allocate a sub-region from the pool and return a tensor. + + This is a convenience method that combines allocate() and cute.make_tensor() + into a single call. + + :param layout: The TMEM layout for the tensor. + :type layout: cute.Layout + :param dtype: The data type for the tensor elements. + :type dtype: Type[Numeric] + :return: A tensor backed by the allocated TMEM region. + :rtype: cute.Tensor + + :raises AssertionError: If there are not enough columns remaining in the pool. + + Example usage:: + + tCtAcc = pool.allocate_tensor(tCtAcc_layout, cutlass.Float32) + tCtSFA = pool.allocate_tensor(tCtSFA_layout, sf_dtype) + """ + ptr = self.allocate(layout, dtype, loc=loc, ip=ip) + return cute.make_tensor(ptr, layout, loc=loc, ip=ip) + + class TmemAllocator: """A class for managing tensor memory allocation on GPUs. @@ -52,7 +272,12 @@ class TmemAllocator: @dsl_user_op @cute.jit - def _init_dealloc_mbarrier(self, *, loc=None, ip=None): + def _init_dealloc_mbarrier( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: assert self._two_cta_tmem_dealloc_mbar_ptr is not None, ( "two_cta_tmem_dealloc_mbar_ptr is required for two cta" ) @@ -81,10 +306,10 @@ class TmemAllocator: two_cta_tmem_dealloc_mbar_ptr: Optional[cute.Pointer] = None, *, arch: str = "sm_100", - dealloc_mbarrier_initialized: bool = False, - loc=None, - ip=None, - ): + initialize_mbarrier: bool = True, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """ Initialize a TmemAllocator instance for managing tensor memory on Blackwell GPUs. @@ -109,6 +334,8 @@ class TmemAllocator: :type num_allocated_columns: int, optional :param two_cta_tmem_dealloc_mbar_ptr: The mbarrier pointer required for two-CTA tensor memory deallocation, optional. :type two_cta_tmem_dealloc_mbar_ptr: cute.Pointer, optional + :param initialize_mbarrier: Whether to initialize the mbarrier for two cta, defaults to True. + :type initialize_mbarrier: bool, optional :param loc: Optional codegen location for debugging and error reporting. :type loc: Any, optional :param ip: Optional insertion point for codegen. @@ -127,7 +354,7 @@ class TmemAllocator: self._max_tmem_columns = get_max_tmem_alloc_cols(arch) # Init tmem dealloc mbarrier if two cta - if not dealloc_mbarrier_initialized and const_expr(self._is_two_cta): + if const_expr(self._is_two_cta and initialize_mbarrier): self._init_dealloc_mbarrier(loc=loc, ip=ip) def __extract_mlir_values__(self) -> list[ir.Value]: @@ -160,11 +387,11 @@ class TmemAllocator: self._num_allocated_columns, new_two_cta_tmem_dealloc_mbar_ptr, arch=self._arch, # Preserve the architecture parameter - dealloc_mbarrier_initialized=True, + initialize_mbarrier=False, ) @cute.jit - def check_valid_num_columns(self, num_columns: int): + def check_valid_num_columns(self, num_columns: int) -> bool: """Check if the number of columns is valid. This method checks if the number of columns is valid. @@ -180,13 +407,21 @@ class TmemAllocator: if const_expr(num_columns % 32 != 0): return False # power of two - if const_expr(num_columns & (num_columns - 1) != 0): + if const_expr( + (num_columns & (num_columns - 1) != 0) + ): return False return True @dsl_user_op @cute.jit - def allocate(self, num_columns: int, *, loc=None, ip=None): + def allocate( + self, + num_columns: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Allocate a block of tensor memory. This method allocates a block of tensor memory from allocator warp and returns a handle to retrieve @@ -215,7 +450,12 @@ class TmemAllocator: self._num_allocated_columns += num_columns @dsl_user_op - def wait_for_alloc(self, *, loc=None, ip=None): + def wait_for_alloc( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Wait for the allocator warp to finish allocation. This method is used to synchronize the allocator warp with the other warps before retrieving tmem ptr. @@ -227,8 +467,8 @@ class TmemAllocator: self, dtype: Type[Numeric] = Float32, *, - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> cute.Pointer: """Retrieve the pointer to the allocated tensor memory. @@ -243,9 +483,50 @@ class TmemAllocator: ip=ip, ) + @dsl_user_op + def reserve( + self, + num_columns: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> TmemBufferPool: + """Reserve a block of tensor memory and return a pool for sub-allocation. + + This method allocates a block of tensor memory, waits for the allocation + to complete, and returns a TmemBufferPool that can be used to sub-allocate + regions within that block without manual offset calculations. + + Example usage:: + + tmem_pool = tmem_allocator.reserve(tmem_total_size) + + # Allocate and create tensors in one call + tCtAcc = tmem_pool.allocate_tensor(tCtAcc_layout, cutlass.Float32) + tCtSFA = tmem_pool.allocate_tensor(tCtSFA_layout, sf_dtype) + + # Or allocate pointer only, then create tensor manually + sfb_ptr = tmem_pool.allocate(tCtSFB_layout, sf_dtype) + tCtSFB = cute.make_tensor(sfb_ptr, tCtSFB_layout) + + :param num_columns: The total number of columns to reserve. + :type num_columns: int + :return: A TmemBufferPool for sub-allocating within the reserved region. + :rtype: TmemBufferPool + """ + self.allocate(num_columns, loc=loc, ip=ip) + self.wait_for_alloc(loc=loc, ip=ip) + base_ptr = self.retrieve_ptr(loc=loc, ip=ip) + return TmemBufferPool(base_ptr, num_columns) + @dsl_user_op @cute.jit - def relinquish_alloc_permit(self, *, loc=None, ip=None): + def relinquish_alloc_permit( + self, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Relinquish the tensor memory allocation permit. This method relinquishes the tensor memory allocation permit for the allocator warp, promising @@ -261,7 +542,14 @@ class TmemAllocator: @dsl_user_op @cute.jit - def free(self, tmem_ptr: cute.Pointer, num_columns: int = 0, *, loc=None, ip=None): + def free( + self, + tmem_ptr: cute.Pointer, + num_columns: int = 0, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: """Deallocate the tensor memory. This method sync on mbarrier (for two cta use case) and deallocates the tensor memory from the allocator warp. @@ -309,7 +597,7 @@ class TmemAllocator: # Set explicit signature for Sphinx documentation to avoid issues with @dsl_user_op decorator -TmemAllocator.__init__.__signature__ = inspect.Signature( +TmemAllocator.__init__.__signature__ = inspect.Signature( # type: ignore[attr-defined] [ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter( @@ -352,11 +640,11 @@ TmemAllocator.__init__.__signature__ = inspect.Signature( def get_num_tmem_alloc_cols( tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], - rounding=True, + rounding: bool = True, *, arch: str = "sm_100", - loc=None, - ip=None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, ) -> int: """Get the total number of TMEM allocation columns for the given TMEM tensors. diff --git a/python/CuTeDSL/prep_editable_install.py b/python/CuTeDSL/prep_editable_install.py index ac7d258f9..d3064f7bb 100644 --- a/python/CuTeDSL/prep_editable_install.py +++ b/python/CuTeDSL/prep_editable_install.py @@ -158,6 +158,7 @@ def extract_version_from_wheel(wheel_path: Path) -> str: else: return "9.9.9.dev0" + def extract_wheel_contents(wheel_path: Path, extract_dir: Path) -> None: """ Extract wheel contents to specified directory. diff --git a/python/CuTeDSL/requirements-cu13.txt b/python/CuTeDSL/requirements-cu13.txt index 3b3f7463f..4fcd9996b 100644 --- a/python/CuTeDSL/requirements-cu13.txt +++ b/python/CuTeDSL/requirements-cu13.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements-cu13.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl[cu13]==4.5.0 +nvidia-cutlass-dsl[cu13]==4.4.2 diff --git a/python/CuTeDSL/requirements.txt b/python/CuTeDSL/requirements.txt index 80d81892e..2238c3db3 100644 --- a/python/CuTeDSL/requirements.txt +++ b/python/CuTeDSL/requirements.txt @@ -1,3 +1,3 @@ # Use `pip install -r requirements.txt` with the present file to install a # wheel consistent with the present state of the github repository -nvidia-cutlass-dsl==4.5.0.dev0 +nvidia-cutlass-dsl==4.4.2 diff --git a/python/CuTeDSL/setup.sh b/python/CuTeDSL/setup.sh index 5428ece93..f148f6050 100755 --- a/python/CuTeDSL/setup.sh +++ b/python/CuTeDSL/setup.sh @@ -37,49 +37,118 @@ set -e # Get the directory where this script is located SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -# Default to requirements.txt +# Default mode +MODE="wheel" REQUIREMENTS_FILE="requirements.txt" # Parse command line arguments -if [ $# -gt 0 ]; then +while [ $# -gt 0 ]; do case "$1" in + --editable|-e) + MODE="editable" + shift + ;; --cu12) REQUIREMENTS_FILE="requirements.txt" - echo "Installing CUDA 12 requirements..." + shift ;; --cu13) REQUIREMENTS_FILE="requirements-cu13.txt" - echo "Installing CUDA 13 requirements..." + shift ;; --help|-h) - echo "Usage: $0 [--cu12|--cu13]" - echo " --cu12 Install requirements for CUDA 12 (default)" - echo " --cu13 Install requirements for CUDA 13" + echo "CUTLASS IR Python DSL Setup Script" + echo "" + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " -e, --editable Install in editable/development mode" + echo " --cu12 Use CUDA 12 requirements (default)" + echo " --cu13 Use CUDA 13 requirements" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " # Install from wheel (CUDA 12)" + echo " $0" + echo "" + echo " # Install from wheel (CUDA 13)" + echo " $0 --cu13" + echo "" + echo " # Install in editable mode (requires CUTLASS_IR_BUILD_DIR)" + echo " export CUTLASS_IR_BUILD_DIR=/path/to/build" + echo " $0 --editable" + echo "" + echo " # Install in editable mode with CUDA 13 dev deps" + echo " export CUTLASS_IR_BUILD_DIR=../../../build" + echo " $0 --editable --cu13" + echo "" + echo "For more details on editable install, see README.md" exit 0 ;; *) echo "Error: Unknown argument '$1'" - echo "Usage: $0 [--cu12|--cu13]" + echo "Use --help for usage information" exit 1 ;; esac +done + +if [ "$MODE" = "editable" ]; then + echo "=====================================================================" + echo "Installing CUTLASS IR Python DSL in EDITABLE mode" + echo "=====================================================================" + echo "" + + # Check if CUTLASS_IR_BUILD_DIR is set + if [ -z "$CUTLASS_IR_BUILD_DIR" ]; then + echo "ERROR: CUTLASS_IR_BUILD_DIR environment variable is required for editable install" + echo "" + echo "Please set it to your CMake build directory:" + echo " export CUTLASS_IR_BUILD_DIR=/path/to/build" + echo "" + echo "Or use a relative path:" + echo " export CUTLASS_IR_BUILD_DIR=../../../build" + echo "" + echo "Then run this script again." + exit 1 + fi + + echo "Build directory: $CUTLASS_IR_BUILD_DIR" + echo "" + + # Install in editable mode with dev dependencies + echo "Installing with dev dependencies..." + pip install -e ".[dev]" + + echo "" + echo "=====================================================================" + echo "Editable installation complete!" + echo "=====================================================================" + echo "" + echo "You can now import cutlass from anywhere:" + echo " python -c 'import cutlass; print(cutlass.__version__)'" + echo "" + echo "Runtime environment (CUTE_DSL_LIBS) is automatically configured." + echo "Changes to Python code will be immediately available without reinstall." + echo "" + echo "See README.md for development workflow details." + else - echo "Installing default requirements (CUDA 12)..." + # Wheel installation mode + echo "Installing CUTLASS IR Python DSL from wheel" + echo "" + + # Check if requirements file exists + REQUIREMENTS_PATH="${SCRIPT_DIR}/${REQUIREMENTS_FILE}" + if [ ! -f "$REQUIREMENTS_PATH" ]; then + echo "Error: Requirements file not found: $REQUIREMENTS_PATH" + exit 1 + fi + + # Install requirements + echo "Installing from: $REQUIREMENTS_FILE" + pip install -r "$REQUIREMENTS_PATH" + + echo "" + echo "Installation complete!" fi - -# Check if requirements file exists -REQUIREMENTS_PATH="${SCRIPT_DIR}/${REQUIREMENTS_FILE}" -if [ ! -f "$REQUIREMENTS_PATH" ]; then - echo "Error: Requirements file not found: $REQUIREMENTS_PATH" - exit 1 -fi - -# Uninstall previous version of the CUTLASS DSL -echo "Trying to uninstall previous version of the CUTLASS DSL..." -pip uninstall nvidia-cutlass-dsl nvidia-cutlass-dsl-libs-base nvidia-cutlass-dsl-libs-cu13 -y - -# Install requirements -echo "Installing from: $REQUIREMENTS_FILE" -pip install -r "$REQUIREMENTS_PATH" - -echo "Installation complete!" diff --git a/python/cutlass_cppgen/__init__.py b/python/cutlass_cppgen/__init__.py index 0cbf25180..889b6f453 100644 --- a/python/cutlass_cppgen/__init__.py +++ b/python/cutlass_cppgen/__init__.py @@ -133,7 +133,7 @@ def get_option_registry(): this._option_registry = OptionRegistry(device_cc()) return this._option_registry -this.__version__ = '4.5.0' +this.__version__ = '4.4.2' from cutlass_cppgen.backend import create_memory_pool from cutlass_cppgen.emit.pytorch import pytorch diff --git a/python/cutlass_library/emit_kernel_listing.py b/python/cutlass_library/emit_kernel_listing.py index 4041caeeb..531091597 100755 --- a/python/cutlass_library/emit_kernel_listing.py +++ b/python/cutlass_library/emit_kernel_listing.py @@ -789,9 +789,11 @@ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode } } - cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m - cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n - cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k + # Fallback cluster shape cannot differ from preferred cluster shape in stream-K kernels. + enable_fallback_cluster = dynamic_cluster and 'stream_k' not in kernel_name + cluster_m_fallback = ctas_per_mma_instruction if enable_fallback_cluster else cluster_shape_m + cluster_n_fallback = 1 if enable_fallback_cluster else cluster_shape_n + cluster_k_fallback = 1 if enable_fallback_cluster else cluster_shape_k if dynamic_datatype: diff --git a/python/setup_cutlass.py b/python/setup_cutlass.py index 98d2e077c..0d1abbb32 100644 --- a/python/setup_cutlass.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup_pycute.perform_setup() setup( name='cutlass_cppgen', - version='4.5.0', + version='4.4.2', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index c88e3320c..84edebc8c 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='cutlass_library', - version='4.5.0', + version='4.4.2', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index 7892f866c..c4ae36938 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ from setuptools import setup def perform_setup(): setup( name='pycute', - version='4.5.0', + version='4.4.2', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu index 8a0b8e7d2..e7de7e62e 100644 --- a/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +++ b/test/unit/conv/device_3x/dgrad/sm100_conv3d_dgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -296,9 +296,9 @@ TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op // // CTA tile shape 64x64x64 // preferred cluster shape 2x4x1 -// fallback cluster shape 2x2x1 +// fallback cluster shape 2x4x1 // -TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { +TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x4x1) { using ElementAct = cutlass::half_t; using ElementFlt = cutlass::half_t; using ElementOut = cutlass::half_t; @@ -338,7 +338,7 @@ TEST(SM100_device_conv3d_dgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op using Conv = cutlass::conv::device::ConvUniversalAdapter; - EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,4,1))); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu index b820602e1..4c9b707a5 100644 --- a/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +++ b/test/unit/conv/device_3x/fprop/sm100_conv3d_fprop_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -301,9 +301,9 @@ TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op // // CTA tile shape 64x64x64 // preferred cluster shape 2x4x1 -// fallback cluster shape 2x2x1 +// fallback cluster shape 2x4x1 // -TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { +TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x4x1) { using ElementAct = cutlass::half_t; using ElementFlt = cutlass::half_t; using ElementOut = cutlass::half_t; @@ -344,7 +344,7 @@ TEST(SM100_device_conv3d_fprop_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op using Conv = cutlass::conv::device::ConvUniversalAdapter; - EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,4,1))); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu index efb394ca4..253d6b162 100644 --- a/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +++ b/test/unit/conv/device_3x/wgrad/sm100_conv1d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -296,9 +296,9 @@ TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, // // CTA tile shape 64x64x64 // preferred cluster shape 2x4x1 -// fallback cluster shape 2x2x1 +// fallback cluster shape 2x4x1 // -TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { +TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x4x1) { using ElementAct = cutlass::half_t; using ElementFlt = cutlass::half_t; using ElementOut = cutlass::half_t; @@ -338,7 +338,7 @@ TEST(SM100_device_conv1d_wgrad_implicitgemm_f16nwc_f16nwc_f16nwc_tensor_op_f16, using Conv = cutlass::conv::device::ConvUniversalAdapter; - EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,4,1))); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu index 88d74d021..9766f6fd7 100644 --- a/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +++ b/test/unit/conv/device_3x/wgrad/sm100_conv2d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -296,9 +296,9 @@ TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f1 // // CTA tile shape 64x64x64 // preferred cluster shape 2x4x1 -// fallback cluster shape 2x2x1 +// fallback cluster shape 2x4x1 // -TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { +TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x4x1) { using ElementAct = cutlass::half_t; using ElementFlt = cutlass::half_t; using ElementOut = cutlass::half_t; @@ -338,7 +338,7 @@ TEST(SM100_device_conv2d_wgrad_implicitgemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f1 using Conv = cutlass::conv::device::ConvUniversalAdapter; - EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,4,1))); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu index 7659f222f..320d3e6f8 100644 --- a/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu +++ b/test/unit/conv/device_3x/wgrad/sm100_conv3d_wgrad_implicit_gemm_f16_f16_f16_tensorop_f16_streamk.cu @@ -202,9 +202,9 @@ TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op // // CTA tile shape 64x64x64 // preferred cluster shape 2x4x1 -// fallback cluster shape 2x2x1 +// fallback cluster shape 2x4x1 // -TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x2x1) { +TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op_f16, 64x64x64_preferred_2x4x1_fallback_2x4x1) { using ElementAct = cutlass::half_t; using ElementFlt = cutlass::half_t; using ElementOut = cutlass::half_t; @@ -244,7 +244,7 @@ TEST(SM100_device_conv3d_wgrad_implicitgemm_f16ndhwc_f16ndhwc_f16ndhwc_tensor_op using Conv = cutlass::conv::device::ConvUniversalAdapter; - EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,2,1))); + EXPECT_TRUE(test::conv::device::TestAllConv(1.0, 0.0, 0.0f, dim3(2,4,1), dim3(2,4,1))); } #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp index 3fed2b727..5eb2776f3 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -620,7 +620,8 @@ template < typename ElementReduce, bool FinalReduction = true, // Should match the FinalReduction in Device type typename CtaTileShapeMNK = cute::Shape, - typename ElementCompute = float + typename ElementCompute = float, + int ReduceIdentity = 0 > class HostRowReduce: public HostEVTNodeBase { public: @@ -674,7 +675,9 @@ public: reduce_buffer_.resize(shape); } - cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + cutlass::reference::host::TensorFill(tensor_row_reduce_.host_view(), ElementDst(ReduceIdentity)); + tensor_row_reduce_.sync_device(); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view(), ElementCompute(ReduceIdentity)); } template @@ -725,7 +728,7 @@ public: } Arguments get_arguments() { - return {tensor_row_reduce_.device_data()}; + return {{tensor_row_reduce_.device_data(), ElementCompute(ReduceIdentity)}}; } }; @@ -738,7 +741,8 @@ template < typename ElementReduce, bool FinalReduction = true, // Should match the FinalReduction in Device type typename CtaTileShapeMNK = cute::Shape, - typename ElementCompute = float + typename ElementCompute = float, + int ReduceIdentity = 0 > class HostColumnReduce: public HostEVTNodeBase { public: @@ -793,7 +797,9 @@ public: reduce_buffer_.resize(shape); } - cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + cutlass::reference::host::TensorFill(tensor_column_reduce_.host_view(), ElementDst(ReduceIdentity)); + tensor_column_reduce_.sync_device(); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view(), ElementCompute(ReduceIdentity)); } template @@ -844,7 +850,7 @@ public: } Arguments get_arguments() { - return {tensor_column_reduce_.device_data()}; + return {{tensor_column_reduce_.device_data(), ElementCompute(ReduceIdentity)}}; } }; @@ -856,7 +862,8 @@ template < template class ReduceFn, typename ElementReduce, typename ElementCompute = float, - bool enabled = true + bool enabled = true, + int ReduceIdentity = 0 > class HostScalarReduce: public HostEVTNodeBase { public: @@ -886,8 +893,9 @@ public: reference_scalar_reduce_.resize(cutlass::Coord<1>(1)); reduce_buffer_.resize(cutlass::Coord<1>(1)); + cutlass::reference::host::TensorFill(tensor_scalar_reduce_.host_view(), ElementReduce(ReduceIdentity)); tensor_scalar_reduce_.sync_device(); - cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view(), ElementCompute(ReduceIdentity)); } template @@ -929,7 +937,7 @@ public: } Arguments get_arguments() { - return {tensor_scalar_reduce_.device_data()}; + return {{tensor_scalar_reduce_.device_data(), ElementCompute(ReduceIdentity)}}; } auto get_flatten_arguments() { diff --git a/tools/library/src/reference/blockwise_gemm_reference_operation.h b/tools/library/src/reference/blockwise_gemm_reference_operation.h index c40ac1bef..a03310149 100644 --- a/tools/library/src/reference/blockwise_gemm_reference_operation.h +++ b/tools/library/src/reference/blockwise_gemm_reference_operation.h @@ -598,201 +598,354 @@ void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int template void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { + // E4M3 FP8 variants make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); + // E5M2 FP8 variants + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + // Mixed E4M3 x E5M2 variants make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1 , 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); + // Mixed E5M2 x E4M3 variants make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 1, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); + // E2M3 FP6 variants make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 1 , 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 128, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 1, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 128, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ - >(manifest, 64, 1 , 128); + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 64, 128, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 32, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 32, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 64, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 64, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 128, 256, 128); make_blockwise_gemm< - float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, - ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + float_e2m3_t /*A*/, float /*SFA*/, float_e2m3_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + // E3M2 FP6 variants + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e3m2_t /*A*/, float /*SFA*/, float_e3m2_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + // E2M1 FP4 variants + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e2m1_t /*A*/, float /*SFA*/, float_e2m1_t /*B*/, float /*SFB*/, + ElementC /*C*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ >(manifest, 1, 256, 128); }